Vectorize Linalg ops for dynamic cases. (#8888)

This is a fallback solution which tile dynamic dims with `1`. The tiling logic of first level is not changed.

This adds more lit tests which include static cases and dynamic cases.
diff --git a/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp b/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp
index b424b9a..68805e3 100644
--- a/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp
+++ b/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp
@@ -331,6 +331,23 @@
   }
 }
 
+static void setAlwaysVectorizeSizes(linalg::LinalgOp op,
+                                    SmallVectorImpl<int64_t> &parallelSizes,
+                                    SmallVectorImpl<int64_t> &reductionSizes) {
+  Optional<SmallVector<int64_t, 4>> staticLoopRanges = op.getStaticLoopRanges();
+  for (auto en :
+       llvm::enumerate(llvm::zip(*staticLoopRanges, op.iterator_types()))) {
+    auto size = std::get<0>(en.value());
+    if (!ShapedType::isDynamic(size)) continue;
+    auto iterType = std::get<1>(en.value()).cast<StringAttr>().getValue();
+    if (iterType == getParallelIteratorTypeName()) {
+      parallelSizes[en.index()] = 1;
+    } else {
+      reductionSizes[en.index()] = 1;
+    }
+  }
+}
+
 /// Sets the default configuration to use for an operation that implements the
 /// `PartitionableLoopsInterface`, given the iteration domain of all the loops.
 static LogicalResult setDefaultRootConfig(
@@ -379,7 +396,8 @@
          "three loops");
   // The tiling for parallel dims and reduction dims should be separated.
   SmallVector<int64_t> parallelTileSizes;
-  int64_t nLoops = cast<linalg::LinalgOp>(op.getOperation()).getNumLoops();
+  auto linalgOp = cast<linalg::LinalgOp>(op.getOperation());
+  int64_t nLoops = linalgOp.getNumLoops();
   if (nLoops >= 3) {
     parallelTileSizes.append(nLoops - 3, 1);
     parallelTileSizes.push_back(getMaxTileSize(
@@ -398,6 +416,8 @@
   reductionTileSizes.push_back(
       getMaxTileSize(0, K, target2ndTileSizes[2], vectorSize));
 
+  setAlwaysVectorizeSizes(linalgOp, parallelTileSizes, reductionTileSizes);
+
   TileSizesListType tileSizes;
   tileSizes.emplace_back(flowTileSizes.begin(), flowTileSizes.end());
   tileSizes.push_back(parallelTileSizes);
@@ -646,6 +666,7 @@
   SmallVector<int64_t> reductionTileSizes;
   splitParallelAndReductionTiles(linalgOp, parallelTileSizes,
                                  reductionTileSizes);
+  setAlwaysVectorizeSizes(linalgOp, parallelTileSizes, reductionTileSizes);
 
   TileSizesListType tileSizes;
   tileSizes.push_back(flowTileSizes);
@@ -703,6 +724,7 @@
   }
   SmallVector<int64_t> reductionTileSizes;
   splitParallelAndReductionTiles(convOp, parallelTileSizes, reductionTileSizes);
+  setAlwaysVectorizeSizes(convOp, parallelTileSizes, reductionTileSizes);
 
   TileSizesListType tileSizes;
   tileSizes.push_back(flowTileSizes);
diff --git a/iree/compiler/Codegen/LLVMCPU/test/materialize_aarch64_launch_configuration.mlir b/iree/compiler/Codegen/LLVMCPU/test/materialize_aarch64_launch_configuration.mlir
index 4be2ead..a1fe67a 100644
--- a/iree/compiler/Codegen/LLVMCPU/test/materialize_aarch64_launch_configuration.mlir
+++ b/iree/compiler/Codegen/LLVMCPU/test/materialize_aarch64_launch_configuration.mlir
@@ -208,15 +208,57 @@
     #hal.descriptor_set.binding<2, storage_buffer>
   ]>
 ]>
