Set configuration using untiled operations. (#8446)

With tile and distributed being moved out of Flow, the backends can now use the untiled operations to decide configuration. Make that change and move the TileAndDistributeToWorkgroups Pass to run after configuration specifications on all backends. This makes the `getUntiledShape` and `getUntiledResultShape` methods unnecessary and are removed. The `LoopTileAndDistributionInfo` is still needed since its load bearing for some `affine.min` canonicalizations on the SPIR-V side.
- On the CPU side the configuration setting made heavy use of the existence of tiled loops to decide configuration in an op agnostic way. Adapt the configuration selection to not rely on these tiled loops and use the untiled operations directly.
- One of the CUDA tests checking for illegal configuration seems to be off and is being looked at in #8456 . Disabling that test now, but will have to be fixed after the fact in that PR.

Co-authored-by: Lei Zhang <antiagainst@google.com>
Co-authored-by: Thomas Raoux <thomasraoux@google.com>
diff --git a/iree/compiler/Codegen/Common/ConvertToDestinationPassingStylePass.cpp b/iree/compiler/Codegen/Common/ConvertToDestinationPassingStylePass.cpp
index f0d6fb7..a4f538a 100644
--- a/iree/compiler/Codegen/Common/ConvertToDestinationPassingStylePass.cpp
+++ b/iree/compiler/Codegen/Common/ConvertToDestinationPassingStylePass.cpp
@@ -294,12 +294,13 @@
                                             linalg::InitTensorOp initTensorOp) {
   OpBuilder::InsertionGuard g(b);
   b.setInsertionPoint(initTensorOp);
-  for (auto &use : llvm::make_range(std::next(initTensorOp->use_begin()),
-                                    initTensorOp->use_end())) {
+  SmallVector<OpOperand *> uses = llvm::to_vector(llvm::map_range(
+      initTensorOp->getUses(), [](OpOperand &use) { return &use; }));
+  for (auto use : llvm::make_range(std::next(uses.begin()), uses.end())) {
     auto newOp =
         cast<linalg::InitTensorOp>(b.clone(*initTensorOp.getOperation()));
-    Operation *user = use.getOwner();
-    user->setOperand(use.getOperandNumber(), newOp);
+    Operation *user = use->getOwner();
+    user->setOperand(use->getOperandNumber(), newOp);
   }
   return success();
 }
diff --git a/iree/compiler/Codegen/Common/TileAndDistributeToWorkgroupsPass.cpp b/iree/compiler/Codegen/Common/TileAndDistributeToWorkgroupsPass.cpp
index ebbcb3e..9decf95 100644
--- a/iree/compiler/Codegen/Common/TileAndDistributeToWorkgroupsPass.cpp
+++ b/iree/compiler/Codegen/Common/TileAndDistributeToWorkgroupsPass.cpp
@@ -16,8 +16,10 @@
 #include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h"
 #include "iree-dialects/Dialect/LinalgExt/Transforms/Transforms.h"
 #include "iree/compiler/Codegen/Common/DestructiveUpdateUtils.h"
+#include "iree/compiler/Codegen/Dialect/LoweringConfig.h"
 #include "iree/compiler/Codegen/PassDetail.h"
 #include "iree/compiler/Codegen/Passes.h"
+#include "iree/compiler/Codegen/Transforms/Transforms.h"
 #include "iree/compiler/Codegen/Utils/Utils.h"
 #include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
 #include "iree/compiler/Dialect/Flow/IR/PartitionableLoopsInterface.h"
@@ -36,6 +38,122 @@
 // Patterns and methods for tile and distribute of Linalg ops to workgroups.
 //===---------------------------------------------------------------------===//
 
+// Get the lowering configuration for the operation within the dispatch.
+// This looks for tile sizes by looking for lowering configuration.
+static FailureOr<SmallVector<int64_t>> getTileSizesFromLoweringConfig(
+    ArrayRef<Operation *> computeOps, MLIRContext *context) {
+  if (computeOps.empty()) return SmallVector<int64_t>{};
+
+  Optional<SmallVector<int64_t>> distributedTileSizes;
+  for (auto op : computeOps) {
+    auto partitionbleLoopInterface =
+        dyn_cast<IREE::Flow::PartitionableLoopsInterface>(op);
+    if (!partitionbleLoopInterface) continue;
+    IREE::Codegen::LoweringConfigAttr currLoweringConfig =
+        getLoweringConfig(op);
+    if (!currLoweringConfig) continue;
+    SmallVector<unsigned> partitionableLoops =
+        partitionbleLoopInterface.getPartitionableLoops(kNumMaxParallelDims);
+    SmallVector<int64_t> tileSizes = currLoweringConfig.getTileSizeVals(0);
+    SmallVector<int64_t> currDistributedTileSizes;
+    if (!partitionableLoops.empty()) {
+      currDistributedTileSizes.resize(partitionableLoops.back() + 1, 0);
+    }
+    for (auto loopID : partitionableLoops) {
+      if (loopID < tileSizes.size()) {
+        currDistributedTileSizes[loopID] = tileSizes[loopID];
+      }
+    }
+    if (distributedTileSizes) {
+      if (currDistributedTileSizes != distributedTileSizes) {
+        // Inconsistent distributed tile sizes. Abort.
+        return static_cast<LogicalResult>(
+            computeOps.front()->emitOpError("inconsistent distribution of ops "
+                                            "for first level of distribution"));
+      }
+    } else {
+      distributedTileSizes = currDistributedTileSizes;
+    }
+  }
+  if (distributedTileSizes) {
+    return distributedTileSizes.getValue();
+  }
+  return SmallVector<int64_t>{};
+}
+
+/// Compute the workload per workgroup to use based on the tile sizes passed.
+static SmallVector<int64_t> getWorkloadPerWorkgroup(
+    ArrayRef<int64_t> distributedLoopTileSizes) {
+  // TODO(ravishankarm): This for now assumes that we can just drop all the
+  // zero-dim tile sizes. We need to eventually change this so that we dont have
+  // to do this. It is implicity linked to the dispatch region workload having
+  // the consistent information. That needs to be changed to take the entire
+  // iteration domain size as the argument, and then we can use the distribute
+  // loop tile sizes directly.
+  SmallVector<int64_t> nonZeroTileSizes;
+  for (auto tileSizes : distributedLoopTileSizes) {
+    if (!tileSizes) continue;
+    nonZeroTileSizes.push_back(tileSizes);
+  }
+  return llvm::to_vector(llvm::reverse(nonZeroTileSizes));
+}
+
+/// Defines the workgroup count region if the tile size for the distributed
+/// loops are known.
+static LogicalResult defineWorkgroupCountRegion(
+    FuncOp entryPointFn, ArrayRef<int64_t> workloadPerWorkgroup) {
+  if (workloadPerWorkgroup.size() > kNumMaxParallelDims) {
+    // For now error out here.
+    return entryPointFn.emitOpError(
+               "expected workload per workgroup to be less than or equal to ")
+           << kNumMaxParallelDims;
+  }
+  WorkgroupCountRegionBuilder regionBuilder =
+      [&workloadPerWorkgroup](
+          OpBuilder &b, Location loc,
+          std::array<Value, 3> workload) -> std::array<Value, 3> {
+    Value one = b.create<arith::ConstantIndexOp>(loc, 1);
+    std::array<Value, 3> numWorkgroups = {one, one, one};
+    for (auto it : llvm::enumerate(workloadPerWorkgroup)) {
+      // If tile size is 0, it implies this isnt tiled, and the number of
+      // workgroups is 1, i.e. the default.
+      if (it.value() == 0) continue;
+      numWorkgroups[it.index()] = applyMapToValues(
+          b, loc,
+          AffineMap::get(0, 1, b.getAffineSymbolExpr(0).ceilDiv(it.value())),
+          workload[it.index()])[0];
+    }
+    return numWorkgroups;
+  };
+  OpBuilder builder(entryPointFn.getContext());
+  return defineWorkgroupCountRegion(builder, entryPointFn, regionBuilder);
+}
+
+/// Update the workload_per_wg value on the TranslationInfoAttr.
+// TODO(ravishankarm): The workload_per_wg field should be deprecated. This
+// is just transition before all dependencies on it can be removed.
+static LogicalResult updateTranslationInfoAttr(
+    FuncOp entryPointFn, ArrayRef<int64_t> workloadPerWorkgroup) {
+  auto entryPointOp = getEntryPoint(entryPointFn);
+  if (!entryPointOp) {
+    return entryPointFn.emitOpError("expected entry point function");
+  }
+  IREE::Codegen::DispatchLoweringPassPipeline passPipeline =
+      IREE::Codegen::DispatchLoweringPassPipeline::CPUDefault;
+  if (auto translationInfo = getTranslationInfo(entryPointOp)) {
+    // Expect the `workload_per_wg` to be empty.
+    if (!translationInfo.getWorkloadPerWorkgroupVals().empty()) {
+      return entryPointFn.emitOpError(
+          "expected workload_per_wg to be empty at this stage");
+    }
+    passPipeline = translationInfo.getDispatchLoweringPassPipeline();
+  }
+  auto newTranslationInfoAttr = IREE::Codegen::TranslationInfoAttr::get(
+      entryPointFn.getContext(), passPipeline, workloadPerWorkgroup);
+  setTranslationInfo(entryPointOp, newTranslationInfoAttr);
+  return success();
+}
+
 // Pull in producers into the tiled operation.
 static void pullInProducers(linalg::LinalgOp tiledOp,
                             ValueRange untiledOperands,
@@ -125,6 +243,21 @@
     return;
   }
 
+  // Get the tile sizes to use from lowering configuration if set.
+  FailureOr<SmallVector<int64_t>> configTileSizes =
+      getTileSizesFromLoweringConfig(computeOps, context);
+  if (failed(configTileSizes)) {
+    return signalPassFailure();
+  }
+  ArrayRef<int64_t> configTileSizesRef(configTileSizes.getValue());
+
+  SmallVector<int64_t> workloadPerWorkroup =
+      getWorkloadPerWorkgroup(configTileSizesRef);
+  if (failed(defineWorkgroupCountRegion(funcOp, workloadPerWorkroup)) ||
+      failed(updateTranslationInfoAttr(funcOp, workloadPerWorkroup))) {
+    return signalPassFailure();
+  }
+
   // Add a marker to the last operation in the list.
   auto marker = StringAttr::get(context, "__workgroup_tiling__");
   computeOps.back()->setAttr(linalg::LinalgTransforms::kLinalgTransformMarker,
@@ -151,40 +284,14 @@
       DenseMap<StringRef,
                std::function<linalg::ProcInfo(OpBuilder &, Location)>>()};
 
-  // Tile size selection function. Sets the tile size now to
-  // hal.interface.workgroup.size op, with 0 for the innermost parallel loop
-  // partitioned, 1 for the next outermost loop partitioned and so on.  Use the
-  // workgroup size as a proxy for tile size here. At the flow level this
-  // represents the "workload" per processors and is not necessarily tied to the
-  // workgroup size.
-
-  // TODO(#...): Refactor this pass to directly take the tile sizes from lower
-  // configuration for the first level of tiling.
+  // Tile size selection function.
   auto tileSizeFn = [&](OpBuilder &builder,
                         Operation *op) -> SmallVector<Value, 4> {
-    auto interfaceOp = dyn_cast<IREE::Flow::PartitionableLoopsInterface>(op);
-    if (!interfaceOp) return {};
-    SmallVector<unsigned> partitionedLoops =
-        interfaceOp.getPartitionableLoops(kNumMaxParallelDims);
-    if (partitionedLoops.empty()) return {};
-    unsigned maxDepth = partitionedLoops.back() + 1;
-
-    // Set all loops not partitioned to tile size 0. and those partitioned to
-    // `flow.workgroup.size`.
-    auto zero = builder.create<arith::ConstantIndexOp>(op->getLoc(), 0);
-    SmallVector<Value, 4> useTileSizes(maxDepth, zero);
-    llvm::DenseSet<unsigned> partitionedLoopsSet;
-    partitionedLoopsSet.insert(partitionedLoops.begin(),
-                               partitionedLoops.end());
-    unsigned currFlowDim = 0;
-    for (size_t dim = maxDepth; dim > 0; dim--) {
-      if (partitionedLoopsSet.count(dim - 1)) {
-        useTileSizes[dim - 1] =
-            buildHALWorkgroupInfoOp<IREE::HAL::InterfaceWorkgroupSizeOp>(
-                builder, currFlowDim++);
-      }
-    }
-    return useTileSizes;
+    // Check if tile sizes are deduced from the configuration. If so use those.
+    return llvm::to_vector<4>(
+        llvm::map_range(configTileSizesRef, [&](int64_t ts) -> Value {
+          return builder.create<arith::ConstantIndexOp>(op->getLoc(), ts);
+        }));
   };
 
   auto linalgTilingOptions =
diff --git a/iree/compiler/Codegen/Common/test/convert_to_destination_passing_style.mlir b/iree/compiler/Codegen/Common/test/convert_to_destination_passing_style.mlir
index 2df64a0..19d84d1 100644
--- a/iree/compiler/Codegen/Common/test/convert_to_destination_passing_style.mlir
+++ b/iree/compiler/Codegen/Common/test/convert_to_destination_passing_style.mlir
@@ -468,3 +468,62 @@
 //       CHECK:   linalg.generic
 //  CHECK-SAME:       ins(%[[IN_VIEW]], %[[INIT]]
 //  CHECK-SAME:       outs(%[[OUT_VIEW]]
+
+// -----
+
+func @three_init_tensor_uses() {
+  %c6400 = arith.constant 6400 : index
+  %c64 = arith.constant 64 : index
+  %c1654784 = arith.constant 1654784 : index
+  %c1638400 = arith.constant 1638400 : index
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 3.40282347E+38 : f32
+  %cst_0 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<64xf32>
+  %cst_1 = arith.constant 0.000000e+00 : f32
+  %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(32) : !flow.dispatch.tensor<readonly:6400x64xf32>
+  %1 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c1638400) alignment(32) : !flow.dispatch.tensor<readonly:64x64xf32>
+  %2 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c1654784) alignment(32) : !flow.dispatch.tensor<writeonly:6400x64xf32>
+  %workgroup_id_x = hal.interface.workgroup.id[0] : index
+  %workgroup_count_x = hal.interface.workgroup.count[0] : index
+  %workgroup_id_y = hal.interface.workgroup.id[1] : index
+  %workgroup_count_y = hal.interface.workgroup.count[1] : index
+  %3 = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%workgroup_id_y]
+  %4 = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%workgroup_count_y]
+  scf.for %arg0 = %3 to %c6400 step %4 {
+    %5 = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%workgroup_id_x]
+    %6 = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%workgroup_count_x]
+    scf.for %arg1 = %5 to %c64 step %6 {
+      %7 = linalg.init_tensor [64, 64] : tensor<64x64xf32>
+      %8 = tensor.extract_slice %cst_0[%arg1] [64] [1] : tensor<64xf32> to tensor<64xf32>
+      %9 = flow.dispatch.tensor.load %0, offsets = [%arg0, 0], sizes = [64, 64], strides = [1, 1] : !flow.dispatch.tensor<readonly:6400x64xf32> -> tensor<64x64xf32>
+      %10 = flow.dispatch.tensor.load %1, offsets = [0, %arg1], sizes = [64, 64], strides = [1, 1] : !flow.dispatch.tensor<readonly:64x64xf32> -> tensor<64x64xf32>
+      %11 = linalg.fill(%cst_1, %7) : f32, tensor<64x64xf32> -> tensor<64x64xf32>
+      %12 = linalg.matmul {lowering.config = #iree_codegen.lowering.config<tile_sizes = [[64, 64, 0], [8, 32, 0], [0, 0, 16]], native_vector_size = []>} ins(%9, %10 : tensor<64x64xf32>, tensor<64x64xf32>) outs(%11 : tensor<64x64xf32>) -> tensor<64x64xf32>
+      %13 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%8, %12 : tensor<64xf32>, tensor<64x64xf32>) outs(%7 : tensor<64x64xf32>) {
+      ^bb0(%arg2: f32, %arg3: f32, %arg4: f32):
+        %15 = arith.addf %arg2, %arg3 : f32
+        linalg.yield %15 : f32
+      } -> tensor<64x64xf32>
+      %14 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%13 : tensor<64x64xf32>) outs(%7 : tensor<64x64xf32>) {
+      ^bb0(%arg2: f32, %arg3: f32):
+        %15 = arith.cmpf olt, %arg2, %cst_1 : f32
+        %16 = arith.select %15, %cst_1, %arg2 : f32
+        %17 = arith.cmpf olt, %cst, %arg2 : f32
+        %18 = arith.select %17, %cst, %16 : f32
+        linalg.yield %18 : f32
+      } -> tensor<64x64xf32>
+      flow.dispatch.tensor.store %14, %2, offsets = [%arg0, %arg1], sizes = [64, 64], strides = [1, 1] : tensor<64x64xf32> -> !flow.dispatch.tensor<writeonly:6400x64xf32>
+    }
+  }
+  return
+}
+// CHECK-LABEL: func @three_init_tensor_uses()
+//       CHECK: %[[OUTPUT:.+]] = hal.interface.binding.subspan set(0) binding(1)
+//   CHECK-NOT:   linalg.init_tensor
+//       CHECK:   %[[LOAD:.+]] = flow.dispatch.tensor.load %[[OUTPUT]]
+//   CHECK-NOT:   linalg.init_tensor
+//       CHECK:   linalg.fill(%{{.+}}, %[[LOAD]])
+//       CHECK:   linalg.generic
+//  CHECK-SAME:       outs(%[[LOAD]] :
+//       CHECK:   linalg.generic
+//  CHECK-SAME:       outs(%[[LOAD]] :
diff --git a/iree/compiler/Codegen/Common/test/remove_trivial_loops.mlir b/iree/compiler/Codegen/Common/test/remove_trivial_loops.mlir
index 4833f6b..7a47dde 100644
--- a/iree/compiler/Codegen/Common/test/remove_trivial_loops.mlir
+++ b/iree/compiler/Codegen/Common/test/remove_trivial_loops.mlir
@@ -10,7 +10,7 @@
 // CHECK-LABEL: func @dispatch_0()
 hal.executable private @dispatch_0  {
   hal.executable.variant @cuda, target = #hal.executable.target<"cuda", "cuda-nvptx-fb"> {
-    hal.executable.entry_point @dispatch_0 layout(#executable_layout) attributes {
+    hal.executable.entry_point @dispatch_0 layout(#executable_layout) {
       workgroup_size = [64: index, 1: index, 1:index]
     }
     builtin.module {
@@ -57,7 +57,7 @@
 #translation = #iree_codegen.translation.info<"LLVMGPUDistribute", workload_per_wg = [32]>
 hal.executable private @workgroup_tile_loop  {
   hal.executable.variant @cuda, target = #hal.executable.target<"cuda", "cuda-nvptx-fb"> {
-    hal.executable.entry_point @workgroup_tile_loop layout(#executable_layout) attributes {
+    hal.executable.entry_point @workgroup_tile_loop layout(#executable_layout) {
       translation.info = #translation
     }
     builtin.module {
@@ -91,7 +91,7 @@
 #translation = #iree_codegen.translation.info<"LLVMGPUDistribute", workload_per_wg = [16]>
 hal.executable private @workgroup_tile_loop_negative  {
   hal.executable.variant @cuda, target = #hal.executable.target<"cuda", "cuda-nvptx-fb"> {
-    hal.executable.entry_point @workgroup_tile_loop_negative layout(#executable_layout) attributes {
+    hal.executable.entry_point @workgroup_tile_loop_negative layout(#executable_layout)  {
       translation.info = #translation
     }
     builtin.module {
@@ -127,7 +127,7 @@
 #translation = #iree_codegen.translation.info<"LLVMGPUDistribute", workload_per_wg = [32, 8, 1]>
 hal.executable private @both_workgroup_and_workitem  {
   hal.executable.variant @cuda, target = #hal.executable.target<"cuda", "cuda-nvptx-fb"> {
-    hal.executable.entry_point @both_workgroup_and_workitem layout(#executable_layout) attributes {
+    hal.executable.entry_point @both_workgroup_and_workitem layout(#executable_layout)  {
       translation.info = #translation,
       workgroup_size = [8: index, 2: index, 1: index]
     }
diff --git a/iree/compiler/Codegen/Common/test/tile_and_distribute_to_workgroups.mlir b/iree/compiler/Codegen/Common/test/tile_and_distribute_to_workgroups.mlir
index 147fc4b..6757346 100644
--- a/iree/compiler/Codegen/Common/test/tile_and_distribute_to_workgroups.mlir
+++ b/iree/compiler/Codegen/Common/test/tile_and_distribute_to_workgroups.mlir
@@ -1,717 +1,1366 @@
-// RUN: iree-opt -iree-codegen-tile-and-distribute-to-workgroups -cse -split-input-file %s | FileCheck %s
+// RUN: iree-opt -pass-pipeline='hal.executable(hal.executable.variant(builtin.module(builtin.func(iree-codegen-tile-and-distribute-to-workgroups)))), canonicalize, cse' -split-input-file %s | FileCheck %s
 
-func @simple_gemm_dynamic() {
-  %cst = arith.constant 0.000000e+00 : f32
-  %c0 = arith.constant 0 : index
-  %0 = hal.interface.constant.load[0] : i32
-  %1 = hal.interface.constant.load[1] : i32
-  %2 = hal.interface.constant.load[2] : i32
-  %4 = arith.index_cast %0 : i32 to index
-  %5 = arith.index_cast %1 : i32 to index
-  %6 = arith.index_cast %2 : i32 to index
-  %8 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(32) : !flow.dispatch.tensor<readonly:?x?xf32>{%4, %5}
-  %9 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(32) : !flow.dispatch.tensor<readonly:?x?xf32>{%5, %6}
-  %10 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) offset(%c0) alignment(32) : !flow.dispatch.tensor<writeonly:?x?xf32>{%4, %6}
-  %11 = flow.dispatch.tensor.load %8, offsets = [0, 0], sizes = [%4, %5], strides = [1, 1] : !flow.dispatch.tensor<readonly:?x?xf32>{%4, %5} -> tensor<?x?xf32>
-  %12 = flow.dispatch.tensor.load %9, offsets = [0, 0], sizes = [%5, %6], strides = [1, 1] : !flow.dispatch.tensor<readonly:?x?xf32>{%5, %6} -> tensor<?x?xf32>
-  %13 = linalg.init_tensor [%4, %6] : tensor<?x?xf32>
-  %14 = linalg.fill(%cst, %13) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
-  %15 = linalg.matmul ins(%11, %12 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%14 : tensor<?x?xf32>) -> tensor<?x?xf32>
-  flow.dispatch.tensor.store %15, %10, offsets = [0, 0], sizes = [%4, %6], strides = [1, 1] : tensor<?x?xf32> -> !flow.dispatch.tensor<writeonly:?x?xf32>{%4, %6}
-  return
+#config = #iree_codegen.lowering.config<tile_sizes = [[64, 64, 0], [16, 4, 64], [4, 4, 4]], native_vector_size = [4, 4, 4]>
+#executable_layout = #hal.executable.layout<push_constants = 0, sets = [
+  #hal.descriptor_set.layout<0, bindings = [
+    #hal.descriptor_set.binding<0, storage_buffer>,
+    #hal.descriptor_set.binding<1, storage_buffer>,
+    #hal.descriptor_set.binding<2, storage_buffer>,
+    #hal.descriptor_set.binding<3, storage_buffer>
+  ]>
+]>
+#executable_target_embedded_elf_arm_64_ = #hal.executable.target<"llvm", "embedded-elf-arm_64", {
+  data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128",
+  native_vector_size = 16 : index,
+  target_triple = "aarch64-unknown-unknown-eabi-elf"
+}>
+#translation = #iree_codegen.translation.info<"CPUTileFuseAndVectorize", workload_per_wg = []>
+hal.executable private @matmul_tensors {
+  hal.executable.variant public @llvm, target = #executable_target_embedded_elf_arm_64_ {
+    hal.executable.entry_point public @matmul_tensors layout(#executable_layout) {translation.info = #translation}
+    builtin.module {
+      func @matmul_tensors() {
+        %0 = hal.interface.constant.load[0] : index
+        %1 = hal.interface.constant.load[1] : index
+        %2 = hal.interface.constant.load[2] : index
+        %3 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer)
+            : !flow.dispatch.tensor<readonly:?x?xf32>{%0, %2}
+        %4 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer)
+            : !flow.dispatch.tensor<readonly:?x?xf32>{%2, %1}
+        %5 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer)
+            : !flow.dispatch.tensor<readonly:?x?xf32>{%0, %1}
+        %6 = hal.interface.binding.subspan set(0) binding(3) type(storage_buffer)
+            : !flow.dispatch.tensor<writeonly:?x?xf32>{%0, %1}
+        %7 = flow.dispatch.tensor.load %3, offsets = [0, 0], sizes = [%0, %2], strides = [1, 1]
+            : !flow.dispatch.tensor<readonly:?x?xf32>{%0, %2} -> tensor<?x?xf32>
+        %8 = flow.dispatch.tensor.load %4, offsets = [0, 0], sizes = [%2, %1], strides = [1, 1]
+            : !flow.dispatch.tensor<readonly:?x?xf32>{%2, %1} -> tensor<?x?xf32>
+        %9 = flow.dispatch.tensor.load %5, offsets = [0, 0], sizes = [%0, %1], strides = [1, 1]
+            : !flow.dispatch.tensor<readonly:?x?xf32>{%0, %1} -> tensor<?x?xf32>
+        %10 = linalg.matmul {lowering.config = #config}
+            ins(%7, %8 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%9 : tensor<?x?xf32>) -> tensor<?x?xf32>
+        flow.dispatch.tensor.store %10, %6, offsets = [0, 0], sizes = [%0, %1], strides = [1, 1]
+            : tensor<?x?xf32> -> !flow.dispatch.tensor<writeonly:?x?xf32>{%0, %1}
+        return
+      }
+    }
+  }
 }
-//   CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> (s0 * s1)>
-//   CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0, s1] -> (s1, -d0 + s0)>
-//   CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0)[s0, s1] -> (-d0 + s0, s1)>
-//       CHECK: func @simple_gemm_dynamic()
-//   CHECK-DAG:   %[[MVAL:.+]] = hal.interface.constant.load[0] : i32
-//   CHECK-DAG:   %[[KVAL:.+]] = hal.interface.constant.load[1] : i32
-//   CHECK-DAG:   %[[NVAL:.+]] = hal.interface.constant.load[2] : i32
-//   CHECK-DAG:   %[[M:.+]] = arith.index_cast %[[MVAL]]
-//   CHECK-DAG:   %[[K:.+]] = arith.index_cast %[[KVAL]]
-//   CHECK-DAG:   %[[N:.+]] = arith.index_cast %[[NVAL]]
-//   CHECK-DAG:   %[[LHS_BINDING:.+]] = hal.interface.binding.subspan set(0) binding(0)
-//   CHECK-DAG:   %[[RHS_BINDING:.+]] = hal.interface.binding.subspan set(0) binding(1)
-//   CHECK-DAG:   %[[OUT_BINDING:.+]] = hal.interface.binding.subspan set(0) binding(2)
-//   CHECK-DAG:   %[[WG_SIZE_X:.+]] = hal.interface.workgroup.size[0]
-//   CHECK-DAG:   %[[WG_SIZE_Y:.+]] = hal.interface.workgroup.size[1]
-//   CHECK-DAG:   %[[WG_ID_X:.+]] = hal.interface.workgroup.id[0]
-//   CHECK-DAG:   %[[WG_COUNT_X:.+]] = hal.interface.workgroup.count[0]
-//   CHECK-DAG:   %[[WG_ID_Y:.+]] = hal.interface.workgroup.id[1]
-//   CHECK-DAG:   %[[WG_COUNT_Y:.+]] = hal.interface.workgroup.count[1]
-//   CHECK-DAG:   %[[LB_Y:.+]] = affine.apply #[[MAP0]]()[%[[WG_ID_Y]], %[[WG_SIZE_Y]]]
-//   CHECK-DAG:   %[[STEP_Y:.+]] = affine.apply #[[MAP0]]()[%[[WG_COUNT_Y]], %[[WG_SIZE_Y]]]
-//       CHECK:   scf.for %[[IV0:.+]] = %[[LB_Y]] to %[[M]] step %[[STEP_Y]]
-//   CHECK-DAG:     %[[LB_X:.+]] = affine.apply #[[MAP0]]()[%[[WG_ID_X]], %[[WG_SIZE_X]]]
-//   CHECK-DAG:     %[[STEP_X:.+]] = affine.apply #[[MAP0]]()[%[[WG_COUNT_X]], %[[WG_SIZE_X]]]
-//       CHECK:     scf.for %[[IV1:.+]] = %[[LB_X]] to %[[N]] step %[[STEP_X]]
-//   CHECK-DAG:       %[[M_TILE:.+]] = affine.min #[[MAP1]](%[[IV0]])[%[[M]], %[[WG_SIZE_Y]]]
-//   CHECK-DAG:       %[[LHS:.+]] = flow.dispatch.tensor.load %[[LHS_BINDING]]
-//  CHECK-SAME:           offsets = [%[[IV0]], 0], sizes = [%[[M_TILE]], %[[K]]]
-//   CHECK-DAG:       %[[N_TILE:.+]] = affine.min #[[MAP1]](%[[IV1]])[%[[N]], %[[WG_SIZE_X]]]
-//   CHECK-DAG:       %[[RHS:.+]] = flow.dispatch.tensor.load %[[RHS_BINDING]]
-//  CHECK-SAME:           offsets = [0, %[[IV1]]], sizes = [%[[K]], %[[N_TILE]]]
-//   CHECK-DAG:       %[[M_TILE2:.+]] = affine.min #[[MAP2]](%[[IV0]])[%[[M]], %[[WG_SIZE_Y]]]
-//   CHECK-DAG:       %[[N_TILE2:.+]] = affine.min #[[MAP2]](%[[IV1]])[%[[N]], %[[WG_SIZE_X]]]
-//   CHECK-DAG:       %[[INIT:.+]] = linalg.init_tensor [%[[M_TILE2]], %[[N_TILE2]]]
-//   CHECK-DAG:       %[[FILL:.+]] = linalg.fill(%{{.+}}, %[[INIT]])
-//   CHECK-DAG:       %[[GEMM:.+]] = linalg.matmul ins(%[[LHS]], %[[RHS]] :
-//  CHECK-SAME:           outs(%[[FILL]] :
-//       CHECK:       flow.dispatch.tensor.store %[[GEMM]], %[[OUT_BINDING]]
-//  CHECK-SAME:           offsets = [%[[IV0]], %[[IV1]]], sizes = [%[[M_TILE]], %[[N_TILE]]]
-
-// -----
-
-func @generic_op_alone() {
-  %cst = arith.constant 0.000000e+00 : f32
-  %c0 = arith.constant 0 : index
-  %0 = hal.interface.constant.load[0] : i32
-  %1 = hal.interface.constant.load[1] : i32
-  %4 = arith.index_cast %0 : i32 to index
-  %5 = arith.index_cast %1 : i32 to index
-  %8 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(32) : !flow.dispatch.tensor<readonly:?x?xf32>{%4, %5}
-  %9 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(32) : !flow.dispatch.tensor<readonly:?xf32>{%5}
-  %10 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) offset(%c0) alignment(32) : !flow.dispatch.tensor<writeonly:?x?xf32>{%4, %5}
-  %11 = flow.dispatch.tensor.load %8, offsets = [0, 0], sizes = [%4, %5], strides = [1, 1] : !flow.dispatch.tensor<readonly:?x?xf32>{%4, %5} -> tensor<?x?xf32>
-  %12 = flow.dispatch.tensor.load %9, offsets = [0], sizes = [%5], strides = [1] : !flow.dispatch.tensor<readonly:?xf32>{%5} -> tensor<?xf32>
-  %13 = linalg.init_tensor [%4, %5] : tensor<?x?xf32>
-  %15 = linalg.generic {
-    indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
-                     affine_map<(d0, d1) -> (d1)>,
-                     affine_map<(d0, d1) -> (d0, d1)>],
-    iterator_types = ["parallel", "parallel"]}
-    ins (%11, %12: tensor<?x?xf32>, tensor<?xf32>)
-    outs (%13 : tensor<?x?xf32>) {
-      ^bb0(%arg0 : f32, %arg1 : f32, %arg2 : f32):
-        %2 = arith.addf %arg0, %arg1 : f32
-        linalg.yield %2 : f32
-    } -> tensor<?x?xf32>
-  flow.dispatch.tensor.store %15, %10, offsets = [0, 0], sizes = [%4, %5], strides = [1, 1] : tensor<?x?xf32> -> !flow.dispatch.tensor<writeonly:?x?xf32>{%4, %5}
-  return
-}
-// CHECK-LABEL: func @generic_op_alone()
-//   CHECK-DAG:   %[[INPUT1:.+]] = hal.interface.binding.subspan set(0) binding(0)
-//   CHECK-DAG:   %[[INPUT2:.+]] = hal.interface.binding.subspan set(0) binding(1)
-//   CHECK-DAG:   %[[OUTPUT:.+]] = hal.interface.binding.subspan set(0) binding(2)
-//   CHECK-DAG:   %[[INIT:.+]] = linalg.init_tensor
-//       CHECK:   scf.for
-//       CHECK:     scf.for
-//   CHECK-DAG:       %[[LHS:.+]] = flow.dispatch.tensor.load %[[INPUT1]]
-//   CHECK-DAG:       %[[RHS:.+]] = flow.dispatch.tensor.load %[[INPUT2]]
-//   CHECK-DAG:       %[[SLICE:.+]] = tensor.extract_slice %[[INIT]]
-//       CHECK:       %[[GENERIC:.+]] = linalg.generic
-//  CHECK-SAME:           ins(%[[LHS]], %[[RHS]] :
-//  CHECK-SAME:           outs(%[[SLICE]] :
-//       CHECK:       flow.dispatch.tensor.store %[[GENERIC]], %[[OUTPUT]]
-
-// -----
-
-func @generic_op_4D() {
-  %cst = arith.constant 0.000000e+00 : f32
-  %c0 = arith.constant 0 : index
-  %0 = hal.interface.constant.load[0] : i32
-  %1 = hal.interface.constant.load[1] : i32
-  %2 = hal.interface.constant.load[2] : i32
-  %3 = hal.interface.constant.load[3] : i32
-  %4 = arith.index_cast %0 : i32 to index
-  %5 = arith.index_cast %1 : i32 to index
-  %6 = arith.index_cast %2 : i32 to index
-  %7 = arith.index_cast %3 : i32 to index
-  %8 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(32) : !flow.dispatch.tensor<readonly:?x?x?x?xf32>{%4, %5, %6, %7}
-  %9 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(32) : !flow.dispatch.tensor<readonly:?x?x?x?xf32>{%4, %5, %6, %7}
-  %10 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) offset(%c0) alignment(32) : !flow.dispatch.tensor<writeonly:?x?x?x?xf32>{%4, %5, %6, %7}
-  %11 = flow.dispatch.tensor.load %8, offsets = [0, 0, 0, 0], sizes = [%4, %5, %6, %7], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:?x?x?x?xf32>{%4, %5, %6, %7} -> tensor<?x?x?x?xf32>
-  %12 = flow.dispatch.tensor.load %9, offsets = [0, 0, 0, 0], sizes = [%4, %5, %6, %7], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:?x?x?x?xf32>{%4, %5, %6, %7} -> tensor<?x?x?x?xf32>
-  %13 = linalg.init_tensor [%4, %5, %6, %7] : tensor<?x?x?x?xf32>
-  %15 = linalg.generic {
-    indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
-                     affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
-                     affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>],
-    iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
-    ins (%11, %12: tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>)
-    outs (%13 : tensor<?x?x?x?xf32>) {
-      ^bb0(%arg0 : f32, %arg1 : f32, %arg2 : f32):
-        %14 = arith.addf %arg0, %arg1 : f32
-        linalg.yield %14 : f32
-    } -> tensor<?x?x?x?xf32>
-  flow.dispatch.tensor.store %15, %10, offsets = [0, 0, 0, 0], sizes = [%4, %5, %6, %7], strides = [1, 1, 1, 1] : tensor<?x?x?x?xf32> -> !flow.dispatch.tensor<writeonly:?x?x?x?xf32>{%4, %5, %6, %7}
-  return
-}
-//  CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> (s0 * s1)>
-//  CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0, s1] -> (s1, -d0 + s0)>
-//      CHECK: func @generic_op_4D()
-//  CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
-//  CHECK-DAG:   %[[D0VAL:.+]] = hal.interface.constant.load[0] : i32
-//  CHECK-DAG:   %[[D1VAL:.+]] = hal.interface.constant.load[1] : i32
-//  CHECK-DAG:   %[[D2VAL:.+]] = hal.interface.constant.load[2] : i32
-//  CHECK-DAG:   %[[D3VAL:.+]] = hal.interface.constant.load[3] : i32
-//  CHECK-DAG:   %[[D0:.+]] = arith.index_cast %[[D0VAL]]
-//  CHECK-DAG:   %[[D1:.+]] = arith.index_cast %[[D1VAL]]
-//  CHECK-DAG:   %[[D2:.+]] = arith.index_cast %[[D2VAL]]
-//  CHECK-DAG:   %[[D3:.+]] = arith.index_cast %[[D3VAL]]
-//  CHECK-DAG:   %[[INPUT1:.+]] = hal.interface.binding.subspan set(0) binding(0)
-//  CHECK-DAG:   %[[INPUT2:.+]] = hal.interface.binding.subspan set(0) binding(1)
-//  CHECK-DAG:   %[[OUTPUT1:.+]] = hal.interface.binding.subspan set(0) binding(2)
-//  CHECK-DAG:   %[[INIT:.+]] = linalg.init_tensor [%[[D0]], %[[D1]], %[[D2]], %[[D3]]]
-//  CHECK-DAG:   %[[WG_SIZE_X:.+]] = hal.interface.workgroup.size[0] : index
-//  CHECK-DAG:   %[[WG_SIZE_Y:.+]] = hal.interface.workgroup.size[1] : index
-//  CHECK-DAG:   %[[WG_SIZE_Z:.+]] = hal.interface.workgroup.size[2] : index
-//  CHECK-DAG:   %[[WG_ID_X:.+]] = hal.interface.workgroup.id[0] : index
-//  CHECK-DAG:   %[[WG_COUNT_X:.+]] = hal.interface.workgroup.count[0] : index
-//  CHECK-DAG:   %[[WG_ID_Y:.+]] = hal.interface.workgroup.id[1] : index
-//  CHECK-DAG:   %[[WG_COUNT_Y:.+]] = hal.interface.workgroup.count[1] : index
-//  CHECK-DAG:   %[[WG_ID_Z:.+]] = hal.interface.workgroup.id[2] : index
-//  CHECK-DAG:   %[[WG_COUNT_Z:.+]] = hal.interface.workgroup.count[2] : index
-//  CHECK-DAG:   %[[LB_Z:.+]] = affine.apply #[[MAP0]]()[%[[WG_ID_Z]], %[[WG_SIZE_Z]]]
-//  CHECK-DAG:   %[[STEP_Z:.+]] = affine.apply #[[MAP0]]()[%[[WG_COUNT_Z]], %[[WG_SIZE_Z]]]
-//      CHECK:   scf.for %[[IV0:.+]] = %[[LB_Z]] to %[[D1]] step %[[STEP_Z]]
-//  CHECK-DAG:     %[[LB_Y:.+]] = affine.apply #[[MAP0]]()[%[[WG_ID_Y]], %[[WG_SIZE_Y]]]
-//  CHECK-DAG:     %[[STEP_Y:.+]] = affine.apply #[[MAP0]]()[%[[WG_COUNT_Y]], %[[WG_SIZE_Y]]]
-//      CHECK:     scf.for %[[IV1:.+]] = %[[LB_Y]] to %[[D2]] step %[[STEP_Y]]
-//  CHECK-DAG:       %[[LB_X:.+]] = affine.apply #[[MAP0]]()[%[[WG_ID_X]], %[[WG_SIZE_X]]]
-//  CHECK-DAG:       %[[STEP_X:.+]] = affine.apply #[[MAP0]]()[%[[WG_COUNT_X]], %[[WG_SIZE_X]]]
-//      CHECK:       scf.for %[[IV2:.+]] = %[[LB_X]] to %[[D3]] step %[[STEP_X]]
-//  CHECK-DAG:         %[[TILE_Z:.+]] = affine.min #[[MAP1]](%[[IV0]])[%[[D1]], %[[WG_SIZE_Z]]]
-//  CHECK-DAG:         %[[TILE_Y:.+]] = affine.min #[[MAP1]](%[[IV1]])[%[[D2]], %[[WG_SIZE_Y]]]
-//  CHECK-DAG:         %[[TILE_X:.+]] = affine.min #[[MAP1]](%[[IV2]])[%[[D3]], %[[WG_SIZE_X]]]
-//      CHECK:         flow.dispatch.tensor.load %[[INPUT1]]
-// CHECK-SAME:             offsets = [0, %[[IV0]], %[[IV1]], %[[IV2]]]
-// CHECK-SAME:             sizes = [%[[D0]], %[[TILE_Z]], %[[TILE_Y]], %[[TILE_X]]]
-//      CHECK:         flow.dispatch.tensor.load %[[INPUT2]]
-// CHECK-SAME:             offsets = [0, %[[IV0]], %[[IV1]], %[[IV2]]]
-// CHECK-SAME:             sizes = [%[[D0]], %[[TILE_Z]], %[[TILE_Y]], %[[TILE_X]]]
-//      CHECK:         %[[D02:[a-zA-Z0-9]+]] = tensor.dim %[[INIT]], %[[C0]] : tensor<?x?x?x?xf32>
-//      CHECK:         tensor.extract_slice %[[INIT]][0, %[[IV0]], %[[IV1]], %[[IV2]]]
-// CHECK-SAME:             [%[[D02]], %[[TILE_Z]], %[[TILE_Y]], %[[TILE_X]]]
-//      CHECK:         flow.dispatch.tensor.store
-// CHECK-SAME:             offsets = [0, %[[IV0]], %[[IV1]], %[[IV2]]]
-// CHECK-SAME:             sizes = [%[[D02]], %[[TILE_Z]], %[[TILE_Y]], %[[TILE_X]]]
-
-
-// -----
-
-func @conv2d() {
-  %cst = arith.constant 0.000000e+00 : f32
-  %c0 = arith.constant 0 : index
-  %8 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(32) : !flow.dispatch.tensor<readonly:1x225x225x16xf32>
-  %9 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(32) : !flow.dispatch.tensor<readonly:3x3x16x32xf32>
-  %10 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) offset(%c0) alignment(32) : !flow.dispatch.tensor<writeonly:1x112x112x32xf32>
-  %11 = flow.dispatch.tensor.load %8, offsets = [0, 0, 0, 0], sizes = [1, 225, 225, 16], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:1x225x225x16xf32> -> tensor<1x225x225x16xf32>
-  %12 = flow.dispatch.tensor.load %9, offsets = [0, 0, 0, 0], sizes = [3, 3, 16, 32], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:3x3x16x32xf32> -> tensor<3x3x16x32xf32>
-  %13 = linalg.init_tensor [1, 112, 112, 32] : tensor<1x112x112x32xf32>
-  %15 = linalg.conv_2d_nhwc_hwcf
-         {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>}
-         ins(%11, %12 : tensor<1x225x225x16xf32>, tensor<3x3x16x32xf32>)
-         outs(%13 : tensor<1x112x112x32xf32>)
-         -> tensor<1x112x112x32xf32>
-  flow.dispatch.tensor.store %15, %10, offsets = [0, 0, 0, 0], sizes = [1, 112, 112, 32], strides = [1, 1, 1, 1] : tensor<1x112x112x32xf32> -> !flow.dispatch.tensor<writeonly:1x112x112x32xf32>
-  return
-}
-// CHECK-LABEL: func @conv2d()
-//   CHECK-DAG:   %[[C112:.+]] = arith.constant 112 : index
-//   CHECK-DAG:   %[[C32:.+]] = arith.constant 32 : index
-//       CHECK:   scf.for %{{.+}} = %{{.+}} to %[[C112]]
-//       CHECK:     scf.for %{{.+}} = %{{.+}} to %[[C112]]
-//       CHECK;       scf.for %{{.+}} = %{{.+}} to %[[C32]]
-
-// -----
-
-func @depthwise_conv2d() {
-  %cst = arith.constant 0.000000e+00 : f32
-  %c0 = arith.constant 0 : index
-  %8 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(32) : !flow.dispatch.tensor<readonly:1x113x113x96xf32>
-  %9 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(32) : !flow.dispatch.tensor<readonly:3x3x96xf32>
-  %10 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) offset(%c0) alignment(32) : !flow.dispatch.tensor<writeonly:1x56x56x96xf32>
-  %11 = flow.dispatch.tensor.load %8, offsets = [0, 0, 0, 0], sizes = [1, 113, 113, 96], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:1x113x113x96xf32> -> tensor<1x113x113x96xf32>
-  %12 = flow.dispatch.tensor.load %9, offsets = [0, 0, 0], sizes = [3, 3, 96], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:3x3x96xf32> -> tensor<3x3x96xf32>
-  %13 = linalg.init_tensor [1, 56, 56, 96] : tensor<1x56x56x96xf32>
-  %15 = linalg.depthwise_conv_2d_nhwc_hwc
-      {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>}
-      ins(%11, %12 : tensor<1x113x113x96xf32>, tensor<3x3x96xf32>)
-      outs(%13 : tensor<1x56x56x96xf32>) -> tensor<1x56x56x96xf32>
-  flow.dispatch.tensor.store %15, %10, offsets = [0, 0, 0, 0], sizes = [1, 56, 56, 96], strides = [1, 1, 1, 1] : tensor<1x56x56x96xf32> -> !flow.dispatch.tensor<writeonly:1x56x56x96xf32>
-  return
-}
-// CHECK-LABEL: func @depthwise_conv2d()
-//   CHECK-DAG:   %[[C56:.+]] = arith.constant 56 : index
-//   CHECK-DAG:   %[[C96:.+]] = arith.constant 96 : index
-//       CHECK:   scf.for %{{.+}} = %{{.+}} to %[[C56]]
-//       CHECK:     scf.for %{{.+}} = %{{.+}} to %[[C56]]
-//       CHECK;       scf.for %{{.+}} = %{{.+}} to %[[C96]]
-
-// -----
-
-func @subtensor_insert() {
-  %c0 = arith.constant 0 : index
-  %offset_y_i32 = hal.interface.constant.load[0] : i32
-  %offset_x_i32 = hal.interface.constant.load[1] : i32
-  %size_y_i32 = hal.interface.constant.load[2] : i32
-  %size_x_i32 = hal.interface.constant.load[3] : i32
-  %dest_size_y_i32 = hal.interface.constant.load[4] : i32
-  %dest_size_x_i32 = hal.interface.constant.load[5] : i32
-  %offset_y = arith.index_cast %offset_y_i32 : i32 to index
-  %offset_x = arith.index_cast %offset_x_i32 : i32 to index
-  %size_y = arith.index_cast %size_y_i32 : i32 to index
-  %size_x = arith.index_cast %size_x_i32 : i32 to index
-  %dest_size_y = arith.index_cast %dest_size_y_i32 : i32 to index
-  %dest_size_x = arith.index_cast %dest_size_x_i32 : i32 to index
-  %source_binding = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(32)
-      : !flow.dispatch.tensor<readonly:?x?xf32>{%size_y, %size_x}
-  %dest_binding = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(32)
-      : !flow.dispatch.tensor<readwrite:?x?xf32>{%dest_size_y, %dest_size_x}
-  %source = flow.dispatch.tensor.load %source_binding, offsets = [0, 0], sizes = [%size_y, %size_x], strides = [1, 1]
-      : !flow.dispatch.tensor<readonly:?x?xf32>{%size_y, %size_x} -> tensor<?x?xf32>
-  %dest = flow.dispatch.tensor.load %dest_binding, offsets = [0, 0], sizes = [%dest_size_y, %dest_size_x], strides = [1, 1]
-      : !flow.dispatch.tensor<readwrite:?x?xf32>{%dest_size_y, %dest_size_x} -> tensor<?x?xf32>
-  %insert = tensor.insert_slice %source into %dest[%offset_y, %offset_x] [%size_y, %size_x] [1, 1]
-      : tensor<?x?xf32> into tensor<?x?xf32>
-  flow.dispatch.tensor.store %insert, %dest_binding, offsets = [0, 0], sizes = [%dest_size_y, %dest_size_x], strides = [1, 1]
-      : tensor<?x?xf32> -> !flow.dispatch.tensor<readwrite:?x?xf32>{%dest_size_y, %dest_size_x}
-  return
-}
-//  CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> (s1 * s0)>
-//  CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0, s1] -> (s0, -d0 + s1)>
-//  CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0)[s0] -> (d0 + s0)>
-//      CHECK: func @subtensor_insert()
-//  CHECK-DAG:   %[[OFFSET_Y_VAL:.+]] = hal.interface.constant.load[0]
-//  CHECK-DAG:   %[[OFFSET_X_VAL:.+]] = hal.interface.constant.load[1]
-//  CHECK-DAG:   %[[SIZE_Y_VAL:.+]] = hal.interface.constant.load[2]
-//  CHECK-DAG:   %[[SIZE_X_VAL:.+]] = hal.interface.constant.load[3]
-//  CHECK-DAG:   %[[DEST_SIZE_Y_VAL:.+]] = hal.interface.constant.load[4]
-//  CHECK-DAG:   %[[DEST_SIZE_X_VAL:.+]] = hal.interface.constant.load[5]
-//  CHECK-DAG:   %[[OFFSET_Y:.+]] = arith.index_cast %[[OFFSET_Y_VAL]]
-//  CHECK-DAG:   %[[OFFSET_X:.+]] = arith.index_cast %[[OFFSET_X_VAL]]
-//  CHECK-DAG:   %[[SIZE_Y:.+]] = arith.index_cast %[[SIZE_Y_VAL]]
-//  CHECK-DAG:   %[[SIZE_X:.+]] = arith.index_cast %[[SIZE_X_VAL]]
-//  CHECK-DAG:   %[[DEST_SIZE_Y:.+]] = arith.index_cast %[[DEST_SIZE_Y_VAL]]
-//  CHECK-DAG:   %[[DEST_SIZE_X:.+]] = arith.index_cast %[[DEST_SIZE_X_VAL]]
-//  CHECK-DAG:   %[[SOURCE:.+]] = hal.interface.binding.subspan set(0) binding(0)
-//  CHECK-DAG:   %[[DEST:.+]] = hal.interface.binding.subspan set(0) binding(1)
-//  CHECK-DAG:   %[[WG_SIZE_X:.+]] = hal.interface.workgroup.size[0]
-//  CHECK-DAG:   %[[WG_SIZE_Y:.+]] = hal.interface.workgroup.size[1]
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 64)>
+//  CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 64)>
+//  CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0)[s0] -> (64, -d0 + s0)>
+//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"CPUTileFuseAndVectorize", workload_per_wg = [64, 64]>
+//      CHECK: hal.executable.entry_point public @matmul_tensors
+// CHECK-SAME:   translation.info = #[[TRANSLATION]]
+// CHECK-NEXT:   (%[[ARG0:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME:    %[[ARG1:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME:    %[[ARG2:[a-zA-Z0-9_]+]]: index)
+//  CHECK-DAG:    %[[C1:.+]] = arith.constant 1 : index
+//  CHECK-DAG:    %[[D0:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]]]
+//  CHECK-DAG:    %[[D1:.+]] = affine.apply #[[MAP0]]()[%[[ARG1]]]
+//      CHECK:    hal.return %[[D0]], %[[D1]], %[[C1]] : index, index, index
+//      CHECK: func @matmul_tensors()
+//  CHECK-DAG:   %[[M:.+]] = hal.interface.constant.load[0]
+//  CHECK-DAG:   %[[N:.+]] = hal.interface.constant.load[1]
+//  CHECK-DAG:   %[[K:.+]] = hal.interface.constant.load[2]
+//  CHECK-DAG:   %[[LHS_BINDING:.+]] = hal.interface.binding.subspan set(0) binding(0)
+//  CHECK-DAG:   %[[RHS_BINDING:.+]] = hal.interface.binding.subspan set(0) binding(1)
+//  CHECK-DAG:   %[[INIT_BINDING:.+]] = hal.interface.binding.subspan set(0) binding(2)
+//  CHECK-DAG:   %[[OUT_BINDING:.+]] = hal.interface.binding.subspan set(0) binding(3)
 //  CHECK-DAG:   %[[WG_ID_X:.+]] = hal.interface.workgroup.id[0]
 //  CHECK-DAG:   %[[WG_COUNT_X:.+]] = hal.interface.workgroup.count[0]
 //  CHECK-DAG:   %[[WG_ID_Y:.+]] = hal.interface.workgroup.id[1]
 //  CHECK-DAG:   %[[WG_COUNT_Y:.+]] = hal.interface.workgroup.count[1]
-//  CHECK-DAG:   %[[LB_Y:.+]] = affine.apply #[[MAP0]]()[%[[WG_SIZE_Y]], %[[WG_ID_Y]]]
-//  CHECK-DAG:   %[[STEP_Y:.+]] = affine.apply #[[MAP0]]()[%[[WG_SIZE_Y]], %[[WG_COUNT_Y]]]
-//      CHECK:   scf.for %[[IV0:.+]] = %[[LB_Y]] to %[[SIZE_Y]] step %[[STEP_Y]]
-//  CHECK-DAG:     %[[TILE_Y:.+]] = affine.min #[[MAP1]](%[[IV0]])[%[[WG_SIZE_Y]], %[[SIZE_Y]]]
-//  CHECK-DAG:     %[[LB_X:.+]] = affine.apply #[[MAP0]]()[%[[WG_SIZE_X]], %[[WG_ID_X]]]
-//  CHECK-DAG:     %[[STEP_X:.+]] = affine.apply #[[MAP0]]()[%[[WG_SIZE_X]], %[[WG_COUNT_X]]]
-//      CHECK:     scf.for %[[IV1:.+]] = %[[LB_X]] to %[[SIZE_X]] step %[[STEP_X]]
-//  CHECK-DAG:       %[[TILE_X:.+]] = affine.min #[[MAP1]](%[[IV1]])[%[[WG_SIZE_X]], %[[SIZE_X]]]
-//  CHECK-DAG:       %[[SOURCE_TILE:.+]] = flow.dispatch.tensor.load %[[SOURCE]]
-// CHECK-SAME:           offsets = [%[[IV0]], %[[IV1]]], sizes = [%[[TILE_Y]], %[[TILE_X]]]
-//  CHECK-DAG:       %[[DEST_OFFSET_Y:.+]] = affine.apply #[[MAP2]](%[[IV0]])[%[[OFFSET_Y]]]
-//  CHECK-DAG:       %[[DEST_OFFSET_X:.+]] = affine.apply #[[MAP2]](%[[IV1]])[%[[OFFSET_X]]]
-//      CHECK:       flow.dispatch.tensor.store %[[SOURCE_TILE]], %[[DEST]]
-// CHECK-SAME:           offsets = [%[[DEST_OFFSET_Y]], %[[DEST_OFFSET_X]]]
-// CHECK-SAME:           sizes = [%[[TILE_Y]], %[[TILE_X]]]
-
+//  CHECK-DAG:   %[[LB_Y:.+]] = affine.apply #[[MAP1]]()[%[[WG_ID_Y]]]
+//  CHECK-DAG:   %[[STEP_Y:.+]] = affine.apply #[[MAP1]]()[%[[WG_COUNT_Y]]]
+//      CHECK:   scf.for %[[IV0:.+]] = %[[LB_Y]] to %[[M]] step %[[STEP_Y]]
+//  CHECK-DAG:     %[[LB_X:.+]] = affine.apply #[[MAP1]]()[%[[WG_ID_X]]]
+//  CHECK-DAG:     %[[STEP_X:.+]] = affine.apply #[[MAP1]]()[%[[WG_COUNT_X]]]
+//      CHECK:     scf.for %[[IV1:.+]] = %[[LB_X]] to %[[N]] step %[[STEP_X]]
+//  CHECK-DAG:       %[[TILESIZE_M:.+]] = affine.min #[[MAP2]](%[[IV0]])[%[[M]]]
+//  CHECK-DAG:       %[[LHS:.+]] = flow.dispatch.tensor.load %[[LHS_BINDING]]
+// CHECK-SAME:           offsets = [%[[IV0]], 0], sizes = [%[[TILESIZE_M]], %[[K]]]
+//  CHECK-DAG:       %[[TILESIZE_N:.+]] = affine.min #[[MAP2]](%[[IV1]])[%[[N]]]
+//  CHECK-DAG:       %[[RHS:.+]] = flow.dispatch.tensor.load %[[RHS_BINDING]]
+// CHECK-SAME:           offsets = [0, %[[IV1]]], sizes = [%[[K]], %[[TILESIZE_N]]]
+//  CHECK-DAG:       %[[INIT:.+]] = flow.dispatch.tensor.load %[[INIT_BINDING]]
+// CHECK-SAME:           offsets = [%[[IV0]], %[[IV1]]], sizes = [%[[TILESIZE_M]], %[[TILESIZE_N]]]
+//      CHECK:       %[[GEMM:.+]] = linalg.matmul
+// CHECK-SAME:           ins(%[[LHS]], %[[RHS]] :
+// CHECK-SAME:           outs(%[[INIT]] :
+//      CHECK:       flow.dispatch.tensor.store %[[GEMM]], %[[OUT_BINDING]]
+// CHECK-SAME:           offsets = [%[[IV0]], %[[IV1]]], sizes = [%[[TILESIZE_M]], %[[TILESIZE_N]]]
 
 // -----
 
-func @non_tiled_reduction_fill() {
-  %zero = arith.constant 0.0 : f32
-  %c0 = arith.constant 0 : index
-  %input_binding = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(32)
-      : !flow.dispatch.tensor<readonly:1000xf32>
-  %output_binding = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(32)
-      : !flow.dispatch.tensor<writeonly:f32>
-  %input = flow.dispatch.tensor.load %input_binding, offsets = [0], sizes = [1000], strides = [1]
-      : !flow.dispatch.tensor<readonly:1000xf32> -> tensor<1000xf32>
-  %init = linalg.init_tensor [] : tensor<f32>
-  %fill = linalg.fill(%zero, %init) : f32, tensor<f32> -> tensor<f32>
-  %reduce = linalg.generic {
-        indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> ()>],
-        iterator_types = ["reduction"]}
-        ins(%input : tensor<1000xf32>) outs(%fill : tensor<f32>) {
-          ^bb0(%b0 : f32, %b1 : f32):
-            %update = arith.addf %b0, %b1 : f32
-            linalg.yield %update : f32
-        } -> tensor<f32>
-  flow.dispatch.tensor.store %reduce, %output_binding, offsets = [], sizes = [], strides = [] : tensor<f32> -> !flow.dispatch.tensor<writeonly:f32>
-  return
+#config = #iree_codegen.lowering.config<tile_sizes = [[64, 64], [1, 4], [0, 0]], native_vector_size = []>
+#executable_layout = #hal.executable.layout<push_constants = 0, sets = [
+  #hal.descriptor_set.layout<0, bindings = [
+    #hal.descriptor_set.binding<0, storage_buffer>,
+    #hal.descriptor_set.binding<1, storage_buffer>,
+    #hal.descriptor_set.binding<2, storage_buffer>
+  ]>
+]>
+#executable_target_embedded_elf_x86_64_ = #hal.executable.target<"llvm", "embedded-elf-x86_64", {
+  data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128",
+  native_vector_size = 16 : index,
+  target_triple = "x86_64-unknown-linux-gnu"
+}>
+#map0 = affine_map<(d0, d1) -> (d0, d1)>
+#map1 = affine_map<(d0, d1) -> (d1)>
+#translation = #iree_codegen.translation.info<"CPUDoubleTilingExpert", workload_per_wg = []>
+hal.executable private @add {
+  hal.executable.variant public @llvm, target = #executable_target_embedded_elf_x86_64_ {
+    hal.executable.entry_point public @add layout(#executable_layout) {translation.info = #translation}
+    builtin.module {
+      func @add() {
+        %0 = hal.interface.constant.load[0] : index
+        %1 = hal.interface.constant.load[1] : index
+        %2 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer)
+            : !flow.dispatch.tensor<readonly:?x?xf32>{%0, %1}
+        %3 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer)
+            : !flow.dispatch.tensor<readonly:?xf32>{%1}
+        %4 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer)
+            : !flow.dispatch.tensor<writeonly:?x?xf32>{%0, %1}
+        %5 = flow.dispatch.tensor.load %2, offsets = [0, 0], sizes = [%0, %1], strides = [1, 1]
+            : !flow.dispatch.tensor<readonly:?x?xf32>{%0, %1} -> tensor<?x?xf32>
+        %6 = flow.dispatch.tensor.load %3, offsets = [0], sizes = [%1], strides = [1]
+            : !flow.dispatch.tensor<readonly:?xf32>{%1} -> tensor<?xf32>
+        %7 = linalg.init_tensor [%0, %1] : tensor<?x?xf32>
+        %8 = linalg.generic {
+            indexing_maps = [#map0, #map1, #map0], iterator_types = ["parallel", "parallel"]}
+            ins(%5, %6 : tensor<?x?xf32>, tensor<?xf32>) outs(%7 : tensor<?x?xf32>)
+            attrs =  {lowering.config = #config} {
+        ^bb0(%arg0: f32, %arg1: f32, %arg2: f32):
+          %9 = arith.addf %arg0, %arg1 : f32
+          linalg.yield %9 : f32
+        } -> tensor<?x?xf32>
+        flow.dispatch.tensor.store %8, %4, offsets = [0, 0], sizes = [%0, %1], strides = [1, 1]
+            : tensor<?x?xf32> -> !flow.dispatch.tensor<writeonly:?x?xf32>{%0, %1}
+        return
+      }
+    }
+  }
 }
-// CHECK-LABEL: func @non_tiled_reduction_fill()
-//   CHECK-DAG:   %[[INPUT_BINDING:.+]] = hal.interface.binding.subspan set(0) binding(0)
-//   CHECK-DAG:   %[[OUTPUT_BINDING:.+]] = hal.interface.binding.subspan set(0) binding(1)
-//   CHECK-DAG:   %[[INPUT:.+]] = flow.dispatch.tensor.load %[[INPUT_BINDING]]
-//  CHECK-SAME:       offsets = [0], sizes = [1000]
-//       CHECK:   %[[INIT:.+]] = linalg.init_tensor []
-//       CHECK:   %[[FILL:.+]] = linalg.fill(%{{.+}}, %[[INIT]])
-//       CHECK:   %[[GENERIC:.+]] = linalg.generic
-//  CHECK-SAME:       ins(%[[INPUT]] :
-//  CHECK-SAME:       outs(%[[FILL]] :
-//       CHECK:   flow.dispatch.tensor.store %[[GENERIC]], %[[OUTPUT_BINDING]]
-//  CHECK-SAME:       offsets = [], sizes = []
-
-// -----
-
-func @multi_result() {
-  %cmin = arith.constant -2147483648 : i32
-  %czero = arith.constant 0.0 : f32
-  %c0 = arith.constant 0 : index
-  %d0_i32 = hal.interface.constant.load[0] : i32
-  %d1_i32 = hal.interface.constant.load[1] : i32
-  %d0 = arith.index_cast %d0_i32 : i32 to index
-  %d1 = arith.index_cast %d1_i32 : i32 to index
-  %input1_binding = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(32)
-      : !flow.dispatch.tensor<readonly:?x?xf32>{%d0, %d1}
-  %input2_binding = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(32)
-      : !flow.dispatch.tensor<readonly:?x?xi32>{%d0, %d1}
-  %output1_binding = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) offset(%c0) alignment(32)
-      : !flow.dispatch.tensor<writeonly:?xf32>{%d0}
-  %output2_binding = hal.interface.binding.subspan set(0) binding(3) type(storage_buffer) offset(%c0) alignment(32)
-      : !flow.dispatch.tensor<writeonly:?xi32>{%d0}
-  %input1 = flow.dispatch.tensor.load %input1_binding, offsets = [0, 0], sizes = [%d0, %d1], strides = [1, 1]
-      : !flow.dispatch.tensor<readonly:?x?xf32>{%d0, %d1} -> tensor<?x?xf32>
-  %input2 = flow.dispatch.tensor.load %input2_binding, offsets = [0, 0], sizes = [%d0, %d1], strides = [1, 1]
-      : !flow.dispatch.tensor<readonly:?x?xi32>{%d0, %d1} -> tensor<?x?xi32>
-  %init1 = linalg.init_tensor [%d0] : tensor<?xf32>
-  %init2 = linalg.init_tensor [%d0] : tensor<?xi32>
-  %fill1 = linalg.fill(%czero, %init1) : f32, tensor<?xf32> -> tensor<?xf32>
-  %fill2 = linalg.fill(%cmin, %init2) : i32, tensor<?xi32> -> tensor<?xi32>
-  %generic:2 = linalg.generic {
-      indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>,
-                       affine_map<(d0, d1) -> (d1, d0)>,
-                       affine_map<(d0, d1) -> (d0)>,
-                       affine_map<(d0, d1) -> (d0)>],
-      iterator_types = ["parallel", "reduction"]}
-      ins(%input1, %input2 : tensor<?x?xf32>, tensor<?x?xi32>)
-      outs(%fill1, %fill2 : tensor<?xf32>, tensor<?xi32>) {
-      ^bb0(%arg2: f32, %arg3: i32, %arg4: f32, %arg5: i32):  // no predecessors
-        %5 = arith.cmpf oge, %arg2, %arg4 : f32
-        %6 = arith.select %5, %arg2, %arg4 : f32
-        %7 = arith.cmpf oeq, %arg2, %arg4 : f32
-        %8 = arith.cmpi slt, %arg3, %arg5 : i32
-        %9 = arith.select %8, %arg3, %arg5 : i32
-        %10 = arith.select %5, %arg3, %arg5 : i32
-        %11 = arith.select %7, %9, %10 : i32
-        linalg.yield %6, %11 : f32, i32
-    } -> (tensor<?xf32>, tensor<?xi32>)
-  flow.dispatch.tensor.store %generic#0, %output1_binding, offsets = [0], sizes = [%d0], strides = [1]
-      : tensor<?xf32> -> !flow.dispatch.tensor<writeonly:?xf32>{%d0}
-  flow.dispatch.tensor.store %generic#1, %output2_binding, offsets = [0], sizes = [%d0], strides = [1]
-      : tensor<?xi32> -> !flow.dispatch.tensor<writeonly:?xi32>{%d0}
-  return
-}
-// CHECK-LABEL: func @multi_result()
-//   CHECK-DAG:   %[[OUT1:.+]] = hal.interface.binding.subspan set(0) binding(2)
-//   CHECK-DAG:   %[[OUT2:.+]] = hal.interface.binding.subspan set(0) binding(3)
-//       CHECK:   scf.for
-//   CHECK-NOT:     scf.for
-//       CHECK:       %[[GENERIC:.+]]:2 = linalg.generic
-//   CHECK-DAG:       flow.dispatch.tensor.store %[[GENERIC]]#0, %[[OUT1]]
-//   CHECK-DAG:       flow.dispatch.tensor.store %[[GENERIC]]#1, %[[OUT2]]
-
-// -----
-
-func @scatter() {
-  %c0 = arith.constant 0 : index
-  %d0_i32 = hal.interface.constant.load[0] : i32
-  %d1_i32 = hal.interface.constant.load[1] : i32
-  %d2_i32 = hal.interface.constant.load[2] : i32
-  %d0 = arith.index_cast %d0_i32 : i32 to index
-  %d1 = arith.index_cast %d1_i32 : i32 to index
-  %d2 = arith.index_cast %d2_i32 : i32 to index
-  %original_binding = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(32)
-      : !flow.dispatch.tensor<readonly:?x?xf32>{%d0, %d1}
-  %indices_binding = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(32)
-      : !flow.dispatch.tensor<readonly:?x1xi32>{%d0}
-  %update_binding = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) offset(%c0) alignment(32)
-      : !flow.dispatch.tensor<writeonly:?x?xf32>{%d2, %d1}
-  %original = flow.dispatch.tensor.load %original_binding, offsets = [0, 0], sizes = [%d0, %d1], strides = [1, 1]
-      : !flow.dispatch.tensor<readonly:?x?xf32>{%d0, %d1} -> tensor<?x?xf32>
-  %indices = flow.dispatch.tensor.load %indices_binding, offsets = [0, 0], sizes = [%d2, 1], strides = [1, 1]
-      : !flow.dispatch.tensor<readonly:?x1xi32>{%d0} -> tensor<?x1xi32>
-  %update = flow.dispatch.tensor.load %update_binding, offsets = [0, 0], sizes = [%d2, %d1], strides = [1, 1]
-      : !flow.dispatch.tensor<writeonly:?x?xf32>{%d2, %d1} -> tensor<?x?xf32>
-  %result = iree_linalg_ext.scatter
-      unique_indices(true)
-      ins(%original, %indices : tensor<?x?xf32>, tensor<?x1xi32>)
-      outs(%update : tensor<?x?xf32>) {
-      ^bb0(%arg0: f32, %arg1: f32):
-        %1 = arith.addf %arg0, %arg1 : f32
-        iree_linalg_ext.yield %1 : f32
-  } -> tensor<?x?xf32>
-  flow.dispatch.tensor.store %result, %update_binding, offsets = [0, 0], sizes = [%d2, %d1], strides = [1, 1]
-      : tensor<?x?xf32> -> !flow.dispatch.tensor<writeonly:?x?xf32>{%d2, %d1}
-  return
-}
-// CHECK-LABEL: func @scatter()
-//   CHECK-DAG:   %[[D0_VAL:.+]] = hal.interface.constant.load[0]
-//   CHECK-DAG:   %[[D1_VAL:.+]] = hal.interface.constant.load[1]
-//   CHECK-DAG:   %[[D2_VAL:.+]] = hal.interface.constant.load[2]
-//   CHECK-DAG:   %[[D0:.+]] = arith.index_cast %[[D0_VAL]]
-//   CHECK-DAG:   %[[D1:.+]] = arith.index_cast %[[D1_VAL]]
-//   CHECK-DAG:   %[[D2:.+]] = arith.index_cast %[[D2_VAL]]
-//   CHECK-DAG:   %[[ORIGINAL:.+]] = hal.interface.binding.subspan set(0) binding(0)
-//   CHECK-DAG:   %[[INDICES:.+]] = hal.interface.binding.subspan set(0) binding(1)
-//   CHECK-DAG:   %[[UPDATE:.+]] = hal.interface.binding.subspan set(0) binding(2)
-//       CHECK:   scf.for %[[IV0:.+]] = %{{.+}} to %[[D0]]
-//       CHECK:     scf.for %[[IV1:.+]] = %{{.+}} to %[[D1]]
-//   CHECK-DAG:       %[[ORIGINAL_TILE:.+]] = flow.dispatch.tensor.load %[[ORIGINAL]], offsets = [%[[IV0]], %[[IV1]]]
-//   CHECK-DAG:       %[[INDICES_TILE:.+]] = flow.dispatch.tensor.load %[[INDICES]], offsets = [%[[IV0]], 0]
-//   CHECK-DAG:       %[[UPDATE_TILE:.+]] = flow.dispatch.tensor.load %[[UPDATE]], offsets = [0, %[[IV1]]], sizes = [%[[D2]],
-//       CHECK:       %[[SCATTER_TILE:.+]] = iree_linalg_ext.scatter
-//  CHECK-SAME:           ins(%[[ORIGINAL_TILE]], %[[INDICES_TILE]] :
-//  CHECK-SAME:           outs(%[[UPDATE_TILE]] :
-//       CHECK:       flow.dispatch.tensor.store %[[SCATTER_TILE]], %[[UPDATE]], offsets = [0, %[[IV1]]], sizes = [%[[D2]],
-
-// -----
-
-func @sort_3d() {
-  %c0 = arith.constant 0 : index
-  %d0_i32 = hal.interface.constant.load[0] : i32
-  %d1_i32 = hal.interface.constant.load[1] : i32
-  %d2_i32 = hal.interface.constant.load[2] : i32
-  %d0 = arith.index_cast %d0_i32 : i32 to index
-  %d1 = arith.index_cast %d1_i32 : i32 to index
-  %d2 = arith.index_cast %d2_i32 : i32 to index
-  %output1_binding = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(32)
-      : !flow.dispatch.tensor<readwrite:?x?x?xf32>{%d0, %d1, %d2}
-  %output2_binding = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(32)
-      : !flow.dispatch.tensor<readwrite:?x?x?xi32>{%d0, %d1, %d2}
-  %output1 = flow.dispatch.tensor.load %output1_binding, offsets = [0, 0, 0], sizes = [%d0, %d1, %d2], strides = [1, 1, 1]
-      : !flow.dispatch.tensor<readwrite:?x?x?xf32>{%d0, %d1, %d2} -> tensor<?x?x?xf32>
-  %output2 = flow.dispatch.tensor.load %output2_binding, offsets = [0, 0, 0], sizes = [%d0, %d1, %d2], strides = [1, 1, 1]
-      : !flow.dispatch.tensor<readwrite:?x?x?xi32>{%d0, %d1, %d2} -> tensor<?x?x?xi32>
-  %result:2 = iree_linalg_ext.sort dimension(0)
-      outs(%output1, %output2 : tensor<?x?x?xf32>, tensor<?x?x?xi32>) {
-        ^bb0(%b0: f32, %b1: f32, %b2 : i32, %b3 : i32):  // no predecessors
-          %2 = arith.cmpf ogt, %b0, %b1 : f32
-          iree_linalg_ext.yield %2 : i1
-      } -> tensor<?x?x?xf32>, tensor<?x?x?xi32>
-  flow.dispatch.tensor.store %result#0, %output1_binding, offsets = [0, 0, 0], sizes = [%d0, %d1, %d2], strides = [1, 1, 1]
-      : tensor<?x?x?xf32> -> !flow.dispatch.tensor<readwrite:?x?x?xf32>{%d0, %d1, %d2}
-  flow.dispatch.tensor.store %result#1, %output2_binding, offsets = [0, 0, 0], sizes = [%d0, %d1, %d2], strides = [1, 1, 1]
-      : tensor<?x?x?xi32> -> !flow.dispatch.tensor<readwrite:?x?x?xi32>{%d0, %d1, %d2}
-  return
-}
-// CHECK-LABEL: func @sort_3d()
-//   CHECK-DAG:   %[[OUTPUT1:.+]] = hal.interface.binding.subspan set(0) binding(0)
-//   CHECK-DAG:   %[[OUTPUT2:.+]] = hal.interface.binding.subspan set(0) binding(1)
-//       CHECK:   scf.for %[[IV0:.+]] =
-//       CHECK:     scf.for %[[IV1:.+]] =
-//   CHECK-DAG:       %[[OUTPUT1_TILE:.+]] = flow.dispatch.tensor.load %[[OUTPUT1]], offsets = [0, %[[IV0]], %[[IV1]]]
-//   CHECK-DAG:       %[[OUTPUT2_TILE:.+]] = flow.dispatch.tensor.load %[[OUTPUT2]], offsets = [0, %[[IV0]], %[[IV1]]]
-//       CHECK:       %[[SORT_TILE:.+]]:2 = iree_linalg_ext.sort dimension(0)
-//  CHECK-SAME:           outs(%[[OUTPUT1_TILE]], %[[OUTPUT2_TILE]] :
-//   CHECK-DAG:       flow.dispatch.tensor.store %[[SORT_TILE]]#0, %[[OUTPUT1]], offsets = [0, %[[IV0]], %[[IV1]]]
-//   CHECK-DAG:       flow.dispatch.tensor.store %[[SORT_TILE]]#1, %[[OUTPUT2]], offsets = [0, %[[IV0]], %[[IV1]]]
-
-// -----
-
-func @extract_slice() {
-  %c0 = arith.constant 0 : index
-  %offset_y_i32 = hal.interface.constant.load[0] : i32
-  %offset_x_i32 = hal.interface.constant.load[1] : i32
-  %size_y_i32 = hal.interface.constant.load[2] : i32
-  %size_x_i32 = hal.interface.constant.load[3] : i32
-  %source_size_y_i32 = hal.interface.constant.load[4] : i32
-  %source_size_x_i32 = hal.interface.constant.load[5] : i32
-  %offset_y = arith.index_cast %offset_y_i32 : i32 to index
-  %offset_x = arith.index_cast %offset_x_i32 : i32 to index
-  %size_y = arith.index_cast %size_y_i32 : i32 to index
-  %size_x = arith.index_cast %size_x_i32 : i32 to index
-  %source_size_y = arith.index_cast %source_size_y_i32 : i32 to index
-  %source_size_x = arith.index_cast %source_size_x_i32 : i32 to index
-  %source_binding = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(32)
-      : !flow.dispatch.tensor<readonly:?x?xf32>{%source_size_y, %source_size_x}
-  %result_binding = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(32)
-      : !flow.dispatch.tensor<writeonly:?x?xf32>{%size_y, %size_x}
-  %source = flow.dispatch.tensor.load %source_binding, offsets = [0, 0], sizes = [%source_size_y, %source_size_x], strides = [1, 1]
-      : !flow.dispatch.tensor<readonly:?x?xf32>{%source_size_y, %source_size_x} -> tensor<?x?xf32>
-  %slice = tensor.extract_slice %source[%offset_y, %offset_x] [%size_y, %size_x] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
-  flow.dispatch.tensor.store %slice, %result_binding, offsets = [0, 0], sizes = [%size_y, %size_x], strides = [1, 1]
-      : tensor<?x?xf32> -> !flow.dispatch.tensor<writeonly:?x?xf32>{%size_y, %size_x}
-  return
-}
-//  CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> (s1 * s0)>
-//  CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0, s1] -> (s0, -d0 + s1)>
-//  CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0)[s0] -> (d0 + s0)>
-//      CHECK: func @extract_slice()
-//  CHECK-DAG:   %[[OFFSET_Y_VAL:.+]] = hal.interface.constant.load[0]
-//  CHECK-DAG:   %[[OFFSET_X_VAL:.+]] = hal.interface.constant.load[1]
-//  CHECK-DAG:   %[[OFFSET_Y:.+]] = arith.index_cast %[[OFFSET_Y_VAL]]
-//  CHECK-DAG:   %[[OFFSET_X:.+]] = arith.index_cast %[[OFFSET_X_VAL]]
-//  CHECK-DAG:   %[[SOURCE:.+]] = hal.interface.binding.subspan set(0) binding(0)
-//  CHECK-DAG:   %[[RESULT:.+]] = hal.interface.binding.subspan set(0) binding(1)
+//  CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0] -> (s0 ceildiv 64)>
+//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"CPUDoubleTilingExpert", workload_per_wg = [64, 64]>
+//      CHECK: hal.executable private @add
+//      CHECK: hal.executable.entry_point public @add
+// CHECK-SAME:  translation.info = #[[TRANSLATION]]
+// CHECK-NEXT:  (%[[ARG0:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: index)
+//  CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
+//  CHECK-DAG:   %[[D0:.+]] = affine.apply #[[MAP]]()[%[[ARG0]]]
+//  CHECK-DAG:   %[[D1:.+]] = affine.apply #[[MAP]]()[%[[ARG1]]]
+//      CHECK:   hal.return %[[D0]], %[[D1]], %[[C1]] : index, index, index
+//      CHECK: func @add()
 //      CHECK:   scf.for %[[IV0:.+]] =
 //      CHECK:     scf.for %[[IV1:.+]] =
-//  CHECK-DAG:       %[[TILE_OFFSET_Y:.+]] = affine.apply #[[MAP2]](%[[IV0]])[%[[OFFSET_Y]]]
-//  CHECK-DAG:       %[[TILE_OFFSET_X:.+]] = affine.apply #[[MAP2]](%[[IV1]])[%[[OFFSET_X]]]
-//      CHECK:       %[[TILE_SLICE:.+]] = flow.dispatch.tensor.load %[[SOURCE]], offsets = [%[[TILE_OFFSET_Y]], %[[TILE_OFFSET_X]]]
-//      CHECK:       flow.dispatch.tensor.store %[[TILE_SLICE]], %[[RESULT]], offsets = [%[[IV0]], %[[IV1]]]
+//      CHECK:       %[[RESULT:.+]] = linalg.generic
+//      CHECK:       flow.dispatch.tensor.store %[[RESULT]], %{{.+}}, offsets = [%[[IV0]], %[[IV1]]]
 
 // -----
 
-func @gemm_unitN() {
-  %cst = arith.constant 0.000000e+00 : f32
-  %c0 = arith.constant 0 : index
-  %0 = hal.interface.constant.load[0] : i32
-  %1 = hal.interface.constant.load[1] : i32
-  %4 = arith.index_cast %0 : i32 to index
-  %5 = arith.index_cast %1 : i32 to index
-  %8 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(32) : !flow.dispatch.tensor<readonly:?x?xf32>{%4, %5}
-  %9 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(32) : !flow.dispatch.tensor<readonly:?x1xf32>{%5}
-  %10 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) offset(%c0) alignment(32) : !flow.dispatch.tensor<writeonly:?x1xf32>{%4}
-  %11 = flow.dispatch.tensor.load %8, offsets = [0, 0], sizes = [%4, %5], strides = [1, 1] : !flow.dispatch.tensor<readonly:?x?xf32>{%4, %5} -> tensor<?x?xf32>
-  %12 = flow.dispatch.tensor.load %9, offsets = [0, 0], sizes = [%5, 1], strides = [1, 1] : !flow.dispatch.tensor<readonly:?x1xf32>{%5} -> tensor<?x1xf32>
-  %13 = linalg.init_tensor [%4, 1] : tensor<?x1xf32>
-  %14 = linalg.fill(%cst, %13) : f32, tensor<?x1xf32> -> tensor<?x1xf32>
-  %15 = linalg.matmul ins(%11, %12 : tensor<?x?xf32>, tensor<?x1xf32>) outs(%14 : tensor<?x1xf32>) -> tensor<?x1xf32>
-  flow.dispatch.tensor.store %15, %10, offsets = [0, 0], sizes = [%4, 1], strides = [1, 1] : tensor<?x1xf32> -> !flow.dispatch.tensor<writeonly:?x1xf32>{%4}
-  return
+#config = #iree_codegen.lowering.config<tile_sizes = [[0, 64, 64, 64], [1, 1, 1, 4], [0, 0, 0, 0]], native_vector_size = []>
+#executable_layout = #hal.executable.layout<push_constants = 0, sets = [
+  #hal.descriptor_set.layout<0, bindings = [
+    #hal.descriptor_set.binding<0, storage_buffer>,
+    #hal.descriptor_set.binding<1, storage_buffer>,
+    #hal.descriptor_set.binding<2, storage_buffer>,
+    #hal.descriptor_set.binding<3, storage_buffer>
+  ]>
+]>
+#executable_target_embedded_elf_x86_64_ = #hal.executable.target<"llvm", "embedded-elf-x86_64", {
+  data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128",
+  native_vector_size = 16 : index,
+  target_triple = "x86_64-unknown-linux-gnu"}>
+#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+#translation = #iree_codegen.translation.info<"CPUDoubleTilingExpert", workload_per_wg = []>
+hal.executable private @add4D {
+  hal.executable.variant public @llvm, target = #executable_target_embedded_elf_x86_64_ {
+    hal.executable.entry_point public @add4D layout(#executable_layout) {translation.info = #translation}
+    builtin.module {
+      func @add4D() {
+        %0 = hal.interface.constant.load[0] : index
+        %1 = hal.interface.constant.load[1] : index
+        %2 = hal.interface.constant.load[2] : index
+        %3 = hal.interface.constant.load[3] : index
+        %4 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(32)
+            : !flow.dispatch.tensor<readonly:?x?x?x?xf32>{%0, %1, %2, %3}
+        %5 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(32)
+            : !flow.dispatch.tensor<readonly:?x?x?x?xf32>{%0, %1, %2, %3}
+        %6 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(32)
+            : !flow.dispatch.tensor<writeonly:?x?x?x?xf32>{%0, %1, %2, %3}
+        %7 = flow.dispatch.tensor.load %4, offsets = [0, 0, 0, 0], sizes = [%0, %1, %2, %3], strides = [1, 1, 1, 1]
+            : !flow.dispatch.tensor<readonly:?x?x?x?xf32>{%0, %1, %2, %3} -> tensor<?x?x?x?xf32>
+        %8 = flow.dispatch.tensor.load %5, offsets = [0, 0, 0, 0], sizes = [%0, %1, %2, %3], strides = [1, 1, 1, 1]
+            : !flow.dispatch.tensor<readonly:?x?x?x?xf32>{%0, %1, %2, %3} -> tensor<?x?x?x?xf32>
+        %9 = linalg.init_tensor [%0, %1, %2, %3] : tensor<?x?x?x?xf32>
+        %10 = linalg.generic {
+            indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
+            ins(%7, %8 : tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) outs(%9 : tensor<?x?x?x?xf32>) attrs =  {lowering.config = #config} {
+        ^bb0(%arg0: f32, %arg1: f32, %arg2: f32):
+          %11 = arith.addf %arg0, %arg1 : f32
+          linalg.yield %11 : f32
+        } -> tensor<?x?x?x?xf32>
+        flow.dispatch.tensor.store %10, %6, offsets = [0, 0, 0, 0], sizes = [%0, %1, %2, %3], strides = [1, 1, 1, 1]
+            : tensor<?x?x?x?xf32> -> !flow.dispatch.tensor<writeonly:?x?x?x?xf32>{%0, %1, %2, %3}
+        return
+      }
+    }
+  }
 }
-// CHECK-LABEL: func @gemm_unitN()
-//   CHECK-DAG:   %[[M_VAL:.+]] = hal.interface.constant.load[0]
-//   CHECK-DAG:   %[[M:.+]] = arith.index_cast %[[M_VAL]] : i32 to index
-//       CHECK:   scf.for %[[IV0:.+]] = %{{.+}} to %[[M]]
-//   CHECK-NOT:   scf.for
-//       CHECK:     linalg.fill
-//       CHECK:     linalg.matmul
-//       CHECK:     flow.dispatch.tensor.store
+//  CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0] -> (s0 ceildiv 64)>
+//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"CPUDoubleTilingExpert", workload_per_wg = [64, 64, 64]>
+//      CHECK: hal.executable.entry_point public @add4D
+// CHECK-SAME:   translation.info = #[[TRANSLATION]]
+// CHECK-NEXT:   (%[[ARG0:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME:    %[[ARG1:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME:    %[[ARG2:[a-zA-Z0-9_]+]]: index)
+//  CHECK-DAG:    %[[D0:.+]] = affine.apply #[[MAP]]()[%[[ARG0]]]
+//  CHECK-DAG:    %[[D1:.+]] = affine.apply #[[MAP]]()[%[[ARG1]]]
+//  CHECK-DAG:    %[[D2:.+]] = affine.apply #[[MAP]]()[%[[ARG2]]]
+//      CHECK:    hal.return %[[D0]], %[[D1]], %[[D2]] : index, index, index
+//      CHECK: func @add4D()
+//      CHECK:   scf.for %[[IV0:.+]] =
+//      CHECK:     scf.for %[[IV1:.+]] =
+//      CHECK:       scf.for %[[IV2:.+]] =
+//  CHECK-NOT:         scf.for
+//      CHECK:         %[[GENERIC:.+]] = linalg.generic
+//      CHECK:         flow.dispatch.tensor.store %[[GENERIC]], %{{.+}}, offsets = [0, %[[IV0]], %[[IV1]], %[[IV2]]]
 
 // -----
 
-func @gemm_unitM_unitN() {
-  %cst = arith.constant 0.000000e+00 : f32
-  %c0 = arith.constant 0 : index
-  %1 = hal.interface.constant.load[0] : i32
-  %5 = arith.index_cast %1 : i32 to index
-  %8 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(32) : !flow.dispatch.tensor<readonly:1x?xf32>{%5}
-  %9 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(32) : !flow.dispatch.tensor<readonly:?x1xf32>{%5}
-  %10 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) offset(%c0) alignment(32) : !flow.dispatch.tensor<writeonly:1x1xf32>
-  %11 = flow.dispatch.tensor.load %8, offsets = [0, 0], sizes = [1, %5], strides = [1, 1] : !flow.dispatch.tensor<readonly:1x?xf32>{%5} -> tensor<1x?xf32>
-  %12 = flow.dispatch.tensor.load %9, offsets = [0, 0], sizes = [%5, 1], strides = [1, 1] : !flow.dispatch.tensor<readonly:?x1xf32>{%5} -> tensor<?x1xf32>
-  %13 = linalg.init_tensor [1, 1] : tensor<1x1xf32>
-  %14 = linalg.fill(%cst, %13) : f32, tensor<1x1xf32> -> tensor<1x1xf32>
-  %15 = linalg.matmul ins(%11, %12 : tensor<1x?xf32>, tensor<?x1xf32>) outs(%14 : tensor<1x1xf32>) -> tensor<1x1xf32>
-  flow.dispatch.tensor.store %15, %10, offsets = [0, 0], sizes = [1, 1], strides = [1, 1] : tensor<1x1xf32> -> !flow.dispatch.tensor<writeonly:1x1xf32>
-  return
+#config = #iree_codegen.lowering.config<tile_sizes = [[1, 64, 64, 0], [1, 16, 4, 64], [1, 4, 4, 4]], native_vector_size = [1, 4, 4, 4]>
+#executable_layout = #hal.executable.layout<push_constants = 0, sets = [
+  #hal.descriptor_set.layout<0, bindings = [
+    #hal.descriptor_set.binding<0, storage_buffer>,
+    #hal.descriptor_set.binding<1, storage_buffer>,
+    #hal.descriptor_set.binding<2, storage_buffer>
+  ]>
+]>
+#executable_target_embedded_elf_arm_64_ = #hal.executable.target<"llvm", "embedded-elf-arm_64", {
+  data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128",
+  native_vector_size = 16 : index,
+  target_triple = "aarch64-unknown-unknown-eabi-elf"}>
+#translation = #iree_codegen.translation.info<"CPUTileFuseAndVectorize", workload_per_wg = []>
+hal.executable private @batch_matmul_tensors {
+  hal.executable.variant public @llvm, target = #executable_target_embedded_elf_arm_64_ {
+    hal.executable.entry_point public @batch_matmul_tensors layout(#executable_layout) {translation.info = #translation}
+    builtin.module {
+      func @batch_matmul_tensors() {
+        %cst = arith.constant 0.000000e+00 : f32
+        %0 = hal.interface.constant.load[0] : index
+        %1 = hal.interface.constant.load[1] : index
+        %2 = hal.interface.constant.load[2] : index
+        %3 = hal.interface.constant.load[3] : index
+        %4 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(32)
+            : !flow.dispatch.tensor<readonly:?x?x?xf32>{%0, %1, %3}
+        %5 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(32)
+            : !flow.dispatch.tensor<readonly:?x?x?xf32>{%0, %3, %2}
+        %6 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(32)
+            : !flow.dispatch.tensor<writeonly:?x?x?xf32>{%0, %1, %2}
+        %7 = flow.dispatch.tensor.load %4, offsets = [0, 0, 0], sizes = [%0, %1, %3], strides = [1, 1, 1]
+            : !flow.dispatch.tensor<readonly:?x?x?xf32>{%0, %1, %3} -> tensor<?x?x?xf32>
+        %8 = flow.dispatch.tensor.load %5, offsets = [0, 0, 0], sizes = [%0, %3, %2], strides = [1, 1, 1]
+            : !flow.dispatch.tensor<readonly:?x?x?xf32>{%0, %3, %2} -> tensor<?x?x?xf32>
+        %9 = linalg.init_tensor [%0, %1, %2] : tensor<?x?x?xf32>
+        %10 = linalg.fill(%cst, %9) : f32, tensor<?x?x?xf32> -> tensor<?x?x?xf32>
+        %11 = linalg.batch_matmul {lowering.config = #config}
+            ins(%7, %8 : tensor<?x?x?xf32>, tensor<?x?x?xf32>) outs(%10 : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+        flow.dispatch.tensor.store %11, %6, offsets = [0, 0, 0], sizes = [%0, %1, %2], strides = [1, 1, 1]
+            : tensor<?x?x?xf32> -> !flow.dispatch.tensor<writeonly:?x?x?xf32>{%0, %1, %2}
+        return
+      }
+    }
+  }
 }
-// CHECK-LABEL: func @gemm_unitM_unitN()
-//   CHECK-NOT:   scf.for
-//       CHECK:   linalg.fill
-//       CHECK:   linalg.matmul
-//       CHECK:   flow.dispatch.tensor.store
-//       CHECK:   return
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 64)>
+//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"CPUTileFuseAndVectorize", workload_per_wg = [64, 64, 1]>
+//      CHECK: hal.executable.entry_point public @batch_matmul_tensors
+// CHECK-NEXT:   (%[[ARG0:[a-zA-Z0-9]+]]: index
+// CHECK-SAME:    %[[ARG1:[a-zA-Z0-9]+]]: index
+// CHECK-SAME:    %[[ARG2:[a-zA-Z0-9]+]]: index)
+//  CHECK-DAG:   %[[D0:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]]]
+//  CHECK-DAG:   %[[D1:.+]] = affine.apply #[[MAP0]]()[%[[ARG1]]]
+//      CHECK:   hal.return %[[D0]], %[[D1]], %[[ARG2]]
+//      CHECK: func @batch_matmul_tensors()
+//      CHECK:   scf.for %[[IV0:.+]] =
+//      CHECK:     scf.for %[[IV1:.+]] =
+//      CHECK:       scf.for %[[IV2:.+]] =
+//  CHECK-NOT:         scf.for
+//      CHECK:         %[[BATCH_GEMM:.+]] = linalg.batch_matmul
+//      CHECK:         flow.dispatch.tensor.store %[[BATCH_GEMM]]
+// CHECK-SAME:             offsets = [%[[IV0]], %[[IV1]], %[[IV2]]], sizes = [1, %{{.+}}, %{{.+}}]
 
 // -----
 
-func @gemm_unitM() {
-  %cst = arith.constant 0.000000e+00 : f32
-  %c0 = arith.constant 0 : index
-  %1 = hal.interface.constant.load[0] : i32
-  %2 = hal.interface.constant.load[1] : i32
-  %5 = arith.index_cast %1 : i32 to index
-  %6 = arith.index_cast %2 : i32 to index
-  %8 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(32) : !flow.dispatch.tensor<readonly:1x?xf32>{%5}
-  %9 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(32) : !flow.dispatch.tensor<readonly:?x?xf32>{%5, %6}
-  %10 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) offset(%c0) alignment(32) : !flow.dispatch.tensor<writeonly:1x?xf32>{%6}
-  %11 = flow.dispatch.tensor.load %8, offsets = [0, 0], sizes = [1, %5], strides = [1, 1] : !flow.dispatch.tensor<readonly:1x?xf32>{%5} -> tensor<1x?xf32>
-  %12 = flow.dispatch.tensor.load %9, offsets = [0, 0], sizes = [%5, %6], strides = [1, 1] : !flow.dispatch.tensor<readonly:?x?xf32>{%5, %6} -> tensor<?x?xf32>
-  %13 = linalg.init_tensor [1, %6] : tensor<1x?xf32>
-  %14 = linalg.fill(%cst, %13) : f32, tensor<1x?xf32> -> tensor<1x?xf32>
-  %15 = linalg.matmul ins(%11, %12 : tensor<1x?xf32>, tensor<?x?xf32>) outs(%14 : tensor<1x?xf32>) -> tensor<1x?xf32>
-  flow.dispatch.tensor.store %15, %10, offsets = [0, 0], sizes = [1, %6], strides = [1, 1] : tensor<1x?xf32> -> !flow.dispatch.tensor<writeonly:1x?xf32>{%6}
-  return
+#config = #iree_codegen.lowering.config<tile_sizes = [[32, 16, 0], [16, 8, 0], [0, 0, 2]], native_vector_size = []>
+#executable_layout = #hal.executable.layout<push_constants = 0, sets = [
+  #hal.descriptor_set.layout<0, bindings = [
+    #hal.descriptor_set.binding<0, storage_buffer>,
+    #hal.descriptor_set.binding<1, storage_buffer>,
+    #hal.descriptor_set.binding<2, storage_buffer>
+  ]>
+]>
+#executable_target_system_elf_x86_64_ = #hal.executable.target<"llvm", "system-elf-x86_64">
+#translation = #iree_codegen.translation.info<"CPUDoubleTilingExpert", workload_per_wg = []>
+hal.executable private @preset_config_matmul_tensors {
+  hal.executable.variant public @system_elf_x86_64, target = #executable_target_system_elf_x86_64_ {
+    hal.executable.entry_point public @preset_config layout(#executable_layout) {translation.info = #translation}
+    builtin.module {
+      func @preset_config() {
+        %cst = arith.constant 0.000000e+00 : f32
+        %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer)
+            : !flow.dispatch.tensor<readonly:128x256xf32>
+        %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer)
+            : !flow.dispatch.tensor<readonly:256x512xf32>
+        %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer)
+            : !flow.dispatch.tensor<writeonly:128x512xf32>
+        %3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [128, 256], strides = [1, 1]
+            : !flow.dispatch.tensor<readonly:128x256xf32> -> tensor<128x256xf32>
+        %4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [256, 512], strides = [1, 1]
+            : !flow.dispatch.tensor<readonly:256x512xf32> -> tensor<256x512xf32>
+        %5 = linalg.init_tensor [128, 512] : tensor<128x512xf32>
+        %6 = linalg.fill(%cst, %5) : f32, tensor<128x512xf32> -> tensor<128x512xf32>
+        %7 = linalg.matmul {lowering.config = #config}
+            ins(%3, %4 : tensor<128x256xf32>, tensor<256x512xf32>) outs(%6 : tensor<128x512xf32>) -> tensor<128x512xf32>
+        flow.dispatch.tensor.store %7, %2, offsets = [0, 0], sizes = [128, 512], strides = [1, 1]
+            : tensor<128x512xf32> -> !flow.dispatch.tensor<writeonly:128x512xf32>
+        return
+      }
+    }
+  }
 }
-// CHECK-LABEL: func @gemm_unitM()
-//   CHECK-DAG:   %[[N_VAL:.+]] = hal.interface.constant.load[1]
-//   CHECK-DAG:   %[[N:.+]] = arith.index_cast %[[N_VAL]] : i32 to index
-//       CHECK:   scf.for %[[IV0:.+]] = %{{.+}} to %[[N]]
-//   CHECK-NOT:   scf.for
-//       CHECK:     linalg.fill
-//       CHECK:     linalg.matmul
-//       CHECK:     flow.dispatch.tensor.store
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 16)>
+//  CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 ceildiv 32)>
+//  CHECK-DAG: #[[MAP2:.+]] = affine_map<()[s0] -> (s0 * 32)>
+//  CHECK-DAG: #[[MAP3:.+]] = affine_map<()[s0] -> (s0 * 16)>
+//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"CPUDoubleTilingExpert", workload_per_wg = [16, 32]>
+//      CHECK: hal.executable.entry_point public @preset_config
+// CHECK-NEXT:   (%[[ARG0:[a-zA-Z0-9]+]]: index
+// CHECK-SAME:    %[[ARG1:[a-zA-Z0-9]+]]: index
+// CHECK-SAME:    %[[ARG2:[a-zA-Z0-9]+]]: index)
+//  CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
+//  CHECK-DAG:   %[[D0:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]]]
+//  CHECK-DAG:   %[[D1:.+]] = affine.apply #[[MAP1]]()[%[[ARG1]]]
+//      CHECK:   hal.return %[[D0]], %[[D1]], %[[C1]]
+//      CHECK: func @preset_config()
+//      CHECK:   scf.for %[[IV0:.+]] =
+//      CHECK:     scf.for %[[IV1:.+]] =
+//  CHECK-DAG:       %[[LHS:.+]] = flow.dispatch.tensor.load %{{.+}}, offsets = [%[[IV0]], 0], sizes = [32, 256]
+//  CHECK-DAG:       %[[RHS:.+]] = flow.dispatch.tensor.load %{{.+}}, offsets = [0, %[[IV1]]], sizes = [256, 16]
+//  CHECK-DAG:       %[[INIT:.+]] = linalg.init_tensor [32, 16]
+//  CHECK-DAG:       %[[FILL:.+]] = linalg.fill(%{{.+}}, %[[INIT]])
+//  CHECK-DAG:       %[[GEMM:.+]] = linalg.matmul
+// CHECK-SAME:           outs(%[[FILL]] :
+//      CHECK:       flow.dispatch.tensor.store %[[GEMM]]
+// CHECK-SAME:           offsets = [%[[IV0]], %[[IV1]]], sizes = [32, 16]
 
 // -----
 
+#config = #iree_codegen.lowering.config<tile_sizes = [[64, 64]], native_vector_size = []>
+#executable_layout = #hal.executable.layout<push_constants = 0, sets = [
+  #hal.descriptor_set.layout<0, bindings = [
+    #hal.descriptor_set.binding<0, storage_buffer>,
+    #hal.descriptor_set.binding<1, storage_buffer>
+  ]>
+]>
+#executable_target_system_elf_x86_64_ = #hal.executable.target<"llvm", "system-elf-x86_64">
+#translation = #iree_codegen.translation.info<"CPUDefault", workload_per_wg = []>
+hal.executable public @tensor_insert {
+  hal.executable.variant public @system_elf_x86_64, target = #executable_target_system_elf_x86_64_ {
+    hal.executable.entry_point public @tensor_insert_slice layout(#executable_layout) {translation.info = #translation}
+    builtin.module {
+      func @tensor_insert_slice() {
+        %0 = hal.interface.constant.load[0] : index
+        %1 = hal.interface.constant.load[1] : index
+        %2 = hal.interface.constant.load[2] : index
+        %3 = hal.interface.constant.load[3] : index
+        %4 = hal.interface.constant.load[4] : index
+        %5 = hal.interface.constant.load[5] : index
+        %6 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer)
+            : !flow.dispatch.tensor<readonly:?x?xi32>{%0, %1}
+        %7 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer)
+            : !flow.dispatch.tensor<readwrite:?x?xi32>{%2, %3}
+        %8 = flow.dispatch.tensor.load %6, offsets = [0, 0], sizes = [%0, %1], strides = [1, 1]
+            : !flow.dispatch.tensor<readonly:?x?xi32>{%0, %1} -> tensor<?x?xi32>
+        %9 = flow.dispatch.tensor.load %7, offsets = [0, 0], sizes = [%0, %1], strides = [1, 1]
+            : !flow.dispatch.tensor<readwrite:?x?xi32>{%2, %3} -> tensor<?x?xi32>
+        %10 = tensor.insert_slice %8 into %9[%4, %5] [%0, %1] [1, 1] {lowering.config = #config} : tensor<?x?xi32> into tensor<?x?xi32>
+        flow.dispatch.tensor.store %10, %7, offsets = [0, 0], sizes = [%2, %3], strides = [1, 1]
+            : tensor<?x?xi32> -> !flow.dispatch.tensor<readwrite:?x?xi32>{%2, %3}
+        return
+      }
+    }
+  }
+}
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 64)>
+//  CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 64)>
+//  CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0)[s0] -> (64, -d0 + s0)>
+//  CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0)[s0] -> (d0 + s0)>
+//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"CPUDefault", workload_per_wg = [64, 64]>
+//      CHECK: hal.executable.entry_point public @tensor_insert
+// CHECK-SAME:     translation.info = #[[TRANSLATION]]
+// CHECK-NEXT:   (%[[ARG0:[a-zA-Z0-9]+]]: index
+// CHECK-SAME:    %[[ARG1:[a-zA-Z0-9]+]]: index
+// CHECK-SAME:    %[[ARG2:[a-zA-Z0-9]+]]: index)
+//  CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
+//  CHECK-DAG:   %[[D0:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]]]
+//  CHECK-DAG:   %[[D1:.+]] = affine.apply #[[MAP0]]()[%[[ARG1]]]
+//      CHECK:   hal.return %[[D0]], %[[D1]], %[[C1]]
+//      CHECK: func @tensor_insert_slice()
+//  CHECK-DAG:   %[[SIZE_Y:.+]] = hal.interface.constant.load[0] : index
+//  CHECK-DAG:   %[[SIZE_X:.+]] = hal.interface.constant.load[1] : index
+//  CHECK-DAG:   %[[DEST_SIZE_Y:.+]] = hal.interface.constant.load[2] : index
+//  CHECK-DAG:   %[[DEST_SIZE_X:.+]] = hal.interface.constant.load[3] : index
+//  CHECK-DAG:   %[[OFFSET_Y:.+]] = hal.interface.constant.load[4] : index
+//  CHECK-DAG:   %[[OFFSET_X:.+]] = hal.interface.constant.load[5] : index
+//  CHECK-DAG:   %[[WG_ID_X:.+]] = hal.interface.workgroup.id[0]
+//  CHECK-DAG:   %[[WG_COUNT_X:.+]] = hal.interface.workgroup.count[0]
+//  CHECK-DAG:   %[[WG_ID_Y:.+]] = hal.interface.workgroup.id[1]
+//  CHECK-DAG:   %[[WG_COUNT_Y:.+]] = hal.interface.workgroup.count[1]
+//  CHECK-DAG:   %[[LB_Y:.+]] = affine.apply #[[MAP1]]()[%[[WG_ID_Y]]]
+//  CHECK-DAG:   %[[STEP_Y:.+]] = affine.apply #[[MAP1]]()[%[[WG_COUNT_Y]]]
+//      CHECK:   scf.for %[[IV0:.+]] = %[[LB_Y]] to %[[SIZE_Y]] step %[[STEP_Y]]
+//  CHECK-DAG:     %[[TILESIZE_Y:.+]] = affine.min #[[MAP2]](%[[ARG0]])[%[[SIZE_Y]]]
+//  CHECK-DAG:     %[[LB_X:.+]] = affine.apply #[[MAP1]]()[%[[WG_ID_X]]]
+//  CHECK-DAG:     %[[STEP_X:.+]] = affine.apply #[[MAP1]]()[%[[WG_COUNT_X]]]
+//      CHECK:     scf.for %[[IV1:.+]] = %[[LB_X]] to %[[SIZE_X]] step %[[STEP_X]]
+//  CHECK-DAG:       %[[TILESIZE_X:.+]] = affine.min #[[MAP2]](%[[ARG1]])[%[[SIZE_X]]]
+//  CHECK-DAG:       %[[SOURCE:.+]] = flow.dispatch.tensor.load
+// CHECK-SAME:           offsets = [%[[IV0]], %[[IV1]]], sizes = [%[[TILESIZE_Y]], %[[TILESIZE_X]]]
+//  CHECK-DAG:       %[[STORE_OFFSET_Y:.+]] = affine.apply #[[MAP3]](%[[IV0]])[%[[OFFSET_Y]]]
+//  CHECK-DAG:       %[[STORE_OFFSET_X:.+]] = affine.apply #[[MAP3]](%[[IV1]])[%[[OFFSET_X]]]
+//      CHECK:       flow.dispatch.tensor.store %[[SOURCE]]
+// CHECK-SAME:           offsets = [%[[STORE_OFFSET_Y]], %[[STORE_OFFSET_X]]], sizes = [%[[TILESIZE_Y]], %[[TILESIZE_X]]]
+
+// -----
+
+#config = #iree_codegen.lowering.config<tile_sizes = [[64, 64]], native_vector_size = []>
+#executable_layout = #hal.executable.layout<push_constants = 0, sets = [
+  #hal.descriptor_set.layout<0, bindings = [
+    #hal.descriptor_set.binding<0, storage_buffer>,
+    #hal.descriptor_set.binding<1, storage_buffer>
+  ]>
+]>
+#executable_target_system_elf_x86_64_ = #hal.executable.target<"llvm", "system-elf-x86_64">
+#translation = #iree_codegen.translation.info<"CPUDefault", workload_per_wg = []>
+hal.executable public @extract_slice {
+  hal.executable.variant public @system_elf_x86_64, target = #executable_target_system_elf_x86_64_ {
+    hal.executable.entry_point public @extract_slice layout(#executable_layout) {translation.info = #translation}
+    builtin.module {
+      func @extract_slice() {
+        %0 = hal.interface.constant.load[0] : index
+        %1 = hal.interface.constant.load[1] : index
+        %2 = hal.interface.constant.load[2] : index
+        %3 = hal.interface.constant.load[3] : index
+        %4 = hal.interface.constant.load[4] : index
+        %5 = hal.interface.constant.load[5] : index
+        %6 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer)
+            : !flow.dispatch.tensor<readonly:?x?xi32>{%0, %1}
+        %7 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer)
+            : !flow.dispatch.tensor<writeonly:?x?xi32>{%2, %3}
+        %8 = flow.dispatch.tensor.load %6, offsets = [0, 0], sizes = [%0, %1], strides = [1, 1]
+            : !flow.dispatch.tensor<readonly:?x?xi32>{%0, %1} -> tensor<?x?xi32>
+        %9 = tensor.extract_slice %8[%4, %5] [%2, %3] [1, 1] {lowering.config = #config} : tensor<?x?xi32> to tensor<?x?xi32>
+        flow.dispatch.tensor.store %9, %7, offsets = [0, 0], sizes = [%2, %3], strides = [1, 1]
+            : tensor<?x?xi32> -> !flow.dispatch.tensor<writeonly:?x?xi32>{%2, %3}
+        return
+      }
+    }
+  }
+}
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 64)>
+//  CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 64)>
+//  CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0)[s0] -> (64, -d0 + s0)>
+//  CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0)[s0] -> (d0 + s0)>
+//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"CPUDefault", workload_per_wg = [64, 64]>
+//      CHECK: hal.executable.entry_point public @extract_slice
+// CHECK-SAME:     translation.info = #[[TRANSLATION]]
+// CHECK-NEXT:   (%[[ARG0:[a-zA-Z0-9]+]]: index
+// CHECK-SAME:    %[[ARG1:[a-zA-Z0-9]+]]: index
+// CHECK-SAME:    %[[ARG2:[a-zA-Z0-9]+]]: index)
+//  CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
+//  CHECK-DAG:   %[[D0:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]]]
+//  CHECK-DAG:   %[[D1:.+]] = affine.apply #[[MAP0]]()[%[[ARG1]]]
+//      CHECK:   hal.return %[[D0]], %[[D1]], %[[C1]]
+//      CHECK: func @extract_slice()
+//  CHECK-DAG:   %[[SOURCE_SIZE_Y:.+]] = hal.interface.constant.load[0] : index
+//  CHECK-DAG:   %[[SOURCE_SIZE_X:.+]] = hal.interface.constant.load[1] : index
+//  CHECK-DAG:   %[[SIZE_Y:.+]] = hal.interface.constant.load[2] : index
+//  CHECK-DAG:   %[[SIZE_X:.+]] = hal.interface.constant.load[3] : index
+//  CHECK-DAG:   %[[OFFSET_Y:.+]] = hal.interface.constant.load[4] : index
+//  CHECK-DAG:   %[[OFFSET_X:.+]] = hal.interface.constant.load[5] : index
+//  CHECK-DAG:   %[[WG_ID_X:.+]] = hal.interface.workgroup.id[0]
+//  CHECK-DAG:   %[[WG_COUNT_X:.+]] = hal.interface.workgroup.count[0]
+//  CHECK-DAG:   %[[WG_ID_Y:.+]] = hal.interface.workgroup.id[1]
+//  CHECK-DAG:   %[[WG_COUNT_Y:.+]] = hal.interface.workgroup.count[1]
+//  CHECK-DAG:   %[[LB_Y:.+]] = affine.apply #[[MAP1]]()[%[[WG_ID_Y]]]
+//  CHECK-DAG:   %[[STEP_Y:.+]] = affine.apply #[[MAP1]]()[%[[WG_COUNT_Y]]]
+//      CHECK:   scf.for %[[IV0:.+]] = %[[LB_Y]] to %[[SIZE_Y]] step %[[STEP_Y]]
+//  CHECK-DAG:     %[[TILESIZE_Y:.+]] = affine.min #[[MAP2]](%[[ARG0]])[%[[SIZE_Y]]]
+//  CHECK-DAG:     %[[LB_X:.+]] = affine.apply #[[MAP1]]()[%[[WG_ID_X]]]
+//  CHECK-DAG:     %[[STEP_X:.+]] = affine.apply #[[MAP1]]()[%[[WG_COUNT_X]]]
+//      CHECK:     scf.for %[[IV1:.+]] = %[[LB_X]] to %[[SIZE_X]] step %[[STEP_X]]
+//  CHECK-DAG:       %[[TILESIZE_X:.+]] = affine.min #[[MAP2]](%[[ARG1]])[%[[SIZE_X]]]
+//  CHECK-DAG:       %[[LOAD_OFFSET_Y:.+]] = affine.apply #[[MAP3]](%[[IV0]])[%[[OFFSET_Y]]]
+//  CHECK-DAG:       %[[LOAD_OFFSET_X:.+]] = affine.apply #[[MAP3]](%[[IV1]])[%[[OFFSET_X]]]
+//  CHECK-DAG:       %[[SOURCE:.+]] = flow.dispatch.tensor.load
+// CHECK-SAME:           offsets = [%[[LOAD_OFFSET_Y]], %[[LOAD_OFFSET_X]]], sizes = [%[[TILESIZE_Y]], %[[TILESIZE_X]]]
+//      CHECK:       flow.dispatch.tensor.store %[[SOURCE]]
+// CHECK-SAME:           offsets = [%[[IV0]], %[[IV1]]], sizes = [%[[TILESIZE_Y]], %[[TILESIZE_X]]]
+
+// -----
+
+#config = #iree_codegen.lowering.config<tile_sizes = [[64]], native_vector_size = []>
+#executable_layout = #hal.executable.layout<push_constants = 0, sets = [
+  #hal.descriptor_set.layout<0, bindings = [
+    #hal.descriptor_set.binding<0, storage_buffer>,
+    #hal.descriptor_set.binding<1, storage_buffer>
+  ]>
+]>
+#executable_target_system_elf_x86_64_ = #hal.executable.target<"llvm", "system-elf-x86_64">
+#translation = #iree_codegen.translation.info<"CPUDefault", workload_per_wg = []>
+hal.executable private @static_1d_fft_stage2 {
+  hal.executable.variant public @system_elf_x86_64, target = #executable_target_system_elf_x86_64_ {
+    hal.executable.entry_point public @static_1d_fft_stage2 layout(#executable_layout) {translation.info = #translation}
+    builtin.module {
+      func @static_1d_fft_stage2() {
+        %c2 = arith.constant 2 : index
+        %cst = arith.constant dense<[1.000000e+00, 6.12323426E-17]> : tensor<2xf32>
+        %cst_0 = arith.constant dense<[-0.000000e+00, -1.000000e+00]> : tensor<2xf32>
+        %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer)
+            : !flow.dispatch.tensor<readwrite:32xf32>
+        %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer)
+            : !flow.dispatch.tensor<readwrite:32xf32>
+        %2 = flow.dispatch.tensor.load %0, offsets = [0], sizes = [32], strides = [1]
+            : !flow.dispatch.tensor<readwrite:32xf32> -> tensor<32xf32>
+        %3 = flow.dispatch.tensor.load %1, offsets = [0], sizes = [32], strides = [1]
+            : !flow.dispatch.tensor<readwrite:32xf32> -> tensor<32xf32>
+        %4:2 = iree_linalg_ext.fft {__internal_linalg_transform__ = "workgroup", lowering.config = #config}
+            ins(%c2, %cst, %cst_0 : index, tensor<2xf32>, tensor<2xf32>) outs(%2, %3 : tensor<32xf32>, tensor<32xf32>) : tensor<32xf32>, tensor<32xf32>
+        flow.dispatch.tensor.store %4#0, %0, offsets = [0], sizes = [32], strides = [1]
+            : tensor<32xf32> -> !flow.dispatch.tensor<readwrite:32xf32>
+        flow.dispatch.tensor.store %4#1, %1, offsets = [0], sizes = [32], strides = [1]
+            : tensor<32xf32> -> !flow.dispatch.tensor<readwrite:32xf32>
+        return
+      }
+    }
+  }
+}
+//  CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0] -> (s0 ceildiv 64)>
+//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"CPUDefault", workload_per_wg = [64]>
+//      CHECK: hal.executable private @static_1d_fft_stage2
+//      CHECK: hal.executable.entry_point public @static_1d_fft_stage2
+// CHECK-SAME:  translation.info = #[[TRANSLATION]]
+// CHECK-NEXT:  (%[[ARG0:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: index)
+//  CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
+//  CHECK-DAG:   %[[D0:.+]] = affine.apply #[[MAP]]()[%[[ARG0]]]
+//      CHECK:   hal.return %[[D0]], %[[C1]], %[[C1]] : index, index, index
+//      CHECK: func @static_1d_fft_stage2()
+//      CHECK:   scf.for %[[IV0:.+]] =
+//      CHECK:     %[[RESULT:.+]]:2 = iree_linalg_ext.fft
+//  CHECK-DAG:     flow.dispatch.tensor.store %[[RESULT]]#0, %{{.+}}, offsets = [%[[IV0]]]
+//  CHECK-DAG:     flow.dispatch.tensor.store %[[RESULT]]#1, %{{.+}}, offsets = [%[[IV0]]]
+
+// -----
+
+#config = #iree_codegen.lowering.config<tile_sizes = [[64, 64, 64]], native_vector_size = []>
+#executable_layout = #hal.executable.layout<push_constants = 0, sets = [
+  #hal.descriptor_set.layout<0, bindings = [
+    #hal.descriptor_set.binding<0, storage_buffer>,
+    #hal.descriptor_set.binding<1, storage_buffer>
+  ]>
+]>
+#executable_target_system_elf_x86_64_ = #hal.executable.target<"llvm", "system-elf-x86_64">
+#translation = #iree_codegen.translation.info<"CPUDefault", workload_per_wg = []>
+hal.executable private @static_3d_fft_stage3 {
+  hal.executable.variant public @system_elf_x86_64, target = #executable_target_system_elf_x86_64_ {
+    hal.executable.entry_point public @static_3d_fft_stage3 layout(#executable_layout) {translation.info = #translation}
+    builtin.module {
+      func @static_3d_fft_stage3() {
+        %c3 = arith.constant 3 : index
+        %cst = arith.constant dense<[1.000000e+00, 0.707106769, 6.12323426E-17, -0.707106769]> : tensor<4xf32>
+        %cst_0 = arith.constant dense<[-0.000000e+00, -0.707106769, -1.000000e+00, -0.707106769]> : tensor<4xf32>
+        %0 = bufferization.to_memref %cst_0 : memref<4xf32>
+        %1 = bufferization.to_memref %cst : memref<4xf32>
+        %2 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) : memref<64x128x32xf32>
+        %3 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) : memref<64x128x32xf32>
+        iree_linalg_ext.fft {lowering.config = #config}
+            ins(%c3, %1, %0 : index, memref<4xf32>, memref<4xf32>) outs(%2, %3 : memref<64x128x32xf32>, memref<64x128x32xf32>)
+        return
+      }
+    }
+  }
+}
+//  CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0] -> (s0 ceildiv 64)>
+//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"CPUDefault", workload_per_wg = [64, 64, 64]>
+//      CHECK: hal.executable private @static_3d_fft_stage3
+//      CHECK: hal.executable.entry_point public @static_3d_fft_stage3
+// CHECK-SAME:  translation.info = #[[TRANSLATION]]
+// CHECK-NEXT:  (%[[ARG0:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: index)
+//  CHECK-DAG:   %[[D0:.+]] = affine.apply #[[MAP]]()[%[[ARG0]]]
+//  CHECK-DAG:   %[[D1:.+]] = affine.apply #[[MAP]]()[%[[ARG1]]]
+//  CHECK-DAG:   %[[D2:.+]] = affine.apply #[[MAP]]()[%[[ARG2]]]
+//      CHECK:   hal.return %[[D0]], %[[D1]], %[[D2]] : index, index, index
+//      CHECK: func @static_3d_fft_stage3()
+//      CHECK:   scf.for %[[IV0:.+]] =
+//      CHECK:     scf.for %[[IV1:.+]] =
+//      CHECK:       scf.for %[[IV2:.+]] =
+//  CHECK-DAG:         %[[SUBVIEW1:.+]] = memref.subview %{{.+}}[%[[IV0]], %[[IV1]], %[[IV2]]]
+//  CHECK-DAG:         %[[SUBVIEW2:.+]] = memref.subview %{{.+}}[%[[IV0]], %[[IV1]], %[[IV2]]]
+//      CHECK:         iree_linalg_ext.fft
+// CHECK-SAME:             outs(%[[SUBVIEW1]], %[[SUBVIEW2]] :
+
+// -----
+
+#config = #iree_codegen.lowering.config<tile_sizes = [[64, 64, 0], [1, 4, 0], [0, 0, 4]], native_vector_size = []>
+#executable_layout = #hal.executable.layout<push_constants = 0, sets = [
+  #hal.descriptor_set.layout<0, bindings = [
+    #hal.descriptor_set.binding<0, storage_buffer>,
+    #hal.descriptor_set.binding<1, storage_buffer>,
+    #hal.descriptor_set.binding<2, storage_buffer>
+  ]>
+]>
+#executable_target_system_elf_x86_64_ = #hal.executable.target<"llvm", "system-elf-x86_64">
+#map0 = affine_map<(d0, d1) -> (d0, d1)>
+#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
+#translation = #iree_codegen.translation.info<"CPUDoubleTilingExpert", workload_per_wg = []>
+hal.executable private @outs_fusion {
+  hal.executable.variant public @system_elf_x86_64, target = #executable_target_system_elf_x86_64_ {
+    hal.executable.entry_point public @outs_fusion_fn layout(#executable_layout) {translation.info = #translation}
+    builtin.module {
+      func @outs_fusion_fn() {
+        %cst = arith.constant 0.000000e+00 : f32
+        %0 = hal.interface.constant.load[0] : index
+        %1 = hal.interface.constant.load[1] : index
+        %2 = hal.interface.constant.load[2] : index
+        %3 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer)
+            : !flow.dispatch.tensor<readonly:?x?xf32>{%0, %2}
+        %4 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer)
+            : !flow.dispatch.tensor<readonly:?x?xf32>{%2, %1}
+        %5 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer)
+            : !flow.dispatch.tensor<writeonly:?x?xf32>{%0, %1}
+        %6 = linalg.init_tensor [%0, %1] : tensor<?x?xf32>
+        %7 = linalg.generic {
+            indexing_maps = [#map0], iterator_types = ["parallel", "parallel"]} outs(%6 : tensor<?x?xf32>) {
+        ^bb0(%arg0: f32):
+          linalg.yield %cst : f32
+        } -> tensor<?x?xf32>
+        %8 = flow.dispatch.tensor.load %3, offsets = [0, 0], sizes = [%0, %2], strides = [1, 1]
+            : !flow.dispatch.tensor<readonly:?x?xf32>{%0, %2} -> tensor<?x?xf32>
+        %9 = flow.dispatch.tensor.load %4, offsets = [0, 0], sizes = [%2, %1], strides = [1, 1]
+            : !flow.dispatch.tensor<readonly:?x?xf32>{%2, %1} -> tensor<?x?xf32>
+        %10 = linalg.generic {
+            indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"]}
+            ins(%8, %9 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%7 : tensor<?x?xf32>) attrs =  {lowering.config = #config} {
+        ^bb0(%arg0: f32, %arg1: f32, %arg2: f32):
+          %11 = arith.mulf %arg0, %arg1 : f32
+          linalg.yield %11 : f32
+        } -> tensor<?x?xf32>
+        flow.dispatch.tensor.store %10, %5, offsets = [0, 0], sizes = [%0, %1], strides = [1, 1]
+            : tensor<?x?xf32> -> !flow.dispatch.tensor<writeonly:?x?xf32>{%0, %1}
+        return
+      }
+    }
+  }
+}
+//      CHECK: func @outs_fusion_fn
+//      CHECK:   scf.for %[[IV0:.+]] =
+//      CHECK:     scf.for %[[IV1:.+]] =
+//      CHECK:       %[[INIT:.+]] = linalg.init_tensor
+//      CHECK:       %[[FILL:.+]] = linalg.generic
+// CHECK-SAME:           outs(%[[INIT]] :
+//      CHECK:       %[[GENERIC:.+]] = linalg.generic
+// CHECK-SAME:           outs(%[[FILL]] :
+//      CHECK:       flow.dispatch.tensor.store %[[GENERIC]], %{{.+}}, offsets = [%[[IV0]], %[[IV1]]]
+
+// -----
+
+#config = #iree_codegen.lowering.config<tile_sizes = [[0, 64, 64, 64, 0, 0, 0]], native_vector_size = []>
+#executable_layout = #hal.executable.layout<push_constants = 0, sets = [
+  #hal.descriptor_set.layout<0, bindings = [
+    #hal.descriptor_set.binding<0, storage_buffer>,
+    #hal.descriptor_set.binding<1, storage_buffer>,
+    #hal.descriptor_set.binding<2, storage_buffer>
+  ]>
+]>
+#executable_target_system_elf_x86_64_ = #hal.executable.target<"llvm", "system-elf-x86_64", {
+  data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128",
+  native_vector_size = 16 : index,
+  target_triple = "x86_64-unknown-linux-gnu"}>
+#translation = #iree_codegen.translation.info<"CPUDefault", workload_per_wg = []>
+hal.executable private @conv {
+  hal.executable.variant public @system_elf_x86_64, target = #executable_target_system_elf_x86_64_ {
+    hal.executable.entry_point public @conv layout(#executable_layout) {translation.info = #translation}
+    builtin.module {
+      func @conv() {
+        %0 = hal.interface.constant.load[0] : index
+        %1 = hal.interface.constant.load[1] : index
+        %2 = hal.interface.constant.load[2] : index
+        %3 = hal.interface.constant.load[3] : index
+        %4 = hal.interface.constant.load[4] : index
+        %5 = hal.interface.constant.load[5] : index
+        %6 = hal.interface.constant.load[6] : index
+        %7 = hal.interface.constant.load[7] : index
+        %8 = hal.interface.constant.load[8] : index
+        %9 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer)
+            : !flow.dispatch.tensor<readonly:?x?x?x?xf32>{%0, %1, %2, %3}
+        %10 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer)
+            : !flow.dispatch.tensor<readonly:?x?x?x?xf32>{%4, %5, %3, %6}
+        %11 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer)
+            : !flow.dispatch.tensor<readwrite:?x?x?x?xf32>{%0, %7, %8, %6}
+        %12 = flow.dispatch.tensor.load %9, offsets = [0, 0, 0, 0], sizes = [%0, %1, %2, %3], strides = [1, 1, 1, 1]
+            : !flow.dispatch.tensor<readonly:?x?x?x?xf32>{%0, %1, %2, %3} -> tensor<?x?x?x?xf32>
+        %13 = flow.dispatch.tensor.load %10, offsets = [0, 0, 0, 0], sizes = [%4, %5, %3, %6], strides = [1, 1, 1, 1]
+            : !flow.dispatch.tensor<readonly:?x?x?x?xf32>{%4, %5, %3, %6} -> tensor<?x?x?x?xf32>
+        %14 = flow.dispatch.tensor.load %11, offsets = [0, 0, 0, 0], sizes = [%0, %7, %8, %6], strides = [1, 1, 1, 1]
+            : !flow.dispatch.tensor<readwrite:?x?x?x?xf32>{%0, %7, %8, %6} -> tensor<?x?x?x?xf32>
+        %15 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, lowering.config = #config, strides = dense<1> : tensor<2xi64>}
+            ins(%12, %13 : tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) outs(%14 : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+        flow.dispatch.tensor.store %15, %11, offsets = [0, 0, 0, 0], sizes = [%0, %7, %8, %6], strides = [1, 1, 1, 1]
+            : tensor<?x?x?x?xf32> -> !flow.dispatch.tensor<readwrite:?x?x?x?xf32>{%0, %7, %8, %6}
+        return
+      }
+    }
+  }
+}
+//  CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0] -> (s0 ceildiv 64)>
+//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"CPUDefault", workload_per_wg = [64, 64, 64]>
+//      CHECK: hal.executable private @conv
+//      CHECK: hal.executable.entry_point public @conv
+// CHECK-SAME:  translation.info = #[[TRANSLATION]]
+// CHECK-NEXT:  (%[[ARG0:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: index)
+//  CHECK-DAG:   %[[D0:.+]] = affine.apply #[[MAP]]()[%[[ARG0]]]
+//  CHECK-DAG:   %[[D1:.+]] = affine.apply #[[MAP]]()[%[[ARG1]]]
+//  CHECK-DAG:   %[[D2:.+]] = affine.apply #[[MAP]]()[%[[ARG2]]]
+//      CHECK:   hal.return %[[D0]], %[[D1]], %[[D2]] : index, index, index
+//      CHECK: func @conv()
+//      CHECK:   scf.for %[[IV0:.+]] =
+//      CHECK:     scf.for %[[IV1:.+]] =
+//      CHECK:       scf.for %[[IV2:.+]] =
+//  CHECK-DAG:         %[[INPUT:.+]] = flow.dispatch.tensor.load %{{.+}}, offsets = [0, %[[IV0]], %[[IV1]], 0]
+//  CHECK-DAG:         %[[FILTER:.+]] = flow.dispatch.tensor.load %{{.+}}, offsets = [0, 0, 0, %[[IV2]]]
+//  CHECK-DAG:         %[[INIT:.+]] = flow.dispatch.tensor.load %{{.+}}, offsets = [0, %[[IV0]], %[[IV1]], %[[IV2]]]
+//      CHECK:         %[[RESULT:.+]] = linalg.conv_2d_nhwc_hwcf
+//      CHECK:         flow.dispatch.tensor.store %[[RESULT]], %{{.+}}, offsets = [0, %[[IV0]], %[[IV1]], %[[IV2]]]
+
+// -----
+
+#config = #iree_codegen.lowering.config<tile_sizes = [[0, 20, 40, 48, 0, 0]], native_vector_size = []>
+#executable_layout = #hal.executable.layout<push_constants = 0, sets = [
+  #hal.descriptor_set.layout<0, bindings = [
+    #hal.descriptor_set.binding<0, storage_buffer>,
+    #hal.descriptor_set.binding<1, storage_buffer>,
+    #hal.descriptor_set.binding<2, storage_buffer>
+  ]>
+]>
+#executable_target_system_elf_x86_64_ = #hal.executable.target<"llvm", "system-elf-x86_64", {
+  data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128",
+  native_vector_size = 16 : index,
+  target_triple = "x86_64-unknown-linux-gnu"}>
+#translation = #iree_codegen.translation.info<"CPUDefault", workload_per_wg = []>
+hal.executable private @conv_static {
+  hal.executable.variant public @system_elf_x86_64, target = #executable_target_system_elf_x86_64_ {
+    hal.executable.entry_point public @conv_static layout(#executable_layout) {translation.info = #translation}
+    builtin.module {
+      func @conv_static() {
+        %cst = arith.constant 0.000000e+00 : f32
+        %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer)
+            : !flow.dispatch.tensor<readonly:1x161x161x96xf32>
+        %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer)
+            : !flow.dispatch.tensor<readonly:3x3x96xf32>
+        %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer)
+            : !flow.dispatch.tensor<writeonly:1x80x80x96xf32>
+        %3 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0, 0], sizes = [1, 161, 161, 96], strides = [1, 1, 1, 1]
+            : !flow.dispatch.tensor<readonly:1x161x161x96xf32> -> tensor<1x161x161x96xf32>
+        %4 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0], sizes = [3, 3, 96], strides = [1, 1, 1]
+            : !flow.dispatch.tensor<readonly:3x3x96xf32> -> tensor<3x3x96xf32>
+        %5 = linalg.init_tensor [1, 80, 80, 96] : tensor<1x80x80x96xf32>
+        %6 = linalg.fill(%cst, %5) : f32, tensor<1x80x80x96xf32> -> tensor<1x80x80x96xf32>
+        %7 = linalg.depthwise_conv_2d_nhwc_hwc {dilations = dense<1> : tensor<2xi64>, lowering.config = #config, strides = dense<2> : tensor<2xi64>}
+            ins(%3, %4 : tensor<1x161x161x96xf32>, tensor<3x3x96xf32>) outs(%6 : tensor<1x80x80x96xf32>) -> tensor<1x80x80x96xf32>
+        flow.dispatch.tensor.store %7, %2, offsets = [0, 0, 0, 0], sizes = [1, 80, 80, 96], strides = [1, 1, 1, 1]
+            : tensor<1x80x80x96xf32> -> !flow.dispatch.tensor<writeonly:1x80x80x96xf32>
+        return
+      }
+    }
+  }
+}
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 48)>
+//  CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 ceildiv 40)>
+//  CHECK-DAG: #[[MAP2:.+]] = affine_map<()[s0] -> (s0 ceildiv 20)>
+//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"CPUDefault", workload_per_wg = [48, 40, 20]>
+//      CHECK: hal.executable private @conv_static
+//      CHECK: hal.executable.entry_point public @conv_static
+// CHECK-SAME:  translation.info = #[[TRANSLATION]]
+// CHECK-NEXT:  (%[[ARG0:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: index)
+//  CHECK-DAG:   %[[D0:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]]]
+//  CHECK-DAG:   %[[D1:.+]] = affine.apply #[[MAP1]]()[%[[ARG1]]]
+//  CHECK-DAG:   %[[D2:.+]] = affine.apply #[[MAP2]]()[%[[ARG2]]]
+//      CHECK:   hal.return %[[D0]], %[[D1]], %[[D2]] : index, index, index
+//      CHECK: func @conv_static()
+//      CHECK:   scf.for %[[IV0:.+]] =
+//      CHECK:     scf.for %[[IV1:.+]] =
+//      CHECK:       scf.for %[[IV2:.+]] =
+//      CHECK:         %[[INIT:.+]] = linalg.init_tensor [1, 20, 40, 48]
+//      CHECK:         %[[FILL:.+]] = linalg.fill(%{{.+}}, %[[INIT]])
+//      CHECK:         %[[RESULT:.+]] = linalg.depthwise_conv_2d_nhwc_hwc
+// CHECK-SAME:             outs(%[[FILL]] :
+//      CHECK:         flow.dispatch.tensor.store %[[RESULT]], %{{.+}}, offsets = [0, %[[IV0]], %[[IV1]], %[[IV2]]], sizes = [1, 20, 40, 48]
+
+// -----
+
+#config = #iree_codegen.lowering.config<tile_sizes = [[16, 32], [16, 16], [0, 0]], native_vector_size = []>
+#executable_layout = #hal.executable.layout<push_constants = 0, sets = [
+  #hal.descriptor_set.layout<0, bindings = [
+    #hal.descriptor_set.binding<0, storage_buffer>,
+    #hal.descriptor_set.binding<1, storage_buffer>
+  ]>
+]>
+#executable_target_system_elf_x86_64_ = #hal.executable.target<"llvm", "system-elf-x86_64", {
+  data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128",
+  native_vector_size = 64 : index,
+  target_triple = "x86_64-pc-linux-gnu"}>
+#map0 = affine_map<(d0, d1) -> (d1, d0)>
+#map1 = affine_map<(d0, d1) -> (d0, d1)>
+#translation = #iree_codegen.translation.info<"CPUDoubleTilingExpert", workload_per_wg = []>
+hal.executable private @generic_static {
+  hal.executable.variant public @system_elf_x86_64, target = #executable_target_system_elf_x86_64_ {
+    hal.executable.entry_point public @generic_static layout(#executable_layout) {translation.info = #translation}
+    builtin.module {
+      func @generic_static() {
+        %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer)
+            : !flow.dispatch.tensor<readonly:96x16xf32>
+        %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer)
+            : !flow.dispatch.tensor<writeonly:16x96xf32>
+        %2 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [96, 16], strides = [1, 1]
+            : !flow.dispatch.tensor<readonly:96x16xf32> -> tensor<96x16xf32>
+        %3 = linalg.init_tensor [16, 96] : tensor<16x96xf32>
+        %4 = linalg.generic {
+            indexing_maps = [#map0, #map1], iterator_types = ["parallel", "parallel"]}
+            ins(%2 : tensor<96x16xf32>) outs(%3 : tensor<16x96xf32>) attrs =  {lowering.config = #config} {
+        ^bb0(%arg0: f32, %arg1: f32):
+          linalg.yield %arg0 : f32
+        } -> tensor<16x96xf32>
+        flow.dispatch.tensor.store %4, %1, offsets = [0, 0], sizes = [16, 96], strides = [1, 1]
+            : tensor<16x96xf32> -> !flow.dispatch.tensor<writeonly:16x96xf32>
+        return
+      }
+    }
+  }
+}
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 32)>
+//  CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 ceildiv 16)>
+//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"CPUDoubleTilingExpert", workload_per_wg = [32, 16]>
+//      CHECK: hal.executable private @generic_static
+//      CHECK: hal.executable.entry_point public @generic_static
+// CHECK-SAME:  translation.info = #[[TRANSLATION]]
+// CHECK-NEXT:  (%[[ARG0:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: index)
+//  CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
+//  CHECK-DAG:   %[[D0:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]]]
+//  CHECK-DAG:   %[[D1:.+]] = affine.apply #[[MAP1]]()[%[[ARG1]]]
+//      CHECK:   hal.return %[[D0]], %[[D1]], %[[C1]] : index, index, index
+//      CHECK: func @generic_static()
+//      CHECK:   scf.for %[[IV0:.+]] =
+//      CHECK:     scf.for %[[IV1:.+]] =
+//      CHECK:       %[[RESULT:.+]] = linalg.generic
+//      CHECK:       flow.dispatch.tensor.store %[[RESULT]], %{{.+}}, offsets = [%[[IV0]], %[[IV1]]]
+
+// -----
+
+#config = #iree_codegen.lowering.config<tile_sizes = [[28, 8, 0], [4, 4, 60], [4, 4, 4]], native_vector_size = [4, 4, 4]>
+#executable_layout = #hal.executable.layout<push_constants = 0, sets = [
+  #hal.descriptor_set.layout<0, bindings = [
+    #hal.descriptor_set.binding<0, storage_buffer>,
+    #hal.descriptor_set.binding<1, storage_buffer>,
+    #hal.descriptor_set.binding<2, storage_buffer>
+  ]>
+]>
+#executable_target_system_elf_arm_64_ = #hal.executable.target<"llvm", "system-elf-arm_64", {
+  data_layout = "e-m:e-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128",
+  native_vector_size = 16 : index,
+  target_triple = "aarch64-none-linux-android30"}>
+#translation = #iree_codegen.translation.info<"CPUTileFuseAndVectorize", workload_per_wg = []>
+hal.executable private @matmul_static {
+  hal.executable.variant public @system_elf_arm_64, target = #executable_target_system_elf_arm_64_ {
+    hal.executable.entry_point public @matmul_static layout(#executable_layout) {translation.info = #translation}
+    builtin.module {
+      func @matmul_static() {
+        %cst = arith.constant 0.000000e+00 : f32
+        %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer)
+            : !flow.dispatch.tensor<readonly:196x240xf32>
+        %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer)
+            : !flow.dispatch.tensor<readonly:240x40xf32>
+        %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer)
+            : !flow.dispatch.tensor<writeonly:196x40xf32>
+        %3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [196, 240], strides = [1, 1]
+            : !flow.dispatch.tensor<readonly:196x240xf32> -> tensor<196x240xf32>
+        %4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [240, 40], strides = [1, 1]
+            : !flow.dispatch.tensor<readonly:240x40xf32> -> tensor<240x40xf32>
+        %5 = linalg.init_tensor [196, 40] : tensor<196x40xf32>
+        %6 = linalg.fill(%cst, %5) : f32, tensor<196x40xf32> -> tensor<196x40xf32>
+        %7 = linalg.matmul {lowering.config = #config}
+            ins(%3, %4 : tensor<196x240xf32>, tensor<240x40xf32>) outs(%6 : tensor<196x40xf32>) -> tensor<196x40xf32>
+        flow.dispatch.tensor.store %7, %2, offsets = [0, 0], sizes = [196, 40], strides = [1, 1]
+            : tensor<196x40xf32> -> !flow.dispatch.tensor<writeonly:196x40xf32>
+        return
+      }
+    }
+  }
+}
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 8)>
+//  CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 ceildiv 28)>
+//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"CPUTileFuseAndVectorize", workload_per_wg = [8, 28]>
+//      CHECK: hal.executable private @matmul_static
+//      CHECK: hal.executable.entry_point public @matmul_static
+// CHECK-SAME:  translation.info = #[[TRANSLATION]]
+// CHECK-NEXT:  (%[[ARG0:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: index)
+//  CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
+//  CHECK-DAG:   %[[D0:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]]]
+//  CHECK-DAG:   %[[D1:.+]] = affine.apply #[[MAP1]]()[%[[ARG1]]]
+//      CHECK:   hal.return %[[D0]], %[[D1]], %[[C1]] : index, index, index
+
+// -----
+
+#config = #iree_codegen.lowering.config<tile_sizes = [[0, 1, 7, 64, 0, 0]], native_vector_size = []>
+#executable_layout = #hal.executable.layout<push_constants = 0, sets = [
+  #hal.descriptor_set.layout<0, bindings = [
+    #hal.descriptor_set.binding<0, storage_buffer>,
+    #hal.descriptor_set.binding<1, storage_buffer>,
+    #hal.descriptor_set.binding<2, storage_buffer>
+  ]>
+]>
+#executable_target_system_elf_arm_64_ = #hal.executable.target<"llvm", "system-elf-arm_64", {
+  data_layout = "e-m:e-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128",
+  native_vector_size = 16 : index,
+  target_triple = "aarch64-none-linux-android30"}>
+#translation = #iree_codegen.translation.info<"CPUDefault", workload_per_wg = []>
+hal.executable private @restrict_num_workgroups {
+  hal.executable.variant public @system_elf_arm_64, target = #executable_target_system_elf_arm_64_ {
+    hal.executable.entry_point public @restrict_num_workgroups layout(#executable_layout) {translation.info = #translation}
+    builtin.module {
+      func @restrict_num_workgroups() {
+        %cst = arith.constant 0.000000e+00 : f32
+        %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer)
+            : !flow.dispatch.tensor<readonly:1x11x11x576xf32>
+        %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer)
+            : !flow.dispatch.tensor<readonly:5x5x576xf32>
+        %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer)
+            : !flow.dispatch.tensor<writeonly:1x7x7x576xf32>
+        %3 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0, 0], sizes = [1, 11, 11, 576], strides = [1, 1, 1, 1]
+            : !flow.dispatch.tensor<readonly:1x11x11x576xf32> -> tensor<1x11x11x576xf32>
+        %4 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0], sizes = [5, 5, 576], strides = [1, 1, 1]
+            : !flow.dispatch.tensor<readonly:5x5x576xf32> -> tensor<5x5x576xf32>
+        %5 = linalg.init_tensor [1, 7, 7, 576] : tensor<1x7x7x576xf32>
+        %6 = linalg.fill(%cst, %5) : f32, tensor<1x7x7x576xf32> -> tensor<1x7x7x576xf32>
+        %7 = linalg.depthwise_conv_2d_nhwc_hwc {dilations = dense<1> : tensor<2xi64>, lowering.config = #config, strides = dense<1> : tensor<2xi64>}
+            ins(%3, %4 : tensor<1x11x11x576xf32>, tensor<5x5x576xf32>) outs(%6 : tensor<1x7x7x576xf32>) -> tensor<1x7x7x576xf32>
+        flow.dispatch.tensor.store %7, %2, offsets = [0, 0, 0, 0], sizes = [1, 7, 7, 576], strides = [1, 1, 1, 1]
+            : tensor<1x7x7x576xf32> -> !flow.dispatch.tensor<writeonly:1x7x7x576xf32>
+        return
+      }
+    }
+  }
+}
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 64)>
+//  CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 ceildiv 7)>
+//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"CPUDefault", workload_per_wg = [64, 7, 1]>
+//      CHECK: hal.executable private @restrict_num_workgroups
+//      CHECK: hal.executable.entry_point public @restrict_num_workgroups
+// CHECK-SAME:  translation.info = #[[TRANSLATION]]
+// CHECK-NEXT:  (%[[ARG0:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: index)
+//  CHECK-DAG:   %[[D0:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]]]
+//  CHECK-DAG:   %[[D1:.+]] = affine.apply #[[MAP1]]()[%[[ARG1]]]
+//      CHECK:   hal.return %[[D0]], %[[D1]], %[[ARG2]] : index, index, index
+
+// -----
+
+#config = #iree_codegen.lowering.config<tile_sizes = [[4, 0, 0], [4, 0, 0], [0, 1, 4]], native_vector_size = []>
+#executable_layout = #hal.executable.layout<push_constants = 4, sets = [
+  #hal.descriptor_set.layout<0, bindings = [
+    #hal.descriptor_set.binding<0, storage_buffer>,
+    #hal.descriptor_set.binding<1, storage_buffer>,
+    #hal.descriptor_set.binding<2, storage_buffer>
+  ]>
+]>
+#executable_target_embedded_elf_x86_64_ = #hal.executable.target<"llvm", "embedded-elf-x86_64", {
+  data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128",
+  native_vector_size = 16 : index,
+  target_triple = "x86_64-unknown-unknown-eabi-elf"}>
+#map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d0)>
+#map2 = affine_map<(d0) -> (d0)>
+#translation = #iree_codegen.translation.info<"CPUDoubleTilingExpert", workload_per_wg = []>
+hal.executable private @reduction {
+  hal.executable.variant public @reduction, target = #executable_target_embedded_elf_x86_64_ {
+    hal.executable.entry_point public @reduction ordinal(0) layout(#executable_layout) {translation.info = #translation}
+    builtin.module {
+      func @reduction(%arg0 : !flow.dispatch.tensor<readonly:7x7x2048xf32>,
+          %arg1 : !flow.dispatch.tensor<writeonly:7xf32>) {
+        %cst = arith.constant 0.000000e+00 : f32
+        %cst_0 = arith.constant 1.000000e+01 : f32
+        %0 = flow.dispatch.tensor.load %arg0, offsets = [0, 0, 0], sizes = [7, 7, 2048], strides = [1, 1, 1]
+            : !flow.dispatch.tensor<readonly:7x7x2048xf32> -> tensor<7x7x2048xf32>
+        %1 = linalg.init_tensor [7] : tensor<7xf32>
+        %2 = linalg.fill(%cst, %1) : f32, tensor<7xf32> -> tensor<7xf32>
+        %3 = linalg.generic {
+            indexing_maps = [#map0, #map1], iterator_types = ["parallel", "reduction", "reduction"]}
+            ins(%0 : tensor<7x7x2048xf32>) outs(%2 : tensor<7xf32>) attrs =  {lowering.config = #config} {
+        ^bb0(%arg2: f32, %arg3: f32):
+          %5 = arith.addf %arg2, %arg3 : f32
+          linalg.yield %5 : f32
+        } -> tensor<7xf32>
+        %4 = linalg.generic {
+            indexing_maps = [#map2, #map2], iterator_types = ["parallel"]}
+            ins(%3 : tensor<7xf32>) outs(%1 : tensor<7xf32>) {
+        ^bb0(%arg2: f32, %arg3: f32):
+          %5 = arith.divf %arg2, %cst_0 : f32
+          linalg.yield %5 : f32
+        } -> tensor<7xf32>
+        flow.dispatch.tensor.store %4, %arg1, offsets = [0], sizes = [7], strides = [1]
+            : tensor<7xf32> -> !flow.dispatch.tensor<writeonly:7xf32>
+        return
+      }
+    }
+  }
+}
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 4)>
+//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"CPUDoubleTilingExpert", workload_per_wg = [4]>
+//      CHECK: hal.executable private @reduction
+//      CHECK: hal.executable.entry_point public @reduction
+// CHECK-SAME:  translation.info = #[[TRANSLATION]]
+// CHECK-NEXT:  (%[[ARG0:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: index)
+//  CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
+//  CHECK-DAG:   %[[D0:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]]]
+//      CHECK:   hal.return %[[D0]], %[[C1]], %[[C1]] : index, index, index
+//      CHECK: func @reduction
+//      CHECK:   scf.for %[[IV0:.+]] =
+//      CHECK:     %[[INIT0:.+]] = linalg.init_tensor
+//      CHECK:     %[[INIT:.+]] = linalg.init_tensor
+//      CHECK:     %[[FILL:.+]] = linalg.fill(%{{.+}}, %[[INIT]])
+//      CHECK:     %[[REDUCE:.+]] = linalg.generic
+// CHECK-SAME:         outs(%[[FILL]] :
+//      CHECK:     %[[GENERIC:.+]] = linalg.generic
+// CHECK-SAME:         ins(%[[REDUCE]] :
+//      CHECK:     flow.dispatch.tensor.store %[[GENERIC]], %{{.+}}, offsets = [%[[IV0]]]
+
+// -----
+
+#config = #iree_codegen.lowering.config<tile_sizes = [[64, 0, 0], [8, 0, 0], [0, 0, 16]], native_vector_size = []>
+#executable_layout = #hal.executable.layout<push_constants = 4, sets = [
+  #hal.descriptor_set.layout<0, bindings = [
+    #hal.descriptor_set.binding<0, storage_buffer>,
+    #hal.descriptor_set.binding<1, storage_buffer>,
+    #hal.descriptor_set.binding<2, storage_buffer>
+  ]>
+]>
+#executable_target_embedded_elf_x86_64_ = #hal.executable.target<"llvm", "embedded-elf-x86_64", {
+  data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128",
+  native_vector_size = 16 : index,
+  target_triple = "x86_64-unknown-unknown-eabi-elf"}>
+#translation = #iree_codegen.translation.info<"CPUDoubleTilingExpert", workload_per_wg = []>
+hal.executable private @gemm_unit_N {
+  hal.executable.variant public @embedded_elf_x86_64, target = #executable_target_embedded_elf_x86_64_ {
+    hal.executable.entry_point public @gemm_unit_N ordinal(0) layout(#executable_layout) {translation.info = #translation}
+    builtin.module {
+      func @gemm_unit_N() {
+        %c0 = arith.constant 0 : index
+        %0 = hal.interface.constant.load[0] : index
+        %1 = hal.interface.constant.load[1] : index
+        %2 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(32)
+            : !flow.dispatch.tensor<readonly:?x?xf32>{%0, %1}
+        %3 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(32)
+            : !flow.dispatch.tensor<readonly:?x1xf32>{%1}
+        %4 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) offset(%c0) alignment(32)
+            : !flow.dispatch.tensor<readwrite:?x1xf32>{%0}
+        %5 = flow.dispatch.tensor.load %3, offsets = [0, 0], sizes = [%1, 1], strides = [1, 1]
+            : !flow.dispatch.tensor<readonly:?x1xf32>{%1} -> tensor<?x1xf32>
+        %6 = flow.dispatch.tensor.load %2, offsets = [0, 0], sizes = [%0, %1], strides = [1, 1]
+            : !flow.dispatch.tensor<readonly:?x?xf32>{%0, %1} -> tensor<?x?xf32>
+        %7 = flow.dispatch.tensor.load %4, offsets = [0, 0], sizes = [%0, 1], strides = [1, 1]
+            : !flow.dispatch.tensor<readwrite:?x1xf32>{%0} -> tensor<?x1xf32>
+        %8 = linalg.matmul {lowering.config = #config}
+            ins(%6, %5 : tensor<?x?xf32>, tensor<?x1xf32>) outs(%7 : tensor<?x1xf32>) -> tensor<?x1xf32>
+        flow.dispatch.tensor.store %8, %4, offsets = [0, 0], sizes = [%0, 1], strides = [1, 1]
+            : tensor<?x1xf32> -> !flow.dispatch.tensor<readwrite:?x1xf32>{%0}
+        return
+      }
+    }
+  }
+}
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 64)>
+//  CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 64)>
+//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"CPUDoubleTilingExpert", workload_per_wg = [64]>
+//      CHECK: hal.executable private @gemm_unit_N
+//      CHECK: hal.executable.entry_point public @gemm_unit_N
+// CHECK-SAME:  translation.info = #[[TRANSLATION]]
+// CHECK-NEXT:  (%[[ARG0:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: index)
+//  CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
+//  CHECK-DAG:   %[[D0:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]]]
+//      CHECK:   hal.return %[[D0]], %[[C1]], %[[C1]] : index, index, index
+//      CHECK: func @gemm_unit_N()
+//  CHECK-DAG:   %[[M:.+]] = hal.interface.constant.load[0]
+//  CHECK-DAG:   %[[WG_ID_X:.+]] = hal.interface.workgroup.id[0]
+//  CHECK-DAG:   %[[WG_COUNT_X:.+]] = hal.interface.workgroup.count[0]
+//  CHECK-DAG:   %[[LB:.+]] = affine.apply #[[MAP1]]()[%[[WG_ID_X]]]
+//  CHECK-DAG:   %[[STEP:.+]] = affine.apply #[[MAP1]]()[%[[WG_COUNT_X]]]
+//      CHECK:   scf.for %[[IV0:.+]] = %[[LB]] to %[[M]] step %[[STEP]]
+//  CHECK-NOT:     scf.for
+//      CHECK:     %[[GEMM:.+]] = linalg.matmul
+//      CHECK:     flow.dispatch.tensor.store %[[GEMM]],
+// CHECK-SAME:         offsets = [%[[IV0]], 0]
+
+// -----
+#config = #iree_codegen.lowering.config<tile_sizes = [[0, 0, 0], [0, 0, 0], [0, 0, 16]], native_vector_size = []>
+#executable_layout = #hal.executable.layout<push_constants = 4, sets = [
+  #hal.descriptor_set.layout<0, bindings = [
+    #hal.descriptor_set.binding<0, storage_buffer>,
+    #hal.descriptor_set.binding<1, storage_buffer>,
+    #hal.descriptor_set.binding<2, storage_buffer>
+  ]>
+]>
+#executable_target_embedded_elf_x86_64_ = #hal.executable.target<"llvm", "embedded-elf-x86_64", {
+  data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128",
+  native_vector_size = 16 : index,
+  target_triple = "x86_64-unknown-unknown-eabi-elf"}>
+#translation = #iree_codegen.translation.info<"CPUDoubleTilingExpert", workload_per_wg = []>
+hal.executable private @gemm_unit_M_unit_N {
+  hal.executable.variant public @embedded_elf_x86_64, target = #executable_target_embedded_elf_x86_64_ {
+    hal.executable.entry_point public @gemm_unit_M_unit_N ordinal(0) layout(#executable_layout) {translation.info = #translation}
+    builtin.module {
+      func @gemm_unit_M_unit_N() {
+        %c0 = arith.constant 0 : index
+        %0 = hal.interface.constant.load[0] : index
+        %1 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(32)
+            : !flow.dispatch.tensor<readonly:1x?xf32>{%0}
+        %2 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(32)
+            : !flow.dispatch.tensor<readonly:?x1xf32>{%0}
+        %3 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) offset(%c0) alignment(32)
+            : !flow.dispatch.tensor<readwrite:1x1xf32>
+        %4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [1, %0], strides = [1, 1]
+            : !flow.dispatch.tensor<readonly:1x?xf32>{%0} -> tensor<1x?xf32>
+        %5 = flow.dispatch.tensor.load %2, offsets = [0, 0], sizes = [%0, 1], strides = [1, 1]
+            : !flow.dispatch.tensor<readonly:?x1xf32>{%0} -> tensor<?x1xf32>
+        %6 = flow.dispatch.tensor.load %3, offsets = [0, 0], sizes = [1, 1], strides = [1, 1]
+            : !flow.dispatch.tensor<readwrite:1x1xf32> -> tensor<1x1xf32>
+        %7 = linalg.matmul {lowering.config = #config}
+            ins(%4, %5 : tensor<1x?xf32>, tensor<?x1xf32>) outs(%6 : tensor<1x1xf32>) -> tensor<1x1xf32>
+        flow.dispatch.tensor.store %7, %3, offsets = [0, 0], sizes = [1, 1], strides = [1, 1]
+            : tensor<1x1xf32> -> !flow.dispatch.tensor<readwrite:1x1xf32>
+        return
+      }
+    }
+  }
+}
+//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"CPUDoubleTilingExpert", workload_per_wg = []>
+//      CHECK: hal.executable private @gemm_unit_M_unit_N
+//      CHECK: hal.executable.entry_point public @gemm_unit_M_unit_N
+// CHECK-SAME:  translation.info = #[[TRANSLATION]]
+// CHECK-NEXT:  (%[[ARG0:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: index)
+//  CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
+//      CHECK:   hal.return %[[C1]], %[[C1]], %[[C1]] : index, index, index
+//      CHECK: func @gemm_unit_M_unit_N()
+//  CHECK-NOT:   scf.for
+//      CHECK:   %[[GEMM:.+]] = linalg.matmul
+//      CHECK:   flow.dispatch.tensor.store %[[GEMM]], %{{.+}}, offsets = [0, 0]
+
+// -----
+
+#config = #iree_codegen.lowering.config<tile_sizes = [[0, 0, 0, 0, 64, 64, 0, 64], [0, 1, 0, 0, 1, 1, 0, 4], [0, 0, 0, 0, 0, 0, 0, 0]], native_vector_size = []>
+#executable_layout = #hal.executable.layout<push_constants = 0, sets = [
+  #hal.descriptor_set.layout<0, bindings = [
+    #hal.descriptor_set.binding<0, storage_buffer>,
+    #hal.descriptor_set.binding<1, storage_buffer>,
+    #hal.descriptor_set.binding<2, storage_buffer>
+  ]>
+]>
+#executable_target_embedded_elf_x86_64_ = #hal.executable.target<"llvm", "embedded-elf-x86_64", {
+  data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128",
+  native_vector_size = 16 : index,
+  target_triple = "x86_64-unknown-linux-gnu"}>
 #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4, d5, d6, d7)>
-func @unit_dim_generic_op() {
-  %cst = arith.constant 0.000000e+00 : f32
-  %c0 = arith.constant 0 : index
-  %0 = hal.interface.constant.load[0] : i32
-  %1 = hal.interface.constant.load[1] : i32
-  %2 = hal.interface.constant.load[2] : i32
-  %3 = hal.interface.constant.load[3] : i32
-  %d0 = arith.index_cast %0 : i32 to index
-  %d1 = arith.index_cast %1 : i32 to index
-  %d2 = arith.index_cast %2 : i32 to index
-  %d3 = arith.index_cast %3 : i32 to index
-  %8 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(32)
-      : !flow.dispatch.tensor<readonly:1x?x1x1x?x?x1x?xf32>{%d0, %d1, %d2, %d3}
-  %9 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(32)
-      : !flow.dispatch.tensor<writeonly:1x?x1x1x?x?x1x?xf32>{%d0, %d1, %d2, %d3}
-  %10 = flow.dispatch.tensor.load %8, offsets = [0, 0, 0, 0, 0, 0, 0, 0], sizes = [1, %d0, 1, 1, %d1, %d2, 1, %d3],
-      strides = [1, 1, 1, 1, 1, 1, 1, 1] : !flow.dispatch.tensor<readonly:1x?x1x1x?x?x1x?xf32>{%d0, %d1, %d2, %d3} -> tensor<1x?x1x1x?x?x1x?xf32>
-  %13 = linalg.init_tensor [1, %d0, 1, 1, %d1, %d2, 1, %d3] : tensor<1x?x1x1x?x?x1x?xf32>
-  %15 = linalg.generic {
-    indexing_maps = [#map, #map],
-    iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]}
-    ins (%10: tensor<1x?x1x1x?x?x1x?xf32>)
-    outs (%13 : tensor<1x?x1x1x?x?x1x?xf32>) {
-      ^bb0(%arg0 : f32, %arg2 : f32):
-        %16 = arith.addf %arg0, %arg0 : f32
-        linalg.yield %16 : f32
-    } -> tensor<1x?x1x1x?x?x1x?xf32>
-  flow.dispatch.tensor.store %15, %9, offsets = [0, 0, 0, 0, 0, 0, 0, 0], sizes = [1, %d0, 1, 1, %d1, %d2, 1, %d3],
-      strides = [1, 1, 1, 1, 1, 1, 1, 1] : tensor<1x?x1x1x?x?x1x?xf32> -> !flow.dispatch.tensor<writeonly:1x?x1x1x?x?x1x?xf32>{%d0, %d1, %d2, %d3}
-  return
+#translation = #iree_codegen.translation.info<"CPUDoubleTilingExpert", workload_per_wg = []>
+hal.executable private @generic_unit_dims {
+  hal.executable.variant public @llvm, target = #executable_target_embedded_elf_x86_64_ {
+    hal.executable.entry_point public @generic_unit_dims layout(#executable_layout) {translation.info = #translation}
+    builtin.module {
+      func @generic_unit_dims() {
+        %0 = hal.interface.constant.load[0] : index
+        %1 = hal.interface.constant.load[1] : index
+        %2 = hal.interface.constant.load[2] : index
+        %3 = hal.interface.constant.load[3] : index
+        %4 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer)
+            : !flow.dispatch.tensor<readonly:1x?x1x1x?x?x1x?xf32>{%0, %1, %2, %3}
+        %5 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer)
+            : !flow.dispatch.tensor<writeonly:1x?x1x1x?x?x1x?xf32>{%0, %1, %2, %3}
+        %6 = flow.dispatch.tensor.load %4, offsets = [0, 0, 0, 0, 0, 0, 0, 0], sizes = [1, %0, 1, 1, %1, %2, 1, %3], strides = [1, 1, 1, 1, 1, 1, 1, 1]
+            : !flow.dispatch.tensor<readonly:1x?x1x1x?x?x1x?xf32>{%0, %1, %2, %3} -> tensor<1x?x1x1x?x?x1x?xf32>
+        %7 = linalg.init_tensor [1, %0, 1, 1, %1, %2, 1, %3] : tensor<1x?x1x1x?x?x1x?xf32>
+        %8 = linalg.generic {
+            indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]}
+            ins(%6 : tensor<1x?x1x1x?x?x1x?xf32>) outs(%7 : tensor<1x?x1x1x?x?x1x?xf32>) attrs =  {lowering.config = #config} {
+        ^bb0(%arg0: f32, %arg1: f32):
+          %9 = arith.addf %arg0, %arg0 : f32
+          linalg.yield %9 : f32
+        } -> tensor<1x?x1x1x?x?x1x?xf32>
+        flow.dispatch.tensor.store %8, %5, offsets = [0, 0, 0, 0, 0, 0, 0, 0], sizes = [1, %0, 1, 1, %1, %2, 1, %3], strides = [1, 1, 1, 1, 1, 1, 1, 1]
+            : tensor<1x?x1x1x?x?x1x?xf32> -> !flow.dispatch.tensor<writeonly:1x?x1x1x?x?x1x?xf32>{%0, %1, %2, %3}
+        return
+      }
+    }
+  }
 }
-// CHECK-LABEL: func @unit_dim_generic_op()
-//   CHECK-DAG:   %[[D0_VAL:.+]] = hal.interface.constant.load[0]
-//   CHECK-DAG:   %[[D1_VAL:.+]] = hal.interface.constant.load[1]
-//   CHECK-DAG:   %[[D2_VAL:.+]] = hal.interface.constant.load[2]
-//   CHECK-DAG:   %[[D3_VAL:.+]] = hal.interface.constant.load[3]
-//   CHECK-DAG:   %[[D0:.+]] = arith.index_cast %[[D0_VAL]]
-//   CHECK-DAG:   %[[D1:.+]] = arith.index_cast %[[D1_VAL]]
-//   CHECK-DAG:   %[[D2:.+]] = arith.index_cast %[[D2_VAL]]
-//   CHECK-DAG:   %[[D3:.+]] = arith.index_cast %[[D3_VAL]]
-//   CHECK-DAG:   %[[INPUT:.+]] = hal.interface.binding.subspan set(0) binding(0)
-//   CHECK-DAG:   %[[OUTPUT:.+]] = hal.interface.binding.subspan set(0) binding(1)
-//       CHECK:   scf.for %[[IV0:.+]] = %{{.+}} to %[[D1]]
-//       CHECK:     scf.for %[[IV1:.+]] = %{{.+}} to %[[D2]]
-//       CHECK:       scf.for %[[IV2:.+]] = %{{.+}} to %[[D3]]
-//       CHECK:         %[[INPUT_TILE:.+]] = flow.dispatch.tensor.load %[[INPUT]]
-//  CHECK-SAME:             offsets = [0, 0, 0, 0, %[[IV0]], %[[IV1]], 0, %[[IV2]]]
-//       CHECK:         %[[GENERIC_TILE:.+]] = linalg.generic
-//  CHECK-SAME:             ins(%[[INPUT_TILE]] :
-//       CHECK:         flow.dispatch.tensor.store %[[GENERIC_TILE]], %[[OUTPUT]]
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 64)>
+//  CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 64)>
+//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"CPUDoubleTilingExpert", workload_per_wg = [64, 64, 64]>
+//      CHECK: hal.executable private @generic_unit_dims
+//      CHECK: hal.executable.entry_point public @generic_unit_dims
+// CHECK-SAME:  translation.info = #[[TRANSLATION]]
+// CHECK-NEXT:  (%[[ARG0:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: index)
+//  CHECK-DAG:   %[[D0:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]]]
+//  CHECK-DAG:   %[[D1:.+]] = affine.apply #[[MAP0]]()[%[[ARG1]]]
+//  CHECK-DAG:   %[[D2:.+]] = affine.apply #[[MAP0]]()[%[[ARG2]]]
+//      CHECK:   hal.return %[[D0]], %[[D1]], %[[D2]] : index, index, index
+//      CHECK: func @generic_unit_dims()
+//      CHECK:   scf.for %[[IV0:.+]] =
+//      CHECK:     scf.for %[[IV1:.+]] =
+//      CHECK:       scf.for %[[IV2:.+]] =
+//      CHECK:         %[[GENERIC:.+]] = linalg.generic
+//      CHECK:         flow.dispatch.tensor.store %[[GENERIC]],
+// CHECK-SAME:             offsets = [0, 0, 0, 0, %[[IV0]], %[[IV1]], 0, %[[IV2]]]
 
 // -----
-
-func @repeated_indices_scatter_update_slice_2D() {
-  %c0 = arith.constant 0 : index
-  %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(32) : !flow.dispatch.tensor<readonly:2x3xi32>
-  %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(32) : !flow.dispatch.tensor<readonly:2x1xi32>
-  %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) offset(%c0) alignment(32) : !flow.dispatch.tensor<readwrite:6x3xi32>
-  %3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [2, 3], strides = [1, 1] : !flow.dispatch.tensor<readonly:2x3xi32> -> tensor<2x3xi32>
-  %4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [2, 1], strides = [1, 1] : !flow.dispatch.tensor<readonly:2x1xi32> -> tensor<2x1xi32>
-  %5 = flow.dispatch.tensor.load %2, offsets = [0, 0], sizes = [6, 3], strides = [1, 1] : !flow.dispatch.tensor<readwrite:6x3xi32> -> tensor<6x3xi32>
-  %6 = iree_linalg_ext.scatter
-    unique_indices(false)
-    ins(%3, %4 : tensor<2x3xi32>, tensor<2x1xi32>)
-    outs(%5 : tensor<6x3xi32>) {
-  ^bb0(%arg0: i32, %arg1: i32):
-    iree_linalg_ext.yield %arg0 : i32
-  } -> tensor<6x3xi32>
-  flow.dispatch.tensor.store %6, %2, offsets = [0, 0], sizes = [6, 3], strides = [1, 1] : tensor<6x3xi32> -> !flow.dispatch.tensor<readwrite:6x3xi32>
-  return
+#config = #iree_codegen.lowering.config<tile_sizes = [[0], [0], [4]], native_vector_size = []>
+#executable_layout = #hal.executable.layout<push_constants = 0, sets = [
+  #hal.descriptor_set.layout<0, bindings = [
+    #hal.descriptor_set.binding<0, storage_buffer>,
+    #hal.descriptor_set.binding<1, storage_buffer>,
+    #hal.descriptor_set.binding<2, storage_buffer>
+  ]>
+]>
+#executable_target_embedded_elf_x86_64_ = #hal.executable.target<"llvm", "embedded-elf-x86_64", {
+  data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128",
+  native_vector_size = 16 : index,
+  target_triple = "x86_64-unknown-linux-gnu"}>
+#map0 = affine_map<(d0) -> (d0)>
+#map1 = affine_map<(d0) -> ()>
+#translation = #iree_codegen.translation.info<"CPUDoubleTilingExpert", workload_per_wg = []>
+hal.executable private @reduce_to_scalar {
+  hal.executable.variant public @llvm, target = #executable_target_embedded_elf_x86_64_ {
+    hal.executable.entry_point public @reduce_to_scalar layout(#executable_layout) {translation.info = #translation}
+    builtin.module {
+      func @reduce_to_scalar() {
+        %0 = hal.interface.constant.load[0] : index
+        %1 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer)
+            : !flow.dispatch.tensor<readonly:?xf32>{%0}
+        %2 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer)
+            : !flow.dispatch.tensor<readwrite:f32>
+        %3 = flow.dispatch.tensor.load %1, offsets = [0], sizes = [%0], strides = [1]
+            : !flow.dispatch.tensor<readonly:?xf32>{%0} -> tensor<?xf32>
+        %4 = flow.dispatch.tensor.load %2, offsets = [], sizes = [], strides = []
+            : !flow.dispatch.tensor<readwrite:f32> -> tensor<f32>
+        %5 = linalg.generic {
+            indexing_maps = [#map0, #map1], iterator_types = ["reduction"]}
+            ins(%3 : tensor<?xf32>) outs(%4 : tensor<f32>) attrs =  {lowering.config = #config} {
+        ^bb0(%arg0: f32, %arg1: f32):
+          %6 = arith.addf %arg0, %arg1 : f32
+          linalg.yield %6 : f32
+        } -> tensor<f32>
+        flow.dispatch.tensor.store %5, %2, offsets = [], sizes = [], strides = []
+            : tensor<f32> -> !flow.dispatch.tensor<readwrite:f32>
+        return
+      }
+    }
+  }
 }
-//      CHECK: func @repeated_indices_scatter_update_slice_2D
-//  CHECK-DAG:   %[[ARG0_CAPTURE:[a-zA-Z0-9_]+]] = {{.+}} !flow.dispatch.tensor<readonly:2x3xi32>
-//  CHECK-DAG:   %[[ARG1_CAPTURE:[a-zA-Z0-9_]+]] = {{.+}} !flow.dispatch.tensor<readonly:2x1xi32>
-//  CHECK-DAG:   %[[ARG2_CAPTURE:[a-zA-Z0-9_]+]] = {{.+}} !flow.dispatch.tensor<readwrite:6x3xi32>
-//  CHECK-DAG:   %[[UPDATE:.+]] = flow.dispatch.tensor.load %[[ARG0_CAPTURE]], offsets = [0, 0]
-//  CHECK-DAG:   %[[INDICES:.+]] = flow.dispatch.tensor.load %[[ARG1_CAPTURE]], offsets = [0, 0]
-//  CHECK-DAG:   %[[ORIGINAL:.+]] = flow.dispatch.tensor.load %[[ARG2_CAPTURE]], offsets = [0, 0]
-//  CHECK-DAG:   %[[SCATTER:.+]] = iree_linalg_ext.scatter
-// CHECK-SAME:       unique_indices(false)
-// CHECK-SAME:       ins(%[[UPDATE]], %[[INDICES]] : tensor<2x3xi32>, tensor<2x1xi32>)
-// CHECK-SAME:       outs(%[[ORIGINAL]] : tensor<6x3xi32>)
-//      CHECK:   flow.dispatch.tensor.store %[[SCATTER]], %[[ARG2_CAPTURE]], offsets = [0, 0]
+//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"CPUDoubleTilingExpert", workload_per_wg = []>
+//      CHECK: hal.executable private @reduce_to_scalar
+//      CHECK: hal.executable.entry_point public @reduce_to_scalar
+// CHECK-SAME:  translation.info = #[[TRANSLATION]]
+// CHECK-NEXT:  (%[[ARG0:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: index)
+//      CHECK:   %[[C1:.+]] = arith.constant 1 : index
+//      CHECK:   hal.return %[[C1]], %[[C1]], %[[C1]] : index, index, index
+//      CHECK: func @reduce_to_scalar()
+//  CHECK-NOT:   scf.for
+
+// -----
+#executable_layout = #hal.executable.layout<push_constants = 0, sets = [
+  #hal.descriptor_set.layout<0, bindings = [
+    #hal.descriptor_set.binding<0, storage_buffer>,
+    #hal.descriptor_set.binding<1, storage_buffer>,
+    #hal.descriptor_set.binding<2, storage_buffer>
+  ]>
+]>
+#executable_target_embedded_elf_x86_64_ = #hal.executable.target<"llvm", "embedded-elf-x86_64", {
+  data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128",
+  native_vector_size = 16 : index,
+  target_triple = "x86_64-unknown-linux-gnu"}>
+#map = affine_map<() -> ()>
+#translation = #iree_codegen.translation.info<"CPUDefault", workload_per_wg = []>
+hal.executable private @scalar {
+  hal.executable.variant public @llvm, target = #executable_target_embedded_elf_x86_64_ {
+    hal.executable.entry_point public @scalar layout(#executable_layout) {translation.info = #translation}
+    builtin.module {
+      func @scalar() {
+        %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer)
+            : !flow.dispatch.tensor<readonly:f32>
+        %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer)
+            : !flow.dispatch.tensor<writeonly:f32>
+        %2 = flow.dispatch.tensor.load %0, offsets = [], sizes = [], strides = []
+            : !flow.dispatch.tensor<readonly:f32> -> tensor<f32>
+        %3 = flow.dispatch.tensor.load %1, offsets = [], sizes = [], strides = []
+            : !flow.dispatch.tensor<writeonly:f32> -> tensor<f32>
+        %4 = linalg.generic {
+            indexing_maps = [#map, #map], iterator_types = []}
+            ins(%2 : tensor<f32>) outs(%3 : tensor<f32>) {
+        ^bb0(%arg0: f32, %arg1: f32):
+          %5 = arith.addf %arg0, %arg1 : f32
+          linalg.yield %5 : f32
+        } -> tensor<f32>
+        flow.dispatch.tensor.store %4, %1, offsets = [], sizes = [], strides = []
+            : tensor<f32> -> !flow.dispatch.tensor<writeonly:f32>
+        return
+      }
+    }
+  }
+}
+//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"CPUDefault", workload_per_wg = []>
+//      CHECK: hal.executable private @scalar
+//      CHECK: hal.executable.entry_point public @scalar
+// CHECK-SAME:  translation.info = #[[TRANSLATION]]
+// CHECK-NEXT:  (%[[ARG0:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: index)
+//      CHECK:   %[[C1:.+]] = arith.constant 1 : index
+//      CHECK:   hal.return %[[C1]], %[[C1]], %[[C1]] : index, index, index
+//      CHECK: func @scalar()
+//  CHECK-NOT:   scf.for
diff --git a/iree/compiler/Codegen/Dialect/LoweringConfig.cpp b/iree/compiler/Codegen/Dialect/LoweringConfig.cpp
index f3c7409..e8df94c 100644
--- a/iree/compiler/Codegen/Dialect/LoweringConfig.cpp
+++ b/iree/compiler/Codegen/Dialect/LoweringConfig.cpp
@@ -329,48 +329,6 @@
   op->setAttr(kConfigAttrName, config);
 }
 
-LogicalResult setOpConfigAndEntryPointFnTranslation(
-    FuncOp entryPointFn, Operation *op,
-    IREE::Codegen::LoweringConfigAttr config,
-    IREE::Codegen::DispatchLoweringPassPipeline passPipeline,
-    ArrayRef<int64_t> workgroupSize) {
-  auto interfaceOp = cast<IREE::Flow::PartitionableLoopsInterface>(*op);
-  auto partitionedLoops =
-      interfaceOp.getPartitionableLoops(kNumMaxParallelDims);
-  SmallVector<int64_t, 3> workloadPerWorkgroup;
-  auto tileSizes = config.getTileSizeVals(0);
-  if (!tileSizes.empty() && !partitionedLoops.empty()) {
-    for (unsigned depth : partitionedLoops) {
-      if (depth >= tileSizes.size()) {
-        return op->emitOpError(
-                   "illegal configuration for lowering op, expect first level "
-                   "tile size to contain at least ")
-               << partitionedLoops.back() << " elements";
-      }
-      if (tileSizes[depth] == 0) {
-        return op->emitOpError("illegal to set tilesize of loop ")
-               << depth
-               << " to zero since it is set to be partitioned at the flow "
-                  "level";
-      }
-      workloadPerWorkgroup.push_back(tileSizes[depth]);
-    }
-    if (!workloadPerWorkgroup.empty()) {
-      workloadPerWorkgroup =
-          llvm::to_vector<3>(llvm::reverse(workloadPerWorkgroup));
-    }
-  }
-  auto entryPointOp = getEntryPoint(entryPointFn);
-  if (!entryPointOp) {
-    return entryPointFn.emitOpError(
-        "unable to find entry point op for entry point function");
-  }
-  auto translationInfo = IREE::Codegen::TranslationInfoAttr::get(
-      entryPointOp->getContext(), passPipeline, workloadPerWorkgroup);
-  setTranslationInfo(entryPointOp, translationInfo, workgroupSize);
-  return success();
-}
-
 //===----------------------------------------------------------------------===//
 // Helpers for getting/setting `iree_codegen.compilation.info` attribute on root
 // operations to override IREEs default compilation.
diff --git a/iree/compiler/Codegen/Dialect/LoweringConfig.h b/iree/compiler/Codegen/Dialect/LoweringConfig.h
index 6d99215..4a01712 100644
--- a/iree/compiler/Codegen/Dialect/LoweringConfig.h
+++ b/iree/compiler/Codegen/Dialect/LoweringConfig.h
@@ -109,11 +109,16 @@
 void setLoweringConfig(Operation *op, IREE::Codegen::LoweringConfigAttr config);
 
 /// Sets translation for the entry-point function based on op configuration.
-LogicalResult setOpConfigAndEntryPointFnTranslation(
+inline LogicalResult setOpConfigAndEntryPointFnTranslation(
     FuncOp entryPointFn, Operation *op,
     IREE::Codegen::LoweringConfigAttr config,
     IREE::Codegen::DispatchLoweringPassPipeline passPipeline,
-    ArrayRef<int64_t> workgroupSize = {});
+    ArrayRef<int64_t> workgroupSize = {}) {
+  auto translationInfo = IREE::Codegen::TranslationInfoAttr::get(
+      entryPointFn->getContext(), passPipeline, ArrayRef<int64_t>{});
+  setTranslationInfo(entryPointFn, translationInfo, workgroupSize);
+  return success();
+}
 inline LogicalResult setOpConfigAndEntryPointFnTranslation(
     FuncOp entryPointFn, Operation *op, TileSizesListTypeRef tileSizes,
     ArrayRef<int64_t> nativeVectorSize,
diff --git a/iree/compiler/Codegen/Dialect/LoweringConfig.td b/iree/compiler/Codegen/Dialect/LoweringConfig.td
index 4609041..204eaee 100644
--- a/iree/compiler/Codegen/Dialect/LoweringConfig.td
+++ b/iree/compiler/Codegen/Dialect/LoweringConfig.td
@@ -135,7 +135,7 @@
     TileSizesListType getTileSizeVals();
 
     // Returns the tile sizes for a level set for the op.
-    SmallVector<int64_t> getTileSizeVals(unsigned level = 0);
+    SmallVector<int64_t> getTileSizeVals(unsigned level);
 
     // Returns the native vector size to use.
     SmallVector<int64_t> getNativeVectorSizeVals();
diff --git a/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp b/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp
index 95bcd63..384c29e 100644
--- a/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp
+++ b/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp
@@ -19,6 +19,7 @@
 #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/MemRef/Transforms/Passes.h"
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/IR/Matchers.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -125,51 +126,43 @@
   return referenceTypeLengthInBytes;
 }
 
-static SmallVector<int64_t> getDefaultWorkloadPerWorkgroup(
-    ArrayRef<LoopTilingAndDistributionInfo> tiledLoops,
-    ArrayRef<int64_t> nativeVectorSizeInElements) {
-  if (tiledLoops.empty()) {
+/// Returns the default tile sizes to use for the loops that are distributed at
+/// Flow level.
+static SmallVector<int64_t> getDefaultDistributedLoopTileSizes(
+    ArrayRef<int64_t> lbs, ArrayRef<int64_t> ubs,
+    ArrayRef<int64_t> minTileSizes, ArrayRef<int64_t> maxTileSizes) {
+  assert(lbs.size() == ubs.size() && lbs.size() == minTileSizes.size() &&
+         lbs.size() == maxTileSizes.size() &&
+         "expected all vectors to be of equal size");
+  if (lbs.empty()) {
     return {};
   }
-  assert(tiledLoops.size() == nativeVectorSizeInElements.size());
-  unsigned maxDim = 0;
-  for (auto tiledLoop : tiledLoops) {
-    maxDim = std::max<unsigned>(tiledLoop.processorDistributionDim, maxDim);
-  }
-  SmallVector<int64_t> workloadPerWorkgroup(maxDim + 1, 1);
-  SmallVector<int64_t> numWorkgroupsPerDim(maxDim + 1, 1);
-  SmallVector<int64_t> workload(maxDim + 1, 1);
-  auto getStaticValue = [](OpFoldResult ofr) -> Optional<int64_t> {
-    return (ofr ? getConstantIntValue(ofr) : llvm::None);
-  };
+  size_t numDims = lbs.size();
+  SmallVector<int64_t> distributedTileSizes(numDims, 1);
+  SmallVector<int64_t> numWorkgroupsPerDim(numDims, 1);
+  SmallVector<int64_t> workload(numDims, 1);
   auto ceilFn = [](int64_t a, int64_t b) { return (a + b - 1) / b; };
 
-  for (auto tiledLoop : enumerate(tiledLoops)) {
-    Optional<int64_t> lb = getStaticValue(tiledLoop.value().untiledLowerBound);
-    Optional<int64_t> ub = getStaticValue(tiledLoop.value().untiledUpperBound);
-    unsigned dim = tiledLoop.value().processorDistributionDim;
-    if (!lb || !ub) {
-      workloadPerWorkgroup[dim] = defaultWorkgroupTileSize;
-      workload[dim] = ShapedType::kDynamicSize;
+  for (auto i : llvm::seq<size_t>(0, numDims)) {
+    if (ShapedType::isDynamic(lbs[i]) || ShapedType::isDynamic(ubs[i])) {
+      distributedTileSizes[i] = maxTileSizes[i];
+      workload[i] = ShapedType::kDynamicSize;
       continue;
     }
-    int64_t candidateTileSize = nativeVectorSizeInElements[tiledLoop.index()];
-    if (*ub <= *lb) {
-      // Should be avoiding tiling this loop, but use tile size of 1.
-      candidateTileSize = 1;
-    } else {
+    int64_t candidateTileSize = 1;
+    if (ubs[i] > lbs[i]) {
       // Pick a value that evenly distributes the workload.
       candidateTileSize = std::max<int64_t>(
-          llvm::PowerOf2Floor(static_cast<uint64_t>(*ub - *lb) / 2),
-          candidateTileSize);
+          llvm::PowerOf2Floor(static_cast<uint64_t>(ubs[i] - lbs[i]) / 2),
+          minTileSizes[i]);
     }
 
     // Limit the workload per workgroup to the default being the max to keep the
     // work per invocation reasonable.
-    workloadPerWorkgroup[dim] =
-        std::min<int64_t>(candidateTileSize, defaultWorkgroupTileSize);
-    workload[dim] = (*ub <= *lb ? 1 : *ub - *lb);
-    numWorkgroupsPerDim[dim] = ceilFn(workload[dim], workloadPerWorkgroup[dim]);
+    distributedTileSizes[i] =
+        std::min<int64_t>(candidateTileSize, maxTileSizes[i]);
+    workload[i] = (ubs[i] <= lbs[i] ? 1 : ubs[i] - lbs[i]);
+    numWorkgroupsPerDim[i] = ceilFn(workload[i], distributedTileSizes[i]);
   }
 
   // Reduce the number of workgroups in cases where we are dividing the work too
@@ -180,26 +173,26 @@
   for (auto ng : numWorkgroupsPerDim) {
     numWorkgroups *= ng;
   }
-  unsigned currDim = 0;
-  while (numWorkgroups > numWorkgroupsLimit &&
-         currDim < numWorkgroupsPerDim.size()) {
-    if (workloadPerWorkgroup[currDim] >= defaultWorkgroupTileSize ||
-        workload[currDim] == ShapedType::kDynamicSize ||
-        workloadPerWorkgroup[currDim] >= workload[currDim]) {
-      currDim++;
+  unsigned currDim = numDims;
+  while (numWorkgroups > numWorkgroupsLimit && currDim > 0) {
+    if (distributedTileSizes[currDim - 1] >= maxTileSizes[currDim - 1] ||
+        workload[currDim - 1] == ShapedType::kDynamicSize ||
+        distributedTileSizes[currDim - 1] >= workload[currDim - 1]) {
+      currDim--;
       continue;
     }
-    workloadPerWorkgroup[currDim] = std::min<int64_t>(
-        workloadPerWorkgroup[currDim] * 2, defaultWorkgroupTileSize);
-    int64_t nwg = ceilFn(workload[currDim], workloadPerWorkgroup[currDim]);
-    if (nwg < numWorkgroupsPerDim[currDim]) {
-      numWorkgroups /= numWorkgroupsPerDim[currDim];
+    distributedTileSizes[currDim - 1] = std::min<int64_t>(
+        distributedTileSizes[currDim - 1] * 2, maxTileSizes[currDim - 1]);
+    int64_t nwg =
+        ceilFn(workload[currDim - 1], distributedTileSizes[currDim - 1]);
+    if (nwg < numWorkgroupsPerDim[currDim - 1]) {
+      numWorkgroups /= numWorkgroupsPerDim[currDim - 1];
       numWorkgroups *= nwg;
     } else {
-      currDim++;
+      currDim--;
     }
   }
-  return workloadPerWorkgroup;
+  return distributedTileSizes;
 }
 
 /// Adjusts the workload per workgroup to be a multiple of vector size to ensure
@@ -210,77 +203,122 @@
     return maxSize;
   }
   int64_t dim = ub - lb;
-  if (dim < vectorSizeVal) return 0;
+  if (dim < vectorSizeVal) return dim;
   for (int64_t i = std::min(maxSize, dim); i > 0; --i) {
     if (dim % i == 0 && i % vectorSizeVal == 0) {
       return i;
     }
   }
-  return maxSize;
+  return vectorSizeVal;
 }
 
-/// Compute the workload per workgroup. The `vectorSize` is expected to contain
-/// the vector size to use along each loop of the `interfaceOp`.
-static SmallVector<int64_t> getDefaultWorkloadPerWorkgroup(
-    ArrayRef<LoopTilingAndDistributionInfo> tiledLoops,
-    ArrayRef<unsigned> partitionedLoops, ArrayRef<int64_t> vectorSize) {
-  if (tiledLoops.empty()) {
-    // Nothing to do.
-    return {};
+/// Returns the tile size to use for the Flow level of an operation that
+/// implements the `PartitionableLoopsInterface`.
+static SmallVector<int64_t> getDefaultDistributedLevelTileSizes(
+    ArrayRef<Range> iterationDomain,
+    IREE::Flow::PartitionableLoopsInterface partitionableLoopInterfaceOp,
+    ArrayRef<int64_t> minTileSizes, ArrayRef<int64_t> maxTileSizes) {
+  assert(iterationDomain.size() == minTileSizes.size() &&
+         "expected as many min tile sizes as number of loops");
+  auto getStaticValue = [](Value v) -> int64_t {
+    IntegerAttr attr;
+    if (!matchPattern(v, m_Constant(&attr))) return ShapedType::kDynamicSize;
+    return attr.getInt();
+  };
+  auto lbs = llvm::to_vector(llvm::map_range(
+      iterationDomain, [&](Range r) { return getStaticValue(r.offset); }));
+  auto ubs = llvm::to_vector(llvm::map_range(
+      iterationDomain, [&](Range r) { return getStaticValue(r.size); }));
+
+  SmallVector<unsigned> partitionableLoops =
+      partitionableLoopInterfaceOp.getPartitionableLoops(kNumMaxParallelDims);
+  llvm::SmallDenseSet<unsigned, 4> partitionableLoopsSet;
+  partitionableLoopsSet.insert(partitionableLoops.begin(),
+                               partitionableLoops.end());
+
+  size_t numPartitionedLoops = partitionableLoops.size();
+  SmallVector<int64_t> distributedLoopLbs(numPartitionedLoops,
+                                          ShapedType::kDynamicSize),
+      distributedLoopUbs(numPartitionedLoops, ShapedType::kDynamicSize),
+      minDistributedLoopTileSizes(numPartitionedLoops, 1),
+      maxDistributedLoopTileSizes(numPartitionedLoops,
+                                  defaultWorkgroupTileSize);
+  // Find the bounds of the partitionable loops
+  unsigned index = 0;
+  for (auto range : llvm::enumerate(iterationDomain)) {
+    if (!partitionableLoopsSet.count(range.index())) continue;
+
+    minDistributedLoopTileSizes[index] = minTileSizes[range.index()];
+    maxDistributedLoopTileSizes[index] = maxTileSizes[range.index()];
+    distributedLoopLbs[index] = lbs[range.index()];
+    distributedLoopUbs[index] = ubs[range.index()];
+    index++;
   }
 
-  assert(partitionedLoops.size() == tiledLoops.size() &&
-         "mismatch in expected parallelization");
-  SmallVector<int64_t> partitionedLoopsVectorSize(tiledLoops.size(), 1);
-  for (auto loopDim : llvm::enumerate(partitionedLoops)) {
-    partitionedLoopsVectorSize[loopDim.index()] = vectorSize[loopDim.value()];
+  SmallVector<int64_t> distributedTileSizes =
+      getDefaultDistributedLoopTileSizes(distributedLoopLbs, distributedLoopUbs,
+                                         minDistributedLoopTileSizes,
+                                         maxDistributedLoopTileSizes);
+  SmallVector<int64_t> distributedLevelTileSizes(iterationDomain.size(), 0);
+  for (auto loopID : llvm::enumerate(partitionableLoops)) {
+    distributedLevelTileSizes[loopID.value()] =
+        distributedTileSizes[loopID.index()];
   }
+  // Final fix up of the tile sizes to make sure that they divide the problem
+  // size to make it vectorizable.
+  for (auto i : llvm::seq<unsigned>(0, distributedLevelTileSizes.size())) {
+    distributedLevelTileSizes[i] =
+        distributedLevelTileSizes[i] != 0
+            ? getMaxTileSize(lbs[i], ubs[i], distributedLevelTileSizes[i],
+                             minTileSizes[i])
+            : 0;
+  }
+  return distributedLevelTileSizes;
+}
 
-  SmallVector<int64_t> workLoadPerWorkgroup =
-      getDefaultWorkloadPerWorkgroup(tiledLoops, partitionedLoopsVectorSize);
-  for (auto tiledLoop : llvm::enumerate(tiledLoops)) {
-    Optional<int64_t> lb =
-        getConstantIntValue(tiledLoop.value().untiledLowerBound);
-    Optional<int64_t> ub =
-        getConstantIntValue(tiledLoop.value().untiledUpperBound);
-    if (!lb || !ub) continue;
-    unsigned workloadIndex = tiledLoops.size() - 1 - tiledLoop.index();
-    workLoadPerWorkgroup[workloadIndex] = getMaxTileSize(
-        lb.getValue(), ub.getValue(), workLoadPerWorkgroup[workloadIndex],
-        partitionedLoopsVectorSize[tiledLoop.index()]);
-    if (workLoadPerWorkgroup[workloadIndex] == 0) {
-      // If the tile size chosen is 0 set the workLoadPerWorkgroup to problem
-      // size.
-      workLoadPerWorkgroup[workloadIndex] = ub.getValue() - lb.getValue();
+/// Sets the default configuration to use for an operation that implements the
+/// `PartitionableLoopsInterface`, given the iteration domain of all the loops.
+static LogicalResult setDefaultRootConfig(
+    FuncOp entryPointFn,
+    IREE::Flow::PartitionableLoopsInterface partitionableLoopsInterfaceOp,
+    ArrayRef<Range> iterationDomain) {
+  if (getLoweringConfig(partitionableLoopsInterfaceOp)) return success();
+
+  SmallVector<unsigned> partitionableLoops =
+      partitionableLoopsInterfaceOp.getPartitionableLoops(kNumMaxParallelDims);
+
+  SmallVector<int64_t> minTileSizes(iterationDomain.size(), 1);
+  SmallVector<int64_t> maxTileSizes(iterationDomain.size(), 1);
+  if (!partitionableLoops.empty()) {
+    // TODO: Here the min tile size is just looking at the type of the data in
+    // the entry point function, and using a vector size that depends on just
+    // that. For `LinalgOp`s we can use the indexing map, find the loops that
+    // are fastest varying and set those to have a min tile size of vector
+    // length. A version of this is done for generic ops. Generalize that and
+    // use it for `LinalgOp`s.
+    unsigned typeWidthInBytes = getReferenceTypeLengthInBytes(entryPointFn);
+    minTileSizes[partitionableLoops.back()] =
+        getVectorSize(entryPointFn, typeWidthInBytes);
+    for (auto partitionableLoopId : partitionableLoops) {
+      maxTileSizes[partitionableLoopId] = defaultWorkgroupTileSize;
     }
   }
-  return workLoadPerWorkgroup;
+
+  SmallVector<int64_t> flowTileSizes = getDefaultDistributedLevelTileSizes(
+      iterationDomain, partitionableLoopsInterfaceOp, minTileSizes,
+      maxTileSizes);
+  TileSizesListType tileSizes;
+  tileSizes.emplace_back(std::move(flowTileSizes));
+  return setOpConfigAndEntryPointFnTranslation(
+      entryPointFn, partitionableLoopsInterfaceOp, tileSizes,
+      /*nativeVectorSize=*/ArrayRef<int64_t>{},
+      DispatchLoweringPassPipeline::CPUDefault);
 }
 
-/// Sets the default launch configuration to use for a tiled + distributed
-/// dispatch region based on the `tiledLoops` found.
-static LogicalResult setDefaultLaunchConfig(
-    FuncOp entryPointFn, ArrayRef<LoopTilingAndDistributionInfo> tiledLoops) {
-  SmallVector<int64_t> nativeVectorSizeInElements(tiledLoops.size(), 1);
-  if (!tiledLoops.empty()) {
-    unsigned typeWidthInBytes = getReferenceTypeLengthInBytes(entryPointFn);
-    nativeVectorSizeInElements.back() =
-        getVectorSize(entryPointFn, typeWidthInBytes);
-  }
-
-  SmallVector<int64_t> workloadPerWorkgroup =
-      getDefaultWorkloadPerWorkgroup(tiledLoops, nativeVectorSizeInElements);
-
-  setTranslationInfo(entryPointFn, DispatchLoweringPassPipeline::CPUDefault,
-                     workloadPerWorkgroup,
-                     /*workgroupSize =*/ArrayRef<int64_t>{});
-  return success();
-}
-
-static LogicalResult setX86SandboxRootConfig(
-    FuncOp entryPointFn, linalg::ContractionOpInterface op,
-    ArrayRef<int64_t> flowTileSizes, ArrayRef<unsigned> partionableLoops,
-    int vectorSize) {
+static LogicalResult setX86SandboxRootConfig(FuncOp entryPointFn,
+                                             linalg::ContractionOpInterface op,
+                                             ArrayRef<int64_t> flowTileSizes,
+                                             int vectorSize) {
   // Hardcoded tiling sizes {1, 1, ..., 8, 32, 16}.
   // The tiling for parallel dims and reduction dims should be separated.
   SmallVector<int64_t> l1TileSizes;
@@ -289,13 +327,6 @@
   l1TileSizes.push_back(getMaxTileSize(0, flowTileSizes[nLoops - 3], 8, 8));
   l1TileSizes.push_back(getMaxTileSize(0, flowTileSizes[nLoops - 2], 32, 32));
   l1TileSizes.push_back(0);
-  llvm::SmallDenseSet<unsigned> pLoopsSet;
-  for (auto i : partionableLoops) pLoopsSet.insert(i);
-  for (auto en : llvm::enumerate(l1TileSizes)) {
-    if (en.value() != 0 && !pLoopsSet.contains(en.index())) {
-      l1TileSizes[en.index()] = 0;
-    }
-  }
 
   auto lhsShapedType = op.lhs().getType().cast<ShapedType>();
   int64_t K = lhsShapedType.getShape().back();
@@ -304,14 +335,13 @@
   vectorTileSizes.push_back(getMaxTileSize(0, K, 16, 16));
 
   TileSizesListType tileSizes;
-  tileSizes.push_back({});
+  tileSizes.emplace_back(flowTileSizes.begin(), flowTileSizes.end());
   tileSizes.push_back(l1TileSizes);
   tileSizes.push_back(vectorTileSizes);
-  auto config = IREE::Codegen::LoweringConfigAttr::get(
-      entryPointFn.getContext(), tileSizes, {});
-  setLoweringConfig(op, config);
 
-  return success();
+  return setOpConfigAndEntryPointFnTranslation(
+      entryPointFn, op, tileSizes, /*nativeVectorSize=*/ArrayRef<int64_t>{},
+      DispatchLoweringPassPipeline::CPUDoubleTilingExpert);
 }
 
 static LogicalResult setX86TileFuseAndVectorizeRootConfig(
@@ -336,15 +366,13 @@
   l1TileSizes.push_back(getMaxTileSize(0, K, 2 * vectorSize, vectorSize));
   vectorTileSizes.push_back(vectorSize);
   TileSizesListType tileSizes;
-  tileSizes.push_back({});  // Empty here since there is nothing to do in first
-                            // level tiling.
+  tileSizes.emplace_back(flowTileSizes.begin(), flowTileSizes.end());
   tileSizes.push_back(l1TileSizes);
   tileSizes.push_back(vectorTileSizes);
-  auto config = IREE::Codegen::LoweringConfigAttr::get(
-      entryPointFn.getContext(), tileSizes, vectorTileSizes);
-  setLoweringConfig(op, config);
 
-  return success();
+  return setOpConfigAndEntryPointFnTranslation(
+      entryPointFn, op, tileSizes, vectorTileSizes,
+      DispatchLoweringPassPipeline::CPUTileFuseAndVectorize);
 }
 
 static LogicalResult setARMRootConfig(FuncOp entryPointFn,
@@ -371,15 +399,13 @@
   l1TileSizes.push_back(getMaxTileSize(0, K, 16 * vectorSize, vectorSize));
   vectorTileSizes.push_back(vectorSize);
   TileSizesListType tileSizes;
-  tileSizes.push_back({});  // Empty here since there is nothing to do in first
-                            // level tiling.
+  tileSizes.emplace_back(flowTileSizes.begin(), flowTileSizes.end());
   tileSizes.push_back(l1TileSizes);
   tileSizes.push_back(vectorTileSizes);
-  auto config = IREE::Codegen::LoweringConfigAttr::get(
-      entryPointFn.getContext(), tileSizes, vectorTileSizes);
-  setLoweringConfig(op, config);
 
-  return success();
+  return setOpConfigAndEntryPointFnTranslation(
+      entryPointFn, op, tileSizes, vectorTileSizes,
+      DispatchLoweringPassPipeline::CPUTileFuseAndVectorize);
 }
 
 /// Sets the lowering configuration for dispatch region with root op that
@@ -387,38 +413,28 @@
 static LogicalResult setRootConfig(
     FuncOp entryPointFn, linalg::ContractionOpInterface contractionOp,
     ArrayRef<LoopTilingAndDistributionInfo> tiledLoops) {
+  auto linalgOp = cast<linalg::LinalgOp>(contractionOp.getOperation());
   auto lhsShapedType = contractionOp.lhs().getType().cast<ShapedType>();
   // Use the default distribution for the matmul loops.
-  unsigned numBatchDims = 0;
-  auto interfaceOp = cast<IREE::Flow::PartitionableLoopsInterface>(
-      contractionOp.getOperation());
-  unsigned numLoops = interfaceOp.getNumLoops();
-  SmallVector<unsigned> partitionedLoops =
-      interfaceOp.getPartitionableLoops(kNumMaxParallelDims);
-  // The batch dim is distributed if numLoops > 3 and partitionedLoops.begin()
-  // == 0.
-  if (numLoops > 3 && !partitionedLoops.empty() && partitionedLoops[0] == 0) {
-    numBatchDims = 1;
-  }
-
+  unsigned numLoops = linalgOp.getNumLoops();
   int64_t vectorSize = getVectorSize(entryPointFn, lhsShapedType);
-  SmallVector<int64_t> vectorSizeVals(numLoops, 1);
-  vectorSizeVals.back() = vectorSize;
-  vectorSizeVals[vectorSizeVals.size() - 2] = vectorSize;
-  vectorSizeVals[vectorSizeVals.size() - 3] = vectorSize;
-
-  SmallVector<int64_t> workloadPerWorkgroup = getDefaultWorkloadPerWorkgroup(
-      tiledLoops.drop_front(numBatchDims),
-      ArrayRef<unsigned>(partitionedLoops).drop_front(numBatchDims),
-      ArrayRef<int64_t>(vectorSizeVals).drop_front(numBatchDims));
-  if (numBatchDims) {
-    workloadPerWorkgroup.push_back(1);
+  SmallVector<int64_t> minTileSizes(numLoops, vectorSize);
+  SmallVector<int64_t> maxTileSizes(numLoops, defaultWorkgroupTileSize);
+  if (numLoops > 3) {
+    minTileSizes[0] = 1;
+    maxTileSizes[0] = 1;
   }
 
-  SmallVector<int64_t> flowTileSizes =
-      getDistributedTileSizes(interfaceOp, workloadPerWorkgroup);
+  OpBuilder builder(entryPointFn.getContext());
+  builder.setInsertionPoint(contractionOp);
+  SmallVector<Range> iterationDomain =
+      linalgOp.createLoopRanges(builder, linalgOp->getLoc());
+  SmallVector<int64_t> flowTileSizes = getDefaultDistributedLevelTileSizes(
+      iterationDomain,
+      cast<IREE::Flow::PartitionableLoopsInterface>(
+          contractionOp.getOperation()),
+      minTileSizes, maxTileSizes);
 
-  Optional<DispatchLoweringPassPipeline> passPipeline = {};
   if (isX86(entryPointFn)) {
     // There is a tileInterchange option. If it needs to be configured, we can
     // only apply the pipeline to linalg.matmul. Because we don't know the
@@ -430,36 +446,17 @@
     Type resElemType =
         getElementTypeOrSelf(contractionOp->getResult(0).getType());
     if (lhsElemType == rhsElemType && rhsElemType == resElemType) {
-      passPipeline = DispatchLoweringPassPipeline::CPUDoubleTilingExpert;
-      if (failed(setX86SandboxRootConfig(entryPointFn, contractionOp,
-                                         flowTileSizes, partitionedLoops,
-                                         vectorSize))) {
-        return failure();
-      }
+      return setX86SandboxRootConfig(entryPointFn, contractionOp, flowTileSizes,
+                                     vectorSize);
     } else {
-      passPipeline = DispatchLoweringPassPipeline::CPUTileFuseAndVectorize;
-      if (failed(setX86TileFuseAndVectorizeRootConfig(
-              entryPointFn, contractionOp, flowTileSizes, vectorSize))) {
-        return failure();
-      }
-    }
-  } else {
-    // Fall back to ARM configurations.
-    passPipeline = DispatchLoweringPassPipeline::CPUTileFuseAndVectorize;
-    if (failed(setARMRootConfig(entryPointFn, contractionOp, flowTileSizes,
-                                vectorSize))) {
-      return failure();
+      return setX86TileFuseAndVectorizeRootConfig(entryPointFn, contractionOp,
+                                                  flowTileSizes, vectorSize);
     }
   }
 
-  if (!passPipeline) {
-    // Do nothing.
-    return success();
-  }
-  setTranslationInfo(entryPointFn, passPipeline.getValue(),
-                     workloadPerWorkgroup,
-                     /*workgroupSize=*/ArrayRef<int64_t>{});
-  return success();
+  // Fall back to ARM configurations.
+  return setARMRootConfig(entryPointFn, contractionOp, flowTileSizes,
+                          vectorSize);
 }
 
 /// Sets the lowering configuration for dispatch region for linalg.mmt4d root
@@ -479,8 +476,8 @@
   };
 
   auto getL1TileSizes = [&]() -> SmallVector<int64_t> {
-    auto lhsShape = getUntiledShape(mmt4dOp.inputs()[0]);
-    auto rhsShape = getUntiledShape(mmt4dOp.inputs()[1]);
+    auto lhsShape = mmt4dOp.inputs()[0].getType().cast<ShapedType>().getShape();
+    auto rhsShape = mmt4dOp.inputs()[1].getType().cast<ShapedType>().getShape();
     int M0 = lhsShape[2];
     int N0 = rhsShape[2];
     int K0 = lhsShape[3];
@@ -492,8 +489,8 @@
   };
 
   auto getVectorSizes = [&]() -> SmallVector<int64_t> {
-    auto lhsShape = getUntiledShape(mmt4dOp.inputs()[0]);
-    auto rhsShape = getUntiledShape(mmt4dOp.inputs()[1]);
+    auto lhsShape = mmt4dOp.inputs()[0].getType().cast<ShapedType>().getShape();
+    auto rhsShape = mmt4dOp.inputs()[1].getType().cast<ShapedType>().getShape();
     int M0 = lhsShape[2];
     int N0 = rhsShape[2];
     int K0 = lhsShape[3];
@@ -519,11 +516,9 @@
 static LogicalResult setRootConfig(
     FuncOp entryPointFn, IREE::LinalgExt::FftOp fftOp,
     ArrayRef<LoopTilingAndDistributionInfo> tiledLoops) {
-  auto interfaceOp = cast<IREE::Flow::PartitionableLoopsInterface>(*fftOp);
-  auto partitionedLoops =
-      interfaceOp.getPartitionableLoops(kNumMaxParallelDims);
-  unsigned maxDepth = partitionedLoops.back() + 1;
-  SmallVector<int64_t> workgroupTileSizes(maxDepth, defaultWorkgroupTileSize);
+  unsigned numLoops = fftOp.getLoopIteratorTypes().size();
+  auto partitionedLoops = fftOp.getPartitionableLoops(kNumMaxParallelDims);
+  SmallVector<int64_t> workgroupTileSizes(numLoops, defaultWorkgroupTileSize);
   llvm::DenseSet<unsigned> partitionedLoopsSet(partitionedLoops.begin(),
                                                partitionedLoops.end());
   for (auto dim : llvm::seq<int64_t>(0, workgroupTileSizes.size())) {
@@ -541,15 +536,12 @@
           std::max(workgroupTileSizes[rank - 1],
                    static_cast<int64_t>(defaultWorkgroupTileSize));
     } else {
-      fftOp.emitError("non-constant stage might not work for fft op");
-      return failure();
+      return fftOp.emitOpError("non-constant stage might not work for fft op");
     }
   }
   TileSizesListType tileSizes = {workgroupTileSizes};
-
   return setOpConfigAndEntryPointFnTranslation(
-      entryPointFn, fftOp, tileSizes,
-      /*nativeVectorSizes=*/ArrayRef<int64_t>{},
+      entryPointFn, fftOp, tileSizes, /*nativeVectorSize=*/ArrayRef<int64_t>{},
       DispatchLoweringPassPipeline::CPUDefault);
 }
 
@@ -561,7 +553,8 @@
   unsigned numLoops = genericOp.getNumLoops();
   if (numLoops == 0) return success();
 
-  SmallVector<int64_t> nativeVectorSize(numLoops, 1);
+  SmallVector<int64_t> minTileSizes(numLoops, 1),
+      maxTileSizes(numLoops, defaultWorkgroupTileSize);
   auto inputOutputOpOperands = genericOp.getInputAndOutputOperands();
   for (auto map : llvm::enumerate(genericOp.getIndexingMaps())) {
     // Check the fastest varying dimension of the operand. Set the vector size
@@ -575,56 +568,102 @@
     // If the indexing map has result it has to be a shaped type.
     auto operandType =
         inputOutputOpOperands[map.index()]->get().getType().cast<ShapedType>();
-    nativeVectorSize[fastestVaryingDim] =
-        std::max<int64_t>(nativeVectorSize[fastestVaryingDim],
+    minTileSizes[fastestVaryingDim] =
+        std::max<int64_t>(minTileSizes[fastestVaryingDim],
                           getVectorSize(entryPointFn, operandType));
   }
-  if (llvm::all_of(nativeVectorSize, [](int64_t vs) { return vs == 1; })) {
+  if (llvm::all_of(minTileSizes, [](int64_t vs) { return vs == 1; })) {
     // Nothing to vectorize just lower to loops.
     return success();
   }
 
   // Set the flow level tiling to the default.
-  auto interfaceOp =
+  OpBuilder builder(genericOp.getContext());
+  builder.setInsertionPoint(genericOp);
+  SmallVector<Range> iterationDomain =
+      cast<linalg::LinalgOp>(genericOp.getOperation())
+          .createLoopRanges(builder, genericOp.getLoc());
+  auto partitionableLoopsInterfaceOp =
       cast<IREE::Flow::PartitionableLoopsInterface>(genericOp.getOperation());
-  SmallVector<unsigned> partitionedLoops =
-      interfaceOp.getPartitionableLoops(kNumMaxParallelDims);
-  SmallVector<int64_t> workloadPerWorkgroup = getDefaultWorkloadPerWorkgroup(
-      tiledLoops, partitionedLoops, nativeVectorSize);
-  setTranslationInfo(entryPointFn,
-                     DispatchLoweringPassPipeline::CPUDoubleTilingExpert,
-                     workloadPerWorkgroup,
-                     /*workgroupSize=*/ArrayRef<int64_t>{});
+  SmallVector<int64_t> flowTileSizes = getDefaultDistributedLevelTileSizes(
+      iterationDomain, partitionableLoopsInterfaceOp, minTileSizes,
+      maxTileSizes);
 
-  llvm::SmallDenseSet<unsigned> pLoopsSet;
-  for (auto i : interfaceOp.getPartitionableLoops(
-           /*maxNumPartitionedLoops=*/std::numeric_limits<unsigned>::max())) {
-    pLoopsSet.insert(i);
+  // Set the Next level tile sizes.
+  SmallVector<int64_t> l1TileSizes(numLoops, 0);
+  Optional<SmallVector<int64_t, 4>> staticLoopRanges =
+      cast<linalg::LinalgOp>(genericOp.getOperation()).getStaticLoopRanges();
+  for (auto loopNum : llvm::seq<unsigned>(0, numLoops)) {
+    if (flowTileSizes[loopNum]) {
+      l1TileSizes[loopNum] =
+          getMaxTileSize(0, flowTileSizes[loopNum], minTileSizes[loopNum],
+                         minTileSizes[loopNum]);
+    } else {
+      // If the flow level tile size is zero, and static loop range is 0 as
+      // well, set the tile sizes here to zero as well.
+      l1TileSizes[loopNum] =
+          (staticLoopRanges && staticLoopRanges.getValue()[loopNum] == 1)
+              ? 0
+              : minTileSizes[loopNum];
+    }
   }
 
-  SmallVector<int64_t> l1TileSizes = nativeVectorSize;
-  SmallVector<int64_t> vectorTileSizes = nativeVectorSize;
-  for (auto i : llvm::seq<unsigned>(0, l1TileSizes.size())) {
-    // This excludes unit parallel dims.
-    if (!pLoopsSet.contains(i)) l1TileSizes[i] = 0;
-  }
-  {
-    SmallVector<unsigned> parallelDims;
-    genericOp.getParallelDims(parallelDims);
-    for (auto d : parallelDims) vectorTileSizes[d] = 0;
+  SmallVector<int64_t> vectorTileSizes = l1TileSizes;
+  for (auto iteratorType : llvm::enumerate(genericOp.iterator_types())) {
+    if (iteratorType.value().cast<StringAttr>().getValue() ==
+        getParallelIteratorTypeName()) {
+      vectorTileSizes[iteratorType.index()] = 0;
+    } else {
+      l1TileSizes[iteratorType.index()] = 0;
+    }
   }
 
   TileSizesListType tileSizes;
-  tileSizes.push_back({});  // Empty since nothing to do for first level tiling.
+  tileSizes.push_back(flowTileSizes);
   tileSizes.push_back(l1TileSizes);
   tileSizes.push_back(vectorTileSizes);
-  auto config = IREE::Codegen::LoweringConfigAttr::get(
-      entryPointFn.getContext(), tileSizes, {});
-  setLoweringConfig(genericOp, config);
-
-  return success();
+  return setOpConfigAndEntryPointFnTranslation(
+      entryPointFn, genericOp, tileSizes,
+      /*nativeVectorSize=*/ArrayRef<int64_t>{},
+      DispatchLoweringPassPipeline::CPUDoubleTilingExpert);
 }
 
+/// Set default configuration for Linalg ops.
+static LogicalResult setRootConfig(
+    FuncOp entryPointFn, linalg::LinalgOp linalgOp,
+    ArrayRef<LoopTilingAndDistributionInfo> tiledLoops) {
+  if (getLoweringConfig(linalgOp)) return success();
+
+  OpBuilder builder(linalgOp.getContext());
+  builder.setInsertionPoint(linalgOp);
+  SmallVector<Range> iterationDomain =
+      linalgOp.createLoopRanges(builder, linalgOp.getLoc());
+
+  auto partitionableLoopOp =
+      cast<IREE::Flow::PartitionableLoopsInterface>(linalgOp.getOperation());
+  return setDefaultRootConfig(entryPointFn, partitionableLoopOp,
+                              iterationDomain);
+}
+
+/// Set the default configuration for operations that implement the
+/// `TiledOpInterface`.
+static LogicalResult setRootConfig(
+    FuncOp entryPointFn, IREE::LinalgExt::TiledOpInterface tiledOpInterfaceOp,
+    ArrayRef<LoopTilingAndDistributionInfo> tiledLoops) {
+  if (getLoweringConfig(tiledOpInterfaceOp)) return success();
+
+  OpBuilder builder(tiledOpInterfaceOp.getContext());
+  builder.setInsertionPoint(tiledOpInterfaceOp);
+  SmallVector<Range> iterationDomain =
+      tiledOpInterfaceOp.getIterationDomain(builder);
+  auto partitionableLoopInterfaceOp =
+      cast<IREE::Flow::PartitionableLoopsInterface>(
+          tiledOpInterfaceOp.getOperation());
+  return setDefaultRootConfig(entryPointFn, partitionableLoopInterfaceOp,
+                              iterationDomain);
+}
+
+/// Redirects to methods that set the configuration based on operation type.
 static LogicalResult setRootConfigImpl(
     FuncOp entryPointFn, Operation *op,
     ArrayRef<LoopTilingAndDistributionInfo> tiledLoops) {
@@ -634,68 +673,135 @@
   // Redirect to individual operations.
   auto setRootConfigFn = [&](Operation *op) -> LogicalResult {
     return TypeSwitch<Operation *, LogicalResult>(op)
-        .Case<linalg::Mmt4DOp, linalg::ContractionOpInterface,
-              IREE::LinalgExt::FftOp>([&](auto op) {
+        .Case<IREE::LinalgExt::FftOp, linalg::GenericOp, linalg::Mmt4DOp>(
+            [&](auto op) {
+              return setRootConfig(entryPointFn, op, tiledLoops);
+            })
+        .Case<linalg::ContractionOpInterface>([&](auto op) {
           return setRootConfig(entryPointFn, op, tiledLoops);
         })
-        .Case<linalg::GenericOp>([&](auto genericOp) {
-          if (genericOp.getNumLoops() == genericOp.getNumParallelLoops()) {
-            // Ignore parallel elementwise operations now. They will be set as
-            // roots ops if there are no other ops that can be treated as a
-            // root op.
-            return success();
-          }
-          return setRootConfig(entryPointFn, genericOp, tiledLoops);
-        })
+        .Case<linalg::LinalgOp, IREE::LinalgExt::TiledOpInterface>(
+            [&](auto op) {
+              return setRootConfig(entryPointFn, op, tiledLoops);
+            })
         .Default([&](Operation *op) { return success(); });
   };
   return setRootConfigFn(op);
 }
 
+/// Redirects to methods that set the configuration based on operation type for
+/// VMVX backend.
+static LogicalResult setVMVXRootConfigImpl(
+    FuncOp entryPointFn, Operation *op,
+    ArrayRef<LoopTilingAndDistributionInfo> tiledLoops) {
+  if (getLoweringConfig(op)) return success();
+
+  // Redirect to individual operations.
+  auto setRootConfigFn = [&](Operation *op) -> LogicalResult {
+    return TypeSwitch<Operation *, LogicalResult>(op)
+        .Case<linalg::LinalgOp, IREE::LinalgExt::TiledOpInterface>(
+            [&](auto op) {
+              return setRootConfig(entryPointFn, op, tiledLoops);
+            })
+        .Default([&](Operation *op) { return success(); });
+  };
+  return setRootConfigFn(op);
+}
+
+/// Find the root operation for the dispatch region.
+static FailureOr<Operation *> getRootOperation(
+    ArrayRef<Operation *> computeOps) {
+  Operation *rootOperation = nullptr;
+  auto updateRootOperation = [&](Operation *op) -> LogicalResult {
+    if (rootOperation) {
+      return op->emitOpError(
+          "unhandled multiple root operations in dispatch region");
+    }
+    rootOperation = op;
+    return success();
+  };
+  for (auto op : computeOps) {
+    if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
+      // Do not not treat linalg ops that are all parallel as root operations in
+      // this sweep.
+      if (linalgOp.getNumLoops() == linalgOp.getNumParallelLoops()) continue;
+
+      // All other linalg ops are root ops.
+      if (failed(updateRootOperation(op))) return failure();
+      continue;
+    }
+
+    if (auto tiledOpInterfaceOp =
+            dyn_cast<IREE::LinalgExt::TiledOpInterface>(op)) {
+      // TODO(ravishankarm): For now
+      // `tensor.extract_slice`/`tensor.insert_slice` implement the
+      // `tiledInterfaceOp`. With tile + distribute moved out of Flow
+      // dialect, this doesnt work anymore. Remove this when the external
+      // model implementation of
+      // `tensor.extract_slice`/`tensor.insert_slice` are dropped.
+      if (isa<tensor::ExtractSliceOp, tensor::InsertSliceOp>(op)) continue;
+
+      // All other operations that implement this interface are root ops.
+      if (failed(updateRootOperation(op))) return failure();
+      continue;
+    }
+  }
+  if (rootOperation) return rootOperation;
+
+  // If no root operation is found yet. Look for linalg generic ops.
+  for (auto op : computeOps) {
+    if (isa<linalg::GenericOp>(op)) {
+      if (failed(updateRootOperation(op))) return failure();
+    }
+  }
+  if (rootOperation) return rootOperation;
+
+  // TODO(ravishankarm): Currently there is a corner case of a dispatch region
+  // with just a `tensor.extract_slice`/`tensor.insert_slice`. Those need to be
+  // folded with `flow.dispatch.tensor.load`/`flow.dispatch.tensor.store` ops
+  // respectively. This should go hand-in-hand with dropping the external model
+  // implementation of the `TiledOpInterface` for these ops. Till we cross that
+  // bridge, handle that case.
+  // Throw in linalg.fill here as well, though that should never happen either.
+  if (computeOps.size() == 1 &&
+      isa<linalg::FillOp, tensor::ExtractSliceOp, tensor::InsertSliceOp>(
+          computeOps[0])) {
+    rootOperation = computeOps[0];
+  }
+  return rootOperation;
+}
+
 /// Finds the root operation in the given list of Linalg operations and sets
 /// its configuration. Returns error for multiple root operations.
 static LogicalResult setRootConfig(
     FuncOp entryPointFn, ArrayRef<Operation *> computeOps,
     ArrayRef<LoopTilingAndDistributionInfo> tiledLoops) {
-  Operation *rootOp = nullptr;
-  for (auto computeOp : computeOps) {
-    if (failed(setRootConfigImpl(entryPointFn, computeOp, tiledLoops))) {
-      return failure();
-    }
-    if (getLoweringConfig(computeOp)) {
-      if (rootOp) {
-        return computeOp->emitOpError(
-            "unhandled multiple roots in dispatch region");
-      }
-      rootOp = computeOp;
-    }
+  FailureOr<Operation *> rootOp = getRootOperation(computeOps);
+  if (failed(rootOp)) {
+    return failure();
   }
-  if (rootOp) return success();
+  Operation *rootOperation = rootOp.getValue();
 
-  // If there are any other ops other than `linalg.generic`, `linalg.generic` or
-  // `linalg.fill` then just use the default.
-  for (auto computeOp : computeOps) {
-    if (!isa<linalg::GenericOp, linalg::FillOp>(computeOp)) {
-      return success();
-    }
-  }
-
-  // If there are no root ops, then check for a single `linalg.generic` op. Make
-  // this the root, and vectorize the operation.
-  for (auto computeOp : computeOps) {
-    if (auto genericOp = dyn_cast<linalg::GenericOp>(computeOp)) {
-      if (failed(setRootConfig(entryPointFn, genericOp, tiledLoops))) {
+  if (rootOperation) {
+    if (isVMVXBackend(entryPointFn)) {
+      if (failed(
+              setVMVXRootConfigImpl(entryPointFn, rootOperation, tiledLoops))) {
         return failure();
       }
-      if (getLoweringConfig(computeOp)) {
-        if (rootOp) {
-          return computeOp->emitOpError(
-              "unhanlded multiple parallel generic ops within a dispatch");
-        }
-        rootOp = computeOp;
+    } else {
+      if (failed(setRootConfigImpl(entryPointFn, rootOperation, tiledLoops))) {
+        return failure();
       }
     }
   }
+
+  if (!getTranslationInfo(entryPointFn)) {
+    // Fall back, just set the translation to CPUDefault.
+    setTranslationInfo(entryPointFn, DispatchLoweringPassPipeline::CPUDefault,
+                       /*workloadPerWorkgroup=*/ArrayRef<int64_t>{},
+                       /*workgroupSize=*/ArrayRef<int64_t>{});
+  }
+
   return success();
 }
 
@@ -724,18 +830,7 @@
   }
 
   // Next set the configuration of the operations.
-  // For VMVX, do not use vectorization. Just lower as default.
-  if (!isVMVXBackend(entryPointFn)) {
-    if (failed(setRootConfig(entryPointFn, computeOps, tiledLoops))) {
-      return failure();
-    }
-  }
-
-  // Check if the translation info for the entry point is already set.
-  if (!getTranslationInfo(entryPointFn)) {
-    return setDefaultLaunchConfig(entryPointFn, tiledLoops);
-  }
-  return success();
+  return setRootConfig(entryPointFn, computeOps, tiledLoops);
 }
 
 LogicalResult initCPULaunchConfig(ModuleOp moduleOp) {
@@ -758,7 +853,12 @@
       return failure();
     }
   }
-  return success();
+
+  // The root confguration setting introduces `tensor.dim` operations. Resolve
+  // those away.
+  RewritePatternSet patterns(moduleOp.getContext());
+  memref::populateResolveRankedShapeTypeResultDimsPatterns(patterns);
+  return applyPatternsAndFoldGreedily(moduleOp, std::move(patterns));
 }
 
 }  // namespace iree_compiler
diff --git a/iree/compiler/Codegen/LLVMCPU/LLVMCPULowerExecutableTarget.cpp b/iree/compiler/Codegen/LLVMCPU/LLVMCPULowerExecutableTarget.cpp
index 7cb17f8..3644e91 100644
--- a/iree/compiler/Codegen/LLVMCPU/LLVMCPULowerExecutableTarget.cpp
+++ b/iree/compiler/Codegen/LLVMCPU/LLVMCPULowerExecutableTarget.cpp
@@ -164,8 +164,6 @@
         return signalPassFailure();
       }
 
-      executableLoweringPipeline.addPass(createSetNumWorkgroupsPass());
-      executableLoweringPipeline.addPass(createCanonicalizerPass());
       bool lowerToVectors = !isVMVXBackend(variantOp);
       if (!testLoweringConfiguration) {
         OpPassManager &nestedModulePM =
diff --git a/iree/compiler/Codegen/LLVMCPU/Passes.cpp b/iree/compiler/Codegen/LLVMCPU/Passes.cpp
index 9a2f78b..4f0cf24 100644
--- a/iree/compiler/Codegen/LLVMCPU/Passes.cpp
+++ b/iree/compiler/Codegen/LLVMCPU/Passes.cpp
@@ -92,77 +92,30 @@
            << pipelineName;
   }
 
+  // Verify that the workload per workgroup is not set.
+  // TODO(ravishankarm): Remove workload_per_wg eventually.
+  SmallVector<int64_t> workloadPerWorkgroup =
+      translationInfo.getWorkloadPerWorkgroupVals();
+  if (!workloadPerWorkgroup.empty()) {
+    return op->emitOpError(
+               "workload_per_wg expected to be empty since its internal "
+               "compiler implementation detail")
+           << kNumMaxParallelDims;
+  }
+
   if (loweringConfig.getTileSizes().size() != 3) {
     return op->emitOpError("expected three tiling sizes for ")
            << pipelineName << ", got " << loweringConfig.getTileSizes().size();
   }
 
-  // Verify that the workload per workgroup is set and is non-zero.
-  SmallVector<int64_t> workloadPerWorkgroup =
-      translationInfo.getWorkloadPerWorkgroupVals();
-  if (workloadPerWorkgroup.size() > kNumMaxParallelDims) {
-    return op->emitOpError(
-               "workload_per_wg size should be less than or equal to ")
-           << kNumMaxParallelDims;
-  }
-  if (llvm::any_of(workloadPerWorkgroup,
-                   [](int64_t val) { return val == 0; })) {
-    return op->emitOpError("invalid to use 0 in workload_per_wg");
-  }
-
   IREE::Flow::PartitionableLoopsInterface interfaceOp =
       dyn_cast_or_null<IREE::Flow::PartitionableLoopsInterface>(op);
   if (interfaceOp) {
-    SmallVector<unsigned> partitionedLoops =
-        interfaceOp.getPartitionableLoops(kNumMaxParallelDims);
-    // TODO(hanchung): Allow empty workloadPerWorkgroup for now. We're going to
-    // defer tile and distribution which may affect this.
-    if (!workloadPerWorkgroup.empty() &&
-        workloadPerWorkgroup.size() != partitionedLoops.size()) {
-      return op->emitOpError("expected ")
-             << partitionedLoops.size()
-             << " entries for workload_per_wg, but got "
-             << workloadPerWorkgroup.size();
-    }
-    SmallVector<int64_t> firstLevelTileSizes = loweringConfig.getTileSizeVals(
-        static_cast<unsigned>(TilingLevel::WorkGroupTiles));
-
-    if (!firstLevelTileSizes.empty()) {
-      // Verify that if the first-level tile sizes are set, they are the same as
-      // workload_per_wg for the partitioned loops.
-      SmallVector<unsigned> partitionedLoops =
-          interfaceOp.getPartitionableLoops(kNumMaxParallelDims);
-      size_t minElements =
-          (partitionedLoops.empty() ? 0 : partitionedLoops.back() + 1);
-      if (firstLevelTileSizes.size() < minElements) {
-        return op->emitOpError("expected at least ")
-               << minElements
-               << " size for first level tiling to get the distribution fully "
-                  "specified.";
-      }
-      llvm::SmallDenseSet<unsigned> partitionedLoopsSet;
-      partitionedLoopsSet.insert(partitionedLoops.begin(),
-                                 partitionedLoops.end());
-      SmallVector<int64_t> partitionedTileSizes;
-      for (auto tileSize : llvm::enumerate(firstLevelTileSizes)) {
-        if (!partitionedLoopsSet.count(tileSize.index())) {
-          continue;
-        }
-        partitionedTileSizes.push_back(tileSize.value());
-      }
-      for (auto val : llvm::enumerate(llvm::reverse(workloadPerWorkgroup))) {
-        if (val.value() != partitionedTileSizes[val.index()]) {
-          return op->emitOpError("mismatch in distributed tile size value ")
-                 << partitionedTileSizes[val.index()] << " at position "
-                 << val.index() << " and workload_per_wg value " << val.value();
-        }
-      }
-    }
-
     llvm::SmallDenseSet<unsigned> pLoopsSet;
-    for (auto i : interfaceOp.getPartitionableLoops(
-             /*maxNumPartitionedLoops=*/std::numeric_limits<unsigned>::max())) {
-      pLoopsSet.insert(i);
+    for (auto iteratorType : llvm::enumerate(interfaceOp.getIteratorTypes())) {
+      if (iteratorType.value() == getParallelIteratorTypeName()) {
+        pLoopsSet.insert(iteratorType.index());
+      }
     }
 
     SmallVector<int64_t> secondLevelTileSizes = loweringConfig.getTileSizeVals(
@@ -170,7 +123,7 @@
     for (auto en : llvm::enumerate(secondLevelTileSizes)) {
       if (en.value() != 0 && !pLoopsSet.contains(en.index())) {
         return op->emitOpError(
-                   "expected only non-unit parallel dims can be set in the "
+                   "expected only parallel dims to be set in the "
                    "second tiling sizes, got ")
                << en.index() << "-th tile size set";
       }
@@ -181,7 +134,7 @@
     for (auto en : llvm::enumerate(thirdLevelTileSizes)) {
       if (en.value() != 0 && pLoopsSet.contains(en.index())) {
         return op->emitOpError(
-                   "expected only reduction dims can be set in the third "
+                   "expected only reduction dims to be set in the third "
                    "tiling sizes, got ")
                << en.index() << "-th tile size set";
       }
@@ -208,6 +161,11 @@
 //===---------------------------------------------------------------------===//
 
 void addSingleTilingExpertPassPipeline(OpPassManager &passManager) {
+  // Do first level of tiling and distribution.
+  passManager.addNestedPass<FuncOp>(createTileAndDistributeToWorkgroupsPass());
+  passManager.addPass(createCanonicalizerPass());
+  passManager.addPass(createCSEPass());
+
   passManager.addNestedPass<FuncOp>(
       createConvertToDestinationPassingStylePass());
   passManager.addPass(createCanonicalizerPass());
@@ -239,6 +197,11 @@
 }
 
 void addDoubleTilingExpertPassPipeline(OpPassManager &passManager) {
+  // Do first level of tiling and distribution.
+  passManager.addNestedPass<FuncOp>(createTileAndDistributeToWorkgroupsPass());
+  passManager.addPass(createCanonicalizerPass());
+  passManager.addPass(createCSEPass());
+
   passManager.addNestedPass<FuncOp>(
       createConvertToDestinationPassingStylePass());
 
@@ -298,7 +261,10 @@
 
 void addTileFuseAndVectorizePassPipeline(OpPassManager &passManager,
                                          bool lowerToVectors) {
+  // Do first level of tile and distribute to workgroups.
+  passManager.addNestedPass<FuncOp>(createTileAndDistributeToWorkgroupsPass());
   passManager.addPass(createCanonicalizerPass());
+  passManager.addPass(createCSEPass());
 
   // Tile and vectorize linalg ops on tensors.
   passManager.addNestedPass<FuncOp>(
@@ -327,7 +293,10 @@
 }
 
 void addCPUDefaultPassPipeline(OpPassManager &passManager) {
+  // Do first level of tile and distribute to workgroups.
+  passManager.addNestedPass<FuncOp>(createTileAndDistributeToWorkgroupsPass());
   passManager.addPass(createCanonicalizerPass());
+  passManager.addPass(createCSEPass());
   // Use stack allocation on CPU side.
   addLinalgBufferizePasses(passManager, cpuAllocationFunction);
 }
@@ -375,10 +344,6 @@
 void buildLLVMCPUCodegenPassPipeline(OpPassManager &passManager) {
   passManager.nest<ModuleOp>().nest<FuncOp>().addPass(
       createTypePropagationPass());
-  passManager.nest<ModuleOp>().nest<FuncOp>().addPass(
-      createTileAndDistributeToWorkgroupsPass());
-  passManager.addPass(createCanonicalizerPass());
-  passManager.addPass(createCSEPass());
   passManager.addPass(createLLVMCPULowerExecutableTargetPass());
   OpPassManager &nestedModulePM = passManager.nest<ModuleOp>();
   addLowerToLLVMPasses(nestedModulePM);
diff --git a/iree/compiler/Codegen/LLVMCPU/test/illegal_configuration.mlir b/iree/compiler/Codegen/LLVMCPU/test/illegal_configuration.mlir
index bcc4d4b..47e47e0 100644
--- a/iree/compiler/Codegen/LLVMCPU/test/illegal_configuration.mlir
+++ b/iree/compiler/Codegen/LLVMCPU/test/illegal_configuration.mlir
@@ -11,7 +11,7 @@
 ]>
 hal.executable private @matmul_tensors {
   hal.executable.variant @llvm, target = #hal.executable.target<"llvm", "embedded-elf-x86_64", {}> {
-    hal.executable.entry_point @illegal layout(#executable_layout) attributes {
+    hal.executable.entry_point @illegal layout(#executable_layout)  {
       translation.info = #translation
     }
     builtin.module {
@@ -31,8 +31,8 @@
 
 // -----
 
-#config = #iree_codegen.lowering.config<tile_sizes = [[], [], []], native_vector_size = []>
-#translation = #iree_codegen.translation.info<"CPUDoubleTilingExpert", workload_per_wg = [1, 0]>
+#config = #iree_codegen.lowering.config<tile_sizes = [[4, 8], [8, 8, 0], [0, 0, 8]], native_vector_size = [0, 0, 4]>
+#translation = #iree_codegen.translation.info<"CPUDoubleTilingExpert", workload_per_wg = []>
 #executable_layout = #hal.executable.layout<push_constants = 0, sets = [
   #hal.descriptor_set.layout<0, bindings = [
     #hal.descriptor_set.binding<0, storage_buffer>,
@@ -42,100 +42,7 @@
 ]>
 hal.executable private @matmul_tensors {
   hal.executable.variant @llvm, target = #hal.executable.target<"llvm", "embedded-elf-x86_64", {}> {
-    hal.executable.entry_point @illegal layout(#executable_layout) attributes {
-      translation.info = #translation
-    }
-    builtin.module {
-      func @illegal() {
-        %c0 = arith.constant 0 : index
-        %lhs = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) : memref<4x8xf32>
-        %rhs = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) : memref<8x16xf32>
-        %result = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) : memref<4x16xf32>
-        // expected-error @+1 {{invalid to use 0 in workload_per_wg}}
-        linalg.matmul {lowering.config = #config} ins(%lhs, %rhs : memref<4x8xf32>, memref<8x16xf32>)
-          outs(%result: memref<4x16xf32>)
-        return
-      }
-    }
-  }
-}
-
-// -----
-
-#config = #iree_codegen.lowering.config<tile_sizes = [[], [], []], native_vector_size = []>
-#translation = #iree_codegen.translation.info<"CPUDoubleTilingExpert", workload_per_wg = [1, 1, 1, 1]>
-#executable_layout = #hal.executable.layout<push_constants = 0, sets = [
-  #hal.descriptor_set.layout<0, bindings = [
-    #hal.descriptor_set.binding<0, storage_buffer>,
-    #hal.descriptor_set.binding<1, storage_buffer>,
-    #hal.descriptor_set.binding<2, storage_buffer>
-  ]>
-]>
-hal.executable private @matmul_tensors {
-  hal.executable.variant @llvm, target = #hal.executable.target<"llvm", "embedded-elf-x86_64", {}> {
-    hal.executable.entry_point @illegal layout(#executable_layout) attributes {
-      translation.info = #translation
-    }
-    builtin.module {
-      func @illegal() {
-        %c0 = arith.constant 0 : index
-        %lhs = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) : memref<4x8xf32>
-        %rhs = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) : memref<8x16xf32>
-        %result = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) : memref<4x16xf32>
-        // expected-error @+1 {{workload_per_wg size should be less than or equal to 3}}
-        linalg.matmul {lowering.config = #config} ins(%lhs, %rhs : memref<4x8xf32>, memref<8x16xf32>)
-          outs(%result: memref<4x16xf32>)
-        return
-      }
-    }
-  }
-}
-
-// -----
-
-#config = #iree_codegen.lowering.config<tile_sizes = [[4, 8], [], []], native_vector_size = []>
-#translation = #iree_codegen.translation.info<"CPUDoubleTilingExpert", workload_per_wg = [8, 6]>
-#executable_layout = #hal.executable.layout<push_constants = 0, sets = [
-  #hal.descriptor_set.layout<0, bindings = [
-    #hal.descriptor_set.binding<0, storage_buffer>,
-    #hal.descriptor_set.binding<1, storage_buffer>,
-    #hal.descriptor_set.binding<2, storage_buffer>
-  ]>
-]>
-hal.executable private @matmul_tensors {
-  hal.executable.variant @llvm, target = #hal.executable.target<"llvm", "embedded-elf-x86_64", {}> {
-    hal.executable.entry_point @illegal layout(#executable_layout) attributes {
-      translation.info = #translation
-    }
-    builtin.module {
-      func @illegal() {
-        %c0 = arith.constant 0 : index
-        %lhs = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) : memref<4x8xf32>
-        %rhs = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) : memref<8x16xf32>
-        %result = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) : memref<4x16xf32>
-        // expected-error @+1 {{mismatch in distributed tile size value 4 at position 0 and workload_per_wg value 6}}
-        linalg.matmul {lowering.config = #config} ins(%lhs, %rhs : memref<4x8xf32>, memref<8x16xf32>)
-          outs(%result: memref<4x16xf32>)
-        return
-      }
-    }
-  }
-}
-
-// -----
-
-#config = #iree_codegen.lowering.config<tile_sizes = [[], [8, 8, 0], [0, 0, 8]], native_vector_size = [0, 0, 4]>
-#translation = #iree_codegen.translation.info<"CPUDoubleTilingExpert", workload_per_wg = [8, 4]>
-#executable_layout = #hal.executable.layout<push_constants = 0, sets = [
-  #hal.descriptor_set.layout<0, bindings = [
-    #hal.descriptor_set.binding<0, storage_buffer>,
-    #hal.descriptor_set.binding<1, storage_buffer>,
-    #hal.descriptor_set.binding<2, storage_buffer>
-  ]>
-]>
-hal.executable private @matmul_tensors {
-  hal.executable.variant @llvm, target = #hal.executable.target<"llvm", "embedded-elf-x86_64", {}> {
-    hal.executable.entry_point @illegal layout(#executable_layout) attributes {
+    hal.executable.entry_point @illegal layout(#executable_layout)  {
       translation.info = #translation
     }
     builtin.module {
@@ -155,8 +62,8 @@
 
 // -----
 
-#config = #iree_codegen.lowering.config<tile_sizes = [[], [8, 32, 0], [0, 0, 16]], native_vector_size = []>
-#translation = #iree_codegen.translation.info<"CPUDoubleTilingExpert", workload_per_wg = [32]>
+#config = #iree_codegen.lowering.config<tile_sizes = [[64, 64, 0], [8, 32, 16], [0, 0, 16]], native_vector_size = []>
+#translation = #iree_codegen.translation.info<"CPUDoubleTilingExpert", workload_per_wg = []>
 #executable_layout = #hal.executable.layout<push_constants = 0, sets = [
   #hal.descriptor_set.layout<0, bindings = [
     #hal.descriptor_set.binding<0, storage_buffer>,
@@ -166,7 +73,7 @@
 ]>
 hal.executable private @matmul_tensors {
   hal.executable.variant @llvm, target = #hal.executable.target<"llvm", "embedded-elf-x86_64", {}> {
-    hal.executable.entry_point @illegal layout(#executable_layout) attributes {
+    hal.executable.entry_point @illegal layout(#executable_layout)  {
       translation.info = #translation
     }
     builtin.module {
@@ -175,7 +82,7 @@
         %lhs = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) : memref<4x8xf32>
         %rhs = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) : memref<8x16xf32>
         %result = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) : memref<4x16xf32>
-        // expected-error @+1 {{expected 2 entries for workload_per_wg, but got 1}}
+        // expected-error @+1 {{expected only parallel dims to be set in the second tiling sizes, got 2-th tile size set}}
         linalg.matmul {lowering.config = #config} ins(%lhs, %rhs : memref<4x8xf32>, memref<8x16xf32>)
           outs(%result: memref<4x16xf32>)
         return
@@ -186,8 +93,8 @@
 
 // -----
 
-#config = #iree_codegen.lowering.config<tile_sizes = [[64, 32], [8, 32, 0], [0, 0, 16]], native_vector_size = []>
-#translation = #iree_codegen.translation.info<"CPUDoubleTilingExpert", workload_per_wg = [64, 64]>
+#config = #iree_codegen.lowering.config<tile_sizes = [[64, 64], [8, 0, 0], [0, 16, 16]], native_vector_size = []>
+#translation = #iree_codegen.translation.info<"CPUDoubleTilingExpert", workload_per_wg = []>
 #executable_layout = #hal.executable.layout<push_constants = 0, sets = [
   #hal.descriptor_set.layout<0, bindings = [
     #hal.descriptor_set.binding<0, storage_buffer>,
@@ -197,7 +104,7 @@
 ]>
 hal.executable private @matmul_tensors {
   hal.executable.variant @llvm, target = #hal.executable.target<"llvm", "embedded-elf-x86_64", {}> {
-    hal.executable.entry_point @illegal layout(#executable_layout) attributes {
+    hal.executable.entry_point @illegal layout(#executable_layout)  {
       translation.info = #translation
     }
     builtin.module {
@@ -206,69 +113,7 @@
         %lhs = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) : memref<4x8xf32>
         %rhs = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) : memref<8x16xf32>
         %result = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) : memref<4x16xf32>
-        // expected-error @+1 {{mismatch in distributed tile size value 32 at position 1 and workload_per_wg value 64}}
-        linalg.matmul {lowering.config = #config} ins(%lhs, %rhs : memref<4x8xf32>, memref<8x16xf32>)
-          outs(%result: memref<4x16xf32>)
-        return
-      }
-    }
-  }
-}
-
-// -----
-
-#config = #iree_codegen.lowering.config<tile_sizes = [[], [8, 32, 16], [0, 0, 16]], native_vector_size = []>
-#translation = #iree_codegen.translation.info<"CPUDoubleTilingExpert", workload_per_wg = [64, 64]>
-#executable_layout = #hal.executable.layout<push_constants = 0, sets = [
-  #hal.descriptor_set.layout<0, bindings = [
-    #hal.descriptor_set.binding<0, storage_buffer>,
-    #hal.descriptor_set.binding<1, storage_buffer>,
-    #hal.descriptor_set.binding<2, storage_buffer>
-  ]>
-]>
-hal.executable private @matmul_tensors {
-  hal.executable.variant @llvm, target = #hal.executable.target<"llvm", "embedded-elf-x86_64", {}> {
-    hal.executable.entry_point @illegal layout(#executable_layout) attributes {
-      translation.info = #translation
-    }
-    builtin.module {
-      func @illegal() {
-        %c0 = arith.constant 0 : index
-        %lhs = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) : memref<4x8xf32>
-        %rhs = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) : memref<8x16xf32>
-        %result = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) : memref<4x16xf32>
-        // expected-error @+1 {{expected only non-unit parallel dims can be set in the second tiling sizes, got 2-th tile size set}}
-        linalg.matmul {lowering.config = #config} ins(%lhs, %rhs : memref<4x8xf32>, memref<8x16xf32>)
-          outs(%result: memref<4x16xf32>)
-        return
-      }
-    }
-  }
-}
-
-// -----
-
-#config = #iree_codegen.lowering.config<tile_sizes = [[], [8, 0, 0], [0, 16, 16]], native_vector_size = []>
-#translation = #iree_codegen.translation.info<"CPUDoubleTilingExpert", workload_per_wg = [64, 64]>
-#executable_layout = #hal.executable.layout<push_constants = 0, sets = [
-  #hal.descriptor_set.layout<0, bindings = [
-    #hal.descriptor_set.binding<0, storage_buffer>,
-    #hal.descriptor_set.binding<1, storage_buffer>,
-    #hal.descriptor_set.binding<2, storage_buffer>
-  ]>
-]>
-hal.executable private @matmul_tensors {
-  hal.executable.variant @llvm, target = #hal.executable.target<"llvm", "embedded-elf-x86_64", {}> {
-    hal.executable.entry_point @illegal layout(#executable_layout) attributes {
-      translation.info = #translation
-    }
-    builtin.module {
-      func @illegal() {
-        %c0 = arith.constant 0 : index
-        %lhs = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) : memref<4x8xf32>
-        %rhs = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) : memref<8x16xf32>
-        %result = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) : memref<4x16xf32>
-        // expected-error @+1 {{only reduction dims can be set in the third tiling sizes, got 1-th tile size set}}
+        // expected-error @+1 {{only reduction dims to be set in the third tiling sizes, got 1-th tile size set}}
         linalg.matmul {lowering.config = #config} ins(%lhs, %rhs : memref<4x8xf32>, memref<8x16xf32>)
           outs(%result: memref<4x16xf32>)
         return
diff --git a/iree/compiler/Codegen/LLVMCPU/test/materialize_launch_configuration.mlir b/iree/compiler/Codegen/LLVMCPU/test/materialize_launch_configuration.mlir
index e00c8eb..09d6dee 100644
--- a/iree/compiler/Codegen/LLVMCPU/test/materialize_launch_configuration.mlir
+++ b/iree/compiler/Codegen/LLVMCPU/test/materialize_launch_configuration.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt -pass-pipeline='hal.executable(hal.executable.variant(iree-llvmcpu-lower-executable-target{test-lowering-configuration=true}))' -cse -canonicalize -split-input-file %s | FileCheck %s
+// RUN: iree-opt -pass-pipeline='hal.executable(hal.executable.variant(iree-llvmcpu-lower-executable-target{test-lowering-configuration=true}))' -split-input-file %s | FileCheck %s
 
 #executable_layout = #hal.executable.layout<push_constants = 0, sets = [
   #hal.descriptor_set.layout<0, bindings = [
@@ -22,51 +22,34 @@
         %M = hal.interface.constant.load[0] : index
         %N = hal.interface.constant.load[1] : index
         %K = hal.interface.constant.load[2] : index
-        %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) : !flow.dispatch.tensor<readonly:?x?xf32>{%M, %K}
-        %2 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) : !flow.dispatch.tensor<readonly:?x?xf32>{%K, %N}
-        %4 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) : !flow.dispatch.tensor<readonly:?x?xf32>{%M, %N}
-        %6 = hal.interface.binding.subspan set(0) binding(3) type(storage_buffer) : !flow.dispatch.tensor<writeonly:?x?xf32>{%M, %N}
-        %workgroup_size_x = hal.interface.workgroup.size[0] : index
-        %workgroup_size_y = hal.interface.workgroup.size[1] : index
-        %workgroup_id_x = hal.interface.workgroup.id[0] : index
-        %workgroup_count_x = hal.interface.workgroup.count[0] : index
-        %workgroup_id_y = hal.interface.workgroup.id[1] : index
-        %workgroup_count_y = hal.interface.workgroup.count[1] : index
-        %8 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_size_y, %workgroup_id_y]
-        %9 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_size_y, %workgroup_count_y]
-        scf.for %arg0 = %8 to %M step %9 {
-          %10 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_size_x, %workgroup_id_x]
-          %11 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_size_x, %workgroup_count_x]
-          scf.for %arg1 = %10 to %N step %11 {
-            %12 = affine.min affine_map<(d0)[s0, s1] -> (s0, -d0 + s1)>(%arg0)[%workgroup_size_y, %N]
-            %13 = flow.dispatch.tensor.load %0, offsets=[%arg0, 0], sizes=[%12, %K], strides=[1, 1] : !flow.dispatch.tensor<readonly:?x?xf32>{%M, %K} -> tensor<?x?xf32>
-            %14 = affine.min affine_map<(d0)[s0, s1] -> (s0, -d0 + s1)>(%arg1)[%workgroup_size_x, %M]
-            %15 = flow.dispatch.tensor.load %2, offsets=[0, %arg1], sizes=[%K, %14], strides=[1, 1] : !flow.dispatch.tensor<readonly:?x?xf32>{%K, %N} -> tensor<?x?xf32>
-            %16 = flow.dispatch.tensor.load %4, offsets=[%arg0, %arg1], sizes=[%12, %14], strides=[1, 1] : !flow.dispatch.tensor<readonly:?x?xf32>{%M, %N} -> tensor<?x?xf32>
-            %17 = linalg.matmul ins(%13, %15 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%16 : tensor<?x?xf32>) -> tensor<?x?xf32>
-            flow.dispatch.tensor.store %17, %6, offsets=[%arg0, %arg1], sizes=[%12, %14], strides=[1, 1] : tensor<?x?xf32> -> !flow.dispatch.tensor<writeonly:?x?xf32>{%M, %N}
-          }
-        }
+        %lhs_binding = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer)
+            : !flow.dispatch.tensor<readonly:?x?xf32>{%M, %K}
+        %rhs_binding = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer)
+            : !flow.dispatch.tensor<readonly:?x?xf32>{%K, %N}
+        %init_binding = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer)
+            : !flow.dispatch.tensor<readonly:?x?xf32>{%M, %N}
+        %result_binding = hal.interface.binding.subspan set(0) binding(3) type(storage_buffer)
+            : !flow.dispatch.tensor<writeonly:?x?xf32>{%M, %N}
+              %lhs = flow.dispatch.tensor.load %lhs_binding, offsets = [0, 0], sizes = [%M, %K], strides = [1, 1]
+            : !flow.dispatch.tensor<readonly:?x?xf32>{%M, %K} -> tensor<?x?xf32>
+        %rhs = flow.dispatch.tensor.load %rhs_binding, offsets = [0, 0], sizes = [%K, %N], strides = [1, 1]
+            : !flow.dispatch.tensor<readonly:?x?xf32>{%K, %N} -> tensor<?x?xf32>
+        %init = flow.dispatch.tensor.load %init_binding, offsets = [0, 0], sizes = [%M, %N], strides = [1, 1]
+            : !flow.dispatch.tensor<readonly:?x?xf32>{%M, %N} -> tensor<?x?xf32>
+        %gemm = linalg.matmul ins(%lhs, %rhs : tensor<?x?xf32>, tensor<?x?xf32>) outs(%init : tensor<?x?xf32>) -> tensor<?x?xf32>
+        flow.dispatch.tensor.store %gemm, %result_binding, offsets = [0, 0], sizes = [%M, %N], strides = [1, 1]
+            : tensor<?x?xf32> -> !flow.dispatch.tensor<writeonly:?x?xf32>{%M, %N}
         return
       }
     }
   }
 }
-
-//  CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering.config<tile_sizes = {{\[}}[], [16, 4, 64], [4, 4, 4]{{\]}}, native_vector_size = [4, 4, 4]>
-//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"CPUTileFuseAndVectorize", workload_per_wg = [64, 64]>
-//  CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 64)>
+//  CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering.config<tile_sizes = {{\[}}[64, 64, 0], [16, 4, 64], [4, 4, 4]{{\]}}, native_vector_size = [4, 4, 4]>
+//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"CPUTileFuseAndVectorize", workload_per_wg = []>
 //      CHECK: hal.executable.entry_point public @matmul_tensors
-// CHECK-SAME:   translation.info = #[[TRANSLATION]]
-// CHECK-NEXT:   (%[[ARG0:[a-zA-Z0-9_]+]]: index
-// CHECK-SAME:    %[[ARG1:[a-zA-Z0-9_]+]]: index
-// CHECK-SAME:    %[[ARG2:[a-zA-Z0-9_]+]]: index)
-//  CHECK-DAG:    %[[C1:.+]] = arith.constant 1 : index
-//  CHECK-DAG:    %[[D0:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]]]
-//  CHECK-DAG:    %[[D1:.+]] = affine.apply #[[MAP0]]()[%[[ARG1]]]
-//      CHECK:    hal.return %[[D0]], %[[D1]], %[[C1]] : index, index, index
+// CHECK-SAME:     translation.info = #[[TRANSLATION]]
 //      CHECK: linalg.matmul
-// CHECK-SAME:   lowering.config = #[[CONFIG]]
+// CHECK-SAME:     lowering.config = #[[CONFIG]]
 
 // -----
 
@@ -77,15 +60,15 @@
     #hal.descriptor_set.binding<2, storage_buffer>
   ]>
 ]>
-hal.executable private @add_no_config  {
+hal.executable private @add {
   hal.executable.variant @llvm, target = <"llvm", "embedded-elf-x86_64", {
     data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128",
     native_vector_size = 16 : index,
     target_triple = "x86_64-unknown-linux-gnu"
   }> {
-    hal.executable.entry_point @add_no_config layout(#executable_layout)
+    hal.executable.entry_point @add layout(#executable_layout)
     builtin.module {
-      func @add_no_config() {
+      func @add() {
         %c0 = arith.constant 0 : index
         %dim0 = hal.interface.constant.load[0] : index
         %dim1 = hal.interface.constant.load[1] : index
@@ -111,97 +94,16 @@
     }
   }
 }
-//  CHECK-DAG:  #[[CONFIG:.+]] = #iree_codegen.lowering.config<tile_sizes = {{\[}}[], [1, 4], [0, 0]], native_vector_size = []>
-//  CHECK-DAG:  #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"CPUDoubleTilingExpert", workload_per_wg = []>
-//      CHECK:  hal.executable private @add_no_config
-//      CHECK:  hal.executable.entry_point public @add_no_config
-// CHECK-SAME:      translation.info = #[[TRANSLATION]]
-//      CHECK:    %[[C1:.+]] = arith.constant 1 : index
-//      CHECK:    hal.return %[[C1]], %[[C1]], %[[C1]]
-//      CHECK:  func @add_no_config() {
-//      CHECK:    linalg.generic
-// CHECK-SAME:        lowering.config = #[[CONFIG]]
-
-// -----
-
-#executable_layout = #hal.executable.layout<push_constants = 0, sets = [
-  #hal.descriptor_set.layout<0, bindings = [
-    #hal.descriptor_set.binding<0, storage_buffer>,
-    #hal.descriptor_set.binding<1, storage_buffer>,
-    #hal.descriptor_set.binding<2, storage_buffer>
-  ]>
-]>
-hal.executable private @add  {
-  hal.executable.variant @llvm, target = <"llvm", "embedded-elf-x86_64", {
-       data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128",
-       native_vector_size = 16 : index,
-       target_triple = "x86_64-unknown-linux-gnu"
-    }> {
-    hal.executable.entry_point @add layout(#executable_layout)
-    builtin.module  {
-      func @add() {
-        %c0 = arith.constant 0 : index
-        %c1 = arith.constant 1 : index
-        %dim0 = hal.interface.constant.load[0] : index
-        %dim1 = hal.interface.constant.load[1] : index
-        %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) : !flow.dispatch.tensor<readonly:?x?xf32>{%dim0, %dim1}
-        %2 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) : !flow.dispatch.tensor<readonly:?xf32>{%dim1}
-        %3 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) : !flow.dispatch.tensor<writeonly:?x?xf32>{%dim0, %dim1}
-        %workgroup_size_x = hal.interface.workgroup.size[0] : index
-        %workgroup_size_y = hal.interface.workgroup.size[1] : index
-        %workgroup_id_x = hal.interface.workgroup.id[0] : index
-        %workgroup_count_x = hal.interface.workgroup.count[0] : index
-        %workgroup_id_y = hal.interface.workgroup.id[1] : index
-        %workgroup_count_y = hal.interface.workgroup.count[1] : index
-        %8 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_size_y, %workgroup_id_y]
-        %9 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_size_y, %workgroup_count_y]
-        scf.for %arg0 = %8 to %dim0 step %9 {
-          %10 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_size_x, %workgroup_id_x]
-          %11 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_size_x, %workgroup_count_x]
-          scf.for %arg1 = %10 to %dim1 step %11 {
-            %12 = affine.min affine_map<(d0)[s0, s1] -> (s0, -d0 + s1)>(%arg0)[%workgroup_size_y, %dim0]
-            %13 = affine.min affine_map<(d0)[s0, s1] -> (s0, -d0 + s1)>(%arg1)[%workgroup_size_x, %dim1]
-            %14 = flow.dispatch.tensor.load %0, offsets=[%arg0, %arg1], sizes=[%12, %13], strides=[1, 1] : !flow.dispatch.tensor<readonly:?x?xf32>{%dim0, %dim1} -> tensor<?x?xf32>
-            %15 = flow.dispatch.tensor.load %2, offsets=[%arg1], sizes=[%13], strides=[1] : !flow.dispatch.tensor<readonly:?xf32>{%dim1} -> tensor<?xf32>
-            %16 = linalg.init_tensor [%12, %13] : tensor<?x?xf32>
-            %17 = linalg.generic {
-              indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
-                               affine_map<(d0, d1) -> (d1)>,
-                               affine_map<(d0, d1) -> (d0, d1)>],
-              iterator_types = ["parallel", "parallel"]}
-              ins(%14, %15 : tensor<?x?xf32>, tensor<?xf32>) outs(%16 : tensor<?x?xf32>) {
-              ^bb0(%arg2: f32, %arg3: f32, %arg4: f32):  // no predecessors
-                %23 = arith.addf %arg2, %arg3 : f32
-                linalg.yield %23 : f32
-              } -> tensor<?x?xf32>
-            flow.dispatch.tensor.store %17, %3, offsets=[%arg0, %arg1], sizes=[%12, %13], strides=[1, 1] : tensor<?x?xf32> -> !flow.dispatch.tensor<writeonly:?x?xf32>{%dim0, %dim1}
-          }
-        }
-        return
-      }
-    }
-  }
-}
-//  CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering.config<tile_sizes = {{\[}}[], [1, 4], [0, 0]], native_vector_size = []>
-//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"CPUDoubleTilingExpert", workload_per_wg = [64, 64]>
-//  CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 64)>
+//  CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering.config<tile_sizes = {{\[}}[64, 64], [1, 4], [0, 0]], native_vector_size = []>
+//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"CPUDoubleTilingExpert", workload_per_wg = []>
 //      CHECK: hal.executable.entry_point public @add
-// CHECK-SAME:   translation.info = #[[TRANSLATION]]
-// CHECK-NEXT:   (%[[ARG0:[a-zA-Z0-9_]+]]: index
-// CHECK-SAME:    %[[ARG1:[a-zA-Z0-9_]+]]: index
-// CHECK-SAME:    %[[ARG2:[a-zA-Z0-9_]+]]: index)
-//  CHECK-DAG:    %[[C1:.+]] = arith.constant 1 : index
-//  CHECK-DAG:    %[[D0:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]]]
-//  CHECK-DAG:    %[[D1:.+]] = affine.apply #[[MAP0]]()[%[[ARG1]]]
-//      CHECK:    hal.return %[[D0]], %[[D1]], %[[C1]] : index, index, index
+// CHECK-SAME:     translation.info = #[[TRANSLATION]]
 //      CHECK: linalg.generic
 // CHECK-SAME:     lowering.config = #[[CONFIG]]
 
 // -----
 
-#map0 = affine_map<()[s0, s1] -> (s0 * s1)>
-#map1 = affine_map<(d0)[s0, s1] -> (s1, -d0 + s0)>
-#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
 #executable_layout = #hal.executable.layout<push_constants = 0, sets = [
   #hal.descriptor_set.layout<0, bindings = [
     #hal.descriptor_set.binding<0, storage_buffer>,
@@ -219,77 +121,46 @@
     hal.executable.entry_point @add4D layout(#executable_layout)
     builtin.module {
       func @add4D() {
-        %c0 = arith.constant 0 : index
-        %0 = hal.interface.constant.load[0] : index
-        %1 = hal.interface.constant.load[1] : index
-        %2 = hal.interface.constant.load[2] : index
-        %3 = hal.interface.constant.load[3] : index
-        %4 = hal.interface.constant.load[4] : index
-        %5 = hal.interface.constant.load[5] : index
-        %6 = hal.interface.constant.load[6] : index
-        %7 = hal.interface.constant.load[7] : index
-        %8 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(32) : !flow.dispatch.tensor<readonly:?x?x?x?xf32>{%0, %1, %2, %3}
-        %9 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(32) : !flow.dispatch.tensor<readonly:?x?x?x?xf32>{%4, %5, %6, %7}
-        %10 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(32) : !flow.dispatch.tensor<writeonly:?x?x?x?xf32>{%0, %1, %2, %3}
-        %workgroup_size_x = hal.interface.workgroup.size[0] : index
-        %workgroup_size_y = hal.interface.workgroup.size[1] : index
-        %workgroup_size_z = hal.interface.workgroup.size[2] : index
-        %workgroup_id_x = hal.interface.workgroup.id[0] : index
-        %workgroup_count_x = hal.interface.workgroup.count[0] : index
-        %workgroup_id_y = hal.interface.workgroup.id[1] : index
-        %workgroup_count_y = hal.interface.workgroup.count[1] : index
-        %workgroup_id_z = hal.interface.workgroup.id[2] : index
-        %workgroup_count_z = hal.interface.workgroup.count[2] : index
-        %11 = affine.apply #map0()[%workgroup_id_z, %workgroup_size_z]
-        %12 = affine.apply #map0()[%workgroup_count_z, %workgroup_size_z]
-        scf.for %arg0 = %11 to %1 step %12 {
-          %13 = affine.apply #map0()[%workgroup_id_y, %workgroup_size_y]
-          %14 = affine.apply #map0()[%workgroup_count_y, %workgroup_size_y]
-          scf.for %arg1 = %13 to %2 step %14 {
-            %15 = affine.apply #map0()[%workgroup_id_x, %workgroup_size_x]
-            %16 = affine.apply #map0()[%workgroup_count_x, %workgroup_size_x]
-            scf.for %arg2 = %15 to %3 step %16 {
-              %17 = affine.min #map1(%arg0)[%1, %workgroup_size_z]
-              %18 = affine.min #map1(%arg1)[%2, %workgroup_size_y]
-              %19 = affine.min #map1(%arg2)[%3, %workgroup_size_x]
-              %20 = flow.dispatch.tensor.load %8, offsets = [0, %arg0, %arg1, %arg2], sizes = [%0, %17, %18, %19], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:?x?x?x?xf32>{%0, %1, %2, %3} -> tensor<?x?x?x?xf32>
-              %21 = flow.dispatch.tensor.load %9, offsets = [0, %arg0, %arg1, %arg2], sizes = [%4, %17, %18, %19], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:?x?x?x?xf32>{%4, %5, %6, %7} -> tensor<?x?x?x?xf32>
-              %22 = linalg.init_tensor [%0, %17, %18, %19] : tensor<?x?x?x?xf32>
-              %23 = linalg.generic {indexing_maps = [#map2, #map2, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%20, %21 : tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) outs(%22 : tensor<?x?x?x?xf32>) {
-              ^bb0(%arg3: f32, %arg4: f32, %arg5: f32):  // no predecessors
-                %24 = arith.addf %arg3, %arg4 : f32
-                linalg.yield %24 : f32
-              } -> tensor<?x?x?x?xf32>
-              flow.dispatch.tensor.store %23, %10, offsets = [0, %arg0, %arg1, %arg2], sizes = [%0, %17, %18, %19], strides = [1, 1, 1, 1] : tensor<?x?x?x?xf32> -> !flow.dispatch.tensor<writeonly:?x?x?x?xf32>{%0, %1, %2, %3}
-            }
-          }
-        }
+        %d0 = hal.interface.constant.load[0] : index
+        %d1 = hal.interface.constant.load[1] : index
+        %d2 = hal.interface.constant.load[2] : index
+        %d3 = hal.interface.constant.load[3] : index
+        %arg1_binding = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(32)
+            : !flow.dispatch.tensor<readonly:?x?x?x?xf32>{%d0, %d1, %d2, %d3}
+        %arg2_binding = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(32)
+            : !flow.dispatch.tensor<readonly:?x?x?x?xf32>{%d0, %d1, %d2, %d3}
+        %result_binding = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(32)
+            : !flow.dispatch.tensor<writeonly:?x?x?x?xf32>{%d0, %d1, %d2, %d3}
+        %arg1 = flow.dispatch.tensor.load %arg1_binding, offsets = [0, 0, 0, 0], sizes = [%d0, %d1, %d2, %d3], strides = [1, 1, 1, 1]
+            : !flow.dispatch.tensor<readonly:?x?x?x?xf32>{%d0, %d1, %d2, %d3} -> tensor<?x?x?x?xf32>
+        %arg2 = flow.dispatch.tensor.load %arg2_binding, offsets = [0, 0, 0, 0], sizes = [%d0, %d1, %d2, %d3], strides = [1, 1, 1, 1]
+            : !flow.dispatch.tensor<readonly:?x?x?x?xf32>{%d0, %d1, %d2, %d3} -> tensor<?x?x?x?xf32>
+        %init = linalg.init_tensor [%d0, %d1, %d2, %d3] : tensor<?x?x?x?xf32>
+        %add = linalg.generic {
+            indexing_maps = [#map, #map, #map],
+            iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
+            ins(%arg1, %arg2 : tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) outs(%init : tensor<?x?x?x?xf32>) {
+            ^bb0(%b0: f32, %b1: f32, %b2: f32):  // no predecessors
+              %addf = arith.addf %b0, %b1 : f32
+              linalg.yield %addf : f32
+            } -> tensor<?x?x?x?xf32>
+        flow.dispatch.tensor.store %add, %result_binding, offsets = [0, 0, 0, 0], sizes = [%d0, %d1, %d2, %d3], strides = [1, 1, 1, 1]
+            : tensor<?x?x?x?xf32> -> !flow.dispatch.tensor<writeonly:?x?x?x?xf32>{%d0, %d1, %d2, %d3}
         return
       }
     }
   }
 }
 
-//  CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering.config<tile_sizes = {{\[}}[], [1, 1, 1, 4], [0, 0, 0, 0]], native_vector_size = []>
-//  CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 64)>
-//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"CPUDoubleTilingExpert", workload_per_wg = [64, 64, 64]>
+//  CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering.config<tile_sizes = {{\[}}[0, 64, 64, 64], [1, 1, 1, 4], [0, 0, 0, 0]], native_vector_size = []>
+//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"CPUDoubleTilingExpert", workload_per_wg = []>
 //      CHECK: hal.executable.entry_point public @add4D
-// CHECK-SAME:   translation.info = #[[TRANSLATION]]
-// CHECK-NEXT:   (%[[ARG0:[a-zA-Z0-9_]+]]: index
-// CHECK-SAME:    %[[ARG1:[a-zA-Z0-9_]+]]: index
-// CHECK-SAME:    %[[ARG2:[a-zA-Z0-9_]+]]: index)
-//  CHECK-DAG:    %[[D0:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]]]
-//  CHECK-DAG:    %[[D1:.+]] = affine.apply #[[MAP0]]()[%[[ARG1]]]
-//  CHECK-DAG:    %[[D2:.+]] = affine.apply #[[MAP0]]()[%[[ARG2]]]
-//      CHECK:    hal.return %[[D0]], %[[D1]], %[[D2]] : index, index, index
+// CHECK-SAME:     translation.info = #[[TRANSLATION]]
 //      CHECK: linalg.generic
 // CHECK-SAME:     lowering.config = #[[CONFIG]]
 
 // -----
 
-#map0 = affine_map<()[s0, s1] -> (s0 * s1)>
-#map1 = affine_map<(d0)[s0, s1] -> (s1, -d0 + s0)>
-#map2 = affine_map<(d0)[s0, s1] -> (-d0 + s0, s1)>
 #executable_layout = #hal.executable.layout<push_constants = 0, sets = [
   #hal.descriptor_set.layout<0, bindings = [
     #hal.descriptor_set.binding<0, storage_buffer>,
@@ -307,73 +178,43 @@
     builtin.module {
       func @batch_matmul_tensors() {
         %cst = arith.constant 0.000000e+00 : f32
-        %c0 = arith.constant 0 : index
-        %0 = hal.interface.constant.load[0] : index
-        %1 = hal.interface.constant.load[1] : index
-        %2 = hal.interface.constant.load[2] : index
-        %3 = hal.interface.constant.load[3] : index
-        %4 = hal.interface.constant.load[4] : index
-        %5 = hal.interface.constant.load[5] : index
-        %6 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(32) : !flow.dispatch.tensor<readonly:?x?x?xf32>{%0, %1, %2}
-        %7 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(32) : !flow.dispatch.tensor<readonly:?x?x?xf32>{%3, %4, %5}
-        %8 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(32) : !flow.dispatch.tensor<writeonly:?x?x?xf32>{%0, %1, %5}
-        %workgroup_size_x = hal.interface.workgroup.size[0] : index
-        %workgroup_size_y = hal.interface.workgroup.size[1] : index
-        %workgroup_size_z = hal.interface.workgroup.size[2] : index
-        %workgroup_id_x = hal.interface.workgroup.id[0] : index
-        %workgroup_count_x = hal.interface.workgroup.count[0] : index
-        %workgroup_id_y = hal.interface.workgroup.id[1] : index
-        %workgroup_count_y = hal.interface.workgroup.count[1] : index
-        %workgroup_id_z = hal.interface.workgroup.id[2] : index
-        %workgroup_count_z = hal.interface.workgroup.count[2] : index
-        %9 = affine.apply #map0()[%workgroup_id_z, %workgroup_size_z]
-        %10 = affine.apply #map0()[%workgroup_count_z, %workgroup_size_z]
-        scf.for %arg0 = %9 to %0 step %10 {
-          %11 = affine.apply #map0()[%workgroup_id_y, %workgroup_size_y]
-          %12 = affine.apply #map0()[%workgroup_count_y, %workgroup_size_y]
-          scf.for %arg1 = %11 to %1 step %12 {
-            %13 = affine.apply #map0()[%workgroup_id_x, %workgroup_size_x]
-            %14 = affine.apply #map0()[%workgroup_count_x, %workgroup_size_x]
-            scf.for %arg2 = %13 to %5 step %14 {
-              %15 = affine.min #map1(%arg0)[%0, %workgroup_size_z]
-              %16 = affine.min #map1(%arg1)[%1, %workgroup_size_y]
-              %17 = flow.dispatch.tensor.load %6, offsets = [%arg0, %arg1, 0], sizes = [%15, %16, %2], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:?x?x?xf32>{%0, %1, %2} -> tensor<?x?x?xf32>
-              %18 = affine.min #map1(%arg2)[%5, %workgroup_size_x]
-              %19 = flow.dispatch.tensor.load %7, offsets = [%arg0, 0, %arg2], sizes = [%15, %4, %18], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:?x?x?xf32>{%3, %4, %5} -> tensor<?x?x?xf32>
-              %20 = affine.min #map2(%arg0)[%0, %workgroup_size_z]
-              %21 = affine.min #map2(%arg1)[%1, %workgroup_size_y]
-              %22 = affine.min #map2(%arg2)[%5, %workgroup_size_x]
-              %23 = linalg.init_tensor [%20, %21, %22] : tensor<?x?x?xf32>
-              %24 = linalg.fill(%cst, %23) : f32, tensor<?x?x?xf32> -> tensor<?x?x?xf32>
-              %25 = linalg.batch_matmul ins(%17, %19 : tensor<?x?x?xf32>, tensor<?x?x?xf32>) outs(%24 : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
-              flow.dispatch.tensor.store %25, %8, offsets = [%arg0, %arg1, %arg2], sizes = [%15, %16, %18], strides = [1, 1, 1] : tensor<?x?x?xf32> -> !flow.dispatch.tensor<writeonly:?x?x?xf32>{%0, %1, %5}
-            }
-          }
-        }
+        %B = hal.interface.constant.load[0] : index
+        %M = hal.interface.constant.load[1] : index
+        %N = hal.interface.constant.load[2] : index
+        %K = hal.interface.constant.load[3] : index
+        %lhs_binding = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(32)
+            : !flow.dispatch.tensor<readonly:?x?x?xf32>{%B, %M, %K}
+        %rhs_binding = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(32)
+            : !flow.dispatch.tensor<readonly:?x?x?xf32>{%B, %K, %N}
+        %result_binding = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(32)
+            : !flow.dispatch.tensor<writeonly:?x?x?xf32>{%B, %M, %N}
+        %lhs = flow.dispatch.tensor.load %lhs_binding, offsets = [0, 0, 0], sizes = [%B, %M, %K], strides = [1, 1, 1]
+            : !flow.dispatch.tensor<readonly:?x?x?xf32>{%B, %M, %K} -> tensor<?x?x?xf32>
+        %rhs = flow.dispatch.tensor.load %rhs_binding, offsets = [0, 0, 0], sizes = [%B, %K, %N], strides = [1, 1, 1]
+            : !flow.dispatch.tensor<readonly:?x?x?xf32>{%B, %K, %N} -> tensor<?x?x?xf32>
+        %init = linalg.init_tensor [%B, %M, %N] : tensor<?x?x?xf32>
+        %fill = linalg.fill(%cst, %init) : f32, tensor<?x?x?xf32> -> tensor<?x?x?xf32>
+        %batch_gemm = linalg.batch_matmul
+            ins(%lhs, %rhs : tensor<?x?x?xf32>, tensor<?x?x?xf32>) outs(%fill : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+        flow.dispatch.tensor.store %batch_gemm, %result_binding, offsets = [0, 0, 0], sizes = [%B, %M, %N], strides = [1, 1, 1]
+            : tensor<?x?x?xf32> -> !flow.dispatch.tensor<writeonly:?x?x?xf32>{%B, %M, %N}
         return
       }
     }
   }
 }
-
-//  CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering.config<tile_sizes = {{\[}}[], [1, 16, 4, 64], [1, 4, 4, 4]{{\]}}, native_vector_size = [1, 4, 4, 4]>
-//  CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 64)>
-//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"CPUTileFuseAndVectorize", workload_per_wg = [64, 64, 1]>
+//  CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering.config<tile_sizes = {{\[}}[1, 64, 64, 0], [1, 16, 4, 64], [1, 4, 4, 4]{{\]}}, native_vector_size = [1, 4, 4, 4]>
+//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"CPUTileFuseAndVectorize", workload_per_wg = []>
 //      CHECK: hal.executable.entry_point public @batch_matmul_tensors
-// CHECK-NEXT: (%[[ARG0:[a-zA-Z0-9]+]]: index
-// CHECK-SAME:  %[[ARG1:[a-zA-Z0-9]+]]: index
-// CHECK-SAME:  %[[ARG2:[a-zA-Z0-9]+]]: index)
-//  CHECK-DAG:  %[[D0:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]]]
-//  CHECK-DAG:  %[[D1:.+]] = affine.apply #[[MAP0]]()[%[[ARG1]]]
-//      CHECK:  hal.return %[[D0]], %[[D1]], %[[ARG2]]
+// CHECK-SAME:     translation.info = #[[TRANSLATION]]
 //      CHECK:  linalg.batch_matmul
-// CHECK-SAME:    lowering.config = #[[CONFIG]]
+// CHECK-SAME:     lowering.config = #[[CONFIG]]
 
 // -----
 
 #compilation = #iree_codegen.compilation.info<
-    #iree_codegen.lowering.config<tile_sizes = [[], [32, 32, 0], [0, 0, 32]], native_vector_size = []>,
-    #iree_codegen.translation.info<"CPUDoubleTilingExpert", workload_per_wg = [32, 32]>,
+    #iree_codegen.lowering.config<tile_sizes = [[64, 64, 0], [32, 32, 0], [0, 0, 32]], native_vector_size = []>,
+    #iree_codegen.translation.info<"CPUDoubleTilingExpert", workload_per_wg = []>,
     workgroup_size = []>
 #executable_layout = #hal.executable.layout<push_constants = 0, sets = [
   #hal.descriptor_set.layout<0, bindings = [
@@ -387,74 +228,37 @@
     hal.executable.entry_point @preset_config layout(#executable_layout)
     builtin.module {
       builtin.func @preset_config() {
-        %c0 = arith.constant 0 : index
-        %c512 = arith.constant 512 : index
-        %c128 = arith.constant 128 : index
         %cst = arith.constant 0.000000e+00 : f32
-        %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) : !flow.dispatch.tensor<readonly:128x256xf32>
-        %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) : !flow.dispatch.tensor<readonly:256x512xf32>
-        %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) : !flow.dispatch.tensor<writeonly:128x512xf32>
-        %workgroup_size_x = hal.interface.workgroup.size[0] : index
-        %workgroup_size_y = hal.interface.workgroup.size[1] : index
-        %workgroup_id_x = hal.interface.workgroup.id[0] : index
-        %workgroup_count_x = hal.interface.workgroup.count[0] : index
-        %workgroup_id_y = hal.interface.workgroup.id[1] : index
-        %workgroup_count_y = hal.interface.workgroup.count[1] : index
-        %3 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_y, %workgroup_size_y]
-        %4 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_y, %workgroup_size_y]
-        scf.for %arg0 = %3 to %c128 step %4 {
-          %5 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_x, %workgroup_size_x]
-          %6 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_x, %workgroup_size_x]
-          scf.for %arg1 = %5 to %c512 step %6 {
-            %7 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 128)>(%arg0)[%workgroup_size_y]
-            %8 = flow.dispatch.tensor.load %0, offsets = [%arg0, 0], sizes = [%7, 256], strides = [1, 1] : !flow.dispatch.tensor<readonly:128x256xf32> -> tensor<?x256xf32>
-            %9 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 512)>(%arg1)[%workgroup_size_x]
-            %10 = flow.dispatch.tensor.load %1, offsets = [0, %arg1], sizes = [256, %9], strides = [1, 1] : !flow.dispatch.tensor<readonly:256x512xf32> -> tensor<256x?xf32>
-            %11 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 128)>(%arg0)[%workgroup_size_y]
-            %12 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 512)>(%arg1)[%workgroup_size_x]
-            %13 = affine.min affine_map<(d0)[s0] -> (-d0 + 128, s0)>(%arg0)[%workgroup_size_y]
-            %14 = affine.min affine_map<(d0)[s0] -> (-d0 + 512, s0)>(%arg1)[%workgroup_size_x]
-            %15 = linalg.init_tensor [%13, %14] : tensor<?x?xf32>
-            %16 = linalg.fill(%cst, %15) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
-            %17 = linalg.matmul {compilation.info = #compilation}
-                 ins(%8, %10 : tensor<?x256xf32>, tensor<256x?xf32>)
-                 outs(%16 : tensor<?x?xf32>) -> tensor<?x?xf32>
-            flow.dispatch.tensor.store %17, %2, offsets = [%arg0, %arg1], sizes = [%11, %12], strides = [1, 1] : tensor<?x?xf32> -> !flow.dispatch.tensor<writeonly:128x512xf32>
-          }
-        }
+        %lhs_binding = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer)
+            : !flow.dispatch.tensor<readonly:128x256xf32>
+        %rhs_binding = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer)
+            : !flow.dispatch.tensor<readonly:256x512xf32>
+        %result_binding = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer)
+            : !flow.dispatch.tensor<writeonly:128x512xf32>
+        %lhs = flow.dispatch.tensor.load %lhs_binding, offsets = [0, 0], sizes = [128, 256], strides = [1, 1]
+            : !flow.dispatch.tensor<readonly:128x256xf32> -> tensor<128x256xf32>
+        %rhs = flow.dispatch.tensor.load %rhs_binding, offsets = [0, 0], sizes = [256, 512], strides = [1, 1]
+            : !flow.dispatch.tensor<readonly:256x512xf32> -> tensor<256x512xf32>
+        %init = linalg.init_tensor [128, 512] : tensor<128x512xf32>
+        %fill = linalg.fill(%cst, %init) : f32, tensor<128x512xf32> -> tensor<128x512xf32>
+        %gemm = linalg.matmul {compilation.info = #compilation}
+            ins(%lhs, %rhs : tensor<128x256xf32>, tensor<256x512xf32>)
+            outs(%fill : tensor<128x512xf32>) -> tensor<128x512xf32>
+        flow.dispatch.tensor.store %gemm, %result_binding, offsets = [0, 0], sizes = [128, 512], strides = [1, 1]
+            : tensor<128x512xf32> -> !flow.dispatch.tensor<writeonly:128x512xf32>
         return
       }
     }
   }
 }
 
-//  CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering.config<tile_sizes = {{\[}}[], [32, 32, 0], [0, 0, 32]], native_vector_size = []>
-//  CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 32)>
-//  CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 32)>
-//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"CPUDoubleTilingExpert", workload_per_wg = [32, 32]>
+//  CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering.config<tile_sizes = {{\[}}[64, 64, 0], [32, 32, 0], [0, 0, 32]], native_vector_size = []>
+//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"CPUDoubleTilingExpert", workload_per_wg = []>
 //      CHECK: hal.executable.entry_point
 // CHECK-SAME:     translation.info = #[[TRANSLATION]]
-// CHECK-NEXT:   ^bb0(%[[ARG0:[a-zA-Z0-9]+]]: index, %[[ARG1:[a-zA-Z0-9]+]]: index
-//  CHECK-DAG:     %[[C1:.+]] = arith.constant 1 : index
-//  CHECK-DAG:     %[[NWG_X:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]]]
-//  CHECK-DAG:     %[[NWG_Y:.+]] = affine.apply #[[MAP0]]()[%[[ARG1]]]
-//      CHECK:     return %[[NWG_X]], %[[NWG_Y]], %[[C1]]
-//      CHECK: builtin.module
-//      CHECK:   func @preset_config
-//  CHECK-DAG:     %[[WGID_X:.+]] = hal.interface.workgroup.id[0]
-//  CHECK-DAG:     %[[WGCOUNT_X:.+]] = hal.interface.workgroup.count[0]
-//  CHECK-DAG:     %[[WGID_Y:.+]] = hal.interface.workgroup.id[1]
-//  CHECK-DAG:     %[[WGCOUNT_Y:.+]] = hal.interface.workgroup.count[1]
-//      CHECK:     %[[LB_Y:.+]] = affine.apply #[[MAP1]]()[%[[WGID_Y]]]
-//      CHECK:     %[[STEP_Y:.+]] = affine.apply #[[MAP1]]()[%[[WGCOUNT_Y]]]
-//      CHECK:     scf.for %[[IV0:.+]] = %[[LB_Y]] to %{{.+}} step %[[STEP_Y]]
-//      CHECK:       %[[LB_X:.+]] = affine.apply #[[MAP1]]()[%[[WGID_X]]]
-//      CHECK:       %[[STEP_X:.+]] = affine.apply #[[MAP1]]()[%[[WGCOUNT_X]]]
-//      CHECK:       scf.for %[[IV1:.+]] = %[[LB_X]] to %{{.+}} step %[[STEP_X]]
-//      CHECK:         linalg.matmul
-// CHECK-SAME:             lowering.config = #[[CONFIG]]
-// CHECK-SAME:             ins(%{{.+}}, %{{.+}} : tensor<32x256xf32>, tensor<256x32xf32>)
-// CHECK-SAME:             outs(%{{.+}} : tensor<32x32xf32>)
+//      CHECK: func @preset_config
+//      CHECK:   linalg.matmul
+// CHECK-SAME:       lowering.config = #[[CONFIG]]
 
 // -----
 
@@ -469,49 +273,79 @@
     hal.executable.entry_point @tensor_insert_slice layout(#executable_layout)
     builtin.module {
       builtin.func @tensor_insert_slice() {
-        %c0 = arith.constant 0 : index
-        %1 = hal.interface.constant.load[0] : index
-        %2 = hal.interface.constant.load[1] : index
-        %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) : !flow.dispatch.tensor<readonly:?x?xi32>{%1, %2}
-        %3 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) : !flow.dispatch.tensor<writeonly:?x?xi32>{%1, %2}
-        %workgroup_size_x = hal.interface.workgroup.size[0] : index
-        %workgroup_size_y = hal.interface.workgroup.size[1] : index
-        %workgroup_id_x = hal.interface.workgroup.id[0] : index
-        %workgroup_count_x = hal.interface.workgroup.count[0] : index
-        %workgroup_id_y = hal.interface.workgroup.id[1] : index
-        %workgroup_count_y = hal.interface.workgroup.count[1] : index
-        %4 = affine.apply affine_map<()[s0, s1] -> (s1 * s0)>()[%workgroup_size_y, %workgroup_id_y]
-        %5 = affine.apply affine_map<()[s0, s1] -> (s1 * s0)>()[%workgroup_size_y, %workgroup_count_y]
-        %d0 = hal.interface.constant.load[2] : index
-        %d1 = hal.interface.constant.load[2] : index
-        scf.for %arg0 = %4 to %d0 step %5 {
-          %6 = affine.min affine_map<(d0)[s0, s1] -> (s0, -d0 + s1)>(%arg0)[%workgroup_size_y, %d0]
-          %7 = affine.apply affine_map<()[s0, s1] -> (s1 * s0)>()[%workgroup_size_x, %workgroup_id_x]
-          %8 = affine.apply affine_map<()[s0, s1] -> (s1 * s0)>()[%workgroup_size_x, %workgroup_count_x]
-          scf.for %arg1 = %7 to %d1 step %8 {
-            %9 = affine.min affine_map<(d0)[s0, s1] -> (s0, -d0 + s1)>(%arg1)[%workgroup_size_x, %d1]
-            %10 = flow.dispatch.tensor.load %0, offsets = [%arg0, %arg1], sizes = [%6, %9], strides = [1, 1] : !flow.dispatch.tensor<readonly:?x?xi32>{%1, %2} -> tensor<?x?xi32>
-            %11 = affine.apply affine_map<(d0)[s0] -> (d0 + s0)>(%arg0)[%1]
-            %12 = affine.apply affine_map<(d0)[s0] -> (d0 + s0)>(%arg1)[%2]
-            flow.dispatch.tensor.store %10, %3, offsets = [%11, %12], sizes = [%6, %9], strides = [1, 1] : tensor<?x?xi32> -> !flow.dispatch.tensor<writeonly:?x?xi32>{%1, %2}
-          }
-        }
+        %d0 = hal.interface.constant.load[0] : index
+        %d1 = hal.interface.constant.load[1] : index
+        %d2 = hal.interface.constant.load[2] : index
+        %d3 = hal.interface.constant.load[3] : index
+        %o0 = hal.interface.constant.load[4] : index
+        %o1 = hal.interface.constant.load[5] : index
+        %source_binding = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer)
+            : !flow.dispatch.tensor<readonly:?x?xi32>{%d0, %d1}
+        %dest_binding = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer)
+            : !flow.dispatch.tensor<readwrite:?x?xi32>{%d2, %d3}
+        %source = flow.dispatch.tensor.load %source_binding, offsets = [0, 0], sizes = [%d0, %d1], strides = [1, 1]
+            : !flow.dispatch.tensor<readonly:?x?xi32>{%d0, %d1} -> tensor<?x?xi32>
+        %dest = flow.dispatch.tensor.load %dest_binding, offsets = [0, 0], sizes = [%d0, %d1], strides = [1, 1]
+            : !flow.dispatch.tensor<readwrite:?x?xi32>{%d2, %d3} -> tensor<?x?xi32>
+        %result = tensor.insert_slice %source into %dest[%o0, %o1] [%d0, %d1] [1, 1] : tensor<?x?xi32> into tensor<?x?xi32>
+        flow.dispatch.tensor.store %result, %dest_binding, offsets = [0, 0], sizes = [%d2, %d3], strides = [1, 1]
+            : tensor<?x?xi32> -> !flow.dispatch.tensor<readwrite:?x?xi32>{%d2, %d3}
         return
       }
     }
   }
 }
 
-//  CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0] -> (s0 ceildiv 64)>
-//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"CPUDefault", workload_per_wg = [64, 64]>
+//  CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering.config<tile_sizes = {{\[}}[64, 64]{{\]}}, native_vector_size = []>
+//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"CPUDefault", workload_per_wg = []>
 //      CHECK: hal.executable.entry_point public @tensor_insert_slice
-// CHECK-SAME:   translation.info = #[[TRANSLATION]]
-// CHECK-NEXT:   %[[ARG0:[a-zA-Z0-9_]+]]: index
-// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: index
-//  CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
-//  CHECK-DAG:   %[[NWGSX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]]]
-//  CHECK-DAG:   %[[NWGSY:.+]] = affine.apply #[[MAP0]]()[%[[ARG1]]]
-//      CHECK:   hal.return %[[NWGSX]], %[[NWGSY]], %[[C1]]
+// CHECK-SAME:     translation.info = #[[TRANSLATION]]
+//      CHECK: tensor.insert_slice
+// CHECK-SAME:     lowering.config = #[[CONFIG]]
+
+// -----
+
+#executable_layout = #hal.executable.layout<push_constants = 0, sets = [
+  #hal.descriptor_set.layout<0, bindings = [
+    #hal.descriptor_set.binding<0, storage_buffer>,
+    #hal.descriptor_set.binding<1, storage_buffer>
+  ]>
+]>
+hal.executable @extract_slice {
+  hal.executable.variant @system_elf_x86_64, target = <"llvm", "system-elf-x86_64"> {
+    hal.executable.entry_point @extract_slice layout(#executable_layout)
+    builtin.module {
+      builtin.func @extract_slice() {
+        %d0 = hal.interface.constant.load[0] : index
+        %d1 = hal.interface.constant.load[1] : index
+        %d2 = hal.interface.constant.load[2] : index
+        %d3 = hal.interface.constant.load[3] : index
+        %o0 = hal.interface.constant.load[4] : index
+        %o1 = hal.interface.constant.load[5] : index
+        %source_binding = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer)
+            : !flow.dispatch.tensor<readonly:?x?xi32>{%d0, %d1}
+        %dest_binding = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer)
+            : !flow.dispatch.tensor<writeonly:?x?xi32>{%d2, %d3}
+        %source = flow.dispatch.tensor.load %source_binding, offsets = [0, 0], sizes = [%d0, %d1], strides = [1, 1]
+            : !flow.dispatch.tensor<readonly:?x?xi32>{%d0, %d1} -> tensor<?x?xi32>
+        %dest = flow.dispatch.tensor.load %dest_binding, offsets = [0, 0], sizes = [%d0, %d1], strides = [1, 1]
+            : !flow.dispatch.tensor<writeonly:?x?xi32>{%d2, %d3} -> tensor<?x?xi32>
+        %result = tensor.extract_slice %source[%o0, %o1] [%d0, %d1] [1, 1] : tensor<?x?xi32> to tensor<?x?xi32>
+        flow.dispatch.tensor.store %result, %dest_binding, offsets = [0, 0], sizes = [%d2, %d3], strides = [1, 1]
+            : tensor<?x?xi32> -> !flow.dispatch.tensor<writeonly:?x?xi32>{%d2, %d3}
+        return
+      }
+    }
+  }
+}
+
+//  CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering.config<tile_sizes = {{\[}}[64, 64]{{\]}}, native_vector_size = []>
+//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"CPUDefault", workload_per_wg = []>
+//      CHECK: hal.executable.entry_point public @extract_slice
+// CHECK-SAME:     translation.info = #[[TRANSLATION]]
+//      CHECK: tensor.extract_slice
+// CHECK-SAME:     lowering.config = #[[CONFIG]]
+
 
 // -----
 
@@ -544,18 +378,12 @@
 }
 
 //   CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering.config<tile_sizes = {{\[}}[64]{{\]}}, native_vector_size = []>
-//   CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 64)>
-//   CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"CPUDefault", workload_per_wg = [64]>
+//   CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"CPUDefault", workload_per_wg = []>
 //       CHECK: hal.executable.entry_point public @static_1d_fft_stage2
-//  CHECK-SAME:   translation.info = #[[TRANSLATION]]
-//  CHECK-NEXT: ^{{.+}}(%[[ARG0:.+]]: index, %[[ARG1:.+]]: index, %[[ARG2:.+]]: index):
-//  CHECK-NEXT:   %[[C1:.+]] = arith.constant 1 : index
-//  CHECK-NEXT:   %[[T0:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]]]
-//  CHECK-NEXT:   hal.return %[[T0]], %[[C1]], %[[C1]]
-
+//  CHECK-SAME:     translation.info = #[[TRANSLATION]]
 //       CHECK: func @static_1d_fft_stage2()
 //       CHECK:   iree_linalg_ext.fft
-//  CHECK-SAME:     lowering.config = #[[CONFIG]]
+//  CHECK-SAME:       lowering.config = #[[CONFIG]]
 
 // -----
 
@@ -570,45 +398,16 @@
     hal.executable.entry_point @static_3d_fft_stage3 layout(#executable_layout)
     builtin.module {
       builtin.func @static_3d_fft_stage3() {
-        %c0 = arith.constant 0 : index
         %c3 = arith.constant 3 : index
-        %c64 = arith.constant 64 : index
-        %c128 = arith.constant 128 : index
-        %c32 = arith.constant 32 : index
         %cst = arith.constant dense<[1.000000e+00, 0.707106769, 6.12323426E-17, -0.707106769]> : tensor<4xf32>
         %cst_0 = arith.constant dense<[-0.000000e+00, -0.707106769, -1.000000e+00, -0.707106769]> : tensor<4xf32>
         %0 = bufferization.to_memref %cst_0 : memref<4xf32>
         %1 = bufferization.to_memref %cst : memref<4xf32>
         %2 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) : memref<64x128x32xf32>
         %3 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) : memref<64x128x32xf32>
-        %workgroup_id_x = hal.interface.workgroup.id[0] : index
-        %workgroup_count_x = hal.interface.workgroup.count[0] : index
-        %workgroup_size_x = hal.interface.workgroup.size[0] : index
-        %workgroup_id_y = hal.interface.workgroup.id[1] : index
-        %workgroup_count_y = hal.interface.workgroup.count[1] : index
-        %workgroup_size_y = hal.interface.workgroup.size[1] : index
-        %workgroup_id_z = hal.interface.workgroup.id[2] : index
-        %workgroup_count_z = hal.interface.workgroup.count[2] : index
-        %workgroup_size_z = hal.interface.workgroup.size[2] : index
-        %lb_z = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_z, %workgroup_size_z]
-        %step_z = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_z, %workgroup_size_z]
-        scf.for %arg0 = %workgroup_id_z to %c64 step %workgroup_count_z {
-          %lb_y = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_y, %workgroup_size_y]
-          %step_y = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_y, %workgroup_size_y]
-          scf.for %arg1 = %workgroup_id_y to %c128 step %workgroup_count_y {
-            %4 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_x, %workgroup_size_x]
-            %5 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_x, %workgroup_size_x]
-            scf.for %arg2 = %4 to %c32 step %5 {
-              %6 = memref.subview %2[%arg0, %arg1, %arg2] [1, 1, 4] [1, 1, 1] : memref<64x128x32xf32> to memref<1x1x4xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 4096 + s0 + d1 * 32 + d2)>>
-              %7 = memref.cast %6 : memref<1x1x4xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 4096 + s0 + d1 * 32 + d2)>> to memref<?x?x?xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 4096 + s0 + d1 * 32 + d2)>>
-              %8 = memref.subview %3[%arg0, %arg1, %arg2] [1, 1, 4] [1, 1, 1] : memref<64x128x32xf32> to memref<1x1x4xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 4096 + s0 + d1 * 32 + d2)>>
-              %9 = memref.cast %8 : memref<1x1x4xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 4096 + s0 + d1 * 32 + d2)>> to memref<?x?x?xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 4096 + s0 + d1 * 32 + d2)>>
-              iree_linalg_ext.fft
-                ins(%c3, %1, %0 : index, memref<4xf32>, memref<4xf32>)
-                outs(%7, %9 : memref<?x?x?xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 4096 + s0 + d1 * 32 + d2)>>, memref<?x?x?xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 4096 + s0 + d1 * 32 + d2)>>)
-            }
-          }
-        }
+        iree_linalg_ext.fft
+            ins(%c3, %1, %0 : index, memref<4xf32>, memref<4xf32>)
+            outs(%2, %3 : memref<64x128x32xf32>, memref<64x128x32xf32>)
         return
       }
     }
@@ -616,19 +415,12 @@
 }
 
 //   CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering.config<tile_sizes = {{\[}}[64, 64, 64]{{\]}}, native_vector_size = []>
-//   CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 64)>
-//   CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"CPUDefault", workload_per_wg = [64, 64, 64]>
+//   CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"CPUDefault", workload_per_wg = []>
 //       CHECK: hal.executable.entry_point public @static_3d_fft_stage3
-//  CHECK-SAME:   translation.info = #[[TRANSLATION]]
-//  CHECK-NEXT: ^{{.+}}(%[[ARG0:.+]]: index, %[[ARG1:.+]]: index, %[[ARG2:.+]]: index):
-//  CHECK-NEXT:   %[[T0:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]]]
-//  CHECK-NEXT:   %[[T1:.+]] = affine.apply #[[MAP0]]()[%[[ARG1]]]
-//  CHECK-NEXT:   %[[T2:.+]] = affine.apply #[[MAP0]]()[%[[ARG2]]]
-//  CHECK-NEXT:   hal.return %[[T0]], %[[T1]], %[[T2]]
-
+//  CHECK-SAME:     translation.info = #[[TRANSLATION]]
 //       CHECK: func @static_3d_fft_stage3()
 //       CHECK:   iree_linalg_ext.fft
-//  CHECK-SAME:     lowering.config = #[[CONFIG]]
+//  CHECK-SAME:       lowering.config = #[[CONFIG]]
 
 // -----
 
@@ -644,39 +436,28 @@
     hal.executable.entry_point @outs_fusion_fn layout(#executable_layout)
     builtin.module {
       builtin.func @outs_fusion_fn() {
-        %c0 = arith.constant 0 : index
         %cst = arith.constant 0.0 : f32
-        %2 = hal.interface.constant.load[0] : index
-        %3 = hal.interface.constant.load[1] : index
-        %4 = hal.interface.constant.load[2] : index
-        %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) : !flow.dispatch.tensor<readonly:?x?xf32>{%2, %3}
-        %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) : !flow.dispatch.tensor<readonly:?x?xf32>{%2, %3}
-        %5 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) : !flow.dispatch.tensor<writeonly:?x?xf32>{%2, %3}
-        %workgroup_id_x = hal.interface.workgroup.id[0] : index
-        %workgroup_count_x = hal.interface.workgroup.count[0] : index
-        %workgroup_size_x = hal.interface.workgroup.size[0] : index
-        %workgroup_id_y = hal.interface.workgroup.id[1] : index
-        %workgroup_count_y = hal.interface.workgroup.count[1] : index
-        %workgroup_size_y = hal.interface.workgroup.size[1] : index
-        %lb_y = affine.apply affine_map<(d0)[s0] -> (d0 * s0)>(%workgroup_id_y)[%workgroup_size_y]
-        %step_y = affine.apply affine_map<(d0)[s0] -> (d0 * s0)>(%workgroup_count_y)[%workgroup_size_y]
-        %lb_x = affine.apply affine_map<(d0)[s0] -> (d0 * s0)>(%workgroup_id_x)[%workgroup_size_x]
-        %step_x = affine.apply affine_map<(d0)[s0] -> (d0 * s0)>(%workgroup_count_x)[%workgroup_size_x]
-        scf.for %iv0 = %lb_y to %2 step %step_y {
-          scf.for %iv1 = %lb_x to %3 step %step_x {
-            %tile_m = affine.min affine_map<(d0)[s0, s1] -> (s0, -d0 + s1)>(%iv0)[%workgroup_size_y, %2]
-            %tile_n = affine.min affine_map<(d0)[s0, s1] -> (s0, -d0 + s1)>(%iv1)[%workgroup_size_x, %3]
-            %init = linalg.init_tensor[%tile_m, %tile_n] : tensor<?x?xf32>
-            %fill = linalg.generic {
-                indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>],
-                iterator_types = ["parallel", "parallel"]}
-                outs(%init : tensor<?x?xf32>) {
-                ^bb0(%arg0: f32):
+        %d0 = hal.interface.constant.load[0] : index
+        %d1 = hal.interface.constant.load[1] : index
+        %d2 = hal.interface.constant.load[2] : index
+        %lhs_binding = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer)
+            : !flow.dispatch.tensor<readonly:?x?xf32>{%d0, %d2}
+        %rhs_binding = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer)
+            : !flow.dispatch.tensor<readonly:?x?xf32>{%d2, %d1}
+        %result_binding = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer)
+            : !flow.dispatch.tensor<writeonly:?x?xf32>{%d0, %d1}
+        %init = linalg.init_tensor[%d0, %d1] : tensor<?x?xf32>
+        %fill = linalg.generic {
+              indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]}
+              outs(%init : tensor<?x?xf32>) {
+              ^bb0(%b0: f32):
                   linalg.yield %cst : f32
-                } -> tensor<?x?xf32>
-            %lhs = flow.dispatch.tensor.load %0, offsets = [%iv0, 0], sizes = [%tile_m, %4], strides = [1, 1] : !flow.dispatch.tensor<readonly:?x?xf32>{%2, %3} -> tensor<?x?xf32>
-            %rhs = flow.dispatch.tensor.load %0, offsets = [0, %iv1], sizes = [%4, %tile_n], strides = [1, 1] : !flow.dispatch.tensor<readonly:?x?xf32>{%2, %3} -> tensor<?x?xf32>
-            %gemm = linalg.generic {
+              } -> tensor<?x?xf32>
+        %lhs = flow.dispatch.tensor.load %lhs_binding, offsets = [0, 0], sizes = [%d0, %d2], strides = [1, 1]
+            : !flow.dispatch.tensor<readonly:?x?xf32>{%d0, %d2} -> tensor<?x?xf32>
+        %rhs = flow.dispatch.tensor.load %rhs_binding, offsets = [0, 0], sizes = [%d2, %d1], strides = [1, 1]
+            : !flow.dispatch.tensor<readonly:?x?xf32>{%d2, %d1} -> tensor<?x?xf32>
+        %gemm = linalg.generic {
                 indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>,
                                  affine_map<(d0, d1, d2) -> (d2, d1)>,
                                  affine_map<(d0, d1, d2) -> (d0, d1)>],
@@ -688,16 +469,15 @@
                   %7 = arith.addf %6, %arg2 : f32
                   linalg.yield %6 : f32
                 } -> tensor<?x?xf32>
-            flow.dispatch.tensor.store %gemm, %5, offsets = [%iv0, %iv1], sizes = [%tile_m, %tile_n], strides = [1, 1] : tensor<?x?xf32> -> !flow.dispatch.tensor<writeonly:?x?xf32>{%2, %3}
-          }
-        }
+        flow.dispatch.tensor.store %gemm, %result_binding, offsets = [0, 0], sizes = [%d0, %d1], strides = [1, 1]
+            : tensor<?x?xf32> -> !flow.dispatch.tensor<writeonly:?x?xf32>{%d0, %d1}
         return
       }
     }
   }
 }
-
-//      CHECK: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"CPUDoubleTilingExpert", workload_per_wg = [64, 64]>
+//  CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering.config<tile_sizes = {{\[}}[64, 64, 0], [1, 4, 0], [0, 0, 4]{{\]}}, native_vector_size = []>
+//      CHECK: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"CPUDoubleTilingExpert", workload_per_wg = []>
 //      CHECK: hal.executable.entry_point public @outs_fusion_fn
 // CHECK-SAME:   translation.info = #[[TRANSLATION]]
 //      CHECK: func @outs_fusion_fn()
@@ -724,71 +504,44 @@
     hal.executable.entry_point public @conv layout(#executable_layout)
     builtin.module {
       func @conv() {
-        %c0 = arith.constant 0 : index
-        %0 = hal.interface.constant.load[0] : index
-        %1 = hal.interface.constant.load[1] : index
-        %2 = hal.interface.constant.load[2] : index
-        %3 = hal.interface.constant.load[3] : index
-        %4 = hal.interface.constant.load[4] : index
-        %5 = hal.interface.constant.load[5] : index
-        %6 = hal.interface.constant.load[6] : index
-        %7 = hal.interface.constant.load[7] : index
-        %8 = hal.interface.constant.load[8] : index
-        %9 = hal.interface.constant.load[9] : index
-        %10 = hal.interface.constant.load[10] : index
-        %11 = hal.interface.constant.load[11] : index
-        %12 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) : !flow.dispatch.tensor<readonly:?x?x?x?xf32>{%0, %1, %2, %3}
-        %13 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) : !flow.dispatch.tensor<readwrite:?x?x?x?xf32>{%4, %5, %6, %7}
-        %14 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) : !flow.dispatch.tensor<readonly:?x?x?x?xf32>{%8, %9, %10, %11}
-        %workgroup_size_x = hal.interface.workgroup.size[0] : index
-        %workgroup_size_y = hal.interface.workgroup.size[1] : index
-        %workgroup_size_z = hal.interface.workgroup.size[2] : index
-        %workgroup_id_x = hal.interface.workgroup.id[0] : index
-        %workgroup_count_x = hal.interface.workgroup.count[0] : index
-        %workgroup_id_y = hal.interface.workgroup.id[1] : index
-        %workgroup_count_y = hal.interface.workgroup.count[1] : index
-        %workgroup_id_z = hal.interface.workgroup.id[2] : index
-        %workgroup_count_z = hal.interface.workgroup.count[2] : index
-        %15 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_z, %workgroup_size_z]
-        %16 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_z, %workgroup_size_z]
-        scf.for %arg0 = %15 to %5 step %16 {
-          %17 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_y, %workgroup_size_y]
-          %18 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_y, %workgroup_size_y]
-          scf.for %arg1 = %17 to %6 step %18 {
-            %19 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_x, %workgroup_size_x]
-            %20 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_x, %workgroup_size_x]
-            scf.for %arg2 = %19 to %3 step %20 {
-              %21 = affine.min affine_map<(d0)[s0, s1, s2] -> (s0 + s2 - 1, -d0 + s0 + s1)>(%arg0)[%0, %5, %workgroup_size_z]
-              %22 = affine.min affine_map<(d0)[s0, s1, s2] -> (s0 + s2 - 1, -d0 + s0 + s1)>(%arg1)[%1, %6, %workgroup_size_y]
-              %23 = flow.dispatch.tensor.load %14, offsets = [0, %arg0, %arg1, 0], sizes = [%8, %21, %22, %11], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:?x?x?x?xf32>{%8, %9, %10, %11} -> tensor<?x?x?x?xf32>
-              %24 = affine.min affine_map<(d0)[s0, s1] -> (s1, -d0 + s0)>(%arg2)[%3, %workgroup_size_x]
-              %25 = flow.dispatch.tensor.load %12, offsets = [0, 0, 0, %arg2], sizes = [%0, %1, %2, %24], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:?x?x?x?xf32>{%0, %1, %2, %3} -> tensor<?x?x?x?xf32>
-              %26 = affine.min affine_map<(d0)[s0, s1] -> (s1, -d0 + s0)>(%arg0)[%5, %workgroup_size_z]
-              %27 = affine.min affine_map<(d0)[s0, s1] -> (s1, -d0 + s0)>(%arg1)[%6, %workgroup_size_y]
-              %28 = affine.min affine_map<(d0)[s0, s1] -> (s1, -d0 + s0)>(%arg2)[%3, %workgroup_size_x]
-              %29 = flow.dispatch.tensor.load %13, offsets = [0, %arg0, %arg1, %arg2], sizes = [%4, %26, %27, %28], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readwrite:?x?x?x?xf32>{%4, %5, %6, %7} -> tensor<?x?x?x?xf32>
-              %30 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%23, %25 : tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) outs(%29 : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
-              flow.dispatch.tensor.store %30, %13, offsets = [0, %arg0, %arg1, %arg2], sizes = [%4, %26, %27, %28], strides = [1, 1, 1, 1] : tensor<?x?x?x?xf32> -> !flow.dispatch.tensor<readwrite:?x?x?x?xf32>{%4, %5, %6, %7}
-            }
-          }
-        }
+        %N = hal.interface.constant.load[0] : index
+        %H = hal.interface.constant.load[1] : index
+        %W = hal.interface.constant.load[2] : index
+        %C = hal.interface.constant.load[3] : index
+        %R = hal.interface.constant.load[4] : index
+        %S = hal.interface.constant.load[5] : index
+        %F = hal.interface.constant.load[6] : index
+        %P = hal.interface.constant.load[7] : index
+        %Q = hal.interface.constant.load[8] : index
+        %input_binding = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer)
+            : !flow.dispatch.tensor<readonly:?x?x?x?xf32>{%N, %H, %W, %C}
+        %filter_binding = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer)
+            : !flow.dispatch.tensor<readonly:?x?x?x?xf32>{%R, %S, %C, %F}
+        %result_binding = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer)
+            : !flow.dispatch.tensor<readwrite:?x?x?x?xf32>{%N, %P, %Q, %F}
+        %input = flow.dispatch.tensor.load %input_binding, offsets = [0, 0, 0, 0], sizes = [%N, %H, %W, %C], strides = [1, 1, 1, 1]
+            : !flow.dispatch.tensor<readonly:?x?x?x?xf32>{%N, %H, %W, %C} -> tensor<?x?x?x?xf32>
+        %filter = flow.dispatch.tensor.load %filter_binding, offsets = [0, 0, 0, 0], sizes = [%R, %S, %C, %F], strides = [1, 1, 1, 1]
+            : !flow.dispatch.tensor<readonly:?x?x?x?xf32>{%R, %S, %C, %F} -> tensor<?x?x?x?xf32>
+        %init = flow.dispatch.tensor.load %result_binding, offsets = [0, 0, 0, 0], sizes = [%N, %P, %Q, %F], strides = [1, 1, 1, 1]
+            : !flow.dispatch.tensor<readwrite:?x?x?x?xf32>{%N, %P, %Q, %F} -> tensor<?x?x?x?xf32>
+        %conv = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
+            ins(%input, %filter : tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>)
+            outs(%init : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+        flow.dispatch.tensor.store %conv, %result_binding, offsets = [0, 0, 0, 0], sizes = [%N, %P, %Q, %F], strides = [1, 1, 1, 1]
+            : tensor<?x?x?x?xf32> -> !flow.dispatch.tensor<readwrite:?x?x?x?xf32>{%N, %P, %Q, %F}
         return
       }
     }
   }
 }
 
-//  CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0] -> (s0 ceildiv 64)>
-//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"CPUDefault", workload_per_wg = [64, 64, 64]>
+//  CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering.config<tile_sizes = {{\[}}[0, 64, 64, 64, 0, 0, 0]{{\]}}, native_vector_size = []>
+//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"CPUDefault", workload_per_wg = []>
 //      CHECK: hal.executable.entry_point public @conv
 // CHECK-SAME:     translation.info = #[[TRANSLATION]]
-// CHECK-NEXT:   ^bb0(%[[ARG0:[a-zA-Z0-9]+]]: index, %[[ARG1:[a-zA-Z0-9]+]]: index, %[[ARG2:[a-zA-Z0-9]+]]: index)
-//  CHECK-DAG:     %[[D0:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]]
-//  CHECK-DAG:     %[[D1:.+]] = affine.apply #[[MAP0]]()[%[[ARG1]]
-//  CHECK-DAG:     %[[D2:.+]] = affine.apply #[[MAP0]]()[%[[ARG2]]
-//      CHECK:     hal.return %[[D0]], %[[D1]], %[[D2]]
 //      CHECK:     linalg.conv_2d_nhwc_hwcf
-//  CHECK-NOT:       lowering.config
+//      CHECK:         lowering.config = #[[CONFIG]]
 
 // -----
 
@@ -802,76 +555,40 @@
 hal.executable private @conv_static {
   hal.executable.variant public @system_elf_x86_64, target = <"llvm", "system-elf-x86_64", {
     data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128",
-    native_vector_size = 64 : index,
-    target_triple = "x86_64-pc-linux-gnu"
+    native_vector_size = 16 : index,
+    target_triple = "x86_64-unknown-linux-gnu"
   }> {
     hal.executable.entry_point public @conv_static layout(#executable_layout)
     builtin.module {
       func @conv_static() {
-        %cst = arith.constant 0.000000e+00 : f32
-        %c80 = arith.constant 80 : index
-        %c96 = arith.constant 96 : index
-        %c0 = arith.constant 0 : index
-        %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) : !flow.dispatch.tensor<readonly:1x161x161x96xf32>
-        %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) : !flow.dispatch.tensor<readonly:3x3x96xf32>
-        %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) : !flow.dispatch.tensor<writeonly:1x80x80x96xf32>
-        %workgroup_size_x = hal.interface.workgroup.size[0] : index
-        %workgroup_size_y = hal.interface.workgroup.size[1] : index
-        %workgroup_size_z = hal.interface.workgroup.size[2] : index
-        %workgroup_id_x = hal.interface.workgroup.id[0] : index
-        %workgroup_count_x = hal.interface.workgroup.count[0] : index
-        %workgroup_id_y = hal.interface.workgroup.id[1] : index
-        %workgroup_count_y = hal.interface.workgroup.count[1] : index
-        %workgroup_id_z = hal.interface.workgroup.id[2] : index
-        %workgroup_count_z = hal.interface.workgroup.count[2] : index
-        %3 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_z, %workgroup_size_z]
-        %4 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_z, %workgroup_size_z]
-        scf.for %arg0 = %3 to %c80 step %4 {
-          %5 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_y, %workgroup_size_y]
-          %6 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_y, %workgroup_size_y]
-          scf.for %arg1 = %5 to %c80 step %6 {
-            %7 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_x, %workgroup_size_x]
-            %8 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_x, %workgroup_size_x]
-            scf.for %arg2 = %7 to %c96 step %8 {
-              %9 = affine.apply affine_map<(d0) -> (d0 * 2)>(%arg0)
-              %10 = affine.min affine_map<(d0)[s0] -> (s0 * 2 + 1, d0 * -2 + 163)>(%arg0)[%workgroup_size_z]
-              %11 = affine.apply affine_map<(d0) -> (d0 * 2)>(%arg1)
-              %12 = affine.min affine_map<(d0)[s0] -> (s0 * 2 + 1, d0 * -2 + 163)>(%arg1)[%workgroup_size_y]
-              %13 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 96)>(%arg2)[%workgroup_size_x]
-              %14 = flow.dispatch.tensor.load %0, offsets = [0, %9, %11, %arg2], sizes = [1, %10, %12, %13], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:1x161x161x96xf32> -> tensor<1x?x?x?xf32>
-              %15 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 96)>(%arg2)[%workgroup_size_x]
-              %16 = flow.dispatch.tensor.load %1, offsets = [0, 0, %arg2], sizes = [3, 3, %15], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:3x3x96xf32> -> tensor<3x3x?xf32>
-              %17 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 80)>(%arg0)[%workgroup_size_z]
-              %18 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 80)>(%arg1)[%workgroup_size_y]
-              %19 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 96)>(%arg2)[%workgroup_size_x]
-              %20 = affine.min affine_map<(d0)[s0] -> (-d0 + 80, s0)>(%arg0)[%workgroup_size_z]
-              %21 = affine.min affine_map<(d0)[s0] -> (-d0 + 80, s0)>(%arg1)[%workgroup_size_y]
-              %22 = affine.min affine_map<(d0)[s0] -> (-d0 + 96, s0)>(%arg2)[%workgroup_size_x]
-              %23 = linalg.init_tensor [1, %20, %21, %22] : tensor<1x?x?x?xf32>
-              %24 = linalg.fill(%cst, %23) : f32, tensor<1x?x?x?xf32> -> tensor<1x?x?x?xf32>
-              %25 = linalg.depthwise_conv_2d_nhwc_hwc {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%14, %16 : tensor<1x?x?x?xf32>, tensor<3x3x?xf32>) outs(%24 : tensor<1x?x?x?xf32>) -> tensor<1x?x?x?xf32>
-              flow.dispatch.tensor.store %25, %2, offsets = [0, %arg0, %arg1, %arg2], sizes = [1, %17, %18, %19], strides = [1, 1, 1, 1] : tensor<1x?x?x?xf32> -> !flow.dispatch.tensor<writeonly:1x80x80x96xf32>
-            }
-          }
-        }
+        %cst = arith.constant 0.0 : f32
+        %input_binding = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer)
+            : !flow.dispatch.tensor<readonly:1x161x161x96xf32>
+        %filter_binding = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer)
+            : !flow.dispatch.tensor<readonly:3x3x96xf32>
+        %result_binding = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer)
+            : !flow.dispatch.tensor<writeonly:1x80x80x96xf32>
+        %input = flow.dispatch.tensor.load %input_binding, offsets = [0, 0, 0, 0], sizes = [1, 161, 161, 96], strides = [1, 1, 1, 1]
+            : !flow.dispatch.tensor<readonly:1x161x161x96xf32> -> tensor<1x161x161x96xf32>
+        %filter = flow.dispatch.tensor.load %filter_binding, offsets = [0, 0, 0], sizes = [3, 3, 96], strides = [1, 1, 1]
+            : !flow.dispatch.tensor<readonly:3x3x96xf32> -> tensor<3x3x96xf32>
+        %init = linalg.init_tensor [1, 80, 80, 96] : tensor<1x80x80x96xf32>
+        %fill = linalg.fill(%cst, %init) : f32, tensor<1x80x80x96xf32> -> tensor<1x80x80x96xf32>
+        %conv = linalg.depthwise_conv_2d_nhwc_hwc {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>}
+            ins(%input, %filter : tensor<1x161x161x96xf32>, tensor<3x3x96xf32>) outs(%fill : tensor<1x80x80x96xf32>) -> tensor<1x80x80x96xf32>
+        flow.dispatch.tensor.store %conv, %result_binding, offsets = [0, 0, 0, 0], sizes = [1, 80, 80, 96], strides = [1, 1, 1, 1]
+            : tensor<1x80x80x96xf32> -> !flow.dispatch.tensor<writeonly:1x80x80x96xf32>
         return
       }
     }
   }
 }
-
-//  CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 64)>
-//  CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 ceildiv 32)>
-//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"CPUDefault", workload_per_wg = [64, 64, 32]>
+//  CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering.config<tile_sizes = {{\[}}[0, 20, 40, 48, 0, 0]{{\]}}, native_vector_size = []>
+//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"CPUDefault", workload_per_wg = []>
 //      CHECK: hal.executable.entry_point public @conv_static
 // CHECK-SAME:     translation.info = #[[TRANSLATION]]
-// CHECK-NEXT:   ^bb0(%[[ARG0:[a-zA-Z0-9]+]]: index, %[[ARG1:[a-zA-Z0-9]+]]: index, %[[ARG2:[a-zA-Z0-9]+]]: index)
-//  CHECK-DAG:     %[[D0:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]]
-//  CHECK-DAG:     %[[D1:.+]] = affine.apply #[[MAP0]]()[%[[ARG1]]
-//  CHECK-DAG:     %[[D2:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]]
-//      CHECK:     hal.return %[[D0]], %[[D1]], %[[D2]]
 //      CHECK:     linalg.depthwise_conv_2d_nhwc_hwc
-//  CHECK-NOT:       lowering.config
+// CHECK-SAME:       lowering.config  = #[[CONFIG]]
 
 // -----
 
@@ -890,53 +607,32 @@
     hal.executable.entry_point public @generic_static layout(#executable_layout)
     builtin.module {
       func @generic_static() {
-        %c16 = arith.constant 16 : index
-        %c96 = arith.constant 96 : index
-        %c0 = arith.constant 0 : index
-        %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) : !flow.dispatch.tensor<readonly:96x16xf32>
-        %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) : !flow.dispatch.tensor<writeonly:16x96xf32>
-        %workgroup_size_x = hal.interface.workgroup.size[0] : index
-        %workgroup_size_y = hal.interface.workgroup.size[1] : index
-        %workgroup_id_x = hal.interface.workgroup.id[0] : index
-        %workgroup_count_x = hal.interface.workgroup.count[0] : index
-        %workgroup_id_y = hal.interface.workgroup.id[1] : index
-        %workgroup_count_y = hal.interface.workgroup.count[1] : index
-        %2 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_y, %workgroup_size_y]
-        %3 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_y, %workgroup_size_y]
-        scf.for %arg0 = %2 to %c16 step %3 {
-          %4 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_x, %workgroup_size_x]
-          %5 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_x, %workgroup_size_x]
-          scf.for %arg1 = %4 to %c96 step %5 {
-            %6 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 96)>(%arg1)[%workgroup_size_x]
-            %7 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 16)>(%arg0)[%workgroup_size_y]
-            %8 = flow.dispatch.tensor.load %0, offsets = [%arg1, %arg0], sizes = [%6, %7], strides = [1, 1] : !flow.dispatch.tensor<readonly:96x16xf32> -> tensor<?x?xf32>
-            %9 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 16)>(%arg0)[%workgroup_size_y]
-            %10 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 96)>(%arg1)[%workgroup_size_x]
-            %11 = linalg.init_tensor [%9, %10] : tensor<?x?xf32>
-            %12 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%8 : tensor<?x?xf32>) outs(%11 : tensor<?x?xf32>) {
-            ^bb0(%arg2: f32, %arg3: f32):  // no predecessors
-              linalg.yield %arg2 : f32
-            } -> tensor<?x?xf32>
-            flow.dispatch.tensor.store %12, %1, offsets = [%arg0, %arg1], sizes = [%9, %10], strides = [1, 1] : tensor<?x?xf32> -> !flow.dispatch.tensor<writeonly:16x96xf32>
-          }
-        }
+        %input_binding = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer)
+            : !flow.dispatch.tensor<readonly:96x16xf32>
+        %result_binding = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer)
+            : !flow.dispatch.tensor<writeonly:16x96xf32>
+        %input = flow.dispatch.tensor.load %input_binding, offsets = [0, 0], sizes = [96, 16], strides = [1, 1]
+            : !flow.dispatch.tensor<readonly:96x16xf32> -> tensor<96x16xf32>
+        %init = linalg.init_tensor [16, 96] : tensor<16x96xf32>
+        %result = linalg.generic {
+            indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0, d1)>],
+            iterator_types = ["parallel", "parallel"]}
+            ins(%input : tensor<96x16xf32>) outs(%init : tensor<16x96xf32>) {
+            ^bb0(%b0: f32, %b1: f32):  // no predecessors
+              linalg.yield %b0 : f32
+            } -> tensor<16x96xf32>
+        flow.dispatch.tensor.store %result, %result_binding, offsets = [0, 0], sizes = [16, 96], strides = [1, 1]
+            : tensor<16x96xf32> -> !flow.dispatch.tensor<writeonly:16x96xf32>
         return
       }
     }
   }
 }
-//  CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering.config<tile_sizes = {{\[}}[], [16, 16], [0, 0]], native_vector_size = []>
-//  CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 32)>
-//  CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 ceildiv 16)>
-//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"CPUDoubleTilingExpert", workload_per_wg = [32, 16]>
+//  CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering.config<tile_sizes = {{\[}}[16, 32], [16, 16], [0, 0]], native_vector_size = []>
+//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"CPUDoubleTilingExpert", workload_per_wg = []>
 //      CHECK: hal.executable.entry_point public @generic_static
 // CHECK-SAME:     translation.info = #[[TRANSLATION]]
-// CHECK-NEXT:   ^bb0(%[[ARG0:[a-zA-Z0-9]+]]: index, %[[ARG1:[a-zA-Z0-9]+]]: index, %[[ARG2:[a-zA-Z0-9]+]]: index)
-//  CHECK-DAG:     %[[C1:.+]] = arith.constant 1 : index
-//  CHECK-DAG:     %[[D0:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]]
-//  CHECK-DAG:     %[[D1:.+]] = affine.apply #[[MAP1]]()[%[[ARG1]]
-//      CHECK:     hal.return %[[D0]], %[[D1]], %[[C1]]
-//      CHECK:     linalg.generic
+//      CHECK:   linalg.generic
 //      CHECK:       lowering.config = #[[CONFIG]]
 
 // -----
@@ -957,54 +653,33 @@
     hal.executable.entry_point public @matmul_static layout(#executable_layout)
     builtin.module {
       func @matmul_static() {
-        %cst = arith.constant 0.000000e+00 : f32
-        %c196 = arith.constant 196 : index
-        %c40 = arith.constant 40 : index
-        %c0 = arith.constant 0 : index
-        %c8 = arith.constant 8 : index
-        %c28 = arith.constant 28 : index
-        %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) : !flow.dispatch.tensor<readonly:196x240xf32>
-        %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) : !flow.dispatch.tensor<readonly:240x40xf32>
-        %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) : !flow.dispatch.tensor<writeonly:196x40xf32>
-        %workgroup_id_x = hal.interface.workgroup.id[0] : index
-        %workgroup_count_x = hal.interface.workgroup.count[0] : index
-        %workgroup_size_x = hal.interface.workgroup.size[0] : index
-        %workgroup_id_y = hal.interface.workgroup.id[1] : index
-        %workgroup_count_y = hal.interface.workgroup.count[1] : index
-        %workgroup_size_y = hal.interface.workgroup.size[1] : index
-        %3 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_y, %workgroup_size_y]
-        %4 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_y, %workgroup_size_y]
-        scf.for %arg0 = %3 to %c196 step %4 {
-          %5 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_x, %workgroup_size_x]
-          %6 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_x, %workgroup_size_x]
-          scf.for %arg1 = %5 to %c40 step %6 {
-            %7 = flow.dispatch.tensor.load %0, offsets = [%arg0, 0], sizes = [28, 240], strides = [1, 1] : !flow.dispatch.tensor<readonly:196x240xf32> -> tensor<28x240xf32>
-            %8 = tensor.cast %7 : tensor<28x240xf32> to tensor<?x240xf32>
-            %9 = flow.dispatch.tensor.load %1, offsets = [0, %arg1], sizes = [240, 8], strides = [1, 1] : !flow.dispatch.tensor<readonly:240x40xf32> -> tensor<240x8xf32>
-            %10 = tensor.cast %9 : tensor<240x8xf32> to tensor<240x?xf32>
-            %11 = linalg.init_tensor [%c28, %c8] : tensor<?x?xf32>
-            %12 = linalg.fill(%cst, %11) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
-            %13 = linalg.matmul ins(%8, %10 : tensor<?x240xf32>, tensor<240x?xf32>) outs(%12 : tensor<?x?xf32>) -> tensor<?x?xf32>
-            flow.dispatch.tensor.store %13, %2, offsets = [%arg0, %arg1], sizes = [%c28, %c8], strides = [1, 1] : tensor<?x?xf32> -> !flow.dispatch.tensor<writeonly:196x40xf32>
-          }
-        }
+        %cst = arith.constant 0.0 : f32
+        %lhs_binding = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer)
+            : !flow.dispatch.tensor<readonly:196x240xf32>
+        %rhs_binding = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer)
+            : !flow.dispatch.tensor<readonly:240x40xf32>
+        %result_binding = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer)
+            : !flow.dispatch.tensor<writeonly:196x40xf32>
+        %lhs = flow.dispatch.tensor.load %lhs_binding, offsets = [0, 0], sizes = [196, 240], strides = [1, 1]
+            : !flow.dispatch.tensor<readonly:196x240xf32> -> tensor<196x240xf32>
+        %rhs = flow.dispatch.tensor.load %rhs_binding, offsets = [0, 0], sizes = [240, 40], strides = [1, 1]
+            : !flow.dispatch.tensor<readonly:240x40xf32> -> tensor<240x40xf32>
+        %init = linalg.init_tensor [196, 40] : tensor<196x40xf32>
+        %fill = linalg.fill(%cst, %init) : f32, tensor<196x40xf32> -> tensor<196x40xf32>
+        %gemm = linalg.matmul ins(%lhs, %rhs : tensor<196x240xf32>, tensor<240x40xf32>)
+            outs(%fill : tensor<196x40xf32>) -> tensor<196x40xf32>
+        flow.dispatch.tensor.store %gemm, %result_binding, offsets = [0, 0], sizes = [196, 40], strides = [1, 1]
+            : tensor<196x40xf32> -> !flow.dispatch.tensor<writeonly:196x40xf32>
         return
       }
     }
   }
 }
 
-//   CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering.config<tile_sizes = {{\[}}[], [4, 4, 60], [4, 4, 4]{{\]}}, native_vector_size = [4, 4, 4]>
-//   CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 8)>
-//   CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 ceildiv 28)>
-//   CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"CPUTileFuseAndVectorize", workload_per_wg = [8, 28]>
+//   CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering.config<tile_sizes = {{\[}}[28, 8, 0], [4, 4, 60], [4, 4, 4]{{\]}}, native_vector_size = [4, 4, 4]>
+//   CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"CPUTileFuseAndVectorize", workload_per_wg = []>
 //       CHECK: hal.executable.entry_point public @matmul_static
 //  CHECK-SAME:     translation.info = #[[TRANSLATION]]
-//  CHECK-NEXT:   ^bb0(%[[ARG0:[a-zA-Z0-9]+]]: index, %[[ARG1:[a-zA-Z0-9]+]]: index, %[[ARG2:[a-zA-Z0-9]+]]: index)
-//   CHECK-DAG:     %[[C1:.+]] = arith.constant 1 : index
-//   CHECK-DAG:     %[[D0:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]]]
-//   CHECK-DAG:     %[[D1:.+]] = affine.apply #[[MAP1]]()[%[[ARG1]]]
-//       CHECK:     hal.return %[[D0]], %[[D1]], %[[C1]]
 //       CHECK: linalg.matmul
 //  CHECK-SAME:     lowering.config = #[[CONFIG]]
 
@@ -1027,299 +702,35 @@
     builtin.module {
       func @restrict_num_workgroups() {
         %cst = arith.constant 0.000000e+00 : f32
-        %c7 = arith.constant 7 : index
-        %c576 = arith.constant 576 : index
-        %c0 = arith.constant 0 : index
-        %c64 = arith.constant 64 : index
-        %c2 = arith.constant 2 : index
-        %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) : !flow.dispatch.tensor<readonly:1x11x11x576xf32>
-        %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) : !flow.dispatch.tensor<readonly:5x5x576xf32>
-        %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) : !flow.dispatch.tensor<writeonly:1x7x7x576xf32>
-        %workgroup_id_x = hal.interface.workgroup.id[0] : index
-        %workgroup_count_x = hal.interface.workgroup.count[0] : index
-        %workgroup_size_x = hal.interface.workgroup.size[0] : index
-        %workgroup_id_y = hal.interface.workgroup.id[1] : index
-        %workgroup_count_y = hal.interface.workgroup.count[1] : index
-        %workgroup_size_y = hal.interface.workgroup.size[1] : index
-        %workgroup_id_z = hal.interface.workgroup.id[2] : index
-        %workgroup_count_z = hal.interface.workgroup.count[2] : index
-        %workgroup_size_z = hal.interface.workgroup.size[2] : index
-        %3 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_z, %workgroup_size_z]
-        %4 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_z, %workgroup_size_z]
-        scf.for %arg0 = %3 to %c7 step %4 {
-          %5 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_y, %workgroup_size_y]
-          %6 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_y, %workgroup_size_y]
-          scf.for %arg1 = %5 to %c7 step %6 {
-            %7 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_x, %workgroup_size_x]
-            %8 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_x, %workgroup_size_x]
-            scf.for %arg2 = %7 to %c576 step %8 {
-              %9 = affine.min affine_map<(d0) -> (6, -d0 + 12)>(%arg0)
-              %10 = affine.min affine_map<(d0) -> (11, -d0 + 12)>(%arg1)
-              %11 = flow.dispatch.tensor.load %0, offsets = [0, %arg0, %arg1, %arg2], sizes = [1, %9, %10, 64], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:1x11x11x576xf32> -> tensor<1x?x?x64xf32>
-              %12 = tensor.cast %11 : tensor<1x?x?x64xf32> to tensor<1x?x?x?xf32>
-              %13 = flow.dispatch.tensor.load %1, offsets = [0, 0, %arg2], sizes = [5, 5, 64], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:5x5x576xf32> -> tensor<5x5x64xf32>
-              %14 = tensor.cast %13 : tensor<5x5x64xf32> to tensor<5x5x?xf32>
-              %15 = affine.min affine_map<(d0) -> (2, -d0 + 7)>(%arg0)
-              %16 = affine.min affine_map<(d0) -> (-d0 + 7, 2)>(%arg0)
-              %17 = linalg.init_tensor [1, %16, %c7, %c64] : tensor<1x?x?x?xf32>
-              %18 = linalg.fill(%cst, %17) : f32, tensor<1x?x?x?xf32> -> tensor<1x?x?x?xf32>
-              %19 = linalg.depthwise_conv_2d_nhwc_hwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%12, %14 : tensor<1x?x?x?xf32>, tensor<5x5x?xf32>) outs(%18 : tensor<1x?x?x?xf32>) -> tensor<1x?x?x?xf32>
-              flow.dispatch.tensor.store %19, %2, offsets = [0, %arg0, %arg1, %arg2], sizes = [1, %15, %c7, %c64], strides = [1, 1, 1, 1] : tensor<1x?x?x?xf32> -> !flow.dispatch.tensor<writeonly:1x7x7x576xf32>
-            }
-          }
-        }
+        %input_binding = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer)
+            : !flow.dispatch.tensor<readonly:1x11x11x576xf32>
+        %filter_binding = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer)
+            : !flow.dispatch.tensor<readonly:5x5x576xf32>
+        %result_binding = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer)
+            : !flow.dispatch.tensor<writeonly:1x7x7x576xf32>
+        %input = flow.dispatch.tensor.load %input_binding, offsets = [0, 0, 0, 0], sizes = [1, 11, 11, 576], strides = [1, 1, 1, 1]
+            : !flow.dispatch.tensor<readonly:1x11x11x576xf32> -> tensor<1x11x11x576xf32>
+        %filter = flow.dispatch.tensor.load %filter_binding, offsets = [0, 0, 0], sizes = [5, 5, 576], strides = [1, 1, 1]
+            : !flow.dispatch.tensor<readonly:5x5x576xf32> -> tensor<5x5x576xf32>
+        %init = linalg.init_tensor [1, 7, 7, 576] : tensor<1x7x7x576xf32>
+        %fill = linalg.fill(%cst, %init) : f32, tensor<1x7x7x576xf32> -> tensor<1x7x7x576xf32>
+        %conv = linalg.depthwise_conv_2d_nhwc_hwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
+            ins(%input, %filter : tensor<1x11x11x576xf32>, tensor<5x5x576xf32>)
+            outs(%fill : tensor<1x7x7x576xf32>) -> tensor<1x7x7x576xf32>
+        flow.dispatch.tensor.store %conv, %result_binding, offsets = [0, 0, 0, 0], sizes = [1, 7, 7, 576], strides = [1, 1, 1, 1]
+            : tensor<1x7x7x576xf32> -> !flow.dispatch.tensor<writeonly:1x7x7x576xf32>
         return
       }
     }
   }
 }
-
-//   CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 64)>
-//   CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 ceildiv 8)>
-//   CHECK-DAG: #[[MAP2:.+]] = affine_map<()[s0] -> (s0 ceildiv 4)>
-//   CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"CPUDefault", workload_per_wg = [64, 8, 4]>
+//   CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering.config<tile_sizes = {{\[}}[0, 1, 7, 64, 0, 0]{{\]}}, native_vector_size = []>
+//   CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"CPUDefault", workload_per_wg = []>
 //       CHECK: hal.executable.entry_point public @restrict_num_workgroups
 //  CHECK-SAME:     translation.info = #[[TRANSLATION]]
-//  CHECK-NEXT:   ^bb0(%[[ARG0:[a-zA-Z0-9]+]]: index, %[[ARG1:[a-zA-Z0-9]+]]: index, %[[ARG2:[a-zA-Z0-9]+]]: index)
-//   CHECK-DAG:     %[[D0:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]]]
-//   CHECK-DAG:     %[[D1:.+]] = affine.apply #[[MAP1]]()[%[[ARG1]]]
-//   CHECK-DAG:     %[[D2:.+]] = affine.apply #[[MAP2]]()[%[[ARG2]]]
-//       CHECK:     hal.return %[[D0]], %[[D1]], %[[D2]]
+//       CHECK: linalg.depthwise_conv_2d_nhwc_hwc
+//  CHECK-SAME:     lowering.config = #[[CONFIG]]
 
-// -----
-
-#executable_layout = #hal.executable.layout<push_constants = 3, sets = [
-  #hal.descriptor_set.layout<0, bindings = [
-    #hal.descriptor_set.binding<0, storage_buffer>,
-    #hal.descriptor_set.binding<1, storage_buffer>
-  ]>
-]>
-hal.executable private @test_exp_0 {
-  hal.executable.variant public @system_elf_arm_64, target = <"llvm", "system-elf-arm_64", {
-    data_layout = "e-m:e-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128",
-    native_vector_size = 16 : index,
-    target_triple = "aarch64-none-linux-android30"
-  }> {
-    hal.executable.entry_point public @test_exp_0 layout(#executable_layout)
-    builtin.module {
-      func @test_exp_0() {
-        %c0 = arith.constant 0 : index
-        %size = hal.interface.workgroup.size[0] : index
-        %count = hal.interface.workgroup.count[0] : index
-        %id = hal.interface.workgroup.id[0] : index
-        %lb = hal.interface.constant.load[0] : index
-        %ub = hal.interface.constant.load[1] : index
-        %step = hal.interface.constant.load[2] : index
-        %read = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) : memref<?xf32>{%ub}
-        %write = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) : memref<?xf32>{%ub}
-        %offset = affine.apply affine_map<(d0)[s0,s1] -> (d0 + s0 * s1)>(%lb)[%id, %size]
-        %stride = affine.apply affine_map<(d0)[s0,s1] -> (d0 * s0 * s1)>(%step)[%count, %size]
-        scf.for %iv = %offset to %ub step %stride {
-          %val = memref.load %read[%iv] : memref<?xf32>
-          memref.store %val, %write[%iv] : memref<?xf32>
-        }
-        return
-      }
-    }
-  }
-}
-
-//  CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0] -> (s0 ceildiv 64)>
-//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"CPUDefault", workload_per_wg = [64]>
-//      CHECK: hal.executable.entry_point public @test_exp_0
-// CHECK-SAME:     translation.info = #[[TRANSLATION]]
-// CHECK-NEXT:   ^bb0(%[[ARG0:[a-zA-Z0-9]+]]: index
-//  CHECK-DAG:     %[[C1:.+]] = arith.constant 1 : index
-//  CHECK-DAG:     %[[D0:.+]] = affine.apply #[[MAP]]()[%[[ARG0]]]
-//      CHECK:     hal.return %[[D0]], %[[C1]], %[[C1]]
-
-// -----
-
-#executable_layout = #hal.executable.layout<push_constants = 3, sets = [
-  #hal.descriptor_set.layout<0, bindings = [
-    #hal.descriptor_set.binding<0, storage_buffer>,
-    #hal.descriptor_set.binding<1, storage_buffer>
-  ]>
-]>
-hal.executable private @test_exp_1 {
-  hal.executable.variant public @system_elf_arm_64, target = <"llvm", "system-elf-arm_64", {
-    data_layout = "e-m:e-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128",
-    native_vector_size = 16 : index,
-    target_triple = "aarch64-none-linux-android30"
-  }> {
-    hal.executable.entry_point public @test_exp_1 layout(#executable_layout)
-    builtin.module {
-      func @test_exp_1() {
-        %c0 = arith.constant 0 : index
-        %size = hal.interface.workgroup.size[0] : index
-        %count = hal.interface.workgroup.count[0] : index
-        %id = hal.interface.workgroup.id[0] : index
-        %lb = hal.interface.constant.load[0] : index
-        %ub = hal.interface.constant.load[1] : index
-        %step = hal.interface.constant.load[2] : index
-        %read = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) : memref<?xf32>{%ub}
-        %write = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) : memref<?xf32>{%ub}
-        %offset = affine.apply affine_map<()[s0,s1] -> (5 + s0 * s1)>()[%id, %size]
-        %stride = affine.apply affine_map<(d0)[s0,s1] -> (s0 * d0 * s1)>(%step)[%count, %size]
-        scf.for %iv = %offset to %ub step %stride {
-          %val = memref.load %read[%iv] : memref<?xf32>
-          memref.store %val, %write[%iv] : memref<?xf32>
-        }
-        return
-      }
-    }
-  }
-}
-
-//  CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0] -> (s0 ceildiv 64)>
-//  CHECk-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"CPUDefault", workload_per_wg = [64]>
-//      CHECK: hal.executable.entry_point public @test_exp_1
-// CHECK-SAME:     translation.info = #[[TRANSLATION]]
-// CHECK-NEXT:   ^bb0(%[[ARG0:[a-zA-Z0-9]+]]: index
-//  CHECK-DAG:     %[[C1:.+]] = arith.constant 1 : index
-//  CHECK-DAG:     %[[D0:.+]] = affine.apply #[[MAP]]()[%[[ARG0]]]
-//      CHECK:     hal.return %[[D0]], %[[C1]], %[[C1]]
-
-// -----
-
-#executable_layout = #hal.executable.layout<push_constants = 3, sets = [
-  #hal.descriptor_set.layout<0, bindings = [
-    #hal.descriptor_set.binding<0, storage_buffer>,
-    #hal.descriptor_set.binding<1, storage_buffer>
-  ]>
-]>
-hal.executable private @test_exp_2 {
-  hal.executable.variant public @system_elf_arm_64, target = <"llvm", "system-elf-arm_64", {
-    data_layout = "e-m:e-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128",
-    native_vector_size = 16 : index,
-    target_triple = "aarch64-none-linux-android30"
-  }> {
-    hal.executable.entry_point public @test_exp_3 layout(#executable_layout)
-    builtin.module {
-      func @test_exp_3() {
-        %c0 = arith.constant 0 : index
-        %size = hal.interface.workgroup.size[0] : index
-        %count = hal.interface.workgroup.count[0] : index
-        %id = hal.interface.workgroup.id[0] : index
-        %lb = hal.interface.constant.load[0] : index
-        %ub = hal.interface.constant.load[1] : index
-        %step = hal.interface.constant.load[2] : index
-        %read = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) : memref<?xf32>{%ub}
-        %write = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) : memref<?xf32>{%ub}
-        %offset = affine.apply affine_map<(d0)[s0,s1] -> (d0 + s0 * s1)>(%lb)[%id, %size]
-        %stride = affine.apply affine_map<()[s0,s1] -> (5 * s0 * s1)>()[%count, %size]
-        scf.for %iv = %offset to %ub step %stride {
-          %val = memref.load %read[%iv] : memref<?xf32>
-          memref.store %val, %write[%iv] : memref<?xf32>
-        }
-        return
-      }
-    }
-  }
-}
-
-//  CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0] -> (s0 ceildiv 64)>
-//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"CPUDefault", workload_per_wg = [64]>
-//      CHECK: hal.executable.entry_point public @test_exp_3
-// CHECK-SAME:     translation.info = #[[TRANSLATION]]
-// CHECK-NEXT:   ^bb0(%[[ARG0:[a-zA-Z0-9]+]]: index
-//  CHECK-DAG:     %[[C1:.+]] = arith.constant 1 : index
-//  CHECK-DAG:     %[[D0:.+]] = affine.apply #[[MAP]]()[%[[ARG0]]]
-//      CHECK:     hal.return %[[D0]], %[[C1]], %[[C1]]
-
-// -----
-
-#executable_layout = #hal.executable.layout<push_constants = 3, sets = [
-  #hal.descriptor_set.layout<0, bindings = [
-    #hal.descriptor_set.binding<0, storage_buffer>,
-    #hal.descriptor_set.binding<1, storage_buffer>
-  ]>
-]>
-hal.executable private @test_exp_3 {
-  hal.executable.variant public @system_elf_arm_64, target = <"llvm", "system-elf-arm_64", {
-    data_layout = "e-m:e-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128",
-    native_vector_size = 16 : index,
-    target_triple = "aarch64-none-linux-android30"
-  }> {
-    hal.executable.entry_point public @test_exp_4 layout(#executable_layout)
-    builtin.module {
-      func @test_exp_4() {
-        %c0 = arith.constant 0 : index
-        %size = hal.interface.workgroup.size[0] : index
-        %count = hal.interface.workgroup.count[0] : index
-        %id = hal.interface.workgroup.id[0] : index
-        %lb = hal.interface.constant.load[0] : index
-        %ub = hal.interface.constant.load[1] : index
-        %step = hal.interface.constant.load[2] : index
-        %read = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) : memref<?xf32>{%ub}
-        %write = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) : memref<?xf32>{%ub}
-        %offset = affine.apply affine_map<(d0)[s0,s1] -> (s0 * s1 + d0)>(%lb)[%id, %size]
-        %stride = affine.apply affine_map<()[s0,s1] -> (s0 * 5 * s1)>()[%count, %size]
-        scf.for %iv = %offset to %ub step %stride {
-          %val = memref.load %read[%iv] : memref<?xf32>
-          memref.store %val, %write[%iv] : memref<?xf32>
-        }
-        return
-      }
-    }
-  }
-}
-
-//  CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0] -> (s0 ceildiv 64)>
-//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"CPUDefault", workload_per_wg = [64]>
-//      CHECK: hal.executable.entry_point public @test_exp_4
-// CHECK-SAME:     translation.info = #[[TRANSLATION]]
-// CHECK-NEXT:   ^bb0(%[[ARG0:[a-zA-Z0-9]+]]: index
-//  CHECK-DAG:     %[[C1:.+]] = arith.constant 1 : index
-//  CHECK-DAG:     %[[D0:.+]] = affine.apply #[[MAP]]()[%[[ARG0]]]
-//      CHECK:     hal.return %[[D0]], %[[C1]], %[[C1]]
-
-// -----
-
-#executable_layout = #hal.executable.layout<push_constants = 3, sets = [
-  #hal.descriptor_set.layout<0, bindings = [
-    #hal.descriptor_set.binding<0, storage_buffer>,
-    #hal.descriptor_set.binding<1, storage_buffer>
-  ]>
-]>
-hal.executable private @test_exp_4 {
-  hal.executable.variant public @system_elf_arm_64, target = <"llvm", "system-elf-arm_64", {
-    data_layout = "e-m:e-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128",
-    native_vector_size = 16 : index,
-    target_triple = "aarch64-none-linux-android30"
-  }> {
-    hal.executable.entry_point public @test_exp_5 layout(#executable_layout)
-    builtin.module {
-      func @test_exp_5() {
-        %c0 = arith.constant 0 : index
-        %size = hal.interface.workgroup.size[0] : index
-        %count = hal.interface.workgroup.count[0] : index
-        %id = hal.interface.workgroup.id[0] : index
-        %lb = hal.interface.constant.load[0] : index
-        %ub = hal.interface.constant.load[1] : index
-        %step = hal.interface.constant.load[2] : index
-        %read = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) : memref<?xf32>{%ub}
-        %write = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) : memref<?xf32>{%ub}
-        %offset = affine.apply affine_map<()[s0,s1] -> (s0 * s1 + 5)>()[%id, %size]
-        %stride = affine.apply affine_map<()[s0,s1] -> (s0 * s1 * 5)>()[%count, %size]
-        scf.for %iv = %offset to %ub step %stride {
-          %val = memref.load %read[%iv] : memref<?xf32>
-          memref.store %val, %write[%iv] : memref<?xf32>
-        }
-        return
-      }
-    }
-  }
-}
-
-//  CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0] -> (s0 ceildiv 64)>
-//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"CPUDefault", workload_per_wg = [64]>
-//      CHECK: hal.executable.entry_point public @test_exp_5
-// CHECK-SAME:     translation.info = #[[TRANSLATION]]
-// CHECK-NEXT:   ^bb0(%[[ARG0:[a-zA-Z0-9]+]]: index
-//  CHECK-DAG:     %[[C1:.+]] = arith.constant 1 : index
-//  CHECK-DAG:     %[[D0:.+]] = affine.apply #[[MAP]]()[%[[ARG0]]]
-//      CHECK:     hal.return %[[D0]], %[[C1]], %[[C1]]
 
 // -----
 
@@ -1341,46 +752,90 @@
     hal.executable.entry_point public @matmul_x86 layout(#executable_layout)
     builtin.module {
       func @matmul_x86() {
-        %c128 = arith.constant 128 : index
-        %c384 = arith.constant 384 : index
-        %cst = arith.constant 0.000000e+00 : f32
-        %c0 = arith.constant 0 : index
-        %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) : !flow.dispatch.tensor<readonly:384x512xf32>
-        %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) : !flow.dispatch.tensor<readonly:512x128xf32>
-        %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) : !flow.dispatch.tensor<writeonly:384x128xf32>
-        %workgroup_size_x = hal.interface.workgroup.size[0] : index
-        %workgroup_size_y = hal.interface.workgroup.size[1] : index
-        %workgroup_id_x = hal.interface.workgroup.id[0] : index
-        %workgroup_count_x = hal.interface.workgroup.count[0] : index
-        %workgroup_id_y = hal.interface.workgroup.id[1] : index
-        %workgroup_count_y = hal.interface.workgroup.count[1] : index
-        %3 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_y, %workgroup_size_y]
-        %4 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_y, %workgroup_size_y]
-        scf.for %arg0 = %3 to %c384 step %4 {
-          %5 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_x, %workgroup_size_x]
-          %6 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_x, %workgroup_size_x]
-          scf.for %arg1 = %5 to %c128 step %6 {
-            %7 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 384)>(%arg0)[%workgroup_size_y]
-            %8 = flow.dispatch.tensor.load %0, offsets = [%arg0, 0], sizes = [%7, 512], strides = [1, 1] : !flow.dispatch.tensor<readonly:384x512xf32> -> tensor<?x512xf32>
-            %9 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 128)>(%arg1)[%workgroup_size_x]
-            %10 = flow.dispatch.tensor.load %1, offsets = [0, %arg1], sizes = [512, %9], strides = [1, 1] : !flow.dispatch.tensor<readonly:512x128xf32> -> tensor<512x?xf32>
-            %11 = affine.min affine_map<(d0)[s0] -> (-d0 + 384, s0)>(%arg0)[%workgroup_size_y]
-            %12 = affine.min affine_map<(d0)[s0] -> (-d0 + 128, s0)>(%arg1)[%workgroup_size_x]
-            %13 = linalg.init_tensor [%11, %12] : tensor<?x?xf32>
-            %14 = linalg.fill(%cst, %13) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
-            %15 = linalg.matmul ins(%8, %10 : tensor<?x512xf32>, tensor<512x?xf32>) outs(%14 : tensor<?x?xf32>) -> tensor<?x?xf32>
-            flow.dispatch.tensor.store %15, %2, offsets = [%arg0, %arg1], sizes = [%7, %9], strides = [1, 1] : tensor<?x?xf32> -> !flow.dispatch.tensor<writeonly:384x128xf32>
-          }
-        }
+        %cst = arith.constant 0.0 : f32
+        %lhs_binding = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) : !flow.dispatch.tensor<readonly:384x512xf32>
+        %rhs_binding = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) : !flow.dispatch.tensor<readonly:512x128xf32>
+        %result_binding = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) : !flow.dispatch.tensor<writeonly:384x128xf32>
+        %lhs = flow.dispatch.tensor.load %lhs_binding, offsets = [0, 0], sizes = [384, 512], strides = [1, 1]
+            : !flow.dispatch.tensor<readonly:384x512xf32> -> tensor<384x512xf32>
+        %rhs = flow.dispatch.tensor.load %rhs_binding, offsets = [0, 0], sizes = [512, 128], strides = [1, 1]
+            : !flow.dispatch.tensor<readonly:512x128xf32> -> tensor<512x128xf32>
+        %init = linalg.init_tensor [384, 128] : tensor<384x128xf32>
+        %fill = linalg.fill(%cst, %init) : f32, tensor<384x128xf32> -> tensor<384x128xf32>
+        %gemm = linalg.matmul ins(%lhs, %rhs : tensor<384x512xf32>, tensor<512x128xf32>)
+            outs(%fill : tensor<384x128xf32>) -> tensor<384x128xf32>
+        flow.dispatch.tensor.store %gemm, %result_binding, offsets = [0, 0], sizes = [384, 128], strides = [1, 1]
+            : tensor<384x128xf32> -> !flow.dispatch.tensor<writeonly:384x128xf32>
         return
       }
     }
   }
 }
 
-//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"CPUDoubleTilingExpert", workload_per_wg = [64, 64]>
-//  CHECK-DAG: #[[CONFIG:.+]] =  #iree_codegen.lowering.config<tile_sizes = [{{\[}}], [8, 32, 0], [0, 0, 16]], native_vector_size = []>
-//  CHECK:       linalg.matmul {lowering.config = #[[CONFIG]]}
+//  CHECK-DAG: #[[CONFIG:.+]] =  #iree_codegen.lowering.config<tile_sizes = {{\[}}[64, 64, 0], [8, 32, 0], [0, 0, 16]{{\]}}, native_vector_size = []>
+//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"CPUDoubleTilingExpert", workload_per_wg = []>
+//      CHECK: hal.executable.entry_point public @matmul_x86
+// CHECK-SAME:     translation.info = #[[TRANSLATION]]
+//      CHECK: linalg.matmul
+// CHECK-SAME:     lowering.config = #[[CONFIG]]
+
+// -----
+
+#executable_layout = #hal.executable.layout<push_constants = 4, sets = [
+  #hal.descriptor_set.layout<0, bindings = [
+    #hal.descriptor_set.binding<0, storage_buffer>,
+    #hal.descriptor_set.binding<1, storage_buffer>,
+    #hal.descriptor_set.binding<2, storage_buffer>
+  ]>
+]>
+#executable_target_embedded_elf_x86_64_ = #hal.executable.target<
+  "llvm", "embedded-elf-x86_64", {
+    data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128",
+    native_vector_size = 16 : index,
+    target_triple = "x86_64-unknown-unknown-eabi-elf"
+  }
+>
+hal.executable private @reduction {
+  hal.executable.variant public @embedded_elf_x86_64, target = #executable_target_embedded_elf_x86_64_ {
+    hal.executable.entry_point public @predict_dispatch_86 ordinal(0) layout(#executable_layout)
+    builtin.module  {
+      func @predict_dispatch_86(%arg0: !flow.dispatch.tensor<readonly:7x7x2048xf32>,
+          %arg1: !flow.dispatch.tensor<writeonly:7xf32>) {
+        %cst = arith.constant 0.0 : f32
+        %cst1 = arith.constant 10.0 : f32
+        %input = flow.dispatch.tensor.load %arg0, offsets = [0, 0, 0], sizes = [7, 7, 2048], strides = [1, 1, 1]
+            : !flow.dispatch.tensor<readonly:7x7x2048xf32> -> tensor<7x7x2048xf32>
+        %init = linalg.init_tensor [7] : tensor<7xf32>
+        %fill = linalg.fill(%cst, %init) : f32, tensor<7xf32> -> tensor<7xf32> 
+        %reduce = linalg.generic {
+            indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0)>],
+            iterator_types = ["parallel", "reduction", "reduction"]}
+            ins(%input : tensor<7x7x2048xf32>) outs(%fill : tensor<7xf32>) {
+            ^bb0(%b0: f32, %b1: f32):
+              %addf = arith.addf %b0, %b1 : f32
+              linalg.yield %addf : f32
+            } -> tensor<7xf32>
+        %generic = linalg.generic {
+            indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>],
+            iterator_types = ["parallel"]}
+            ins(%reduce : tensor<7xf32>) outs(%init : tensor<7xf32>) {
+            ^bb0(%b0: f32, %b1: f32):
+              %11 = arith.divf %b0, %cst1 : f32
+              linalg.yield %11 : f32
+            } -> tensor<7xf32>
+          flow.dispatch.tensor.store %generic, %arg1, offsets = [0], sizes = [7], strides = [1]
+              : tensor<7xf32> -> !flow.dispatch.tensor<writeonly:7xf32>
+        return
+      }
+    }
+  }
+}
+//  CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering.config<tile_sizes = {{\[}}[4, 0, 0], [4, 0, 0], [0, 1, 4]{{\]}}, native_vector_size = []>
+//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"CPUDoubleTilingExpert", workload_per_wg = []>
+//      CHECK: hal.executable.entry_point public @predict_dispatch_86
+// CHECK-SAME:     translation.info = #[[TRANSLATION]]
+//      CHECK: linalg.generic {indexing_maps = [#{{.+}}, #{{.+}}], iterator_types = ["parallel", "reduction", "reduction"]}
+// CHECK-SAME:     lowering.config = #[[CONFIG]]
 
 // -----
 
@@ -1403,51 +858,36 @@
     builtin.module {
       func @matmul_i8_i8_i32() {
         %c0 = arith.constant 0 : index
-        %0 = hal.interface.constant.load[0] : i32
-        %1 = hal.interface.constant.load[1] : i32
-        %2 = hal.interface.constant.load[2] : i32
-        %3 = hal.interface.constant.load[3] : i32
-        %4 = hal.interface.constant.load[4] : i32
-        %5 = hal.interface.constant.load[5] : i32
-        %6 = arith.index_cast %0 : i32 to index
-        %7 = arith.index_cast %1 : i32 to index
-        %8 = arith.index_cast %2 : i32 to index
-        %9 = arith.index_cast %3 : i32 to index
-        %10 = arith.index_cast %4 : i32 to index
-        %11 = arith.index_cast %5 : i32 to index
-        %12 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(32) : !flow.dispatch.tensor<readonly:?x?xi8>{%6, %7}
-        %13 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(32) : !flow.dispatch.tensor<readonly:?x?xi8>{%8, %9}
-        %14 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) offset(%c0) alignment(32) : !flow.dispatch.tensor<readwrite:?x?xi32>{%10, %11}
-        %workgroup_size_x = hal.interface.workgroup.size[0] : index
-        %workgroup_size_y = hal.interface.workgroup.size[1] : index
-        %workgroup_id_x = hal.interface.workgroup.id[0] : index
-        %workgroup_count_x = hal.interface.workgroup.count[0] : index
-        %workgroup_id_y = hal.interface.workgroup.id[1] : index
-        %workgroup_count_y = hal.interface.workgroup.count[1] : index
-        %15 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_y, %workgroup_size_y]
-        %16 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_y, %workgroup_size_y]
-        scf.for %arg0 = %15 to %6 step %16 {
-          %17 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_x, %workgroup_size_x]
-          %18 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_x, %workgroup_size_x]
-          scf.for %arg1 = %17 to %9 step %18 {
-            %19 = affine.min affine_map<(d0)[s0, s1] -> (s1, -d0 + s0)>(%arg0)[%6, %workgroup_size_y]
-            %20 = flow.dispatch.tensor.load %12, offsets = [%arg0, 0], sizes = [%19, %7], strides = [1, 1] : !flow.dispatch.tensor<readonly:?x?xi8>{%6, %7} -> tensor<?x?xi8>
-            %21 = affine.min affine_map<(d0)[s0, s1] -> (s1, -d0 + s0)>(%arg1)[%9, %workgroup_size_x]
-            %22 = flow.dispatch.tensor.load %13, offsets = [0, %arg1], sizes = [%8, %21], strides = [1, 1] : !flow.dispatch.tensor<readonly:?x?xi8>{%8, %9} -> tensor<?x?xi8>
-            %23 = flow.dispatch.tensor.load %14, offsets = [%arg0, %arg1], sizes = [%19, %21], strides = [1, 1] : !flow.dispatch.tensor<readwrite:?x?xi32>{%10, %11} -> tensor<?x?xi32>
-            %24 = linalg.matmul ins(%20, %22 : tensor<?x?xi8>, tensor<?x?xi8>) outs(%23 : tensor<?x?xi32>) -> tensor<?x?xi32>
-            flow.dispatch.tensor.store %24, %14, offsets = [%arg0, %arg1], sizes = [%19, %21], strides = [1, 1] : tensor<?x?xi32> -> !flow.dispatch.tensor<readwrite:?x?xi32>{%10, %11}
-          }
-        }
+        %M = hal.interface.constant.load[0] : index
+        %N = hal.interface.constant.load[1] : index
+        %K = hal.interface.constant.load[2] : index
+        %lhs_binding = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(32)
+            : !flow.dispatch.tensor<readonly:?x?xi8>{%M, %K}
+        %rhs_binding = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(32)
+            : !flow.dispatch.tensor<readonly:?x?xi8>{%K, %N}
+        %result_binding = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) offset(%c0) alignment(32)
+            : !flow.dispatch.tensor<readwrite:?x?xi32>{%M, %N}
+        %lhs = flow.dispatch.tensor.load %lhs_binding, offsets = [0, 0], sizes = [%M, %K], strides = [1, 1]
+            : !flow.dispatch.tensor<readonly:?x?xi8>{%M, %K} -> tensor<?x?xi8>
+        %rhs = flow.dispatch.tensor.load %rhs_binding, offsets = [0, 0], sizes = [%K, %N], strides = [1, 1]
+            : !flow.dispatch.tensor<readonly:?x?xi8>{%K, %N} -> tensor<?x?xi8>
+        %init = flow.dispatch.tensor.load %result_binding, offsets = [0, 0], sizes = [%M, %N], strides = [1, 1]
+            : !flow.dispatch.tensor<readwrite:?x?xi32>{%M, %N} -> tensor<?x?xi32>
+        %gemm = linalg.matmul ins(%lhs, %rhs : tensor<?x?xi8>, tensor<?x?xi8>) outs(%init : tensor<?x?xi32>) -> tensor<?x?xi32>
+        flow.dispatch.tensor.store %gemm, %result_binding, offsets = [0, 0], sizes = [%M, %N], strides = [1, 1]
+            : tensor<?x?xi32> -> !flow.dispatch.tensor<readwrite:?x?xi32>{%M, %N}
         return
       }
     }
   }
 }
 
-//  CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering.config<tile_sizes = {{\[}}[], [8, 8, 8], [1, 4, 4]{{\]}}, native_vector_size = [1, 4, 4]>
-//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"CPUTileFuseAndVectorize", workload_per_wg = [64, 64]>
-//  CHECK:       linalg.matmul {lowering.config = #[[CONFIG]]}
+//  CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering.config<tile_sizes = {{\[}}[64, 64, 0], [8, 8, 8], [1, 4, 4]{{\]}}, native_vector_size = [1, 4, 4]>
+//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"CPUTileFuseAndVectorize", workload_per_wg = []>
+//      CHECK: hal.executable.entry_point public @matmul_i8_i8_i32
+// CHECK-SAME:     translation.info = #[[TRANSLATION]]
+//      CHECK:   linalg.matmul
+// CHECK-SAME:       lowering.config = #[[CONFIG]]
 
 // -----
 
@@ -1473,47 +913,32 @@
     builtin.module  {
       func @gemm_unit_N() {
         %c0 = arith.constant 0 : index
-        %0 = hal.interface.constant.load[0] : i32
-        %1 = hal.interface.constant.load[1] : i32
-        %2 = hal.interface.constant.load[2] : i32
-        %3 = hal.interface.constant.load[3] : i32
-        %4 = arith.index_cast %0 : i32 to index
-        %5 = arith.index_cast %1 : i32 to index
-        %6 = arith.index_cast %2 : i32 to index
-        %7 = arith.index_cast %3 : i32 to index
-        %8 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(32) : !flow.dispatch.tensor<readonly:?x?xf32>{%4, %5}
-        %9 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(32) : !flow.dispatch.tensor<readonly:?x1xf32>{%6}
-        %10 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) offset(%c0) alignment(32) : !flow.dispatch.tensor<readwrite:?x1xf32>{%7}
-        %11 = flow.dispatch.tensor.load %9, offsets = [0, 0], sizes = [%6, 1], strides = [1, 1] : !flow.dispatch.tensor<readonly:?x1xf32>{%6} -> tensor<?x1xf32>
-        %workgroup_size_x = hal.interface.workgroup.size[0] : index
-        %workgroup_id_x = hal.interface.workgroup.id[0] : index
-        %workgroup_count_x = hal.interface.workgroup.count[0] : index
-        %12 = affine.apply #map0()[%workgroup_id_x, %workgroup_size_x]
-        %13 = affine.apply #map0()[%workgroup_count_x, %workgroup_size_x]
-        scf.for %arg0 = %12 to %4 step %13 {
-          %14 = affine.min #map1(%arg0)[%4, %workgroup_size_x]
-          %15 = flow.dispatch.tensor.load %8, offsets = [%arg0, 0], sizes = [%14, %5], strides = [1, 1] : !flow.dispatch.tensor<readonly:?x?xf32>{%4, %5} -> tensor<?x?xf32>
-          %16 = flow.dispatch.tensor.load %10, offsets = [%arg0, 0], sizes = [%14, 1], strides = [1, 1] : !flow.dispatch.tensor<readwrite:?x1xf32>{%7} -> tensor<?x1xf32>
-          %17 = linalg.matmul ins(%15, %11 : tensor<?x?xf32>, tensor<?x1xf32>) outs(%16 : tensor<?x1xf32>) -> tensor<?x1xf32>
-          flow.dispatch.tensor.store %17, %10, offsets = [%arg0, 0], sizes = [%14, 1], strides = [1, 1] : tensor<?x1xf32> -> !flow.dispatch.tensor<readwrite:?x1xf32>{%7}
-        }
+        %M = hal.interface.constant.load[0] : index
+        %K = hal.interface.constant.load[1] : index
+        %lhs_binding = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(32)
+            : !flow.dispatch.tensor<readonly:?x?xf32>{%M, %K}
+        %rhs_binding = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(32)
+            : !flow.dispatch.tensor<readonly:?x1xf32>{%K}
+        %result_binding = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) offset(%c0) alignment(32)
+            : !flow.dispatch.tensor<readwrite:?x1xf32>{%M}
+        %rhs = flow.dispatch.tensor.load %rhs_binding, offsets = [0, 0], sizes = [%K, 1], strides = [1, 1]
+            : !flow.dispatch.tensor<readonly:?x1xf32>{%K} -> tensor<?x1xf32>
+        %lhs = flow.dispatch.tensor.load %lhs_binding, offsets = [0, 0], sizes = [%M, %K], strides = [1, 1]
+            : !flow.dispatch.tensor<readonly:?x?xf32>{%M, %K} -> tensor<?x?xf32>
+        %init = flow.dispatch.tensor.load %result_binding, offsets = [0, 0], sizes = [%M, 1], strides = [1, 1]
+            : !flow.dispatch.tensor<readwrite:?x1xf32>{%M} -> tensor<?x1xf32>
+        %gemm = linalg.matmul ins(%lhs, %rhs : tensor<?x?xf32>, tensor<?x1xf32>) outs(%init : tensor<?x1xf32>) -> tensor<?x1xf32>
+        flow.dispatch.tensor.store %gemm, %result_binding, offsets = [0, 0], sizes = [%M, 1], strides = [1, 1]
+            : tensor<?x1xf32> -> !flow.dispatch.tensor<readwrite:?x1xf32>{%M}
         return
       }
     }
   }
 }
-//  CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering.config<tile_sizes = {{\[}}[], [8, 0, 0], [0, 0, 16]], native_vector_size = []>
-//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"CPUDoubleTilingExpert", workload_per_wg = [64]>
-//  CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0] -> (s0 ceildiv 64)>
+//  CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering.config<tile_sizes = {{\[}}[64, 0, 0], [8, 0, 0], [0, 0, 16]], native_vector_size = []>
+//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"CPUDoubleTilingExpert", workload_per_wg = []>
 //      CHECK: hal.executable.entry_point public @gemm_unit_N
-// CHECK-SAME:     translation.info = #[[TRANSLATION]]
-//      CHECK:   ^{{[a-z0-9]+}}
-// CHECK-SAME:       %[[ARG0:[a-zA-Z0-9]+]]: index
-// CHECK-SAME:       %[[ARG1:[a-zA-Z0-9]+]]: index
-// CHECK-SAME:       %[[ARG2:[a-zA-Z0-9]+]]: index
-//  CHECK-DAG:     %[[C1:.+]] = arith.constant 1 : index
-//  CHECK-DAG:     %[[N0:.+]] = affine.apply #[[MAP]]()[%[[ARG0]]]
-//      CHECK:     hal.return %[[N0]], %[[C1]], %[[C1]]
+// CHECK-SAME:       translation.info = #[[TRANSLATION]]
 //      CHECK:   linalg.matmul
 // CHECK-SAME:       lowering.config = #[[CONFIG]]
 
@@ -1539,161 +964,169 @@
     builtin.module  {
       func @gemm_unit_M_unit_N() {
         %c0 = arith.constant 0 : index
-        %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(32) : !flow.dispatch.tensor<readonly:1x1xf32>
-        %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(32) : !flow.dispatch.tensor<readonly:1x1xf32>
-        %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) offset(%c0) alignment(32) : !flow.dispatch.tensor<readwrite:1x1xf32>
-        %3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [1, 1], strides = [1, 1] : !flow.dispatch.tensor<readonly:1x1xf32> -> tensor<1x1xf32>
-        %4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [1, 1], strides = [1, 1] : !flow.dispatch.tensor<readonly:1x1xf32> -> tensor<1x1xf32>
-        %5 = flow.dispatch.tensor.load %2, offsets = [0, 0], sizes = [1, 1], strides = [1, 1] : !flow.dispatch.tensor<readwrite:1x1xf32> -> tensor<1x1xf32>
-        %6 = linalg.matmul ins(%3, %4 : tensor<1x1xf32>, tensor<1x1xf32>) outs(%5 : tensor<1x1xf32>) -> tensor<1x1xf32>
-        flow.dispatch.tensor.store %6, %2, offsets = [0, 0], sizes = [1, 1], strides = [1, 1] : tensor<1x1xf32> -> !flow.dispatch.tensor<readwrite:1x1xf32>
+        %K = hal.interface.constant.load[0] : index
+        %lhs_binding = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(32) : !flow.dispatch.tensor<readonly:1x?xf32>{%K}
+        %rhs_binding = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(32) : !flow.dispatch.tensor<readonly:?x1xf32>{%K}
+        %result_binding = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) offset(%c0) alignment(32) : !flow.dispatch.tensor<readwrite:1x1xf32>
+        %lhs = flow.dispatch.tensor.load %lhs_binding, offsets = [0, 0], sizes = [1, %K], strides = [1, 1]
+            : !flow.dispatch.tensor<readonly:1x?xf32>{%K} -> tensor<1x?xf32>
+        %rhs = flow.dispatch.tensor.load %rhs_binding, offsets = [0, 0], sizes = [%K, 1], strides = [1, 1]
+            : !flow.dispatch.tensor<readonly:?x1xf32>{%K} -> tensor<?x1xf32>
+        %init = flow.dispatch.tensor.load %result_binding, offsets = [0, 0], sizes = [1, 1], strides = [1, 1]
+            : !flow.dispatch.tensor<readwrite:1x1xf32> -> tensor<1x1xf32>
+        %gemm = linalg.matmul ins(%lhs, %rhs : tensor<1x?xf32>, tensor<?x1xf32>) outs(%init : tensor<1x1xf32>) -> tensor<1x1xf32>
+        flow.dispatch.tensor.store %gemm, %result_binding, offsets = [0, 0], sizes = [1, 1], strides = [1, 1]
+            : tensor<1x1xf32> -> !flow.dispatch.tensor<readwrite:1x1xf32>
         return
       }
     }
   }
 }
-//  CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering.config<tile_sizes = {{\[}}[], [0, 0, 0], [0, 0, 0]], native_vector_size = []>
+//  CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering.config<tile_sizes = {{\[}}[0, 0, 0], [0, 0, 0], [0, 0, 16]], native_vector_size = []>
 //  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"CPUDoubleTilingExpert", workload_per_wg = []>
 //      CHECK: hal.executable.entry_point public @gemm_unit_M_unit_N
-// CHECK-SAME:     translation.info = #[[TRANSLATION]]
-//      CHECK:   ^{{[a-z0-9]+}}
-// CHECK-SAME:       %[[ARG0:[a-zA-Z0-9]+]]: index
-// CHECK-SAME:       %[[ARG1:[a-zA-Z0-9]+]]: index
-// CHECK-SAME:       %[[ARG2:[a-zA-Z0-9]+]]: index
-//  CHECK-DAG:     %[[C1:.+]] = arith.constant 1 : index
-//      CHECK:     hal.return %[[C1]], %[[C1]], %[[C1]]
+// CHECK-SAME:       translation.info = #[[TRANSLATION]]
 //      CHECK:   linalg.matmul
 // CHECK-SAME:       lowering.config = #[[CONFIG]]
 
 // -----
 
-#executable_layout = #hal.executable.layout<push_constants = 4, sets = [
+#executable_layout = #hal.executable.layout<push_constants = 0, sets = [
   #hal.descriptor_set.layout<0, bindings = [
     #hal.descriptor_set.binding<0, storage_buffer>,
     #hal.descriptor_set.binding<1, storage_buffer>,
     #hal.descriptor_set.binding<2, storage_buffer>
   ]>
 ]>
-#executable_target_embedded_elf_x86_64_ = #hal.executable.target<
-  "llvm", "embedded-elf-x86_64", {
+hal.executable private @generic_unit_dims {
+  hal.executable.variant @llvm, target = <"llvm", "embedded-elf-x86_64", {
     data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128",
     native_vector_size = 16 : index,
-    target_triple = "x86_64-unknown-unknown-eabi-elf"
-  }
->
-#map0 = affine_map<()[s0, s1] -> (s0 * s1)>
-#map1 = affine_map<(d0)[s0, s1] -> (s1, -d0 + s0)>
-hal.executable private @gemm_unit_M {
-  hal.executable.variant public @embedded_elf_x86_64, target = #executable_target_embedded_elf_x86_64_ {
-    hal.executable.entry_point public @gemm_unit_M ordinal(0) layout(#executable_layout)
-    builtin.module  {
-      func @gemm_unit_M() {
+    target_triple = "x86_64-unknown-linux-gnu"
+  }> {
+    hal.executable.entry_point @generic_unit_dims layout(#executable_layout)
+    builtin.module {
+      func @generic_unit_dims() {
         %c0 = arith.constant 0 : index
-        %0 = hal.interface.constant.load[0] : i32
-        %1 = hal.interface.constant.load[1] : i32
-        %2 = hal.interface.constant.load[2] : i32
-        %3 = hal.interface.constant.load[3] : i32
-        %4 = arith.index_cast %0 : i32 to index
-        %5 = arith.index_cast %1 : i32 to index
-        %6 = arith.index_cast %2 : i32 to index
-        %7 = arith.index_cast %3 : i32 to index
-        %8 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(32) : !flow.dispatch.tensor<readonly:?x?xf32>{%4, %5}
-        %9 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(32) : !flow.dispatch.tensor<readonly:1x?xf32>{%6}
-        %10 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) offset(%c0) alignment(32) : !flow.dispatch.tensor<readwrite:1x?xf32>{%7}
-        %11 = flow.dispatch.tensor.load %9, offsets = [0, 0], sizes = [1, %6], strides = [1, 1] : !flow.dispatch.tensor<readonly:1x?xf32>{%6} -> tensor<1x?xf32>
-        %workgroup_size_x = hal.interface.workgroup.size[0] : index
-        %workgroup_id_x = hal.interface.workgroup.id[0] : index
-        %workgroup_count_x = hal.interface.workgroup.count[0] : index
-        %12 = affine.apply #map0()[%workgroup_id_x, %workgroup_size_x]
-        %13 = affine.apply #map0()[%workgroup_count_x, %workgroup_size_x]
-        scf.for %arg0 = %12 to %5 step %13 {
-          %14 = affine.min #map1(%arg0)[%5, %workgroup_size_x]
-          %15 = flow.dispatch.tensor.load %8, offsets = [0, %arg0], sizes = [%4, %14], strides = [1, 1] : !flow.dispatch.tensor<readonly:?x?xf32>{%4, %5} -> tensor<?x?xf32>
-          %16 = flow.dispatch.tensor.load %10, offsets = [0, %arg0], sizes = [1, %14], strides = [1, 1] : !flow.dispatch.tensor<readwrite:1x?xf32>{%7} -> tensor<1x?xf32>
-          %17 = linalg.matmul ins(%11, %15 : tensor<1x?xf32>, tensor<?x?xf32>) outs(%16 : tensor<1x?xf32>) -> tensor<1x?xf32>
-          flow.dispatch.tensor.store %17, %10, offsets = [0, %arg0], sizes = [1, %14], strides = [1, 1] : tensor<1x?xf32> -> !flow.dispatch.tensor<readwrite:1x?xf32>{%7}
-        }
+        %d0 = hal.interface.constant.load[0] : index
+        %d1 = hal.interface.constant.load[1] : index
+        %d2 = hal.interface.constant.load[2] : index
+        %d3 = hal.interface.constant.load[3] : index
+        %in_binding = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer)
+            : !flow.dispatch.tensor<readonly:1x?x1x1x?x?x1x?xf32>{%d0, %d1, %d2, %d3}
+        %result_binding = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer)
+            : !flow.dispatch.tensor<writeonly:1x?x1x1x?x?x1x?xf32>{%d0, %d1, %d2, %d3}
+        %in = flow.dispatch.tensor.load %in_binding, offsets=[0, 0, 0, 0, 0, 0, 0, 0],
+            sizes=[1, %d0, 1, 1, %d1, %d2, 1, %d3], strides=[1, 1, 1, 1, 1, 1, 1, 1]
+            : !flow.dispatch.tensor<readonly:1x?x1x1x?x?x1x?xf32>{%d0, %d1, %d2, %d3} -> tensor<1x?x1x1x?x?x1x?xf32>
+        %init = linalg.init_tensor [1, %d0, 1, 1, %d1, %d2, 1, %d3] : tensor<1x?x1x1x?x?x1x?xf32>
+        %generic = linalg.generic {
+          indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4, d5, d6, d7)>,
+                           affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4, d5, d6, d7)>],
+          iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]}
+          ins(%in : tensor<1x?x1x1x?x?x1x?xf32>) outs(%init : tensor<1x?x1x1x?x?x1x?xf32>) {
+          ^bb0(%arg0: f32, %arg1: f32):  // no predecessors
+            %7 = arith.addf %arg0, %arg0 : f32
+            linalg.yield %7 : f32
+          } -> tensor<1x?x1x1x?x?x1x?xf32>
+        flow.dispatch.tensor.store %generic, %result_binding, offsets = [0, 0, 0, 0, 0, 0, 0, 0],
+            sizes = [1, %d0, 1, 1, %d1, %d2, 1, %d3], strides = [1, 1, 1, 1, 1, 1, 1, 1]
+            : tensor<1x?x1x1x?x?x1x?xf32> -> !flow.dispatch.tensor<writeonly:1x?x1x1x?x?x1x?xf32>{%d0, %d1, %d2, %d3}
         return
       }
     }
   }
 }
-//  CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering.config<tile_sizes = {{\[}}[], [0, 32, 0], [0, 0, 16]], native_vector_size = []>
-//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"CPUDoubleTilingExpert", workload_per_wg = [64]>
-//  CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0] -> (s0 ceildiv 64)>
-//      CHECK: hal.executable.entry_point public @gemm_unit_M
+//  CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering.config<tile_sizes = {{\[}}[0, 0, 0, 0, 64, 64, 0, 64], [0, 1, 0, 0, 1, 1, 0, 4], [0, 0, 0, 0, 0, 0, 0, 0]{{\]}}, native_vector_size = []>
+//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"CPUDoubleTilingExpert", workload_per_wg = []>
+//      CHECK: hal.executable.entry_point public @generic_unit_dim
 // CHECK-SAME:     translation.info = #[[TRANSLATION]]
-//      CHECK:   ^{{[a-z0-9]+}}
-// CHECK-SAME:       %[[ARG0:[a-zA-Z0-9]+]]: index
-// CHECK-SAME:       %[[ARG1:[a-zA-Z0-9]+]]: index
-// CHECK-SAME:       %[[ARG2:[a-zA-Z0-9]+]]: index
-//  CHECK-DAG:     %[[C1:.+]] = arith.constant 1 : index
-//  CHECK-DAG:     %[[N0:.+]] = affine.apply #[[MAP]]()[%[[ARG0]]]
-//      CHECK:     hal.return %[[N0]], %[[C1]], %[[C1]]
-//      CHECK:   linalg.matmul
-// CHECK-SAME:       lowering.config = #[[CONFIG]]
+//      CHECK: linalg.generic
+// CHECK-SAME:     lowering.config = #[[CONFIG]]
 
 // -----
 
-#executable_layout = #hal.executable.layout<push_constants = 4, sets = [
+#executable_layout = #hal.executable.layout<push_constants = 0, sets = [
   #hal.descriptor_set.layout<0, bindings = [
     #hal.descriptor_set.binding<0, storage_buffer>,
     #hal.descriptor_set.binding<1, storage_buffer>,
     #hal.descriptor_set.binding<2, storage_buffer>
   ]>
 ]>
-#executable_target_embedded_elf_x86_64_ = #hal.executable.target<
-  "llvm", "embedded-elf-x86_64", {
+hal.executable private @reduce_to_scalar {
+  hal.executable.variant @llvm, target = <"llvm", "embedded-elf-x86_64", {
     data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128",
     native_vector_size = 16 : index,
-    target_triple = "x86_64-unknown-unknown-eabi-elf"
-  }
->
-
-#map4 = affine_map<()[s0, s1] -> (s0 * s1)>
-#map57 = affine_map<(d0, d1) -> (d0, -d1 + 2048)>
-#map59 = affine_map<(d0, d1) -> (-d0 + 2048, d1)>
-#map60 = affine_map<(d0, d1, d2) -> (d1, d2, d0)>
-#map61 = affine_map<(d0, d1, d2) -> (d0)>
-#map62 = affine_map<(d0) -> (d0)>
-
-hal.executable private @gemm_unit_M {
-  hal.executable.variant public @embedded_elf_x86_64, target = #executable_target_embedded_elf_x86_64_ {
-    hal.executable.entry_point public @predict_dispatch_86 ordinal(0) layout(#executable_layout)
-    builtin.module  {
-      func @predict_dispatch_86(%arg0: !flow.dispatch.tensor<readonly:7x7x2048xf32>, %arg1: !flow.dispatch.tensor<writeonly:2048xf32>) {
-        %cst = arith.constant 4.900000e+01 : f32
-        %cst_0 = arith.constant 0.000000e+00 : f32
-        %c2048 = arith.constant 2048 : index
-        %workgroup_size_0 = flow.dispatch.workgroup.size[0] : index
-        %workgroup_id_0 = flow.dispatch.workgroup.id[0] : index
-        %workgroup_count_0 = flow.dispatch.workgroup.count[0] : index
-        %0 = affine.apply #map4()[%workgroup_id_0, %workgroup_size_0]
-        %1 = affine.apply #map4()[%workgroup_count_0, %workgroup_size_0]
-        scf.for %arg2 = %0 to %c2048 step %1 {
-          %2 = affine.min #map57(%workgroup_size_0, %arg2)
-          %3 = linalg.init_tensor [%2] : tensor<?xf32>
-          %4 = affine.min #map59(%arg2, %workgroup_size_0)
-          %5 = flow.dispatch.tensor.load %arg0, offsets = [0, 0, %arg2], sizes = [7, 7, %4], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:7x7x2048xf32> -> tensor<7x7x?xf32>
-          %6 = affine.min #map59(%arg2, %workgroup_size_0)
-          %7 = linalg.init_tensor [%6] : tensor<?xf32>
-          %8 = linalg.fill(%cst_0, %7) : f32, tensor<?xf32> -> tensor<?xf32> 
-          %9 = linalg.generic {indexing_maps = [#map60, #map61], iterator_types = ["parallel", "reduction", "reduction"]} ins(%5 : tensor<7x7x?xf32>) outs(%8 : tensor<?xf32>) {
-          ^bb0(%arg3: f32, %arg4: f32):
-            %11 = arith.addf %arg3, %arg4 : f32
-            linalg.yield %11 : f32
-          } -> tensor<?xf32>
-          %10 = linalg.generic {indexing_maps = [#map62, #map62], iterator_types = ["parallel"]} ins(%9 : tensor<?xf32>) outs(%3 : tensor<?xf32>) {
-          ^bb0(%arg3: f32, %arg4: f32):
-            %11 = arith.divf %arg3, %cst : f32
-            linalg.yield %11 : f32
-          } -> tensor<?xf32>
-          flow.dispatch.tensor.store %10, %arg1, offsets = [%arg2], sizes = [%2], strides = [1] : tensor<?xf32> -> !flow.dispatch.tensor<writeonly:2048xf32>
-        }
+    target_triple = "x86_64-unknown-linux-gnu"
+  }> {
+    hal.executable.entry_point @reduce_to_scalar layout(#executable_layout)
+    builtin.module {
+      func @reduce_to_scalar() {
+        %c0 = arith.constant 0 : index
+        %d0 = hal.interface.constant.load[0] : index
+        %in_binding = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) : !flow.dispatch.tensor<readonly:?xf32>{%d0}
+        %out_binding = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) : !flow.dispatch.tensor<readwrite:f32>
+        %in = flow.dispatch.tensor.load %in_binding, offsets=[0], sizes=[%d0], strides=[1] : !flow.dispatch.tensor<readonly:?xf32>{%d0} -> tensor<?xf32>
+        %out = flow.dispatch.tensor.load %out_binding, offsets=[], sizes=[], strides=[] : !flow.dispatch.tensor<readwrite:f32> -> tensor<f32>
+        %reduce = linalg.generic {
+          indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> ()>],
+          iterator_types = ["reduction"]}
+          ins(%in : tensor<?xf32>) outs(%out : tensor<f32>) {
+          ^bb0(%arg0: f32, %arg1: f32):  // no predecessors
+            %7 = arith.addf %arg0, %arg1 : f32
+            linalg.yield %7 : f32
+          } -> tensor<f32>
+        flow.dispatch.tensor.store %reduce, %out_binding, offsets = [], sizes = [], strides = [] : tensor<f32> -> !flow.dispatch.tensor<readwrite:f32>
         return
       }
     }
   }
 }
-//  CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering.config<tile_sizes = {{\[}}[], [4, 0, 0], [0, 1, 1]], native_vector_size = []>
+//  CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering.config<tile_sizes = {{\[}}[0], [0], [4]{{\]}}, native_vector_size = []>
 //  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"CPUDoubleTilingExpert", workload_per_wg = []>
+//      CHECK: hal.executable.entry_point public @reduce_to_scalar
+// CHECK-SAME:     translation.info = #[[TRANSLATION]]
+//      CHECK: linalg.generic
+// CHECK-SAME:     lowering.config = #[[CONFIG]]
+
+// -----
+
+#executable_layout = #hal.executable.layout<push_constants = 0, sets = [
+  #hal.descriptor_set.layout<0, bindings = [
+    #hal.descriptor_set.binding<0, storage_buffer>,
+    #hal.descriptor_set.binding<1, storage_buffer>,
+    #hal.descriptor_set.binding<2, storage_buffer>
+  ]>
+]>
+hal.executable private @scalar {
+  hal.executable.variant @llvm, target = <"llvm", "embedded-elf-x86_64", {
+    data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128",
+    native_vector_size = 16 : index,
+    target_triple = "x86_64-unknown-linux-gnu"
+  }> {
+    hal.executable.entry_point @scalar layout(#executable_layout)
+    builtin.module {
+      func @scalar() {
+        %c0 = arith.constant 0 : index
+        %in_binding = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) : !flow.dispatch.tensor<readonly:f32>
+        %out_binding = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) : !flow.dispatch.tensor<writeonly:f32>
+        %in = flow.dispatch.tensor.load %in_binding, offsets=[], sizes=[], strides=[] : !flow.dispatch.tensor<readonly:f32> -> tensor<f32>
+        %out = flow.dispatch.tensor.load %out_binding, offsets=[], sizes=[], strides=[] : !flow.dispatch.tensor<writeonly:f32> -> tensor<f32>
+        %reduce = linalg.generic {
+          indexing_maps = [affine_map<() -> ()>,
+                           affine_map<() -> ()>],
+          iterator_types = []}
+          ins(%in : tensor<f32>) outs(%out : tensor<f32>) {
+          ^bb0(%arg0: f32, %arg1: f32):  // no predecessors
+            %7 = arith.addf %arg0, %arg1 : f32
+            linalg.yield %7 : f32
+          } -> tensor<f32>
+        flow.dispatch.tensor.store %reduce, %out_binding, offsets = [], sizes = [], strides = [] : tensor<f32> -> !flow.dispatch.tensor<writeonly:f32>
+        return
+      }
+    }
+  }
+}
+//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"CPUDefault", workload_per_wg = []>
+//      CHECK: hal.executable.entry_point public @scalar
+// CHECK-SAME:     translation.info = #[[TRANSLATION]]
diff --git a/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp b/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
index 865bb24..765be1a 100644
--- a/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
+++ b/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
@@ -99,53 +99,34 @@
                          llvm::ArrayRef<int64_t> workgroupSize,
                          IREE::Codegen::DispatchLoweringPassPipeline pipeline) {
         TileSizesListType tileSizes;
-        SmallVector<int64_t> ts;
+        unsigned numParallelLoops = op.getNumParallelLoops();
+        SmallVector<int64_t> workgroupTileSizes(numParallelLoops - 2, 1);
+        workgroupTileSizes.append({tileX, tileY});
+        workgroupTileSizes.append(op.getNumReductionLoops(), tileK);
+
         SmallVector<unsigned> partitionedLoops =
             cast<IREE::Flow::PartitionableLoopsInterface>(op.getOperation())
                 .getPartitionableLoops(kNumMaxParallelDims);
-        unsigned index = 0;
-        // Tile all the higher parallel dimension with a size of 1 and the 2
-        // most inner dimension with the tileX/tileY size.
-        for (auto loopNum :
-             llvm::seq<unsigned>(0, op.getNumParallelLoops() - 2)) {
-          int64_t tileSize = 0;
-          if (index < partitionedLoops.size() &&
-              partitionedLoops[index] == loopNum) {
-            tileSize = 1;
-            index++;
+        llvm::SmallDenseSet<unsigned, 4> partitionedLoopsSet;
+        partitionedLoopsSet.insert(partitionedLoops.begin(),
+                                   partitionedLoops.end());
+        for (auto loopID : llvm::seq<unsigned>(0, numParallelLoops)) {
+          if (!partitionedLoopsSet.count(loopID)) {
+            workgroupTileSizes[loopID] = 0;
           }
-          ts.push_back(tileSize);
         }
 
-        // Check for M loop being partitioned.
-        if (index < partitionedLoops.size() &&
-            partitionedLoops[index] == op.getNumParallelLoops() - 2) {
-          index++;
-        } else {
-          // M dim isnt partitioned.
-          tileX = 0;
-        }
-
-        // Check for N loop being partitioned.
-        if (index < partitionedLoops.size() &&
-            partitionedLoops[index] == op.getNumParallelLoops() - 1) {
-          index++;
-        } else {
-          // N dim isnt partitioned.
-          tileY = 0;
-        }
-
-        ts.append({tileX, tileY});
-        // Tile all the reduction dimensions.
-        ts.append(op.getNumReductionLoops(), tileK);
-        tileSizes.push_back(ts);  // Workgroup level.
+        tileSizes.emplace_back(
+            std::move(workgroupTileSizes));  // Workgroup level.
         return setOpConfigAndEntryPointFnTranslation(
             entryPoint, op, tileSizes,
             /*nativeVectorSizes=*/ArrayRef<int64_t>{}, pipeline, workgroupSize);
       };
   // Infer the MxN size of the matmul based on operands and indexing maps.
-  auto lhsShape = getUntiledShape(op.getInputOperand(0)->get());
-  auto rhsShape = getUntiledShape(op.getInputOperand(1)->get());
+  auto lhsShape =
+      op.getInputOperand(0)->get().getType().cast<ShapedType>().getShape();
+  auto rhsShape =
+      op.getInputOperand(1)->get().getType().cast<ShapedType>().getShape();
   int64_t sizeM = ShapedType::kDynamicSize;
   int64_t sizeN = ShapedType::kDynamicSize;
   int64_t sizeK = ShapedType::kDynamicSize;
@@ -332,8 +313,12 @@
         vectorSize = 1;
         break;
       }
-      SmallVector<int64_t> shape = getUntiledResultShape(
-          cast<linalg::LinalgOp>(op), outputOperand.index());
+      ArrayRef<int64_t> shape = cast<linalg::LinalgOp>(op)
+                                    .getOutputOperand(outputOperand.index())
+                                    ->get()
+                                    .getType()
+                                    .cast<ShapedType>()
+                                    .getShape();
       if (llvm::any_of(shape, ShapedType::isDynamic)) {
         vectorSize = 1;
         break;
@@ -428,29 +413,6 @@
       return funcOp.emitOpError("failed to get compute ops");
     }
 
-    if (computeOps.empty()) {
-      std::array<int64_t, 3> workgroupSize = {1, 1, 1};
-      SmallVector<int64_t> workloadPerWorkgroup;
-      if (!tiledLoops.empty()) {
-        // If the tiled loops are not empty then this could be a corner case of
-        // tensor.insert_slice being tiled and distributed, that just shows up
-        // as a `flow.dispatch.tensor.load` and a `flow.dispatch.tensor.store`.
-        // For now just treat the tiled loops not being empty as an indicator of
-        // that. Need a better way of information flow from flow dialect to hal.
-        workgroupSize[0] = cudaWarpSize;
-        workloadPerWorkgroup.resize(tiledLoops.size(), 1);
-        workloadPerWorkgroup.front() = cudaWarpSize * 4;
-      }
-      // TODO(ravishankarm): Maybe this should just return without setting
-      // anything. Without any compute ops, this shouldnt be using tile and
-      // distribute.
-      setTranslationInfo(
-          funcOp,
-          IREE::Codegen::DispatchLoweringPassPipeline::LLVMGPUDistribute,
-          workloadPerWorkgroup, workgroupSize);
-      continue;
-    }
-
     Operation *rootOperation = nullptr;
     // Find the root operation. linalg.generic and linalg.fill are not root
     // operations if there are other compute operations present.
@@ -478,14 +440,26 @@
     }
 
     if (!rootOperation) {
-      // TODO(ravishankarm): Maybe this should just return without setting
-      // anything. Without any compute ops, this shouldnt be using tile and
-      // distribute.
-      setTranslationInfo(
-          funcOp,
-          IREE::Codegen::DispatchLoweringPassPipeline::LLVMGPUDistribute,
-          /*workloadPerWorkgroup=*/{}, {1, 1, 1});
-      continue;
+      // TODO(ravishankarm): Currently you could have dispatches with a single
+      // tensor.insert_slice or a tensor.extract_slice. Those are handled by
+      // tile + distribute as well since these ops have an external model
+      // implementing the `TiledOpInterface`. This is legacy. These ops shouldnt
+      // implement this interface, and backends must be able to handle a
+      // dispatch with flow.dispatch.tensor.load -> flow.dispatch.tensor.store.
+      // Till this is cleaned up, set a configuration for this.
+      if (computeOps.size() == 1 &&
+          isa<tensor::ExtractSliceOp, tensor::InsertSliceOp>(computeOps[0])) {
+        rootOperation = computeOps[0];
+      }
+    }
+
+    if (!rootOperation) {
+      // setTranslationInfo(
+      //    funcOp,
+      //    IREE::Codegen::DispatchLoweringPassPipeline::LLVMGPUDistribute,
+      //    /*workloadPerWorkgroup=*/{}, {1, 1, 1});
+      // continue;
+      return funcOp.emitOpError("unable to find root operation");
     }
     if (failed(setRootConfig(funcOp, rootOperation))) continue;
 
diff --git a/iree/compiler/Codegen/LLVMGPU/LLVMGPULowerExecutableTarget.cpp b/iree/compiler/Codegen/LLVMGPU/LLVMGPULowerExecutableTarget.cpp
index 795c84f..21c4924 100644
--- a/iree/compiler/Codegen/LLVMGPU/LLVMGPULowerExecutableTarget.cpp
+++ b/iree/compiler/Codegen/LLVMGPU/LLVMGPULowerExecutableTarget.cpp
@@ -139,8 +139,6 @@
     }
   }
 
-  executableLoweringPipeline.addPass(createSetNumWorkgroupsPass());
-  executableLoweringPipeline.addPass(createCanonicalizerPass());
   if (!testLoweringConfiguration && translationInfo.hasValue()) {
     OpPassManager &nestedModulePM = executableLoweringPipeline.nest<ModuleOp>();
     switch (translationInfo.getValue().getDispatchLoweringPassPipeline()) {
diff --git a/iree/compiler/Codegen/LLVMGPU/Passes.cpp b/iree/compiler/Codegen/LLVMGPU/Passes.cpp
index 0ef1ebc..274a9c3 100644
--- a/iree/compiler/Codegen/LLVMGPU/Passes.cpp
+++ b/iree/compiler/Codegen/LLVMGPU/Passes.cpp
@@ -37,17 +37,23 @@
   return builder.create<memref::AllocOp>(loc, allocType, dynamicSizes);
 }
 
+static void tileAndBufferize(OpPassManager &pm) {
+  pm.addNestedPass<FuncOp>(createTileAndDistributeToWorkgroupsPass());
+  pm.addPass(createCanonicalizerPass());
+  pm.addPass(createCSEPass());
+
+  addLinalgBufferizePasses(pm, gpuAllocationFunction);
+
+  pm.addPass(createCanonicalizerPass());
+  pm.addPass(createCSEPass());
+}
 
 //===---------------------------------------------------------------------===//
 // Codegen pipelines.
 //===---------------------------------------------------------------------===//
 
 void addGPUVectorizationPassPipeline(OpPassManager &pm) {
-  //===--------------------------------------------------------------------===//
-  // Initial clean up.
-  //===--------------------------------------------------------------------===//
-  pm.addPass(createCanonicalizerPass());
-  pm.addPass(createCSEPass());
+  tileAndBufferize(pm);
 
   // Distribute linalg onto threads within the workgroup.
   pm.addNestedPass<FuncOp>(createLLVMGPUTileAndDistribute());
@@ -64,9 +70,12 @@
 }
 
 void addGPUMatmulSimtPassPipeline(OpPassManager &pm) {
-  //===--------------------------------------------------------------------===//
-  // Initial clean up.
-  //===--------------------------------------------------------------------===//
+  pm.addNestedPass<FuncOp>(createTileAndDistributeToWorkgroupsPass());
+  pm.addPass(createCanonicalizerPass());
+  pm.addPass(createCSEPass());
+
+  addLinalgBufferizePasses(pm, gpuAllocationFunction);
+
   pm.addPass(createCanonicalizerPass());
   pm.addPass(createCSEPass());
 
@@ -89,11 +98,7 @@
 }
 
 void addGPUMatmulTensorCorePassPipeline(OpPassManager &pm) {
-  //===--------------------------------------------------------------------===//
-  // Initial clean up.
-  //===--------------------------------------------------------------------===//
-  pm.addPass(createCanonicalizerPass());
-  pm.addPass(createCSEPass());
+  tileAndBufferize(pm);
 
   // Distribute linalg onto warps within the workgroup.
   pm.addNestedPass<FuncOp>(
@@ -123,11 +128,7 @@
 }
 
 void addGPUSimpleDistributePassPipeline(OpPassManager &pm) {
-  //===--------------------------------------------------------------------===//
-  // Initial clean up.
-  //===--------------------------------------------------------------------===//
-  pm.addPass(createCanonicalizerPass());
-  pm.addPass(createCSEPass());
+  tileAndBufferize(pm);
 
   // Distribute linalg onto threads within the workgroup.
   pm.addNestedPass<FuncOp>(createLLVMGPUTileAndDistribute());
@@ -183,13 +184,6 @@
 
 void buildLLVMGPUTransformPassPipeline(OpPassManager &pm, bool useROCM) {
   pm.nest<ModuleOp>().nest<FuncOp>().addPass(createTypePropagationPass());
-  pm.nest<ModuleOp>().nest<FuncOp>().addPass(
-      createTileAndDistributeToWorkgroupsPass());
-  pm.addPass(createCanonicalizerPass());
-  pm.addPass(createCSEPass());
-
-  OpPassManager &bufferizePassPM = pm.nest<ModuleOp>();
-  addLinalgBufferizePasses(bufferizePassPM, gpuAllocationFunction);
   pm.addPass(createLLVMGPULowerExecutableTargetPass());
   OpPassManager &nestedModulePM = pm.nest<ModuleOp>();
   //===--------------------------------------------------------------------===//
diff --git a/iree/compiler/Codegen/LLVMGPU/Verifiers.cpp b/iree/compiler/Codegen/LLVMGPU/Verifiers.cpp
index 550fdde..76bfe2e 100644
--- a/iree/compiler/Codegen/LLVMGPU/Verifiers.cpp
+++ b/iree/compiler/Codegen/LLVMGPU/Verifiers.cpp
@@ -95,8 +95,10 @@
   }
 
   Type inputType = op->getOperand(0).getType();
-  ArrayRef<int64_t> lhsShape = getUntiledShape(op->getOperand(0));
-  ArrayRef<int64_t> rhsShape = getUntiledShape(op->getOperand(1));
+  ArrayRef<int64_t> lhsShape =
+      op->getOperand(0).getType().cast<ShapedType>().getShape();
+  ArrayRef<int64_t> rhsShape =
+      op->getOperand(1).getType().cast<ShapedType>().getShape();
   SmallVector<int64_t> firstLevelTileSizes =
       loweringConfig.getTileSizeVals(kWorkgroupTileLevel);
 
diff --git a/iree/compiler/Codegen/LLVMGPU/test/distribute_to_thread.mlir b/iree/compiler/Codegen/LLVMGPU/test/distribute_to_thread.mlir
index 5c181fa..b371c6f 100644
--- a/iree/compiler/Codegen/LLVMGPU/test/distribute_to_thread.mlir
+++ b/iree/compiler/Codegen/LLVMGPU/test/distribute_to_thread.mlir
@@ -17,7 +17,7 @@
 #map4 = affine_map<(d0, d1)[s0] -> (d0 * 1024 + s0 + d1)>
 hal.executable private @dot_dispatch_0  {
   hal.executable.variant @cuda, target = #executable_target_cuda_nvptx_fb {
-    hal.executable.entry_point @dot_dispatch_0 layout(#executable_layout) attributes {
+    hal.executable.entry_point @dot_dispatch_0 layout(#executable_layout) {
       translation.info = #translation,
       workgroup_size = [64 : index, 1 : index, 1 : index]
     }
@@ -97,7 +97,7 @@
 ]>
 hal.executable private @batch_matmul_func  {
   hal.executable.variant @cuda, target = #executable_target_cuda_nvptx_fb {
-    hal.executable.entry_point @batch_matmul_func layout(#executable_layout) attributes {
+    hal.executable.entry_point @batch_matmul_func layout(#executable_layout) {
       translation.info = #translation,
       workgroup_size = [8 : index, 8 : index, 1 : index]
     }
@@ -177,7 +177,7 @@
 #map4 = affine_map<(d0, d1)[s0] -> (d0 * 1024 + s0 + d1)>
 hal.executable private @dot_dispatch_0  {
   hal.executable.variant @cuda, target = #executable_target_cuda_nvptx_fb {
-    hal.executable.entry_point @dot_dispatch_0 layout(#executable_layout) attributes {
+    hal.executable.entry_point @dot_dispatch_0 layout(#executable_layout) {
       translation.info = #translation,
       workgroup_size = [64 : index, 8 : index, 1 : index]
     }
@@ -260,7 +260,7 @@
 // Pure reducion case, skip tiling.
 hal.executable @reduction_dispatch {
   hal.executable.variant @cuda, target = <"cuda", "cuda-nvptx-fb"> {
-    hal.executable.entry_point @predict_dispatch_153 layout(#executable_layout) attributes {
+    hal.executable.entry_point @predict_dispatch_153 layout(#executable_layout) {
       translation.info = #translation,
       workgroup_size = [1: index, 1: index, 1: index]
     }
diff --git a/iree/compiler/Codegen/LLVMGPU/test/distribute_wg_copy.mlir b/iree/compiler/Codegen/LLVMGPU/test/distribute_wg_copy.mlir
index 8bc4f5d..5511cce 100644
--- a/iree/compiler/Codegen/LLVMGPU/test/distribute_wg_copy.mlir
+++ b/iree/compiler/Codegen/LLVMGPU/test/distribute_wg_copy.mlir
@@ -17,7 +17,7 @@
 ]>
 hal.executable private @shared_mem_cpy  {
   hal.executable.variant @cuda, target = <"cuda", "cuda-nvptx-fb"> {
-    hal.executable.entry_point @shared_mem_cpy layout(#executable_layout) attributes {
+    hal.executable.entry_point @shared_mem_cpy layout(#executable_layout) {
       workgroup_size = [32: index, 4: index, 1:index]
     }
     builtin.module {
diff --git a/iree/compiler/Codegen/LLVMGPU/test/gpu_set_num_workgroups.mlir b/iree/compiler/Codegen/LLVMGPU/test/gpu_set_num_workgroups.mlir
index e7644df..d03ce6f 100644
--- a/iree/compiler/Codegen/LLVMGPU/test/gpu_set_num_workgroups.mlir
+++ b/iree/compiler/Codegen/LLVMGPU/test/gpu_set_num_workgroups.mlir
@@ -32,15 +32,10 @@
 }
 
 //  CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering.config<tile_sizes = {{\[}}[256]{{\]}}, native_vector_size = []>
-//  CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 256)>
-//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"LLVMGPUVectorize", workload_per_wg = [256]>
+//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"LLVMGPUVectorize", workload_per_wg = []>
 //      CHECK: hal.executable.entry_point public @add_dispatch_0
 // CHECK-SAME:     translation.info = #[[TRANSLATION]]
 // CHECK-SAME:     workgroup_size = [64 : index, 1 : index, 1 : index]
-// CHECK-NEXT:   ^bb0(%[[ARG0:[a-zA-Z0-9]+]]: index,
-//  CHECK-DAG:     %[[C1:.+]] = arith.constant 1 : index
-//  CHECK-DAG:     %[[NWGS_X:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]]]
-//      CHECK:     hal.return %[[NWGS_X]], %[[C1]], %[[C1]]
 //      CHECK: func @add_dispatch_0
 //      CHECK:   linalg.generic
 // CHECK-SAME:       lowering.config = #[[CONFIG]]
@@ -66,45 +61,19 @@
         %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) : memref<2x3xf32>
         %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) : memref<3x4xf32>
         %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) : memref<2x4xf32>
-        %workgroup_size_x = hal.interface.workgroup.size[0] : index
-        %workgroup_size_y = hal.interface.workgroup.size[1] : index
-        %workgroup_id_x = hal.interface.workgroup.id[0] : index
-        %workgroup_count_x = hal.interface.workgroup.count[0] : index
-        %workgroup_id_y = hal.interface.workgroup.id[1] : index
-        %workgroup_count_y = hal.interface.workgroup.count[1] : index
-        %3 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_y, %workgroup_size_y]
-        %4 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_y, %workgroup_size_y]
-        scf.for %arg0 = %3 to %c2 step %4 {
-          %5 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_x, %workgroup_size_x]
-          %6 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_x, %workgroup_size_x]
-          scf.for %arg1 = %5 to %c4 step %6 {
-            %7 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 2)>(%arg0)[%workgroup_size_y]
-            %8 = memref.subview %0[%arg0, 0] [%7, 3] [1, 1] : memref<2x3xf32> to memref<?x3xf32, affine_map<(d0, d1)[s0] -> (d0 * 3 + s0 + d1)>>
-            %9 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 4)>(%arg1)[%workgroup_size_x]
-            %10 = memref.subview %1[0, %arg1] [3, %9] [1, 1] : memref<3x4xf32> to memref<3x?xf32, affine_map<(d0, d1)[s0] -> (d0 * 4 + s0 + d1)>>
-            %11 = memref.subview %2[%arg0, %arg1] [%7, %9] [1, 1] : memref<2x4xf32> to memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d0 * 4 + s0 + d1)>>
-            linalg.fill(%cst, %11) : f32, memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d0 * 4 + s0 + d1)>>
-            linalg.matmul {__internal_linalg_transform__ = "workgroup"} ins(%8, %10 : memref<?x3xf32, affine_map<(d0, d1)[s0] -> (d0 * 3 + s0 + d1)>>, memref<3x?xf32, affine_map<(d0, d1)[s0] -> (d0 * 4 + s0 + d1)>>) outs(%11 : memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d0 * 4 + s0 + d1)>>)
-          }
-        }
-        return
+              linalg.fill(%cst, %2) : f32, memref<2x4xf32>
+              linalg.matmul ins(%0, %1 : memref<2x3xf32>, memref<3x4xf32>) outs(%2 : memref<2x4xf32>)
+              return
       }
     }
   }
 }
 
 //  CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering.config<tile_sizes = {{\[}}[4, 2, 4]{{\]}}, native_vector_size = []>
-//  CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 2)>
-//  CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 ceildiv 4)>
-//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"LLVMGPUMatmulSimt", workload_per_wg = [2, 4]>
+//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"LLVMGPUMatmulSimt", workload_per_wg = []>
 //      CHECK: hal.executable.entry_point public @dot_dispatch_1
 // CHECK-SAME:     translation.info = #[[TRANSLATION]]
 // CHECK-SAME:     workgroup_size = [2 : index, 4 : index, 1 : index]
-// CHECK-NEXT:   ^bb0(%[[ARG0:[a-zA-Z0-9]+]]: index, %[[ARG1:[a-zA-Z0-9]+]]: index,
-//  CHECK-DAG:     %[[C1:.+]] = arith.constant 1 : index
-//  CHECK-DAG:     %[[NWGS_X:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]]]
-//  CHECK-DAG:     %[[NWGS_Y:.+]] = affine.apply #[[MAP1]]()[%[[ARG1]]]
-//      CHECK:     hal.return %[[NWGS_X]], %[[NWGS_Y]], %[[C1]]
 //      CHECK: func @dot_dispatch_1
 //      CHECK:   linalg.fill
 // CHECK-SAME:       lowering.config = #[[CONFIG]]
@@ -148,10 +117,6 @@
 //  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"LLVMGPUDistribute", workload_per_wg = []>
 //      CHECK: hal.executable.entry_point public @predict_dispatch_153
 // CHECK-SAME:     translation.info = #[[TRANSLATION]]
-// CHECK-SAME:     workgroup_size = [1 : index, 1 : index, 1 : index]
-// CHECK-NEXT:   ^bb0(%[[ARG0:[a-zA-Z0-9]+]]: index,
-//  CHECK-DAG:     %[[C1:.+]] = arith.constant 1 : index
-//      CHECK:     hal.return %[[C1]], %[[C1]], %[[C1]]
 //      CHECK: linalg.fill
 // CHECK-SAME:   lowering.config = #[[CONFIG]]
 //      CHECK: linalg.generic
@@ -165,53 +130,41 @@
     #hal.descriptor_set.binding<1, storage_buffer>
   ]>
 ]>
-hal.executable @tensor_insert {
+hal.executable @tensor_insert_slice {
   hal.executable.variant @cuda, target = <"cuda", "cuda-nvptx-fb"> {
     hal.executable.entry_point @tensor_insert_slice layout(#executable_layout)
     builtin.module {
       builtin.func @tensor_insert_slice() {
         %c0 = arith.constant 0 : index
-        %1 = hal.interface.constant.load[0] : index
-        %2 = hal.interface.constant.load[1] : index
-        %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) : !flow.dispatch.tensor<readonly:?x?xi32>{%1, %2}
-        %3 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) : !flow.dispatch.tensor<writeonly:?x?xi32>{%1, %2}
-        %workgroup_size_x = hal.interface.workgroup.size[0] : index
-        %workgroup_size_y = hal.interface.workgroup.size[1] : index
-        %workgroup_id_x = hal.interface.workgroup.id[0] : index
-        %workgroup_count_x = hal.interface.workgroup.count[0] : index
-        %workgroup_id_y = hal.interface.workgroup.id[1] : index
-        %workgroup_count_y = hal.interface.workgroup.count[1] : index
-        %4 = affine.apply affine_map<()[s0, s1] -> (s1 * s0)>()[%workgroup_size_y, %workgroup_id_y]
-        %5 = affine.apply affine_map<()[s0, s1] -> (s1 * s0)>()[%workgroup_size_y, %workgroup_count_y]
-        %d0 = hal.interface.constant.load[2] : index
-        %d1 = hal.interface.constant.load[2] : index
-        scf.for %arg0 = %4 to %d0 step %5 {
-          %6 = affine.min affine_map<(d0)[s0, s1] -> (s0, -d0 + s1)>(%arg0)[%workgroup_size_y, %d0]
-          %7 = affine.apply affine_map<()[s0, s1] -> (s1 * s0)>()[%workgroup_size_x, %workgroup_id_x]
-          %8 = affine.apply affine_map<()[s0, s1] -> (s1 * s0)>()[%workgroup_size_x, %workgroup_count_x]
-          scf.for %arg1 = %7 to %d1 step %8 {
-            %9 = affine.min affine_map<(d0)[s0, s1] -> (s0, -d0 + s1)>(%arg1)[%workgroup_size_x, %d1]
-            %10 = flow.dispatch.tensor.load %0, offsets = [%arg0, %arg1], sizes = [%6, %9], strides = [1, 1] : !flow.dispatch.tensor<readonly:?x?xi32>{%1, %2} -> tensor<?x?xi32>
-            %11 = affine.apply affine_map<(d0)[s0] -> (d0 + s0)>(%arg0)[%1]
-            %12 = affine.apply affine_map<(d0)[s0] -> (d0 + s0)>(%arg1)[%2]
-            flow.dispatch.tensor.store %10, %3, offsets = [%11, %12], sizes = [%6, %9], strides = [1, 1] : tensor<?x?xi32> -> !flow.dispatch.tensor<writeonly:?x?xi32>{%1, %2}
-          }
-        }
+        %size_y = hal.interface.constant.load[0] : index
+        %size_x = hal.interface.constant.load[1] : index
+        %dest_size_y = hal.interface.constant.load[2] : index
+        %dest_size_x = hal.interface.constant.load[3] : index
+        %offset_y = hal.interface.constant.load[4] : index
+        %offset_x = hal.interface.constant.load[5] : index
+        %source_binding = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer)
+            : !flow.dispatch.tensor<readonly:?x?xi32>{%size_y, %size_x}
+        %dest_binding = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer)
+            : !flow.dispatch.tensor<readwrite:?x?xi32>{%dest_size_y, %dest_size_x}
+        %source = flow.dispatch.tensor.load %source_binding, offsets = [0, 0], sizes = [%size_y, %size_x], strides = [1, 1]
+            : !flow.dispatch.tensor<readonly:?x?xi32>{%size_y, %size_x} -> tensor<?x?xi32>
+        %dest = flow.dispatch.tensor.load %dest_binding, offsets = [0, 0], sizes = [%dest_size_y, %dest_size_x], strides = [1, 1]
+            : !flow.dispatch.tensor<readwrite:?x?xi32>{%dest_size_y, %dest_size_x} -> tensor<?x?xi32>
+        %result = tensor.insert_slice %source into %dest[%offset_y, %offset_x] [%size_y, %size_x] [1, 1]
+            : tensor<?x?xi32> into tensor<?x?xi32>
+        flow.dispatch.tensor.store %result, %dest_binding, offsets = [0, 0], sizes = [%dest_size_y, %dest_size_x], strides = [1, 1]
+            : tensor<?x?xi32> -> !flow.dispatch.tensor<readwrite:?x?xi32>{%dest_size_y, %dest_size_x}
         return
       }
     }
   }
 }
-
-//  CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0] -> (s0 ceildiv 128)>
-//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"LLVMGPUDistribute", workload_per_wg = [128, 1]>
+//  CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering.config<tile_sizes = {{\[}}[1, 64]{{\]}}, native_vector_size = []>
+//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"LLVMGPUVectorize", workload_per_wg = []>
 //      CHECK: hal.executable.entry_point public @tensor_insert_slice
-// CHECK-SAME:   translation.info = #[[TRANSLATION]]
-// CHECK-NEXT:   %[[ARG0:[a-zA-Z0-9_]+]]: index
-// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: index
-//  CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
-//  CHECK-DAG:   %[[NWGSX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]]]
-//      CHECK:   hal.return %[[NWGSX]], %[[ARG1]], %[[C1]]
+// CHECK-SAME:     translation.info = #[[TRANSLATION]]
+//      CHECK: tensor.insert_slice
+// CHECK-SAME:     lowering.config = #[[CONFIG]]
 
 // -----
 
@@ -221,42 +174,21 @@
     #hal.descriptor_set.binding<1, storage_buffer>
   ]>
 ]>
-hal.executable @tensor_insert {
+hal.executable @copy_as_generic {
   hal.executable.variant @cuda, target = <"cuda", "cuda-nvptx-fb"> {
-    hal.executable.entry_point @tensor_insert_slice layout(#executable_layout)
+    hal.executable.entry_point @copy_as_generic layout(#executable_layout)
     builtin.module {
-      builtin.func @tensor_insert_slice() {
+      builtin.func @copy_as_generic() {
         %c0 = arith.constant 0 : index
         %d0 = hal.interface.constant.load[0] : index
         %d1 = hal.interface.constant.load[1] : index
         %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) : memref<?x?xi32>{%d0, %d1}
         %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) : memref<?x?xi32>{%d0, %d1}
-        %workgroup_size_x = hal.interface.workgroup.size[0] : index
-        %workgroup_size_y = hal.interface.workgroup.size[1] : index
-        %workgroup_id_x = hal.interface.workgroup.id[0] : index
-        %workgroup_count_x = hal.interface.workgroup.count[0] : index
-        %workgroup_id_y = hal.interface.workgroup.id[1] : index
-        %workgroup_count_y = hal.interface.workgroup.count[1] : index
-        %2 = affine.apply affine_map<()[s0, s1] -> (s1 * s0)>()[%workgroup_size_y, %workgroup_id_y]
-        %3 = affine.apply affine_map<()[s0, s1] -> (s1 * s0)>()[%workgroup_size_y, %workgroup_count_y]
-        scf.for %arg0 = %2 to %d0 step %3 {
-          %4 = affine.min affine_map<(d0)[s0, s1] -> (s0, -d0 + s1)>(%arg0)[%workgroup_size_y, %d0]
-          %5 = affine.apply affine_map<()[s0, s1] -> (s1 * s0)>()[%workgroup_size_x, %workgroup_id_x]
-          %6 = affine.apply affine_map<()[s0, s1] -> (s1 * s0)>()[%workgroup_size_x, %workgroup_count_x]
-          scf.for %arg1 = %5 to %d1 step %6 {
-            %7 = affine.min affine_map<(d0)[s0, s1] -> (s0, -d0 + s1)>(%arg1)[%workgroup_size_x, %d1]
-            %8 = memref.subview %0[%arg0, %arg1] [%4, %7] [1, 1] : memref<?x?xi32> to memref<?x?xi32, affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>>
-            %9 = affine.apply affine_map<(d0) -> (d0 + 4)>(%arg0)
-            %10 = affine.apply affine_map<(d0) -> (d0 + 3)>(%arg1)
-            %11 = memref.subview %1[%9, %10] [%4, %7] [1, 1] : memref<?x?xi32> to memref<?x?xi32, affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>>
-            linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]}
-              ins(%8 : memref<?x?xi32, affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>>)
-              outs(%11 : memref<?x?xi32, affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>>) {
-              ^bb0(%arg4: i32, %s: i32):  // no predecessors
-                linalg.yield %arg4 : i32
-            }
+        linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]}
+            ins(%0 : memref<?x?xi32>) outs(%1 : memref<?x?xi32>) {
+          ^bb0(%arg4: i32, %s: i32):  // no predecessors
+            linalg.yield %arg4 : i32
           }
-        }
         return
       }
     }
@@ -264,16 +196,10 @@
 }
 
 //  CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering.config<tile_sizes = {{\[}}[1, 64]{{\]}}, native_vector_size = []>
-//  CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0] -> (s0 ceildiv 64)>
-//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"LLVMGPUVectorize", workload_per_wg = [64, 1]>
-//      CHECK: hal.executable.entry_point public @tensor_insert_slice
-// CHECK-SAME:   translation.info = #[[TRANSLATION]]
-// CHECK-NEXT:   %[[ARG0:[a-zA-Z0-9_]+]]: index
-// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: index
-//  CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
-//  CHECK-DAG:   %[[NWGSX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]]]
-//      CHECK:   hal.return %[[NWGSX]], %[[ARG1]], %[[C1]]
-//      CHECK:   linalg.generic
+//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"LLVMGPUVectorize", workload_per_wg = []>
+//      CHECK: hal.executable.entry_point public @copy_as_generic
+// CHECK-SAME:     translation.info = #[[TRANSLATION]]
+//      CHECK: linalg.generic
 // CHECK-SAME:     lowering.config = #[[CONFIG]]
 
 // -----
@@ -307,18 +233,10 @@
 }
 
 //   CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering.config<tile_sizes = {{\[}}[4]{{\]}}, native_vector_size = []>
-//   CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 4)>
-//   CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"LLVMGPUDistribute", workload_per_wg = [4]>
+//   CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"LLVMGPUDistribute", workload_per_wg = []>
 //       CHECK: hal.executable.entry_point public @static_1d_fft_stage2
-//  CHECK-SAME:   translation.info = #[[TRANSLATION]]
-//  CHECK-SAME:   workgroup_size = [32 : index, 1 : index, 1 : index]
-//  CHECK-NEXT: ^{{.+}}(%[[ARG0:.+]]: index, %{{.+}}: index, %{{.+}}: index):
-//  CHECK-NEXT:   %[[ONE:.+]] = arith.constant 1 : index
-//  CHECK-NEXT:   %[[T:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]]]
-//  CHECK-NEXT:   hal.return %[[T]], %[[ONE]], %[[ONE]]
-
-//       CHECK: func @static_1d_fft_stage2()
-//       CHECK:   iree_linalg_ext.fft
+//  CHECK-SAME:     translation.info = #[[TRANSLATION]]
+//       CHECK: iree_linalg_ext.fft
 //  CHECK-SAME:     lowering.config = #[[CONFIG]]
 
 // -----
@@ -345,27 +263,9 @@
         %1 = bufferization.to_memref %cst : memref<4xf32>
         %2 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) : memref<64x128x32xf32>
         %3 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) : memref<64x128x32xf32>
-        %workgroup_id_x = hal.interface.workgroup.id[0] : index
-        %workgroup_count_x = hal.interface.workgroup.count[0] : index
-        %workgroup_id_y = hal.interface.workgroup.id[1] : index
-        %workgroup_count_y = hal.interface.workgroup.count[1] : index
-        %workgroup_id_z = hal.interface.workgroup.id[2] : index
-        %workgroup_count_z = hal.interface.workgroup.count[2] : index
-        scf.for %arg0 = %workgroup_id_z to %c64 step %workgroup_count_z {
-          scf.for %arg1 = %workgroup_id_y to %c128 step %workgroup_count_y {
-            %4 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%workgroup_id_x]
-            %5 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%workgroup_count_x]
-            scf.for %arg2 = %4 to %c32 step %5 {
-              %6 = memref.subview %2[%arg0, %arg1, %arg2] [1, 1, 4] [1, 1, 1] : memref<64x128x32xf32> to memref<1x1x4xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 4096 + s0 + d1 * 32 + d2)>>
-              %7 = memref.cast %6 : memref<1x1x4xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 4096 + s0 + d1 * 32 + d2)>> to memref<?x?x?xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 4096 + s0 + d1 * 32 + d2)>>
-              %8 = memref.subview %3[%arg0, %arg1, %arg2] [1, 1, 4] [1, 1, 1] : memref<64x128x32xf32> to memref<1x1x4xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 4096 + s0 + d1 * 32 + d2)>>
-              %9 = memref.cast %8 : memref<1x1x4xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 4096 + s0 + d1 * 32 + d2)>> to memref<?x?x?xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 4096 + s0 + d1 * 32 + d2)>>
-              iree_linalg_ext.fft {__internal_linalg_transform__ = "workgroup"}
-                ins(%c3, %1, %0 : index, memref<4xf32>, memref<4xf32>)
-                outs(%7, %9 : memref<?x?x?xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 4096 + s0 + d1 * 32 + d2)>>, memref<?x?x?xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 4096 + s0 + d1 * 32 + d2)>>)
-            }
-          }
-        }
+        iree_linalg_ext.fft {__internal_linalg_transform__ = "workgroup"}
+            ins(%c3, %1, %0 : index, memref<4xf32>, memref<4xf32>)
+            outs(%2, %3 : memref<64x128x32xf32>, memref<64x128x32xf32>)
         return
       }
     }
@@ -373,17 +273,10 @@
 }
 
 //   CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering.config<tile_sizes = {{\[}}[1, 1, 8]{{\]}}, native_vector_size = []>
-//   CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 8)>
-//   CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"LLVMGPUDistribute", workload_per_wg = [8, 1, 1]>
+//   CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"LLVMGPUDistribute", workload_per_wg = []>
 //       CHECK: hal.executable.entry_point public @static_3d_fft_stage3
-//  CHECK-SAME:   translation.info = #[[TRANSLATION]]
-//  CHECK-SAME:   workgroup_size = [32 : index, 1 : index, 1 : index]
-//  CHECK-NEXT: ^{{.+}}(%[[ARG0:.+]]: index, %[[ARG1:.+]]: index, %[[ARG2:.+]]: index):
-//  CHECK-NEXT:   %[[T:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]]]
-//  CHECK-NEXT:   hal.return %[[T]], %[[ARG1]], %[[ARG2]]
-
-//       CHECK: func @static_3d_fft_stage3()
-//       CHECK:   iree_linalg_ext.fft
+//  CHECK-SAME:     translation.info = #[[TRANSLATION]]
+//       CHECK: iree_linalg_ext.fft
 //  CHECK-SAME:     lowering.config = #[[CONFIG]]
 
 // -----
@@ -411,32 +304,15 @@
       %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) : !flow.dispatch.tensor<readonly:128x256xf32>
       %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) : !flow.dispatch.tensor<readonly:256x1024xf32>
       %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) : !flow.dispatch.tensor<writeonly:128x1024xf32>
-      %workgroup_size_x = hal.interface.workgroup.size[0] : index
-      %workgroup_size_y = hal.interface.workgroup.size[1] : index
-      %workgroup_id_x = hal.interface.workgroup.id[0] : index
-      %workgroup_count_x = hal.interface.workgroup.count[0] : index
-      %workgroup_id_y = hal.interface.workgroup.id[1] : index
-      %workgroup_count_y = hal.interface.workgroup.count[1] : index
-      %3 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_y, %workgroup_size_y]
-      %4 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_y, %workgroup_size_y]
-      scf.for %arg0 = %3 to %c128 step %4 {
-        %5 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_x, %workgroup_size_x]
-        %6 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_x, %workgroup_size_x]
-        scf.for %arg1 = %5 to %c1024 step %6 {
-          %7 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 128)>(%arg0)[%workgroup_size_y]
-          %8 = flow.dispatch.tensor.load %0, offsets = [%arg0, 0], sizes = [%7, 256], strides = [1, 1] : !flow.dispatch.tensor<readonly:128x256xf32> -> tensor<?x256xf32>
-          %9 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 1024)>(%arg1)[%workgroup_size_x]
-          %10 = flow.dispatch.tensor.load %1, offsets = [0, %arg1], sizes = [256, %9], strides = [1, 1] : !flow.dispatch.tensor<readonly:256x1024xf32> -> tensor<256x?xf32>
-          %11 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 128)>(%arg0)[%workgroup_size_y]
-          %12 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 1024)>(%arg1)[%workgroup_size_x]
-          %13 = affine.min affine_map<(d0)[s0] -> (-d0 + 128, s0)>(%arg0)[%workgroup_size_y]
-          %14 = affine.min affine_map<(d0)[s0] -> (-d0 + 1024, s0)>(%arg1)[%workgroup_size_x]
-          %15 = linalg.init_tensor [%13, %14] : tensor<?x?xf32>
-          %16 = linalg.fill(%cst, %15) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
-          %17 = linalg.matmul {__internal_linalg_transform__ = "workgroup", compilation.info = #compilation} ins(%8, %10 : tensor<?x256xf32>, tensor<256x?xf32>) outs(%16 : tensor<?x?xf32>) -> tensor<?x?xf32>
-          flow.dispatch.tensor.store %17, %2, offsets = [%arg0, %arg1], sizes = [%11, %12], strides = [1, 1] : tensor<?x?xf32> -> !flow.dispatch.tensor<writeonly:128x1024xf32>
-        }
-      }
+      %3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [128, 256], strides = [1, 1]
+          : !flow.dispatch.tensor<readonly:128x256xf32> -> tensor<128x256xf32>
+      %4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [256, 1024], strides = [1, 1]
+          : !flow.dispatch.tensor<readonly:256x1024xf32> -> tensor<256x1024xf32>
+      %15 = linalg.init_tensor [128, 1024] : tensor<128x1024xf32>
+      %16 = linalg.fill(%cst, %15) : f32, tensor<128x1024xf32> -> tensor<128x1024xf32>
+      %17 = linalg.matmul {__internal_linalg_transform__ = "workgroup", compilation.info = #compilation}
+          ins(%3, %4 : tensor<128x256xf32>, tensor<256x1024xf32>) outs(%16 : tensor<128x1024xf32>) -> tensor<128x1024xf32>
+      flow.dispatch.tensor.store %17, %2, offsets = [0, 0], sizes = [128, 1024], strides = [1, 1] : tensor<128x1024xf32> -> !flow.dispatch.tensor<writeonly:128x1024xf32>
       return
     }
   }
@@ -448,11 +324,10 @@
 //      CHECK: hal.executable.entry_point public @_lowering_config_test_dispatch_1
 // CHECK-SAME:     translation.info = #[[TRANSLATION]]
 // CHECK-SAME:     workgroup_size = [16 : index, 8 : index, 1 : index]
-//      CHECK: func @_lowering_config_test_dispatch_1
-//      CHECK:   linalg.fill
-// CHECK-SAME:       lowering.config = #[[CONFIG]]
-//      CHECK:   linalg.matmul
-// CHECK-SAME:       lowering.config = #[[CONFIG]]
+//      CHECK: linalg.fill
+// CHECK-SAME:     lowering.config = #[[CONFIG]]
+//      CHECK: linalg.matmul
+// CHECK-SAME:     lowering.config = #[[CONFIG]]
 
 // -----
 
@@ -476,23 +351,17 @@
         %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(32) : !flow.dispatch.tensor<readonly:1x576000xi32>
         %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(32) : !flow.dispatch.tensor<writeonly:1x576000xf32>
         %3 = hal.interface.binding.subspan set(0) binding(3) type(storage_buffer) offset(%c2304000) alignment(32) : !flow.dispatch.tensor<writeonly:1x576000xi32>
-        %workgroup_size_x = hal.interface.workgroup.size[0] : index
-        %workgroup_id_x = hal.interface.workgroup.id[0] : index
-        %workgroup_count_x = hal.interface.workgroup.count[0] : index
-        %4 = affine.apply affine_map<()[s0, s1] -> (s1 * s0)>()[%workgroup_size_x, %workgroup_id_x]
-        %5 = affine.apply affine_map<()[s0, s1] -> (s1 * s0)>()[%workgroup_size_x, %workgroup_count_x]
-        scf.for %arg0 = %4 to %c1 step %5 {
-          %6 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 1)>(%arg0)[%workgroup_size_x]
-          %7 = flow.dispatch.tensor.load %0, offsets = [%arg0, 0], sizes = [%6, 576000], strides = [1, 1] : !flow.dispatch.tensor<readonly:1x576000xf32> -> tensor<?x576000xf32>
-          %8 = flow.dispatch.tensor.load %1, offsets = [%arg0, 0], sizes = [%6, 576000], strides = [1, 1] : !flow.dispatch.tensor<readonly:1x576000xi32> -> tensor<?x576000xi32>
-          %9:2 = iree_linalg_ext.sort dimension(1) outs(%7, %8 : tensor<?x576000xf32>, tensor<?x576000xi32>)  {
-          ^bb0(%arg1: f32, %arg2: f32, %arg3: i32, %arg4: i32):  // no predecessors
-            %10 = arith.cmpf ogt, %arg1, %arg2 : f32
-            iree_linalg_ext.yield %10 : i1
-          } -> tensor<?x576000xf32>, tensor<?x576000xi32>
-          flow.dispatch.tensor.store %9#0, %2, offsets = [%arg0, 0], sizes = [%6, 576000], strides = [1, 1] : tensor<?x576000xf32> -> !flow.dispatch.tensor<writeonly:1x576000xf32>
-          flow.dispatch.tensor.store %9#1, %3, offsets = [%arg0, 0], sizes = [%6, 576000], strides = [1, 1] : tensor<?x576000xi32> -> !flow.dispatch.tensor<writeonly:1x576000xi32>
-        }
+        %4 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [1, 576000], strides = [1, 1]
+            : !flow.dispatch.tensor<readonly:1x576000xf32> -> tensor<1x576000xf32>
+        %5 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [1, 576000], strides = [1, 1]
+            : !flow.dispatch.tensor<readonly:1x576000xi32> -> tensor<1x576000xi32>
+        %9:2 = iree_linalg_ext.sort dimension(1) outs(%4, %5 : tensor<1x576000xf32>, tensor<1x576000xi32>)  {
+        ^bb0(%arg1: f32, %arg2: f32, %arg3: i32, %arg4: i32):  // no predecessors
+          %10 = arith.cmpf ogt, %arg1, %arg2 : f32
+          iree_linalg_ext.yield %10 : i1
+        } -> tensor<1x576000xf32>, tensor<1x576000xi32>
+        flow.dispatch.tensor.store %9#0, %2, offsets = [0, 0], sizes = [1, 576000], strides = [1, 1] : tensor<1x576000xf32> -> !flow.dispatch.tensor<writeonly:1x576000xf32>
+        flow.dispatch.tensor.store %9#1, %3, offsets = [0, 0], sizes = [1, 576000], strides = [1, 1] : tensor<1x576000xi32> -> !flow.dispatch.tensor<writeonly:1x576000xi32>
         return
       }
     }
@@ -500,11 +369,8 @@
 }
 
 //   CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering.config<tile_sizes = {{\[}}[64]{{\]}}, native_vector_size = []>
-//   CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 64)>
-//   CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"LLVMGPUDistribute", workload_per_wg = [64]>
+//   CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"LLVMGPUDistribute", workload_per_wg = []>
 //       CHECK: hal.executable.entry_point public @sort_op
-//  CHECK-SAME:   translation.info = #[[TRANSLATION]]
-
-//       CHECK: func @sort_op()
-//       CHECK:   iree_linalg_ext.sort
+//  CHECK-SAME:     translation.info = #[[TRANSLATION]]
+//       CHECK: iree_linalg_ext.sort
 //  CHECK-SAME:     lowering.config = #[[CONFIG]]
diff --git a/iree/compiler/Codegen/LLVMGPU/test/illegal_configuration.mlir b/iree/compiler/Codegen/LLVMGPU/test/illegal_configuration.mlir
index 9cad730..0b79546 100644
--- a/iree/compiler/Codegen/LLVMGPU/test/illegal_configuration.mlir
+++ b/iree/compiler/Codegen/LLVMGPU/test/illegal_configuration.mlir
@@ -1,7 +1,7 @@
 // RUN: iree-opt -pass-pipeline='hal.executable(hal.executable.variant(iree-llvmgpu-lower-executable-target-pass{test-lowering-configuration=true}))' -verify-diagnostics -split-input-file %s
 
 #config = #iree_codegen.lowering.config<tile_sizes = [], native_vector_size = []>
-#translation = #iree_codegen.translation.info<"LLVMGPUMatmulSimt", workload_per_wg = [128, 32]>
+#translation = #iree_codegen.translation.info<"LLVMGPUMatmulSimt", workload_per_wg = []>
 #executable_layout = #hal.executable.layout<push_constants = 0, sets = [
   #hal.descriptor_set.layout<0, bindings = [
     #hal.descriptor_set.binding<0, storage_buffer>,
@@ -11,7 +11,7 @@
 ]>
 hal.executable private @matmul_tensors {
   hal.executable.variant @cuda, target = #hal.executable.target<"cuda", "cuda-nvptx-fb"> {
-    hal.executable.entry_point @illegal layout(#executable_layout) attributes {
+    hal.executable.entry_point @illegal layout(#executable_layout) {
       translation.info = #translation,
       workgroup_size = [32 : index, 8 : index, 8 : index]
     }
@@ -33,7 +33,7 @@
 // -----
 
 #config = #iree_codegen.lowering.config<tile_sizes = [], native_vector_size = []>
-#translation = #iree_codegen.translation.info<"LLVMGPUMatmulSimt", workload_per_wg = [128, 32]>
+#translation = #iree_codegen.translation.info<"LLVMGPUMatmulSimt", workload_per_wg = []>
 #executable_layout = #hal.executable.layout<push_constants = 0, sets = [
   #hal.descriptor_set.layout<0, bindings = [
     #hal.descriptor_set.binding<0, storage_buffer>,
@@ -43,7 +43,7 @@
 ]>
 hal.executable private @matmul_tensors {
   hal.executable.variant @cuda, target = #hal.executable.target<"cuda", "cuda-nvptx-fb"> {
-    hal.executable.entry_point @illegal layout(#executable_layout) attributes {
+    hal.executable.entry_point @illegal layout(#executable_layout) {
       translation.info = #translation,
       workgroup_size = [32 : index, 8 : index, 2 : index]
     }
@@ -65,7 +65,7 @@
 // -----
 
 #config = #iree_codegen.lowering.config<tile_sizes = [[32, 32, 16]], native_vector_size = []>
-#translation = #iree_codegen.translation.info<"LLVMGPUMatmulTensorCore", workload_per_wg = [32, 32]>
+#translation = #iree_codegen.translation.info<"LLVMGPUMatmulTensorCore", workload_per_wg = []>
 #executable_layout = #hal.executable.layout<push_constants = 0, sets = [
   #hal.descriptor_set.layout<0, bindings = [
     #hal.descriptor_set.binding<0, storage_buffer>,
@@ -75,7 +75,7 @@
 ]>
 hal.executable private @matmul_tensors {
   hal.executable.variant @cuda, target = #hal.executable.target<"cuda", "cuda-nvptx-fb"> {
-    hal.executable.entry_point @illegal layout(#executable_layout) attributes {
+    hal.executable.entry_point @illegal layout(#executable_layout) {
       translation.info = #translation,
       workgroup_size = [64 : index, 2 : index, 10 : index]
     }
@@ -97,7 +97,7 @@
 // -----
 
 #config = #iree_codegen.lowering.config<tile_sizes = [[32, 32, 16]], native_vector_size = []>
-#translation = #iree_codegen.translation.info<"LLVMGPUMatmulTensorCore", workload_per_wg = [32, 32]>
+#translation = #iree_codegen.translation.info<"LLVMGPUMatmulTensorCore", workload_per_wg = []>
 #executable_layout = #hal.executable.layout<push_constants = 0, sets = [
   #hal.descriptor_set.layout<0, bindings = [
     #hal.descriptor_set.binding<0, storage_buffer>,
@@ -107,7 +107,7 @@
 ]>
 hal.executable private @matmul_tensors {
   hal.executable.variant @cuda, target = #hal.executable.target<"cuda", "cuda-nvptx-fb"> {
-    hal.executable.entry_point @illegal layout(#executable_layout) attributes {
+    hal.executable.entry_point @illegal layout(#executable_layout) {
       translation.info = #translation,
       workgroup_size = [48 : index, 2 : index, 1 : index]
     }
@@ -129,7 +129,7 @@
 // -----
 
 #config = #iree_codegen.lowering.config<tile_sizes = [[32, 32, 16]], native_vector_size = []>
-#translation = #iree_codegen.translation.info<"LLVMGPUMatmulTensorCore", workload_per_wg = [32, 32]>
+#translation = #iree_codegen.translation.info<"LLVMGPUMatmulTensorCore", workload_per_wg = []>
 #executable_layout = #hal.executable.layout<push_constants = 0, sets = [
   #hal.descriptor_set.layout<0, bindings = [
     #hal.descriptor_set.binding<0, storage_buffer>,
@@ -139,7 +139,7 @@
 ]>
 hal.executable private @matmul_tensors {
   hal.executable.variant @cuda, target = #hal.executable.target<"cuda", "cuda-nvptx-fb"> {
-    hal.executable.entry_point @illegal layout(#executable_layout) attributes {
+    hal.executable.entry_point @illegal layout(#executable_layout) {
       translation.info = #translation,
       workgroup_size = [64 : index, 2 : index, 2 : index]
     }
@@ -161,7 +161,7 @@
 // -----
 
 #config = #iree_codegen.lowering.config<tile_sizes = [[32, 32, 20]], native_vector_size = []>
-#translation = #iree_codegen.translation.info<"LLVMGPUMatmulTensorCore", workload_per_wg = [32, 32]>
+#translation = #iree_codegen.translation.info<"LLVMGPUMatmulTensorCore", workload_per_wg = []>
 #executable_layout = #hal.executable.layout<push_constants = 0, sets = [
   #hal.descriptor_set.layout<0, bindings = [
     #hal.descriptor_set.binding<0, storage_buffer>,
@@ -171,7 +171,7 @@
 ]>
 hal.executable private @matmul_tensors {
   hal.executable.variant @cuda, target = #hal.executable.target<"cuda", "cuda-nvptx-fb"> {
-    hal.executable.entry_point @illegal layout(#executable_layout) attributes {
+    hal.executable.entry_point @illegal layout(#executable_layout) {
       translation.info = #translation,
       workgroup_size = [64 : index, 2 : index, 1 : index]
     }
@@ -193,7 +193,7 @@
 // -----
 
 #config = #iree_codegen.lowering.config<tile_sizes = [[32, 32, 16]], native_vector_size = []>
-#translation = #iree_codegen.translation.info<"LLVMGPUMatmulTensorCore", workload_per_wg = [32, 32]>
+#translation = #iree_codegen.translation.info<"LLVMGPUMatmulTensorCore", workload_per_wg = []>
 #executable_layout = #hal.executable.layout<push_constants = 0, sets = [
   #hal.descriptor_set.layout<0, bindings = [
     #hal.descriptor_set.binding<0, storage_buffer>,
@@ -203,7 +203,7 @@
 ]>
 hal.executable private @matmul_tensors {
   hal.executable.variant @cuda, target = #hal.executable.target<"cuda", "cuda-nvptx-fb"> {
-    hal.executable.entry_point @illegal layout(#executable_layout) attributes {
+    hal.executable.entry_point @illegal layout(#executable_layout) {
       translation.info = #translation,
       workgroup_size = [64 : index, 2 : index, 1 : index]
     }
@@ -225,7 +225,7 @@
 // -----
 
 #config = #iree_codegen.lowering.config<tile_sizes = [[32, 32, 16]], native_vector_size = []>
-#translation = #iree_codegen.translation.info<"LLVMGPUMatmulTensorCore", workload_per_wg = [32, 32]>
+#translation = #iree_codegen.translation.info<"LLVMGPUMatmulTensorCore", workload_per_wg = []>
 #executable_layout = #hal.executable.layout<push_constants = 0, sets = [
   #hal.descriptor_set.layout<0, bindings = [
     #hal.descriptor_set.binding<0, storage_buffer>,
@@ -235,7 +235,7 @@
 ]>
 hal.executable private @matmul_tensors {
   hal.executable.variant @cuda, target = #hal.executable.target<"cuda", "cuda-nvptx-fb"> {
-    hal.executable.entry_point @illegal layout(#executable_layout) attributes {
+    hal.executable.entry_point @illegal layout(#executable_layout) {
       translation.info = #translation,
       workgroup_size = [64 : index, 2 : index, 1 : index]
     }
@@ -256,61 +256,44 @@
 
 // -----
 
-#config = #iree_codegen.lowering.config<tile_sizes = [[2, 32, 32, 16]], native_vector_size = []>
-#translation = #iree_codegen.translation.info<"LLVMGPUMatmulTensorCore", workload_per_wg = [32, 8, 1]>
-#executable_target_cuda_nvptx_fb = #hal.executable.target<"cuda", "cuda-nvptx-fb">
-#executable_layout = #hal.executable.layout<push_constants = 0, sets = [
-  #hal.descriptor_set.layout<0, bindings = [
-    #hal.descriptor_set.binding<0, storage_buffer>,
-    #hal.descriptor_set.binding<1, storage_buffer>,
-    #hal.descriptor_set.binding<2, storage_buffer>
-  ]>
-]>
-hal.executable private @batch_matmul_func  {
-  hal.executable.variant @cuda, target = #executable_target_cuda_nvptx_fb {
-    hal.executable.entry_point @batch_matmul_func layout(#executable_layout) attributes {
-      translation.info = #translation,
-      workgroup_size = [64 : index, 2 : index, 1 : index]
-    }
-builtin.module {
-  func @batch_matmul_func() {
-    %c0 = arith.constant 0 : index
-    %cst = arith.constant 0.000000e+00 : f32
-    %c4 = arith.constant 4 : index
-    %c32 = arith.constant 32 : index
-    %c64 = arith.constant 64 : index
-    %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(32) : memref<4x32x1024xf32>
-    memref.assume_alignment %0, 32 : memref<4x32x1024xf32>
-    %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(32) : memref<4x1024x64xf32>
-    memref.assume_alignment %1, 32 : memref<4x1024x64xf32>
-    %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) offset(%c0) alignment(32) : memref<4x32x64xf32>
-    memref.assume_alignment %2, 32 : memref<4x32x64xf32>
-    %workgroup_id_x = hal.interface.workgroup.id[0] : index
-    %workgroup_count_x = hal.interface.workgroup.count[0] : index
-    %workgroup_id_y = hal.interface.workgroup.id[1] : index
-    %workgroup_count_y = hal.interface.workgroup.count[1] : index
-    %workgroup_id_z = hal.interface.workgroup.id[2] : index
-    %workgroup_count_z = hal.interface.workgroup.count[2] : index
-    scf.for %arg0 = %workgroup_id_z to %c4 step %workgroup_count_z {
-      %3 = affine.apply affine_map<()[s0] -> (s0 * 8)>()[%workgroup_id_y]
-      %4 = affine.apply affine_map<()[s0] -> (s0 * 8)>()[%workgroup_count_y]
-      scf.for %arg1 = %3 to %c32 step %4 {
-        %5 = affine.apply affine_map<()[s0] -> (s0 * 32)>()[%workgroup_id_x]
-        %6 = affine.apply affine_map<()[s0] -> (s0 * 32)>()[%workgroup_count_x]
-        scf.for %arg2 = %5 to %c64 step %6 {
-          %7 = memref.subview %0[%arg0, %arg1, 0] [1, 8, 1024] [1, 1, 1] : memref<4x32x1024xf32> to memref<1x8x1024xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 32768 + s0 + d1 * 1024 + d2)>>
-          %8 = memref.subview %1[%arg0, 0, %arg2] [1, 1024, 32] [1, 1, 1] : memref<4x1024x64xf32> to memref<1x1024x32xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 65536 + s0 + d1 * 64 + d2)>>
-          %9 = memref.subview %2[%arg0, %arg1, %arg2] [1, 8, 32] [1, 1, 1] : memref<4x32x64xf32> to memref<1x8x32xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 2048 + s0 + d1 * 64 + d2)>>
-          linalg.fill(%cst, %9) {lowering.config = #config} : f32, memref<1x8x32xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 2048 + s0 + d1 * 64 + d2)>>
-          // expected-error @+1 {{Received first tile dimension of 2 instead of 1 for LLVMGPUMatmulTensorCore}}
-          linalg.batch_matmul {lowering.config = #config} ins(%7, %8 : memref<1x8x1024xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 32768 + s0 + d1 * 1024 + d2)>>, memref<1x1024x32xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 65536 + s0 + d1 * 64 + d2)>>) outs(%9 : memref<1x8x32xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 2048 + s0 + d1 * 64 + d2)>>)
-        }
-      }
-    }
-    return
-  }
-}
-}
-}
+// This test might not be valid anymore when setting configuration on untiled ops.
+
+// #config = #iree_codegen.lowering.config<tile_sizes = [[2, 32, 32, 16]], native_vector_size = []>
+// #translation = #iree_codegen.translation.info<"LLVMGPUMatmulTensorCore", workload_per_wg = []>
+// #executable_target_cuda_nvptx_fb = #hal.executable.target<"cuda", "cuda-nvptx-fb">
+// #executable_layout = #hal.executable.layout<push_constants = 0, sets = [
+//   #hal.descriptor_set.layout<0, bindings = [
+//     #hal.descriptor_set.binding<0, storage_buffer>,
+//     #hal.descriptor_set.binding<1, storage_buffer>,
+//     #hal.descriptor_set.binding<2, storage_buffer>
+//   ]>
+// ]>
+// hal.executable private @batch_matmul_func  {
+//   hal.executable.variant @cuda, target = #executable_target_cuda_nvptx_fb {
+//     hal.executable.entry_point @batch_matmul_func layout(#executable_layout) {
+//       translation.info = #translation,
+//       workgroup_size = [64 : index, 2 : index, 1 : index]
+//     }
+// builtin.module {
+//   func @batch_matmul_func() {
+//     %c0 = arith.constant 0 : index
+//     %cst = arith.constant 0.000000e+00 : f32
+//     %c4 = arith.constant 4 : index
+//     %c32 = arith.constant 32 : index
+//     %c64 = arith.constant 64 : index
+//     %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(32) : memref<4x32x1024xf32>
+//     memref.assume_alignment %0, 32 : memref<4x32x1024xf32>
+//     %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(32) : memref<4x1024x64xf32>
+//     memref.assume_alignment %1, 32 : memref<4x1024x64xf32>
+//     %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) offset(%c0) alignment(32) : memref<4x32x64xf32>
+//     memref.assume_alignment %2, 32 : memref<4x32x64xf32>
+//     linalg.fill(%cst, %2) {lowering.config = #config} : f32, memref<4x32x64xf32>
+//     // exp- ected-error @+1 {{Received first tile dimension of 2 instead of 1 for LLVMGPUMatmulTensorCore}}
+//     linalg.batch_matmul {lowering.config = #config} ins(%0, %1 : memref<4x32x1024xf32>, memref<4x1024x64xf32>) outs(%2 : memref<4x32x64xf32>)
+//     return
+//   }
+// }
+// }
+// }
 
 // -----
diff --git a/iree/compiler/Codegen/LLVMGPU/test/nvvm_pipeline_test.mlir b/iree/compiler/Codegen/LLVMGPU/test/nvvm_pipeline_test.mlir
index 889ccb8..2379836 100644
--- a/iree/compiler/Codegen/LLVMGPU/test/nvvm_pipeline_test.mlir
+++ b/iree/compiler/Codegen/LLVMGPU/test/nvvm_pipeline_test.mlir
@@ -63,32 +63,16 @@
         %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) : !flow.dispatch.tensor<readonly:1024x1024xf32>
         %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) : !flow.dispatch.tensor<readonly:1024x1024xf32>
         %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) : !flow.dispatch.tensor<writeonly:1024x1024xf32>
-        %workgroup_size_x = hal.interface.workgroup.size[0] : index
-        %workgroup_size_y = hal.interface.workgroup.size[1] : index
-        %workgroup_id_x = hal.interface.workgroup.id[0] : index
-        %workgroup_count_x = hal.interface.workgroup.count[0] : index
-        %workgroup_id_y = hal.interface.workgroup.id[1] : index
-        %workgroup_count_y = hal.interface.workgroup.count[1] : index
-        %3 = affine.apply #map0()[%workgroup_id_y, %workgroup_size_y]
-        %4 = affine.apply #map0()[%workgroup_count_y, %workgroup_size_y]
-        scf.for %arg0 = %3 to %c1024 step %4 {
-          %5 = affine.apply #map0()[%workgroup_id_x, %workgroup_size_x]
-          %6 = affine.apply #map0()[%workgroup_count_x, %workgroup_size_x]
-          scf.for %arg1 = %5 to %c1024 step %6 {
-            %7 = affine.min #map1(%arg0)[%workgroup_size_y]
-            %8 = flow.dispatch.tensor.load %0, offsets = [%arg0, %c0], sizes = [%7, %c1024], strides = [%c1, %c1] : !flow.dispatch.tensor<readonly:1024x1024xf32> -> tensor<?x1024xf32>
-            %9 = affine.min #map1(%arg1)[%workgroup_size_x]
-            %10 = flow.dispatch.tensor.load %1, offsets = [%c0, %arg1], sizes = [%c1024, %9], strides = [%c1, %c1] : !flow.dispatch.tensor<readonly:1024x1024xf32> -> tensor<1024x?xf32>
-            %11 = affine.min #map1(%arg0)[%workgroup_size_y]
-            %12 = affine.min #map1(%arg1)[%workgroup_size_x]
-            %13 = affine.min #map2(%arg0)[%workgroup_size_y]
-            %14 = affine.min #map2(%arg1)[%workgroup_size_x]
-            %15 = linalg.init_tensor [%13, %14] : tensor<?x?xf32>
-            %16 = linalg.fill(%cst, %15) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
-            %17 = linalg.matmul ins(%8, %10 : tensor<?x1024xf32>, tensor<1024x?xf32>) outs(%16 : tensor<?x?xf32>) -> tensor<?x?xf32>
-            flow.dispatch.tensor.store %17, %2, offsets = [%arg0, %arg1], sizes = [%11, %12], strides = [%c1, %c1] : tensor<?x?xf32> -> !flow.dispatch.tensor<writeonly:1024x1024xf32>
-          }
-        }
+        %8 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [1024, 1024], strides = [1, 1]
+            : !flow.dispatch.tensor<readonly:1024x1024xf32> -> tensor<1024x1024xf32>
+        %10 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [1024, 1024], strides = [1, 1]
+            : !flow.dispatch.tensor<readonly:1024x1024xf32> -> tensor<1024x1024xf32>
+        %15 = linalg.init_tensor [1024, 1024] : tensor<1024x1024xf32>
+        %16 = linalg.fill(%cst, %15) : f32, tensor<1024x1024xf32> -> tensor<1024x1024xf32>
+        %17 = linalg.matmul ins(%8, %10 : tensor<1024x1024xf32>, tensor<1024x1024xf32>)
+            outs(%16 : tensor<1024x1024xf32>) -> tensor<1024x1024xf32>
+        flow.dispatch.tensor.store %17, %2, offsets = [0, 0], sizes = [1024, 1024], strides = [1, 1]
+            : tensor<1024x1024xf32> -> !flow.dispatch.tensor<writeonly:1024x1024xf32>
         return
       }
     }
@@ -146,37 +130,21 @@
         %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) : !flow.dispatch.tensor<readonly:1024x1024xf32>
         %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) : !flow.dispatch.tensor<readonly:1024x1024xf32>
         %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) : !flow.dispatch.tensor<writeonly:1024x1024xf32>
-        %workgroup_size_x = hal.interface.workgroup.size[0] : index
-        %workgroup_size_y = hal.interface.workgroup.size[1] : index
-        %workgroup_id_x = hal.interface.workgroup.id[0] : index
-        %workgroup_count_x = hal.interface.workgroup.count[0] : index
-        %workgroup_id_y = hal.interface.workgroup.id[1] : index
-        %workgroup_count_y = hal.interface.workgroup.count[1] : index
-        %3 = affine.apply #map0()[%workgroup_id_y, %workgroup_size_y]
-        %4 = affine.apply #map0()[%workgroup_count_y, %workgroup_size_y]
-        scf.for %arg0 = %3 to %c1024 step %4 {
-          %5 = affine.apply #map0()[%workgroup_id_x, %workgroup_size_x]
-          %6 = affine.apply #map0()[%workgroup_count_x, %workgroup_size_x]
-          scf.for %arg1 = %5 to %c1024 step %6 {
-            %7 = affine.min #map1(%arg0)[%workgroup_size_y]
-            %8 = flow.dispatch.tensor.load %0, offsets = [%arg0, %c0], sizes = [%7, %c1024], strides = [%c1, %c1] : !flow.dispatch.tensor<readonly:1024x1024xf32> -> tensor<?x1024xf32>
-            %9 = affine.min #map1(%arg1)[%workgroup_size_x]
-            %10 = flow.dispatch.tensor.load %1, offsets = [%c0, %arg1], sizes = [%c1024, %9], strides = [%c1, %c1] : !flow.dispatch.tensor<readonly:1024x1024xf32> -> tensor<1024x?xf32>
-            %11 = affine.min #map1(%arg0)[%workgroup_size_y]
-            %12 = affine.min #map1(%arg1)[%workgroup_size_x]
-            %13 = affine.min #map2(%arg0)[%workgroup_size_y]
-            %14 = affine.min #map2(%arg1)[%workgroup_size_x]
-            %15 = linalg.init_tensor [%13, %14] : tensor<?x?xf32>
-            %16 = linalg.fill(%cst, %15) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
-            %17 = linalg.generic #matmul_trait ins(%8, %10 : tensor<?x1024xf32>, tensor<1024x?xf32>) outs(%16 : tensor<?x?xf32>)  {
-              ^bb(%a: f32, %b: f32, %c: f32) :
-              %d = arith.mulf %a, %b: f32
-              %e = arith.addf %c, %d: f32
-              linalg.yield %e : f32
-            } -> (tensor<?x?xf32>)
-            flow.dispatch.tensor.store %17, %2, offsets = [%arg0, %arg1], sizes = [%11, %12], strides = [%c1, %c1] : tensor<?x?xf32> -> !flow.dispatch.tensor<writeonly:1024x1024xf32>
-          }
-        }
+        %8 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [1024, 1024], strides = [1, 1]
+            : !flow.dispatch.tensor<readonly:1024x1024xf32> -> tensor<1024x1024xf32>
+        %10 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [1024, 1024], strides = [1, 1]
+            : !flow.dispatch.tensor<readonly:1024x1024xf32> -> tensor<1024x1024xf32>
+        %15 = linalg.init_tensor [1024, 1024] : tensor<1024x1024xf32>
+        %16 = linalg.fill(%cst, %15) : f32, tensor<1024x1024xf32> -> tensor<1024x1024xf32>
+        %17 = linalg.generic #matmul_trait 
+            ins(%8, %10 : tensor<1024x1024xf32>, tensor<1024x1024xf32>) outs(%16 : tensor<1024x1024xf32>)  {
+          ^bb(%a: f32, %b: f32, %c: f32) :
+            %d = arith.mulf %a, %b: f32
+            %e = arith.addf %c, %d: f32
+            linalg.yield %e : f32
+          } -> (tensor<1024x1024xf32>)
+        flow.dispatch.tensor.store %17, %2, offsets = [0, 0], sizes = [1024, 1024], strides = [1, 1]
+            : tensor<1024x1024xf32> -> !flow.dispatch.tensor<writeonly:1024x1024xf32>
         return
       }
     }
@@ -212,42 +180,16 @@
       %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) : !flow.dispatch.tensor<readonly:1x4x4x2xf32>
       %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) : !flow.dispatch.tensor<readonly:3x2x2x1xf32>
       %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) : !flow.dispatch.tensor<writeonly:1x2x3x1xf32>
-      %workgroup_size_x = hal.interface.workgroup.size[0] : index
-      %workgroup_size_y = hal.interface.workgroup.size[1] : index
-      %workgroup_size_z = hal.interface.workgroup.size[2] : index
-      %workgroup_id_x = hal.interface.workgroup.id[0] : index
-      %workgroup_count_x = hal.interface.workgroup.count[0] : index
-      %workgroup_id_y = hal.interface.workgroup.id[1] : index
-      %workgroup_count_y = hal.interface.workgroup.count[1] : index
-      %workgroup_id_z = hal.interface.workgroup.id[2] : index
-      %workgroup_count_z = hal.interface.workgroup.count[2] : index
-      %3 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_z, %workgroup_size_z]
-      %4 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_z, %workgroup_size_z]
-      scf.for %arg0 = %3 to %c2 step %4 {
-        %5 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_y, %workgroup_size_y]
-        %6 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_y, %workgroup_size_y]
-        scf.for %arg1 = %5 to %c3 step %6 {
-          %7 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_x, %workgroup_size_x]
-          %8 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_x, %workgroup_size_x]
-          scf.for %arg2 = %7 to %c1 step %8 {
-            %9 = affine.min affine_map<(d0)[s0] -> (s0 + 2, -d0 + 4)>(%arg0)[%workgroup_size_z]
-            %10 = affine.min affine_map<(d0)[s0] -> (s0 + 1, -d0 + 4)>(%arg1)[%workgroup_size_y]
-            %11 = flow.dispatch.tensor.load %0, offsets = [0, %arg0, %arg1, 0], sizes = [1, %9, %10, 2], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:1x4x4x2xf32> -> tensor<1x?x?x2xf32>
-            %12 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 1)>(%arg2)[%workgroup_size_x]
-            %13 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0, %arg2], sizes = [3, 2, 2, %12], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:3x2x2x1xf32> -> tensor<3x2x2x?xf32>
-            %14 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 2)>(%arg0)[%workgroup_size_z]
-            %15 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 3)>(%arg1)[%workgroup_size_y]
-            %16 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 1)>(%arg2)[%workgroup_size_x]
-            %17 = affine.min affine_map<(d0)[s0] -> (-d0 + 2, s0)>(%arg0)[%workgroup_size_z]
-            %18 = affine.min affine_map<(d0)[s0] -> (-d0 + 3, s0)>(%arg1)[%workgroup_size_y]
-            %19 = affine.min affine_map<(d0)[s0] -> (-d0 + 1, s0)>(%arg2)[%workgroup_size_x]
-            %20 = linalg.init_tensor [1, %17, %18, %19] : tensor<1x?x?x?xf32>
-            %21 = linalg.fill(%cst, %20) : f32, tensor<1x?x?x?xf32> -> tensor<1x?x?x?xf32>
-            %22 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%11, %13 : tensor<1x?x?x2xf32>, tensor<3x2x2x?xf32>) outs(%21 : tensor<1x?x?x?xf32>) -> tensor<1x?x?x?xf32>
-            flow.dispatch.tensor.store %22, %2, offsets = [0, %arg0, %arg1, %arg2], sizes = [1, %14, %15, %16], strides = [1, 1, 1, 1] : tensor<1x?x?x?xf32> -> !flow.dispatch.tensor<writeonly:1x2x3x1xf32>
-          }
-        }
-      }
+      %11 = flow.dispatch.tensor.load %0, offsets = [0, 0 ,0, 0], sizes = [1, 4, 4, 2], strides = [1, 1, 1, 1]
+          : !flow.dispatch.tensor<readonly:1x4x4x2xf32> -> tensor<1x4x4x2xf32>
+      %13 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0, 0], sizes = [3, 2, 2, 1], strides = [1, 1, 1, 1]
+          : !flow.dispatch.tensor<readonly:3x2x2x1xf32> -> tensor<3x2x2x1xf32>
+      %20 = linalg.init_tensor [1, 2, 3, 1] : tensor<1x2x3x1xf32>
+      %21 = linalg.fill(%cst, %20) : f32, tensor<1x2x3x1xf32> -> tensor<1x2x3x1xf32>
+      %22 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
+          ins(%11, %13 : tensor<1x4x4x2xf32>, tensor<3x2x2x1xf32>) outs(%21 : tensor<1x2x3x1xf32>) -> tensor<1x2x3x1xf32>
+      flow.dispatch.tensor.store %22, %2, offsets = [0, 0, 0, 0], sizes = [1, 2, 3, 1], strides = [1, 1, 1, 1]
+          : tensor<1x2x3x1xf32> -> !flow.dispatch.tensor<writeonly:1x2x3x1xf32>
       return
     }
   }
@@ -315,25 +257,20 @@
       %c96 = arith.constant 96 : index
       %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) : !flow.dispatch.tensor<readonly:14x14x96xf32>
       %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) : !flow.dispatch.tensor<writeonly:96xf32>
-      %workgroup_size_x = hal.interface.workgroup.size[0] : index
-      %workgroup_id_x = hal.interface.workgroup.id[0] : index
-      %workgroup_count_x = hal.interface.workgroup.count[0] : index
-      %2 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_x, %workgroup_size_x]
-      %3 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_x, %workgroup_size_x]
-      scf.for %arg0 = %2 to %c96 step %3 {
-        %4 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 96)>(%arg0)[%workgroup_size_x]
-        %5 = flow.dispatch.tensor.load %0, offsets = [0, 0, %arg0], sizes = [14, 14, %4], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:14x14x96xf32> -> tensor<14x14x?xf32>
-        %6 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 96)>(%arg0)[%workgroup_size_x]
-        %7 = affine.min affine_map<(d0)[s0] -> (-d0 + 96, s0)>(%arg0)[%workgroup_size_x]
-        %8 = linalg.init_tensor [%7] : tensor<?xf32>
-        %9 = linalg.fill(%cst, %8) : f32, tensor<?xf32> -> tensor<?xf32>
-        %10 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d2, d0)>, affine_map<(d0, d1, d2) -> (d0)>], iterator_types = ["parallel", "reduction", "reduction"]} ins(%5 : tensor<14x14x?xf32>) outs(%9 : tensor<?xf32>) {
+      %5 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0], sizes = [14, 14, 96], strides = [1, 1, 1]
+          : !flow.dispatch.tensor<readonly:14x14x96xf32> -> tensor<14x14x96xf32>
+      %8 = linalg.init_tensor [96] : tensor<96xf32>
+      %9 = linalg.fill(%cst, %8) : f32, tensor<96xf32> -> tensor<96xf32>
+      %10 = linalg.generic {
+            indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d2, d0)>, affine_map<(d0, d1, d2) -> (d0)>],
+            iterator_types = ["parallel", "reduction", "reduction"]}
+            ins(%5 : tensor<14x14x96xf32>) outs(%9 : tensor<96xf32>) {
         ^bb0(%arg1: f32, %arg2: f32):  // no predecessors
           %11 = arith.addf %arg1, %arg2 : f32
           linalg.yield %11 : f32
-        } -> tensor<?xf32>
-        flow.dispatch.tensor.store %10, %1, offsets = [%arg0], sizes = [%6], strides = [1] : tensor<?xf32> -> !flow.dispatch.tensor<writeonly:96xf32>
-      }
+        } -> tensor<96xf32>
+      flow.dispatch.tensor.store %10, %1, offsets = [0], sizes = [96], strides = [1]
+          : tensor<96xf32> -> !flow.dispatch.tensor<writeonly:96xf32>
       return
     }
   }
@@ -363,25 +300,21 @@
       %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) : !flow.dispatch.tensor<readonly:16384xf32>
       %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) : !flow.dispatch.tensor<readonly:16384xf32>
       %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) : !flow.dispatch.tensor<writeonly:16384xf32>
-      %workgroup_size_x = hal.interface.workgroup.size[0] : index
-      %workgroup_id_x = hal.interface.workgroup.id[0] : index
-      %workgroup_count_x = hal.interface.workgroup.count[0] : index
-      %3 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_x, %workgroup_size_x]
-      %4 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_x, %workgroup_size_x]
-      scf.for %arg0 = %3 to %c16384 step %4 {
-        %5 = affine.min affine_map<(d0, d1) -> (d1, -d0 + 16384)>(%arg0)[%workgroup_size_x]
-        %6 = flow.dispatch.tensor.load %0, offsets = [%arg0], sizes = [%5], strides = [1] : !flow.dispatch.tensor<readonly:16384xf32> -> tensor<?xf32>
-        %7 = affine.min affine_map<(d0, d1) -> (d1, -d0 + 16384)>(%arg0)[%workgroup_size_x]
-        %8 = flow.dispatch.tensor.load %1, offsets = [%arg0], sizes = [%7], strides = [1] : !flow.dispatch.tensor<readonly:16384xf32> -> tensor<?xf32>
-        %9 = affine.min affine_map<(d0, d1) -> (d1, -d0 + 16384)>(%arg0)[%workgroup_size_x]
-        %10 = linalg.init_tensor [%9] : tensor<?xf32>
-        %11 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%6, %8 : tensor<?xf32>, tensor<?xf32>) outs(%10 : tensor<?xf32>) {
+      %6 = flow.dispatch.tensor.load %0, offsets = [0], sizes = [16384], strides = [1]
+          : !flow.dispatch.tensor<readonly:16384xf32> -> tensor<16384xf32>
+      %8 = flow.dispatch.tensor.load %1, offsets = [0], sizes = [16384], strides = [1]
+          : !flow.dispatch.tensor<readonly:16384xf32> -> tensor<16384xf32>
+      %10 = linalg.init_tensor [16384] : tensor<16384xf32>
+      %11 = linalg.generic {
+          indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>],
+          iterator_types = ["parallel"]}
+          ins(%6, %8 : tensor<16384xf32>, tensor<16384xf32>) outs(%10 : tensor<16384xf32>) {
         ^bb0(%arg1: f32, %arg2: f32, %arg3: f32):  // no predecessors
           %12 = arith.addf %arg1, %arg2 : f32
           linalg.yield %12 : f32
-        } -> tensor<?xf32>
-        flow.dispatch.tensor.store %11, %2, offsets = [%arg0], sizes = [%9], strides = [1] : tensor<?xf32> -> !flow.dispatch.tensor<writeonly:16384xf32>
-      }
+        } -> tensor<16384xf32>
+      flow.dispatch.tensor.store %11, %2, offsets = [0], sizes = [16384], strides = [1]
+          : tensor<16384xf32> -> !flow.dispatch.tensor<writeonly:16384xf32>
       return
     }
   }
@@ -416,25 +349,19 @@
       %cst = arith.constant 1.000000e+00 : f32
       %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) : !flow.dispatch.tensor<readonly:512x16384xf32>
       %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) : !flow.dispatch.tensor<writeonly:16384xf32>
-      %workgroup_size_x = hal.interface.workgroup.size[0] : index
-      %workgroup_id_x = hal.interface.workgroup.id[0] : index
-      %workgroup_count_x = hal.interface.workgroup.count[0] : index
-      %2 = affine.apply #map0()[%workgroup_id_x, %workgroup_size_x]
-      %3 = affine.apply #map0()[%workgroup_count_x, %workgroup_size_x]
-      scf.for %arg0 = %2 to %c16384 step %3 {
-        %4 = affine.min #map1(%arg0)[%workgroup_size_x]
-        %5 = flow.dispatch.tensor.load %0, offsets = [0, %arg0], sizes = [512, %4], strides = [1, 1] : !flow.dispatch.tensor<readonly:512x16384xf32> -> tensor<512x?xf32>
-        %6 = affine.min #map1(%arg0)[%workgroup_size_x]
-        %7 = affine.min #map2(%arg0)[%workgroup_size_x]
-        %8 = linalg.init_tensor [%7] : tensor<?xf32>
-        %9 = linalg.fill(%cst, %8) : f32, tensor<?xf32> -> tensor<?xf32>
-        %10 = linalg.generic {indexing_maps = [#map3, #map4], iterator_types = ["parallel", "reduction"]} ins(%5 : tensor<512x?xf32>) outs(%9 : tensor<?xf32>) {
+      %5 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [512, 16384], strides = [1, 1]
+          : !flow.dispatch.tensor<readonly:512x16384xf32> -> tensor<512x16384xf32>
+      %8 = linalg.init_tensor [16384] : tensor<16384xf32>
+      %9 = linalg.fill(%cst, %8) : f32, tensor<16384xf32> -> tensor<16384xf32>
+      %10 = linalg.generic {
+          indexing_maps = [#map3, #map4], iterator_types = ["parallel", "reduction"]}
+          ins(%5 : tensor<512x16384xf32>) outs(%9 : tensor<16384xf32>) {
         ^bb0(%arg1: f32, %arg2: f32):  // no predecessors
           %11 = arith.addf %arg1, %arg2 : f32
           linalg.yield %11 : f32
-        } -> tensor<?xf32>
-        flow.dispatch.tensor.store %10, %1, offsets = [%arg0], sizes = [%6], strides = [1] : tensor<?xf32> -> !flow.dispatch.tensor<writeonly:16384xf32>
-      }
+        } -> tensor<16384xf32>
+      flow.dispatch.tensor.store %10, %1, offsets = [0], sizes = [16384], strides = [1]
+          : tensor<16384xf32> -> !flow.dispatch.tensor<writeonly:16384xf32>
       return
     }
   }
@@ -464,39 +391,30 @@
       %cst = arith.constant 0.000000e+00 : f32
       %c2048 = arith.constant 2048 : index
       %c512 = arith.constant 512 : index
-      %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(32) : memref<2048x1024xf32>
-      memref.assume_alignment %0, 32 : memref<2048x1024xf32>
-      %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(32) : memref<1024x512xf32>
-      memref.assume_alignment %1, 32 : memref<1024x512xf32>
-      %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) offset(%c0) alignment(32) : memref<2048x512xf32>
-      memref.assume_alignment %2, 32 : memref<2048x512xf32>
-      %workgroup_size_x = hal.interface.workgroup.size[0] : index
-      %workgroup_size_y = hal.interface.workgroup.size[1] : index
-      %workgroup_id_x = hal.interface.workgroup.id[0] : index
-      %workgroup_count_x = hal.interface.workgroup.count[0] : index
-      %workgroup_id_y = hal.interface.workgroup.id[1] : index
-      %workgroup_count_y = hal.interface.workgroup.count[1] : index
-      %3 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_y, %workgroup_size_y]
-      %4 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_y, %workgroup_size_y]
-      scf.for %arg0 = %3 to %c2048 step %4 {
-        %5 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_x, %workgroup_size_x]
-        %6 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_x, %workgroup_size_x]
-        scf.for %arg1 = %5 to %c512 step %6 {
-          %7 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 2048)>(%arg0)[%workgroup_size_y]
-          %8 = memref.subview %0[%arg0, 0] [%7, 1024] [1, 1] : memref<2048x1024xf32> to memref<?x1024xf32, affine_map<(d0, d1)[s0] -> (d0 * 1024 + s0 + d1)>>
-          %9 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 512)>(%arg1)[%workgroup_size_x]
-          %10 = memref.subview %1[0, %arg1] [1024, %9] [1, 1] : memref<1024x512xf32> to memref<1024x?xf32, affine_map<(d0, d1)[s0] -> (d0 * 512 + s0 + d1)>>
-          %11 = memref.subview %2[%arg0, %arg1] [%7, %9] [1, 1] : memref<2048x512xf32> to memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d0 * 512 + s0 + d1)>>
-          linalg.fill(%cst, %11) : f32, memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d0 * 512 + s0 + d1)>>
-          linalg.matmul ins(%8, %10 : memref<?x1024xf32, affine_map<(d0, d1)[s0] -> (d0 * 1024 + s0 + d1)>>, memref<1024x?xf32, affine_map<(d0, d1)[s0] -> (d0 * 512 + s0 + d1)>>) outs(%11 : memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d0 * 512 + s0 + d1)>>)
-          linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
-              iterator_types = ["parallel", "parallel"]} ins(%11, %11 : memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d0 * 512 + s0 + d1)>>, memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d0 * 512 + s0 + d1)>>) outs(%11 : memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d0 * 512 + s0 + d1)>>) {
-            ^bb0(%arg3: f32, %arg4: f32, %arg5: f32):  // no predecessors
-              %19 = arith.addf %arg3, %arg4 : f32
-             linalg.yield %19 : f32
-            }
-        }
-      }
+      %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) : !flow.dispatch.tensor<readonly:2048x1024xf32>
+      %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) : !flow.dispatch.tensor<readonly:1024x512xf32>
+      %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) : !flow.dispatch.tensor<writeonly:2048x512xf32>
+      %di = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) : !flow.dispatch.tensor<readonly:2048x512xf32>
+      %3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [2048, 1024], strides = [1, 1]
+          : !flow.dispatch.tensor<readonly:2048x1024xf32> -> tensor<2048x1024xf32>
+      %4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [1024, 512], strides = [1, 1]
+          : !flow.dispatch.tensor<readonly:1024x512xf32> -> tensor<1024x512xf32>   
+      %d = flow.dispatch.tensor.load %di, offsets = [0, 0], sizes = [2048, 512], strides = [1, 1]
+          : !flow.dispatch.tensor<readonly:2048x512xf32> -> tensor<2048x512xf32>             
+      %init = linalg.init_tensor [2048, 512] : tensor<2048x512xf32>
+      %f = linalg.fill(%cst, %init) : f32, tensor<2048x512xf32> -> tensor<2048x512xf32>
+      %m = linalg.matmul ins(%3, %4 : tensor<2048x1024xf32>, tensor<1024x512xf32>) outs(%f : tensor<2048x512xf32>) -> tensor<2048x512xf32>
+      %init2 = linalg.init_tensor [2048, 512] : tensor<2048x512xf32>
+      %a = linalg.generic {
+          indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
+          iterator_types = ["parallel", "parallel"]}
+          ins(%m, %d : tensor<2048x512xf32>, tensor<2048x512xf32>) outs(%m : tensor<2048x512xf32>) {
+        ^bb0(%arg3: f32, %arg4: f32, %arg5: f32):  // no predecessors
+          %19 = arith.addf %arg3, %arg4 : f32
+          linalg.yield %19 : f32
+        } -> (tensor<2048x512xf32>)
+        flow.dispatch.tensor.store %a, %2, offsets = [0, 0], sizes = [2048, 512], strides = [1, 1]
+          : tensor<2048x512xf32> -> !flow.dispatch.tensor<writeonly:2048x512xf32>
       return
     }
   }
@@ -582,42 +500,22 @@
           %c4 = arith.constant 4 : index
           %cst = arith.constant 0.000000e+00 : f32
           %c0 = arith.constant 0 : index
-          %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(32) : !flow.dispatch.tensor<readonly:4x32x1024xf32>
-          %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(32) : !flow.dispatch.tensor<readonly:4x1024x64xf32>
-          %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) offset(%c0) alignment(32) : !flow.dispatch.tensor<writeonly:4x32x64xf32>
-          %workgroup_size_x = hal.interface.workgroup.size[0] : index
-          %workgroup_size_y = hal.interface.workgroup.size[1] : index
-          %workgroup_size_z = hal.interface.workgroup.size[2] : index
-          %workgroup_id_x = hal.interface.workgroup.id[0] : index
-          %workgroup_count_x = hal.interface.workgroup.count[0] : index
-          %workgroup_id_y = hal.interface.workgroup.id[1] : index
-          %workgroup_count_y = hal.interface.workgroup.count[1] : index
-          %workgroup_id_z = hal.interface.workgroup.id[2] : index
-          %workgroup_count_z = hal.interface.workgroup.count[2] : index
-          %3 = affine.apply #map0()[%workgroup_id_z, %workgroup_size_z]
-          %4 = affine.apply #map0()[%workgroup_count_z, %workgroup_size_z]
-          scf.for %arg0 = %3 to %c4 step %4 {
-            %5 = affine.apply #map0()[%workgroup_id_y, %workgroup_size_y]
-            %6 = affine.apply #map0()[%workgroup_count_y, %workgroup_size_y]
-            scf.for %arg1 = %5 to %c32 step %6 {
-              %7 = affine.apply #map0()[%workgroup_id_x, %workgroup_size_x]
-              %8 = affine.apply #map0()[%workgroup_count_x, %workgroup_size_x]
-              scf.for %arg2 = %7 to %c64 step %8 {
-                %9 = affine.min #map1(%arg0)[%workgroup_size_z]
-                %10 = affine.min #map2(%arg1)[%workgroup_size_y]
-                %11 = flow.dispatch.tensor.load %0, offsets = [%arg0, %arg1, 0], sizes = [%9, %10, 1024], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:4x32x1024xf32> -> tensor<?x?x1024xf32>
-                %12 = affine.min #map3(%arg2)[%workgroup_size_x]
-                %13 = flow.dispatch.tensor.load %1, offsets = [%arg0, 0, %arg2], sizes = [%9, 1024, %12], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:4x1024x64xf32> -> tensor<?x1024x?xf32>
-                %14 = affine.min #map4(%arg0)[%workgroup_size_z]
-                %15 = affine.min #map5(%arg1)[%workgroup_size_y]
-                %16 = affine.min #map6(%arg2)[%workgroup_size_x]
-                %17 = linalg.init_tensor [%14, %15, %16] : tensor<?x?x?xf32>
-                %18 = linalg.fill(%cst, %17) : f32, tensor<?x?x?xf32> -> tensor<?x?x?xf32>
-                %19 = linalg.batch_matmul ins(%11, %13 : tensor<?x?x1024xf32>, tensor<?x1024x?xf32>) outs(%18 : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
-                flow.dispatch.tensor.store %19, %2, offsets = [%arg0, %arg1, %arg2], sizes = [%9, %10, %12], strides = [1, 1, 1] : tensor<?x?x?xf32> -> !flow.dispatch.tensor<writeonly:4x32x64xf32>
-              }
-            }
-          }
+          %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(32)
+              : !flow.dispatch.tensor<readonly:4x32x1024xf32>
+          %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(32)
+              : !flow.dispatch.tensor<readonly:4x1024x64xf32>
+          %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) offset(%c0) alignment(32)
+              : !flow.dispatch.tensor<writeonly:4x32x64xf32>
+          %11 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0], sizes = [4, 32, 1024], strides = [1, 1, 1]
+              : !flow.dispatch.tensor<readonly:4x32x1024xf32> -> tensor<4x32x1024xf32>
+          %13 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0], sizes = [4, 1024, 64], strides = [1, 1, 1]
+              : !flow.dispatch.tensor<readonly:4x1024x64xf32> -> tensor<4x1024x64xf32>
+          %17 = linalg.init_tensor [4, 32, 64] : tensor<4x32x64xf32>
+          %18 = linalg.fill(%cst, %17) : f32, tensor<4x32x64xf32> -> tensor<4x32x64xf32>
+          %19 = linalg.batch_matmul ins(%11, %13 : tensor<4x32x1024xf32>, tensor<4x1024x64xf32>)
+              outs(%18 : tensor<4x32x64xf32>) -> tensor<4x32x64xf32>
+          flow.dispatch.tensor.store %19, %2, offsets = [0, 0, 0], sizes = [4, 32, 64], strides = [1, 1, 1]
+              : tensor<4x32x64xf32> -> !flow.dispatch.tensor<writeonly:4x32x64xf32>
           return
         }
       }
diff --git a/iree/compiler/Codegen/LLVMGPU/test/rocdl_pipeline_test.mlir b/iree/compiler/Codegen/LLVMGPU/test/rocdl_pipeline_test.mlir
index 4f5d4de..ad61567 100644
--- a/iree/compiler/Codegen/LLVMGPU/test/rocdl_pipeline_test.mlir
+++ b/iree/compiler/Codegen/LLVMGPU/test/rocdl_pipeline_test.mlir
@@ -62,32 +62,16 @@
         %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) : !flow.dispatch.tensor<readonly:1024x1024xf32>
         %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) : !flow.dispatch.tensor<readonly:1024x1024xf32>
         %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) : !flow.dispatch.tensor<writeonly:1024x1024xf32>
-        %workgroup_size_x = hal.interface.workgroup.size[0] : index
-        %workgroup_size_y = hal.interface.workgroup.size[1] : index
-        %workgroup_id_x = hal.interface.workgroup.id[0] : index
-        %workgroup_count_x = hal.interface.workgroup.count[0] : index
-        %workgroup_id_y = hal.interface.workgroup.id[1] : index
-        %workgroup_count_y = hal.interface.workgroup.count[1] : index
-        %3 = affine.apply #map0()[%workgroup_id_y, %workgroup_size_y]
-        %4 = affine.apply #map0()[%workgroup_count_y, %workgroup_size_y]
-        scf.for %arg0 = %3 to %c1024 step %4 {
-          %5 = affine.apply #map0()[%workgroup_id_x, %workgroup_size_x]
-          %6 = affine.apply #map0()[%workgroup_count_x, %workgroup_size_x]
-          scf.for %arg1 = %5 to %c1024 step %6 {
-            %7 = affine.min #map1(%arg0)[%workgroup_size_y]
-            %8 = flow.dispatch.tensor.load %0, offsets = [%arg0, %c0], sizes = [%7, %c1024], strides = [%c1, %c1] : !flow.dispatch.tensor<readonly:1024x1024xf32> -> tensor<?x1024xf32>
-            %9 = affine.min #map1(%arg1)[%workgroup_size_x]
-            %10 = flow.dispatch.tensor.load %1, offsets = [%c0, %arg1], sizes = [%c1024, %9], strides = [%c1, %c1] : !flow.dispatch.tensor<readonly:1024x1024xf32> -> tensor<1024x?xf32>
-            %11 = affine.min #map1(%arg0)[%workgroup_size_y]
-            %12 = affine.min #map1(%arg1)[%workgroup_size_x]
-            %13 = affine.min #map2(%arg0)[%workgroup_size_y]
-            %14 = affine.min #map2(%arg1)[%workgroup_size_x]
-            %15 = linalg.init_tensor [%13, %14] : tensor<?x?xf32>
-            %16 = linalg.fill(%cst, %15) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
-            %17 = linalg.matmul ins(%8, %10 : tensor<?x1024xf32>, tensor<1024x?xf32>) outs(%16 : tensor<?x?xf32>) -> tensor<?x?xf32>
-            flow.dispatch.tensor.store %17, %2, offsets = [%arg0, %arg1], sizes = [%11, %12], strides = [%c1, %c1] : tensor<?x?xf32> -> !flow.dispatch.tensor<writeonly:1024x1024xf32>
-          }
-        }
+        %8 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [1024, 1024], strides = [1, 1]
+            : !flow.dispatch.tensor<readonly:1024x1024xf32> -> tensor<1024x1024xf32>
+        %10 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [1024, 1024], strides = [1, 1]
+            : !flow.dispatch.tensor<readonly:1024x1024xf32> -> tensor<1024x1024xf32>
+        %15 = linalg.init_tensor [1024, 1024] : tensor<1024x1024xf32>
+        %16 = linalg.fill(%cst, %15) : f32, tensor<1024x1024xf32> -> tensor<1024x1024xf32>
+        %17 = linalg.matmul ins(%8, %10 : tensor<1024x1024xf32>, tensor<1024x1024xf32>)
+            outs(%16 : tensor<1024x1024xf32>) -> tensor<1024x1024xf32>
+        flow.dispatch.tensor.store %17, %2, offsets = [0, 0], sizes = [1024, 1024], strides = [1, 1]
+            : tensor<1024x1024xf32> -> !flow.dispatch.tensor<writeonly:1024x1024xf32>
         return
       }
     }
diff --git a/iree/compiler/Codegen/Passes.h b/iree/compiler/Codegen/Passes.h
index 24df586..d4867ca 100644
--- a/iree/compiler/Codegen/Passes.h
+++ b/iree/compiler/Codegen/Passes.h
@@ -371,8 +371,7 @@
 createSPIRVLowerExecutableTargetPass();
 
 /// Initializes CodeGen configuration for the given dispatch region.
-std::unique_ptr<OperationPass<IREE::HAL::ExecutableVariantOp>>
-createSPIRVInitConfigPass();
+std::unique_ptr<OperationPass<ModuleOp>> createSPIRVInitConfigPass();
 
 /// Pass to tile and distribute Linalg ops with buffer semantics to invocations.
 std::unique_ptr<OperationPass<FuncOp>> createSPIRVTileAndDistributePass();
diff --git a/iree/compiler/Codegen/Passes.td b/iree/compiler/Codegen/Passes.td
index 25a8682..5793bcb 100644
--- a/iree/compiler/Codegen/Passes.td
+++ b/iree/compiler/Codegen/Passes.td
@@ -306,8 +306,7 @@
 }
 
 def SPIRVInitConfig :
-    Pass<"iree-spirv-init-config-pass",
-         "mlir::iree_compiler::IREE::HAL::ExecutableVariantOp"> {
+    Pass<"iree-spirv-init-config-pass", "ModuleOp"> {
   let summary = "Initialize CodeGen configuration for a given dispatch region";
   let constructor = "mlir::iree_compiler::createSPIRVInitConfigPass()";
 }
diff --git a/iree/compiler/Codegen/SPIRV/KernelConfig.cpp b/iree/compiler/Codegen/SPIRV/KernelConfig.cpp
index 14fb49d..9567d3f 100644
--- a/iree/compiler/Codegen/SPIRV/KernelConfig.cpp
+++ b/iree/compiler/Codegen/SPIRV/KernelConfig.cpp
@@ -36,8 +36,16 @@
 LogicalResult setConvOpConfig(linalg::LinalgOp linalgOp,
                               const int64_t subgroupSize,
                               const int64_t bestTilingFactor) {
-  ArrayRef<int64_t> inputShape = getUntiledShape(linalgOp.inputs()[0]);
-  SmallVector<int64_t> outputShape = getUntiledResultShape(linalgOp, 0);
+  ArrayRef<int64_t> inputShape = linalgOp.getInputOperand(0)
+                                     ->get()
+                                     .getType()
+                                     .cast<ShapedType>()
+                                     .getShape();
+  ArrayRef<int64_t> outputShape = linalgOp.getOutputOperand(0)
+                                      ->get()
+                                      .getType()
+                                      .cast<ShapedType>()
+                                      .getShape();
   if (llvm::any_of(inputShape, ShapedType::isDynamic)) return success();
   if (llvm::any_of(outputShape, ShapedType::isDynamic)) return success();
 
@@ -163,8 +171,10 @@
   auto elementBits = lhsType.getElementType().getIntOrFloatBitWidth();
   if (elementBits != 16 && elementBits != 32) return success();
 
-  ArrayRef<int64_t> lhsShape = getUntiledShape(op.inputs()[0]);
-  ArrayRef<int64_t> rhsShape = getUntiledShape(op.inputs()[1]);
+  ArrayRef<int64_t> lhsShape =
+      op.getInputOperand(0)->get().getType().cast<ShapedType>().getShape();
+  ArrayRef<int64_t> rhsShape =
+      op.getInputOperand(1)->get().getType().cast<ShapedType>().getShape();
   if (llvm::any_of(lhsShape, ShapedType::isDynamic)) return success();
   if (llvm::any_of(rhsShape, ShapedType::isDynamic)) return success();
 
@@ -364,27 +374,6 @@
   // 1) distributing to as many threads as possible, and 2) avoid assigning too
   // many threads to handle out-of-bound elements (thus idle).
 
-  SmallVector<LoopTilingAndDistributionInfo> tiledLoopInfo =
-      getTiledAndDistributedLoopInfo(funcOp);
-  // The number of linalg implicit loops to partition and tiled loops
-  // surrounding the op should match. Otherwise, something is incorrect.
-  assert(partitionedLoops.size() == tiledLoopInfo.size());
-
-  // The upper bound for each implicit loop: 0 - untiled, negative - dynamic.
-  SmallVector<int64_t> loopBounds(loopDepth, 0);
-  // tiledLoopInfo uses the reverse order of partitionedLoops.
-  for (auto pair : llvm::zip(llvm::reverse(partitionedLoops), tiledLoopInfo)) {
-    unsigned loopIndex = std::get<0>(pair);
-    const LoopTilingAndDistributionInfo &loopInfo = std::get<1>(pair);
-    Optional<int64_t> attrValue =
-        getConstantIntValue(loopInfo.untiledUpperBound);
-    if (attrValue) {
-      loopBounds[loopIndex] = *attrValue;
-    } else {
-      loopBounds[loopIndex] = ShapedType::kDynamicSize;
-    }
-  }
-
   // Returns true if the given `operand` has 32-bit element type.
   auto has32BitElementType = [](Value operand) {
     auto shapedType = operand.getType().dyn_cast<ShapedType>();
@@ -394,7 +383,7 @@
   };
 
   // Whether we can try to use the vectorization pipeline.
-  auto untiledResultShape = getUntiledResultShape(linalgOp, 0);
+  Optional<SmallVector<int64_t, 4>> loopBounds = linalgOp.getStaticLoopRanges();
   bool vectorizable =
       !linalgOp.hasIndexSemantics() &&
       // Skip vectorization for non-minor identity inputs as it generates
@@ -407,16 +396,7 @@
       // TODO: Lowering of integers other than i32 may require emulation.
       // This is currently not supported for vector operation.
       llvm::all_of(linalgOp->getOperands(), has32BitElementType) &&
-      !untiledResultShape.empty() &&
-      llvm::none_of(untiledResultShape, ShapedType::isDynamic);
-
-  LLVM_DEBUG({
-    llvm::dbgs() << "Linalg op " << linalgOp << "\n  partitioned loops: [";
-    llvm::interleaveComma(partitionedLoops, llvm::dbgs());
-    llvm::dbgs() << "]\n  loop bounds: [";
-    llvm::interleaveComma(loopBounds, llvm::dbgs());
-    llvm::dbgs() << "]\n";
-  });
+      loopBounds && llvm::none_of(loopBounds.getValue(), ShapedType::isDynamic);
 
   // Distribute workload to the given `numThreads` by allowing a potental loss.
   auto distributeToThreads = [&](int64_t numThreads,
@@ -426,19 +406,11 @@
 
     // Scan from the innermost shape dimension and try to deduce the
     // configuration for the corresponding GPU workgroup dimension.
-    for (auto p : llvm::zip(llvm::reverse(partitionedLoops), tiledLoopInfo)) {
-      int shapeDim = std::get<0>(p);
-      int wgDim = std::get<1>(p).processorDistributionDim;
-      LLVM_DEBUG({
-        llvm::dbgs() << "Remaining threads: " << numThreads << "\n";
-        llvm::dbgs() << "Shape dim #" << shapeDim << "=";
-        llvm::dbgs() << loopBounds[shapeDim] << "\n"
-                     << "Workgroup dim #" << wgDim << "\n";
-      });
-
+    int64_t wgDim = 0;
+    for (auto shapeDim : llvm::reverse(partitionedLoops)) {
       // Skip untiled or dynamic dimensions.
       // TODO: Skip size-1 dimensions in Flow level tiling and distribution.
-      if (loopBounds[shapeDim] <= 0) continue;
+      if (loopBounds.getValue()[shapeDim] <= 0) continue;
 
       // Try to find some power of two that can devide the current shape dim
       // size. This vector keeps the candidate tile sizes.
@@ -460,10 +432,11 @@
       });
 
       for (int64_t candidate : candidates) {
-        if (loopBounds[shapeDim] % candidate != 0) {
+        if (loopBounds.getValue()[shapeDim] % candidate != 0) {
           if (!lossFactor) continue;
           // Skip this candidate if it causes many threads to be idle.
-          int64_t idleThreads = candidate - (loopBounds[shapeDim] % candidate);
+          int64_t idleThreads =
+              candidate - (loopBounds.getValue()[shapeDim] % candidate);
           if (idleThreads > candidate / *lossFactor) continue;
         }
         LLVM_DEBUG(llvm::dbgs() << "Chosen Candiate " << candidate << "\n");
@@ -473,13 +446,13 @@
         workgroupTileSizes[shapeDim] = candidate;
         if (vectorizable && wgDim == 0 && !lossFactor && candidate % 4 == 0) {
           threadTileSizes[shapeDim] = 4;
-          workgroupSize[wgDim++] = candidate / 4;
+          workgroupSize[wgDim] = candidate / 4;
           assert(numThreads % (candidate / 4) == 0);
           numThreads /= candidate / 4;
         } else {
           if (wgDim == 0) vectorizable = false;
           threadTileSizes[shapeDim] = 1;
-          workgroupSize[wgDim++] = candidate;
+          workgroupSize[wgDim] = candidate;
           assert(numThreads % candidate == 0);
           numThreads /= candidate;
         }
@@ -489,6 +462,7 @@
 
       // Stop if we have distributed all threads.
       if (numThreads == 1) break;
+      wgDim++;
     }
     return numThreads;
   };
@@ -617,6 +591,11 @@
       return funcOp.emitOpError("failed to get compute ops");
     }
 
+    if (computeOps.empty()) {
+      return funcOp.emitOpError(
+          "unhandled translation of function without compute ops");
+    }
+
     Operation *rootOperation = nullptr;
     // Try to find a configuration according to a matmul/convolution op and use
     // it as the root op.
@@ -635,7 +614,27 @@
 
     // If there are still no root op, check for any linalg.generic op.
     if (!rootOperation) {
-      for (Operation *computeOp : reverse(computeOps)) {
+      Operation *computeOp = computeOps.back();
+
+      // Handle the case of compute op being a
+      // `tensor.extract_slice`/`tensor.insert_slice`. That needs bufferization
+      // to run before configuration can be set again. Just set the translation
+      // to use the `SPIRVDistributeAndCopy` pipeline. The configuration will be
+      // set again after bufferization.
+      //
+      // TODO(ravishankarm): This is a awkward.
+      // `tensor.extract_slice`/`tensor.insert_slice` will be dropped from
+      // `TiledOpInterface` soon, and will not be compute op. At that time, they
+      // will be folded with `flow.tensor.load` and `flow.tensor.store`
+      // operations. Then this case will degenerate to having no compute ops.
+      // Rework this at that stage to run bufferization early.
+      if (isa<tensor::ExtractSliceOp, tensor::InsertSliceOp>(computeOp)) {
+        setTranslationInfo(
+            funcOp,
+            IREE::Codegen::DispatchLoweringPassPipeline::SPIRVDistributeCopy,
+            /*workloadPerWorkgroup=*/ArrayRef<int64_t>{},
+            /*workgroupSize=*/ArrayRef<int64_t>{});
+      } else {
         if (failed(setDefaultOpConfig(limits, computeOp))) return failure();
 
         // Check if the op configuration was set.
@@ -644,33 +643,8 @@
               "without known roots, the last compute operation in the tiled "
               "loop body is expected to be set as root");
         }
-        rootOperation = computeOp;
-        break;
       }
-    }
-
-    if (!rootOperation) {
-      // If the tiled loops are not empty then this could be a corner case of
-      // tensor.insert_slice being tiled and distributed, that just shows up as
-      // a `flow.dispatch.tensor.load` and a `flow.dispatch.tensor.store` (or as
-      // a copy). For now just treat the tiled loops not being empty as an
-      // indicator of that. Need a better way of information flow from flow
-      // dialect to hal.
-      if (!tiledLoops.empty()) {
-        // These configuration parameters will be overwritten by the
-        // SPIRVDistributeCopy pipeline later.
-        const int64_t subgroupSize =
-            limits.subgroup_size().getValue().getSExtValue();
-        std::array<int64_t, 3> workgroupSize = {subgroupSize, 1, 1};
-        SmallVector<int64_t> workloadPerWorkgroup(tiledLoops.size(), 1);
-        workloadPerWorkgroup.front() = subgroupSize;
-        setTranslationInfo(
-            funcOp,
-            IREE::Codegen::DispatchLoweringPassPipeline::SPIRVDistributeCopy,
-            workloadPerWorkgroup, workgroupSize);
-        return success();
-      }
-      return funcOp.emitError("contains no root Linalg operation");
+      rootOperation = computeOp;
     }
 
     // Propogate the `lowering.config` attribute to the other ops.
diff --git a/iree/compiler/Codegen/SPIRV/NVIDIAConfig.cpp b/iree/compiler/Codegen/SPIRV/NVIDIAConfig.cpp
index 6166a94..7a245ba 100644
--- a/iree/compiler/Codegen/SPIRV/NVIDIAConfig.cpp
+++ b/iree/compiler/Codegen/SPIRV/NVIDIAConfig.cpp
@@ -63,8 +63,8 @@
 
   Value lhs = op.inputs()[0], rhs = op.inputs()[1], init = op.outputs()[0];
 
-  ArrayRef<int64_t> lhsShape = getUntiledShape(lhs);
-  ArrayRef<int64_t> rhsShape = getUntiledShape(rhs);
+  ArrayRef<int64_t> lhsShape = lhs.getType().cast<ShapedType>().getShape();
+  ArrayRef<int64_t> rhsShape = rhs.getType().cast<ShapedType>().getShape();
   if (llvm::any_of(lhsShape, ShapedType::isDynamic)) return success();
   if (llvm::any_of(rhsShape, ShapedType::isDynamic)) return success();
 
diff --git a/iree/compiler/Codegen/SPIRV/Passes.cpp b/iree/compiler/Codegen/SPIRV/Passes.cpp
index 80bb6d2..f822e22 100644
--- a/iree/compiler/Codegen/SPIRV/Passes.cpp
+++ b/iree/compiler/Codegen/SPIRV/Passes.cpp
@@ -135,6 +135,10 @@
 //===----------------------------------------------------------------------===//
 
 void addSPIRVTileAndVectorizePassPipeline(OpPassManager &pm) {
+  pm.addNestedPass<FuncOp>(createTileAndDistributeToWorkgroupsPass());
+  pm.addPass(createCanonicalizerPass());
+  pm.addPass(createCSEPass());
+
   pm.addPass(createCanonicalizerPass());
   pm.addPass(createCSEPass());
 
@@ -156,6 +160,10 @@
 }
 
 void addSPIRVTileAndVectorizeToCooperativeOpsPassPipeline(OpPassManager &pm) {
+  pm.addNestedPass<FuncOp>(createTileAndDistributeToWorkgroupsPass());
+  pm.addPass(createCanonicalizerPass());
+  pm.addPass(createCSEPass());
+
   addLinalgBufferizePasses(pm, gpuAllocationFunction);
 
   pm.addPass(createCanonicalizerPass());
@@ -178,6 +186,10 @@
 }
 
 void addSPIRVTileAndDistributePassPipeline(OpPassManager &pm) {
+  pm.addNestedPass<FuncOp>(createTileAndDistributeToWorkgroupsPass());
+  pm.addPass(createCanonicalizerPass());
+  pm.addPass(createCSEPass());
+
   addLinalgBufferizePasses(pm, gpuAllocationFunction);
 
   pm.addPass(createCanonicalizerPass());
@@ -205,25 +217,19 @@
 // still perform bufferization first to expose a linalg.copy op, from which we
 // can deduce the configuration.
 void addSPIRVTileAndDistributeCopyPassPipeline(OpPassManager &pm) {
-  addLinalgBufferizePasses(pm.nest<ModuleOp>(), gpuAllocationFunction);
-
-  // Rerun CodeGen configuration deduction after bufferization. This enables
-  // us to find a better configuration for linalg.copy ops and attach the
-  // `lowering.config` attribute properly to drive transformations.
+  addLinalgBufferizePasses(pm, gpuAllocationFunction);
   pm.addPass(createSPIRVInitConfigPass());
-  pm.addPass(createSetNumWorkgroupsPass());
 
-  OpPassManager &modulePM = pm.nest<ModuleOp>();
-
-  modulePM.addPass(createCanonicalizerPass());
-  modulePM.addPass(createCSEPass());
+  pm.addNestedPass<FuncOp>(createTileAndDistributeToWorkgroupsPass());
+  pm.addPass(createCanonicalizerPass());
+  pm.addPass(createCSEPass());
 
   // Tile and distribute to GPU invocations.
-  modulePM.addNestedPass<FuncOp>(createSPIRVTileAndDistributePass());
-  modulePM.addPass(createCanonicalizerPass());
-  modulePM.addPass(createCSEPass());
+  pm.addNestedPass<FuncOp>(createSPIRVTileAndDistributePass());
+  pm.addPass(createCanonicalizerPass());
+  pm.addPass(createCSEPass());
 
-  addLoopMaterializationPasses(modulePM);
+  addLoopMaterializationPasses(pm);
 }
 
 //===----------------------------------------------------------------------===//
@@ -232,10 +238,7 @@
 
 void buildSPIRVCodegenPassPipeline(OpPassManager &pm) {
   pm.nest<ModuleOp>().nest<FuncOp>().addPass(createTypePropagationPass());
-  pm.nest<ModuleOp>().nest<FuncOp>().addPass(
-      createTileAndDistributeToWorkgroupsPass());
-  pm.addPass(createCanonicalizerPass());
-  pm.addPass(createCSEPass());
+
   pm.addPass(createSPIRVLowerExecutableTargetPass());
   addMemRefLoweringPasses(pm.nest<ModuleOp>());
   addSPIRVLoweringPasses(pm.nest<ModuleOp>());
diff --git a/iree/compiler/Codegen/SPIRV/SPIRVInitConfigPass.cpp b/iree/compiler/Codegen/SPIRV/SPIRVInitConfigPass.cpp
index da57d08..c8b77cc 100644
--- a/iree/compiler/Codegen/SPIRV/SPIRVInitConfigPass.cpp
+++ b/iree/compiler/Codegen/SPIRV/SPIRVInitConfigPass.cpp
@@ -29,16 +29,14 @@
   }
 
   void runOnOperation() override {
-    IREE::HAL::ExecutableVariantOp variantOp = getOperation();
-    ModuleOp moduleOp = variantOp.getInnerModule();
+    ModuleOp moduleOp = getOperation();
     if (failed(initSPIRVLaunchConfig(moduleOp))) return signalPassFailure();
   }
 };
 
 }  // namespace
 
-std::unique_ptr<OperationPass<IREE::HAL::ExecutableVariantOp>>
-createSPIRVInitConfigPass() {
+std::unique_ptr<OperationPass<ModuleOp>> createSPIRVInitConfigPass() {
   return std::make_unique<SPIRVInitConfigPass>();
 }
 
diff --git a/iree/compiler/Codegen/SPIRV/SPIRVLowerExecutableTargetPass.cpp b/iree/compiler/Codegen/SPIRV/SPIRVLowerExecutableTargetPass.cpp
index 55a7940..a274609 100644
--- a/iree/compiler/Codegen/SPIRV/SPIRVLowerExecutableTargetPass.cpp
+++ b/iree/compiler/Codegen/SPIRV/SPIRVLowerExecutableTargetPass.cpp
@@ -93,13 +93,6 @@
     }
   }
 
-  if (*passPipeline !=
-      IREE::Codegen::DispatchLoweringPassPipeline::SPIRVDistributeCopy) {
-    // SPIRVDistributeCopy handles these passes by itself.
-    executableLoweringPipeline.addPass(createSetNumWorkgroupsPass());
-    executableLoweringPipeline.addPass(createCanonicalizerPass());
-  }
-
   if (!testLoweringConfiguration && passPipeline.hasValue()) {
     OpPassManager &nestedModulePM = executableLoweringPipeline.nest<ModuleOp>();
     switch (*passPipeline) {
@@ -107,7 +100,7 @@
         addSPIRVTileAndDistributePassPipeline(nestedModulePM);
         break;
       case IREE::Codegen::DispatchLoweringPassPipeline::SPIRVDistributeCopy:
-        addSPIRVTileAndDistributeCopyPassPipeline(executableLoweringPipeline);
+        addSPIRVTileAndDistributeCopyPassPipeline(nestedModulePM);
         break;
       case IREE::Codegen::DispatchLoweringPassPipeline::SPIRVVectorize:
         addSPIRVTileAndVectorizePassPipeline(nestedModulePM);
diff --git a/iree/compiler/Codegen/SPIRV/test/config_adreno_conv.mlir b/iree/compiler/Codegen/SPIRV/test/config_adreno_conv.mlir
index 29ddec1..7c28ecf 100644
--- a/iree/compiler/Codegen/SPIRV/test/config_adreno_conv.mlir
+++ b/iree/compiler/Codegen/SPIRV/test/config_adreno_conv.mlir
@@ -27,44 +27,16 @@
         %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) : !flow.dispatch.tensor<readonly:1x225x225x3xf32>
         %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) : !flow.dispatch.tensor<readonly:3x3x3x512xf32>
         %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) : !flow.dispatch.tensor<writeonly:1x112x112x512xf32>
-        %workgroup_size_x = hal.interface.workgroup.size[0] : index
-        %workgroup_size_y = hal.interface.workgroup.size[1] : index
-        %workgroup_size_z = hal.interface.workgroup.size[2] : index
-        %workgroup_id_x = hal.interface.workgroup.id[0] : index
-        %workgroup_count_x = hal.interface.workgroup.count[0] : index
-        %workgroup_id_y = hal.interface.workgroup.id[1] : index
-        %workgroup_count_y = hal.interface.workgroup.count[1] : index
-        %workgroup_id_z = hal.interface.workgroup.id[2] : index
-        %workgroup_count_z = hal.interface.workgroup.count[2] : index
-        %3 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_z, %workgroup_size_z]
-        %4 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_z, %workgroup_size_z]
-        scf.for %arg0 = %3 to %c112 step %4 {
-          %5 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_y, %workgroup_size_y]
-          %6 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_y, %workgroup_size_y]
-          scf.for %arg1 = %5 to %c112 step %6 {
-            %7 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_x, %workgroup_size_x]
-            %8 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_x, %workgroup_size_x]
-            scf.for %arg2 = %7 to %c512 step %8 {
-              %9 = affine.apply affine_map<(d0) -> (d0 * 2)>(%arg0)
-              %10 = affine.min affine_map<(d0)[s0] -> (s0 * 2 + 1, d0 * -2 + 225)>(%arg0)[%workgroup_size_z]
-              %11 = affine.apply affine_map<(d0) -> (d0 * 2)>(%arg1)
-              %12 = affine.min affine_map<(d0)[s0] -> (s0 * 2 + 1, d0 * -2 + 225)>(%arg1)[%workgroup_size_y]
-              %13 = flow.dispatch.tensor.load %0, offsets = [0, %9, %11, 0], sizes = [1, %10, %12, 3], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:1x225x225x3xf32> -> tensor<1x?x?x3xf32>
-              %14 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 512)>(%arg2)[%workgroup_size_x]
-              %15 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0, %arg2], sizes = [3, 3, 3, %14], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:3x3x3x512xf32> -> tensor<3x3x3x?xf32>
-              %16 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 112)>(%arg0)[%workgroup_size_z]
-              %17 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 112)>(%arg1)[%workgroup_size_y]
-              %18 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 512)>(%arg2)[%workgroup_size_x]
-              %19 = affine.min affine_map<(d0)[s0] -> (-d0 + 112, s0)>(%arg0)[%workgroup_size_z]
-              %20 = affine.min affine_map<(d0)[s0] -> (-d0 + 112, s0)>(%arg1)[%workgroup_size_y]
-              %21 = affine.min affine_map<(d0)[s0] -> (-d0 + 512, s0)>(%arg2)[%workgroup_size_x]
-              %22 = linalg.init_tensor [1, %19, %20, %21] : tensor<1x?x?x?xf32>
-              %23 = linalg.fill(%cst, %22) : f32, tensor<1x?x?x?xf32> -> tensor<1x?x?x?xf32>
-              %24 = linalg.conv_2d_nhwc_hwcf {__internal_linalg_transform__ = "workgroup", dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%13, %15 : tensor<1x?x?x3xf32>, tensor<3x3x3x?xf32>) outs(%23 : tensor<1x?x?x?xf32>) -> tensor<1x?x?x?xf32>
-              flow.dispatch.tensor.store %24, %2, offsets = [0, %arg0, %arg1, %arg2], sizes = [1, %16, %17, %18], strides = [1, 1, 1, 1] : tensor<1x?x?x?xf32> -> !flow.dispatch.tensor<writeonly:1x112x112x512xf32>
-            }
-          }
-        }
+        %13 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0, 0], sizes = [1, 225, 225, 3], strides = [1, 1, 1, 1]
+            : !flow.dispatch.tensor<readonly:1x225x225x3xf32> -> tensor<1x225x225x3xf32>
+        %15 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0, 0], sizes = [3, 3, 3, 512], strides = [1, 1, 1, 1]
+            : !flow.dispatch.tensor<readonly:3x3x3x512xf32> -> tensor<3x3x3x512xf32>
+        %22 = linalg.init_tensor [1, 112, 112, 512] : tensor<1x112x112x512xf32>
+        %23 = linalg.fill(%cst, %22) : f32, tensor<1x112x112x512xf32> -> tensor<1x112x112x512xf32>
+        %24 = linalg.conv_2d_nhwc_hwcf {__internal_linalg_transform__ = "workgroup", dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>}
+            ins(%13, %15 : tensor<1x225x225x3xf32>, tensor<3x3x3x512xf32>) outs(%23 : tensor<1x112x112x512xf32>) -> tensor<1x112x112x512xf32>
+        flow.dispatch.tensor.store %24, %2, offsets = [0, 0, 0, 0], sizes = [1, 112, 112, 512], strides = [1, 1, 1, 1]
+            : tensor<1x112x112x512xf32> -> !flow.dispatch.tensor<writeonly:1x112x112x512xf32>
         return
       }
     }
@@ -72,17 +44,10 @@
 }
 
 //  CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering.config<tile_sizes = {{\[}}[0, 1, 8, 256], [0, 1, 8, 4], [0, 0, 0, 0, 1, 1, 4]{{\]}}, native_vector_size = []>
-//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"SPIRVVectorize", workload_per_wg = [256, 8, 1]>
-//  CHECK-DAG: #[[MAP_X:.+]] = affine_map<()[s0] -> (s0 ceildiv 256)
-//  CHECK-DAG: #[[MAP_Y:.+]] = affine_map<()[s0] -> (s0 ceildiv 8)>
+//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"SPIRVVectorize", workload_per_wg = []>
 //      CHECK: hal.executable.entry_point public @conv_112x112x512
 // CHECK-SAME:   translation.info = #[[TRANSLATION]]
 // CHECK-SAME:   workgroup_size = [64 : index, 1 : index, 1 : index]
-// CHECK-NEXT: ^{{.+}}(%[[X:.+]]: index, %[[Y:.+]]: index, %[[Z:.+]]: index):
-// CHECK-NEXT:   %[[X_COUNT:.+]] = affine.apply #[[MAP_X]]()[%[[X]]]
-// CHECK-NEXT:   %[[Y_COUNT:.+]] = affine.apply #[[MAP_Y]]()[%[[Y]]]
-// CHECK-NEXT:   hal.return %[[X_COUNT]], %[[Y_COUNT]], %[[Z]]
-
 //      CHECK: func @conv_112x112x512()
 //      CHECK:   linalg.conv_2d_nhwc_hwcf
 // CHECK-SAME:     lowering.config = #[[CONFIG]]
@@ -116,44 +81,16 @@
         %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) : !flow.dispatch.tensor<readonly:1x225x225x3xf32>
         %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) : !flow.dispatch.tensor<readonly:3x3x3x32xf32>
         %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) : !flow.dispatch.tensor<writeonly:1x112x112x32xf32>
-        %workgroup_size_x = hal.interface.workgroup.size[0] : index
-        %workgroup_size_y = hal.interface.workgroup.size[1] : index
-        %workgroup_size_z = hal.interface.workgroup.size[2] : index
-        %workgroup_id_x = hal.interface.workgroup.id[0] : index
-        %workgroup_count_x = hal.interface.workgroup.count[0] : index
-        %workgroup_id_y = hal.interface.workgroup.id[1] : index
-        %workgroup_count_y = hal.interface.workgroup.count[1] : index
-        %workgroup_id_z = hal.interface.workgroup.id[2] : index
-        %workgroup_count_z = hal.interface.workgroup.count[2] : index
-        %3 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_z, %workgroup_size_z]
-        %4 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_z, %workgroup_size_z]
-        scf.for %arg0 = %3 to %c112 step %4 {
-          %5 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_y, %workgroup_size_y]
-          %6 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_y, %workgroup_size_y]
-          scf.for %arg1 = %5 to %c112 step %6 {
-            %7 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_x, %workgroup_size_x]
-            %8 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_x, %workgroup_size_x]
-            scf.for %arg2 = %7 to %c32 step %8 {
-              %9 = affine.apply affine_map<(d0) -> (d0 * 2)>(%arg0)
-              %10 = affine.min affine_map<(d0)[s0] -> (s0 * 2 + 1, d0 * -2 + 225)>(%arg0)[%workgroup_size_z]
-              %11 = affine.apply affine_map<(d0) -> (d0 * 2)>(%arg1)
-              %12 = affine.min affine_map<(d0)[s0] -> (s0 * 2 + 1, d0 * -2 + 225)>(%arg1)[%workgroup_size_y]
-              %13 = flow.dispatch.tensor.load %0, offsets = [0, %9, %11, 0], sizes = [1, %10, %12, 3], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:1x225x225x3xf32> -> tensor<1x?x?x3xf32>
-              %14 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 32)>(%arg2)[%workgroup_size_x]
-              %15 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0, %arg2], sizes = [3, 3, 3, %14], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:3x3x3x32xf32> -> tensor<3x3x3x?xf32>
-              %16 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 112)>(%arg0)[%workgroup_size_z]
-              %17 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 112)>(%arg1)[%workgroup_size_y]
-              %18 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 32)>(%arg2)[%workgroup_size_x]
-              %19 = affine.min affine_map<(d0)[s0] -> (-d0 + 112, s0)>(%arg0)[%workgroup_size_z]
-              %20 = affine.min affine_map<(d0)[s0] -> (-d0 + 112, s0)>(%arg1)[%workgroup_size_y]
-              %21 = affine.min affine_map<(d0)[s0] -> (-d0 + 32, s0)>(%arg2)[%workgroup_size_x]
-              %22 = linalg.init_tensor [1, %19, %20, %21] : tensor<1x?x?x?xf32>
-              %23 = linalg.fill(%cst, %22) : f32, tensor<1x?x?x?xf32> -> tensor<1x?x?x?xf32>
-              %24 = linalg.conv_2d_nhwc_hwcf {__internal_linalg_transform__ = "workgroup", dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%13, %15 : tensor<1x?x?x3xf32>, tensor<3x3x3x?xf32>) outs(%23 : tensor<1x?x?x?xf32>) -> tensor<1x?x?x?xf32>
-              flow.dispatch.tensor.store %24, %2, offsets = [0, %arg0, %arg1, %arg2], sizes = [1, %16, %17, %18], strides = [1, 1, 1, 1] : tensor<1x?x?x?xf32> -> !flow.dispatch.tensor<writeonly:1x112x112x32xf32>
-            }
-          }
-        }
+        %13 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0, 0], sizes = [1, 225, 225, 3], strides = [1, 1, 1, 1]
+            : !flow.dispatch.tensor<readonly:1x225x225x3xf32> -> tensor<1x225x225x3xf32>
+        %15 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0, 0], sizes = [3, 3, 3, 32], strides = [1, 1, 1, 1]
+            : !flow.dispatch.tensor<readonly:3x3x3x32xf32> -> tensor<3x3x3x32xf32>
+        %22 = linalg.init_tensor [1, 112, 112, 32] : tensor<1x112x112x32xf32>
+        %23 = linalg.fill(%cst, %22) : f32, tensor<1x112x112x32xf32> -> tensor<1x112x112x32xf32>
+        %24 = linalg.conv_2d_nhwc_hwcf {__internal_linalg_transform__ = "workgroup", dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>}
+            ins(%13, %15 : tensor<1x225x225x3xf32>, tensor<3x3x3x32xf32>) outs(%23 : tensor<1x112x112x32xf32>) -> tensor<1x112x112x32xf32>
+        flow.dispatch.tensor.store %24, %2, offsets = [0, 0, 0, 0], sizes = [1, 112, 112, 32], strides = [1, 1, 1, 1]
+            : tensor<1x112x112x32xf32> -> !flow.dispatch.tensor<writeonly:1x112x112x32xf32>
         return
       }
     }
@@ -161,19 +98,10 @@
 }
 
 //  CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering.config<tile_sizes = {{\[}}[0, 4, 16, 32], [0, 4, 2, 4], [0, 0, 0, 0, 1, 1, 4]{{\]}}, native_vector_size = []>
-//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"SPIRVVectorize", workload_per_wg = [32, 16, 4]>
-//  CHECK-DAG: #[[MAP_X:.+]] = affine_map<()[s0] -> (s0 ceildiv 32)
-//  CHECK-DAG: #[[MAP_Y:.+]] = affine_map<()[s0] -> (s0 ceildiv 16)>
-//  CHECK-DAG: #[[MAP_Z:.+]] = affine_map<()[s0] -> (s0 ceildiv 4)>
+//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"SPIRVVectorize", workload_per_wg = []>
 //      CHECK: hal.executable.entry_point public @conv_112x112x32
 // CHECK-SAME:   translation.info = #[[TRANSLATION]]
 // CHECK-SAME:   workgroup_size = [8 : index, 8 : index, 1 : index]
-// CHECK-NEXT: ^{{.+}}(%[[X:.+]]: index, %[[Y:.+]]: index, %[[Z:.+]]: index):
-// CHECK-NEXT:   %[[X_COUNT:.+]] = affine.apply #[[MAP_X]]()[%[[X]]]
-// CHECK-NEXT:   %[[Y_COUNT:.+]] = affine.apply #[[MAP_Y]]()[%[[Y]]]
-// CHECK-NEXT:   %[[Z_COUNT:.+]] = affine.apply #[[MAP_Z]]()[%[[Z]]]
-// CHECK-NEXT:   hal.return %[[X_COUNT]], %[[Y_COUNT]], %[[Z_COUNT]]
-
 //      CHECK: func @conv_112x112x32()
 //      CHECK:   linalg.conv_2d_nhwc_hwcf
 // CHECK-SAME:     lowering.config = #[[CONFIG]]
@@ -206,44 +134,16 @@
         %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) : !flow.dispatch.tensor<readonly:1x33x33x3xf32>
         %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) : !flow.dispatch.tensor<readonly:3x3x3x16xf32>
         %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) : !flow.dispatch.tensor<writeonly:1x16x16x16xf32>
-        %workgroup_size_x = hal.interface.workgroup.size[0] : index
-        %workgroup_size_y = hal.interface.workgroup.size[1] : index
-        %workgroup_size_z = hal.interface.workgroup.size[2] : index
-        %workgroup_id_x = hal.interface.workgroup.id[0] : index
-        %workgroup_count_x = hal.interface.workgroup.count[0] : index
-        %workgroup_id_y = hal.interface.workgroup.id[1] : index
-        %workgroup_count_y = hal.interface.workgroup.count[1] : index
-        %workgroup_id_z = hal.interface.workgroup.id[2] : index
-        %workgroup_count_z = hal.interface.workgroup.count[2] : index
-        %3 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_z, %workgroup_size_z]
-        %4 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_z, %workgroup_size_z]
-        scf.for %arg0 = %3 to %c16 step %4 {
-          %5 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_y, %workgroup_size_y]
-          %6 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_y, %workgroup_size_y]
-          scf.for %arg1 = %5 to %c16 step %6 {
-            %7 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_x, %workgroup_size_x]
-            %8 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_x, %workgroup_size_x]
-            scf.for %arg2 = %7 to %c16 step %8 {
-              %9 = affine.apply affine_map<(d0) -> (d0 * 2)>(%arg0)
-              %10 = affine.min affine_map<(d0)[s0] -> (s0 * 2 + 1, d0 * -2 + 33)>(%arg0)[%workgroup_size_z]
-              %11 = affine.apply affine_map<(d0) -> (d0 * 2)>(%arg1)
-              %12 = affine.min affine_map<(d0)[s0] -> (s0 * 2 + 1, d0 * -2 + 33)>(%arg1)[%workgroup_size_y]
-              %13 = flow.dispatch.tensor.load %0, offsets = [0, %9, %11, 0], sizes = [1, %10, %12, 3], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:1x33x33x3xf32> -> tensor<1x?x?x3xf32>
-              %14 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 16)>(%arg2)[%workgroup_size_x]
-              %15 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0, %arg2], sizes = [3, 3, 3, %14], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:3x3x3x16xf32> -> tensor<3x3x3x?xf32>
-              %16 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 16)>(%arg0)[%workgroup_size_z]
-              %17 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 16)>(%arg1)[%workgroup_size_y]
-              %18 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 16)>(%arg2)[%workgroup_size_x]
-              %19 = affine.min affine_map<(d0)[s0] -> (-d0 + 16, s0)>(%arg0)[%workgroup_size_z]
-              %20 = affine.min affine_map<(d0)[s0] -> (-d0 + 16, s0)>(%arg1)[%workgroup_size_y]
-              %21 = affine.min affine_map<(d0)[s0] -> (-d0 + 16, s0)>(%arg2)[%workgroup_size_x]
-              %22 = linalg.init_tensor [1, %19, %20, %21] : tensor<1x?x?x?xf32>
-              %23 = linalg.fill(%cst, %22) : f32, tensor<1x?x?x?xf32> -> tensor<1x?x?x?xf32>
-              %24 = linalg.conv_2d_nhwc_hwcf {__internal_linalg_transform__ = "workgroup", dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%13, %15 : tensor<1x?x?x3xf32>, tensor<3x3x3x?xf32>) outs(%23 : tensor<1x?x?x?xf32>) -> tensor<1x?x?x?xf32>
-              flow.dispatch.tensor.store %24, %2, offsets = [0, %arg0, %arg1, %arg2], sizes = [1, %16, %17, %18], strides = [1, 1, 1, 1] : tensor<1x?x?x?xf32> -> !flow.dispatch.tensor<writeonly:1x16x16x16xf32>
-            }
-          }
-        }
+        %13 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0, 0], sizes = [1, 33, 33, 3], strides = [1, 1, 1, 1]
+            : !flow.dispatch.tensor<readonly:1x33x33x3xf32> -> tensor<1x33x33x3xf32>
+        %15 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0, 0], sizes = [3, 3, 3, 16], strides = [1, 1, 1, 1]
+            : !flow.dispatch.tensor<readonly:3x3x3x16xf32> -> tensor<3x3x3x16xf32>
+        %22 = linalg.init_tensor [1, 16, 16, 16] : tensor<1x16x16x16xf32>
+        %23 = linalg.fill(%cst, %22) : f32, tensor<1x16x16x16xf32> -> tensor<1x16x16x16xf32>
+        %24 = linalg.conv_2d_nhwc_hwcf {__internal_linalg_transform__ = "workgroup", dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>}
+            ins(%13, %15 : tensor<1x33x33x3xf32>, tensor<3x3x3x16xf32>) outs(%23 : tensor<1x16x16x16xf32>) -> tensor<1x16x16x16xf32>
+        flow.dispatch.tensor.store %24, %2, offsets = [0, 0, 0, 0], sizes = [1, 16, 16, 16], strides = [1, 1, 1, 1]
+            : tensor<1x16x16x16xf32> -> !flow.dispatch.tensor<writeonly:1x16x16x16xf32>
         return
       }
     }
@@ -251,17 +151,10 @@
 }
 
 //  CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering.config<tile_sizes = {{\[}}[0, 8, 8, 16], [0, 2, 2, 4], [0, 0, 0, 0, 1, 1, 4]{{\]}}, native_vector_size = []>
-//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"SPIRVVectorize", workload_per_wg = [16, 8, 8]>
-//  CHECK-DAG: #[[MAP_X:.+]] = affine_map<()[s0] -> (s0 ceildiv 16)
-//  CHECK-DAG: #[[MAP_YZ:.+]] = affine_map<()[s0] -> (s0 ceildiv 8)>
+//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"SPIRVVectorize", workload_per_wg = []>
 //      CHECK: hal.executable.entry_point public @conv_16x16x16
 // CHECK-SAME:   translation.info = #[[TRANSLATION]]
 // CHECK-SAME:   workgroup_size = [4 : index, 4 : index, 4 : index]
-// CHECK-NEXT: ^{{.+}}(%[[X:.+]]: index, %[[Y:.+]]: index, %[[Z:.+]]: index):
-// CHECK-NEXT:   %[[X_COUNT:.+]] = affine.apply #[[MAP_X]]()[%[[X]]]
-// CHECK-NEXT:   %[[Y_COUNT:.+]] = affine.apply #[[MAP_YZ]]()[%[[Y]]]
-// CHECK-NEXT:   %[[Z_COUNT:.+]] = affine.apply #[[MAP_YZ]]()[%[[Z]]]
-// CHECK-NEXT:   hal.return %[[X_COUNT]], %[[Y_COUNT]], %[[Z_COUNT]]
 
 //      CHECK: func @conv_16x16x16()
 //      CHECK:   linalg.conv_2d_nhwc_hwcf
@@ -296,45 +189,16 @@
         %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) : !flow.dispatch.tensor<readonly:1x57x57x144xf32>
         %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) : !flow.dispatch.tensor<readonly:3x3x144xf32>
         %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) : !flow.dispatch.tensor<writeonly:1x28x28x144xf32>
-        %workgroup_size_x = hal.interface.workgroup.size[0] : index
-        %workgroup_size_y = hal.interface.workgroup.size[1] : index
-        %workgroup_size_z = hal.interface.workgroup.size[2] : index
-        %workgroup_id_x = hal.interface.workgroup.id[0] : index
-        %workgroup_count_x = hal.interface.workgroup.count[0] : index
-        %workgroup_id_y = hal.interface.workgroup.id[1] : index
-        %workgroup_count_y = hal.interface.workgroup.count[1] : index
-        %workgroup_id_z = hal.interface.workgroup.id[2] : index
-        %workgroup_count_z = hal.interface.workgroup.count[2] : index
-        %3 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_z, %workgroup_size_z]
-        %4 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_z, %workgroup_size_z]
-        scf.for %arg0 = %3 to %c28 step %4 {
-          %5 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_y, %workgroup_size_y]
-          %6 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_y, %workgroup_size_y]
-          scf.for %arg1 = %5 to %c28 step %6 {
-            %7 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_x, %workgroup_size_x]
-            %8 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_x, %workgroup_size_x]
-            scf.for %arg2 = %7 to %c144 step %8 {
-              %9 = affine.apply affine_map<(d0) -> (d0 * 2)>(%arg0)
-              %10 = affine.min affine_map<(d0)[s0] -> (s0 * 2 + 1, d0 * -2 + 57)>(%arg0)[%workgroup_size_z]
-              %11 = affine.apply affine_map<(d0) -> (d0 * 2)>(%arg1)
-              %12 = affine.min affine_map<(d0)[s0] -> (s0 * 2 + 1, d0 * -2 + 57)>(%arg1)[%workgroup_size_y]
-              %13 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 144)>(%arg2)[%workgroup_size_x]
-              %14 = flow.dispatch.tensor.load %0, offsets = [0, %9, %11, %arg2], sizes = [1, %10, %12, %13], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:1x57x57x144xf32> -> tensor<1x?x?x?xf32>
-              %15 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 144)>(%arg2)[%workgroup_size_x]
-              %16 = flow.dispatch.tensor.load %1, offsets = [0, 0, %arg2], sizes = [3, 3, %15], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:3x3x144xf32> -> tensor<3x3x?xf32>
-              %17 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 28)>(%arg0)[%workgroup_size_z]
-              %18 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 28)>(%arg1)[%workgroup_size_y]
-              %19 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 144)>(%arg2)[%workgroup_size_x]
-              %20 = affine.min affine_map<(d0)[s0] -> (-d0 + 28, s0)>(%arg0)[%workgroup_size_z]
-              %21 = affine.min affine_map<(d0)[s0] -> (-d0 + 28, s0)>(%arg1)[%workgroup_size_y]
-              %22 = affine.min affine_map<(d0)[s0] -> (-d0 + 144, s0)>(%arg2)[%workgroup_size_x]
-              %23 = linalg.init_tensor [1, %20, %21, %22] : tensor<1x?x?x?xf32>
-              %24 = linalg.fill(%cst, %23) : f32, tensor<1x?x?x?xf32> -> tensor<1x?x?x?xf32>
-              %25 = linalg.depthwise_conv_2d_nhwc_hwc {__internal_linalg_transform__ = "workgroup", dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%14, %16 : tensor<1x?x?x?xf32>, tensor<3x3x?xf32>) outs(%24 : tensor<1x?x?x?xf32>) -> tensor<1x?x?x?xf32>
-              flow.dispatch.tensor.store %25, %2, offsets = [0, %arg0, %arg1, %arg2], sizes = [1, %17, %18, %19], strides = [1, 1, 1, 1] : tensor<1x?x?x?xf32> -> !flow.dispatch.tensor<writeonly:1x28x28x144xf32>
-            }
-          }
-        }
+        %14 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0, 0], sizes = [1, 57, 57, 144], strides = [1, 1, 1, 1]
+            : !flow.dispatch.tensor<readonly:1x57x57x144xf32> -> tensor<1x57x57x144xf32>
+        %16 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0], sizes = [3, 3, 144], strides = [1, 1, 1]
+            : !flow.dispatch.tensor<readonly:3x3x144xf32> -> tensor<3x3x144xf32>
+        %23 = linalg.init_tensor [1, 28, 28, 144] : tensor<1x28x28x144xf32>
+        %24 = linalg.fill(%cst, %23) : f32, tensor<1x28x28x144xf32> -> tensor<1x28x28x144xf32>
+        %25 = linalg.depthwise_conv_2d_nhwc_hwc {__internal_linalg_transform__ = "workgroup", dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>}
+                  ins(%14, %16 : tensor<1x57x57x144xf32>, tensor<3x3x144xf32>) outs(%24 : tensor<1x28x28x144xf32>) -> tensor<1x28x28x144xf32>
+        flow.dispatch.tensor.store %25, %2, offsets = [0, 0, 0, 0], sizes = [1, 28, 28, 144], strides = [1, 1, 1, 1]
+            : tensor<1x28x28x144xf32> -> !flow.dispatch.tensor<writeonly:1x28x28x144xf32>
         return
       }
     }
@@ -342,18 +206,10 @@
 }
 
 //  CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering.config<tile_sizes = {{\[}}[0, 4, 4, 16], [0, 1, 1, 4], [0, 0, 0, 0, 1, 1]{{\]}}, native_vector_size = []>
-//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"SPIRVVectorize", workload_per_wg = [16, 4, 4]>
-//  CHECK-DAG: #[[MAP_X:.+]] = affine_map<()[s0] -> (s0 ceildiv 16)
-//  CHECK-DAG: #[[MAP_YZ:.+]] = affine_map<()[s0] -> (s0 ceildiv 4)>
+//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"SPIRVVectorize", workload_per_wg = []>
 //      CHECK: hal.executable.entry_point public @dwconv_28x28x144
 // CHECK-SAME:   translation.info = #[[TRANSLATION]]
 // CHECK-SAME:   workgroup_size = [4 : index, 4 : index, 4 : index]
-// CHECK-NEXT: ^{{.+}}(%[[X:.+]]: index, %[[Y:.+]]: index, %[[Z:.+]]: index):
-// CHECK-NEXT:   %[[X_COUNT:.+]] = affine.apply #[[MAP_X]]()[%[[X]]]
-// CHECK-NEXT:   %[[Y_COUNT:.+]] = affine.apply #[[MAP_YZ]]()[%[[Y]]]
-// CHECK-NEXT:   %[[Z_COUNT:.+]] = affine.apply #[[MAP_YZ]]()[%[[Z]]]
-// CHECK-NEXT:   hal.return %[[X_COUNT]], %[[Y_COUNT]], %[[Z_COUNT]]
-
 //      CHECK: func @dwconv_28x28x144()
 //      CHECK:   linalg.depthwise_conv_2d_nhwc_hwc
 // CHECK-SAME:     lowering.config = #[[CONFIG]]
@@ -387,63 +243,26 @@
         %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) : !flow.dispatch.tensor<readonly:1x9x9x8xf32>
         %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) : !flow.dispatch.tensor<readonly:3x3x8xf32>
         %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) : !flow.dispatch.tensor<writeonly:1x4x4x8xf32>
-        %workgroup_size_x = hal.interface.workgroup.size[0] : index
-        %workgroup_size_y = hal.interface.workgroup.size[1] : index
-        %workgroup_size_z = hal.interface.workgroup.size[2] : index
-        %workgroup_id_x = hal.interface.workgroup.id[0] : index
-        %workgroup_count_x = hal.interface.workgroup.count[0] : index
-        %workgroup_id_y = hal.interface.workgroup.id[1] : index
-        %workgroup_count_y = hal.interface.workgroup.count[1] : index
-        %workgroup_id_z = hal.interface.workgroup.id[2] : index
-        %workgroup_count_z = hal.interface.workgroup.count[2] : index
-        %3 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_z, %workgroup_size_z]
-        %4 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_z, %workgroup_size_z]
-        scf.for %arg0 = %3 to %c4 step %4 {
-          %5 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_y, %workgroup_size_y]
-          %6 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_y, %workgroup_size_y]
-          scf.for %arg1 = %5 to %c4 step %6 {
-            %7 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_x, %workgroup_size_x]
-            %8 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_x, %workgroup_size_x]
-            scf.for %arg2 = %7 to %c8 step %8 {
-              %9 = affine.apply affine_map<(d0) -> (d0 * 2)>(%arg0)
-              %10 = affine.min affine_map<(d0)[s0] -> (s0 * 2 + 1, d0 * -2 + 9)>(%arg0)[%workgroup_size_z]
-              %11 = affine.apply affine_map<(d0) -> (d0 * 2)>(%arg1)
-              %12 = affine.min affine_map<(d0)[s0] -> (s0 * 2 + 1, d0 * -2 + 9)>(%arg1)[%workgroup_size_y]
-              %13 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 8)>(%arg2)[%workgroup_size_x]
-              %14 = flow.dispatch.tensor.load %0, offsets = [0, %9, %11, %arg2], sizes = [1, %10, %12, %13], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:1x9x9x8xf32> -> tensor<1x?x?x?xf32>
-              %15 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 8)>(%arg2)[%workgroup_size_x]
-              %16 = flow.dispatch.tensor.load %1, offsets = [0, 0, %arg2], sizes = [3, 3, %15], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:3x3x8xf32> -> tensor<3x3x?xf32>
-              %17 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 4)>(%arg0)[%workgroup_size_z]
-              %18 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 4)>(%arg1)[%workgroup_size_y]
-              %19 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 8)>(%arg2)[%workgroup_size_x]
-              %20 = affine.min affine_map<(d0)[s0] -> (-d0 + 4, s0)>(%arg0)[%workgroup_size_z]
-              %21 = affine.min affine_map<(d0)[s0] -> (-d0 + 4, s0)>(%arg1)[%workgroup_size_y]
-              %22 = affine.min affine_map<(d0)[s0] -> (-d0 + 8, s0)>(%arg2)[%workgroup_size_x]
-              %23 = linalg.init_tensor [1, %20, %21, %22] : tensor<1x?x?x?xf32>
-              %24 = linalg.fill(%cst, %23) : f32, tensor<1x?x?x?xf32> -> tensor<1x?x?x?xf32>
-              %25 = linalg.depthwise_conv_2d_nhwc_hwc {__internal_linalg_transform__ = "workgroup", dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%14, %16 : tensor<1x?x?x?xf32>, tensor<3x3x?xf32>) outs(%24 : tensor<1x?x?x?xf32>) -> tensor<1x?x?x?xf32>
-              flow.dispatch.tensor.store %25, %2, offsets = [0, %arg0, %arg1, %arg2], sizes = [1, %17, %18, %19], strides = [1, 1, 1, 1] : tensor<1x?x?x?xf32> -> !flow.dispatch.tensor<writeonly:1x4x4x8xf32>
-            }
-          }
-        }
+        %14 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0, 0], sizes = [1, 9, 9, 8], strides = [1, 1, 1, 1]
+            : !flow.dispatch.tensor<readonly:1x9x9x8xf32> -> tensor<1x9x9x8xf32>
+        %16 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0], sizes = [3, 3, 8], strides = [1, 1, 1]
+            : !flow.dispatch.tensor<readonly:3x3x8xf32> -> tensor<3x3x8xf32>
+        %23 = linalg.init_tensor [1, 4, 4, 8] : tensor<1x4x4x8xf32>
+        %24 = linalg.fill(%cst, %23) : f32, tensor<1x4x4x8xf32> -> tensor<1x4x4x8xf32>
+        %25 = linalg.depthwise_conv_2d_nhwc_hwc {__internal_linalg_transform__ = "workgroup", dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>}
+            ins(%14, %16 : tensor<1x9x9x8xf32>, tensor<3x3x8xf32>) outs(%24 : tensor<1x4x4x8xf32>) -> tensor<1x4x4x8xf32>
+        flow.dispatch.tensor.store %25, %2, offsets = [0, 0, 0, 0], sizes = [1, 4, 4, 8], strides = [1, 1, 1, 1]
+            : tensor<1x4x4x8xf32> -> !flow.dispatch.tensor<writeonly:1x4x4x8xf32>
         return
       }
     }
   }
 }
 //  CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering.config<tile_sizes = {{\[}}[0, 4, 4, 8], [0, 1, 1, 4], [0, 0, 0, 0, 1, 1]{{\]}}, native_vector_size = []>
-//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"SPIRVVectorize", workload_per_wg = [8, 4, 4]>
-//  CHECK-DAG: #[[MAP_X:.+]] = affine_map<()[s0] -> (s0 ceildiv 8)
-//  CHECK-DAG: #[[MAP_YZ:.+]] = affine_map<()[s0] -> (s0 ceildiv 4)>
+//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"SPIRVVectorize", workload_per_wg = []>
 //      CHECK: hal.executable.entry_point public @dwconv_4x4x8
 // CHECK-SAME:   translation.info = #[[TRANSLATION]]
 // CHECK-SAME:   workgroup_size = [2 : index, 4 : index, 4 : index]
-// CHECK-NEXT: ^{{.+}}(%[[X:.+]]: index, %[[Y:.+]]: index, %[[Z:.+]]: index):
-// CHECK-NEXT:   %[[X_COUNT:.+]] = affine.apply #[[MAP_X]]()[%[[X]]]
-// CHECK-NEXT:   %[[Y_COUNT:.+]] = affine.apply #[[MAP_YZ]]()[%[[Y]]]
-// CHECK-NEXT:   %[[Z_COUNT:.+]] = affine.apply #[[MAP_YZ]]()[%[[Z]]]
-// CHECK-NEXT:   hal.return %[[X_COUNT]], %[[Y_COUNT]], %[[Z_COUNT]]
-
 //      CHECK: func @dwconv_4x4x8()
 //      CHECK:   linalg.depthwise_conv_2d_nhwc_hwc
 // CHECK-SAME:     lowering.config = #[[CONFIG]]
diff --git a/iree/compiler/Codegen/SPIRV/test/config_adreno_matmul.mlir b/iree/compiler/Codegen/SPIRV/test/config_adreno_matmul.mlir
index d5cbb5f..10c9fef 100644
--- a/iree/compiler/Codegen/SPIRV/test/config_adreno_matmul.mlir
+++ b/iree/compiler/Codegen/SPIRV/test/config_adreno_matmul.mlir
@@ -27,32 +27,16 @@
         %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) : !flow.dispatch.tensor<readonly:1024x512xf32>
         %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) : !flow.dispatch.tensor<readonly:512x2048xf32>
         %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) : !flow.dispatch.tensor<writeonly:1024x2048xf32>
-        %workgroup_size_x = hal.interface.workgroup.size[0] : index
-        %workgroup_size_y = hal.interface.workgroup.size[1] : index
-        %workgroup_id_x = hal.interface.workgroup.id[0] : index
-        %workgroup_count_x = hal.interface.workgroup.count[0] : index
-        %workgroup_id_y = hal.interface.workgroup.id[1] : index
-        %workgroup_count_y = hal.interface.workgroup.count[1] : index
-        %3 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_y, %workgroup_size_y]
-        %4 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_y, %workgroup_size_y]
-        scf.for %arg0 = %3 to %c1024 step %4 {
-          %5 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_x, %workgroup_size_x]
-          %6 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_x, %workgroup_size_x]
-          scf.for %arg1 = %5 to %c2048 step %6 {
-            %7 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 1024)>(%arg0)[%workgroup_size_y]
-            %8 = flow.dispatch.tensor.load %0, offsets = [%arg0, 0], sizes = [%7, 512], strides = [1, 1] : !flow.dispatch.tensor<readonly:1024x512xf32> -> tensor<?x512xf32>
-            %9 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 2048)>(%arg1)[%workgroup_size_x]
-            %10 = flow.dispatch.tensor.load %1, offsets = [0, %arg1], sizes = [512, %9], strides = [1, 1] : !flow.dispatch.tensor<readonly:512x2048xf32> -> tensor<512x?xf32>
-            %11 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 1024)>(%arg0)[%workgroup_size_y]
-            %12 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 2048)>(%arg1)[%workgroup_size_x]
-            %13 = affine.min affine_map<(d0)[s0] -> (-d0 + 1024, s0)>(%arg0)[%workgroup_size_y]
-            %14 = affine.min affine_map<(d0)[s0] -> (-d0 + 2048, s0)>(%arg1)[%workgroup_size_x]
-            %15 = linalg.init_tensor [%13, %14] : tensor<?x?xf32>
-            %16 = linalg.fill(%cst, %15) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
-            %17 = linalg.matmul {__internal_linalg_transform__ = "workgroup"} ins(%8, %10 : tensor<?x512xf32>, tensor<512x?xf32>) outs(%16 : tensor<?x?xf32>) -> tensor<?x?xf32>
-            flow.dispatch.tensor.store %17, %2, offsets = [%arg0, %arg1], sizes = [%11, %12], strides = [1, 1] : tensor<?x?xf32> -> !flow.dispatch.tensor<writeonly:1024x2048xf32>
-          }
-        }
+        %8 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [1024, 512], strides = [1, 1]
+            : !flow.dispatch.tensor<readonly:1024x512xf32> -> tensor<1024x512xf32>
+        %10 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [512, 2048], strides = [1, 1]
+            : !flow.dispatch.tensor<readonly:512x2048xf32> -> tensor<512x2048xf32>
+        %15 = linalg.init_tensor [1024, 2048] : tensor<1024x2048xf32>
+        %16 = linalg.fill(%cst, %15) : f32, tensor<1024x2048xf32> -> tensor<1024x2048xf32>
+        %17 = linalg.matmul {__internal_linalg_transform__ = "workgroup"}
+            ins(%8, %10 : tensor<1024x512xf32>, tensor<512x2048xf32>) outs(%16 : tensor<1024x2048xf32>) -> tensor<1024x2048xf32>
+        flow.dispatch.tensor.store %17, %2, offsets = [0, 0], sizes = [1024, 2048], strides = [1, 1]
+            : tensor<1024x2048xf32> -> !flow.dispatch.tensor<writeonly:1024x2048xf32>
         return
       }
     }
@@ -60,18 +44,10 @@
 }
 
 //  CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering.config<tile_sizes = {{\[}}[32, 128], [16, 4], [0, 0, 4]{{\]}}, native_vector_size = []>
-//  CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 128)
-//  CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 ceildiv 32)>
-//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"SPIRVVectorize", workload_per_wg = [128, 32]>
+//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"SPIRVVectorize", workload_per_wg = []>
 //      CHECK: hal.executable.entry_point public @matmul_1024x2048x512
 // CHECK-SAME:   translation.info = #[[TRANSLATION]]
 // CHECK-SAME:   workgroup_size = [32 : index, 2 : index, 1 : index]
-// CHECK-NEXT: ^{{.+}}(%[[X:.+]]: index, %[[Y:.+]]: index, %{{.+}}: index):
-// CHECK-NEXT:   %[[ONE:.+]] = arith.constant 1 : index
-// CHECK-NEXT:   %[[X_COUNT:.+]] = affine.apply #[[MAP0]]()[%[[X]]]
-// CHECK-NEXT:   %[[Y_COUNT:.+]] = affine.apply #[[MAP1]]()[%[[Y]]]
-// CHECK-NEXT:   hal.return %[[X_COUNT]], %[[Y_COUNT]], %[[ONE]]
-
 //      CHECK: func @matmul_1024x2048x512()
 //      CHECK:   linalg.matmul
 // CHECK-SAME:     lowering.config = #[[CONFIG]]
@@ -105,32 +81,16 @@
         %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) : !flow.dispatch.tensor<readonly:3136x96xf32>
         %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) : !flow.dispatch.tensor<readonly:96x24xf32>
         %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) : !flow.dispatch.tensor<writeonly:3136x24xf32>
-        %workgroup_size_x = hal.interface.workgroup.size[0] : index
-        %workgroup_size_y = hal.interface.workgroup.size[1] : index
-        %workgroup_id_x = hal.interface.workgroup.id[0] : index
-        %workgroup_count_x = hal.interface.workgroup.count[0] : index
-        %workgroup_id_y = hal.interface.workgroup.id[1] : index
-        %workgroup_count_y = hal.interface.workgroup.count[1] : index
-        %3 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_y, %workgroup_size_y]
-        %4 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_y, %workgroup_size_y]
-        scf.for %arg0 = %3 to %c3136 step %4 {
-          %5 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_x, %workgroup_size_x]
-          %6 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_x, %workgroup_size_x]
-          scf.for %arg1 = %5 to %c24 step %6 {
-            %7 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 3136)>(%arg0)[%workgroup_size_y]
-            %8 = flow.dispatch.tensor.load %0, offsets = [%arg0, 0], sizes = [%7, 96], strides = [1, 1] : !flow.dispatch.tensor<readonly:3136x96xf32> -> tensor<?x96xf32>
-            %9 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 24)>(%arg1)[%workgroup_size_x]
-            %10 = flow.dispatch.tensor.load %1, offsets = [0, %arg1], sizes = [96, %9], strides = [1, 1] : !flow.dispatch.tensor<readonly:96x24xf32> -> tensor<96x?xf32>
-            %11 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 3136)>(%arg0)[%workgroup_size_y]
-            %12 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 24)>(%arg1)[%workgroup_size_x]
-            %13 = affine.min affine_map<(d0)[s0] -> (-d0 + 3136, s0)>(%arg0)[%workgroup_size_y]
-            %14 = affine.min affine_map<(d0)[s0] -> (-d0 + 24, s0)>(%arg1)[%workgroup_size_x]
-            %15 = linalg.init_tensor [%13, %14] : tensor<?x?xf32>
-            %16 = linalg.fill(%cst, %15) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
-            %17 = linalg.matmul {__internal_linalg_transform__ = "workgroup"} ins(%8, %10 : tensor<?x96xf32>, tensor<96x?xf32>) outs(%16 : tensor<?x?xf32>) -> tensor<?x?xf32>
-            flow.dispatch.tensor.store %17, %2, offsets = [%arg0, %arg1], sizes = [%11, %12], strides = [1, 1] : tensor<?x?xf32> -> !flow.dispatch.tensor<writeonly:3136x24xf32>
-          }
-        }
+        %8 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [3136, 96], strides = [1, 1]
+            : !flow.dispatch.tensor<readonly:3136x96xf32> -> tensor<3136x96xf32>
+        %10 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [96, 24], strides = [1, 1]
+            : !flow.dispatch.tensor<readonly:96x24xf32> -> tensor<96x24xf32>
+        %15 = linalg.init_tensor [3136, 24] : tensor<3136x24xf32>
+        %16 = linalg.fill(%cst, %15) : f32, tensor<3136x24xf32> -> tensor<3136x24xf32>
+        %17 = linalg.matmul {__internal_linalg_transform__ = "workgroup"}
+            ins(%8, %10 : tensor<3136x96xf32>, tensor<96x24xf32>) outs(%16 : tensor<3136x24xf32>) -> tensor<3136x24xf32>
+        flow.dispatch.tensor.store %17, %2, offsets = [0, 0], sizes = [3136, 24], strides = [1, 1]
+            : tensor<3136x24xf32> -> !flow.dispatch.tensor<writeonly:3136x24xf32>
         return
       }
     }
@@ -138,18 +98,10 @@
 }
 
 //  CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering.config<tile_sizes = {{\[}}[448, 8], [14, 4], [0, 0, 4]{{\]}}, native_vector_size = []>
-//  CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 8)>
-//  CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 ceildiv 448)>
-//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"SPIRVVectorize", workload_per_wg = [8, 448]>
+//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"SPIRVVectorize", workload_per_wg = []>
 //      CHECK: hal.executable.entry_point public @matmul_3136x24x96
 // CHECK-SAME:   translation.info = #[[TRANSLATION]]
 // CHECK-SAME:   workgroup_size = [2 : index, 32 : index, 1 : index]
-// CHECK-NEXT: ^{{.+}}(%[[X:.+]]: index, %[[Y:.+]]: index, %{{.+}}: index):
-// CHECK-NEXT:   %[[ONE:.+]] = arith.constant 1 : index
-// CHECK-NEXT:   %[[X_COUNT:.+]] = affine.apply #[[MAP0]]()[%[[X]]]
-// CHECK-NEXT:   %[[Y_COUNT:.+]] = affine.apply #[[MAP1]]()[%[[Y]]]
-// CHECK-NEXT:   hal.return %[[X_COUNT]], %[[Y_COUNT]], %[[ONE]]
-
 //      CHECK: func @matmul_3136x24x96()
 //      CHECK:   linalg.matmul
 // CHECK-SAME:     lowering.config = #[[CONFIG]]
@@ -183,32 +135,16 @@
         %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) : !flow.dispatch.tensor<readonly:196x192xf32>
         %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) : !flow.dispatch.tensor<readonly:192x64xf32>
         %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) : !flow.dispatch.tensor<writeonly:196x64xf32>
-        %workgroup_size_x = hal.interface.workgroup.size[0] : index
-        %workgroup_size_y = hal.interface.workgroup.size[1] : index
-        %workgroup_id_x = hal.interface.workgroup.id[0] : index
-        %workgroup_count_x = hal.interface.workgroup.count[0] : index
-        %workgroup_id_y = hal.interface.workgroup.id[1] : index
-        %workgroup_count_y = hal.interface.workgroup.count[1] : index
-        %3 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_y, %workgroup_size_y]
-        %4 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_y, %workgroup_size_y]
-        scf.for %arg0 = %3 to %c196 step %4 {
-          %5 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_x, %workgroup_size_x]
-          %6 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_x, %workgroup_size_x]
-          scf.for %arg1 = %5 to %c64 step %6 {
-            %7 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 196)>(%arg0)[%workgroup_size_y]
-            %8 = flow.dispatch.tensor.load %0, offsets = [%arg0, 0], sizes = [%7, 192], strides = [1, 1] : !flow.dispatch.tensor<readonly:196x192xf32> -> tensor<?x192xf32>
-            %9 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 64)>(%arg1)[%workgroup_size_x]
-            %10 = flow.dispatch.tensor.load %1, offsets = [0, %arg1], sizes = [192, %9], strides = [1, 1] : !flow.dispatch.tensor<readonly:192x64xf32> -> tensor<192x?xf32>
-            %11 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 196)>(%arg0)[%workgroup_size_y]
-            %12 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 64)>(%arg1)[%workgroup_size_x]
-            %13 = affine.min affine_map<(d0)[s0] -> (-d0 + 196, s0)>(%arg0)[%workgroup_size_y]
-            %14 = affine.min affine_map<(d0)[s0] -> (-d0 + 64, s0)>(%arg1)[%workgroup_size_x]
-            %15 = linalg.init_tensor [%13, %14] : tensor<?x?xf32>
-            %16 = linalg.fill(%cst, %15) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
-            %17 = linalg.matmul {__internal_linalg_transform__ = "workgroup"} ins(%8, %10 : tensor<?x192xf32>, tensor<192x?xf32>) outs(%16 : tensor<?x?xf32>) -> tensor<?x?xf32>
-            flow.dispatch.tensor.store %17, %2, offsets = [%arg0, %arg1], sizes = [%11, %12], strides = [1, 1] : tensor<?x?xf32> -> !flow.dispatch.tensor<writeonly:196x64xf32>
-          }
-        }
+        %8 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [196, 192], strides = [1, 1]
+            : !flow.dispatch.tensor<readonly:196x192xf32> -> tensor<196x192xf32>
+        %10 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [192, 64], strides = [1, 1]
+            : !flow.dispatch.tensor<readonly:192x64xf32> -> tensor<192x64xf32>
+        %15 = linalg.init_tensor [196, 64] : tensor<196x64xf32>
+        %16 = linalg.fill(%cst, %15) : f32, tensor<196x64xf32> -> tensor<196x64xf32>
+        %17 = linalg.matmul {__internal_linalg_transform__ = "workgroup"}
+            ins(%8, %10 : tensor<196x192xf32>, tensor<192x64xf32>) outs(%16 : tensor<196x64xf32>) -> tensor<196x64xf32>
+        flow.dispatch.tensor.store %17, %2, offsets = [0, 0], sizes = [196, 64], strides = [1, 1]
+            : tensor<196x64xf32> -> !flow.dispatch.tensor<writeonly:196x64xf32>
         return
       }
     }
@@ -216,18 +152,10 @@
 }
 
 //  CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering.config<tile_sizes = {{\[}}[28, 64], [7, 4], [0, 0, 8]{{\]}}, native_vector_size = []>
-//  CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 64)>
-//  CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 ceildiv 28)>
-//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"SPIRVVectorize", workload_per_wg = [64, 28]>
+//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"SPIRVVectorize", workload_per_wg = []>
 //      CHECK: hal.executable.entry_point public @matmul_196x64x192
 // CHECK-SAME:   translation.info = #[[TRANSLATION]]
 // CHECK-SAME:   workgroup_size = [16 : index, 4 : index, 1 : index]
-// CHECK-NEXT: ^{{.+}}(%[[X:.+]]: index, %[[Y:.+]]: index, %{{.+}}: index):
-// CHECK-NEXT:   %[[ONE:.+]] = arith.constant 1 : index
-// CHECK-NEXT:   %[[X_COUNT:.+]] = affine.apply #[[MAP0]]()[%[[X]]]
-// CHECK-NEXT:   %[[Y_COUNT:.+]] = affine.apply #[[MAP1]]()[%[[Y]]]
-// CHECK-NEXT:   hal.return %[[X_COUNT]], %[[Y_COUNT]], %[[ONE]]
-
 //      CHECK: func @matmul_196x64x192()
 //      CHECK:   linalg.matmul
 // CHECK-SAME:      lowering.config = #[[CONFIG]]
@@ -261,27 +189,8 @@
         %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) : memref<12544x16xf32>
         %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) : memref<16x96xf32>
         %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) : memref<12544x96xf32>
-        %workgroup_size_x = hal.interface.workgroup.size[0] : index
-        %workgroup_size_y = hal.interface.workgroup.size[1] : index
-        %workgroup_id_x = hal.interface.workgroup.id[0] : index
-        %workgroup_count_x = hal.interface.workgroup.count[0] : index
-        %workgroup_id_y = hal.interface.workgroup.id[1] : index
-        %workgroup_count_y = hal.interface.workgroup.count[1] : index
-        %3 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_y, %workgroup_size_y]
-        %4 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_y, %workgroup_size_y]
-        scf.for %arg0 = %3 to %c12544 step %4 {
-          %5 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_x, %workgroup_size_x]
-          %6 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_x, %workgroup_size_x]
-          scf.for %arg1 = %5 to %c96 step %6 {
-            %7 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 12544)>(%arg0)[%workgroup_size_y]
-            %8 = memref.subview %0[%arg0, 0] [%7, 16] [1, 1] : memref<12544x16xf32> to memref<?x16xf32, affine_map<(d0, d1)[s0] -> (d0 * 16 + s0 + d1)>>
-            %9 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 96)>(%arg1)[%workgroup_size_x]
-            %10 = memref.subview %1[0, %arg1] [16, %9] [1, 1] : memref<16x96xf32> to memref<16x?xf32, affine_map<(d0, d1)[s0] -> (d0 * 96 + s0 + d1)>>
-            %11 = memref.subview %2[%arg0, %arg1] [%7, %9] [1, 1] : memref<12544x96xf32> to memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d0 * 96 + s0 + d1)>>
-            linalg.fill(%cst, %11) : f32, memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d0 * 96 + s0 + d1)>>
-            linalg.matmul {__internal_linalg_transform__ = "workgroup"} ins(%8, %10 : memref<?x16xf32, affine_map<(d0, d1)[s0] -> (d0 * 16 + s0 + d1)>>, memref<16x?xf32, affine_map<(d0, d1)[s0] -> (d0 * 96 + s0 + d1)>>) outs(%11 : memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d0 * 96 + s0 + d1)>>)
-          }
-        }
+        linalg.fill(%cst, %2) : f32, memref<12544x96xf32>
+        linalg.matmul {__internal_linalg_transform__ = "workgroup"} ins(%0, %1 : memref<12544x16xf32>, memref<16x96xf32>) outs(%2 : memref<12544x96xf32>)
         return
       }
     }
@@ -289,18 +198,10 @@
 }
 
 //  CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering.config<tile_sizes = {{\[}}[128, 32], [16, 4], [0, 0, 4]{{\]}}, native_vector_size = []>
-//  CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 32)>
-//  CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 ceildiv 128)>
-//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"SPIRVVectorize", workload_per_wg = [32, 128]>
+//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"SPIRVVectorize", workload_per_wg = []>
 //      CHECK: hal.executable.entry_point public @matmul_12544x96x16
 // CHECK-SAME:   translation.info = #[[TRANSLATION]]
 // CHECK-SAME:   workgroup_size = [8 : index, 8 : index, 1 : index]
-// CHECK-NEXT: ^{{.+}}(%[[X:.+]]: index, %[[Y:.+]]: index, %{{.+}}: index):
-// CHECK-NEXT:   %[[ONE:.+]] = arith.constant 1 : index
-// CHECK-NEXT:   %[[X_COUNT:.+]] = affine.apply #[[MAP0]]()[%[[X]]]
-// CHECK-NEXT:   %[[Y_COUNT:.+]] = affine.apply #[[MAP1]]()[%[[Y]]]
-// CHECK-NEXT:   hal.return %[[X_COUNT]], %[[Y_COUNT]], %[[ONE]]
-
 //      CHECK: func @matmul_12544x96x16()
 //      CHECK:   linalg.matmul
 // CHECK-SAME:     lowering.config = #[[CONFIG]]
@@ -334,32 +235,14 @@
         %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) : !flow.dispatch.tensor<readonly:49x576xf32>
         %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) : !flow.dispatch.tensor<readonly:576x160xf32>
         %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) : !flow.dispatch.tensor<writeonly:49x160xf32>
-        %workgroup_size_x = hal.interface.workgroup.size[0] : index
-        %workgroup_size_y = hal.interface.workgroup.size[1] : index
-        %workgroup_id_x = hal.interface.workgroup.id[0] : index
-        %workgroup_count_x = hal.interface.workgroup.count[0] : index
-        %workgroup_id_y = hal.interface.workgroup.id[1] : index
-        %workgroup_count_y = hal.interface.workgroup.count[1] : index
-        %3 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_y, %workgroup_size_y]
-        %4 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_y, %workgroup_size_y]
-        scf.for %arg0 = %3 to %c49 step %4 {
-          %5 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_x, %workgroup_size_x]
-          %6 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_x, %workgroup_size_x]
-          scf.for %arg1 = %5 to %c160 step %6 {
-            %7 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 49)>(%arg0)[%workgroup_size_y]
-            %8 = flow.dispatch.tensor.load %0, offsets = [%arg0, 0], sizes = [%7, 576], strides = [1, 1] : !flow.dispatch.tensor<readonly:49x576xf32> -> tensor<?x576xf32>
-            %9 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 160)>(%arg1)[%workgroup_size_x]
-            %10 = flow.dispatch.tensor.load %1, offsets = [0, %arg1], sizes = [576, %9], strides = [1, 1] : !flow.dispatch.tensor<readonly:576x160xf32> -> tensor<576x?xf32>
-            %11 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 49)>(%arg0)[%workgroup_size_y]
-            %12 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 160)>(%arg1)[%workgroup_size_x]
-            %13 = affine.min affine_map<(d0)[s0] -> (-d0 + 49, s0)>(%arg0)[%workgroup_size_y]
-            %14 = affine.min affine_map<(d0)[s0] -> (-d0 + 160, s0)>(%arg1)[%workgroup_size_x]
-            %15 = linalg.init_tensor [%13, %14] : tensor<?x?xf32>
-            %16 = linalg.fill(%cst, %15) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
-            %17 = linalg.matmul {__internal_linalg_transform__ = "workgroup"} ins(%8, %10 : tensor<?x576xf32>, tensor<576x?xf32>) outs(%16 : tensor<?x?xf32>) -> tensor<?x?xf32>
-            flow.dispatch.tensor.store %17, %2, offsets = [%arg0, %arg1], sizes = [%11, %12], strides = [1, 1] : tensor<?x?xf32> -> !flow.dispatch.tensor<writeonly:49x160xf32>
-          }
-        }
+        %8 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [49, 576], strides = [1, 1] : !flow.dispatch.tensor<readonly:49x576xf32> -> tensor<49x576xf32>
+        %10 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [576, 160], strides = [1, 1] : !flow.dispatch.tensor<readonly:576x160xf32> -> tensor<576x160xf32>
+        %15 = linalg.init_tensor [49, 160] : tensor<49x160xf32>
+        %16 = linalg.fill(%cst, %15) : f32, tensor<49x160xf32> -> tensor<49x160xf32>
+        %17 = linalg.matmul {__internal_linalg_transform__ = "workgroup"}
+            ins(%8, %10 : tensor<49x576xf32>, tensor<576x160xf32>) outs(%16 : tensor<49x160xf32>) -> tensor<49x160xf32>
+        flow.dispatch.tensor.store %17, %2, offsets = [0, 0], sizes = [49, 160], strides = [1, 1]
+            : tensor<49x160xf32> -> !flow.dispatch.tensor<writeonly:49x160xf32>
         return
       }
     }
@@ -367,18 +250,10 @@
 }
 
 //  CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering.config<tile_sizes = {{\[}}[7, 32], [7, 4], [0, 0, 8]{{\]}}, native_vector_size = []>
-//  CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 32)>
-//  CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 ceildiv 7)>
-//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"SPIRVVectorize", workload_per_wg = [32, 7]>
+//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"SPIRVVectorize", workload_per_wg = []>
 //      CHECK: hal.executable.entry_point public @matmul_49x160x576
 // CHECK-SAME:   translation.info = #[[TRANSLATION]]
 // CHECK-SAME:   workgroup_size = [8 : index, 1 : index, 1 : index]
-// CHECK-NEXT: ^{{.+}}(%[[X:.+]]: index, %[[Y:.+]]: index, %{{.+}}: index):
-// CHECK-NEXT:   %[[ONE:.+]] = arith.constant 1 : index
-// CHECK-NEXT:   %[[X_COUNT:.+]] = affine.apply #[[MAP0]]()[%[[X]]]
-// CHECK-NEXT:   %[[Y_COUNT:.+]] = affine.apply #[[MAP1]]()[%[[Y]]]
-// CHECK-NEXT:   hal.return %[[X_COUNT]], %[[Y_COUNT]], %[[ONE]]
-
 //      CHECK: func @matmul_49x160x576()
 //      CHECK:   linalg.matmul
 // CHECK-SAME:     lowering.config = #[[CONFIG]]
@@ -412,43 +287,16 @@
         %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) : !flow.dispatch.tensor<readonly:4x384x32xf32>
         %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) : !flow.dispatch.tensor<readonly:4x32x384xf32>
         %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) : !flow.dispatch.tensor<writeonly:4x384x384xf32>
-        %workgroup_size_x = hal.interface.workgroup.size[0] : index
-        %workgroup_size_y = hal.interface.workgroup.size[1] : index
-        %workgroup_size_z = hal.interface.workgroup.size[2] : index
-        %workgroup_id_x = hal.interface.workgroup.id[0] : index
-        %workgroup_count_x = hal.interface.workgroup.count[0] : index
-        %workgroup_id_y = hal.interface.workgroup.id[1] : index
-        %workgroup_count_y = hal.interface.workgroup.count[1] : index
-        %workgroup_id_z = hal.interface.workgroup.id[2] : index
-        %workgroup_count_z = hal.interface.workgroup.count[2] : index
-        %3 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_z, %workgroup_size_z]
-        %4 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_z, %workgroup_size_z]
-        scf.for %arg0 = %3 to %c4 step %4 {
-          %5 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_y, %workgroup_size_y]
-          %6 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_y, %workgroup_size_y]
-          scf.for %arg1 = %5 to %c384 step %6 {
-            %7 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_x, %workgroup_size_x]
-            %8 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_x, %workgroup_size_x]
-            scf.for %arg2 = %7 to %c384 step %8 {
-              %9 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 4)>(%arg0)[%workgroup_size_z]
-              %10 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 384)>(%arg1)[%workgroup_size_y]
-              %11 = flow.dispatch.tensor.load %0, offsets = [%arg0, %arg1, 0], sizes = [%9, %10, 32], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:4x384x32xf32> -> tensor<?x?x32xf32>
-              %12 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 4)>(%arg0)[%workgroup_size_z]
-              %13 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 384)>(%arg2)[%workgroup_size_x]
-              %14 = flow.dispatch.tensor.load %1, offsets = [%arg0, 0, %arg2], sizes = [%12, 32, %13], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:4x32x384xf32> -> tensor<?x32x?xf32>
-              %15 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 4)>(%arg0)[%workgroup_size_z]
-              %16 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 384)>(%arg1)[%workgroup_size_y]
-              %17 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 384)>(%arg2)[%workgroup_size_x]
-              %18 = affine.min affine_map<(d0)[s0] -> (-d0 + 4, s0)>(%arg0)[%workgroup_size_z]
-              %19 = affine.min affine_map<(d0)[s0] -> (-d0 + 384, s0)>(%arg1)[%workgroup_size_y]
-              %20 = affine.min affine_map<(d0)[s0] -> (-d0 + 384, s0)>(%arg2)[%workgroup_size_x]
-              %21 = linalg.init_tensor [%18, %19, %20] : tensor<?x?x?xf32>
-              %22 = linalg.fill(%cst, %21) : f32, tensor<?x?x?xf32> -> tensor<?x?x?xf32>
-              %23 = linalg.batch_matmul {__internal_linalg_transform__ = "workgroup"} ins(%11, %14 : tensor<?x?x32xf32>, tensor<?x32x?xf32>) outs(%22 : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
-              flow.dispatch.tensor.store %23, %2, offsets = [%arg0, %arg1, %arg2], sizes = [%15, %16, %17], strides = [1, 1, 1] : tensor<?x?x?xf32> -> !flow.dispatch.tensor<writeonly:4x384x384xf32>
-            }
-          }
-        }
+        %11 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0], sizes = [4, 384, 32], strides = [1, 1, 1]
+            : !flow.dispatch.tensor<readonly:4x384x32xf32> -> tensor<4x384x32xf32>
+        %14 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0], sizes = [4, 32, 384], strides = [1, 1, 1]
+            : !flow.dispatch.tensor<readonly:4x32x384xf32> -> tensor<4x32x384xf32>
+        %21 = linalg.init_tensor [4, 384, 384] : tensor<4x384x384xf32>
+        %22 = linalg.fill(%cst, %21) : f32, tensor<4x384x384xf32> -> tensor<4x384x384xf32>
+        %23 = linalg.batch_matmul {__internal_linalg_transform__ = "workgroup"}
+            ins(%11, %14 : tensor<4x384x32xf32>, tensor<4x32x384xf32>) outs(%22 : tensor<4x384x384xf32>) -> tensor<4x384x384xf32>
+        flow.dispatch.tensor.store %23, %2, offsets = [0, 0, 0], sizes = [4, 384, 384], strides = [1, 1, 1]
+            : tensor<4x384x384xf32> -> !flow.dispatch.tensor<writeonly:4x384x384xf32>
         return
       }
     }
@@ -456,17 +304,10 @@
 }
 
 //  CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering.config<tile_sizes = {{\[}}[1, 32, 128], [1, 16, 4], [0, 0, 0, 4]{{\]}}, native_vector_size = []>
-//  CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 128)>
-//  CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 ceildiv 32)>
-//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"SPIRVVectorize", workload_per_wg = [128, 32, 1]>
+//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"SPIRVVectorize", workload_per_wg = []>
 //      CHECK: hal.executable.entry_point public @batch_matmul_4x384x384
 // CHECK-SAME:   translation.info = #[[TRANSLATION]]
 // CHECK-SAME:   workgroup_size = [32 : index, 2 : index, 1 : index]
-// CHECK-NEXT: ^{{.+}}(%[[X:.+]]: index, %[[Y:.+]]: index, %[[Z:.+]]: index):
-// CHECK-NEXT:   %[[X_COUNT:.+]] = affine.apply #[[MAP0]]()[%[[X]]]
-// CHECK-NEXT:   %[[Y_COUNT:.+]] = affine.apply #[[MAP1]]()[%[[Y]]]
-// CHECK-NEXT:   hal.return %[[X_COUNT]], %[[Y_COUNT]], %[[Z]]
-
 //      CHECK: func @batch_matmul_4x384x384()
 //      CHECK:   linalg.batch_matmul
 // CHECK-SAME:     lowering.config = #[[CONFIG]]
@@ -500,43 +341,16 @@
         %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) : !flow.dispatch.tensor<readonly:4x8x32xf32>
         %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) : !flow.dispatch.tensor<readonly:4x32x8xf32>
         %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) : !flow.dispatch.tensor<writeonly:4x8x8xf32>
-        %workgroup_size_x = hal.interface.workgroup.size[0] : index
-        %workgroup_size_y = hal.interface.workgroup.size[1] : index
-        %workgroup_size_z = hal.interface.workgroup.size[2] : index
-        %workgroup_id_x = hal.interface.workgroup.id[0] : index
-        %workgroup_count_x = hal.interface.workgroup.count[0] : index
-        %workgroup_id_y = hal.interface.workgroup.id[1] : index
-        %workgroup_count_y = hal.interface.workgroup.count[1] : index
-        %workgroup_id_z = hal.interface.workgroup.id[2] : index
-        %workgroup_count_z = hal.interface.workgroup.count[2] : index
-        %3 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_z, %workgroup_size_z]
-        %4 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_z, %workgroup_size_z]
-        scf.for %arg0 = %3 to %c4 step %4 {
-          %5 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_y, %workgroup_size_y]
-          %6 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_y, %workgroup_size_y]
-          scf.for %arg1 = %5 to %c8 step %6 {
-            %7 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_x, %workgroup_size_x]
-            %8 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_x, %workgroup_size_x]
-            scf.for %arg2 = %7 to %c8 step %8 {
-              %9 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 4)>(%arg0)[%workgroup_size_z]
-              %10 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 8)>(%arg1)[%workgroup_size_y]
-              %11 = flow.dispatch.tensor.load %0, offsets = [%arg0, %arg1, 0], sizes = [%9, %10, 32], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:4x8x32xf32> -> tensor<?x?x32xf32>
-              %12 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 4)>(%arg0)[%workgroup_size_z]
-              %13 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 8)>(%arg2)[%workgroup_size_x]
-              %14 = flow.dispatch.tensor.load %1, offsets = [%arg0, 0, %arg2], sizes = [%12, 32, %13], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:4x32x8xf32> -> tensor<?x32x?xf32>
-              %15 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 4)>(%arg0)[%workgroup_size_z]
-              %16 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 8)>(%arg1)[%workgroup_size_y]
-              %17 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 8)>(%arg2)[%workgroup_size_x]
-              %18 = affine.min affine_map<(d0)[s0] -> (-d0 + 4, s0)>(%arg0)[%workgroup_size_z]
-              %19 = affine.min affine_map<(d0)[s0] -> (-d0 + 8, s0)>(%arg1)[%workgroup_size_y]
-              %20 = affine.min affine_map<(d0)[s0] -> (-d0 + 8, s0)>(%arg2)[%workgroup_size_x]
-              %21 = linalg.init_tensor [%18, %19, %20] : tensor<?x?x?xf32>
-              %22 = linalg.fill(%cst, %21) : f32, tensor<?x?x?xf32> -> tensor<?x?x?xf32>
-              %23 = linalg.batch_matmul {__internal_linalg_transform__ = "workgroup"} ins(%11, %14 : tensor<?x?x32xf32>, tensor<?x32x?xf32>) outs(%22 : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
-              flow.dispatch.tensor.store %23, %2, offsets = [%arg0, %arg1, %arg2], sizes = [%15, %16, %17], strides = [1, 1, 1] : tensor<?x?x?xf32> -> !flow.dispatch.tensor<writeonly:4x8x8xf32>
-            }
-          }
-        }
+        %11 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0], sizes = [4, 8, 32], strides = [1, 1, 1]
+            : !flow.dispatch.tensor<readonly:4x8x32xf32> -> tensor<4x8x32xf32>
+        %14 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0], sizes = [4, 32, 8], strides = [1, 1, 1]
+            : !flow.dispatch.tensor<readonly:4x32x8xf32> -> tensor<4x32x8xf32>
+        %21 = linalg.init_tensor [4, 8, 8] : tensor<4x8x8xf32>
+        %22 = linalg.fill(%cst, %21) : f32, tensor<4x8x8xf32> -> tensor<4x8x8xf32>
+        %23 = linalg.batch_matmul {__internal_linalg_transform__ = "workgroup"}
+            ins(%11, %14 : tensor<4x8x32xf32>, tensor<4x32x8xf32>) outs(%22 : tensor<4x8x8xf32>) -> tensor<4x8x8xf32>
+        flow.dispatch.tensor.store %23, %2, offsets = [0, 0, 0], sizes = [4, 8, 8], strides = [1, 1, 1]
+            : tensor<4x8x8xf32> -> !flow.dispatch.tensor<writeonly:4x8x8xf32>
         return
       }
     }
@@ -544,16 +358,10 @@
 }
 
 //  CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering.config<tile_sizes = {{\[}}[1, 8, 8], [1, 1, 4], [0, 0, 0, 16]{{\]}}, native_vector_size = []>
-//  CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0] -> (s0 ceildiv 8)>
-//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"SPIRVVectorize", workload_per_wg = [8, 8, 1]>
+//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"SPIRVVectorize", workload_per_wg = []>
 //      CHECK: hal.executable.entry_point public @batch_matmul_4x8x8
 // CHECK-SAME:   translation.info = #[[TRANSLATION]]
 // CHECK-SAME:   workgroup_size = [2 : index, 8 : index, 1 : index]
-// CHECK-NEXT: ^{{.+}}(%[[X:.+]]: index, %[[Y:.+]]: index, %[[Z:.+]]: index):
-// CHECK-NEXT:   %[[X_COUNT:.+]] = affine.apply #[[MAP]]()[%[[X]]]
-// CHECK-NEXT:   %[[Y_COUNT:.+]] = affine.apply #[[MAP]]()[%[[Y]]]
-// CHECK-NEXT:   hal.return %[[X_COUNT]], %[[Y_COUNT]], %[[Z]]
-
 //      CHECK: func @batch_matmul_4x8x8()
 //      CHECK:   linalg.batch_matmul
 // CHECK-SAME:     lowering.config = #[[CONFIG]]
diff --git a/iree/compiler/Codegen/SPIRV/test/config_default_conv.mlir b/iree/compiler/Codegen/SPIRV/test/config_default_conv.mlir
index 43a72cc..2f0e8be 100644
--- a/iree/compiler/Codegen/SPIRV/test/config_default_conv.mlir
+++ b/iree/compiler/Codegen/SPIRV/test/config_default_conv.mlir
@@ -37,50 +37,25 @@
         %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) : !flow.dispatch.tensor<readonly:1x225x225x3xf32>
         %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) : !flow.dispatch.tensor<readonly:3x3x3x32xf32>
         %3 = hal.interface.binding.subspan set(0) binding(3) type(storage_buffer) : !flow.dispatch.tensor<writeonly:1x112x112x32xf32>
-        %workgroup_size_x = hal.interface.workgroup.size[0] : index
-        %workgroup_size_y = hal.interface.workgroup.size[1] : index
-        %workgroup_size_z = hal.interface.workgroup.size[2] : index
-        %workgroup_id_x = hal.interface.workgroup.id[0] : index
-        %workgroup_count_x = hal.interface.workgroup.count[0] : index
-        %workgroup_id_y = hal.interface.workgroup.id[1] : index
-        %workgroup_count_y = hal.interface.workgroup.count[1] : index
-        %workgroup_id_z = hal.interface.workgroup.id[2] : index
-        %workgroup_count_z = hal.interface.workgroup.count[2] : index
-        %4 = affine.apply #map0()[%workgroup_id_z, %workgroup_size_z]
-        %5 = affine.apply #map0()[%workgroup_count_z, %workgroup_size_z]
-        scf.for %arg0 = %4 to %c112 step %5 {
-          %6 = affine.apply #map0()[%workgroup_id_y, %workgroup_size_y]
-          %7 = affine.apply #map0()[%workgroup_count_y, %workgroup_size_y]
-          scf.for %arg1 = %6 to %c112 step %7 {
-            %8 = affine.apply #map0()[%workgroup_id_x, %workgroup_size_x]
-            %9 = affine.apply #map0()[%workgroup_count_x, %workgroup_size_x]
-            scf.for %arg2 = %8 to %c32 step %9 {
-              %10 = affine.min #map1(%arg0)[%workgroup_size_z]
-              %11 = affine.min #map1(%arg1)[%workgroup_size_y]
-              %12 = affine.min #map2(%arg2)[%workgroup_size_x]
-              %13 = flow.dispatch.tensor.load %0, offsets = [0, %arg0, %arg1, %arg2], sizes = [1, %10, %11, %12], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:1x112x112x32xf32> -> tensor<1x?x?x?xf32>
-              %14 = linalg.init_tensor [1, %10, %11, %12] : tensor<1x?x?x?xf32>
-              %15 = affine.apply #map3(%arg0)
-              %16 = affine.min #map4(%10, %arg0)
-              %17 = affine.apply #map3(%arg1)
-              %18 = affine.min #map4(%11, %arg1)
-              %19 = flow.dispatch.tensor.load %1, offsets = [0, %15, %17, 0], sizes = [1, %16, %18, 3], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:1x225x225x3xf32> -> tensor<1x?x?x3xf32>
-              %20 = affine.min #map5(%arg2)[%workgroup_size_x]
-              %21 = flow.dispatch.tensor.load %2, offsets = [0, 0, 0, %arg2], sizes = [3, 3, 3, %20], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:3x3x3x32xf32> -> tensor<3x3x3x?xf32>
-              %22 = affine.min #map6(%arg0)[%workgroup_size_z]
-              %23 = affine.min #map6(%arg1)[%workgroup_size_y]
-              %24 = linalg.init_tensor [1, %22, %23, %20] : tensor<1x?x?x?xf32>
-              %25 = linalg.fill(%cst, %24) : f32, tensor<1x?x?x?xf32> -> tensor<1x?x?x?xf32>
-              %26 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%19, %21 : tensor<1x?x?x3xf32>, tensor<3x3x3x?xf32>) outs(%25 : tensor<1x?x?x?xf32>) -> tensor<1x?x?x?xf32>
-              %27 = linalg.generic {indexing_maps = [#map7, #map7, #map7], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%26, %13 : tensor<1x?x?x?xf32>, tensor<1x?x?x?xf32>) outs(%14 : tensor<1x?x?x?xf32>) {
-              ^bb0(%arg3: f32, %arg4: f32, %arg5: f32):  // no predecessors
-                %28 = arith.subf %arg3, %arg4 : f32
-                linalg.yield %28 : f32
-              } -> tensor<1x?x?x?xf32>
-              flow.dispatch.tensor.store %27, %3, offsets = [0, %arg0, %arg1, %arg2], sizes = [1, %10, %11, %12], strides = [1, 1, 1, 1] : tensor<1x?x?x?xf32> -> !flow.dispatch.tensor<writeonly:1x112x112x32xf32>
-            }
-          }
-        }
+        %13 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0, 0], sizes = [1, 112, 112, 32], strides = [1, 1, 1, 1]
+            : !flow.dispatch.tensor<readonly:1x112x112x32xf32> -> tensor<1x112x112x32xf32>
+        %14 = linalg.init_tensor [1, 112, 112, 32] : tensor<1x112x112x32xf32>
+        %19 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0, 0], sizes = [1, 225, 225, 3], strides = [1, 1, 1, 1]
+            : !flow.dispatch.tensor<readonly:1x225x225x3xf32> -> tensor<1x225x225x3xf32>
+        %21 = flow.dispatch.tensor.load %2, offsets = [0, 0, 0, 0], sizes = [3, 3, 3, 32], strides = [1, 1, 1, 1]
+            : !flow.dispatch.tensor<readonly:3x3x3x32xf32> -> tensor<3x3x3x32xf32>
+        %24 = linalg.init_tensor [1, 112, 112, 32] : tensor<1x112x112x32xf32>
+        %25 = linalg.fill(%cst, %24) : f32, tensor<1x112x112x32xf32> -> tensor<1x112x112x32xf32>
+        %26 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>}
+            ins(%19, %21 : tensor<1x225x225x3xf32>, tensor<3x3x3x32xf32>) outs(%25 : tensor<1x112x112x32xf32>) -> tensor<1x112x112x32xf32>
+        %27 = linalg.generic {indexing_maps = [#map7, #map7, #map7], iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
+            ins(%26, %13 : tensor<1x112x112x32xf32>, tensor<1x112x112x32xf32>) outs(%14 : tensor<1x112x112x32xf32>) {
+            ^bb0(%arg3: f32, %arg4: f32, %arg5: f32):  // no predecessors
+              %28 = arith.subf %arg3, %arg4 : f32
+              linalg.yield %28 : f32
+            } -> tensor<1x112x112x32xf32>
+        flow.dispatch.tensor.store %27, %3, offsets = [0, 0, 0, 0], sizes = [1, 112, 112, 32], strides = [1, 1, 1, 1]
+            : tensor<1x112x112x32xf32> -> !flow.dispatch.tensor<writeonly:1x112x112x32xf32>
         return
       }
     }
@@ -88,18 +63,10 @@
 }
 
 //  CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering.config<tile_sizes = {{\[}}[0, 4, 4, 32], [0, 2, 2, 4], [0, 0, 0, 0, 1, 1, 4]{{\]}}, native_vector_size = []>
-//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"SPIRVVectorize", workload_per_wg = [32, 4, 4]>
-//  CHECK-DAG: #[[MAP_X:.+]] = affine_map<()[s0] -> (s0 ceildiv 32)
-//  CHECK-DAG: #[[MAP_YZ:.+]] = affine_map<()[s0] -> (s0 ceildiv 4)>
+//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"SPIRVVectorize", workload_per_wg = []>
 //      CHECK: hal.executable.entry_point public @conv_pointwise_112x112x32
 // CHECK-SAME:   translation.info = #[[TRANSLATION]]
 // CHECK-SAME:   workgroup_size = [8 : index, 2 : index, 2 : index]
-// CHECK-NEXT: ^{{.+}}(%[[X:.+]]: index, %[[Y:.+]]: index, %[[Z:.+]]: index):
-// CHECK-NEXT:   %[[X_COUNT:.+]] = affine.apply #[[MAP_X]]()[%[[X]]]
-// CHECK-NEXT:   %[[Y_COUNT:.+]] = affine.apply #[[MAP_YZ]]()[%[[Y]]]
-// CHECK-NEXT:   %[[Z_COUNT:.+]] = affine.apply #[[MAP_YZ]]()[%[[Z]]]
-// CHECK-NEXT:   hal.return %[[X_COUNT]], %[[Y_COUNT]], %[[Z_COUNT]]
-
 //      CHECK: func @conv_pointwise_112x112x32()
 //      CHECK:   linalg.conv_2d_nhwc_hwcf
 // CHECK-SAME:     lowering.config = #[[CONFIG]]
diff --git a/iree/compiler/Codegen/SPIRV/test/config_default_linalg_ext_ops.mlir b/iree/compiler/Codegen/SPIRV/test/config_default_linalg_ext_ops.mlir
index 777e9bf..4500dae 100644
--- a/iree/compiler/Codegen/SPIRV/test/config_default_linalg_ext_ops.mlir
+++ b/iree/compiler/Codegen/SPIRV/test/config_default_linalg_ext_ops.mlir
@@ -38,10 +38,6 @@
 //       CHECK: hal.executable.entry_point public @static_1d_sort
 //  CHECK-SAME:   translation.info = #[[TRANSLATION]]
 //  CHECK-SAME:   workgroup_size = [1 : index, 1 : index, 1 : index]
-//  CHECK-NEXT: ^{{.+}}(%{{.+}}: index, %{{.+}}: index, %{{.+}}: index):
-//  CHECK-NEXT:   %[[ONE:.+]] = arith.constant 1 : index
-//  CHECK-NEXT:   hal.return %[[ONE]], %[[ONE]], %[[ONE]]
-
 //       CHECK: func @static_1d_sort()
 //       CHECK:   iree_linalg_ext.sort
 //  CHECK-SAME:     lowering.config = #[[CONFIG]]
@@ -70,35 +66,16 @@
         %c0 = arith.constant 0 : index
         %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) : memref<64x32x128xi32>
         %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) : memref<64x32x128xi32>
-        %workgroup_size_x = hal.interface.workgroup.size[0] : index
-        %workgroup_size_y = hal.interface.workgroup.size[1] : index
-        %workgroup_id_x = hal.interface.workgroup.id[0] : index
-        %workgroup_count_x = hal.interface.workgroup.count[0] : index
-        %workgroup_id_y = hal.interface.workgroup.id[1] : index
-        %workgroup_count_y = hal.interface.workgroup.count[1] : index
-        %2 = affine.apply affine_map<()[s0, s1] -> (s1 * s0)>()[%workgroup_size_y, %workgroup_id_y]
-        %3 = affine.apply affine_map<()[s0, s1] -> (s1 * s0)>()[%workgroup_size_y, %workgroup_count_y]
-        scf.for %arg0 = %2 to %c64 step %3 {
-          %4 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 64)>(%arg0)[%workgroup_size_y]
-          %5 = affine.apply affine_map<()[s0, s1] -> (s1 * s0)>()[%workgroup_size_x, %workgroup_id_x]
-          %6 = affine.apply affine_map<()[s0, s1] -> (s1 * s0)>()[%workgroup_size_x, %workgroup_count_x]
-          scf.for %arg1 = %5 to %c128 step %6 {
-            %7 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 128)>(%arg1)[%workgroup_size_x]
-            %8 = memref.subview %0[%arg0, 0, %arg1] [%4, 32, %7] [1, 1, 1] : memref<64x32x128xi32> to memref<?x32x?xi32, affine_map<(d0, d1, d2)[s0] -> (d0 * 4096 + s0 + d1 * 128 + d2)>>
-            %9 = memref.cast %8 : memref<?x32x?xi32, affine_map<(d0, d1, d2)[s0] -> (d0 * 4096 + s0 + d1 * 128 + d2)>> to memref<?x?x?xi32>
-            %10 = memref.subview %1[%arg0, 0, %arg1] [%4, 32, %7] [1, 1, 1] : memref<64x32x128xi32> to memref<?x32x?xi32, affine_map<(d0, d1, d2)[s0] -> (d0 * 4096 + s0 + d1 * 128 + d2)>>
-            linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]}
-              ins(%9 : memref<?x?x?xi32>) outs(%10 : memref<?x32x?xi32, affine_map<(d0, d1, d2)[s0] -> (d0 * 4096 + s0 + d1 * 128 + d2)>>) {
-              ^bb0(%arg4: i32, %s: i32):  // no predecessors
-                linalg.yield %arg4 : i32
-            }
-            iree_linalg_ext.sort {__internal_linalg_transform__ = "workgroup"} dimension(1) outs(%10 : memref<?x32x?xi32, affine_map<(d0, d1, d2)[s0] -> (d0 * 4096 + s0 + d1 * 128 + d2)>>)  {
-            ^bb0(%arg2: i32, %arg3: i32):  // no predecessors
-              %11 = arith.cmpi slt, %arg2, %arg3 : i32
-              iree_linalg_ext.yield %11 : i1
-            }
+        linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]}
+            ins(%0 : memref<64x32x128xi32>) outs(%1 : memref<64x32x128xi32>) {
+          ^bb0(%arg4: i32, %s: i32):  // no predecessors
+              linalg.yield %arg4 : i32
           }
-        }
+        iree_linalg_ext.sort {__internal_linalg_transform__ = "workgroup"} dimension(1) outs(%1 : memref<64x32x128xi32>)  {
+          ^bb0(%arg2: i32, %arg3: i32):  // no predecessors
+            %11 = arith.cmpi slt, %arg2, %arg3 : i32
+            iree_linalg_ext.yield %11 : i1
+          }
         return
       }
     }
@@ -106,16 +83,10 @@
 }
 
 //  CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering.config<tile_sizes = {{\[}}[1, 0, 16], [1, 0, 1]{{\]}}, native_vector_size = []>
-//  CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0] -> (s0 ceildiv 16)>
-//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"SPIRVDistribute", workload_per_wg = [16, 1]>
+//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"SPIRVDistribute", workload_per_wg = []>
 //      CHECK: hal.executable.entry_point public @static_3d_sort
 // CHECK-SAME:   translation.info = #[[TRANSLATION]]
 // CHECK-SAME:   workgroup_size = [16 : index, 1 : index, 1 : index]
-// CHECK-NEXT: ^{{.+}}(%[[X:.+]]: index, %[[Y:.+]]: index, %{{.+}}: index):
-// CHECK-NEXT:   %[[ONE:.+]] = arith.constant 1 : index
-// CHECK-NEXT:   %[[DIV:.+]] = affine.apply #[[MAP]]()[%[[X]]]
-// CHECK-NEXT:   hal.return %[[DIV]], %[[Y]], %[[ONE]]
-
 //      CHECK: func @static_3d_sort()
 //      CHECK:   iree_linalg_ext.sort
 // CHECK-SAME:     lowering.config = #[[CONFIG]]
@@ -157,16 +128,10 @@
 }
 
 //   CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering.config<tile_sizes = {{\[}}[4]{{\]}}, native_vector_size = []>
-//   CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0] -> (s0 ceildiv 4)>
-//   CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"SPIRVDistribute", workload_per_wg = [4]>
+//   CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"SPIRVDistribute", workload_per_wg = []>
 //       CHECK: hal.executable.entry_point public @static_1d_fft_stage2
 //  CHECK-SAME:   translation.info = #[[TRANSLATION]]
 //  CHECK-SAME:   workgroup_size = [16 : index, 1 : index, 1 : index]
-//  CHECK-NEXT: ^{{.+}}(%[[ARG0:.+]]: index, %{{.+}}: index, %{{.+}}: index):
-//  CHECK-NEXT:   %[[ONE:.+]] = arith.constant 1 : index
-//  CHECK-NEXT:   %[[T:.+]] = affine.apply #[[MAP]]()[%[[ARG0]]]
-//  CHECK-NEXT:   hal.return %[[T]], %[[ONE]], %[[ONE]]
-
 //       CHECK: func @static_1d_fft_stage2()
 //       CHECK:   iree_linalg_ext.fft
 //  CHECK-SAME:     lowering.config = #[[CONFIG]]
@@ -201,27 +166,9 @@
         %1 = bufferization.to_memref %cst : memref<4xf32>
         %2 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) : memref<64x128x32xf32>
         %3 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) : memref<64x128x32xf32>
-        %workgroup_id_x = hal.interface.workgroup.id[0] : index
-        %workgroup_count_x = hal.interface.workgroup.count[0] : index
-        %workgroup_id_y = hal.interface.workgroup.id[1] : index
-        %workgroup_count_y = hal.interface.workgroup.count[1] : index
-        %workgroup_id_z = hal.interface.workgroup.id[2] : index
-        %workgroup_count_z = hal.interface.workgroup.count[2] : index
-        scf.for %arg0 = %workgroup_id_z to %c64 step %workgroup_count_z {
-          scf.for %arg1 = %workgroup_id_y to %c128 step %workgroup_count_y {
-            %4 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%workgroup_id_x]
-            %5 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%workgroup_count_x]
-            scf.for %arg2 = %4 to %c32 step %5 {
-              %6 = memref.subview %2[%arg0, %arg1, %arg2] [1, 1, 4] [1, 1, 1] : memref<64x128x32xf32> to memref<1x1x4xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 4096 + s0 + d1 * 32 + d2)>>
-              %7 = memref.cast %6 : memref<1x1x4xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 4096 + s0 + d1 * 32 + d2)>> to memref<?x?x?xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 4096 + s0 + d1 * 32 + d2)>>
-              %8 = memref.subview %3[%arg0, %arg1, %arg2] [1, 1, 4] [1, 1, 1] : memref<64x128x32xf32> to memref<1x1x4xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 4096 + s0 + d1 * 32 + d2)>>
-              %9 = memref.cast %8 : memref<1x1x4xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 4096 + s0 + d1 * 32 + d2)>> to memref<?x?x?xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 4096 + s0 + d1 * 32 + d2)>>
-              iree_linalg_ext.fft {__internal_linalg_transform__ = "workgroup"}
-                ins(%c3, %1, %0 : index, memref<4xf32>, memref<4xf32>)
-                outs(%7, %9 : memref<?x?x?xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 4096 + s0 + d1 * 32 + d2)>>, memref<?x?x?xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 4096 + s0 + d1 * 32 + d2)>>)
-            }
-          }
-        }
+        iree_linalg_ext.fft {__internal_linalg_transform__ = "workgroup"}
+            ins(%c3, %1, %0 : index, memref<4xf32>, memref<4xf32>)
+            outs(%2, %3 : memref<64x128x32xf32>, memref<64x128x32xf32>)
         return
       }
     }
@@ -230,15 +177,102 @@
 
 
 //   CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering.config<tile_sizes = {{\[}}[1, 1, 8]{{\]}}, native_vector_size = []>
-//   CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0] -> (s0 ceildiv 8)>
-//   CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"SPIRVDistribute", workload_per_wg = [8, 1, 1]>
+//   CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"SPIRVDistribute", workload_per_wg = []>
 //       CHECK: hal.executable.entry_point public @static_3d_fft_stage3
 //  CHECK-SAME:   translation.info = #[[TRANSLATION]]
 //  CHECK-SAME:   workgroup_size = [16 : index, 1 : index, 1 : index]
-//  CHECK-NEXT: ^{{.+}}(%[[ARG0:.+]]: index, %[[ARG1:.+]]: index, %[[ARG2:.+]]: index):
-//  CHECK-NEXT:   %[[T:.+]] = affine.apply #[[MAP]]()[%[[ARG0]]]
-//  CHECK-NEXT:   hal.return %[[T]], %[[ARG1]], %[[ARG2]]
-
 //       CHECK: func @static_3d_fft_stage3()
 //       CHECK:   iree_linalg_ext.fft
 //  CHECK-SAME:     lowering.config = #[[CONFIG]]
+
+// -----
+
+#executable_layout = #hal.executable.layout<push_constants = 0, sets = [
+  #hal.descriptor_set.layout<0, bindings = [
+    #hal.descriptor_set.binding<0, storage_buffer>,
+    #hal.descriptor_set.binding<1, storage_buffer>
+  ]>
+]>
+hal.executable private @tensor_insert {
+  hal.executable.variant @vulkan_spirv_fb, target = <"vulkan", "vulkan-spirvfb", {
+      spv.target_env = #spv.target_env<#spv.vce<v1.4, [Shader], []>, Unknown:IntegratedGPU, {
+        max_compute_shared_memory_size = 32768 : i32,
+        max_compute_workgroup_invocations = 512 : i32,
+        max_compute_workgroup_size = dense<512> : vector<3xi32>,
+        subgroup_size = 16 : i32}>
+    }> {
+    hal.executable.entry_point @tensor_insert layout(#executable_layout)
+    builtin.module {
+      builtin.func @tensor_insert() {
+        %offset_y = hal.interface.constant.load[0] : index
+        %offset_x = hal.interface.constant.load[1] : index
+        %source_size_y = hal.interface.constant.load[2] : index
+        %source_size_x = hal.interface.constant.load[3] : index
+        %dest_size_y = hal.interface.constant.load[4] : index
+        %dest_size_x = hal.interface.constant.load[5] : index
+        %source_binding = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer)
+            : !flow.dispatch.tensor<readonly:?x?xf32>{%source_size_y, %source_size_x}
+        %dest_binding = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer)
+            : !flow.dispatch.tensor<readwrite:?x?xf32>{%dest_size_y, %dest_size_x}
+        %source = flow.dispatch.tensor.load %source_binding, offsets = [0, 0], sizes = [%source_size_y, %source_size_y], strides = [1, 1]
+            : !flow.dispatch.tensor<readonly:?x?xf32>{%source_size_y, %source_size_x} -> tensor<?x?xf32>
+        %dest = flow.dispatch.tensor.load %dest_binding, offsets = [0, 0], sizes = [%dest_size_y, %dest_size_x], strides = [1, 1]
+            : !flow.dispatch.tensor<readwrite:?x?xf32>{%dest_size_y, %dest_size_x} -> tensor<?x?xf32>
+        %insert = tensor.insert_slice %source into %dest[%offset_y, %offset_x] [%source_size_y, %source_size_x] [1, 1]
+            : tensor<?x?xf32> into tensor<?x?xf32>
+        flow.dispatch.tensor.store %insert, %dest_binding, offsets = [0, 0], sizes = [%dest_size_y, %dest_size_x], strides = [1, 1]
+            : tensor<?x?xf32> -> !flow.dispatch.tensor<readwrite:?x?xf32>{%dest_size_y, %dest_size_x}
+        return
+      }
+    }
+  }
+}
+// Check that the pipeline is set to `SPIRVDistributeAndCopy`
+
+//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"SPIRVDistributeCopy", workload_per_wg = []>
+//      CHECK: tensor.insert_slice
+//  CHECK-NOT:     lowering.config
+
+// -----
+
+#executable_layout = #hal.executable.layout<push_constants = 0, sets = [
+  #hal.descriptor_set.layout<0, bindings = [
+    #hal.descriptor_set.binding<0, storage_buffer>,
+    #hal.descriptor_set.binding<1, storage_buffer>
+  ]>
+]>
+hal.executable private @tensor_extract {
+  hal.executable.variant @vulkan_spirv_fb, target = <"vulkan", "vulkan-spirvfb", {
+      spv.target_env = #spv.target_env<#spv.vce<v1.4, [Shader], []>, Unknown:IntegratedGPU, {
+        max_compute_shared_memory_size = 32768 : i32,
+        max_compute_workgroup_invocations = 512 : i32,
+        max_compute_workgroup_size = dense<512> : vector<3xi32>,
+        subgroup_size = 16 : i32}>
+    }> {
+    hal.executable.entry_point @tensor_extract layout(#executable_layout)
+    builtin.module {
+      builtin.func @tensor_extract() {
+        %offset_y = hal.interface.constant.load[0] : index
+        %offset_x = hal.interface.constant.load[1] : index
+        %source_size_y = hal.interface.constant.load[2] : index
+        %source_size_x = hal.interface.constant.load[3] : index
+        %result_size_y = hal.interface.constant.load[4] : index
+        %result_size_x = hal.interface.constant.load[5] : index
+        %source_binding = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer)
+            : !flow.dispatch.tensor<readonly:?x?xf32>{%source_size_y, %source_size_x}
+        %result_binding = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer)
+            : !flow.dispatch.tensor<writeonly:?x?xf32>{%result_size_y, %result_size_x}
+        %source = flow.dispatch.tensor.load %source_binding, offsets = [0, 0], sizes = [%source_size_y, %source_size_y], strides = [1, 1]
+            : !flow.dispatch.tensor<readonly:?x?xf32>{%source_size_y, %source_size_x} -> tensor<?x?xf32>
+        %extract = tensor.extract_slice %source[%offset_y, %offset_x] [%result_size_y, %result_size_x] [1, 1]
+            : tensor<?x?xf32> to tensor<?x?xf32>
+        flow.dispatch.tensor.store %extract, %result_binding, offsets = [0, 0], sizes = [%result_size_y, %result_size_x], strides = [1, 1]
+            : tensor<?x?xf32> -> !flow.dispatch.tensor<writeonly:?x?xf32>{%result_size_y, %result_size_x}
+        return
+      }
+    }
+  }
+}
+//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"SPIRVDistributeCopy", workload_per_wg = []>
+//      CHECK: tensor.extract_slice
+//  CHECK-NOT:     lowering.config
diff --git a/iree/compiler/Codegen/SPIRV/test/config_default_linalg_ops.mlir b/iree/compiler/Codegen/SPIRV/test/config_default_linalg_ops.mlir
index d4dcf98..d74e6e7 100644
--- a/iree/compiler/Codegen/SPIRV/test/config_default_linalg_ops.mlir
+++ b/iree/compiler/Codegen/SPIRV/test/config_default_linalg_ops.mlir
@@ -1,69 +1,12 @@
 // RUN: iree-opt -split-input-file -pass-pipeline='hal.executable(hal.executable.variant(iree-spirv-lower-executable-target-pass{test-lowering-configuration=true}))' %s | FileCheck %s
 
-#executable_layout = #hal.executable.layout<push_constants = 0, sets = [
-  #hal.descriptor_set.layout<0, bindings = [
-    #hal.descriptor_set.binding<0, storage_buffer>,
-    #hal.descriptor_set.binding<1, storage_buffer>
-  ]>
-]>
-hal.executable @tensor_insert {
-  hal.executable.variant @vulkan_spirv_fb, target = <"vulkan-spirv", "vulkan-spirv-fb", {
-      spv.target_env = #spv.target_env<#spv.vce<v1.4, [Shader], []>, Unknown:IntegratedGPU, {
-        max_compute_shared_memory_size = 32768 : i32,
-        max_compute_workgroup_invocations = 512 : i32,
-        max_compute_workgroup_size = dense<512> : vector<3xi32>,
-        subgroup_size = 16 : i32}>
-    }> {
-    hal.executable.entry_point @tensor_insert_slice layout(#executable_layout)
-    builtin.module {
-      builtin.func @tensor_insert_slice() {
-        %c0 = arith.constant 0 : index
-        %1 = hal.interface.constant.load[0] : index
-        %2 = hal.interface.constant.load[1] : index
-        %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) : !flow.dispatch.tensor<readonly:?x?xi32>{%1, %2}
-        %3 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) : !flow.dispatch.tensor<writeonly:?x?xi32>{%1, %2}
-        %workgroup_size_x = hal.interface.workgroup.size[0] : index
-        %workgroup_size_y = hal.interface.workgroup.size[1] : index
-        %workgroup_id_x = hal.interface.workgroup.id[0] : index
-        %workgroup_count_x = hal.interface.workgroup.count[0] : index
-        %workgroup_id_y = hal.interface.workgroup.id[1] : index
-        %workgroup_count_y = hal.interface.workgroup.count[1] : index
-        %4 = affine.apply affine_map<()[s0, s1] -> (s1 * s0)>()[%workgroup_size_y, %workgroup_id_y]
-        %5 = affine.apply affine_map<()[s0, s1] -> (s1 * s0)>()[%workgroup_size_y, %workgroup_count_y]
-        %d0 = hal.interface.constant.load[2] : index
-        %d1 = hal.interface.constant.load[2] : index
-        scf.for %arg0 = %4 to %d0 step %5 {
-          %6 = affine.min affine_map<(d0)[s0, s1] -> (s0, -d0 + s1)>(%arg0)[%workgroup_size_y, %d0]
-          %7 = affine.apply affine_map<()[s0, s1] -> (s1 * s0)>()[%workgroup_size_x, %workgroup_id_x]
-          %8 = affine.apply affine_map<()[s0, s1] -> (s1 * s0)>()[%workgroup_size_x, %workgroup_count_x]
-          scf.for %arg1 = %7 to %d1 step %8 {
-            %9 = affine.min affine_map<(d0)[s0, s1] -> (s0, -d0 + s1)>(%arg1)[%workgroup_size_x, %d1]
-            %10 = flow.dispatch.tensor.load %0, offsets = [%arg0, %arg1], sizes = [%6, %9], strides = [1, 1] : !flow.dispatch.tensor<readonly:?x?xi32>{%1, %2} -> tensor<?x?xi32>
-            %11 = affine.apply affine_map<(d0)[s0] -> (d0 + s0)>(%arg0)[%1]
-            %12 = affine.apply affine_map<(d0)[s0] -> (d0 + s0)>(%arg1)[%2]
-            flow.dispatch.tensor.store %10, %3, offsets = [%11, %12], sizes = [%6, %9], strides = [1, 1] : tensor<?x?xi32> -> !flow.dispatch.tensor<writeonly:?x?xi32>{%1, %2}
-          }
-        }
-        return
-      }
-    }
-  }
-}
-
-//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"SPIRVDistributeCopy", workload_per_wg = [16, 1]>
-//      CHECK: hal.executable.entry_point public @tensor_insert_slice
-// CHECK-SAME:   translation.info = #[[TRANSLATION]]
-//  CHECK-NOT:   hal.return
-
-// -----
-
 #executable_layout = #hal.executable.layout<push_constants = 2, sets = [
   #hal.descriptor_set.layout<0, bindings = [
     #hal.descriptor_set.binding<0, storage_buffer>,
     #hal.descriptor_set.binding<1, storage_buffer>
   ]>
 ]>
-hal.executable @tensor_insert {
+hal.executable @copy_as_generic {
   hal.executable.variant @vulkan_spirv_fb, target = <"vulkan-spirv", "vulkan-spirv-fb", {
       spv.target_env = #spv.target_env<#spv.vce<v1.4, [Shader], []>, Unknown:IntegratedGPU, {
         max_compute_shared_memory_size = 32768 : i32,
@@ -71,39 +14,21 @@
         max_compute_workgroup_size = dense<512> : vector<3xi32>,
         subgroup_size = 16 : i32}>
     }> {
-    hal.executable.entry_point @tensor_insert_slice layout(#executable_layout)
+    hal.executable.entry_point @copy_as_generic layout(#executable_layout)
     builtin.module {
-      builtin.func @tensor_insert_slice() {
+      builtin.func @copy_as_generic() {
         %c0 = arith.constant 0 : index
         %d0 = hal.interface.constant.load[0] : index
         %d1 = hal.interface.constant.load[1] : index
         %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) : memref<?x?xi32>{%d0, %d1}
         %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) : memref<?x?xi32>{%d0, %d1}
-        %workgroup_size_x = hal.interface.workgroup.size[0] : index
-        %workgroup_size_y = hal.interface.workgroup.size[1] : index
-        %workgroup_id_x = hal.interface.workgroup.id[0] : index
-        %workgroup_count_x = hal.interface.workgroup.count[0] : index
-        %workgroup_id_y = hal.interface.workgroup.id[1] : index
-        %workgroup_count_y = hal.interface.workgroup.count[1] : index
-        %2 = affine.apply affine_map<()[s0, s1] -> (s1 * s0)>()[%workgroup_size_y, %workgroup_id_y]
-        %3 = affine.apply affine_map<()[s0, s1] -> (s1 * s0)>()[%workgroup_size_y, %workgroup_count_y]
-        scf.for %arg0 = %2 to %d0 step %3 {
-          %4 = affine.min affine_map<(d0)[s0, s1] -> (s0, -d0 + s1)>(%arg0)[%workgroup_size_y, %d0]
-          %5 = affine.apply affine_map<()[s0, s1] -> (s1 * s0)>()[%workgroup_size_x, %workgroup_id_x]
-          %6 = affine.apply affine_map<()[s0, s1] -> (s1 * s0)>()[%workgroup_size_x, %workgroup_count_x]
-          scf.for %arg1 = %5 to %d1 step %6 {
-            %7 = affine.min affine_map<(d0)[s0, s1] -> (s0, -d0 + s1)>(%arg1)[%workgroup_size_x, %d1]
-            %8 = memref.subview %0[%arg0, %arg1] [%4, %7] [1, 1] : memref<?x?xi32> to memref<?x?xi32, affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>>
-            %9 = affine.apply affine_map<(d0) -> (d0 + 4)>(%arg0)
-            %10 = affine.apply affine_map<(d0) -> (d0 + 3)>(%arg1)
-            %11 = memref.subview %1[%9, %10] [%4, %7] [1, 1] : memref<?x?xi32> to memref<?x?xi32, affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>>
-            linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]}
-              ins(%8 : memref<?x?xi32, affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>>) outs(%11 : memref<?x?xi32, affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>>) {
-              ^bb0(%arg4: i32, %s: i32):  // no predecessors
-                linalg.yield %arg4 : i32
-            }
+        linalg.generic {
+            indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
+            iterator_types = ["parallel", "parallel"]}
+            ins(%0 : memref<?x?xi32>) outs(%1 : memref<?x?xi32>) {
+            ^bb0(%arg4: i32, %s: i32):  // no predecessors
+              linalg.yield %arg4 : i32
           }
-        }
         return
       }
     }
@@ -111,15 +36,9 @@
 }
 
 //  CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering.config<tile_sizes = {{\[}}[1, 16], [1, 1]{{\]}}, native_vector_size = []>
-//  CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0] -> (s0 ceildiv 16)>
-//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"SPIRVDistribute", workload_per_wg = [16, 1]>
-//      CHECK: hal.executable.entry_point public @tensor_insert_slice
+//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"SPIRVDistribute", workload_per_wg = []>
+//      CHECK: hal.executable.entry_point public @copy_as_generic
 // CHECK-SAME:   translation.info = #[[TRANSLATION]]
-// CHECK-NEXT:   %[[ARG0:[a-zA-Z0-9_]+]]: index
-// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: index
-//  CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
-//  CHECK-DAG:   %[[NWGSX:.+]] = affine.apply #[[MAP]]()[%[[ARG0]]]
-//      CHECK:   hal.return %[[NWGSX]], %[[ARG1]], %[[C1]]
 //      CHECK:   linalg.generic
 // CHECK-SAME:     lowering.config = #[[CONFIG]]
 
@@ -145,40 +64,15 @@
         %c0 = arith.constant 0 : index
         %c224 = arith.constant 224 : index
         %c3 = arith.constant 3 : index
-        %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) : memref<1x225x225x3xf32>
+        %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) : memref<1x224x224x3xf32>
         %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) : memref<1x224x224x3xf32>
-        %workgroup_size_x = hal.interface.workgroup.size[0] : index
-        %workgroup_size_y = hal.interface.workgroup.size[1] : index
-        %workgroup_size_z = hal.interface.workgroup.size[2] : index
-        %workgroup_id_x = hal.interface.workgroup.id[0] : index
-        %workgroup_count_x = hal.interface.workgroup.count[0] : index
-        %workgroup_id_y = hal.interface.workgroup.id[1] : index
-        %workgroup_count_y = hal.interface.workgroup.count[1] : index
-        %workgroup_id_z = hal.interface.workgroup.id[2] : index
-        %workgroup_count_z = hal.interface.workgroup.count[2] : index
-        %2 = affine.apply affine_map<()[s0, s1] -> (s1 * s0)>()[%workgroup_size_z, %workgroup_id_z]
-        %3 = affine.apply affine_map<()[s0, s1] -> (s1 * s0)>()[%workgroup_size_z, %workgroup_count_z]
-        scf.for %arg0 = %2 to %c224 step %3 {
-          %4 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 224)>(%arg0)[%workgroup_size_z]
-          %5 = affine.apply affine_map<()[s0, s1] -> (s1 * s0)>()[%workgroup_size_y, %workgroup_id_y]
-          %6 = affine.apply affine_map<()[s0, s1] -> (s1 * s0)>()[%workgroup_size_y, %workgroup_count_y]
-          scf.for %arg1 = %5 to %c224 step %6 {
-            %7 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 224)>(%arg1)[%workgroup_size_y]
-            %8 = affine.apply affine_map<()[s0, s1] -> (s1 * s0)>()[%workgroup_size_x, %workgroup_id_x]
-            %9 = affine.apply affine_map<()[s0, s1] -> (s1 * s0)>()[%workgroup_size_x, %workgroup_count_x]
-            scf.for %arg2 = %8 to %c3 step %9 {
-              %10 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 3)>(%arg2)[%workgroup_size_x]
-              %11 = memref.subview %1[0, %arg0, %arg1, %arg2] [1, %4, %7, %10] [1, 1, 1, 1] : memref<1x224x224x3xf32> to memref<1x?x?x?xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 150528 + s0 + d1 * 672 + d2 * 3 + d3)>>
-              %12 = memref.subview %0[0, %arg0, %arg1, %arg2] [1, %4, %7, %10] [1, 1, 1, 1] : memref<1x225x225x3xf32> to memref<1x?x?x?xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 151875 + s0 + d1 * 675 + d2 * 3 + d3)>>
-              linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
-                ins(%11 : memref<1x?x?x?xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 150528 + s0 + d1 * 672 + d2 * 3 + d3)>>) 
-                outs(%12 : memref<1x?x?x?xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 151875 + s0 + d1 * 675 + d2 * 3 + d3)>>) {
-                ^bb0(%arg4: f32, %s: f32):  // no predecessors
-                  linalg.yield %arg4 : f32
-              }
-            }
+        linalg.generic {
+            indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>],
+            iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
+            ins(%0 : memref<1x224x224x3xf32>) outs(%1 : memref<1x224x224x3xf32>) {
+          ^bb0(%arg4: f32, %s: f32):  // no predecessors
+            linalg.yield %arg4 : f32
           }
-        }
         return
       }
     }
@@ -186,17 +80,9 @@
 }
 
 //  CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering.config<tile_sizes = {{\[}}[0, 2, 32, 1], [0, 1, 1, 1]{{\]}}, native_vector_size = []>
-//  CHECK-DAG: #[[MAP_Y:.+]] = affine_map<()[s0] -> (s0 ceildiv 32)>
-//  CHECK-DAG: #[[MAP_Z:.+]] = affine_map<()[s0] -> (s0 ceildiv 2)>
-//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"SPIRVDistribute", workload_per_wg = [1, 32, 2]>
-
+//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"SPIRVDistribute", workload_per_wg = []>
 //      CHECK: hal.executable.entry_point public @copy
 // CHECK-SAME:   translation.info = #[[TRANSLATION]]
-// CHECK-NEXT:   (%[[X:.+]]: index, %[[Y:.+]]: index, %[[Z:.+]]: index)
-//  CHECK-DAG:   %[[Y_COUNT:.+]] = affine.apply #[[MAP_Y]]()[%[[Y]]]
-//  CHECK-DAG:   %[[Z_COUNT:.+]] = affine.apply #[[MAP_Z]]()[%[[Z]]]
-//      CHECK:   hal.return %[[X]], %[[Y_COUNT]], %[[Z_COUNT]]
-
 //      CHECK:   linalg.generic
 // CHECK-SAME:     lowering.config = #[[CONFIG]]
 
@@ -204,14 +90,6 @@
 
 // Average pooling op with nice tilable input.
 
-#map0 = affine_map<()[s0, s1] -> (s0 * s1)>
-#map1 = affine_map<(d0) -> (d0 * 12)>
-#map2 = affine_map<(d0)[s0] -> (s0 * 12, d0 * -12 + 24)>
-#map3 = affine_map<(d0)[s0] -> (s0, -d0 + 8)>
-#map4 = affine_map<(d0)[s0] -> (s0, -d0 + 2)>
-#map5 = affine_map<(d0)[s0] -> (-d0 + 2, s0)>
-#map6 = affine_map<(d0)[s0] -> (-d0 + 8, s0)>
-#map7 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
 #executable_layout = #hal.executable.layout<push_constants = 0, sets = [
   #hal.descriptor_set.layout<0, bindings = [
     #hal.descriptor_set.binding<0, storage_buffer>,
@@ -236,42 +114,15 @@
         %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) : !flow.dispatch.tensor<readonly:1x24x24x8xf32>
         %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) : !flow.dispatch.tensor<writeonly:1x2x2x8xf32>
         %2 = linalg.init_tensor [12, 12] : tensor<12x12xf32>
-        %workgroup_size_x = hal.interface.workgroup.size[0] : index
-        %workgroup_size_y = hal.interface.workgroup.size[1] : index
-        %workgroup_size_z = hal.interface.workgroup.size[2] : index
-        %workgroup_id_x = hal.interface.workgroup.id[0] : index
-        %workgroup_count_x = hal.interface.workgroup.count[0] : index
-        %workgroup_id_y = hal.interface.workgroup.id[1] : index
-        %workgroup_count_y = hal.interface.workgroup.count[1] : index
-        %workgroup_id_z = hal.interface.workgroup.id[2] : index
-        %workgroup_count_z = hal.interface.workgroup.count[2] : index
-        %3 = affine.apply #map0()[%workgroup_id_z, %workgroup_size_z]
-        %4 = affine.apply #map0()[%workgroup_count_z, %workgroup_size_z]
-        scf.for %arg0 = %3 to %c2 step %4 {
-          %5 = affine.apply #map0()[%workgroup_id_y, %workgroup_size_y]
-          %6 = affine.apply #map0()[%workgroup_count_y, %workgroup_size_y]
-          scf.for %arg1 = %5 to %c2 step %6 {
-            %7 = affine.apply #map0()[%workgroup_id_x, %workgroup_size_x]
-            %8 = affine.apply #map0()[%workgroup_count_x, %workgroup_size_x]
-            scf.for %arg2 = %7 to %c8 step %8 {
-              %9 = affine.apply #map1(%arg0)
-              %10 = affine.min #map2(%arg0)[%workgroup_size_z]
-              %11 = affine.apply #map1(%arg1)
-              %12 = affine.min #map2(%arg1)[%workgroup_size_y]
-              %13 = affine.min #map3(%arg2)[%workgroup_size_x]
-              %14 = flow.dispatch.tensor.load %0, offsets = [0, %9, %11, %arg2], sizes = [1, %10, %12, %13], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:1x24x24x8xf32> -> tensor<1x?x?x?xf32>
-              %15 = affine.min #map4(%arg0)[%workgroup_size_z]
-              %16 = affine.min #map4(%arg1)[%workgroup_size_y]
-              %17 = affine.min #map5(%arg0)[%workgroup_size_z]
-              %18 = affine.min #map5(%arg1)[%workgroup_size_y]
-              %19 = affine.min #map6(%arg2)[%workgroup_size_x]
-              %20 = linalg.init_tensor [1, %17, %18, %19] : tensor<1x?x?x?xf32>
-              %21 = linalg.fill(%cst, %20) : f32, tensor<1x?x?x?xf32> -> tensor<1x?x?x?xf32>
-              %22 = linalg.pooling_nhwc_sum {__internal_linalg_transform__ = "workgroup", dilations = dense<1> : vector<2xi64>, strides = dense<12> : vector<2xi64>} ins(%14, %2 : tensor<1x?x?x?xf32>, tensor<12x12xf32>) outs(%21 : tensor<1x?x?x?xf32>) -> tensor<1x?x?x?xf32>
-              flow.dispatch.tensor.store %22, %1, offsets = [0, %arg0, %arg1, %arg2], sizes = [1, %15, %16, %13], strides = [1, 1, 1, 1] : tensor<1x?x?x?xf32> -> !flow.dispatch.tensor<writeonly:1x2x2x8xf32>
-            }
-          }
-        }
+        %14 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0, 0], sizes = [1, 24, 24, 8], strides = [1, 1, 1, 1]
+            : !flow.dispatch.tensor<readonly:1x24x24x8xf32> -> tensor<1x24x24x8xf32>
+        %20 = linalg.init_tensor [1, 2, 2, 8] : tensor<1x2x2x8xf32>
+        %21 = linalg.fill(%cst, %20) : f32, tensor<1x2x2x8xf32> -> tensor<1x2x2x8xf32>
+        %22 = linalg.pooling_nhwc_sum {__internal_linalg_transform__ = "workgroup", dilations = dense<1> : vector<2xi64>, strides = dense<12> : vector<2xi64>}
+            ins(%14, %2 : tensor<1x24x24x8xf32>, tensor<12x12xf32>)
+            outs(%21 : tensor<1x2x2x8xf32>) -> tensor<1x2x2x8xf32>
+        flow.dispatch.tensor.store %22, %1, offsets = [0, 0, 0, 0], sizes = [1, 2, 2, 8], strides = [1, 1, 1, 1]
+            : tensor<1x2x2x8xf32> -> !flow.dispatch.tensor<writeonly:1x2x2x8xf32>
         return
       }
     }
@@ -279,18 +130,9 @@
 }
 
 //  CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering.config<tile_sizes = {{\[}}[0, 2, 2, 8], [0, 1, 1, 1]{{\]}}, native_vector_size = []>
-//  CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 8)>
-//  CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 ceildiv 2)>
-//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"SPIRVDistribute", workload_per_wg = [8, 2, 2]>
-
+//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"SPIRVDistribute", workload_per_wg = []>
 //      CHECK: hal.executable.entry_point public @avg_pool
 // CHECK-SAME:   translation.info = #[[TRANSLATION]]
-// CHECK-NEXT:   (%[[X:.+]]: index, %[[Y:.+]]: index, %[[Z:.+]]: index)
-//  CHECK-DAG:   %[[X_COUNT:.+]] = affine.apply #[[MAP0]]()[%[[X]]]
-//  CHECK-DAG:   %[[Y_COUNT:.+]] = affine.apply #[[MAP1]]()[%[[Y]]]
-//  CHECK-DAG:   %[[Z_COUNT:.+]] = affine.apply #[[MAP1]]()[%[[Z]]]
-//      CHECK:   hal.return %[[X_COUNT]], %[[Y_COUNT]], %[[Z_COUNT]]
-
 //      CHECK:   linalg.pooling_nhwc_sum
 // CHECK-SAME:     lowering.config = #[[CONFIG]]
 
@@ -304,15 +146,6 @@
     #hal.descriptor_set.binding<1, storage_buffer>
   ]>
 ]>
-
-#map0 = affine_map<()[s0, s1] -> (s0 * s1)>
-#map8 = affine_map<(d0)[s0] -> (s0, -d0 + 1)>
-#map10 = affine_map<(d0)[s0] -> (-d0 + 1, s0)>
-#map20 = affine_map<(d0) -> (d0 * 2)>
-#map21 = affine_map<(d0)[s0] -> (s0 * 2, d0 * -2 + 76)>
-#map22 = affine_map<(d0)[s0] -> (s0, -d0 + 38)>
-#map23 = affine_map<(d0)[s0] -> (-d0 + 38, s0)>
-
 hal.executable @max_pool {
   hal.executable.variant @vulkan_spirv_fb, target = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {
       spv.target_env = #spv.target_env<#spv.vce<v1.4, [Shader], []>, Unknown:IntegratedGPU, {
@@ -332,62 +165,26 @@
         %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) : !flow.dispatch.tensor<readonly:1x76x1x1xf32>
         %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) : !flow.dispatch.tensor<writeonly:1x38x1x1xf32>
         %2 = linalg.init_tensor [2, 1] : tensor<2x1xf32>
-        %workgroup_size_x = hal.interface.workgroup.size[0] : index
-        %workgroup_size_y = hal.interface.workgroup.size[1] : index
-        %workgroup_size_z = hal.interface.workgroup.size[2] : index
-        %workgroup_id_x = hal.interface.workgroup.id[0] : index
-        %workgroup_count_x = hal.interface.workgroup.count[0] : index
-        %workgroup_id_y = hal.interface.workgroup.id[1] : index
-        %workgroup_count_y = hal.interface.workgroup.count[1] : index
-        %workgroup_id_z = hal.interface.workgroup.id[2] : index
-        %workgroup_count_z = hal.interface.workgroup.count[2] : index
-        %3 = affine.apply #map0()[%workgroup_id_z, %workgroup_size_z]
-        %4 = affine.apply #map0()[%workgroup_count_z, %workgroup_size_z]
-        scf.for %arg0 = %3 to %c38 step %4 {
-          %5 = affine.apply #map0()[%workgroup_id_y, %workgroup_size_y]
-          %6 = affine.apply #map0()[%workgroup_count_y, %workgroup_size_y]
-          scf.for %arg1 = %5 to %c1 step %6 {
-            %7 = affine.apply #map0()[%workgroup_id_x, %workgroup_size_x]
-            %8 = affine.apply #map0()[%workgroup_count_x, %workgroup_size_x]
-            scf.for %arg2 = %7 to %c1 step %8 {
-              %9 = affine.apply #map20(%arg0)
-              %10 = affine.min #map21(%arg0)[%workgroup_size_z]
-              %11 = affine.min #map8(%arg1)[%workgroup_size_y]
-              %12 = affine.min #map8(%arg2)[%workgroup_size_x]
-              %13 = flow.dispatch.tensor.load %0, offsets = [0, %9, %arg1, %arg2], sizes = [1, %10, %11, %12], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:1x76x1x1xf32> -> tensor<1x?x?x?xf32>
-              %14 = affine.min #map22(%arg0)[%workgroup_size_z]
-              %15 = affine.min #map23(%arg0)[%workgroup_size_z]
-              %16 = affine.min #map10(%arg1)[%workgroup_size_y]
-              %17 = affine.min #map10(%arg2)[%workgroup_size_x]
-              %18 = linalg.init_tensor [1, %15, %16, %17] : tensor<1x?x?x?xf32>
-              %19 = linalg.fill(%cst, %18) : f32, tensor<1x?x?x?xf32> -> tensor<1x?x?x?xf32>
-              %20 = linalg.pooling_nhwc_max {dilations = dense<1> : vector<2xi64>, strides = dense<[2, 1]> : vector<2xi64>} ins(%13, %2 : tensor<1x?x?x?xf32>, tensor<2x1xf32>) outs(%19 : tensor<1x?x?x?xf32>) -> tensor<1x?x?x?xf32>
-              flow.dispatch.tensor.store %20, %1, offsets = [0, %arg0, %arg1, %arg2], sizes = [1, %14, %11, %12], strides = [1, 1, 1, 1] : tensor<1x?x?x?xf32> -> !flow.dispatch.tensor<writeonly:1x38x1x1xf32>
-            }
-          }
-        }
+        %13 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0, 0], sizes = [1, 76, 1, 1], strides = [1, 1, 1, 1]
+            : !flow.dispatch.tensor<readonly:1x76x1x1xf32> -> tensor<1x76x1x1xf32>
+        %18 = linalg.init_tensor [1, 38, 1, 1] : tensor<1x38x1x1xf32>
+        %19 = linalg.fill(%cst, %18) : f32, tensor<1x38x1x1xf32> -> tensor<1x38x1x1xf32>
+        %20 = linalg.pooling_nhwc_max {dilations = dense<1> : vector<2xi64>, strides = dense<[2, 1]> : vector<2xi64>}
+            ins(%13, %2 : tensor<1x76x1x1xf32>, tensor<2x1xf32>)
+            outs(%19 : tensor<1x38x1x1xf32>) -> tensor<1x38x1x1xf32>
+        flow.dispatch.tensor.store %20, %1, offsets = [0, 0, 0, 0], sizes = [1, 38, 1, 1], strides = [1, 1, 1, 1]
+            : tensor<1x38x1x1xf32> -> !flow.dispatch.tensor<writeonly:1x38x1x1xf32>
         return
       }
     }
   }
 }
 
-//  CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering.config<tile_sizes = {{\[}}[0, 8, 2, 2], [0, 1, 1, 1]{{\]}}, native_vector_size = []>
-
-//  CHECK-DAG: #[[MAPXY:.+]] = affine_map<()[s0] -> (s0 ceildiv 2)>
-//  CHECK-DAG: #[[MAPZ:.+]] = affine_map<()[s0] -> (s0 ceildiv 8)>
-//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"SPIRVDistribute", workload_per_wg = [2, 2, 8]>
+//  CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering.config<tile_sizes = {{\[}}[0, 32], [0, 1]{{\]}}, native_vector_size = []>
+//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"SPIRVDistribute", workload_per_wg = []>
 //      CHECK: hal.executable.entry_point public @max_pool
 // CHECK-SAME:   translation.info = #[[TRANSLATION]]
-// CHECK-SAME:   workgroup_size = [2 : index, 2 : index, 8 : index]
-// CHECK-NEXT:   %[[WLOADX:[a-zA-Z0-9_]+]]: index
-// CHECK-SAME:   %[[WLOADY:[a-zA-Z0-9_]+]]: index
-// CHECK-SAME:   %[[WLOADZ:[a-zA-Z0-9_]+]]: index
-//  CHECK-DAG:   %[[COUNTX:.+]] = affine.apply #[[MAPXY]]()[%[[WLOADX]]]
-//  CHECK-DAG:   %[[COUNTY:.+]] = affine.apply #[[MAPXY]]()[%[[WLOADY]]]
-//  CHECK-DAG:   %[[COUNTZ:.+]] = affine.apply #[[MAPZ]]()[%[[WLOADZ]]]
-//      CHECK:   hal.return %[[COUNTX]], %[[COUNTY]], %[[COUNTZ]]
-
+// CHECK-SAME:   workgroup_size = [32 : index, 1 : index, 1 : index]
 //      CHECK:   linalg.pooling_nhwc_max
 // CHECK-SAME:     lowering.config = #[[CONFIG]]
 
@@ -420,54 +217,29 @@
         %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) : !flow.dispatch.tensor<readonly:1x10xf32>
         %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) : !flow.dispatch.tensor<readonly:10xf32>
         %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) : !flow.dispatch.tensor<writeonly:10xf32>
-        %workgroup_size_x = hal.interface.workgroup.size[0] : index
-        %workgroup_size_y = hal.interface.workgroup.size[1] : index
-        %workgroup_id_x = hal.interface.workgroup.id[0] : index
-        %workgroup_count_x = hal.interface.workgroup.count[0] : index
-        %workgroup_id_y = hal.interface.workgroup.id[1] : index
-        %workgroup_count_y = hal.interface.workgroup.count[1] : index
-        %3 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_y, %workgroup_size_y]
-        %4 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_y, %workgroup_size_y]
-        scf.for %arg0 = %3 to %c1 step %4 {
-          %5 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_x, %workgroup_size_x]
-          %6 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_x, %workgroup_size_x]
-          scf.for %arg1 = %5 to %c10 step %6 {
-            %7 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 1)>(%arg0)[%workgroup_size_y]
-            %8 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 10)>(%arg1)[%workgroup_size_x]
-            %9 = flow.dispatch.tensor.load %0, offsets = [%arg0, %arg1], sizes = [%7, %8], strides = [1, 1] : !flow.dispatch.tensor<readonly:1x10xf32> -> tensor<?x?xf32>
-            %10 = flow.dispatch.tensor.load %1, offsets = [%arg1], sizes = [%8], strides = [1] : !flow.dispatch.tensor<readonly:10xf32> -> tensor<?xf32>
-            %11 = linalg.init_tensor [%8] : tensor<?xf32>
-            %12 = linalg.generic {
-              indexing_maps = [
-                affine_map<(d0, d1) -> (d0, d1)>,
-                affine_map<(d0, d1) -> (d1)>,
-                affine_map<(d0, d1) -> (d1)>],
-              iterator_types = ["parallel", "parallel"]
-            } ins(%9, %10 : tensor<?x?xf32>, tensor<?xf32>) outs(%11 : tensor<?xf32>) {
+        %9 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [1, 10], strides = [1, 1]
+            : !flow.dispatch.tensor<readonly:1x10xf32> -> tensor<1x10xf32>
+        %10 = flow.dispatch.tensor.load %1, offsets = [0], sizes = [10], strides = [1]
+            : !flow.dispatch.tensor<readonly:10xf32> -> tensor<10xf32>
+        %11 = linalg.init_tensor [10] : tensor<10xf32>
+        %12 = linalg.generic {
+            indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d1)>],
+            iterator_types = ["parallel", "parallel"]}
+            ins(%9, %10 : tensor<1x10xf32>, tensor<10xf32>) outs(%11 : tensor<10xf32>) {
             ^bb0(%arg2: f32, %arg3: f32, %arg4: f32):  // no predecessors
               %13 = arith.addf %arg2, %arg3 : f32
               linalg.yield %13 : f32
-            } -> tensor<?xf32>
-            flow.dispatch.tensor.store %12, %2, offsets = [%arg1], sizes = [%8], strides = [1] : tensor<?xf32> -> !flow.dispatch.tensor<writeonly:10xf32>
-          }
-        }
+            } -> tensor<10xf32>
+        flow.dispatch.tensor.store %12, %2, offsets = [0], sizes = [10], strides = [1] : tensor<10xf32> -> !flow.dispatch.tensor<writeonly:10xf32>
         return
       }
     }
   }
 }
 
-//  CHECK-DAG: #[[MAPX:.+]] = affine_map<()[s0] -> (s0 ceildiv 16)>
-//  CHECK-DAG: #[[MAPY:.+]] = affine_map<()[s0] -> (s0 ceildiv 2)>
-//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"SPIRVDistribute", workload_per_wg = [16, 2]>
+//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"SPIRVDistribute", workload_per_wg = []>
 //      CHECK: hal.executable.entry_point public @elementwise
 // CHECK-SAME:   translation.info = #[[TRANSLATION]]
-// CHECK-NEXT:   %[[WLOADX:[a-zA-Z0-9_]+]]: index
-// CHECK-SAME:   %[[WLOADY:[a-zA-Z0-9_]+]]: index
-//  CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
-//  CHECK-DAG:   %[[NWGSX:.+]] = affine.apply #[[MAPX]]()[%[[WLOADX]]]
-//  CHECK-DAG:   %[[NWGSY:.+]] = affine.apply #[[MAPY]]()[%[[WLOADY]]]
-//      CHECK:   hal.return %[[NWGSX]], %[[NWGSY]], %[[C1]]
 
 // -----
 
@@ -479,15 +251,6 @@
     #hal.descriptor_set.binding<1, storage_buffer>
   ]>
 ]>
-
-#map0 = affine_map<()[s0, s1] -> (s0 * s1)>
-#map8 = affine_map<(d0)[s0] -> (s0, -d0 + 1)>
-#map10 = affine_map<(d0)[s0] -> (-d0 + 1, s0)>
-#map17 = affine_map<(d0)[s0] -> (s0, -d0 + 18)>
-#map18 = affine_map<(d0)[s0] -> (s0, -d0 + 4)>
-#map19 = affine_map<(d0, d1) -> (d1 + 2, -d0 + 20)>
-#map20 = affine_map<(d0)[s0] -> (-d0 + 4, s0)>
-#map21 = affine_map<(d0)[s0] -> (-d0 + 18, s0)>
 #map22 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
 
 hal.executable @dwconv_elementwise {
@@ -511,67 +274,32 @@
         %c6272 = arith.constant 6272 : index
         %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) : !flow.dispatch.tensor<readonly:1x21x20x1xf32>
         %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) : !flow.dispatch.tensor<writeonly:1x19x18x1x4xf32>
-        %workgroup_size_x = hal.interface.workgroup.size[0] : index
-        %workgroup_size_y = hal.interface.workgroup.size[1] : index
-        %workgroup_size_z = hal.interface.workgroup.size[2] : index
-        %workgroup_id_x = hal.interface.workgroup.id[0] : index
-        %workgroup_count_x = hal.interface.workgroup.count[0] : index
-        %workgroup_id_y = hal.interface.workgroup.id[1] : index
-        %workgroup_count_y = hal.interface.workgroup.count[1] : index
-        %workgroup_id_z = hal.interface.workgroup.id[2] : index
-        %workgroup_count_z = hal.interface.workgroup.count[2] : index
-        %2 = affine.apply #map0()[%workgroup_id_z, %workgroup_size_z]
-        %3 = affine.apply #map0()[%workgroup_count_z, %workgroup_size_z]
-        scf.for %arg0 = %2 to %c18 step %3 {
-          %4 = affine.apply #map0()[%workgroup_id_y, %workgroup_size_y]
-          %5 = affine.apply #map0()[%workgroup_count_y, %workgroup_size_y]
-          scf.for %arg1 = %4 to %c1 step %5 {
-            %6 = affine.apply #map0()[%workgroup_id_x, %workgroup_size_x]
-            %7 = affine.apply #map0()[%workgroup_count_x, %workgroup_size_x]
-            scf.for %arg2 = %6 to %c4 step %7 {
-              %8 = affine.min #map17(%arg0)[%workgroup_size_z]
-              %9 = affine.min #map8(%arg1)[%workgroup_size_y]
-              %10 = affine.min #map18(%arg2)[%workgroup_size_x]
-              %11 = linalg.init_tensor [1, 19, %8, %9, %10] : tensor<1x19x?x?x?xf32>
-              %12 = affine.min #map19(%arg0, %8)
-              %13 = affine.min #map10(%arg1)[%workgroup_size_y]
-              %14 = flow.dispatch.tensor.load %0, offsets = [0, 0, %arg0, %arg1], sizes = [1, 21, %12, %13], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:1x21x20x1xf32> -> tensor<1x21x?x?xf32>
-              %15 = affine.min #map20(%arg2)[%workgroup_size_x]
-              %16 = tensor.extract_slice %cst[0, 0, %arg1, %arg2] [3, 3, %13, %15] [1, 1, 1, 1] : tensor<3x3x1x4xf32> to tensor<3x3x?x?xf32>
-              %17 = affine.min #map21(%arg0)[%workgroup_size_z]
-              %18 = linalg.init_tensor [1, 19, %17, %13, %15] : tensor<1x19x?x?x?xf32>
-              %19 = linalg.fill(%cst_9, %18) : f32, tensor<1x19x?x?x?xf32> -> tensor<1x19x?x?x?xf32>
-              %20 = linalg.depthwise_conv_2d_nhwc_hwcm {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%14, %16 : tensor<1x21x?x?xf32>, tensor<3x3x?x?xf32>) outs(%19 : tensor<1x19x?x?x?xf32>) -> tensor<1x19x?x?x?xf32>
-              %21 = linalg.generic {indexing_maps = [#map22, #map22], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%20 : tensor<1x19x?x?x?xf32>) outs(%11 : tensor<1x19x?x?x?xf32>) {
-              ^bb0(%arg3: f32, %arg4: f32):
-                %22 = math.sqrt %cst_8 : f32
-                %23 = arith.addf %arg3, %cst_9 : f32
-                linalg.yield %23 : f32
-              } -> tensor<1x19x?x?x?xf32>
-              flow.dispatch.tensor.store %21, %1, offsets = [0, 0, %arg0, %arg1, %arg2], sizes = [1, 19, %8, %9, %10], strides = [1, 1, 1, 1, 1] : tensor<1x19x?x?x?xf32> -> !flow.dispatch.tensor<writeonly:1x19x18x1x4xf32>
-            }
-          }
-        }
+        %11 = linalg.init_tensor [1, 19, 18, 1, 4] : tensor<1x19x18x1x4xf32>
+        %14 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0, 0], sizes = [1, 21, 20, 1], strides = [1, 1, 1, 1]
+            : !flow.dispatch.tensor<readonly:1x21x20x1xf32> -> tensor<1x21x20x1xf32>
+        %18 = linalg.init_tensor [1, 19, 18, 1, 4] : tensor<1x19x18x1x4xf32>
+        %19 = linalg.fill(%cst_9, %18) : f32, tensor<1x19x18x1x4xf32> -> tensor<1x19x18x1x4xf32>
+        %20 = linalg.depthwise_conv_2d_nhwc_hwcm {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
+            ins(%14, %cst : tensor<1x21x20x1xf32>, tensor<3x3x1x4xf32>) outs(%19 : tensor<1x19x18x1x4xf32>) -> tensor<1x19x18x1x4xf32>
+        %21 = linalg.generic {
+            indexing_maps = [#map22, #map22], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]}
+            ins(%20 : tensor<1x19x18x1x4xf32>) outs(%11 : tensor<1x19x18x1x4xf32>) {
+          ^bb0(%arg3: f32, %arg4: f32):
+            %22 = math.sqrt %cst_8 : f32
+            %23 = arith.addf %arg3, %cst_9 : f32
+            linalg.yield %23 : f32
+          } -> tensor<1x19x18x1x4xf32>
+        flow.dispatch.tensor.store %21, %1, offsets = [0, 0, 0, 0, 0], sizes = [1, 19, 18, 1, 14], strides = [1, 1, 1, 1, 1]
+            : tensor<1x19x18x1x4xf32> -> !flow.dispatch.tensor<writeonly:1x19x18x1x4xf32>
         return
       }
     }
   }
 }
 
-//  CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering.config<tile_sizes = {{\[}}[0, 0, 2, 2, 8], [0, 0, 1, 1, 1]{{\]}}, native_vector_size = []>
-
-//  CHECK-DAG: #[[MAPX:.+]] = affine_map<()[s0] -> (s0 ceildiv 8)>
-//  CHECK-DAG: #[[MAPYZ:.+]] = affine_map<()[s0] -> (s0 ceildiv 2)>
-//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"SPIRVDistribute", workload_per_wg = [8, 2, 2]>
+//  CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering.config<tile_sizes = {{\[}}[0, 4, 2, 0, 4], [0, 1, 1, 0, 1]{{\]}}, native_vector_size = []>
+//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"SPIRVDistribute", workload_per_wg = []>
 //      CHECK: hal.executable.entry_point public @dwconv_elementwise
 // CHECK-SAME:   translation.info = #[[TRANSLATION]]
-// CHECK-NEXT:   %[[WLOADX:[a-zA-Z0-9_]+]]: index
-// CHECK-SAME:   %[[WLOADY:[a-zA-Z0-9_]+]]: index
-// CHECK-SAME:   %[[WLOADZ:[a-zA-Z0-9_]+]]: index
-//  CHECK-DAG:   %[[COUNTX:.+]] = affine.apply #[[MAPX]]()[%[[WLOADX]]]
-//  CHECK-DAG:   %[[COUNTY:.+]] = affine.apply #[[MAPYZ]]()[%[[WLOADY]]]
-//  CHECK-DAG:   %[[COUNTZ:.+]] = affine.apply #[[MAPYZ]]()[%[[WLOADZ]]]
-//      CHECK:   hal.return %[[COUNTX]], %[[COUNTY]], %[[COUNTZ]]
-
 //      CHECK:   linalg.generic
 // CHECK-SAME:     lowering.config = #[[CONFIG]]
diff --git a/iree/compiler/Codegen/SPIRV/test/config_default_matmul.mlir b/iree/compiler/Codegen/SPIRV/test/config_default_matmul.mlir
index 0be03f9..8165759 100644
--- a/iree/compiler/Codegen/SPIRV/test/config_default_matmul.mlir
+++ b/iree/compiler/Codegen/SPIRV/test/config_default_matmul.mlir
@@ -28,59 +28,27 @@
         %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) : !flow.dispatch.tensor<readonly:1x3x3xf32>
         %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) : !flow.dispatch.tensor<readonly:1x3x32xf32>
         %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) : !flow.dispatch.tensor<writeonly:1x3x32xf32>
-        %workgroup_size_x = hal.interface.workgroup.size[0] : index
-        %workgroup_size_y = hal.interface.workgroup.size[1] : index
-        %workgroup_size_z = hal.interface.workgroup.size[2] : index
-        %workgroup_id_x = hal.interface.workgroup.id[0] : index
-        %workgroup_count_x = hal.interface.workgroup.count[0] : index
-        %workgroup_id_y = hal.interface.workgroup.id[1] : index
-        %workgroup_count_y = hal.interface.workgroup.count[1] : index
-        %workgroup_id_z = hal.interface.workgroup.id[2] : index
-        %workgroup_count_z = hal.interface.workgroup.count[2] : index
-        %3 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_z, %workgroup_size_z]
-        %4 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_z, %workgroup_size_z]
-        scf.for %arg0 = %3 to %c1 step %4 {
-          %5 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_y, %workgroup_size_y]
-          %6 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_y, %workgroup_size_y]
-          scf.for %arg1 = %5 to %c3 step %6 {
-            %7 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_x, %workgroup_size_x]
-            %8 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_x, %workgroup_size_x]
-            scf.for %arg2 = %7 to %c32 step %8 {
-              %9 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 1)>(%arg0)[%workgroup_size_z]
-              %10 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 3)>(%arg1)[%workgroup_size_y]
-              %11 = flow.dispatch.tensor.load %0, offsets = [%arg0, %arg1, 0], sizes = [%9, %10, 3], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:1x3x3xf32> -> tensor<?x?x3xf32>
-              %12 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 1)>(%arg0)[%workgroup_size_z]
-              %13 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 32)>(%arg2)[%workgroup_size_x]
-              %14 = flow.dispatch.tensor.load %1, offsets = [%arg0, 0, %arg2], sizes = [%12, 3, %13], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:1x3x32xf32> -> tensor<?x3x?xf32>
-              %15 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 1)>(%arg0)[%workgroup_size_z]
-              %16 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 3)>(%arg1)[%workgroup_size_y]
-              %17 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 32)>(%arg2)[%workgroup_size_x]
-              %18 = affine.min affine_map<(d0)[s0] -> (-d0 + 1, s0)>(%arg0)[%workgroup_size_z]
-              %19 = affine.min affine_map<(d0)[s0] -> (-d0 + 3, s0)>(%arg1)[%workgroup_size_y]
-              %20 = affine.min affine_map<(d0)[s0] -> (-d0 + 32, s0)>(%arg2)[%workgroup_size_x]
-              %21 = linalg.init_tensor [%18, %19, %20] : tensor<?x?x?xf32>
-              %22 = linalg.fill(%cst, %21) : f32, tensor<?x?x?xf32> -> tensor<?x?x?xf32>
-              %23 = linalg.batch_matmul {__internal_linalg_transform__ = "workgroup"} ins(%11, %14 : tensor<?x?x3xf32>, tensor<?x3x?xf32>) outs(%22 : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
-              flow.dispatch.tensor.store %23, %2, offsets = [%arg0, %arg1, %arg2], sizes = [%15, %16, %17], strides = [1, 1, 1] : tensor<?x?x?xf32> -> !flow.dispatch.tensor<writeonly:1x3x32xf32>
-            }
-          }
-        }
+        %11 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0], sizes = [1, 3, 3], strides = [1, 1, 1]
+            : !flow.dispatch.tensor<readonly:1x3x3xf32> -> tensor<1x3x3xf32>
+        %14 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0], sizes = [1, 3, 32], strides = [1, 1, 1]
+            : !flow.dispatch.tensor<readonly:1x3x32xf32> -> tensor<1x3x32xf32>
+        %21 = linalg.init_tensor [1, 3, 32] : tensor<1x3x32xf32>
+        %22 = linalg.fill(%cst, %21) : f32, tensor<1x3x32xf32> -> tensor<1x3x32xf32>
+        %23 = linalg.batch_matmul {__internal_linalg_transform__ = "workgroup"}
+            ins(%11, %14 : tensor<1x3x3xf32>, tensor<1x3x32xf32>) outs(%22 : tensor<1x3x32xf32>) -> tensor<1x3x32xf32>
+        flow.dispatch.tensor.store %23, %2, offsets = [0, 0, 0], sizes = [1, 3, 32], strides = [1, 1, 1]
+            : tensor<1x3x32xf32> -> !flow.dispatch.tensor<writeonly:1x3x32xf32>
         return
       }
     }
   }
 }
 
-//  CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering.config<tile_sizes = {{\[}}[1, 1, 32], [1, 1, 1]{{\]}}, native_vector_size = []>
-//  CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0] -> (s0 ceildiv 32)>
-//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"SPIRVDistribute", workload_per_wg = [32, 1, 1]>
+//  CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering.config<tile_sizes = {{\[}}[0, 1, 32], [0, 1, 1]{{\]}}, native_vector_size = []>
+//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"SPIRVDistribute", workload_per_wg = []>
 //      CHECK: hal.executable.entry_point public @batch_matmul_1x3x32
 // CHECK-SAME:   translation.info = #[[TRANSLATION]]
 // CHECK-SAME:   workgroup_size = [32 : index, 1 : index, 1 : index]
-// CHECK-NEXT: ^{{.+}}(%[[X:.+]]: index, %[[Y:.+]]: index, %[[Z:.+]]: index):
-// CHECK-NEXT:   %[[X_COUNT:.+]] = affine.apply #[[MAP]]()[%[[X]]]
-// CHECK-NEXT:   hal.return %[[X_COUNT]], %[[Y]], %[[Z]]
-
 //      CHECK: func @batch_matmul_1x3x32()
 //      CHECK:   linalg.batch_matmul
 // CHECK-SAME:     lowering.config = #[[CONFIG]]
@@ -114,32 +82,16 @@
         %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) : !flow.dispatch.tensor<readonly:64x32xi8>
         %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) : !flow.dispatch.tensor<readonly:32x16xi8>
         %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) : !flow.dispatch.tensor<writeonly:64x16xi32>
-        %workgroup_size_x = hal.interface.workgroup.size[0] : index
-        %workgroup_size_y = hal.interface.workgroup.size[1] : index
-        %workgroup_id_x = hal.interface.workgroup.id[0] : index
-        %workgroup_count_x = hal.interface.workgroup.count[0] : index
-        %workgroup_id_y = hal.interface.workgroup.id[1] : index
-        %workgroup_count_y = hal.interface.workgroup.count[1] : index
-        %3 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_y, %workgroup_size_y]
-        %4 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_y, %workgroup_size_y]
-        scf.for %arg0 = %3 to %c64 step %4 {
-          %5 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_x, %workgroup_size_x]
-          %6 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_x, %workgroup_size_x]
-          scf.for %arg1 = %5 to %c16 step %6 {
-            %7 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 64)>(%arg0)[%workgroup_size_y]
-            %8 = flow.dispatch.tensor.load %0, offsets = [%arg0, 0], sizes = [%7, 32], strides = [1, 1] : !flow.dispatch.tensor<readonly:64x32xi8> -> tensor<?x32xi8>
-            %9 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 16)>(%arg1)[%workgroup_size_x]
-            %10 = flow.dispatch.tensor.load %1, offsets = [0, %arg1], sizes = [32, %9], strides = [1, 1] : !flow.dispatch.tensor<readonly:32x16xi8> -> tensor<32x?xi8>
-            %11 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 64)>(%arg0)[%workgroup_size_y]
-            %12 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 16)>(%arg1)[%workgroup_size_x]
-            %13 = affine.min affine_map<(d0)[s0] -> (-d0 + 64, s0)>(%arg0)[%workgroup_size_y]
-            %14 = affine.min affine_map<(d0)[s0] -> (-d0 + 16, s0)>(%arg1)[%workgroup_size_x]
-            %15 = linalg.init_tensor [%13, %14] : tensor<?x?xi32>
-            %16 = linalg.fill(%c0_i32, %15) : i32, tensor<?x?xi32> -> tensor<?x?xi32>
-            %17 = linalg.matmul {__internal_linalg_transform__ = "workgroup"} ins(%8, %10 : tensor<?x32xi8>, tensor<32x?xi8>) outs(%16 : tensor<?x?xi32>) -> tensor<?x?xi32>
-            flow.dispatch.tensor.store %17, %2, offsets = [%arg0, %arg1], sizes = [%11, %12], strides = [1, 1] : tensor<?x?xi32> -> !flow.dispatch.tensor<writeonly:64x16xi32>
-          }
-        }
+        %8 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [64, 32], strides = [1, 1]
+            : !flow.dispatch.tensor<readonly:64x32xi8> -> tensor<64x32xi8>
+        %10 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [32, 16], strides = [1, 1]
+            : !flow.dispatch.tensor<readonly:32x16xi8> -> tensor<32x16xi8>
+        %15 = linalg.init_tensor [64, 16] : tensor<64x16xi32>
+        %16 = linalg.fill(%c0_i32, %15) : i32, tensor<64x16xi32> -> tensor<64x16xi32>
+        %17 = linalg.matmul {__internal_linalg_transform__ = "workgroup"}
+            ins(%8, %10 : tensor<64x32xi8>, tensor<32x16xi8>) outs(%16 : tensor<64x16xi32>) -> tensor<64x16xi32>
+        flow.dispatch.tensor.store %17, %2, offsets = [0, 0], sizes = [64, 16], strides = [1, 1]
+            : tensor<64x16xi32> -> !flow.dispatch.tensor<writeonly:64x16xi32>
         return
       }
     }
@@ -147,18 +99,10 @@
 }
 
 //  CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering.config<tile_sizes = {{\[}}[4, 16], [1, 1]{{\]}}, native_vector_size = []>
-//  CHECK-DAG: #[[MAP_X:.+]] = affine_map<()[s0] -> (s0 ceildiv 16)>
-//  CHECK-DAG: #[[MAP_Y:.+]] = affine_map<()[s0] -> (s0 ceildiv 4)>
-//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"SPIRVDistribute", workload_per_wg = [16, 4]>
+//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"SPIRVDistribute", workload_per_wg = []>
 //      CHECK: hal.executable.entry_point public @matmul_64x16
 // CHECK-SAME:   translation.info = #[[TRANSLATION]]
 // CHECK-SAME:   workgroup_size = [16 : index, 4 : index, 1 : index]
-// CHECK-NEXT: ^{{.+}}(%[[X:.+]]: index, %[[Y:.+]]: index, %{{.+}}: index):
-// CHECK-NEXT:   %[[ONE:.+]] = arith.constant 1 : index
-// CHECK-NEXT:   %[[X_COUNT:.+]] = affine.apply #[[MAP_X]]()[%[[X]]]
-// CHECK-NEXT:   %[[Y_COUNT:.+]] = affine.apply #[[MAP_Y]]()[%[[Y]]]
-// CHECK-NEXT:   hal.return %[[X_COUNT]], %[[Y_COUNT]], %[[ONE]]
-
 //      CHECK: func @matmul_64x16()
 //      CHECK:   linalg.matmul
 // CHECK-SAME:     lowering.config = #[[CONFIG]]
@@ -194,37 +138,25 @@
         %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) : !flow.dispatch.tensor<readonly:400x576xf32>
         %2 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) : !flow.dispatch.tensor<readonly:576x273xf32>
         %3 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) : !flow.dispatch.tensor<writeonly:400x273xf32>
-        %workgroup_size_x = hal.interface.workgroup.size[0] : index
-        %workgroup_size_y = hal.interface.workgroup.size[1] : index
-        %workgroup_id_x = hal.interface.workgroup.id[0] : index
-        %workgroup_count_x = hal.interface.workgroup.count[0] : index
-        %workgroup_id_y = hal.interface.workgroup.id[1] : index
-        %workgroup_count_y = hal.interface.workgroup.count[1] : index
-        %4 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_y, %workgroup_size_y]
-        %5 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_y, %workgroup_size_y]
-        scf.for %arg0 = %4 to %c400 step %5 {
-          %6 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_x, %workgroup_size_x]
-          %7 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_x, %workgroup_size_x]
-          scf.for %arg1 = %6 to %c273 step %7 {
-            %8 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 273)>(%arg1)[%workgroup_size_x]
-            %9 = flow.dispatch.tensor.load %0, offsets = [%arg1], sizes = [%8], strides = [1] : !flow.dispatch.tensor<readonly:273xf32> -> tensor<?xf32>
-            %10 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 400)>(%arg0)[%workgroup_size_y]
-            %11 = linalg.init_tensor [%10, %8] : tensor<?x?xf32>
-            %12 = affine.min affine_map<(d0)[s0] -> (-d0 + 400, s0)>(%arg0)[%workgroup_size_y]
-            %13 = flow.dispatch.tensor.load %1, offsets = [%arg0, 0], sizes = [%12, 576], strides = [1, 1] : !flow.dispatch.tensor<readonly:400x576xf32> -> tensor<?x576xf32>
-            %14 = affine.min affine_map<(d0)[s0] -> (-d0 + 273, s0)>(%arg1)[%workgroup_size_x]
-            %15 = flow.dispatch.tensor.load %2, offsets = [0, %arg1], sizes = [576, %14], strides = [1, 1] : !flow.dispatch.tensor<readonly:576x273xf32> -> tensor<576x?xf32>
-            %16 = linalg.init_tensor [%12, %14] : tensor<?x?xf32>
-            %17 = linalg.fill(%cst, %16) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
-            %18 = linalg.matmul ins(%13, %15 : tensor<?x576xf32>, tensor<576x?xf32>) outs(%17 : tensor<?x?xf32>) -> tensor<?x?xf32>
-            %19 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%9, %18 : tensor<?xf32>, tensor<?x?xf32>) outs(%11 : tensor<?x?xf32>) {
-            ^bb0(%arg2: f32, %arg3: f32, %arg4: f32):  // no predecessors
-              %20 = arith.addf %arg2, %arg3 : f32
-              linalg.yield %20 : f32
-            } -> tensor<?x?xf32>
-            flow.dispatch.tensor.store %19, %3, offsets = [%arg0, %arg1], sizes = [%10, %8], strides = [1, 1] : tensor<?x?xf32> -> !flow.dispatch.tensor<writeonly:400x273xf32>
-          }
-        }
+        %9 = flow.dispatch.tensor.load %0, offsets = [0], sizes = [273], strides = [1] : !flow.dispatch.tensor<readonly:273xf32> -> tensor<273xf32>
+        %11 = linalg.init_tensor [400, 273] : tensor<400x273xf32>
+        %13 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [400, 576], strides = [1, 1]
+            : !flow.dispatch.tensor<readonly:400x576xf32> -> tensor<400x576xf32>
+        %15 = flow.dispatch.tensor.load %2, offsets = [0, 0], sizes = [576, 273], strides = [1, 1]
+            : !flow.dispatch.tensor<readonly:576x273xf32> -> tensor<576x273xf32>
+        %16 = linalg.init_tensor [400, 273] : tensor<400x273xf32>
+        %17 = linalg.fill(%cst, %16) : f32, tensor<400x273xf32> -> tensor<400x273xf32>
+        %18 = linalg.matmul ins(%13, %15 : tensor<400x576xf32>, tensor<576x273xf32>) outs(%17 : tensor<400x273xf32>) -> tensor<400x273xf32>
+        %19 = linalg.generic {
+            indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
+            iterator_types = ["parallel", "parallel"]}
+            ins(%9, %18 : tensor<273xf32>, tensor<400x273xf32>) outs(%11 : tensor<400x273xf32>) {
+          ^bb0(%arg2: f32, %arg3: f32, %arg4: f32):  // no predecessors
+            %20 = arith.addf %arg2, %arg3 : f32
+            linalg.yield %20 : f32
+          } -> tensor<400x273xf32>
+        flow.dispatch.tensor.store %19, %3, offsets = [0, 0], sizes = [400, 273], strides = [1, 1]
+            : tensor<400x273xf32> -> !flow.dispatch.tensor<writeonly:400x273xf32>
         return
       }
     }
@@ -232,19 +164,10 @@
 }
 
 //  CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering.config<tile_sizes = {{\[}}[2, 32], [1, 1]{{\]}}, native_vector_size = []>
-//  CHECK-DAG: #[[MAP_X:.+]] = affine_map<()[s0] -> (s0 ceildiv 32)>
-//  CHECK-DAG: #[[MAP_Y:.+]] = affine_map<()[s0] -> (s0 ceildiv 2)>
-//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"SPIRVDistribute", workload_per_wg = [32, 2]>
-
+//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"SPIRVDistribute", workload_per_wg = []>
 //      CHECK: hal.executable.entry_point public @matmul_400x273
 // CHECK-SAME:   translation.info = #[[TRANSLATION]]
 // CHECK-SAME:   workgroup_size = [32 : index, 2 : index, 1 : index]
-// CHECK-NEXT: ^{{.+}}(%[[X:.+]]: index, %[[Y:.+]]: index, %{{.+}}: index):
-// CHECK-NEXT:   %[[ONE:.+]] = arith.constant 1 : index
-// CHECK-NEXT:   %[[X_COUNT:.+]] = affine.apply #[[MAP_X]]()[%[[X]]]
-// CHECK-NEXT:   %[[Y_COUNT:.+]] = affine.apply #[[MAP_Y]]()[%[[Y]]]
-// CHECK-NEXT:   hal.return %[[X_COUNT]], %[[Y_COUNT]], %[[ONE]]
-
 //      CHECK: func @matmul_400x273()
 //      CHECK:   linalg.matmul
 // CHECK-SAME:     lowering.config = #[[CONFIG]]
@@ -280,37 +203,26 @@
         %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) : !flow.dispatch.tensor<readonly:25x512xf32>
         %2 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) : !flow.dispatch.tensor<readonly:512x546xf32>
         %3 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) : !flow.dispatch.tensor<writeonly:25x546xf32>
-        %workgroup_size_x = hal.interface.workgroup.size[0] : index
-        %workgroup_size_y = hal.interface.workgroup.size[1] : index
-        %workgroup_id_x = hal.interface.workgroup.id[0] : index
-        %workgroup_count_x = hal.interface.workgroup.count[0] : index
-        %workgroup_id_y = hal.interface.workgroup.id[1] : index
-        %workgroup_count_y = hal.interface.workgroup.count[1] : index
-        %4 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_y, %workgroup_size_y]
-        %5 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_y, %workgroup_size_y]
-        scf.for %arg0 = %4 to %c25 step %5 {
-          %6 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_x, %workgroup_size_x]
-          %7 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_x, %workgroup_size_x]
-          scf.for %arg1 = %6 to %c546 step %7 {
-            %8 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 546)>(%arg1)[%workgroup_size_x]
-            %9 = flow.dispatch.tensor.load %0, offsets = [%arg1], sizes = [%8], strides = [1] : !flow.dispatch.tensor<readonly:546xf32> -> tensor<?xf32>
-            %10 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 25)>(%arg0)[%workgroup_size_y]
-            %11 = linalg.init_tensor [%10, %8] : tensor<?x?xf32>
-            %12 = affine.min affine_map<(d0)[s0] -> (-d0 + 25, s0)>(%arg0)[%workgroup_size_y]
-            %13 = flow.dispatch.tensor.load %1, offsets = [%arg0, 0], sizes = [%12, 512], strides = [1, 1] : !flow.dispatch.tensor<readonly:25x512xf32> -> tensor<?x512xf32>
-            %14 = affine.min affine_map<(d0)[s0] -> (-d0 + 546, s0)>(%arg1)[%workgroup_size_x]
-            %15 = flow.dispatch.tensor.load %2, offsets = [0, %arg1], sizes = [512, %14], strides = [1, 1] : !flow.dispatch.tensor<readonly:512x546xf32> -> tensor<512x?xf32>
-            %16 = linalg.init_tensor [%12, %14] : tensor<?x?xf32>
-            %17 = linalg.fill(%cst, %16) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
-            %18 = linalg.matmul ins(%13, %15 : tensor<?x512xf32>, tensor<512x?xf32>) outs(%17 : tensor<?x?xf32>) -> tensor<?x?xf32>
-            %19 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%9, %18 : tensor<?xf32>, tensor<?x?xf32>) outs(%11 : tensor<?x?xf32>) attrs =  {__internal_linalg_transform__ = "workgroup"} {
-            ^bb0(%arg2: f32, %arg3: f32, %arg4: f32):  // no predecessors
-              %20 = arith.addf %arg2, %arg3 : f32
-              linalg.yield %20 : f32
-            } -> tensor<?x?xf32>
-            flow.dispatch.tensor.store %19, %3, offsets = [%arg0, %arg1], sizes = [%10, %8], strides = [1, 1] : tensor<?x?xf32> -> !flow.dispatch.tensor<writeonly:25x546xf32>
-          }
-        }
+        %9 = flow.dispatch.tensor.load %0, offsets = [0], sizes = [546], strides = [1]
+            : !flow.dispatch.tensor<readonly:546xf32> -> tensor<546xf32>
+        %11 = linalg.init_tensor [25, 546] : tensor<25x546xf32>
+        %13 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [25, 512], strides = [1, 1]
+            : !flow.dispatch.tensor<readonly:25x512xf32> -> tensor<25x512xf32>
+        %15 = flow.dispatch.tensor.load %2, offsets = [0, 0], sizes = [512, 546], strides = [1, 1]
+            : !flow.dispatch.tensor<readonly:512x546xf32> -> tensor<512x546xf32>
+        %16 = linalg.init_tensor [25, 546] : tensor<25x546xf32>
+        %17 = linalg.fill(%cst, %16) : f32, tensor<25x546xf32> -> tensor<25x546xf32>
+        %18 = linalg.matmul ins(%13, %15 : tensor<25x512xf32>, tensor<512x546xf32>) outs(%17 : tensor<25x546xf32>) -> tensor<25x546xf32>
+        %19 = linalg.generic {
+            indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
+            iterator_types = ["parallel", "parallel"]}
+            ins(%9, %18 : tensor<546xf32>, tensor<25x546xf32>) outs(%11 : tensor<25x546xf32>) attrs =  {__internal_linalg_transform__ = "workgroup"} {
+          ^bb0(%arg2: f32, %arg3: f32, %arg4: f32):  // no predecessors
+            %20 = arith.addf %arg2, %arg3 : f32
+            linalg.yield %20 : f32
+          } -> tensor<25x546xf32>
+        flow.dispatch.tensor.store %19, %3, offsets = [0, 0], sizes = [25, 546], strides = [1, 1]
+            : tensor<25x546xf32> -> !flow.dispatch.tensor<writeonly:25x546xf32>
         return
       }
     }
@@ -318,18 +230,10 @@
 }
 
 //  CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering.config<tile_sizes = {{\[}}[32, 2], [1, 1]{{\]}}, native_vector_size = []>
-//  CHECK-DAG: #[[MAP_X:.+]] = affine_map<()[s0] -> (s0 ceildiv 2)>
-//  CHECK-DAG: #[[MAP_Y:.+]] = affine_map<()[s0] -> (s0 ceildiv 32)>
-//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"SPIRVDistribute", workload_per_wg = [2, 32]>
+//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"SPIRVDistribute", workload_per_wg = []>
 //      CHECK: hal.executable.entry_point public @matmul_25x546
 // CHECK-SAME:   translation.info = #[[TRANSLATION]]
 // CHECK-SAME:   workgroup_size = [2 : index, 32 : index, 1 : index]
-// CHECK-NEXT: ^{{.+}}(%[[X:.+]]: index, %[[Y:.+]]: index, %{{.+}}: index):
-// CHECK-NEXT:   %[[ONE:.+]] = arith.constant 1 : index
-// CHECK-NEXT:   %[[X_COUNT:.+]] = affine.apply #[[MAP_X]]()[%[[X]]]
-// CHECK-NEXT:   %[[Y_COUNT:.+]] = affine.apply #[[MAP_Y]]()[%[[Y]]]
-// CHECK-NEXT:   hal.return %[[X_COUNT]], %[[Y_COUNT]], %[[ONE]]
-
 //      CHECK: func @matmul_25x546()
 //      CHECK:   linalg.matmul
 // CHECK-SAME:     lowering.config = #[[CONFIG]]
@@ -373,39 +277,28 @@
         %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) : !flow.dispatch.tensor<readonly:256x128xf16>
         %3 = hal.interface.binding.subspan set(0) binding(3) type(storage_buffer) : !flow.dispatch.tensor<readonly:128x1024xf16>
         %4 = hal.interface.binding.subspan set(0) binding(4) type(storage_buffer) : !flow.dispatch.tensor<writeonly:256x1024xf16>
-        %workgroup_size_x = hal.interface.workgroup.size[0] : index
-        %workgroup_size_y = hal.interface.workgroup.size[1] : index
-        %workgroup_id_x = hal.interface.workgroup.id[0] : index
-        %workgroup_count_x = hal.interface.workgroup.count[0] : index
-        %workgroup_id_y = hal.interface.workgroup.id[1] : index
-        %workgroup_count_y = hal.interface.workgroup.count[1] : index
-        %5 = affine.apply #map0()[%workgroup_id_y, %workgroup_size_y]
-        %6 = affine.apply #map0()[%workgroup_count_y, %workgroup_size_y]
-        scf.for %arg0 = %5 to %c256 step %6 {
-          %7 = affine.apply #map0()[%workgroup_id_x, %workgroup_size_x]
-          %8 = affine.apply #map0()[%workgroup_count_x, %workgroup_size_x]
-          scf.for %arg1 = %7 to %c1024 step %8 {
-            %9 = affine.min #map1(%arg0)[%workgroup_size_y]
-            %10 = affine.min #map2(%arg1)[%workgroup_size_x]
-            %11 = flow.dispatch.tensor.load %0, offsets = [%arg0, %arg1], sizes = [%9, %10], strides = [1, 1] : !flow.dispatch.tensor<readonly:256x1024xf16> -> tensor<?x?xf16>
-            %12 = flow.dispatch.tensor.load %1, offsets = [%arg0, %arg1], sizes = [%9, %10], strides = [1, 1] : !flow.dispatch.tensor<readonly:256x1024xf16> -> tensor<?x?xf16>
-            %13 = linalg.init_tensor [%9, %10] : tensor<?x?xf16>
-            %14 = affine.min #map3(%arg0)[%workgroup_size_y]
-            %15 = flow.dispatch.tensor.load %2, offsets = [%arg0, 0], sizes = [%14, 128], strides = [1, 1] : !flow.dispatch.tensor<readonly:256x128xf16> -> tensor<?x128xf16>
-            %16 = affine.min #map4(%arg1)[%workgroup_size_x]
-            %17 = flow.dispatch.tensor.load %3, offsets = [0, %arg1], sizes = [128, %16], strides = [1, 1] : !flow.dispatch.tensor<readonly:128x1024xf16> -> tensor<128x?xf16>
-            %18 = linalg.init_tensor [%14, %16] : tensor<?x?xf16>
-            %19 = linalg.fill(%cst, %18) : f16, tensor<?x?xf16> -> tensor<?x?xf16>
-            %20 = linalg.matmul ins(%15, %17 : tensor<?x128xf16>, tensor<128x?xf16>) outs(%19 : tensor<?x?xf16>) -> tensor<?x?xf16>
-            %21 = linalg.generic {indexing_maps = [#map5, #map5, #map5, #map5], iterator_types = ["parallel", "parallel"]} ins(%20, %11, %12 : tensor<?x?xf16>, tensor<?x?xf16>, tensor<?x?xf16>) outs(%13 : tensor<?x?xf16>) {
-            ^bb0(%arg2: f16, %arg3: f16, %arg4: f16, %arg5: f16):  // no predecessors
-              %22 = arith.divf %arg2, %arg3 : f16
-              %23 = arith.subf %22, %arg4 : f16
-              linalg.yield %23 : f16
-            } -> tensor<?x?xf16>
-            flow.dispatch.tensor.store %21, %4, offsets = [%arg0, %arg1], sizes = [%9, %10], strides = [1, 1] : tensor<?x?xf16> -> !flow.dispatch.tensor<writeonly:256x1024xf16>
-          }
-        }
+        %11 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [256, 1024], strides = [1, 1]
+            : !flow.dispatch.tensor<readonly:256x1024xf16> -> tensor<256x1024xf16>
+        %12 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [256, 1024], strides = [1, 1]
+            : !flow.dispatch.tensor<readonly:256x1024xf16> -> tensor<256x1024xf16>
+        %13 = linalg.init_tensor [256, 1024] : tensor<256x1024xf16>
+        %15 = flow.dispatch.tensor.load %2, offsets = [0, 0], sizes = [256, 128], strides = [1, 1]
+            : !flow.dispatch.tensor<readonly:256x128xf16> -> tensor<256x128xf16>
+        %17 = flow.dispatch.tensor.load %3, offsets = [0, 0], sizes = [128, 1024], strides = [1, 1]
+            : !flow.dispatch.tensor<readonly:128x1024xf16> -> tensor<128x1024xf16>
+        %18 = linalg.init_tensor [256, 1024] : tensor<256x1024xf16>
+        %19 = linalg.fill(%cst, %18) : f16, tensor<256x1024xf16> -> tensor<256x1024xf16>
+        %20 = linalg.matmul ins(%15, %17 : tensor<256x128xf16>, tensor<128x1024xf16>) outs(%19 : tensor<256x1024xf16>) -> tensor<256x1024xf16>
+        %21 = linalg.generic {
+            indexing_maps = [#map5, #map5, #map5, #map5], iterator_types = ["parallel", "parallel"]}
+            ins(%20, %11, %12 : tensor<256x1024xf16>, tensor<256x1024xf16>, tensor<256x1024xf16>) outs(%13 : tensor<256x1024xf16>) {
+          ^bb0(%arg2: f16, %arg3: f16, %arg4: f16, %arg5: f16):  // no predecessors
+            %22 = arith.divf %arg2, %arg3 : f16
+            %23 = arith.subf %22, %arg4 : f16
+            linalg.yield %23 : f16
+          } -> tensor<256x1024xf16>
+        flow.dispatch.tensor.store %21, %4, offsets = [0, 0], sizes = [256, 1024], strides = [1, 1]
+            : tensor<256x1024xf16> -> !flow.dispatch.tensor<writeonly:256x1024xf16>
         return
       }
     }
@@ -413,18 +306,10 @@
 }
 
 //  CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering.config<tile_sizes = {{\[}}[16, 256], [8, 8], [0, 0, 4]{{\]}}, native_vector_size = []>
-//  CHECK-DAG: #[[MAP_X:.+]] = affine_map<()[s0] -> (s0 ceildiv 256)>
-//  CHECK-DAG: #[[MAP_Y:.+]] = affine_map<()[s0] -> (s0 ceildiv 16)>
-//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"SPIRVVectorize", workload_per_wg = [256, 16]>
+//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"SPIRVVectorize", workload_per_wg = []>
 //      CHECK: hal.executable.entry_point public @matmul_pointwise_256x1024
 // CHECK-SAME:   translation.info = #[[TRANSLATION]]
 // CHECK-SAME:   workgroup_size = [32 : index, 2 : index, 1 : index]
-// CHECK-NEXT: ^{{.+}}(%[[X:.+]]: index, %[[Y:.+]]: index, %{{.+}}: index):
-// CHECK-NEXT:   %[[ONE:.+]] = arith.constant 1 : index
-// CHECK-NEXT:   %[[X_COUNT:.+]] = affine.apply #[[MAP_X]]()[%[[X]]]
-// CHECK-NEXT:   %[[Y_COUNT:.+]] = affine.apply #[[MAP_Y]]()[%[[Y]]]
-// CHECK-NEXT:   hal.return %[[X_COUNT]], %[[Y_COUNT]], %[[ONE]]
-
 //      CHECK: func @matmul_pointwise_256x1024()
 //      CHECK:   linalg.matmul
 // CHECK-SAME:     lowering.config = #[[CONFIG]]
diff --git a/iree/compiler/Codegen/SPIRV/test/config_mali_conv.mlir b/iree/compiler/Codegen/SPIRV/test/config_mali_conv.mlir
index c78db62..eec7a58 100644
--- a/iree/compiler/Codegen/SPIRV/test/config_mali_conv.mlir
+++ b/iree/compiler/Codegen/SPIRV/test/config_mali_conv.mlir
@@ -27,44 +27,17 @@
         %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) : !flow.dispatch.tensor<readonly:1x225x225x3xf32>
         %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) : !flow.dispatch.tensor<readonly:3x3x3x512xf32>
         %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) : !flow.dispatch.tensor<writeonly:1x112x112x512xf32>
-        %workgroup_size_x = hal.interface.workgroup.size[0] : index
-        %workgroup_size_y = hal.interface.workgroup.size[1] : index
-        %workgroup_size_z = hal.interface.workgroup.size[2] : index
-        %workgroup_id_x = hal.interface.workgroup.id[0] : index
-        %workgroup_count_x = hal.interface.workgroup.count[0] : index
-        %workgroup_id_y = hal.interface.workgroup.id[1] : index
-        %workgroup_count_y = hal.interface.workgroup.count[1] : index
-        %workgroup_id_z = hal.interface.workgroup.id[2] : index
-        %workgroup_count_z = hal.interface.workgroup.count[2] : index
-        %3 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_z, %workgroup_size_z]
-        %4 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_z, %workgroup_size_z]
-        scf.for %arg0 = %3 to %c112 step %4 {
-          %5 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_y, %workgroup_size_y]
-          %6 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_y, %workgroup_size_y]
-          scf.for %arg1 = %5 to %c112 step %6 {
-            %7 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_x, %workgroup_size_x]
-            %8 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_x, %workgroup_size_x]
-            scf.for %arg2 = %7 to %c512 step %8 {
-              %9 = affine.apply affine_map<(d0) -> (d0 * 2)>(%arg0)
-              %10 = affine.min affine_map<(d0)[s0] -> (s0 * 2 + 1, d0 * -2 + 225)>(%arg0)[%workgroup_size_z]
-              %11 = affine.apply affine_map<(d0) -> (d0 * 2)>(%arg1)
-              %12 = affine.min affine_map<(d0)[s0] -> (s0 * 2 + 1, d0 * -2 + 225)>(%arg1)[%workgroup_size_y]
-              %13 = flow.dispatch.tensor.load %0, offsets = [0, %9, %11, 0], sizes = [1, %10, %12, 3], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:1x225x225x3xf32> -> tensor<1x?x?x3xf32>
-              %14 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 512)>(%arg2)[%workgroup_size_x]
-              %15 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0, %arg2], sizes = [3, 3, 3, %14], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:3x3x3x512xf32> -> tensor<3x3x3x?xf32>
-              %16 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 112)>(%arg0)[%workgroup_size_z]
-              %17 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 112)>(%arg1)[%workgroup_size_y]
-              %18 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 512)>(%arg2)[%workgroup_size_x]
-              %19 = affine.min affine_map<(d0)[s0] -> (-d0 + 112, s0)>(%arg0)[%workgroup_size_z]
-              %20 = affine.min affine_map<(d0)[s0] -> (-d0 + 112, s0)>(%arg1)[%workgroup_size_y]
-              %21 = affine.min affine_map<(d0)[s0] -> (-d0 + 512, s0)>(%arg2)[%workgroup_size_x]
-              %22 = linalg.init_tensor [1, %19, %20, %21] : tensor<1x?x?x?xf32>
-              %23 = linalg.fill(%cst, %22) : f32, tensor<1x?x?x?xf32> -> tensor<1x?x?x?xf32>
-              %24 = linalg.conv_2d_nhwc_hwcf {__internal_linalg_transform__ = "workgroup", dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%13, %15 : tensor<1x?x?x3xf32>, tensor<3x3x3x?xf32>) outs(%23 : tensor<1x?x?x?xf32>) -> tensor<1x?x?x?xf32>
-              flow.dispatch.tensor.store %24, %2, offsets = [0, %arg0, %arg1, %arg2], sizes = [1, %16, %17, %18], strides = [1, 1, 1, 1] : tensor<1x?x?x?xf32> -> !flow.dispatch.tensor<writeonly:1x112x112x512xf32>
-            }
-          }
-        }
+        %13 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0, 0], sizes = [1, 225, 225, 3], strides = [1, 1, 1, 1]
+            : !flow.dispatch.tensor<readonly:1x225x225x3xf32> -> tensor<1x225x225x3xf32>
+        %15 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0, 0], sizes = [3, 3, 3, 512], strides = [1, 1, 1, 1]
+            : !flow.dispatch.tensor<readonly:3x3x3x512xf32> -> tensor<3x3x3x512xf32>
+        %22 = linalg.init_tensor [1, 112, 112, 512] : tensor<1x112x112x512xf32>
+        %23 = linalg.fill(%cst, %22) : f32, tensor<1x112x112x512xf32> -> tensor<1x112x112x512xf32>
+        %24 = linalg.conv_2d_nhwc_hwcf {__internal_linalg_transform__ = "workgroup", dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>}
+            ins(%13, %15 : tensor<1x225x225x3xf32>, tensor<3x3x3x512xf32>)
+            outs(%23 : tensor<1x112x112x512xf32>) -> tensor<1x112x112x512xf32>
+        flow.dispatch.tensor.store %24, %2, offsets = [0, 0, 0, 0], sizes = [1, 112, 112, 512], strides = [1, 1, 1, 1]
+            : tensor<1x112x112x512xf32> -> !flow.dispatch.tensor<writeonly:1x112x112x512xf32>
         return
       }
     }
@@ -72,17 +45,10 @@
 }
 
 //  CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering.config<tile_sizes = {{\[}}[0, 1, 4, 64], [0, 1, 4, 4], [0, 0, 0, 0, 1, 1, 4]{{\]}}, native_vector_size = []>
-//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"SPIRVVectorize", workload_per_wg = [64, 4, 1]>
-//  CHECK-DAG: #[[MAP_X:.+]] = affine_map<()[s0] -> (s0 ceildiv 64)
-//  CHECK-DAG: #[[MAP_Y:.+]] = affine_map<()[s0] -> (s0 ceildiv 4)>
+//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"SPIRVVectorize", workload_per_wg = []>
 //      CHECK: hal.executable.entry_point public @conv_112x112x512
 // CHECK-SAME:   translation.info = #[[TRANSLATION]]
 // CHECK-SAME:   workgroup_size = [16 : index, 1 : index, 1 : index]
-// CHECK-NEXT: ^{{.+}}(%[[X:.+]]: index, %[[Y:.+]]: index, %[[Z:.+]]: index):
-// CHECK-NEXT:   %[[X_COUNT:.+]] = affine.apply #[[MAP_X]]()[%[[X]]]
-// CHECK-NEXT:   %[[Y_COUNT:.+]] = affine.apply #[[MAP_Y]]()[%[[Y]]]
-// CHECK-NEXT:   hal.return %[[X_COUNT]], %[[Y_COUNT]], %[[Z]]
-
 //      CHECK: func @conv_112x112x512()
 //      CHECK:   linalg.conv_2d_nhwc_hwcf
 // CHECK-SAME:     lowering.config = #[[CONFIG]]
@@ -116,62 +82,27 @@
         %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) : !flow.dispatch.tensor<readonly:1x225x225x3xf32>
         %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) : !flow.dispatch.tensor<readonly:3x3x3x32xf32>
         %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) : !flow.dispatch.tensor<writeonly:1x112x112x32xf32>
-        %workgroup_size_x = hal.interface.workgroup.size[0] : index
-        %workgroup_size_y = hal.interface.workgroup.size[1] : index
-        %workgroup_size_z = hal.interface.workgroup.size[2] : index
-        %workgroup_id_x = hal.interface.workgroup.id[0] : index
-        %workgroup_count_x = hal.interface.workgroup.count[0] : index
-        %workgroup_id_y = hal.interface.workgroup.id[1] : index
-        %workgroup_count_y = hal.interface.workgroup.count[1] : index
-        %workgroup_id_z = hal.interface.workgroup.id[2] : index
-        %workgroup_count_z = hal.interface.workgroup.count[2] : index
-        %3 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_z, %workgroup_size_z]
-        %4 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_z, %workgroup_size_z]
-        scf.for %arg0 = %3 to %c112 step %4 {
-          %5 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_y, %workgroup_size_y]
-          %6 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_y, %workgroup_size_y]
-          scf.for %arg1 = %5 to %c112 step %6 {
-            %7 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_x, %workgroup_size_x]
-            %8 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_x, %workgroup_size_x]
-            scf.for %arg2 = %7 to %c32 step %8 {
-              %9 = affine.apply affine_map<(d0) -> (d0 * 2)>(%arg0)
-              %10 = affine.min affine_map<(d0)[s0] -> (s0 * 2 + 1, d0 * -2 + 225)>(%arg0)[%workgroup_size_z]
-              %11 = affine.apply affine_map<(d0) -> (d0 * 2)>(%arg1)
-              %12 = affine.min affine_map<(d0)[s0] -> (s0 * 2 + 1, d0 * -2 + 225)>(%arg1)[%workgroup_size_y]
-              %13 = flow.dispatch.tensor.load %0, offsets = [0, %9, %11, 0], sizes = [1, %10, %12, 3], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:1x225x225x3xf32> -> tensor<1x?x?x3xf32>
-              %14 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 32)>(%arg2)[%workgroup_size_x]
-              %15 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0, %arg2], sizes = [3, 3, 3, %14], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:3x3x3x32xf32> -> tensor<3x3x3x?xf32>
-              %16 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 112)>(%arg0)[%workgroup_size_z]
-              %17 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 112)>(%arg1)[%workgroup_size_y]
-              %18 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 32)>(%arg2)[%workgroup_size_x]
-              %19 = affine.min affine_map<(d0)[s0] -> (-d0 + 112, s0)>(%arg0)[%workgroup_size_z]
-              %20 = affine.min affine_map<(d0)[s0] -> (-d0 + 112, s0)>(%arg1)[%workgroup_size_y]
-              %21 = affine.min affine_map<(d0)[s0] -> (-d0 + 32, s0)>(%arg2)[%workgroup_size_x]
-              %22 = linalg.init_tensor [1, %19, %20, %21] : tensor<1x?x?x?xf32>
-              %23 = linalg.fill(%cst, %22) : f32, tensor<1x?x?x?xf32> -> tensor<1x?x?x?xf32>
-              %24 = linalg.conv_2d_nhwc_hwcf {__internal_linalg_transform__ = "workgroup", dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%13, %15 : tensor<1x?x?x3xf32>, tensor<3x3x3x?xf32>) outs(%23 : tensor<1x?x?x?xf32>) -> tensor<1x?x?x?xf32>
-              flow.dispatch.tensor.store %24, %2, offsets = [0, %arg0, %arg1, %arg2], sizes = [1, %16, %17, %18], strides = [1, 1, 1, 1] : tensor<1x?x?x?xf32> -> !flow.dispatch.tensor<writeonly:1x112x112x32xf32>
-            }
-          }
-        }
-        return
+        %13 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0, 0], sizes = [1, 225, 225, 3], strides = [1, 1, 1, 1]
+            : !flow.dispatch.tensor<readonly:1x225x225x3xf32> -> tensor<1x225x225x3xf32>
+        %15 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0, 0], sizes = [3, 3, 3, 32], strides = [1, 1, 1, 1]
+            : !flow.dispatch.tensor<readonly:3x3x3x32xf32> -> tensor<3x3x3x32xf32>
+        %22 = linalg.init_tensor [1, 112, 112, 32] : tensor<1x112x112x32xf32>
+        %23 = linalg.fill(%cst, %22) : f32, tensor<1x112x112x32xf32> -> tensor<1x112x112x32xf32>
+        %24 = linalg.conv_2d_nhwc_hwcf {__internal_linalg_transform__ = "workgroup", dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>}
+            ins(%13, %15 : tensor<1x225x225x3xf32>, tensor<3x3x3x32xf32>) outs(%23 : tensor<1x112x112x32xf32>) -> tensor<1x112x112x32xf32>
+        flow.dispatch.tensor.store %24, %2, offsets = [0, 0, 0, 0], sizes = [1, 112, 112, 32], strides = [1, 1, 1, 1]
+            : tensor<1x112x112x32xf32> -> !flow.dispatch.tensor<writeonly:1x112x112x32xf32>
+       return
       }
     }
   }
 }
 
 //  CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering.config<tile_sizes = {{\[}}[0, 1, 8, 32], [0, 1, 4, 4], [0, 0, 0, 0, 1, 1, 4]{{\]}}, native_vector_size = []>
-//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"SPIRVVectorize", workload_per_wg = [32, 8, 1]>
-//  CHECK-DAG: #[[MAP_X:.+]] = affine_map<()[s0] -> (s0 ceildiv 32)
-//  CHECK-DAG: #[[MAP_Y:.+]] = affine_map<()[s0] -> (s0 ceildiv 8)>
+//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"SPIRVVectorize", workload_per_wg = []>
 //      CHECK: hal.executable.entry_point public @conv_112x112x32
 // CHECK-SAME:   translation.info = #[[TRANSLATION]]
 // CHECK-SAME:   workgroup_size = [8 : index, 2 : index, 1 : index]
-// CHECK-NEXT: ^{{.+}}(%[[X:.+]]: index, %[[Y:.+]]: index, %[[Z:.+]]: index):
-// CHECK-NEXT:   %[[X_COUNT:.+]] = affine.apply #[[MAP_X]]()[%[[X]]]
-// CHECK-NEXT:   %[[Y_COUNT:.+]] = affine.apply #[[MAP_Y]]()[%[[Y]]]
-// CHECK-NEXT:   hal.return %[[X_COUNT]], %[[Y_COUNT]], %[[Z]]
-
 //      CHECK: func @conv_112x112x32()
 //      CHECK:   linalg.conv_2d_nhwc_hwcf
 // CHECK-SAME:     lowering.config = #[[CONFIG]]
@@ -204,44 +135,16 @@
         %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) : !flow.dispatch.tensor<readonly:1x33x33x3xf32>
         %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) : !flow.dispatch.tensor<readonly:3x3x3x16xf32>
         %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) : !flow.dispatch.tensor<writeonly:1x16x16x16xf32>
-        %workgroup_size_x = hal.interface.workgroup.size[0] : index
-        %workgroup_size_y = hal.interface.workgroup.size[1] : index
-        %workgroup_size_z = hal.interface.workgroup.size[2] : index
-        %workgroup_id_x = hal.interface.workgroup.id[0] : index
-        %workgroup_count_x = hal.interface.workgroup.count[0] : index
-        %workgroup_id_y = hal.interface.workgroup.id[1] : index
-        %workgroup_count_y = hal.interface.workgroup.count[1] : index
-        %workgroup_id_z = hal.interface.workgroup.id[2] : index
-        %workgroup_count_z = hal.interface.workgroup.count[2] : index
-        %3 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_z, %workgroup_size_z]
-        %4 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_z, %workgroup_size_z]
-        scf.for %arg0 = %3 to %c16 step %4 {
-          %5 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_y, %workgroup_size_y]
-          %6 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_y, %workgroup_size_y]
-          scf.for %arg1 = %5 to %c16 step %6 {
-            %7 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_x, %workgroup_size_x]
-            %8 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_x, %workgroup_size_x]
-            scf.for %arg2 = %7 to %c16 step %8 {
-              %9 = affine.apply affine_map<(d0) -> (d0 * 2)>(%arg0)
-              %10 = affine.min affine_map<(d0)[s0] -> (s0 * 2 + 1, d0 * -2 + 33)>(%arg0)[%workgroup_size_z]
-              %11 = affine.apply affine_map<(d0) -> (d0 * 2)>(%arg1)
-              %12 = affine.min affine_map<(d0)[s0] -> (s0 * 2 + 1, d0 * -2 + 33)>(%arg1)[%workgroup_size_y]
-              %13 = flow.dispatch.tensor.load %0, offsets = [0, %9, %11, 0], sizes = [1, %10, %12, 3], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:1x33x33x3xf32> -> tensor<1x?x?x3xf32>
-              %14 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 16)>(%arg2)[%workgroup_size_x]
-              %15 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0, %arg2], sizes = [3, 3, 3, %14], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:3x3x3x16xf32> -> tensor<3x3x3x?xf32>
-              %16 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 16)>(%arg0)[%workgroup_size_z]
-              %17 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 16)>(%arg1)[%workgroup_size_y]
-              %18 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 16)>(%arg2)[%workgroup_size_x]
-              %19 = affine.min affine_map<(d0)[s0] -> (-d0 + 16, s0)>(%arg0)[%workgroup_size_z]
-              %20 = affine.min affine_map<(d0)[s0] -> (-d0 + 16, s0)>(%arg1)[%workgroup_size_y]
-              %21 = affine.min affine_map<(d0)[s0] -> (-d0 + 16, s0)>(%arg2)[%workgroup_size_x]
-              %22 = linalg.init_tensor [1, %19, %20, %21] : tensor<1x?x?x?xf32>
-              %23 = linalg.fill(%cst, %22) : f32, tensor<1x?x?x?xf32> -> tensor<1x?x?x?xf32>
-              %24 = linalg.conv_2d_nhwc_hwcf {__internal_linalg_transform__ = "workgroup", dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%13, %15 : tensor<1x?x?x3xf32>, tensor<3x3x3x?xf32>) outs(%23 : tensor<1x?x?x?xf32>) -> tensor<1x?x?x?xf32>
-              flow.dispatch.tensor.store %24, %2, offsets = [0, %arg0, %arg1, %arg2], sizes = [1, %16, %17, %18], strides = [1, 1, 1, 1] : tensor<1x?x?x?xf32> -> !flow.dispatch.tensor<writeonly:1x16x16x16xf32>
-            }
-          }
-        }
+        %13 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0, 0], sizes = [1, 33, 33, 3], strides = [1, 1, 1, 1]
+            : !flow.dispatch.tensor<readonly:1x33x33x3xf32> -> tensor<1x33x33x3xf32>
+        %15 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0, 0], sizes = [3, 3, 3, 16], strides = [1, 1, 1, 1]
+            : !flow.dispatch.tensor<readonly:3x3x3x16xf32> -> tensor<3x3x3x16xf32>
+        %22 = linalg.init_tensor [1, 16, 16, 16] : tensor<1x16x16x16xf32>
+        %23 = linalg.fill(%cst, %22) : f32, tensor<1x16x16x16xf32> -> tensor<1x16x16x16xf32>
+        %24 = linalg.conv_2d_nhwc_hwcf {__internal_linalg_transform__ = "workgroup", dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>}
+            ins(%13, %15 : tensor<1x33x33x3xf32>, tensor<3x3x3x16xf32>) outs(%23 : tensor<1x16x16x16xf32>) -> tensor<1x16x16x16xf32>
+        flow.dispatch.tensor.store %24, %2, offsets = [0, 0, 0, 0], sizes = [1, 16, 16, 16], strides = [1, 1, 1, 1]
+            : tensor<1x16x16x16xf32> -> !flow.dispatch.tensor<writeonly:1x16x16x16xf32>
         return
       }
     }
@@ -249,18 +152,10 @@
 }
 
 //  CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering.config<tile_sizes = {{\[}}[0, 4, 4, 16], [0, 2, 2, 4], [0, 0, 0, 0, 1, 1, 4]{{\]}}, native_vector_size = []>
-//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"SPIRVVectorize", workload_per_wg = [16, 4, 4]>
-//  CHECK-DAG: #[[MAP_X:.+]] = affine_map<()[s0] -> (s0 ceildiv 16)
-//  CHECK-DAG: #[[MAP_YZ:.+]] = affine_map<()[s0] -> (s0 ceildiv 4)>
+//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"SPIRVVectorize", workload_per_wg = []>
 //      CHECK: hal.executable.entry_point public @conv_16x16x16
 // CHECK-SAME:   translation.info = #[[TRANSLATION]]
 // CHECK-SAME:   workgroup_size = [4 : index, 2 : index, 2 : index]
-// CHECK-NEXT: ^{{.+}}(%[[X:.+]]: index, %[[Y:.+]]: index, %[[Z:.+]]: index):
-// CHECK-NEXT:   %[[X_COUNT:.+]] = affine.apply #[[MAP_X]]()[%[[X]]]
-// CHECK-NEXT:   %[[Y_COUNT:.+]] = affine.apply #[[MAP_YZ]]()[%[[Y]]]
-// CHECK-NEXT:   %[[Z_COUNT:.+]] = affine.apply #[[MAP_YZ]]()[%[[Z]]]
-// CHECK-NEXT:   hal.return %[[X_COUNT]], %[[Y_COUNT]], %[[Z_COUNT]]
-
 //      CHECK: func @conv_16x16x16()
 //      CHECK:   linalg.conv_2d_nhwc_hwcf
 // CHECK-SAME:     lowering.config = #[[CONFIG]]
@@ -294,45 +189,16 @@
         %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) : !flow.dispatch.tensor<readonly:1x57x57x144xf32>
         %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) : !flow.dispatch.tensor<readonly:3x3x144xf32>
         %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) : !flow.dispatch.tensor<writeonly:1x28x28x144xf32>
-        %workgroup_size_x = hal.interface.workgroup.size[0] : index
-        %workgroup_size_y = hal.interface.workgroup.size[1] : index
-        %workgroup_size_z = hal.interface.workgroup.size[2] : index
-        %workgroup_id_x = hal.interface.workgroup.id[0] : index
-        %workgroup_count_x = hal.interface.workgroup.count[0] : index
-        %workgroup_id_y = hal.interface.workgroup.id[1] : index
-        %workgroup_count_y = hal.interface.workgroup.count[1] : index
-        %workgroup_id_z = hal.interface.workgroup.id[2] : index
-        %workgroup_count_z = hal.interface.workgroup.count[2] : index
-        %3 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_z, %workgroup_size_z]
-        %4 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_z, %workgroup_size_z]
-        scf.for %arg0 = %3 to %c28 step %4 {
-          %5 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_y, %workgroup_size_y]
-          %6 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_y, %workgroup_size_y]
-          scf.for %arg1 = %5 to %c28 step %6 {
-            %7 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_x, %workgroup_size_x]
-            %8 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_x, %workgroup_size_x]
-            scf.for %arg2 = %7 to %c144 step %8 {
-              %9 = affine.apply affine_map<(d0) -> (d0 * 2)>(%arg0)
-              %10 = affine.min affine_map<(d0)[s0] -> (s0 * 2 + 1, d0 * -2 + 57)>(%arg0)[%workgroup_size_z]
-              %11 = affine.apply affine_map<(d0) -> (d0 * 2)>(%arg1)
-              %12 = affine.min affine_map<(d0)[s0] -> (s0 * 2 + 1, d0 * -2 + 57)>(%arg1)[%workgroup_size_y]
-              %13 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 144)>(%arg2)[%workgroup_size_x]
-              %14 = flow.dispatch.tensor.load %0, offsets = [0, %9, %11, %arg2], sizes = [1, %10, %12, %13], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:1x57x57x144xf32> -> tensor<1x?x?x?xf32>
-              %15 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 144)>(%arg2)[%workgroup_size_x]
-              %16 = flow.dispatch.tensor.load %1, offsets = [0, 0, %arg2], sizes = [3, 3, %15], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:3x3x144xf32> -> tensor<3x3x?xf32>
-              %17 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 28)>(%arg0)[%workgroup_size_z]
-              %18 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 28)>(%arg1)[%workgroup_size_y]
-              %19 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 144)>(%arg2)[%workgroup_size_x]
-              %20 = affine.min affine_map<(d0)[s0] -> (-d0 + 28, s0)>(%arg0)[%workgroup_size_z]
-              %21 = affine.min affine_map<(d0)[s0] -> (-d0 + 28, s0)>(%arg1)[%workgroup_size_y]
-              %22 = affine.min affine_map<(d0)[s0] -> (-d0 + 144, s0)>(%arg2)[%workgroup_size_x]
-              %23 = linalg.init_tensor [1, %20, %21, %22] : tensor<1x?x?x?xf32>
-              %24 = linalg.fill(%cst, %23) : f32, tensor<1x?x?x?xf32> -> tensor<1x?x?x?xf32>
-              %25 = linalg.depthwise_conv_2d_nhwc_hwc {__internal_linalg_transform__ = "workgroup", dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%14, %16 : tensor<1x?x?x?xf32>, tensor<3x3x?xf32>) outs(%24 : tensor<1x?x?x?xf32>) -> tensor<1x?x?x?xf32>
-              flow.dispatch.tensor.store %25, %2, offsets = [0, %arg0, %arg1, %arg2], sizes = [1, %17, %18, %19], strides = [1, 1, 1, 1] : tensor<1x?x?x?xf32> -> !flow.dispatch.tensor<writeonly:1x28x28x144xf32>
-            }
-          }
-        }
+        %14 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0, 0], sizes = [0, 57, 57, 144], strides = [1, 1, 1, 1]
+            : !flow.dispatch.tensor<readonly:1x57x57x144xf32> -> tensor<1x57x57x144xf32>
+        %16 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0], sizes = [3, 3, 144], strides = [1, 1, 1]
+            : !flow.dispatch.tensor<readonly:3x3x144xf32> -> tensor<3x3x144xf32>
+        %23 = linalg.init_tensor [1, 28, 28, 144] : tensor<1x28x28x144xf32>
+        %24 = linalg.fill(%cst, %23) : f32, tensor<1x28x28x144xf32> -> tensor<1x28x28x144xf32>
+        %25 = linalg.depthwise_conv_2d_nhwc_hwc {__internal_linalg_transform__ = "workgroup", dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>}
+            ins(%14, %16 : tensor<1x57x57x144xf32>, tensor<3x3x144xf32>) outs(%24 : tensor<1x28x28x144xf32>) -> tensor<1x28x28x144xf32>
+        flow.dispatch.tensor.store %25, %2, offsets = [0, 0, 0, 0], sizes = [1, 28, 28, 144], strides = [1, 1, 1, 1]
+            : tensor<1x28x28x144xf32> -> !flow.dispatch.tensor<writeonly:1x28x28x144xf32>
         return
       }
     }
@@ -340,18 +206,10 @@
 }
 
 //  CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering.config<tile_sizes = {{\[}}[0, 4, 4, 16], [0, 2, 2, 4], [0, 0, 0, 0, 1, 1]{{\]}}, native_vector_size = []>
-//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"SPIRVVectorize", workload_per_wg = [16, 4, 4]>
-//  CHECK-DAG: #[[MAP_X:.+]] = affine_map<()[s0] -> (s0 ceildiv 16)
-//  CHECK-DAG: #[[MAP_YZ:.+]] = affine_map<()[s0] -> (s0 ceildiv 4)>
+//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"SPIRVVectorize", workload_per_wg = []>
 //      CHECK: hal.executable.entry_point public @dwconv_28x28x144
 // CHECK-SAME:   translation.info = #[[TRANSLATION]]
 // CHECK-SAME:   workgroup_size = [4 : index, 2 : index, 2 : index]
-// CHECK-NEXT: ^{{.+}}(%[[X:.+]]: index, %[[Y:.+]]: index, %[[Z:.+]]: index):
-// CHECK-NEXT:   %[[X_COUNT:.+]] = affine.apply #[[MAP_X]]()[%[[X]]]
-// CHECK-NEXT:   %[[Y_COUNT:.+]] = affine.apply #[[MAP_YZ]]()[%[[Y]]]
-// CHECK-NEXT:   %[[Z_COUNT:.+]] = affine.apply #[[MAP_YZ]]()[%[[Z]]]
-// CHECK-NEXT:   hal.return %[[X_COUNT]], %[[Y_COUNT]], %[[Z_COUNT]]
-
 //      CHECK: func @dwconv_28x28x144()
 //      CHECK:   linalg.depthwise_conv_2d_nhwc_hwc
 // CHECK-SAME:     lowering.config = #[[CONFIG]]
@@ -386,45 +244,16 @@
         %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) : !flow.dispatch.tensor<readonly:1x3x5x8xf32>
         %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) : !flow.dispatch.tensor<readonly:3x3x8xf32>
         %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) : !flow.dispatch.tensor<writeonly:1x1x2x8xf32>
-        %workgroup_size_x = hal.interface.workgroup.size[0] : index
-        %workgroup_size_y = hal.interface.workgroup.size[1] : index
-        %workgroup_size_z = hal.interface.workgroup.size[2] : index
-        %workgroup_id_x = hal.interface.workgroup.id[0] : index
-        %workgroup_count_x = hal.interface.workgroup.count[0] : index
-        %workgroup_id_y = hal.interface.workgroup.id[1] : index
-        %workgroup_count_y = hal.interface.workgroup.count[1] : index
-        %workgroup_id_z = hal.interface.workgroup.id[2] : index
-        %workgroup_count_z = hal.interface.workgroup.count[2] : index
-        %3 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_z, %workgroup_size_z]
-        %4 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_z, %workgroup_size_z]
-        scf.for %arg0 = %3 to %c1 step %4 {
-          %5 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_y, %workgroup_size_y]
-          %6 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_y, %workgroup_size_y]
-          scf.for %arg1 = %5 to %c2 step %6 {
-            %7 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_x, %workgroup_size_x]
-            %8 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_x, %workgroup_size_x]
-            scf.for %arg2 = %7 to %c8 step %8 {
-              %9 = affine.apply affine_map<(d0) -> (d0 * 2)>(%arg0)
-              %10 = affine.min affine_map<(d0)[s0] -> (s0 * 2 + 1, d0 * -2 + 5)>(%arg0)[%workgroup_size_z]
-              %11 = affine.apply affine_map<(d0) -> (d0 * 2)>(%arg1)
-              %12 = affine.min affine_map<(d0)[s0] -> (s0 * 2 + 1, d0 * -2 + 7)>(%arg1)[%workgroup_size_y]
-              %13 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 8)>(%arg2)[%workgroup_size_x]
-              %14 = flow.dispatch.tensor.load %0, offsets = [0, %9, %11, %arg2], sizes = [1, %10, %12, %13], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:1x3x5x8xf32> -> tensor<1x?x?x?xf32>
-              %15 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 8)>(%arg2)[%workgroup_size_x]
-              %16 = flow.dispatch.tensor.load %1, offsets = [0, 0, %arg2], sizes = [3, 3, %15], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:3x3x8xf32> -> tensor<3x3x?xf32>
-              %17 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 1)>(%arg0)[%workgroup_size_z]
-              %18 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 2)>(%arg1)[%workgroup_size_y]
-              %19 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 8)>(%arg2)[%workgroup_size_x]
-              %20 = affine.min affine_map<(d0)[s0] -> (-d0 + 1, s0)>(%arg0)[%workgroup_size_z]
-              %21 = affine.min affine_map<(d0)[s0] -> (-d0 + 2, s0)>(%arg1)[%workgroup_size_y]
-              %22 = affine.min affine_map<(d0)[s0] -> (-d0 + 8, s0)>(%arg2)[%workgroup_size_x]
-              %23 = linalg.init_tensor [1, %20, %21, %22] : tensor<1x?x?x?xf32>
-              %24 = linalg.fill(%cst, %23) : f32, tensor<1x?x?x?xf32> -> tensor<1x?x?x?xf32>
-              %25 = linalg.depthwise_conv_2d_nhwc_hwc {__internal_linalg_transform__ = "workgroup", dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%14, %16 : tensor<1x?x?x?xf32>, tensor<3x3x?xf32>) outs(%24 : tensor<1x?x?x?xf32>) -> tensor<1x?x?x?xf32>
-              flow.dispatch.tensor.store %25, %2, offsets = [0, %arg0, %arg1, %arg2], sizes = [1, %17, %18, %19], strides = [1, 1, 1, 1] : tensor<1x?x?x?xf32> -> !flow.dispatch.tensor<writeonly:1x1x2x8xf32>
-            }
-          }
-        }
+        %14 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0, 0], sizes = [1, 3, 5, 8], strides = [1, 1, 1, 1]
+            : !flow.dispatch.tensor<readonly:1x3x5x8xf32> -> tensor<1x3x5x8xf32>
+        %16 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0], sizes = [3, 3, 8], strides = [1, 1, 1]
+            : !flow.dispatch.tensor<readonly:3x3x8xf32> -> tensor<3x3x8xf32>
+        %23 = linalg.init_tensor [1, 1, 2, 8] : tensor<1x1x2x8xf32>
+        %24 = linalg.fill(%cst, %23) : f32, tensor<1x1x2x8xf32> -> tensor<1x1x2x8xf32>
+        %25 = linalg.depthwise_conv_2d_nhwc_hwc {__internal_linalg_transform__ = "workgroup", dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>}
+            ins(%14, %16 : tensor<1x3x5x8xf32>, tensor<3x3x8xf32>) outs(%24 : tensor<1x1x2x8xf32>) -> tensor<1x1x2x8xf32>
+        flow.dispatch.tensor.store %25, %2, offsets = [0, 0, 0, 0], sizes = [1, 1, 2, 8], strides = [1, 1, 1, 1]
+            : tensor<1x1x2x8xf32> -> !flow.dispatch.tensor<writeonly:1x1x2x8xf32>
         return
       }
     }
@@ -432,17 +261,10 @@
 }
 
 //  CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering.config<tile_sizes = {{\[}}[0, 1, 2, 8], [0, 1, 1, 4], [0, 0, 0, 0, 1, 1]{{\]}}, native_vector_size = []>
-//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"SPIRVVectorize", workload_per_wg = [8, 2, 1]>
-//  CHECK-DAG: #[[MAP_X:.+]] = affine_map<()[s0] -> (s0 ceildiv 8)
-//  CHECK-DAG: #[[MAP_Y:.+]] = affine_map<()[s0] -> (s0 ceildiv 2)>
+//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"SPIRVVectorize", workload_per_wg = []>
 //      CHECK: hal.executable.entry_point public @dwconv_1x2x8
 // CHECK-SAME:   translation.info = #[[TRANSLATION]]
 // CHECK-SAME:   workgroup_size = [2 : index, 2 : index, 1 : index]
-// CHECK-NEXT: ^{{.+}}(%[[X:.+]]: index, %[[Y:.+]]: index, %[[Z:.+]]: index):
-// CHECK-NEXT:   %[[X_COUNT:.+]] = affine.apply #[[MAP_X]]()[%[[X]]]
-// CHECK-NEXT:   %[[Y_COUNT:.+]] = affine.apply #[[MAP_Y]]()[%[[Y]]]
-// CHECK-NEXT:   hal.return %[[X_COUNT]], %[[Y_COUNT]], %[[Z]]
-
 //      CHECK: func @dwconv_1x2x8()
 //      CHECK:   linalg.depthwise_conv_2d_nhwc_hwc
 // CHECK-SAME:     lowering.config = #[[CONFIG]]
diff --git a/iree/compiler/Codegen/SPIRV/test/config_mali_matmul.mlir b/iree/compiler/Codegen/SPIRV/test/config_mali_matmul.mlir
index 4b43ccf..b87eecd 100644
--- a/iree/compiler/Codegen/SPIRV/test/config_mali_matmul.mlir
+++ b/iree/compiler/Codegen/SPIRV/test/config_mali_matmul.mlir
@@ -27,32 +27,14 @@
         %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) : !flow.dispatch.tensor<readonly:1024x512xf32>
         %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) : !flow.dispatch.tensor<readonly:512x2048xf32>
         %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) : !flow.dispatch.tensor<writeonly:1024x2048xf32>
-        %workgroup_size_x = hal.interface.workgroup.size[0] : index
-        %workgroup_size_y = hal.interface.workgroup.size[1] : index
-        %workgroup_id_x = hal.interface.workgroup.id[0] : index
-        %workgroup_count_x = hal.interface.workgroup.count[0] : index
-        %workgroup_id_y = hal.interface.workgroup.id[1] : index
-        %workgroup_count_y = hal.interface.workgroup.count[1] : index
-        %3 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_y, %workgroup_size_y]
-        %4 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_y, %workgroup_size_y]
-        scf.for %arg0 = %3 to %c1024 step %4 {
-          %5 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_x, %workgroup_size_x]
-          %6 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_x, %workgroup_size_x]
-          scf.for %arg1 = %5 to %c2048 step %6 {
-            %7 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 1024)>(%arg0)[%workgroup_size_y]
-            %8 = flow.dispatch.tensor.load %0, offsets = [%arg0, 0], sizes = [%7, 512], strides = [1, 1] : !flow.dispatch.tensor<readonly:1024x512xf32> -> tensor<?x512xf32>
-            %9 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 2048)>(%arg1)[%workgroup_size_x]
-            %10 = flow.dispatch.tensor.load %1, offsets = [0, %arg1], sizes = [512, %9], strides = [1, 1] : !flow.dispatch.tensor<readonly:512x2048xf32> -> tensor<512x?xf32>
-            %11 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 1024)>(%arg0)[%workgroup_size_y]
-            %12 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 2048)>(%arg1)[%workgroup_size_x]
-            %13 = affine.min affine_map<(d0)[s0] -> (-d0 + 1024, s0)>(%arg0)[%workgroup_size_y]
-            %14 = affine.min affine_map<(d0)[s0] -> (-d0 + 2048, s0)>(%arg1)[%workgroup_size_x]
-            %15 = linalg.init_tensor [%13, %14] : tensor<?x?xf32>
-            %16 = linalg.fill(%cst, %15) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
-            %17 = linalg.matmul {__internal_linalg_transform__ = "workgroup"} ins(%8, %10 : tensor<?x512xf32>, tensor<512x?xf32>) outs(%16 : tensor<?x?xf32>) -> tensor<?x?xf32>
-            flow.dispatch.tensor.store %17, %2, offsets = [%arg0, %arg1], sizes = [%11, %12], strides = [1, 1] : tensor<?x?xf32> -> !flow.dispatch.tensor<writeonly:1024x2048xf32>
-          }
-        }
+        %8 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [1024, 512], strides = [1, 1] : !flow.dispatch.tensor<readonly:1024x512xf32> -> tensor<1024x512xf32>
+        %10 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [512, 2048], strides = [1, 1] : !flow.dispatch.tensor<readonly:512x2048xf32> -> tensor<512x2048xf32>
+        %15 = linalg.init_tensor [1024, 2048] : tensor<1024x2048xf32>
+        %16 = linalg.fill(%cst, %15) : f32, tensor<1024x2048xf32> -> tensor<1024x2048xf32>
+        %17 = linalg.matmul {__internal_linalg_transform__ = "workgroup"}
+            ins(%8, %10 : tensor<1024x512xf32>, tensor<512x2048xf32>) outs(%16 : tensor<1024x2048xf32>) -> tensor<1024x2048xf32>
+        flow.dispatch.tensor.store %17, %2, offsets = [0, 0], sizes = [1024, 2048], strides = [1, 1]
+            : tensor<1024x2048xf32> -> !flow.dispatch.tensor<writeonly:1024x2048xf32>
         return
       }
     }
@@ -60,18 +42,10 @@
 }
 
 //  CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering.config<tile_sizes = {{\[}}[8, 32], [4, 4], [0, 0, 4]{{\]}}, native_vector_size = []>
-//  CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 32)>
-//  CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 ceildiv 8)>
-//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"SPIRVVectorize", workload_per_wg = [32, 8]>
+//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"SPIRVVectorize", workload_per_wg = []>
 //      CHECK: hal.executable.entry_point public @matmul_1024x2048x512
 // CHECK-SAME:   translation.info = #[[TRANSLATION]]
 // CHECK-SAME:   workgroup_size = [8 : index, 2 : index, 1 : index]
-// CHECK-NEXT: ^{{.+}}(%[[X:.+]]: index, %[[Y:.+]]: index, %{{.+}}: index):
-// CHECK-NEXT:   %[[ONE:.+]] = arith.constant 1 : index
-// CHECK-NEXT:   %[[X_COUNT:.+]] = affine.apply #[[MAP0]]()[%[[X]]]
-// CHECK-NEXT:   %[[Y_COUNT:.+]] = affine.apply #[[MAP1]]()[%[[Y]]]
-// CHECK-NEXT:   hal.return %[[X_COUNT]], %[[Y_COUNT]], %[[ONE]]
-
 //      CHECK: func @matmul_1024x2048x512()
 //      CHECK:   linalg.matmul
 // CHECK-SAME:     lowering.config = #[[CONFIG]]
@@ -105,32 +79,15 @@
         %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) : !flow.dispatch.tensor<readonly:3136x96xf32>
         %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) : !flow.dispatch.tensor<readonly:96x24xf32>
         %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) : !flow.dispatch.tensor<writeonly:3136x24xf32>
-        %workgroup_size_x = hal.interface.workgroup.size[0] : index
-        %workgroup_size_y = hal.interface.workgroup.size[1] : index
-        %workgroup_id_x = hal.interface.workgroup.id[0] : index
-        %workgroup_count_x = hal.interface.workgroup.count[0] : index
-        %workgroup_id_y = hal.interface.workgroup.id[1] : index
-        %workgroup_count_y = hal.interface.workgroup.count[1] : index
-        %3 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_y, %workgroup_size_y]
-        %4 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_y, %workgroup_size_y]
-        scf.for %arg0 = %3 to %c3136 step %4 {
-          %5 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_x, %workgroup_size_x]
-          %6 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_x, %workgroup_size_x]
-          scf.for %arg1 = %5 to %c24 step %6 {
-            %7 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 3136)>(%arg0)[%workgroup_size_y]
-            %8 = flow.dispatch.tensor.load %0, offsets = [%arg0, 0], sizes = [%7, 96], strides = [1, 1] : !flow.dispatch.tensor<readonly:3136x96xf32> -> tensor<?x96xf32>
-            %9 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 24)>(%arg1)[%workgroup_size_x]
-            %10 = flow.dispatch.tensor.load %1, offsets = [0, %arg1], sizes = [96, %9], strides = [1, 1] : !flow.dispatch.tensor<readonly:96x24xf32> -> tensor<96x?xf32>
-            %11 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 3136)>(%arg0)[%workgroup_size_y]
-            %12 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 24)>(%arg1)[%workgroup_size_x]
-            %13 = affine.min affine_map<(d0)[s0] -> (-d0 + 3136, s0)>(%arg0)[%workgroup_size_y]
-            %14 = affine.min affine_map<(d0)[s0] -> (-d0 + 24, s0)>(%arg1)[%workgroup_size_x]
-            %15 = linalg.init_tensor [%13, %14] : tensor<?x?xf32>
-            %16 = linalg.fill(%cst, %15) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
-            %17 = linalg.matmul {__internal_linalg_transform__ = "workgroup"} ins(%8, %10 : tensor<?x96xf32>, tensor<96x?xf32>) outs(%16 : tensor<?x?xf32>) -> tensor<?x?xf32>
-            flow.dispatch.tensor.store %17, %2, offsets = [%arg0, %arg1], sizes = [%11, %12], strides = [1, 1] : tensor<?x?xf32> -> !flow.dispatch.tensor<writeonly:3136x24xf32>
-          }
-        }
+        %8 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [3136, 96], strides = [1, 1] : !flow.dispatch.tensor<readonly:3136x96xf32> -> tensor<3136x96xf32>
+        %10 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [96, 24], strides = [1, 1] : !flow.dispatch.tensor<readonly:96x24xf32> -> tensor<96x24xf32>
+        %15 = linalg.init_tensor [3136, 24] : tensor<3136x24xf32>
+        %16 = linalg.fill(%cst, %15) : f32, tensor<3136x24xf32> -> tensor<3136x24xf32>
+        %17 = linalg.matmul {__internal_linalg_transform__ = "workgroup"}
+            ins(%8, %10 : tensor<3136x96xf32>, tensor<96x24xf32>)
+            outs(%16 : tensor<3136x24xf32>) -> tensor<3136x24xf32>
+        flow.dispatch.tensor.store %17, %2, offsets = [0, 0], sizes = [3136, 24], strides = [1, 1]
+            : tensor<3136x24xf32> -> !flow.dispatch.tensor<writeonly:3136x24xf32>
         return
       }
     }
@@ -138,18 +95,10 @@
 }
 
 //  CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering.config<tile_sizes = {{\[}}[32, 8], [4, 4], [0, 0, 4]{{\]}}, native_vector_size = []>
-//  CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 8)>
-//  CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 ceildiv 32)>
-//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"SPIRVVectorize", workload_per_wg = [8, 32]>
+//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"SPIRVVectorize", workload_per_wg = []>
 //      CHECK: hal.executable.entry_point public @matmul_3136x24x96
 // CHECK-SAME:   translation.info = #[[TRANSLATION]]
 // CHECK-SAME:   workgroup_size = [2 : index, 8 : index, 1 : index]
-// CHECK-NEXT: ^{{.+}}(%[[X:.+]]: index, %[[Y:.+]]: index, %{{.+}}: index):
-// CHECK-NEXT:   %[[ONE:.+]] = arith.constant 1 : index
-// CHECK-NEXT:   %[[X_COUNT:.+]] = affine.apply #[[MAP0]]()[%[[X]]]
-// CHECK-NEXT:   %[[Y_COUNT:.+]] = affine.apply #[[MAP1]]()[%[[Y]]]
-// CHECK-NEXT:   hal.return %[[X_COUNT]], %[[Y_COUNT]], %[[ONE]]
-
 //      CHECK: func @matmul_3136x24x96()
 //      CHECK:   linalg.matmul
 // CHECK-SAME:     lowering.config = #[[CONFIG]]
@@ -183,32 +132,16 @@
         %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) : !flow.dispatch.tensor<readonly:196x192xf32>
         %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) : !flow.dispatch.tensor<readonly:192x64xf32>
         %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) : !flow.dispatch.tensor<writeonly:196x64xf32>
-        %workgroup_size_x = hal.interface.workgroup.size[0] : index
-        %workgroup_size_y = hal.interface.workgroup.size[1] : index
-        %workgroup_id_x = hal.interface.workgroup.id[0] : index
-        %workgroup_count_x = hal.interface.workgroup.count[0] : index
-        %workgroup_id_y = hal.interface.workgroup.id[1] : index
-        %workgroup_count_y = hal.interface.workgroup.count[1] : index
-        %3 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_y, %workgroup_size_y]
-        %4 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_y, %workgroup_size_y]
-        scf.for %arg0 = %3 to %c196 step %4 {
-          %5 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_x, %workgroup_size_x]
-          %6 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_x, %workgroup_size_x]
-          scf.for %arg1 = %5 to %c64 step %6 {
-            %7 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 196)>(%arg0)[%workgroup_size_y]
-            %8 = flow.dispatch.tensor.load %0, offsets = [%arg0, 0], sizes = [%7, 192], strides = [1, 1] : !flow.dispatch.tensor<readonly:196x192xf32> -> tensor<?x192xf32>
-            %9 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 64)>(%arg1)[%workgroup_size_x]
-            %10 = flow.dispatch.tensor.load %1, offsets = [0, %arg1], sizes = [192, %9], strides = [1, 1] : !flow.dispatch.tensor<readonly:192x64xf32> -> tensor<192x?xf32>
-            %11 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 196)>(%arg0)[%workgroup_size_y]
-            %12 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 64)>(%arg1)[%workgroup_size_x]
-            %13 = affine.min affine_map<(d0)[s0] -> (-d0 + 196, s0)>(%arg0)[%workgroup_size_y]
-            %14 = affine.min affine_map<(d0)[s0] -> (-d0 + 64, s0)>(%arg1)[%workgroup_size_x]
-            %15 = linalg.init_tensor [%13, %14] : tensor<?x?xf32>
-            %16 = linalg.fill(%cst, %15) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
-            %17 = linalg.matmul {__internal_linalg_transform__ = "workgroup"} ins(%8, %10 : tensor<?x192xf32>, tensor<192x?xf32>) outs(%16 : tensor<?x?xf32>) -> tensor<?x?xf32>
-            flow.dispatch.tensor.store %17, %2, offsets = [%arg0, %arg1], sizes = [%11, %12], strides = [1, 1] : tensor<?x?xf32> -> !flow.dispatch.tensor<writeonly:196x64xf32>
-          }
-        }
+        %8 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [196, 192], strides = [1, 1]
+            : !flow.dispatch.tensor<readonly:196x192xf32> -> tensor<196x192xf32>
+        %10 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [192, 64], strides = [1, 1]
+            : !flow.dispatch.tensor<readonly:192x64xf32> -> tensor<192x64xf32>
+        %15 = linalg.init_tensor [196, 64] : tensor<196x64xf32>
+        %16 = linalg.fill(%cst, %15) : f32, tensor<196x64xf32> -> tensor<196x64xf32>
+        %17 = linalg.matmul {__internal_linalg_transform__ = "workgroup"}
+            ins(%8, %10 : tensor<196x192xf32>, tensor<192x64xf32>) outs(%16 : tensor<196x64xf32>) -> tensor<196x64xf32>
+        flow.dispatch.tensor.store %17, %2, offsets = [0, 0], sizes = [196, 64], strides = [1, 1]
+            : tensor<196x64xf32> -> !flow.dispatch.tensor<writeonly:196x64xf32>
         return
       }
     }
@@ -216,18 +149,10 @@
 }
 
 //  CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering.config<tile_sizes = {{\[}}[4, 32], [2, 4], [0, 0, 8]{{\]}}, native_vector_size = []>
-//  CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 32)>
-//  CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 ceildiv 4)>
-//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"SPIRVVectorize", workload_per_wg = [32, 4]>
+//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"SPIRVVectorize", workload_per_wg = []>
 //      CHECK: hal.executable.entry_point public @matmul_196x64x192
 // CHECK-SAME:   translation.info = #[[TRANSLATION]]
 // CHECK-SAME:   workgroup_size = [8 : index, 2 : index, 1 : index]
-// CHECK-NEXT: ^{{.+}}(%[[X:.+]]: index, %[[Y:.+]]: index, %{{.+}}: index):
-// CHECK-NEXT:   %[[ONE:.+]] = arith.constant 1 : index
-// CHECK-NEXT:   %[[X_COUNT:.+]] = affine.apply #[[MAP0]]()[%[[X]]]
-// CHECK-NEXT:   %[[Y_COUNT:.+]] = affine.apply #[[MAP1]]()[%[[Y]]]
-// CHECK-NEXT:   hal.return %[[X_COUNT]], %[[Y_COUNT]], %[[ONE]]
-
 //      CHECK: func @matmul_196x64x192()
 //      CHECK:   linalg.matmul
 // CHECK-SAME:      lowering.config = #[[CONFIG]]
@@ -261,27 +186,9 @@
         %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) : memref<12544x16xf32>
         %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) : memref<16x96xf32>
         %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) : memref<12544x96xf32>
-        %workgroup_size_x = hal.interface.workgroup.size[0] : index
-        %workgroup_size_y = hal.interface.workgroup.size[1] : index
-        %workgroup_id_x = hal.interface.workgroup.id[0] : index
-        %workgroup_count_x = hal.interface.workgroup.count[0] : index
-        %workgroup_id_y = hal.interface.workgroup.id[1] : index
-        %workgroup_count_y = hal.interface.workgroup.count[1] : index
-        %3 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_y, %workgroup_size_y]
-        %4 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_y, %workgroup_size_y]
-        scf.for %arg0 = %3 to %c12544 step %4 {
-          %5 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_x, %workgroup_size_x]
-          %6 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_x, %workgroup_size_x]
-          scf.for %arg1 = %5 to %c96 step %6 {
-            %7 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 12544)>(%arg0)[%workgroup_size_y]
-            %8 = memref.subview %0[%arg0, 0] [%7, 16] [1, 1] : memref<12544x16xf32> to memref<?x16xf32, affine_map<(d0, d1)[s0] -> (d0 * 16 + s0 + d1)>>
-            %9 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 96)>(%arg1)[%workgroup_size_x]
-            %10 = memref.subview %1[0, %arg1] [16, %9] [1, 1] : memref<16x96xf32> to memref<16x?xf32, affine_map<(d0, d1)[s0] -> (d0 * 96 + s0 + d1)>>
-            %11 = memref.subview %2[%arg0, %arg1] [%7, %9] [1, 1] : memref<12544x96xf32> to memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d0 * 96 + s0 + d1)>>
-            linalg.fill(%cst, %11) : f32, memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d0 * 96 + s0 + d1)>>
-            linalg.matmul {__internal_linalg_transform__ = "workgroup"} ins(%8, %10 : memref<?x16xf32, affine_map<(d0, d1)[s0] -> (d0 * 16 + s0 + d1)>>, memref<16x?xf32, affine_map<(d0, d1)[s0] -> (d0 * 96 + s0 + d1)>>) outs(%11 : memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d0 * 96 + s0 + d1)>>)
-          }
-        }
+        linalg.fill(%cst, %2) : f32, memref<12544x96xf32>
+        linalg.matmul {__internal_linalg_transform__ = "workgroup"}
+            ins(%0, %1 : memref<12544x16xf32>, memref<16x96xf32>) outs(%2 : memref<12544x96xf32>)
         return
       }
     }
@@ -289,18 +196,10 @@
 }
 
 //  CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering.config<tile_sizes = {{\[}}[8, 32], [4, 4], [0, 0, 4]{{\]}}, native_vector_size = []>
-//  CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 32)>
-//  CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 ceildiv 8)>
-//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"SPIRVVectorize", workload_per_wg = [32, 8]>
+//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"SPIRVVectorize", workload_per_wg = []>
 //      CHECK: hal.executable.entry_point public @matmul_12544x96x16
 // CHECK-SAME:   translation.info = #[[TRANSLATION]]
 // CHECK-SAME:   workgroup_size = [8 : index, 2 : index, 1 : index]
-// CHECK-NEXT: ^{{.+}}(%[[X:.+]]: index, %[[Y:.+]]: index, %{{.+}}: index):
-// CHECK-NEXT:   %[[ONE:.+]] = arith.constant 1 : index
-// CHECK-NEXT:   %[[X_COUNT:.+]] = affine.apply #[[MAP0]]()[%[[X]]]
-// CHECK-NEXT:   %[[Y_COUNT:.+]] = affine.apply #[[MAP1]]()[%[[Y]]]
-// CHECK-NEXT:   hal.return %[[X_COUNT]], %[[Y_COUNT]], %[[ONE]]
-
 //      CHECK: func @matmul_12544x96x16()
 //      CHECK:   linalg.matmul
 // CHECK-SAME:     lowering.config = #[[CONFIG]]
@@ -334,32 +233,16 @@
         %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) : !flow.dispatch.tensor<readonly:49x576xf32>
         %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) : !flow.dispatch.tensor<readonly:576x160xf32>
         %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) : !flow.dispatch.tensor<writeonly:49x160xf32>
-        %workgroup_size_x = hal.interface.workgroup.size[0] : index
-        %workgroup_size_y = hal.interface.workgroup.size[1] : index
-        %workgroup_id_x = hal.interface.workgroup.id[0] : index
-        %workgroup_count_x = hal.interface.workgroup.count[0] : index
-        %workgroup_id_y = hal.interface.workgroup.id[1] : index
-        %workgroup_count_y = hal.interface.workgroup.count[1] : index
-        %3 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_y, %workgroup_size_y]
-        %4 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_y, %workgroup_size_y]
-        scf.for %arg0 = %3 to %c49 step %4 {
-          %5 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_x, %workgroup_size_x]
-          %6 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_x, %workgroup_size_x]
-          scf.for %arg1 = %5 to %c160 step %6 {
-            %7 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 49)>(%arg0)[%workgroup_size_y]
-            %8 = flow.dispatch.tensor.load %0, offsets = [%arg0, 0], sizes = [%7, 576], strides = [1, 1] : !flow.dispatch.tensor<readonly:49x576xf32> -> tensor<?x576xf32>
-            %9 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 160)>(%arg1)[%workgroup_size_x]
-            %10 = flow.dispatch.tensor.load %1, offsets = [0, %arg1], sizes = [576, %9], strides = [1, 1] : !flow.dispatch.tensor<readonly:576x160xf32> -> tensor<576x?xf32>
-            %11 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 49)>(%arg0)[%workgroup_size_y]
-            %12 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 160)>(%arg1)[%workgroup_size_x]
-            %13 = affine.min affine_map<(d0)[s0] -> (-d0 + 49, s0)>(%arg0)[%workgroup_size_y]
-            %14 = affine.min affine_map<(d0)[s0] -> (-d0 + 160, s0)>(%arg1)[%workgroup_size_x]
-            %15 = linalg.init_tensor [%13, %14] : tensor<?x?xf32>
-            %16 = linalg.fill(%cst, %15) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
-            %17 = linalg.matmul {__internal_linalg_transform__ = "workgroup"} ins(%8, %10 : tensor<?x576xf32>, tensor<576x?xf32>) outs(%16 : tensor<?x?xf32>) -> tensor<?x?xf32>
-            flow.dispatch.tensor.store %17, %2, offsets = [%arg0, %arg1], sizes = [%11, %12], strides = [1, 1] : tensor<?x?xf32> -> !flow.dispatch.tensor<writeonly:49x160xf32>
-          }
-        }
+        %8 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [49, 576], strides = [1, 1]
+            : !flow.dispatch.tensor<readonly:49x576xf32> -> tensor<49x576xf32>
+        %10 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [576, 160], strides = [1, 1]
+            : !flow.dispatch.tensor<readonly:576x160xf32> -> tensor<576x160xf32>
+        %15 = linalg.init_tensor [49, 160] : tensor<49x160xf32>
+        %16 = linalg.fill(%cst, %15) : f32, tensor<49x160xf32> -> tensor<49x160xf32>
+        %17 = linalg.matmul {__internal_linalg_transform__ = "workgroup"}
+            ins(%8, %10 : tensor<49x576xf32>, tensor<576x160xf32>) outs(%16 : tensor<49x160xf32>) -> tensor<49x160xf32>
+        flow.dispatch.tensor.store %17, %2, offsets = [0, 0], sizes = [49, 160], strides = [1, 1]
+            : tensor<49x160xf32> -> !flow.dispatch.tensor<writeonly:49x160xf32>
         return
       }
     }
@@ -367,16 +250,10 @@
 }
 
 //  CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering.config<tile_sizes = {{\[}}[1, 32], [1, 4], [0, 0, 8]{{\]}}, native_vector_size = []>
-//  CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0] -> (s0 ceildiv 32)>
-//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"SPIRVVectorize", workload_per_wg = [32, 1]>
+//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"SPIRVVectorize", workload_per_wg = []>
 //      CHECK: hal.executable.entry_point public @matmul_49x160x576
 // CHECK-SAME:   translation.info = #[[TRANSLATION]]
 // CHECK-SAME:   workgroup_size = [8 : index, 1 : index, 1 : index]
-// CHECK-NEXT: ^{{.+}}(%[[X:.+]]: index, %[[Y:.+]]: index, %{{.+}}: index):
-// CHECK-NEXT:   %[[ONE:.+]] = arith.constant 1 : index
-// CHECK-NEXT:   %[[X_COUNT:.+]] = affine.apply #[[MAP]]()[%[[X]]]
-// CHECK-NEXT:   hal.return %[[X_COUNT]], %[[Y]], %[[ONE]]
-
 //      CHECK: func @matmul_49x160x576()
 //      CHECK:   linalg.matmul
 // CHECK-SAME:     lowering.config = #[[CONFIG]]
@@ -410,43 +287,17 @@
         %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) : !flow.dispatch.tensor<readonly:4x384x32xf32>
         %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) : !flow.dispatch.tensor<readonly:4x32x384xf32>
         %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) : !flow.dispatch.tensor<writeonly:4x384x384xf32>
-        %workgroup_size_x = hal.interface.workgroup.size[0] : index
-        %workgroup_size_y = hal.interface.workgroup.size[1] : index
-        %workgroup_size_z = hal.interface.workgroup.size[2] : index
-        %workgroup_id_x = hal.interface.workgroup.id[0] : index
-        %workgroup_count_x = hal.interface.workgroup.count[0] : index
-        %workgroup_id_y = hal.interface.workgroup.id[1] : index
-        %workgroup_count_y = hal.interface.workgroup.count[1] : index
-        %workgroup_id_z = hal.interface.workgroup.id[2] : index
-        %workgroup_count_z = hal.interface.workgroup.count[2] : index
-        %3 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_z, %workgroup_size_z]
-        %4 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_z, %workgroup_size_z]
-        scf.for %arg0 = %3 to %c4 step %4 {
-          %5 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_y, %workgroup_size_y]
-          %6 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_y, %workgroup_size_y]
-          scf.for %arg1 = %5 to %c384 step %6 {
-            %7 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_x, %workgroup_size_x]
-            %8 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_x, %workgroup_size_x]
-            scf.for %arg2 = %7 to %c384 step %8 {
-              %9 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 4)>(%arg0)[%workgroup_size_z]
-              %10 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 384)>(%arg1)[%workgroup_size_y]
-              %11 = flow.dispatch.tensor.load %0, offsets = [%arg0, %arg1, 0], sizes = [%9, %10, 32], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:4x384x32xf32> -> tensor<?x?x32xf32>
-              %12 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 4)>(%arg0)[%workgroup_size_z]
-              %13 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 384)>(%arg2)[%workgroup_size_x]
-              %14 = flow.dispatch.tensor.load %1, offsets = [%arg0, 0, %arg2], sizes = [%12, 32, %13], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:4x32x384xf32> -> tensor<?x32x?xf32>
-              %15 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 4)>(%arg0)[%workgroup_size_z]
-              %16 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 384)>(%arg1)[%workgroup_size_y]
-              %17 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 384)>(%arg2)[%workgroup_size_x]
-              %18 = affine.min affine_map<(d0)[s0] -> (-d0 + 4, s0)>(%arg0)[%workgroup_size_z]
-              %19 = affine.min affine_map<(d0)[s0] -> (-d0 + 384, s0)>(%arg1)[%workgroup_size_y]
-              %20 = affine.min affine_map<(d0)[s0] -> (-d0 + 384, s0)>(%arg2)[%workgroup_size_x]
-              %21 = linalg.init_tensor [%18, %19, %20] : tensor<?x?x?xf32>
-              %22 = linalg.fill(%cst, %21) : f32, tensor<?x?x?xf32> -> tensor<?x?x?xf32>
-              %23 = linalg.batch_matmul {__internal_linalg_transform__ = "workgroup"} ins(%11, %14 : tensor<?x?x32xf32>, tensor<?x32x?xf32>) outs(%22 : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
-              flow.dispatch.tensor.store %23, %2, offsets = [%arg0, %arg1, %arg2], sizes = [%15, %16, %17], strides = [1, 1, 1] : tensor<?x?x?xf32> -> !flow.dispatch.tensor<writeonly:4x384x384xf32>
-            }
-          }
-        }
+        %11 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0], sizes = [4, 384, 32], strides = [1, 1, 1]
+            : !flow.dispatch.tensor<readonly:4x384x32xf32> -> tensor<4x384x32xf32>
+        %14 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0], sizes = [4, 32, 384], strides = [1, 1, 1]
+            : !flow.dispatch.tensor<readonly:4x32x384xf32> -> tensor<4x32x384xf32>
+        %21 = linalg.init_tensor [4, 384, 384] : tensor<4x384x384xf32>
+        %22 = linalg.fill(%cst, %21) : f32, tensor<4x384x384xf32> -> tensor<4x384x384xf32>
+        %23 = linalg.batch_matmul {__internal_linalg_transform__ = "workgroup"}
+            ins(%11, %14 : tensor<4x384x32xf32>, tensor<4x32x384xf32>)
+            outs(%22 : tensor<4x384x384xf32>) -> tensor<4x384x384xf32>
+        flow.dispatch.tensor.store %23, %2, offsets = [0, 0, 0], sizes = [4, 384, 384], strides = [1, 1, 1]
+            : tensor<4x384x384xf32> -> !flow.dispatch.tensor<writeonly:4x384x384xf32>
         return
       }
     }
@@ -454,17 +305,10 @@
 }
 
 //  CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering.config<tile_sizes = {{\[}}[1, 12, 32], [1, 6, 4], [0, 0, 0, 4]{{\]}}, native_vector_size = []>
-//  CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 32)>
-//  CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 ceildiv 12)>
-//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"SPIRVVectorize", workload_per_wg = [32, 12, 1]>
+//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"SPIRVVectorize", workload_per_wg = []>
 //      CHECK: hal.executable.entry_point public @batch_matmul_4x384x384
 // CHECK-SAME:   translation.info = #[[TRANSLATION]]
 // CHECK-SAME:   workgroup_size = [8 : index, 2 : index, 1 : index]
-// CHECK-NEXT: ^{{.+}}(%[[X:.+]]: index, %[[Y:.+]]: index, %[[Z:.+]]: index):
-// CHECK-NEXT:   %[[X_COUNT:.+]] = affine.apply #[[MAP0]]()[%[[X]]]
-// CHECK-NEXT:   %[[Y_COUNT:.+]] = affine.apply #[[MAP1]]()[%[[Y]]]
-// CHECK-NEXT:   hal.return %[[X_COUNT]], %[[Y_COUNT]], %[[Z]]
-
 //      CHECK: func @batch_matmul_4x384x384()
 //      CHECK:   linalg.batch_matmul
 // CHECK-SAME:     lowering.config = #[[CONFIG]]
@@ -499,43 +343,16 @@
         %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) : !flow.dispatch.tensor<readonly:4x2x32xf32>
         %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) : !flow.dispatch.tensor<readonly:4x32x8xf32>
         %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) : !flow.dispatch.tensor<writeonly:4x2x8xf32>
-        %workgroup_size_x = hal.interface.workgroup.size[0] : index
-        %workgroup_size_y = hal.interface.workgroup.size[1] : index
-        %workgroup_size_z = hal.interface.workgroup.size[2] : index
-        %workgroup_id_x = hal.interface.workgroup.id[0] : index
-        %workgroup_count_x = hal.interface.workgroup.count[0] : index
-        %workgroup_id_y = hal.interface.workgroup.id[1] : index
-        %workgroup_count_y = hal.interface.workgroup.count[1] : index
-        %workgroup_id_z = hal.interface.workgroup.id[2] : index
-        %workgroup_count_z = hal.interface.workgroup.count[2] : index
-        %3 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_z, %workgroup_size_z]
-        %4 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_z, %workgroup_size_z]
-        scf.for %arg0 = %3 to %c4 step %4 {
-          %5 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_y, %workgroup_size_y]
-          %6 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_y, %workgroup_size_y]
-          scf.for %arg1 = %5 to %c2 step %6 {
-            %7 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_x, %workgroup_size_x]
-            %8 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_x, %workgroup_size_x]
-            scf.for %arg2 = %7 to %c8 step %8 {
-              %9 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 4)>(%arg0)[%workgroup_size_z]
-              %10 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 2)>(%arg1)[%workgroup_size_y]
-              %11 = flow.dispatch.tensor.load %0, offsets = [%arg0, %arg1, 0], sizes = [%9, %10, 32], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:4x2x32xf32> -> tensor<?x?x32xf32>
-              %12 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 4)>(%arg0)[%workgroup_size_z]
-              %13 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 8)>(%arg2)[%workgroup_size_x]
-              %14 = flow.dispatch.tensor.load %1, offsets = [%arg0, 0, %arg2], sizes = [%12, 32, %13], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:4x32x8xf32> -> tensor<?x32x?xf32>
-              %15 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 4)>(%arg0)[%workgroup_size_z]
-              %16 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 2)>(%arg1)[%workgroup_size_y]
-              %17 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 8)>(%arg2)[%workgroup_size_x]
-              %18 = affine.min affine_map<(d0)[s0] -> (-d0 + 4, s0)>(%arg0)[%workgroup_size_z]
-              %19 = affine.min affine_map<(d0)[s0] -> (-d0 + 2, s0)>(%arg1)[%workgroup_size_y]
-              %20 = affine.min affine_map<(d0)[s0] -> (-d0 + 8, s0)>(%arg2)[%workgroup_size_x]
-              %21 = linalg.init_tensor [%18, %19, %20] : tensor<?x?x?xf32>
-              %22 = linalg.fill(%cst, %21) : f32, tensor<?x?x?xf32> -> tensor<?x?x?xf32>
-              %23 = linalg.batch_matmul {__internal_linalg_transform__ = "workgroup"} ins(%11, %14 : tensor<?x?x32xf32>, tensor<?x32x?xf32>) outs(%22 : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
-              flow.dispatch.tensor.store %23, %2, offsets = [%arg0, %arg1, %arg2], sizes = [%15, %16, %17], strides = [1, 1, 1] : tensor<?x?x?xf32> -> !flow.dispatch.tensor<writeonly:4x2x8xf32>
-            }
-          }
-        }
+        %11 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0], sizes = [4, 2, 32], strides = [1, 1, 1]
+            : !flow.dispatch.tensor<readonly:4x2x32xf32> -> tensor<4x2x32xf32>
+        %14 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0], sizes = [4, 32, 8], strides = [1, 1, 1]
+            : !flow.dispatch.tensor<readonly:4x32x8xf32> -> tensor<4x32x8xf32>
+        %21 = linalg.init_tensor [4, 2, 8] : tensor<4x2x8xf32>
+        %22 = linalg.fill(%cst, %21) : f32, tensor<4x2x8xf32> -> tensor<4x2x8xf32>
+        %23 = linalg.batch_matmul {__internal_linalg_transform__ = "workgroup"}
+            ins(%11, %14 : tensor<4x2x32xf32>, tensor<4x32x8xf32>) outs(%22 : tensor<4x2x8xf32>) -> tensor<4x2x8xf32>
+        flow.dispatch.tensor.store %23, %2, offsets = [0, 0, 0], sizes = [4, 2, 8], strides = [1, 1, 1]
+            : tensor<4x2x8xf32> -> !flow.dispatch.tensor<writeonly:4x2x8xf32>
         return
       }
     }
@@ -543,17 +360,10 @@
 }
 
 //  CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering.config<tile_sizes = {{\[}}[1, 2, 8], [1, 1, 4], [0, 0, 0, 8]{{\]}}, native_vector_size = []>
-//  CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 8)>
-//  CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 ceildiv 2)>
-//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"SPIRVVectorize", workload_per_wg = [8, 2, 1]>
+//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"SPIRVVectorize", workload_per_wg = []>
 //      CHECK: hal.executable.entry_point public @batch_matmul_4x2x8
 // CHECK-SAME:   translation.info = #[[TRANSLATION]]
 // CHECK-SAME:   workgroup_size = [2 : index, 2 : index, 1 : index]
-// CHECK-NEXT: ^{{.+}}(%[[X:.+]]: index, %[[Y:.+]]: index, %[[Z:.+]]: index):
-// CHECK-NEXT:   %[[X_COUNT:.+]] = affine.apply #[[MAP0]]()[%[[X]]]
-// CHECK-NEXT:   %[[Y_COUNT:.+]] = affine.apply #[[MAP1]]()[%[[Y]]]
-// CHECK-NEXT:   hal.return %[[X_COUNT]], %[[Y_COUNT]], %[[Z]]
-
 //      CHECK: func @batch_matmul_4x2x8()
 //      CHECK:   linalg.batch_matmul
 // CHECK-SAME:     lowering.config = #[[CONFIG]]
diff --git a/iree/compiler/Codegen/SPIRV/test/config_nvidia_matmul_cooperative_ops.mlir b/iree/compiler/Codegen/SPIRV/test/config_nvidia_matmul_cooperative_ops.mlir
index 1ec9f66..f5aa51b 100644
--- a/iree/compiler/Codegen/SPIRV/test/config_nvidia_matmul_cooperative_ops.mlir
+++ b/iree/compiler/Codegen/SPIRV/test/config_nvidia_matmul_cooperative_ops.mlir
@@ -46,48 +46,30 @@
         %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) : !flow.dispatch.tensor<readonly:256x128xf16>
         %3 = hal.interface.binding.subspan set(0) binding(3) type(storage_buffer) : !flow.dispatch.tensor<readonly:128x1024xf16>
         %4 = hal.interface.binding.subspan set(0) binding(4) type(storage_buffer) : !flow.dispatch.tensor<writeonly:256x1024xf16>
-        %workgroup_size_x = hal.interface.workgroup.size[0] : index
-        %workgroup_size_y = hal.interface.workgroup.size[1] : index
-        %workgroup_id_x = hal.interface.workgroup.id[0] : index
-        %workgroup_count_x = hal.interface.workgroup.count[0] : index
-        %workgroup_id_y = hal.interface.workgroup.id[1] : index
-        %workgroup_count_y = hal.interface.workgroup.count[1] : index
-        %5 = affine.apply #map0()[%workgroup_id_y, %workgroup_size_y]
-        %6 = affine.apply #map0()[%workgroup_count_y, %workgroup_size_y]
-        scf.for %arg0 = %5 to %c256 step %6 {
-          %7 = affine.apply #map0()[%workgroup_id_x, %workgroup_size_x]
-          %8 = affine.apply #map0()[%workgroup_count_x, %workgroup_size_x]
-          scf.for %arg1 = %7 to %c1024 step %8 {
-            %9 = affine.min #map1(%arg0)[%workgroup_size_y]
-            %10 = affine.min #map2(%arg1)[%workgroup_size_x]
-            %11 = flow.dispatch.tensor.load %0, offsets = [%arg0, %arg1], sizes = [%9, %10], strides = [1, 1] : !flow.dispatch.tensor<readonly:256x1024xf16> -> tensor<?x?xf16>
-            %12 = affine.min #map1(%arg0)[%workgroup_size_y]
-            %13 = affine.min #map2(%arg1)[%workgroup_size_x]
-            %14 = flow.dispatch.tensor.load %1, offsets = [%arg0, %arg1], sizes = [%12, %13], strides = [1, 1] : !flow.dispatch.tensor<readonly:256x1024xf16> -> tensor<?x?xf16>
-            %15 = affine.min #map1(%arg0)[%workgroup_size_y]
-            %16 = affine.min #map2(%arg1)[%workgroup_size_x]
-            %17 = linalg.init_tensor [%15, %16] : tensor<?x?xf16>
-            %18 = affine.min #map3(%arg0)[%workgroup_size_y]
-            %19 = flow.dispatch.tensor.load %2, offsets = [%arg0, 0], sizes = [%18, 128], strides = [1, 1] : !flow.dispatch.tensor<readonly:256x128xf16> -> tensor<?x128xf16>
-            %20 = affine.min #map4(%arg1)[%workgroup_size_x]
-            %21 = flow.dispatch.tensor.load %3, offsets = [0, %arg1], sizes = [128, %20], strides = [1, 1] : !flow.dispatch.tensor<readonly:128x1024xf16> -> tensor<128x?xf16>
-            %22 = affine.min #map3(%arg0)[%workgroup_size_y]
-            %23 = affine.min #map4(%arg1)[%workgroup_size_x]
-            %24 = linalg.init_tensor [%22, %23] : tensor<?x?xf16>
-            %25 = linalg.fill(%cst, %24) : f16, tensor<?x?xf16> -> tensor<?x?xf16>
-            %26 = linalg.matmul ins(%19, %21 : tensor<?x128xf16>, tensor<128x?xf16>) outs(%25 : tensor<?x?xf16>) -> tensor<?x?xf16>
-            %27 = linalg.generic {indexing_maps = [#map5, #map5, #map5, #map5], iterator_types = ["parallel", "parallel"]}
-              ins(%26, %11, %14 : tensor<?x?xf16>, tensor<?x?xf16>, tensor<?x?xf16>)
-              outs(%17 : tensor<?x?xf16>)
-              attrs =  {__internal_linalg_transform__ = "workgroup"} {
-            ^bb0(%arg2: f16, %arg3: f16, %arg4: f16, %arg5: f16):  // no predecessors
-              %28 = arith.divf %arg2, %arg3 : f16
-              %29 = arith.subf %28, %arg4 : f16
+        %11 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [256, 1024], strides = [1, 1]
+            : !flow.dispatch.tensor<readonly:256x1024xf16> -> tensor<256x1024xf16>
+        %14 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [256, 1024], strides = [1, 1]
+            : !flow.dispatch.tensor<readonly:256x1024xf16> -> tensor<256x1024xf16>
+        %17 = linalg.init_tensor [256, 1024] : tensor<256x1024xf16>
+        %19 = flow.dispatch.tensor.load %2, offsets = [0, 0], sizes = [256, 128], strides = [1, 1]
+            : !flow.dispatch.tensor<readonly:256x128xf16> -> tensor<256x128xf16>
+        %21 = flow.dispatch.tensor.load %3, offsets = [0, 0], sizes = [128, 1024], strides = [1, 1]
+            : !flow.dispatch.tensor<readonly:128x1024xf16> -> tensor<128x1024xf16>
+        %24 = linalg.init_tensor [256, 1024] : tensor<256x1024xf16>
+        %25 = linalg.fill(%cst, %24) : f16, tensor<256x1024xf16> -> tensor<256x1024xf16>
+        %26 = linalg.matmul ins(%19, %21 : tensor<256x128xf16>, tensor<128x1024xf16>) outs(%25 : tensor<256x1024xf16>) -> tensor<256x1024xf16>
+        %27 = linalg.generic {
+            indexing_maps = [#map5, #map5, #map5, #map5], iterator_types = ["parallel", "parallel"]}
+            ins(%26, %11, %14 : tensor<256x1024xf16>, tensor<256x1024xf16>, tensor<256x1024xf16>)
+            outs(%17 : tensor<256x1024xf16>)
+            attrs =  {__internal_linalg_transform__ = "workgroup"} {
+          ^bb0(%arg2: f16, %arg3: f16, %arg4: f16, %arg5: f16):  // no predecessors
+            %28 = arith.divf %arg2, %arg3 : f16
+            %29 = arith.subf %28, %arg4 : f16
               linalg.yield %29 : f16
-            } -> tensor<?x?xf16>
-            flow.dispatch.tensor.store %27, %4, offsets = [%arg0, %arg1], sizes = [%15, %16], strides = [1, 1] : tensor<?x?xf16> -> !flow.dispatch.tensor<writeonly:256x1024xf16>
-          }
-        }
+            } -> tensor<256x1024xf16>
+        flow.dispatch.tensor.store %27, %4, offsets = [0, 0], sizes = [256, 1024], strides = [1, 1]
+            : tensor<256x1024xf16> -> !flow.dispatch.tensor<writeonly:256x1024xf16>
         return
       }
     }
@@ -95,17 +77,10 @@
 }
 
 //  CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering.config<tile_sizes = {{\[}}[16, 16, 16], [16, 16, 16]{{\]}}, native_vector_size = []>
-//  CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0] -> (s0 ceildiv 16)>
-//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"SPIRVVectorizeToCooperativeOps", workload_per_wg = [16, 16]>
+//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"SPIRVVectorizeToCooperativeOps", workload_per_wg = []>
 //      CHECK: hal.executable.entry_point public @matmul_256x1024x128_div_sub
 // CHECK-SAME:   translation.info = #[[TRANSLATION]]
 // CHECK-SAME:   workgroup_size = [32 : index, 1 : index, 1 : index]
-// CHECK-NEXT: ^{{.+}}(%[[X:.+]]: index, %[[Y:.+]]: index, %{{.+}}: index):
-// CHECK-NEXT:   %[[C1:.+]] = arith.constant 1 : index
-// CHECK-NEXT:   %[[X_COUNT:.+]] = affine.apply #[[MAP]]()[%[[X]]]
-// CHECK-NEXT:   %[[Y_COUNT:.+]] = affine.apply #[[MAP]]()[%[[Y]]]
-// CHECK-NEXT:   hal.return %[[X_COUNT]], %[[Y_COUNT]], %[[C1]]
-
 //      CHECK: func @matmul_256x1024x128_div_sub()
 //      CHECK:   linalg.matmul
 // CHECK-SAME:     lowering.config = #[[CONFIG]]
@@ -155,32 +130,14 @@
         %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) : !flow.dispatch.tensor<readonly:256x8xf16>
         %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) : !flow.dispatch.tensor<readonly:8x1024xf16>
         %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) : !flow.dispatch.tensor<writeonly:256x1024xf16>
-        %workgroup_size_x = hal.interface.workgroup.size[0] : index
-        %workgroup_size_y = hal.interface.workgroup.size[1] : index
-        %workgroup_id_x = hal.interface.workgroup.id[0] : index
-        %workgroup_count_x = hal.interface.workgroup.count[0] : index
-        %workgroup_id_y = hal.interface.workgroup.id[1] : index
-        %workgroup_count_y = hal.interface.workgroup.count[1] : index
-        %3 = affine.apply #map0()[%workgroup_id_y, %workgroup_size_y]
-        %4 = affine.apply #map0()[%workgroup_count_y, %workgroup_size_y]
-        scf.for %arg0 = %3 to %c256 step %4 {
-          %5 = affine.apply #map0()[%workgroup_id_x, %workgroup_size_x]
-          %6 = affine.apply #map0()[%workgroup_count_x, %workgroup_size_x]
-          scf.for %arg1 = %5 to %c1024 step %6 {
-            %7 = affine.min #map1(%arg0)[%workgroup_size_y]
-            %8 = flow.dispatch.tensor.load %0, offsets = [%arg0, 0], sizes = [%7, 8], strides = [1, 1] : !flow.dispatch.tensor<readonly:256x8xf16> -> tensor<?x8xf16>
-            %9 = affine.min #map2(%arg1)[%workgroup_size_x]
-            %10 = flow.dispatch.tensor.load %1, offsets = [0, %arg1], sizes = [8, %9], strides = [1, 1] : !flow.dispatch.tensor<readonly:8x1024xf16> -> tensor<8x?xf16>
-            %11 = affine.min #map1(%arg0)[%workgroup_size_y]
-            %12 = affine.min #map2(%arg1)[%workgroup_size_x]
-            %13 = affine.min #map3(%arg0)[%workgroup_size_y]
-            %14 = affine.min #map4(%arg1)[%workgroup_size_x]
-            %15 = linalg.init_tensor [%13, %14] : tensor<?x?xf16>
-            %16 = linalg.fill(%cst, %15) : f16, tensor<?x?xf16> -> tensor<?x?xf16>
-            %17 = linalg.matmul {__internal_linalg_transform__ = "workgroup"} ins(%8, %10 : tensor<?x8xf16>, tensor<8x?xf16>) outs(%16 : tensor<?x?xf16>) -> tensor<?x?xf16>
-            flow.dispatch.tensor.store %17, %2, offsets = [%arg0, %arg1], sizes = [%11, %12], strides = [1, 1] : tensor<?x?xf16> -> !flow.dispatch.tensor<writeonly:256x1024xf16>
-          }
-        }
+        %8 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [256, 8], strides = [1, 1] : !flow.dispatch.tensor<readonly:256x8xf16> -> tensor<256x8xf16>
+        %10 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [8, 1024], strides = [1, 1] : !flow.dispatch.tensor<readonly:8x1024xf16> -> tensor<8x1024xf16>
+        %15 = linalg.init_tensor [256, 1024] : tensor<256x1024xf16>
+        %16 = linalg.fill(%cst, %15) : f16, tensor<256x1024xf16> -> tensor<256x1024xf16>
+        %17 = linalg.matmul {__internal_linalg_transform__ = "workgroup"}
+            ins(%8, %10 : tensor<256x8xf16>, tensor<8x1024xf16>) outs(%16 : tensor<256x1024xf16>) -> tensor<256x1024xf16>
+        flow.dispatch.tensor.store %17, %2, offsets = [0, 0], sizes = [256, 1024], strides = [1, 1]
+            : tensor<256x1024xf16> -> !flow.dispatch.tensor<writeonly:256x1024xf16>
         return
       }
     }
diff --git a/iree/compiler/Codegen/SPIRV/test/convert_to_spirv.mlir b/iree/compiler/Codegen/SPIRV/test/convert_to_spirv.mlir
index 5ccb53c..28c7b4b 100644
--- a/iree/compiler/Codegen/SPIRV/test/convert_to_spirv.mlir
+++ b/iree/compiler/Codegen/SPIRV/test/convert_to_spirv.mlir
@@ -9,7 +9,7 @@
 hal.executable private @push_constant {
   hal.executable.variant @vulkan, target = <"vulkan-spirv", "vulkan-spirv-fb", {
       spv.target_env = #spv.target_env<#spv.vce<v1.3, [Shader], []>, {}>}> {
-    hal.executable.entry_point @push_constant layout(#executable_layout) attributes {
+    hal.executable.entry_point @push_constant layout(#executable_layout) {
       workgroup_size = [32: index, 1: index, 1: index]
     }
     builtin.module {
@@ -43,7 +43,7 @@
 hal.executable private @resource_bindings_in_same_func {
   hal.executable.variant @vulkan, target = <"vulkan-spirv", "vulkan-spirv-fb", {
       spv.target_env = #spv.target_env<#spv.vce<v1.3, [Shader], []>, {}>}> {
-    hal.executable.entry_point @resource_bindings_in_same_func layout(#executable_layout) attributes {
+    hal.executable.entry_point @resource_bindings_in_same_func layout(#executable_layout) {
       workgroup_size = [32: index, 1: index, 1: index]
     }
     builtin.module {
@@ -98,10 +98,10 @@
 hal.executable private @resource_bindings_in_multi_entry_func {
   hal.executable.variant @vulkan, target = <"vulkan-spirv", "vulkan-spirv-fb", {
       spv.target_env = #spv.target_env<#spv.vce<v1.3, [Shader], []>, {}>}> {
-    hal.executable.entry_point @resource_bindings_in_entry_func1 layout(#executable_layout) attributes {
+    hal.executable.entry_point @resource_bindings_in_entry_func1 layout(#executable_layout) {
       workgroup_size = [32: index, 1: index, 1: index]
     }
-    hal.executable.entry_point @resource_bindings_in_entry_func2 layout(#executable_layout) attributes {
+    hal.executable.entry_point @resource_bindings_in_entry_func2 layout(#executable_layout) {
       workgroup_size = [32: index, 1: index, 1: index]
     }
     builtin.module {
@@ -154,7 +154,7 @@
 hal.executable private @interface_binding {
   hal.executable.variant @vulkan, target = <"vulkan-spirv", "vulkan-spirv-fb", {
       spv.target_env = #spv.target_env<#spv.vce<v1.3, [Shader], []>, {}>}> {
-    hal.executable.entry_point @interface_binding layout(#executable_layout) attributes {
+    hal.executable.entry_point @interface_binding layout(#executable_layout) {
       workgroup_size = [32: index, 1: index, 1: index]
     }
     builtin.module {
@@ -197,7 +197,7 @@
 hal.executable private @interface_wg_id {
   hal.executable.variant @vulkan, target = <"vulkan-spirv", "vulkan-spirv-fb", {
       spv.target_env = #spv.target_env<#spv.vce<v1.3, [Shader], []>, {}>}> {
-    hal.executable.entry_point @interface_wg_id layout(#executable_layout) attributes {
+    hal.executable.entry_point @interface_wg_id layout(#executable_layout) {
       workgroup_size = [32: index, 1: index, 1: index]
     }
     builtin.module {
@@ -232,7 +232,7 @@
 hal.executable private @interface_wg_count {
   hal.executable.variant @vulkan, target = <"vulkan-spirv", "vulkan-spirv-fb", {
       spv.target_env = #spv.target_env<#spv.vce<v1.3, [Shader], []>, {}>}> {
-    hal.executable.entry_point @interface_wg_count layout(#executable_layout) attributes {
+    hal.executable.entry_point @interface_wg_count layout(#executable_layout) {
       workgroup_size = [32: index, 1: index, 1: index]
     }
     builtin.module {
diff --git a/iree/compiler/Codegen/SPIRV/test/pipeline_matmul_cooperative_ops.mlir b/iree/compiler/Codegen/SPIRV/test/pipeline_matmul_cooperative_ops.mlir
index b156293..3dfe104 100644
--- a/iree/compiler/Codegen/SPIRV/test/pipeline_matmul_cooperative_ops.mlir
+++ b/iree/compiler/Codegen/SPIRV/test/pipeline_matmul_cooperative_ops.mlir
@@ -1,11 +1,5 @@
 // RUN: iree-opt -split-input-file -pass-pipeline='hal.executable(hal.executable.variant(iree-codegen-linalg-to-spirv-pipeline))' %s | FileCheck %s
 
-#map0 = affine_map<()[s0, s1] -> (s0 * s1)>
-#map1 = affine_map<(d0)[s0] -> (s0, -d0 + 256)>
-#map2 = affine_map<(d0)[s0] -> (s0, -d0 + 1024)>
-#map3 = affine_map<(d0)[s0] -> (-d0 + 256, s0)>
-#map4 = affine_map<(d0)[s0] -> (-d0 + 1024, s0)>
-#map5 = affine_map<(d0, d1) -> (d0, d1)>
 #executable_layout = #hal.executable.layout<push_constants = 0, sets = [
   #hal.descriptor_set.layout<0, bindings = [
     #hal.descriptor_set.binding<0, storage_buffer>,
@@ -15,6 +9,7 @@
     #hal.descriptor_set.binding<4, storage_buffer>
   ]>
 ]>
+
 hal.executable public @matmul_256x1024x128_div_sub {
   hal.executable.variant @vulkan, target = <"vulkan-spirv", "vulkan-spirv-fb", {
     spv.target_env =
@@ -46,47 +41,25 @@
         %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) : !flow.dispatch.tensor<readonly:256x128xf16>
         %3 = hal.interface.binding.subspan set(0) binding(3) type(storage_buffer) : !flow.dispatch.tensor<readonly:128x1024xf16>
         %4 = hal.interface.binding.subspan set(0) binding(4) type(storage_buffer) : !flow.dispatch.tensor<writeonly:256x1024xf16>
-        %workgroup_size_x = hal.interface.workgroup.size[0] : index
-        %workgroup_size_y = hal.interface.workgroup.size[1] : index
-        %workgroup_id_x = hal.interface.workgroup.id[0] : index
-        %workgroup_count_x = hal.interface.workgroup.count[0] : index
-        %workgroup_id_y = hal.interface.workgroup.id[1] : index
-        %workgroup_count_y = hal.interface.workgroup.count[1] : index
-        %5 = affine.apply #map0()[%workgroup_id_y, %workgroup_size_y]
-        %6 = affine.apply #map0()[%workgroup_count_y, %workgroup_size_y]
-        scf.for %arg0 = %5 to %c256 step %6 {
-          %7 = affine.apply #map0()[%workgroup_id_x, %workgroup_size_x]
-          %8 = affine.apply #map0()[%workgroup_count_x, %workgroup_size_x]
-          scf.for %arg1 = %7 to %c1024 step %8 {
-            %9 = affine.min #map1(%arg0)[%workgroup_size_y]
-            %10 = affine.min #map2(%arg1)[%workgroup_size_x]
-            %11 = flow.dispatch.tensor.load %0, offsets = [%arg0, %arg1], sizes = [%9, %10], strides = [1, 1] : !flow.dispatch.tensor<readonly:256x1024xf16> -> tensor<?x?xf16>
-            %12 = affine.min #map1(%arg0)[%workgroup_size_y]
-            %13 = affine.min #map2(%arg1)[%workgroup_size_x]
-            %14 = flow.dispatch.tensor.load %1, offsets = [%arg0, %arg1], sizes = [%12, %13], strides = [1, 1] : !flow.dispatch.tensor<readonly:256x1024xf16> -> tensor<?x?xf16>
-            %15 = affine.min #map1(%arg0)[%workgroup_size_y]
-            %16 = affine.min #map2(%arg1)[%workgroup_size_x]
-            %17 = linalg.init_tensor [%15, %16] : tensor<?x?xf16>
-            %18 = affine.min #map3(%arg0)[%workgroup_size_y]
-            %19 = flow.dispatch.tensor.load %2, offsets = [%arg0, 0], sizes = [%18, 128], strides = [1, 1] : !flow.dispatch.tensor<readonly:256x128xf16> -> tensor<?x128xf16>
-            %20 = affine.min #map4(%arg1)[%workgroup_size_x]
-            %21 = flow.dispatch.tensor.load %3, offsets = [0, %arg1], sizes = [128, %20], strides = [1, 1] : !flow.dispatch.tensor<readonly:128x1024xf16> -> tensor<128x?xf16>
-            %22 = affine.min #map3(%arg0)[%workgroup_size_y]
-            %23 = affine.min #map4(%arg1)[%workgroup_size_x]
-            %24 = linalg.init_tensor [%22, %23] : tensor<?x?xf16>
-            %25 = linalg.fill(%cst, %24) : f16, tensor<?x?xf16> -> tensor<?x?xf16>
-            %26 = linalg.matmul ins(%19, %21 : tensor<?x128xf16>, tensor<128x?xf16>) outs(%25 : tensor<?x?xf16>) -> tensor<?x?xf16>
-            %27 = linalg.generic {indexing_maps = [#map5, #map5, #map5, #map5], iterator_types = ["parallel", "parallel"]}
-              ins(%26, %11, %14 : tensor<?x?xf16>, tensor<?x?xf16>, tensor<?x?xf16>)
-              outs(%17 : tensor<?x?xf16>) {
-            ^bb0(%arg2: f16, %arg3: f16, %arg4: f16, %arg5: f16):  // no predecessors
-              %28 = arith.divf %arg2, %arg3 : f16
-              %29 = arith.subf %28, %arg4 : f16
-              linalg.yield %29 : f16
-            } -> tensor<?x?xf16>
-            flow.dispatch.tensor.store %27, %4, offsets = [%arg0, %arg1], sizes = [%15, %16], strides = [1, 1] : tensor<?x?xf16> -> !flow.dispatch.tensor<writeonly:256x1024xf16>
-          }
-        }
+        %11 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [256, 1024], strides = [1, 1] : !flow.dispatch.tensor<readonly:256x1024xf16> -> tensor<256x1024xf16>
+        %14 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [256, 1024], strides = [1, 1] : !flow.dispatch.tensor<readonly:256x1024xf16> -> tensor<256x1024xf16>
+        %17 = linalg.init_tensor [256, 1024] : tensor<256x1024xf16>
+        %19 = flow.dispatch.tensor.load %2, offsets = [0, 0], sizes = [256, 128], strides = [1, 1] : !flow.dispatch.tensor<readonly:256x128xf16> -> tensor<256x128xf16>
+        %21 = flow.dispatch.tensor.load %3, offsets = [0, 0], sizes = [128, 1204], strides = [1, 1] : !flow.dispatch.tensor<readonly:128x1024xf16> -> tensor<128x1024xf16>
+        %24 = linalg.init_tensor [256, 1024] : tensor<256x1024xf16>
+        %25 = linalg.fill(%cst, %24) : f16, tensor<256x1024xf16> -> tensor<256x1024xf16>
+        %26 = linalg.matmul ins(%19, %21 : tensor<256x128xf16>, tensor<128x1024xf16>) outs(%25 : tensor<256x1024xf16>) -> tensor<256x1024xf16>
+        %27 = linalg.generic {
+            indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
+            iterator_types = ["parallel", "parallel"]}
+          ins(%26, %11, %14 : tensor<256x1024xf16>, tensor<256x1024xf16>, tensor<256x1024xf16>)
+          outs(%17 : tensor<256x1024xf16>) {
+        ^bb0(%arg2: f16, %arg3: f16, %arg4: f16, %arg5: f16):
+          %28 = arith.divf %arg2, %arg3 : f16
+          %29 = arith.subf %28, %arg4 : f16
+          linalg.yield %29 : f16
+        } -> tensor<256x1024xf16>
+        flow.dispatch.tensor.store %27, %4, offsets = [0, 0], sizes = [256, 1024], strides = [1, 1] : tensor<256x1024xf16> -> !flow.dispatch.tensor<writeonly:256x1024xf16>
         return
       }
     }
diff --git a/iree/compiler/Codegen/SPIRV/test/pipeline_matmul_vectorization.mlir b/iree/compiler/Codegen/SPIRV/test/pipeline_matmul_vectorization.mlir
index 6042f19..aa3e66f 100644
--- a/iree/compiler/Codegen/SPIRV/test/pipeline_matmul_vectorization.mlir
+++ b/iree/compiler/Codegen/SPIRV/test/pipeline_matmul_vectorization.mlir
@@ -24,32 +24,12 @@
         %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) : !flow.dispatch.tensor<readonly:4096x4096xf32>
         %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) : !flow.dispatch.tensor<readonly:4096x4096xf32>
         %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) : !flow.dispatch.tensor<writeonly:4096x4096xf32>
-        %workgroup_size_x = hal.interface.workgroup.size[0] : index
-        %workgroup_size_y = hal.interface.workgroup.size[1] : index
-        %workgroup_id_x = hal.interface.workgroup.id[0] : index
-        %workgroup_count_x = hal.interface.workgroup.count[0] : index
-        %workgroup_id_y = hal.interface.workgroup.id[1] : index
-        %workgroup_count_y = hal.interface.workgroup.count[1] : index
-        %3 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_y, %workgroup_size_y]
-        %4 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_y, %workgroup_size_y]
-        scf.for %arg0 = %3 to %c4096 step %4 {
-          %5 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_x, %workgroup_size_x]
-          %6 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_x, %workgroup_size_x]
-          scf.for %arg1 = %5 to %c4096 step %6 {
-            %7 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 4096)>(%arg0)[%workgroup_size_y]
-            %8 = flow.dispatch.tensor.load %0, offsets = [%arg0, 0], sizes = [%7, 4096], strides = [1, 1] : !flow.dispatch.tensor<readonly:4096x4096xf32> -> tensor<?x4096xf32>
-            %9 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 4096)>(%arg1)[%workgroup_size_x]
-            %10 = flow.dispatch.tensor.load %1, offsets = [0, %arg1], sizes = [4096, %9], strides = [1, 1] : !flow.dispatch.tensor<readonly:4096x4096xf32> -> tensor<4096x?xf32>
-            %11 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 4096)>(%arg0)[%workgroup_size_y]
-            %12 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 4096)>(%arg1)[%workgroup_size_x]
-            %13 = affine.min affine_map<(d0)[s0] -> (-d0 + 4096, s0)>(%arg0)[%workgroup_size_y]
-            %14 = affine.min affine_map<(d0)[s0] -> (-d0 + 4096, s0)>(%arg1)[%workgroup_size_x]
-            %15 = linalg.init_tensor [%13, %14] : tensor<?x?xf32>
-            %16 = linalg.fill(%cst, %15) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
-            %17 = linalg.matmul ins(%8, %10 : tensor<?x4096xf32>, tensor<4096x?xf32>) outs(%16 : tensor<?x?xf32>) -> tensor<?x?xf32>
-            flow.dispatch.tensor.store %17, %2, offsets = [%arg0, %arg1], sizes = [%11, %12], strides = [1, 1] : tensor<?x?xf32> -> !flow.dispatch.tensor<writeonly:4096x4096xf32>
-          }
-        }
+        %8 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [4096, 4096], strides = [1, 1] : !flow.dispatch.tensor<readonly:4096x4096xf32> -> tensor<4096x4096xf32>
+        %10 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [4096, 4096], strides = [1, 1] : !flow.dispatch.tensor<readonly:4096x4096xf32> -> tensor<4096x4096xf32>
+        %15 = linalg.init_tensor [4096, 4096] : tensor<4096x4096xf32>
+        %16 = linalg.fill(%cst, %15) : f32, tensor<4096x4096xf32> -> tensor<4096x4096xf32>
+        %17 = linalg.matmul ins(%8, %10 : tensor<4096x4096xf32>, tensor<4096x4096xf32>) outs(%16 : tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
+        flow.dispatch.tensor.store %17, %2, offsets = [0, 0], sizes = [4096, 4096], strides = [1, 1] : tensor<4096x4096xf32> -> !flow.dispatch.tensor<writeonly:4096x4096xf32>
         return
       }
     }
@@ -93,41 +73,19 @@
         %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) : !flow.dispatch.tensor<readonly:1024x512xf32>
         %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) : !flow.dispatch.tensor<readonly:512x256xf32>
         %3 = hal.interface.binding.subspan set(0) binding(3) type(storage_buffer) : !flow.dispatch.tensor<writeonly:1024x256xf32>
-        %workgroup_size_x = hal.interface.workgroup.size[0] : index
-        %workgroup_size_y = hal.interface.workgroup.size[1] : index
-        %workgroup_id_x = hal.interface.workgroup.id[0] : index
-        %workgroup_count_x = hal.interface.workgroup.count[0] : index
-        %workgroup_id_y = hal.interface.workgroup.id[1] : index
-        %workgroup_count_y = hal.interface.workgroup.count[1] : index
-        %4 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_y, %workgroup_size_y]
-        %5 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_y, %workgroup_size_y]
-        scf.for %arg0 = %4 to %c1024 step %5 {
-          %6 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_x, %workgroup_size_x]
-          %7 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_x, %workgroup_size_x]
-          scf.for %arg1 = %6 to %c256 step %7 {
-            %8 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 1024)>(%arg0)[%workgroup_size_y]
-            %9 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 256)>(%arg1)[%workgroup_size_x]
-            %10 = flow.dispatch.tensor.load %0, offsets = [%arg0, %arg1], sizes = [%8, %9], strides = [1, 1] : !flow.dispatch.tensor<readonly:1024x256xf32> -> tensor<?x?xf32>
-            %11 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 1024)>(%arg0)[%workgroup_size_y]
-            %12 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 256)>(%arg1)[%workgroup_size_x]
-            %13 = linalg.init_tensor [%11, %12] : tensor<?x?xf32>
-            %14 = affine.min affine_map<(d0)[s0] -> (-d0 + 1024, s0)>(%arg0)[%workgroup_size_y]
-            %15 = flow.dispatch.tensor.load %1, offsets = [%arg0, 0], sizes = [%14, 512], strides = [1, 1] : !flow.dispatch.tensor<readonly:1024x512xf32> -> tensor<?x512xf32>
-            %16 = affine.min affine_map<(d0)[s0] -> (-d0 + 256, s0)>(%arg1)[%workgroup_size_x]
-            %17 = flow.dispatch.tensor.load %2, offsets = [0, %arg1], sizes = [512, %16], strides = [1, 1] : !flow.dispatch.tensor<readonly:512x256xf32> -> tensor<512x?xf32>
-            %18 = affine.min affine_map<(d0)[s0] -> (-d0 + 1024, s0)>(%arg0)[%workgroup_size_y]
-            %19 = affine.min affine_map<(d0)[s0] -> (-d0 + 256, s0)>(%arg1)[%workgroup_size_x]
-            %20 = linalg.init_tensor [%18, %19] : tensor<?x?xf32>
-            %21 = linalg.fill(%cst, %20) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
-            %22 = linalg.matmul ins(%15, %17 : tensor<?x512xf32>, tensor<512x?xf32>) outs(%21 : tensor<?x?xf32>) -> tensor<?x?xf32>
-            %23 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%22, %10 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%13 : tensor<?x?xf32>) {
-            ^bb0(%arg2: f32, %arg3: f32, %arg4: f32):  // no predecessors
-              %24 = arith.addf %arg2, %arg3 : f32
-              linalg.yield %24 : f32
-            } -> tensor<?x?xf32>
-            flow.dispatch.tensor.store %23, %3, offsets = [%arg0, %arg1], sizes = [%11, %12], strides = [1, 1] : tensor<?x?xf32> -> !flow.dispatch.tensor<writeonly:1024x256xf32>
-          }
-        }
+        %10 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [1024, 256], strides = [1, 1] : !flow.dispatch.tensor<readonly:1024x256xf32> -> tensor<1024x256xf32>
+        %13 = linalg.init_tensor [1024, 256] : tensor<1024x256xf32>
+        %15 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [1024, 512], strides = [1, 1] : !flow.dispatch.tensor<readonly:1024x512xf32> -> tensor<1024x512xf32>
+        %17 = flow.dispatch.tensor.load %2, offsets = [0, 0], sizes = [512, 256], strides = [1, 1] : !flow.dispatch.tensor<readonly:512x256xf32> -> tensor<512x256xf32>
+        %20 = linalg.init_tensor [1024, 256] : tensor<1024x256xf32>
+        %21 = linalg.fill(%cst, %20) : f32, tensor<1024x256xf32> -> tensor<1024x256xf32>
+        %22 = linalg.matmul ins(%15, %17 : tensor<1024x512xf32>, tensor<512x256xf32>) outs(%21 : tensor<1024x256xf32>) -> tensor<1024x256xf32>
+        %23 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%22, %10 : tensor<1024x256xf32>, tensor<1024x256xf32>) outs(%13 : tensor<1024x256xf32>) {
+        ^bb0(%arg2: f32, %arg3: f32, %arg4: f32):
+          %24 = arith.addf %arg2, %arg3 : f32
+          linalg.yield %24 : f32
+        } -> tensor<1024x256xf32>
+        flow.dispatch.tensor.store %23, %3, offsets = [0, 0], sizes = [1024, 256], strides = [1, 1] : tensor<1024x256xf32> -> !flow.dispatch.tensor<writeonly:1024x256xf32>
         return
       }
     }
diff --git a/iree/compiler/Codegen/SPIRV/test/tile_and_distribute.mlir b/iree/compiler/Codegen/SPIRV/test/tile_and_distribute.mlir
index e27e144..db4595f 100644
--- a/iree/compiler/Codegen/SPIRV/test/tile_and_distribute.mlir
+++ b/iree/compiler/Codegen/SPIRV/test/tile_and_distribute.mlir
@@ -19,7 +19,7 @@
 ]>
 hal.executable private @matmul {
   hal.executable.variant @vulkan, target = <"vulkan-spirv", "vulkan-spirv-fb"> {
-    hal.executable.entry_point @matmul layout(#executable_layout) attributes {
+    hal.executable.entry_point @matmul layout(#executable_layout) {
       workgroup_size = [16: index, 8: index, 1: index],
       translation.info = #translation
     }
@@ -89,7 +89,7 @@
 ]>
 hal.executable private @conv_1d {
   hal.executable.variant @vulkan, target = <"vulkan-spirv", "vulkan-spirv-fb"> {
-    hal.executable.entry_point @conv_1d layout(#executable_layout) attributes {
+    hal.executable.entry_point @conv_1d layout(#executable_layout) {
       workgroup_size = [32: index, 4: index, 1: index],
       translation.info = #translation
     }
@@ -169,7 +169,7 @@
 ]>
 hal.executable private @conv_2d {
   hal.executable.variant @vulkan, target = <"vulkan-spirv", "vulkan-spirv-fb"> {
-    hal.executable.entry_point @conv_2d layout(#executable_layout) attributes {
+    hal.executable.entry_point @conv_2d layout(#executable_layout) {
       workgroup_size = [32: index, 4: index, 1: index],
       translation.info = #translation
     }
@@ -284,7 +284,7 @@
 ]>
 hal.executable private @conv_3d {
   hal.executable.variant @vulkan, target = <"vulkan-spirv", "vulkan-spirv-fb"> {
-    hal.executable.entry_point @conv_3d layout(#executable_layout) attributes {
+    hal.executable.entry_point @conv_3d layout(#executable_layout) {
       workgroup_size = [32: index, 4: index, 1: index],
       translation.info = #translation
     }
@@ -355,7 +355,7 @@
 module  {
   hal.executable private @pooling_nhwc_max {
     hal.executable.variant @vulkan, target = <"vulkan-spirv", "vulkan-spirv-fb"> {
-      hal.executable.entry_point @pooling_nhwc_max layout(#executable_layout) attributes {
+      hal.executable.entry_point @pooling_nhwc_max layout(#executable_layout) {
         workgroup_size = [32: index, 4: index, 1: index],
         translation.info = #translation
       }
diff --git a/iree/compiler/Codegen/SPIRV/test/tile_and_distribute_scatter.mlir b/iree/compiler/Codegen/SPIRV/test/tile_and_distribute_scatter.mlir
index dc3f8e7..cf08843 100644
--- a/iree/compiler/Codegen/SPIRV/test/tile_and_distribute_scatter.mlir
+++ b/iree/compiler/Codegen/SPIRV/test/tile_and_distribute_scatter.mlir
@@ -11,7 +11,7 @@
 ]>
 hal.executable private @static_scatter_update_slice  {
   hal.executable.variant @vulkan_spirv_fb, target = <"vulkan", "vulkan-spirv-fb"> {
-    hal.executable.entry_point @static_scatter_update_slice layout(#executable_layout) attributes {
+    hal.executable.entry_point @static_scatter_update_slice layout(#executable_layout) {
       translation.info = #translation,
       workgroup_size = [16 : index, 1 : index, 1 : index]
     }
diff --git a/iree/compiler/Codegen/SPIRV/test/tile_and_distribute_sort.mlir b/iree/compiler/Codegen/SPIRV/test/tile_and_distribute_sort.mlir
index a34ad27..b79898a 100644
--- a/iree/compiler/Codegen/SPIRV/test/tile_and_distribute_sort.mlir
+++ b/iree/compiler/Codegen/SPIRV/test/tile_and_distribute_sort.mlir
@@ -10,7 +10,7 @@
 ]>
 hal.executable private @static_3d_sort  {
   hal.executable.variant @vulkan_spirv_fb, target = <"vulkan-spirv", "vulkan-spirv-fb"> {
-    hal.executable.entry_point @static_3d_sort layout(#executable_layout) attributes {
+    hal.executable.entry_point @static_3d_sort layout(#executable_layout) {
       translation.info = #translation,
       workgroup_size = [16 : index, 1 : index, 1 : index]
     }
diff --git a/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_batch_matmul.mlir b/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_batch_matmul.mlir
index e7fc5dc..09b8a67 100644
--- a/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_batch_matmul.mlir
+++ b/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_batch_matmul.mlir
@@ -11,7 +11,7 @@
 ]>
 hal.executable private @fused_fill_batch_matmul {
   hal.executable.variant @vulkan, target = <"vulkan-spirv", "vulkan-spirv-fb"> {
-    hal.executable.entry_point @fused_fill_batch_matmul layout(#executable_layout) attributes {
+    hal.executable.entry_point @fused_fill_batch_matmul layout(#executable_layout) {
       workgroup_size = [16: index, 1: index, 1: index],
       translation.info = #translation
     }
diff --git a/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_conv.mlir b/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_conv.mlir
index f2667f3..7a5b27e 100644
--- a/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_conv.mlir
+++ b/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_conv.mlir
@@ -11,7 +11,7 @@
 ]>
 hal.executable private @conv_static_shape_f32 {
   hal.executable.variant @vulkan, target = <"vulkan-spirv", "vulkan-spirv-fb"> {
-    hal.executable.entry_point @conv_static_shape_f32 layout(#executable_layout) attributes {
+    hal.executable.entry_point @conv_static_shape_f32 layout(#executable_layout) {
       workgroup_size = [4: index, 4: index, 1: index],
       translation.info = #translation
     }
@@ -102,7 +102,7 @@
 ]>
 hal.executable private @depthwise_conv_static_shape_f32 {
   hal.executable.variant @vulkan, target = <"vulkan-spirv", "vulkan-spirv-fb"> {
-    hal.executable.entry_point @depthwise_conv_static_shape_f32 layout(#executable_layout) attributes {
+    hal.executable.entry_point @depthwise_conv_static_shape_f32 layout(#executable_layout) {
       workgroup_size = [4: index, 4: index, 4: index],
       translation.info = #translation
     }
diff --git a/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_matmul.mlir b/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_matmul.mlir
index cd5a873..cc3a4a0 100644
--- a/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_matmul.mlir
+++ b/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_matmul.mlir
@@ -11,7 +11,7 @@
 ]>
 hal.executable private @matmul_static_shape_f16 {
   hal.executable.variant @vulkan, target = <"vulkan-spirv", "vulkan-spirv-fb"> {
-    hal.executable.entry_point @matmul_static_shape_f16 layout(#executable_layout) attributes {
+    hal.executable.entry_point @matmul_static_shape_f16 layout(#executable_layout) {
       workgroup_size = [16: index, 1: index, 1: index],
       translation.info = #translation
     }
@@ -75,7 +75,7 @@
 ]>
 hal.executable private @matmul_static_shape_f32 {
   hal.executable.variant @vulkan, target = <"vulkan-spirv", "vulkan-spirv-fb"> {
-    hal.executable.entry_point @matmul_static_shape_f32 layout(#executable_layout) attributes {
+    hal.executable.entry_point @matmul_static_shape_f32 layout(#executable_layout) {
       workgroup_size = [16: index, 1: index, 1: index],
       translation.info = #translation
     }
diff --git a/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_to_cooperative_ops.mlir b/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_to_cooperative_ops.mlir
index f5b00ce..cde3767 100644
--- a/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_to_cooperative_ops.mlir
+++ b/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_to_cooperative_ops.mlir
@@ -30,7 +30,7 @@
            max_compute_workgroup_invocations = 1024 : i32,
            max_compute_workgroup_size = dense<[2147483647, 65535, 65535]> : vector<3xi32>,
            subgroup_size = 32 : i32}>}> {
-    hal.executable.entry_point public @matmul_256x1024x128_div_sub layout(#executable_layout) attributes {
+    hal.executable.entry_point public @matmul_256x1024x128_div_sub layout(#executable_layout) {
       translation.info = #translation,
       workgroup_size = [32 : index, 1 : index, 1 : index]
     } {
diff --git a/iree/compiler/Codegen/Utils/Utils.cpp b/iree/compiler/Codegen/Utils/Utils.cpp
index 096ca6f..5b54f96 100644
--- a/iree/compiler/Codegen/Utils/Utils.cpp
+++ b/iree/compiler/Codegen/Utils/Utils.cpp
@@ -70,119 +70,6 @@
 }
 
 //===----------------------------------------------------------------------===//
-// Utility functions to get untiled op shapes
-//===----------------------------------------------------------------------===//
-
-SmallVector<int64_t> getDistributedTileSizes(
-    IREE::Flow::PartitionableLoopsInterface interfaceOp,
-    ArrayRef<int64_t> workloadPerWorkgroup) {
-  SmallVector<int64_t> tileSizes(interfaceOp.getNumLoops(), 0);
-  SmallVector<unsigned> partitionableLoops =
-      interfaceOp.getPartitionableLoops(kNumMaxParallelDims);
-  assert(partitionableLoops.size() == workloadPerWorkgroup.size() &&
-         "mismatch in parallelization");
-  for (auto it :
-       llvm::zip(workloadPerWorkgroup, llvm::reverse(partitionableLoops))) {
-    tileSizes[std::get<1>(it)] = std::get<0>(it);
-  }
-  return tileSizes;
-}
-
-/// Walk up the defs of the view, to get the untiled value. Either walks up
-/// `ViewOpInterface` op-chains or the `subtensor` op-chains.
-static Value getViewSource(Value view) {
-  while (true) {
-    Operation *definingOp = view.getDefiningOp();
-    if (!definingOp) break;
-    if (auto viewOp = view.getDefiningOp<ViewLikeOpInterface>()) {
-      view = viewOp.getViewSource();
-      continue;
-    }
-    if (auto subTensorOp = view.getDefiningOp<tensor::ExtractSliceOp>()) {
-      view = subTensorOp.source();
-      continue;
-    }
-    if (auto dispatchTensorLoadOp =
-            view.getDefiningOp<IREE::Flow::DispatchTensorLoadOp>()) {
-      view = dispatchTensorLoadOp.source();
-      continue;
-    }
-    break;
-  }
-  return view;
-}
-
-ArrayRef<int64_t> getUntiledShape(Value tiledView) {
-  auto type = getViewSource(tiledView).getType();
-  return TypeSwitch<Type, ArrayRef<int64_t>>(type)
-      .Case<ShapedType, IREE::Flow::DispatchTensorType>(
-          [&](auto shapedType) { return shapedType.getShape(); })
-      .Default([&](Type type) { return ArrayRef<int64_t>{}; });
-}
-
-// TODO(ravishankarm): Using the result shape for vectorization should be
-// avoided. Ideally the tile size is enough. But there is a phase ordering issue
-// which prevents the tile size from being known at this point.
-SmallVector<int64_t> getUntiledResultShape(linalg::LinalgOp linalgOp,
-                                           unsigned resultNum) {
-  // Get the shape from the `outs` operand.
-  SmallVector<int64_t> outputShape =
-      llvm::to_vector<4>(getUntiledShape(linalgOp.outputs()[resultNum]));
-
-  // If this is already fully static, it means we didn't tile it at the flow
-  // level at all; just return.
-  if (llvm::none_of(outputShape, ShapedType::isDynamic)) return outputShape;
-
-  // For Linalg ops with buffer semantics, subview chains should give us enough
-  // information; also directly return.
-  if (linalgOp.hasBufferSemantics()) return outputShape;
-
-  // For Linalg ops with tensor semantics, we need to correlate how the op
-  // should be tiled and the materialized loop nest. The materialized loops'
-  // upper bounds should be the original dimension size for the corresponding
-  // tiled op shape dimension.
-  auto interfaceOp = cast<IREE::Flow::PartitionableLoopsInterface>(*linalgOp);
-  auto partitionedLoops =
-      interfaceOp.getPartitionableLoops(kNumMaxParallelDims);
-  SmallVector<LoopTilingAndDistributionInfo> loopInfo =
-      getTiledAndDistributedLoopInfo(linalgOp->getParentOfType<FuncOp>());
-  // The number of linalg implicit loops to partition and tiled loops
-  // surrounding the op should match. Otherwise, something is incorrect.
-  assert(partitionedLoops.size() == loopInfo.size());
-
-  // Collect the mapping from the implict loop / iterator indices of the Linalg
-  // op to the output shape dimensions. Normally the first batch of iterators of
-  // a Linalg op are used to index into the output shape dimensions; but that's
-  // not guaranteed. For example, we can see the `linalg.pooling_nhwc_sum` op
-  // having the following indexing map for the output:
-  //   (N, OH, OW, KH, KW, C) -> (N, OH, OW, C)
-  DenseMap<int64_t, int64_t> loopToOutputDimMap;
-  auto outputMap =
-      linalgOp.getTiedIndexingMap(linalgOp.getOutputOperand(resultNum));
-  for (const auto &indexedResult : llvm::enumerate(outputMap.getResults())) {
-    if (auto dimExpr = indexedResult.value().dyn_cast<AffineDimExpr>()) {
-      loopToOutputDimMap[dimExpr.getPosition()] = indexedResult.index();
-    }
-  }
-
-  for (auto pair : llvm::zip(llvm::reverse(partitionedLoops), loopInfo)) {
-    unsigned loopIndex = std::get<0>(pair);
-    const LoopTilingAndDistributionInfo &loopInfo = std::get<1>(pair);
-    // If we know the static upper bound of this loop..
-    if (Optional<int64_t> attrValue =
-            getConstantIntValue(loopInfo.untiledUpperBound)) {
-      // ..and it accesses one output dimension..
-      if (loopToOutputDimMap.count(loopIndex)) {
-        // then we can recover the corresponding shape dimension's size.
-        outputShape[loopToOutputDimMap[loopIndex]] = *attrValue;
-      }
-    }
-  }
-
-  return outputShape;
-}
-
-//===----------------------------------------------------------------------===//
 // Utility functions to set configurations
 //===----------------------------------------------------------------------===//
 
diff --git a/iree/compiler/Codegen/Utils/Utils.h b/iree/compiler/Codegen/Utils/Utils.h
index 8623c87..4b02007 100644
--- a/iree/compiler/Codegen/Utils/Utils.h
+++ b/iree/compiler/Codegen/Utils/Utils.h
@@ -51,31 +51,9 @@
 }
 
 //===----------------------------------------------------------------------===//
-// Utility functions to get untiled op shapes
-//===----------------------------------------------------------------------===//
-
-/// Returns the untiled type of a tiled view for both tensor and memref
-/// types. Either walks the `ViewOpInterface` chain (for memrefs) or the
-/// extract/load op chain (for tensors).
-ArrayRef<int64_t> getUntiledShape(Value tiledView);
-
-/// Returns the untiled result shape for the given Linalg `op` by inspecting
-/// the subview chain or the tiled and distributed loop nests around it.
-SmallVector<int64_t> getUntiledResultShape(linalg::LinalgOp linalgOp,
-                                           unsigned resultNum);
-
-//===----------------------------------------------------------------------===//
 // Utility functions to set configurations
 //===----------------------------------------------------------------------===//
 
-/// Return the tile sizes to use for the Flow partitioned loops given the
-/// workload per workgroup. The tile sizes for the partitioned loops are
-/// obtained from the workload per workgroup. The other loops are returned as
-/// zero.
-SmallVector<int64_t> getDistributedTileSizes(
-    IREE::Flow::PartitionableLoopsInterface interfaceOp,
-    ArrayRef<int64_t> workloadPerWorkgroup);
-
 /// Information about a tiled and distributed loop.
 ///
 /// Right now distribution is happening as the same time when we tile the linalg
diff --git a/iree/compiler/Dialect/Flow/IR/BUILD b/iree/compiler/Dialect/Flow/IR/BUILD
index 048b269..7e1a331 100644
--- a/iree/compiler/Dialect/Flow/IR/BUILD
+++ b/iree/compiler/Dialect/Flow/IR/BUILD
@@ -194,6 +194,7 @@
         ":PartitionableLoopsInterfaceGen",
         "//llvm-external-projects/iree-dialects:IREELinalgExtDialect",
         "@llvm-project//llvm:Support",
+        "@llvm-project//mlir:IR",
         "@llvm-project//mlir:LinalgInterfaces",
         "@llvm-project//mlir:LinalgOps",
     ],
diff --git a/iree/compiler/Dialect/Flow/IR/CMakeLists.txt b/iree/compiler/Dialect/Flow/IR/CMakeLists.txt
index 97403e2..4dcb702 100644
--- a/iree/compiler/Dialect/Flow/IR/CMakeLists.txt
+++ b/iree/compiler/Dialect/Flow/IR/CMakeLists.txt
@@ -121,6 +121,7 @@
     ::PartitionableLoopsInterfaceGen
     IREELinalgExtDialect
     LLVMSupport
+    MLIRIR
     MLIRLinalg
   PUBLIC
 )
diff --git a/iree/compiler/Dialect/Flow/IR/PartitionableLoopsInterface.cpp b/iree/compiler/Dialect/Flow/IR/PartitionableLoopsInterface.cpp
index 4941a02..115224e 100644
--- a/iree/compiler/Dialect/Flow/IR/PartitionableLoopsInterface.cpp
+++ b/iree/compiler/Dialect/Flow/IR/PartitionableLoopsInterface.cpp
@@ -10,11 +10,15 @@
 #include "iree-dialects/Dialect/LinalgExt/IR/TiledOpInterface.h"
 #include "llvm/ADT/SmallVector.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/IR/BuiltinTypes.h"
 
 // clang-format off
 #include "iree/compiler/Dialect/Flow/IR/PartitionableLoopsInterface.cpp.inc"  // IWYU pragma: export
 // clang-format on
 
+namespace mlir {
+namespace iree_compiler {
+
 /// Filters out dimensions in `parallelLoops` that have unit range in
 /// `loopRanges`.
 static llvm::SmallVector<unsigned> pruneUnitTripParallelLoops(
@@ -27,7 +31,7 @@
 
 /// Returns the partitionable loops for all Linalg ops.
 llvm::SmallVector<unsigned> getPartitionableLoopsImpl(
-    mlir::linalg::LinalgOp linalgOp, unsigned maxNumPartitionedLoops) {
+    linalg::LinalgOp linalgOp, unsigned maxNumPartitionedLoops) {
   llvm::SmallVector<unsigned> parallelLoops;
   linalgOp.getParallelDims(parallelLoops);
   // Get the static loop ranges.
@@ -46,8 +50,13 @@
   return parallelLoops;
 }
 
-namespace mlir {
-namespace iree_compiler {
+static llvm::SmallVector<llvm::StringRef> getIteratorTypesFromAttr(
+    ArrayAttr iteratorTypesAttr) {
+  return llvm::to_vector(llvm::map_range(iteratorTypesAttr, [](Attribute attr) {
+    return attr.cast<StringAttr>().getValue();
+  }));
+}
+
 namespace IREE {
 namespace Flow {
 
@@ -65,6 +74,11 @@
     auto linalgOp = cast<linalg::LinalgOp>(op);
     return getPartitionableLoopsImpl(linalgOp, maxNumPartitionedLoops);
   }
+
+  llvm::SmallVector<llvm::StringRef> getIteratorTypes(Operation *op) const {
+    return getIteratorTypesFromAttr(
+        cast<linalg::LinalgOp>(op).iterator_types());
+  }
 };
 
 /// External model implementation for linalg::Mmt4DOp.
@@ -80,6 +94,11 @@
       Operation *op, unsigned maxNumPartitionedLoops) const {
     return {0, 1};
   }
+
+  llvm::SmallVector<StringRef> getIteratorTypes(Operation *op) const {
+    return getIteratorTypesFromAttr(
+        cast<linalg::LinalgOp>(op).iterator_types());
+  }
 };
 
 /// External model implementation for all operations that implement the
@@ -101,6 +120,11 @@
     auto tiledOp = cast<LinalgExt::TiledOpInterface>(op);
     return tiledOp.getPartitionableLoops(maxNumPartitionedLoops);
   }
+
+  llvm::SmallVector<StringRef> getIteratorTypes(Operation *op) const {
+    auto tiledOp = cast<LinalgExt::TiledOpInterface>(op);
+    return tiledOp.getLoopIteratorTypes();
+  }
 };
 
 /// Registers the `LinalgOpPartitionableLoops` model for all Linalg ops. This
diff --git a/iree/compiler/Dialect/Flow/IR/PartitionableLoopsInterface.td b/iree/compiler/Dialect/Flow/IR/PartitionableLoopsInterface.td
index a040c85..1103e70 100644
--- a/iree/compiler/Dialect/Flow/IR/PartitionableLoopsInterface.td
+++ b/iree/compiler/Dialect/Flow/IR/PartitionableLoopsInterface.td
@@ -37,6 +37,13 @@
       /*retTy=*/"llvm::SmallVector<unsigned>",
       /*methodName=*/"getPartitionableLoops",
       /*args=*/(ins "unsigned":$maxNumPartitionedLoops)
+    >,
+    InterfaceMethod<
+      /*desc=*/[{
+        Returns the iterator types for all the loops of the op.
+      }],
+      /*retTy=*/"llvm::SmallVector<llvm::StringRef>",
+      /*methodName=*/"getIteratorTypes"
     >
   ];
 }
diff --git a/iree/compiler/Dialect/HAL/IR/HALOps.cpp b/iree/compiler/Dialect/HAL/IR/HALOps.cpp
index 2469273..a2c4c5e 100644
--- a/iree/compiler/Dialect/HAL/IR/HALOps.cpp
+++ b/iree/compiler/Dialect/HAL/IR/HALOps.cpp
@@ -684,7 +684,7 @@
   if (failed(parser.parseKeyword("layout")) || failed(parser.parseLParen()) ||
       failed(parser.parseAttribute(layoutAttr)) ||
       failed(parser.parseRParen()) ||
-      failed(parser.parseOptionalAttrDictWithKeyword(result.attributes))) {
+      failed(parser.parseOptionalAttrDict(result.attributes))) {
     return failure();
   }
   result.addAttribute("layout", layoutAttr);
diff --git a/iree/compiler/Dialect/HAL/IR/test/executable_ops.mlir b/iree/compiler/Dialect/HAL/IR/test/executable_ops.mlir
index 4b89988..3ae833d 100644
--- a/iree/compiler/Dialect/HAL/IR/test/executable_ops.mlir
+++ b/iree/compiler/Dialect/HAL/IR/test/executable_ops.mlir
@@ -1,7 +1,7 @@
 // RUN: iree-opt -split-input-file %s | FileCheck %s
 
-#executable_target_format = #hal.executable.target<"backend", "format">
 
+#executable_target_format = #hal.executable.target<"backend", "format">
 // CHECK-LABEL: @ex
 hal.executable @ex {
   // CHECK: hal.executable.variant public @backend, target = #executable_target_format
@@ -13,7 +13,7 @@
         #hal.descriptor_set.binding<0, storage_buffer>,
         #hal.descriptor_set.binding<1, storage_buffer>
       ]>
-    ]>) attributes {
+    ]>) {
       workgroup_size = [4 : index, 1 : index, 1 : index]
     }
   }
@@ -41,7 +41,7 @@
         #hal.descriptor_set.binding<0, storage_buffer>,
         #hal.descriptor_set.binding<1, storage_buffer>
       ]>
-    ]>) attributes {
+    ]>) {
       workgroup_size = [4 : index, 1 : index, 1 : index]
     } {
     ^bb0(%arg0: index, %arg1: index, %arg2: index):
diff --git a/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/test/linking.mlir b/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/test/linking.mlir
index 0a6ab47..86ca654 100644
--- a/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/test/linking.mlir
+++ b/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/test/linking.mlir
@@ -64,13 +64,7 @@
 }
 hal.executable private @call_dispatch_3  {
   hal.executable.variant @vulkan_spirv_fb, target = #executable_target_vulkan_spirv_fb {
-    hal.executable.entry_point @call_dispatch_3 ordinal(0) layout(#executable_layout_1) {
-    ^bb0(%arg0: index, %arg1: index, %arg2: index):  // no predecessors
-      %c1 = arith.constant 1 : index
-      %c56 = arith.constant 56 : index
-      %c56_0 = arith.constant 56 : index
-      hal.return %c1, %c56, %c56_0 : index, index, index
-    }
+    hal.executable.entry_point @call_dispatch_3 ordinal(0) layout(#executable_layout_1)
     builtin.module {
       spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]> {
         spv.func @call_dispatch_3() "None" {
diff --git a/iree/compiler/Dialect/HAL/Transforms/test/convert_to_hal.mlir b/iree/compiler/Dialect/HAL/Transforms/test/convert_to_hal.mlir
index cc85089..e641e99 100644
--- a/iree/compiler/Dialect/HAL/Transforms/test/convert_to_hal.mlir
+++ b/iree/compiler/Dialect/HAL/Transforms/test/convert_to_hal.mlir
@@ -20,7 +20,7 @@
   // CHECK: hal.executable private @ex
   hal.executable private @ex {
     hal.executable.variant public @embedded_elf_x86_64, target = #executable_target_embedded_elf_x86_64_ {
-      hal.executable.entry_point public @dispatch ordinal(0) layout(#executable_layout) attributes {
+      hal.executable.entry_point public @dispatch ordinal(0) layout(#executable_layout) {
         translation.info = #iree_codegen.translation.info<"CPUDefault", workload_per_wg = [4]>
       } {
       ^bb0(%arg0: index, %arg1: index, %arg2: index):  // no predecessors
diff --git a/iree/compiler/Dialect/HAL/Transforms/test/materialize_resource_caches.mlir b/iree/compiler/Dialect/HAL/Transforms/test/materialize_resource_caches.mlir
index aca1cdf..52ab835 100644
--- a/iree/compiler/Dialect/HAL/Transforms/test/materialize_resource_caches.mlir
+++ b/iree/compiler/Dialect/HAL/Transforms/test/materialize_resource_caches.mlir
@@ -122,13 +122,13 @@
 //   - If there is no matching hal.executable.variant then the executable will not be cached
 hal.executable @exe {
   hal.executable.variant @vmvx, target = <"vmvx", "vmvx-bytecode-fb"> {
-    hal.executable.entry_point @entry0 ordinal(0) layout(#executable_layout_0) attributes {
+    hal.executable.entry_point @entry0 ordinal(0) layout(#executable_layout_0) {
       workgroup_size = [32 : index, 1 : index, 1 : index]
     }
-    hal.executable.entry_point @entry0_alias ordinal(0) layout(#executable_layout_0) attributes {
+    hal.executable.entry_point @entry0_alias ordinal(0) layout(#executable_layout_0) {
       workgroup_size = [32 : index, 1 : index, 1 : index]
     }
-    hal.executable.entry_point @entry1 ordinal(1) layout(#executable_layout_1) attributes {
+    hal.executable.entry_point @entry1 ordinal(1) layout(#executable_layout_1) {
       workgroup_size = [32 : index, 1 : index, 1 : index]
     }
   }
diff --git a/iree/compiler/Dialect/HAL/Transforms/test/resolve_entry_point_ordinals.mlir b/iree/compiler/Dialect/HAL/Transforms/test/resolve_entry_point_ordinals.mlir
index 3e35458..132c769 100644
--- a/iree/compiler/Dialect/HAL/Transforms/test/resolve_entry_point_ordinals.mlir
+++ b/iree/compiler/Dialect/HAL/Transforms/test/resolve_entry_point_ordinals.mlir
@@ -7,7 +7,7 @@
         #hal.descriptor_set.binding<0, storage_buffer>,
         #hal.descriptor_set.binding<1, storage_buffer>
       ]>
-    ]>) attributes {
+    ]>) {
       workgroup_size = [32 : index, 1 : index, 1 : index]
     }
   }
@@ -62,7 +62,7 @@
         #hal.descriptor_set.binding<0, storage_buffer>,
         #hal.descriptor_set.binding<1, storage_buffer>
       ]>
-    ]>) attributes {
+    ]>) {
       workgroup_size = [32 : index, 1 : index, 1 : index]
     }
   }
diff --git a/iree/compiler/Dialect/Modules/VMVX/Transforms/Passes.cpp b/iree/compiler/Dialect/Modules/VMVX/Transforms/Passes.cpp
index 94ac38c..2797737 100644
--- a/iree/compiler/Dialect/Modules/VMVX/Transforms/Passes.cpp
+++ b/iree/compiler/Dialect/Modules/VMVX/Transforms/Passes.cpp
@@ -30,10 +30,6 @@
 static void buildVectorVMVXTransformPassPipeline(OpPassManager &passManager) {
   passManager.nest<ModuleOp>().nest<FuncOp>().addPass(
       createTypePropagationPass());
-  passManager.nest<ModuleOp>().nest<FuncOp>().addPass(
-      createTileAndDistributeToWorkgroupsPass());
-  passManager.addPass(createCanonicalizerPass());
-  passManager.addPass(createCSEPass());
   passManager.addPass(createLLVMCPULowerExecutableTargetPass());
 
   OpPassManager &nestedModulePM = passManager.nest<ModuleOp>();
diff --git a/iree/test/e2e/regression/generate_e2e_matmul_tests.py b/iree/test/e2e/regression/generate_e2e_matmul_tests.py
index a2ae6b1..2547c96 100644
--- a/iree/test/e2e/regression/generate_e2e_matmul_tests.py
+++ b/iree/test/e2e/regression/generate_e2e_matmul_tests.py
@@ -416,7 +416,7 @@
     compilation_info_string = (
         f"#compilation{generate_function.compilation_index} = #iree_codegen.compilation.info<\n"
         f"  #iree_codegen.lowering.config<tile_sizes = [{compilation_info.tile_sizes}], native_vector_size = {compilation_info.native_vector_size}>,\n"
-        f"  #iree_codegen.translation.info<\"{compilation_info.dispatch_lowering_pass_pipeline}\", workload_per_wg = {compilation_info.workload_per_wg}>,\n"
+        f"  #iree_codegen.translation.info<\"{compilation_info.dispatch_lowering_pass_pipeline}\", workload_per_wg = []>,\n"
         f"  workgroup_size = {compilation_info.workgroup_size_str()}>\n")
     compilation_info_attr = f"{{compilation.info = #compilation{generate_function.compilation_index}}} "
     func_definition = func_definition + compilation_info_string
diff --git a/iree/test/e2e/regression/lowering_config.mlir b/iree/test/e2e/regression/lowering_config.mlir
index 5798ec7..07c49f9 100644
--- a/iree/test/e2e/regression/lowering_config.mlir
+++ b/iree/test/e2e/regression/lowering_config.mlir
@@ -1,10 +1,10 @@
 #compilation0 = #iree_codegen.compilation.info<
-    #iree_codegen.lowering.config<tile_sizes = [[], [8, 8, 0], [0, 0, 8]], native_vector_size = []>,
-    #iree_codegen.translation.info<"CPUDoubleTilingExpert", workload_per_wg = [32, 32]>,
+    #iree_codegen.lowering.config<tile_sizes = [[32, 32], [8, 8, 0], [0, 0, 8]], native_vector_size = []>,
+    #iree_codegen.translation.info<"CPUDoubleTilingExpert", workload_per_wg = []>,
     workgroup_size = []>
 #compilation1 = #iree_codegen.compilation.info<
-    #iree_codegen.lowering.config<tile_sizes = [[], [4, 4, 0], [0, 0, 4]], native_vector_size = []>,
-    #iree_codegen.translation.info<"CPUDoubleTilingExpert", workload_per_wg = [64, 64]>,
+    #iree_codegen.lowering.config<tile_sizes = [[64, 64], [4, 4, 0], [0, 0, 4]], native_vector_size = []>,
+    #iree_codegen.translation.info<"CPUDoubleTilingExpert", workload_per_wg = []>,
     workgroup_size = []>
 func @lowering_config_test() {
   %a = util.unfoldable_constant dense<1.0> : tensor<128x256xf32>