-hal.executable private @matmul_aarch_i8_i8_i32  {
+hal.executable private @matmul_aarch_i8_i8_i32_static  {
   hal.executable.variant public @system_elf_arm_64, target = <"llvm", "system-elf-arm_64", {
     data_layout = "e-m:e-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128",
     native_vector_size = 16 : index,
     target_triple = "aarch64-none-linux-android30"
   }> {
-  hal.executable.entry_point public @matmul_aarch_i8_i8_i32 layout(#executable_layout)
+  hal.executable.entry_point public @matmul_aarch_i8_i8_i32_static layout(#executable_layout)
     builtin.module {
-      func.func @matmul_aarch_i8_i8_i32() {
+      func.func @matmul_aarch_i8_i8_i32_static() {
+        %c0_i32 = arith.constant 0 : i32
+        %c0 = arith.constant 0 : index
+        %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<readonly:128x384xi8>
+        %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<readonly:384x1536xi8>
+        %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<writeonly:128x1536xi32>
+        %3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [128, 384], strides = [1, 1] : !flow.dispatch.tensor<readonly:128x384xi8> -> tensor<128x384xi8>
+        %4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [384, 1536], strides = [1, 1] : !flow.dispatch.tensor<readonly:384x1536xi8> -> tensor<384x1536xi8>
+        %5 = linalg.init_tensor [128, 1536] : tensor<128x1536xi32>
+        %6 = linalg.fill ins(%c0_i32 : i32) outs(%5 : tensor<128x1536xi32>) -> tensor<128x1536xi32>
+        %7 = linalg.matmul ins(%3, %4 : tensor<128x384xi8>, tensor<384x1536xi8>) outs(%6 : tensor<128x1536xi32>) -> tensor<128x1536xi32>
+        flow.dispatch.tensor.store %7, %2, offsets = [0, 0], sizes = [128, 1536], strides = [1, 1] : tensor<128x1536xi32> -> !flow.dispatch.tensor<writeonly:128x1536xi32>
+        return
+      }
+    }
+  }
+}
+
+//  CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[64, 64, 0], [4, 16, 0], [0, 0, 4]]>
+//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<CPUDoubleTilingExpert>
+//      CHECK: hal.executable.entry_point public @matmul_aarch_i8_i8_i32_static
+// CHECK-SAME:     translation_info = #[[TRANSLATION]]
+//      CHECK:   linalg.matmul
+// CHECK-SAME:       lowering_config = #[[CONFIG]]
+
+// -----
+
+#executable_layout = #hal.executable.layout<push_constants = 0, sets = [
+  #hal.descriptor_set.layout<0, bindings = [
+    #hal.descriptor_set.binding<0, storage_buffer>,
+    #hal.descriptor_set.binding<1, storage_buffer>,
+    #hal.descriptor_set.binding<2, storage_buffer>
+  ]>
+]>
+hal.executable private @matmul_aarch_i8_i8_i32_dynamic  {
+  hal.executable.variant public @system_elf_arm_64, target = <"llvm", "system-elf-arm_64", {
+    data_layout = "e-m:e-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128",
+    native_vector_size = 16 : index,
+    target_triple = "aarch64-none-linux-android30"
+  }> {
+  hal.executable.entry_point public @matmul_aarch_i8_i8_i32_dynamic layout(#executable_layout)
+    builtin.module {
+      func.func @matmul_aarch_i8_i8_i32_dynamic() {
         %c0 = arith.constant 0 : index
         %M = hal.interface.constant.load[0] : index
         %N = hal.interface.constant.load[1] : index
@@ -242,9 +284,9 @@
   }
 }
 
-//  CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[64, 64, 0], [4, 16, 0], [0, 0, 4]{{\]}}>
+//  CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[64, 64, 0], [1, 1, 0], [0, 0, 1]]>
 //  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<CPUDoubleTilingExpert>
-//      CHECK: hal.executable.entry_point public @matmul_aarch_i8_i8_i32
+//      CHECK: hal.executable.entry_point public @matmul_aarch_i8_i8_i32_dynamic
 // CHECK-SAME:     translation_info = #[[TRANSLATION]]
 //      CHECK:   linalg.matmul
 // CHECK-SAME:       lowering_config = #[[CONFIG]]
diff --git a/iree/compiler/Codegen/LLVMCPU/test/materialize_x86_64_launch_configuration.mlir b/iree/compiler/Codegen/LLVMCPU/test/materialize_x86_64_launch_configuration.mlir
index 8281ba3..d81ec74 100644
--- a/iree/compiler/Codegen/LLVMCPU/test/materialize_x86_64_launch_configuration.mlir
+++ b/iree/compiler/Codegen/LLVMCPU/test/materialize_x86_64_launch_configuration.mlir
@@ -7,15 +7,56 @@
     #hal.descriptor_set.binding<2, storage_buffer>
   ]>
 ]>
-hal.executable private @matvec_tensors  {
+hal.executable private @matvec_static  {
   hal.executable.variant @llvm, target = <"llvm", "embedded-elf-x86_64", {
     data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128",
     native_vector_size = 16 : index,
     target_triple = "x86_64-unknown-linux-gnu"
   }> {
-    hal.executable.entry_point @matvec_tensors layout(#executable_layout)
+    hal.executable.entry_point @matvec_static layout(#executable_layout)
     builtin.module {
-      func.func @matvec_tensors() {
+      func.func @matvec_static() {
+        %cst = arith.constant 0.000000e+00 : f32
+        %c0 = arith.constant 0 : index
+        %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<readonly:128x384xf32>
+        %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<readonly:384xf32>
+        %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<writeonly:128xf32>
+        %3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [128, 384], strides = [1, 1] : !flow.dispatch.tensor<readonly:128x384xf32> -> tensor<128x384xf32>
+        %4 = flow.dispatch.tensor.load %1, offsets = [0], sizes = [384], strides = [1] : !flow.dispatch.tensor<readonly:384xf32> -> tensor<384xf32>
+        %5 = linalg.init_tensor [128] : tensor<128xf32>
+        %6 = linalg.fill ins(%cst : f32) outs(%5 : tensor<128xf32>) -> tensor<128xf32>
+        %7 = linalg.matvec ins(%3, %4 : tensor<128x384xf32>, tensor<384xf32>) outs(%6 : tensor<128xf32>) -> tensor<128xf32>
+        flow.dispatch.tensor.store %7, %2, offsets = [0], sizes = [128], strides = [1] : tensor<128xf32> -> !flow.dispatch.tensor<writeonly:128xf32>
+        return
+      }
+    }
+  }
+}
+//  CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[64, 0], [32, 0], [0, 16]]>
+//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<CPUDoubleTilingExpert>
+//      CHECK: hal.executable.entry_point public @matvec_static
+// CHECK-SAME:     translation_info = #[[TRANSLATION]]
+//      CHECK: linalg.matvec
+// CHECK-SAME:     lowering_config = #[[CONFIG]]
+
+// -----
+
+#executable_layout = #hal.executable.layout<push_constants = 0, sets = [
+  #hal.descriptor_set.layout<0, bindings = [
+    #hal.descriptor_set.binding<0, storage_buffer>,
+    #hal.descriptor_set.binding<1, storage_buffer>,
+    #hal.descriptor_set.binding<2, storage_buffer>
+  ]>
+]>
+hal.executable private @matvec_dynamic  {
+  hal.executable.variant @llvm, target = <"llvm", "embedded-elf-x86_64", {
+    data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128",
+    native_vector_size = 16 : index,
+    target_triple = "x86_64-unknown-linux-gnu"
+  }> {
+    hal.executable.entry_point @matvec_dynamic layout(#executable_layout)
+    builtin.module {
+      func.func @matvec_dynamic() {
         %c0 = arith.constant 0 : index
         %cst = arith.constant 0.000000e+00 : f32
         %0 = hal.interface.constant.load[0] : i32
@@ -40,9 +81,9 @@
     }
   }
 }
-//  CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[64, 0], [32, 0], [0, 16]]>
+//  CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[64, 0], [1, 0], [0, 1]]>
 //  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<CPUDoubleTilingExpert>
-//      CHECK: hal.executable.entry_point public @matvec_tensors
+//      CHECK: hal.executable.entry_point public @matvec_dynamic
 // CHECK-SAME:     translation_info = #[[TRANSLATION]]
 //      CHECK: linalg.matvec
 // CHECK-SAME:     lowering_config = #[[CONFIG]]
@@ -56,15 +97,56 @@
     #hal.descriptor_set.binding<2, storage_buffer>
   ]>
 ]>
-hal.executable private @dot_tensors  {
+hal.executable private @dot_static  {
   hal.executable.variant @llvm, target = <"llvm", "embedded-elf-x86_64", {
     data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128",
     native_vector_size = 16 : index,
     target_triple = "x86_64-unknown-linux-gnu"
   }> {
-    hal.executable.entry_point @dot_tensors layout(#executable_layout)
+    hal.executable.entry_point @dot_static layout(#executable_layout)
     builtin.module {
-      func.func @dot_tensors() {
+      func.func @dot_static() {
+        %cst = arith.constant 0.000000e+00 : f32
+        %c0 = arith.constant 0 : index
+        %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<readonly:384xf32>
+        %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<readonly:384xf32>
+        %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<writeonly:f32>
+        %3 = flow.dispatch.tensor.load %0, offsets = [0], sizes = [384], strides = [1] : !flow.dispatch.tensor<readonly:384xf32> -> tensor<384xf32>
+        %4 = flow.dispatch.tensor.load %1, offsets = [0], sizes = [384], strides = [1] : !flow.dispatch.tensor<readonly:384xf32> -> tensor<384xf32>
+        %5 = linalg.init_tensor [] : tensor<f32>
+        %6 = linalg.fill ins(%cst : f32) outs(%5 : tensor<f32>) -> tensor<f32>
+        %7 = linalg.dot ins(%3, %4 : tensor<384xf32>, tensor<384xf32>) outs(%6 : tensor<f32>) -> tensor<f32>
+        flow.dispatch.tensor.store %7, %2, offsets = [], sizes = [], strides = [] : tensor<f32> -> !flow.dispatch.tensor<writeonly:f32>
+        return
+      }
+    }
+  }
+}
+//  CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[0], [0], [16]]>
+//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<CPUDoubleTilingExpert>
+//      CHECK: hal.executable.entry_point public @dot_static
+// CHECK-SAME:     translation_info = #[[TRANSLATION]]
+//      CHECK: linalg.dot
+// CHECK-SAME:     lowering_config = #[[CONFIG]]
+
+// -----
+
+#executable_layout = #hal.executable.layout<push_constants = 0, sets = [
+  #hal.descriptor_set.layout<0, bindings = [
+    #hal.descriptor_set.binding<0, storage_buffer>,
+    #hal.descriptor_set.binding<1, storage_buffer>,
+    #hal.descriptor_set.binding<2, storage_buffer>
+  ]>
+]>
+hal.executable private @dot_dynamic  {
+  hal.executable.variant @llvm, target = <"llvm", "embedded-elf-x86_64", {
+    data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128",
+    native_vector_size = 16 : index,
+    target_triple = "x86_64-unknown-linux-gnu"
+  }> {
+    hal.executable.entry_point @dot_dynamic layout(#executable_layout)
+    builtin.module {
+      func.func @dot_dynamic() {
         %c0 = arith.constant 0 : index
         %cst = arith.constant 0.000000e+00 : f32
         %0 = hal.interface.constant.load[0] : i32
@@ -85,9 +167,9 @@
     }
   }
 }
-//  CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[0], [0], [16]]>
+//  CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[0], [0], [1]]>
 //  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<CPUDoubleTilingExpert>
-//      CHECK: hal.executable.entry_point public @dot_tensors
+//      CHECK: hal.executable.entry_point public @dot_dynamic
 // CHECK-SAME:     translation_info = #[[TRANSLATION]]
 //      CHECK: linalg.dot
 // CHECK-SAME:     lowering_config = #[[CONFIG]]
@@ -135,7 +217,7 @@
     }
   }
 }
-//  CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[64, 64], [1, 4], [0, 0]]>
+//  CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[64, 64], [1, 1], [0, 0]]>
 //  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<CPUDoubleTilingExpert>
 //      CHECK: hal.executable.entry_point public @add
 // CHECK-SAME:     translation_info = #[[TRANSLATION]]
@@ -193,7 +275,7 @@
   }
 }
 
-//  CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[0, 64, 64, 64], [1, 1, 1, 4], [0, 0, 0, 0]]>
+//  CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[0, 64, 64, 64], [1, 1, 1, 1], [0, 0, 0, 0]]>
 //  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<CPUDoubleTilingExpert>
 //      CHECK: hal.executable.entry_point public @add4D
 // CHECK-SAME:     translation_info = #[[TRANSLATION]]
@@ -202,6 +284,47 @@
 
 // -----
 
+#executable_layout = #hal.executable.layout<push_constants = 0, sets = [
+  #hal.descriptor_set.layout<0, bindings = [
+    #hal.descriptor_set.binding<0, storage_buffer>,
+    #hal.descriptor_set.binding<1, storage_buffer>,
+    #hal.descriptor_set.binding<2, storage_buffer>
+  ]>
+]>
+hal.executable private @add_static {
+  hal.executable.variant @llvm, target = <"llvm", "embedded-elf-x86_64", {
+    data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128",
+    native_vector_size = 16 : index,
+    target_triple = "x86_64-unknown-linux-gnu"
+  }> {
+    hal.executable.entry_point @add_static layout(#executable_layout)
+    builtin.module {
+      func.func @add_static() {
+        %c0 = arith.constant 0 : index
+        %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<readonly:64x16x32x128xf32>
+        %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<writeonly:64x16x32x128xf32>
+        %2 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0, 0], sizes = [64, 16, 32, 128], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:64x16x32x128xf32> -> tensor<64x16x32x128xf32>
+        %3 = linalg.init_tensor [64, 16, 32, 128] : tensor<64x16x32x128xf32>
+        %4 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%2 : tensor<64x16x32x128xf32>) outs(%3 : tensor<64x16x32x128xf32>) {
+        ^bb0(%arg0: f32, %arg1: f32):
+          %5 = arith.addf %arg0, %arg0 : f32
+          linalg.yield %5 : f32
+        } -> tensor<64x16x32x128xf32>
+        flow.dispatch.tensor.store %4, %1, offsets = [0, 0, 0, 0], sizes = [64, 16, 32, 128], strides = [1, 1, 1, 1] : tensor<64x16x32x128xf32> -> !flow.dispatch.tensor<writeonly:64x16x32x128xf32>
+        return
+      }
+    }
+  }
+}
+//  CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[0, 8, 16, 64], [1, 1, 1, 4], [0, 0, 0, 0]]>
+//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<CPUDoubleTilingExpert>
+//      CHECK: hal.executable.entry_point public @add_static
+// CHECK-SAME:     translation_info = #[[TRANSLATION]]
+//      CHECK: linalg.generic
+// CHECK-SAME:     lowering_config = #[[CONFIG]]
+
+// -----
+
 #compilation = #iree_codegen.compilation_info<
     lowering_config = <tile_sizes = [[64, 64, 0], [32, 32, 0], [0, 0, 32]]>,
     translation_info  = <CPUDoubleTilingExpert>,
@@ -258,11 +381,11 @@
     #hal.descriptor_set.binding<1, storage_buffer>
   ]>
 ]>
-hal.executable @copy_op {
+hal.executable @copy_op_dynamic {
   hal.executable.variant @system_elf_x86_64, target = <"llvm", "system-elf-x86_64"> {
-    hal.executable.entry_point @copy_op layout(#executable_layout)
+    hal.executable.entry_point @copy_op_dynamic layout(#executable_layout)
     builtin.module {
-      func.func @copy_op() {
+      func.func @copy_op_dynamic() {
         %d0 = hal.interface.constant.load[0] : index
         %d1 = hal.interface.constant.load[1] : index
         %d2 = hal.interface.constant.load[2] : index
@@ -285,9 +408,9 @@
   }
 }
 
-//  CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[64, 64], [1, 4], [0, 0]{{\]}}>
+//  CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[64, 64], [1, 1], [0, 0]{{\]}}>
 //  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<CPUBufferOpsTileAndVectorize>
-//      CHECK: hal.executable.entry_point public @copy_op
+//      CHECK: hal.executable.entry_point public @copy_op_dynamic
 // CHECK-SAME:     translation_info = #[[TRANSLATION]]
 //      CHECK:   linalg.generic
 // CHECK-SAME:       lowering_config = #[[CONFIG]]
@@ -421,7 +544,7 @@
     }
   }
 }
-//  CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[64, 64, 0], [1, 4, 0], [0, 0, 4]{{\]}}>
+//  CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[64, 64, 0], [1, 1, 0], [0, 0, 1]{{\]}}>
 //      CHECK: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<CPUDoubleTilingExpert>
 //      CHECK: hal.executable.entry_point public @outs_fusion_fn
 // CHECK-SAME:   translation_info = #[[TRANSLATION]]
@@ -440,15 +563,15 @@
     #hal.descriptor_set.binding<2, storage_buffer>
   ]>
 ]>
-hal.executable private @conv {
+hal.executable private @conv_dynamic {
   hal.executable.variant public @system_elf_x86_64, target = <"llvm", "system-elf-x86_64", {
     data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128",
     native_vector_size = 16 : index,
     target_triple = "x86_64-unknown-linux-gnu"
   }> {
-    hal.executable.entry_point public @conv layout(#executable_layout)
+    hal.executable.entry_point public @conv_dynamic layout(#executable_layout)
     builtin.module {
-      func.func @conv() {
+      func.func @conv_dynamic() {
         %N = hal.interface.constant.load[0] : index
         %H = hal.interface.constant.load[1] : index
         %W = hal.interface.constant.load[2] : index
@@ -481,9 +604,9 @@
   }
 }
 
-//  CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[0, 64, 64, 64, 0, 0, 0], [1, 1, 8, 8, 0, 0, 0], [0, 0, 0, 0, 1, 1, 8]]>
+//  CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[0, 64, 64, 64, 0, 0, 0], [1, 1, 1, 1, 0, 0, 0], [0, 0, 0, 0, 1, 1, 1]]>
 //  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<CPUConvTileAndDecomposeExpert>
-//      CHECK: hal.executable.entry_point public @conv
+//      CHECK: hal.executable.entry_point public @conv_dynamic
 // CHECK-SAME:     translation_info = #[[TRANSLATION]]
 //      CHECK:     linalg.conv_2d_nhwc_hwcf
 //      CHECK:         lowering_config = #[[CONFIG]]
@@ -631,7 +754,7 @@
     #hal.descriptor_set.binding<2, storage_buffer>
   ]>
 ]>
-hal.executable private @matmul_x86  {
+hal.executable private @matmul_static  {
   hal.executable.variant public @embedded_elf_x86_64, target = #hal.executable.target<
     "llvm",
     "embedded-elf-x86_64", {
@@ -639,9 +762,9 @@
       native_vector_size = 16 : index,
       target_triple = "x86_64-unknown-unknown-eabi-elf"
     }> {
-    hal.executable.entry_point public @matmul_x86 layout(#executable_layout)
+    hal.executable.entry_point public @matmul_static layout(#executable_layout)
     builtin.module {
-      func.func @matmul_x86() {
+      func.func @matmul_static() {
         %cst = arith.constant 0.0 : f32
         %lhs_binding = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) : !flow.dispatch.tensor<readonly:384x512xf32>
         %rhs_binding = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) : !flow.dispatch.tensor<readonly:512x128xf32>
@@ -664,7 +787,7 @@
 
 //  CHECK-DAG: #[[CONFIG:.+]] =  #iree_codegen.lowering_config<tile_sizes = {{\[}}[64, 64, 0], [8, 32, 0], [0, 0, 16]{{\]}}>
 //  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<CPUDoubleTilingExpert>
-//      CHECK: hal.executable.entry_point public @matmul_x86
+//      CHECK: hal.executable.entry_point public @matmul_static
 // CHECK-SAME:     translation_info = #[[TRANSLATION]]
 //      CHECK: linalg.matmul
 // CHECK-SAME:     lowering_config = #[[CONFIG]]
@@ -696,7 +819,7 @@
         %input = flow.dispatch.tensor.load %arg0, offsets = [0, 0, 0], sizes = [7, 7, 2048], strides = [1, 1, 1]
             : !flow.dispatch.tensor<readonly:7x7x2048xf32> -> tensor<7x7x2048xf32>
         %init = linalg.init_tensor [7] : tensor<7xf32>
-        %fill = linalg.fill ins(%cst : f32) outs(%init : tensor<7xf32>) -> tensor<7xf32> 
+        %fill = linalg.fill ins(%cst : f32) outs(%init : tensor<7xf32>) -> tensor<7xf32>
         %reduce = linalg.generic {
             indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0)>],
             iterator_types = ["parallel", "reduction", "reduction"]}
@@ -736,7 +859,7 @@
     #hal.descriptor_set.binding<2, storage_buffer>
   ]>
 ]>
-hal.executable private @matmul_x86_i8_i8_i32  {
+hal.executable private @matmul_i8_i8_i32_static  {
   hal.executable.variant public @embedded_elf_x86_64, target = #hal.executable.target<
     "llvm",
     "embedded-elf-x86_64", {
@@ -744,28 +867,20 @@
       native_vector_size = 4 : index,
       target_triple = "x86_64-unknown-unknown-eabi-elf"
     }> {
-    hal.executable.entry_point public @matmul_x86_i8_i8_i32 layout(#executable_layout)
+    hal.executable.entry_point public @matmul_i8_i8_i32_static layout(#executable_layout)
     builtin.module {
-      func.func @matmul_x86_i8_i8_i32() {
+      func.func @matmul_i8_i8_i32_static() {
+        %c0_i32 = arith.constant 0 : i32
         %c0 = arith.constant 0 : index
-        %M = hal.interface.constant.load[0] : index
-        %N = hal.interface.constant.load[1] : index
-        %K = hal.interface.constant.load[2] : index
-        %lhs_binding = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(32)
-            : !flow.dispatch.tensor<readonly:?x?xi8>{%M, %K}
-        %rhs_binding = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(32)
-            : !flow.dispatch.tensor<readonly:?x?xi8>{%K, %N}
-        %result_binding = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) offset(%c0) alignment(32)
-            : !flow.dispatch.tensor<readwrite:?x?xi32>{%M, %N}
-        %lhs = flow.dispatch.tensor.load %lhs_binding, offsets = [0, 0], sizes = [%M, %K], strides = [1, 1]
-            : !flow.dispatch.tensor<readonly:?x?xi8>{%M, %K} -> tensor<?x?xi8>
-        %rhs = flow.dispatch.tensor.load %rhs_binding, offsets = [0, 0], sizes = [%K, %N], strides = [1, 1]
-            : !flow.dispatch.tensor<readonly:?x?xi8>{%K, %N} -> tensor<?x?xi8>
-        %init = flow.dispatch.tensor.load %result_binding, offsets = [0, 0], sizes = [%M, %N], strides = [1, 1]
-            : !flow.dispatch.tensor<readwrite:?x?xi32>{%M, %N} -> tensor<?x?xi32>
-        %gemm = linalg.matmul ins(%lhs, %rhs : tensor<?x?xi8>, tensor<?x?xi8>) outs(%init : tensor<?x?xi32>) -> tensor<?x?xi32>
-        flow.dispatch.tensor.store %gemm, %result_binding, offsets = [0, 0], sizes = [%M, %N], strides = [1, 1]
-            : tensor<?x?xi32> -> !flow.dispatch.tensor<readwrite:?x?xi32>{%M, %N}
+        %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<readonly:128x384xi8>
+        %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<readonly:384x1536xi8>
+        %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<writeonly:128x1536xi32>
+        %3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [128, 384], strides = [1, 1] : !flow.dispatch.tensor<readonly:128x384xi8> -> tensor<128x384xi8>
+        %4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [384, 1536], strides = [1, 1] : !flow.dispatch.tensor<readonly:384x1536xi8> -> tensor<384x1536xi8>
+        %5 = linalg.init_tensor [128, 1536] : tensor<128x1536xi32>
+        %6 = linalg.fill ins(%c0_i32 : i32) outs(%5 : tensor<128x1536xi32>) -> tensor<128x1536xi32>
+        %7 = linalg.matmul ins(%3, %4 : tensor<128x384xi8>, tensor<384x1536xi8>) outs(%6 : tensor<128x1536xi32>) -> tensor<128x1536xi32>
+        flow.dispatch.tensor.store %7, %2, offsets = [0, 0], sizes = [128, 1536], strides = [1, 1] : tensor<128x1536xi32> -> !flow.dispatch.tensor<writeonly:128x1536xi32>
         return
       }
     }
@@ -774,7 +889,7 @@
 
 //  CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[64, 64, 0], [8, 32, 0], [0, 0, 16]{{\]}}>
 //  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<CPUDoubleTilingExpert>
-//      CHECK: hal.executable.entry_point public @matmul_x86_i8_i8_i32
+//      CHECK: hal.executable.entry_point public @matmul_i8_i8_i32_static
 // CHECK-SAME:     translation_info = #[[TRANSLATION]]
 //      CHECK:   linalg.matmul
 // CHECK-SAME:       lowering_config = #[[CONFIG]]
@@ -825,7 +940,7 @@
     }
   }
 }
-//  CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[64, 0, 0], [8, 0, 0], [0, 0, 16]]>
+//  CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[64, 0, 0], [1, 0, 0], [0, 0, 1]]>
 //  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<CPUDoubleTilingExpert>
 //      CHECK: hal.executable.entry_point public @gemm_unit_N
 // CHECK-SAME:       translation_info = #[[TRANSLATION]]
@@ -872,7 +987,7 @@
     }
   }
 }
-//  CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[0, 0, 0], [0, 0, 0], [0, 0, 16]]>
+//  CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[0, 0, 0], [0, 0, 0], [0, 0, 1]]>
 //  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<CPUDoubleTilingExpert>
 //      CHECK: hal.executable.entry_point public @gemm_unit_M_unit_N
 // CHECK-SAME:       translation_info = #[[TRANSLATION]]
@@ -934,15 +1049,15 @@
     #hal.descriptor_set.binding<2, storage_buffer>
   ]>
 ]>
-hal.executable private @generic_unit_dims {
+hal.executable private @generic_unit_dims_dynamic {
   hal.executable.variant @llvm, target = <"llvm", "embedded-elf-x86_64", {
     data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128",
     native_vector_size = 16 : index,
     target_triple = "x86_64-unknown-linux-gnu"
   }> {
-    hal.executable.entry_point @generic_unit_dims layout(#executable_layout)
+    hal.executable.entry_point @generic_unit_dims_dynamic layout(#executable_layout)
     builtin.module {
-      func.func @generic_unit_dims() {
+      func.func @generic_unit_dims_dynamic() {
         %c0 = arith.constant 0 : index
         %d0 = hal.interface.constant.load[0] : index
         %d1 = hal.interface.constant.load[1] : index
@@ -973,9 +1088,9 @@
     }
   }
 }
-//  CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[0, 0, 0, 0, 64, 64, 0, 64], [0, 1, 0, 0, 1, 1, 0, 4], [0, 0, 0, 0, 0, 0, 0, 0]{{\]}}>
+//  CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[0, 0, 0, 0, 64, 64, 0, 64], [0, 1, 0, 0, 1, 1, 0, 1], [0, 0, 0, 0, 0, 0, 0, 0]{{\]}}>
 //  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<CPUDoubleTilingExpert>
-//      CHECK: hal.executable.entry_point public @generic_unit_dim
+//      CHECK: hal.executable.entry_point public @generic_unit_dims_dynamic
 // CHECK-SAME:     translation_info = #[[TRANSLATION]]
 //      CHECK: linalg.generic
 // CHECK-SAME:     lowering_config = #[[CONFIG]]
@@ -989,15 +1104,58 @@
     #hal.descriptor_set.binding<2, storage_buffer>
   ]>
 ]>
-hal.executable private @reduce_to_scalar {
+hal.executable private @reduce_to_scalar_static {
   hal.executable.variant @llvm, target = <"llvm", "embedded-elf-x86_64", {
     data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128",
     native_vector_size = 16 : index,
     target_triple = "x86_64-unknown-linux-gnu"
   }> {
-    hal.executable.entry_point @reduce_to_scalar layout(#executable_layout)
+    hal.executable.entry_point @reduce_to_scalar_static layout(#executable_layout)
     builtin.module {
-      func.func @reduce_to_scalar() {
+      func.func @reduce_to_scalar_static() {
+        %cst = arith.constant 0.000000e+00 : f32
+        %c0 = arith.constant 0 : index
+        %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<readonly:128xf32>
+        %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<writeonly:f32>
+        %2 = flow.dispatch.tensor.load %0, offsets = [0], sizes = [128], strides = [1] : !flow.dispatch.tensor<readonly:128xf32> -> tensor<128xf32>
+        %3 = linalg.init_tensor [] : tensor<f32>
+        %4 = linalg.fill ins(%cst : f32) outs(%3 : tensor<f32>) -> tensor<f32>
+        %5 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> ()>], iterator_types = ["reduction"]} ins(%2 : tensor<128xf32>) outs(%4 : tensor<f32>) {
+        ^bb0(%arg0: f32, %arg1: f32):
+          %6 = arith.addf %arg0, %arg1 : f32
+          linalg.yield %6 : f32
+        } -> tensor<f32>
+        flow.dispatch.tensor.store %5, %1, offsets = [], sizes = [], strides = [] : tensor<f32> -> !flow.dispatch.tensor<writeonly:f32>
+        return
+      }
+    }
+  }
+}
+//  CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[0], [0], [4]{{\]}}>
+//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<CPUDoubleTilingExpert>
+//      CHECK: hal.executable.entry_point public @reduce_to_scalar_static
+// CHECK-SAME:     translation_info = #[[TRANSLATION]]
+//      CHECK: linalg.generic
+// CHECK-SAME:     lowering_config = #[[CONFIG]]
+
+// -----
+
+#executable_layout = #hal.executable.layout<push_constants = 0, sets = [
+  #hal.descriptor_set.layout<0, bindings = [
+    #hal.descriptor_set.binding<0, storage_buffer>,
+    #hal.descriptor_set.binding<1, storage_buffer>,
+    #hal.descriptor_set.binding<2, storage_buffer>
+  ]>
+]>
+hal.executable private @reduce_to_scalar_dynamic {
+  hal.executable.variant @llvm, target = <"llvm", "embedded-elf-x86_64", {
+    data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128",
+    native_vector_size = 16 : index,
+    target_triple = "x86_64-unknown-linux-gnu"
+  }> {
+    hal.executable.entry_point @reduce_to_scalar_dynamic layout(#executable_layout)
+    builtin.module {
+      func.func @reduce_to_scalar_dynamic() {
         %c0 = arith.constant 0 : index
         %d0 = hal.interface.constant.load[0] : index
         %in_binding = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) : !flow.dispatch.tensor<readonly:?xf32>{%d0}
@@ -1018,9 +1176,9 @@
     }
   }
 }
-//  CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[0], [0], [4]{{\]}}>
+//  CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[0], [0], [1]{{\]}}>
 //  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<CPUDoubleTilingExpert>
-//      CHECK: hal.executable.entry_point public @reduce_to_scalar
+//      CHECK: hal.executable.entry_point public @reduce_to_scalar_dynamic
 // CHECK-SAME:     translation_info = #[[TRANSLATION]]
 //      CHECK: linalg.generic
 // CHECK-SAME:     lowering_config = #[[CONFIG]]