[Codegen][GPU] Handle dynamic and unaligned cases in DerivedThreadConfig (#18281)

This adds a default set of tile sizes for any dynamically shaped
copy/linalg ops that just uses the preferred vector size based on the
element type bitwidth of the linalg op. This same logic might also be
worth applying in aligned cases, but this patch is opting not to change
pre-existing behavior without proper benchmarking.

Additionally cleans up the tiling tests.
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_apply_tiling_level.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_apply_tiling_level.mlir
index 1b9cc62..7c4cd2f 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_apply_tiling_level.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_apply_tiling_level.mlir
@@ -2,24 +2,10 @@
 // RUN: iree-opt --split-input-file --mlir-print-local-scope --pass-pipeline="builtin.module(func.func(iree-codegen-gpu-apply-tiling-level{tiling-level=thread}, canonicalize, cse))" %s | FileCheck %s --check-prefix=THREAD
 // RUN: iree-opt --split-input-file --mlir-print-local-scope --pass-pipeline="builtin.module(func.func(iree-codegen-gpu-apply-tiling-level{tiling-level=subgroup}, canonicalize, cse))" %s | FileCheck %s --check-prefix=SUBGROUP
 
-#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [
-  #hal.descriptor_set.layout<0, bindings = [
-    #hal.descriptor_set.binding<0, storage_buffer>,
-    #hal.descriptor_set.binding<1, storage_buffer>,
-    #hal.descriptor_set.binding<2, storage_buffer>
-  ]>
-]>
 #config = #iree_gpu.lowering_config<{thread = [2, 16], subgroup = [2, 16]}>
 #map = affine_map<(d0, d1) -> (d0, d1)>
 module {
-  func.func @add_tensor() {
-    %c0 = arith.constant 0 : index
-    %0 = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(0) alignment(64) offset(%c0) : !flow.dispatch.tensor<readonly:tensor<64x256xf32>>
-    %1 = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(1) alignment(64) offset(%c0) : !flow.dispatch.tensor<readonly:tensor<64x256xf32>>
-    %2 = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(2) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<64x256xf32>>
-    %3 = flow.dispatch.tensor.load %0, offsets = [%c0, %c0], sizes = [64, 256], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<64x256xf32>> -> tensor<64x256xf32>
-    %4 = flow.dispatch.tensor.load %1, offsets = [%c0, %c0], sizes = [64, 256], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<64x256xf32>> -> tensor<64x256xf32>
-    %5 = flow.dispatch.tensor.load %2, offsets = [%c0, %c0], sizes = [64, 256], strides = [1, 1] : !flow.dispatch.tensor<writeonly:tensor<64x256xf32>> -> tensor<64x256xf32>
+  func.func @add_tensor(%3: tensor<64x256xf32>, %4: tensor<64x256xf32>, %5: tensor<64x256xf32>) -> tensor<64x256xf32> {
     %6 = linalg.generic {
       indexing_maps = [#map, #map, #map],
       iterator_types = ["parallel", "parallel"]
@@ -28,8 +14,7 @@
       %7 = arith.addf %in, %in_0 : f32
       linalg.yield %7 : f32
     } -> tensor<64x256xf32>
-    flow.dispatch.tensor.store %6, %2, offsets = [%c0, %c0], sizes = [64, 256], strides = [1, 1] : tensor<64x256xf32> -> !flow.dispatch.tensor<writeonly:tensor<64x256xf32>>
-    return
+    return %6 : tensor<64x256xf32>
   }
 }
 
@@ -51,24 +36,10 @@
 
 // -----
 
-#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [
-  #hal.descriptor_set.layout<0, bindings = [
-    #hal.descriptor_set.binding<0, storage_buffer>,
-    #hal.descriptor_set.binding<1, storage_buffer>,
-    #hal.descriptor_set.binding<2, storage_buffer>
-  ]>
-]>
 #config = #iree_gpu.lowering_config<{thread = [0, 16]}>
 #map = affine_map<(d0, d1) -> (d0, d1)>
 module {
-  func.func @sequential_forall_mappings() {
-    %c0 = arith.constant 0 : index
-    %0 = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(0) alignment(64) offset(%c0) : !flow.dispatch.tensor<readonly:tensor<4x256xf32>>
-    %1 = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(1) alignment(64) offset(%c0) : !flow.dispatch.tensor<readonly:tensor<4x256xf32>>
-    %2 = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(2) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<4x256xf32>>
-    %3 = flow.dispatch.tensor.load %0, offsets = [%c0, %c0], sizes = [4, 256], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<4x256xf32>> -> tensor<4x256xf32>
-    %4 = flow.dispatch.tensor.load %1, offsets = [%c0, %c0], sizes = [4, 256], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<4x256xf32>> -> tensor<4x256xf32>
-    %5 = flow.dispatch.tensor.load %2, offsets = [%c0, %c0], sizes = [4, 256], strides = [1, 1] : !flow.dispatch.tensor<writeonly:tensor<4x256xf32>> -> tensor<4x256xf32>
+  func.func @sequential_forall_mappings(%3: tensor<4x256xf32>, %4: tensor<4x256xf32>, %5: tensor<4x256xf32>) -> tensor<4x256xf32> {
     %6 = linalg.generic {
       indexing_maps = [#map, #map, #map],
       iterator_types = ["parallel", "parallel"]
@@ -77,8 +48,7 @@
       %7 = arith.addf %in, %in_0 : f32
       linalg.yield %7 : f32
     } -> tensor<4x256xf32>
-    flow.dispatch.tensor.store %6, %2, offsets = [%c0, %c0], sizes = [4, 256], strides = [1, 1] : tensor<4x256xf32> -> !flow.dispatch.tensor<writeonly:tensor<4x256xf32>>
-    return
+    return %6 : tensor<4x256xf32>
   }
 }
 
@@ -94,28 +64,11 @@
 
 // -----
 
-#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [
-  #hal.descriptor_set.layout<0, bindings = [
-    #hal.descriptor_set.binding<0, storage_buffer>,
-    #hal.descriptor_set.binding<1, storage_buffer>,
-    #hal.descriptor_set.binding<2, storage_buffer>
-  ]>
-]>
-func.func @matmul_transpose_b() attributes {translation_info = #iree_codegen.translation_info<LLVMGPUVectorize workgroup_size = [128, 2, 1] subgroup_size = 64>} {
+func.func @matmul_transpose_b(%5: tensor<64x64xf32>, %6: tensor<64x1280xf16>, %7: tensor<64x1280xf16>) -> tensor<64x64xf32> {
   %c4 = arith.constant 4 : index
   %c1280 = arith.constant 1280 : index
   %cst = arith.constant 0.000000e+00 : f16
   %c0 = arith.constant 0 : index
-  %0 = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<2048x1280xf16>>
-  %1 = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(1) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<10240x1280xf16>>
-  %2 = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(2) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<2048x10240xf32>>
-  %workgroup_id_y = hal.interface.workgroup.id[1] : index
-  %3 = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%workgroup_id_y]
-  %workgroup_id_x = hal.interface.workgroup.id[0] : index
-  %4 = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%workgroup_id_x]
-  %5 = flow.dispatch.tensor.load %2, offsets = [%3, %4], sizes = [64, 64], strides = [1, 1] : !flow.dispatch.tensor<writeonly:tensor<2048x10240xf32>> -> tensor<64x64xf32>
-  %6 = flow.dispatch.tensor.load %0, offsets = [%3, 0], sizes = [64, 1280], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<2048x1280xf16>> -> tensor<64x1280xf16>
-  %7 = flow.dispatch.tensor.load %1, offsets = [%4, 0], sizes = [64, 1280], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<10240x1280xf16>> -> tensor<64x1280xf16>
   %8 = linalg.fill ins(%cst : f16) outs(%5 : tensor<64x64xf32>) -> tensor<64x64xf32>
   %9 = tensor.empty() : tensor<64x1280xf16>
   %10 = tensor.empty() : tensor<64x1280xf16>
@@ -129,8 +82,7 @@
     %14 = linalg.matmul_transpose_b {lowering_config = #iree_gpu.lowering_config<{thread = [4, 4]}>} ins(%12, %13 : tensor<64x4xf16>, tensor<64x4xf16>) outs(%arg1 : tensor<64x64xf32>) -> tensor<64x64xf32>
     scf.yield %14 : tensor<64x64xf32>
   }
-  flow.dispatch.tensor.store %11, %2, offsets = [%3, %4], sizes = [64, 64], strides = [1, 1] : tensor<64x64xf32> -> !flow.dispatch.tensor<writeonly:tensor<2048x10240xf32>>
-  return
+  return %11 : tensor<64x64xf32>
 }
 
 // CHECK-LABEL: func.func @matmul_transpose_b
@@ -148,22 +100,12 @@
 
 // -----
 
-#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [
-  #hal.descriptor_set.layout<0, bindings = [
-    #hal.descriptor_set.binding<0, storage_buffer>,
-    #hal.descriptor_set.binding<1, storage_buffer>
-  ]>
-]>
 #config = #iree_gpu.lowering_config<{reduction = [0, 8]}>
 #map = affine_map<()[s0] -> (s0 * 64)>
 #map1 = affine_map<(d0, d1) -> (d0, d1)>
 #map2 = affine_map<(d0, d1) -> (d0)>
-func.func @reduction() {
-  %c0 = arith.constant 0 : index
+func.func @reduction(%3: tensor<128x384xf32>) -> tensor<128xf32> {
   %cst = arith.constant 0.000000e+00 : f32
-  %0 = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(0) alignment(64) offset(%c0) : !flow.dispatch.tensor<readonly:tensor<128x384xf32>>
-  %1 = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(1) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<128xf32>>
-  %3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [128, 384], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<128x384xf32>> -> tensor<128x384xf32>
   %empty = tensor.empty() : tensor<128xf32>
   %4 = linalg.fill ins(%cst : f32) outs(%empty : tensor<128xf32>) -> tensor<128xf32>
   %5 = linalg.generic {
@@ -174,8 +116,7 @@
     %7 = arith.addf %in, %out : f32
     linalg.yield %7 : f32
   } -> tensor<128xf32>
-  flow.dispatch.tensor.store %5, %1, offsets = [%c0], sizes = [128], strides = [1] : tensor<128xf32> -> !flow.dispatch.tensor<writeonly:tensor<128xf32>>
-  return
+  return %5 : tensor<128xf32>
 }
 
 // CHECK-LABEL: func.func @reduction
@@ -190,24 +131,10 @@
 
 // -----
 
-#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [
-  #hal.descriptor_set.layout<0, bindings = [
-    #hal.descriptor_set.binding<0, storage_buffer>,
-    #hal.descriptor_set.binding<1, storage_buffer>,
-    #hal.descriptor_set.binding<2, storage_buffer>
-  ]>
-]>
 #config = #iree_gpu.lowering_config<{reduction = [0, 0, 8]}>
 #map = affine_map<(d0, d1) -> (d0, d1)>
-func.func @matmul_fuse() {
-  %c0 = arith.constant 0 : index
+func.func @matmul_fuse(%3: tensor<64x64xf32>, %4: tensor<64x64xf32>, %5: tensor<64x64xf32>) -> tensor<64x64xf32> {
   %cst = arith.constant 1.0 : f32
-  %0 = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(0) alignment(64) offset(%c0) : !flow.dispatch.tensor<readonly:tensor<64x64xf32>>
-  %1 = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(1) alignment(64) offset(%c0) : !flow.dispatch.tensor<readonly:tensor<64x64xf32>>
-  %2 = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(2) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<64x64xf32>>
-  %3 = flow.dispatch.tensor.load %0, offsets = [%c0, %c0], sizes = [64, 64], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<64x64xf32>> -> tensor<64x64xf32>
-  %4 = flow.dispatch.tensor.load %1, offsets = [%c0, %c0], sizes = [64, 64], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<64x64xf32>> -> tensor<64x64xf32>
-  %5 = flow.dispatch.tensor.load %2, offsets = [%c0, %c0], sizes = [64, 64], strides = [1, 1] : !flow.dispatch.tensor<writeonly:tensor<64x64xf32>> -> tensor<64x64xf32>
   %empty = tensor.empty() : tensor<64x64xf32>
   %6 = linalg.generic {
     indexing_maps = [#map, #map],
@@ -218,8 +145,7 @@
     linalg.yield %8 : f32
   } -> tensor<64x64xf32>
   %7 = linalg.matmul {lowering_config = #config} ins(%6, %4 : tensor<64x64xf32>, tensor<64x64xf32>) outs(%5 : tensor<64x64xf32>) -> tensor<64x64xf32>
-  flow.dispatch.tensor.store %7, %2, offsets = [%c0, %c0], sizes = [64, 64], strides = [1, 1] : tensor<64x64xf32> -> !flow.dispatch.tensor<writeonly:tensor<64x64xf32>>
-  return
+  return %7 : tensor<64x64xf32>
 }
 
 // CHECK-LABEL: func.func @matmul_fuse
@@ -229,39 +155,23 @@
 
 // -----
 
-#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [
-  #hal.descriptor_set.layout<0, bindings = [
-    #hal.descriptor_set.binding<0, storage_buffer>,
-    #hal.descriptor_set.binding<1, storage_buffer>,
-    #hal.descriptor_set.binding<2, storage_buffer>
-  ]>
-]>
 #config = #iree_gpu.lowering_config<{thread = [8, 8]}>
-func.func @matmul_cleanup() {
+func.func @matmul_cleanup(%3: tensor<64x64xf32>, %4: tensor<64x64xf32>, %5: tensor<64x64xf32>) -> tensor<64x64xf32> {
   %c8 = arith.constant 8 : index
   %c64 = arith.constant 64 : index
   %c0 = arith.constant 0 : index
-  %0 = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(0) alignment(64) offset(%c0) : !flow.dispatch.tensor<readonly:tensor<64x64xf32>>
-  %1 = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(1) alignment(64) offset(%c0) : !flow.dispatch.tensor<readonly:tensor<64x64xf32>>
-  %2 = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(2) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<64x64xf32>>
-  %3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [64, 64], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<64x64xf32>> -> tensor<64x64xf32>
-  %4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [64, 64], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<64x64xf32>> -> tensor<64x64xf32>
-  %5 = flow.dispatch.tensor.load %2, offsets = [0, 0], sizes = [64, 64], strides = [1, 1] : !flow.dispatch.tensor<writeonly:tensor<64x64xf32>> -> tensor<64x64xf32>
   %6 = scf.for %arg0 = %c0 to %c64 step %c8 iter_args(%arg1 = %5) -> (tensor<64x64xf32>) {
     %extracted_slice = tensor.extract_slice %3[0, %arg0] [64, 8] [1, 1] : tensor<64x64xf32> to tensor<64x8xf32>
     %extracted_slice_0 = tensor.extract_slice %4[%arg0, 0] [8, 64] [1, 1] : tensor<64x64xf32> to tensor<8x64xf32>
     %7 = linalg.matmul {lowering_config = #config} ins(%extracted_slice, %extracted_slice_0 : tensor<64x8xf32>, tensor<8x64xf32>) outs(%arg1 : tensor<64x64xf32>) -> tensor<64x64xf32>
     scf.yield %7 : tensor<64x64xf32>
   }
-  flow.dispatch.tensor.store %6, %2, offsets = [0, 0], sizes = [64, 64], strides = [1, 1] : tensor<64x64xf32> -> !flow.dispatch.tensor<writeonly:tensor<64x64xf32>>
-  return
+  return %6 : tensor<64x64xf32>
 }
 
 // THREAD-LABEL: func.func @matmul_cleanup
-//       THREAD:   %[[B0:.+]] = hal.interface.binding.subspan layout({{.+}}) set(0) binding(0)
-//       THREAD:   %[[B1:.+]] = hal.interface.binding.subspan layout({{.+}}) set(0) binding(1)
-//       THREAD:   %[[A:.+]] = flow.dispatch.tensor.load %[[B0]]
-//       THREAD:   %[[B:.+]] = flow.dispatch.tensor.load %[[B1]]
+//  THREAD-SAME:   %[[A:[A-Za-z0-9]+]]: tensor<64x64xf32>
+//  THREAD-SAME:   %[[B:[A-Za-z0-9]+]]: tensor<64x64xf32>
 //       THREAD:   scf.for %{{.*}} = %c0 to %c64 step %c8
 //       THREAD:     scf.forall
 //   THREAD-DAG:       %[[LHS:.+]] = tensor.extract_slice %[[A]]
@@ -270,25 +180,13 @@
 
 // -----
 
-#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [
-  #hal.descriptor_set.layout<0, bindings = [
-    #hal.descriptor_set.binding<0, storage_buffer>,
-    #hal.descriptor_set.binding<1, storage_buffer>,
-    #hal.descriptor_set.binding<2, storage_buffer>
-  ]>
-]>
 #config = #iree_gpu.derived_thread_config
 #map = affine_map<(d0, d1) -> (d0, d1)>
 module {
-  func.func @inferred_add_tensor()
-      attributes {translation_info = #iree_codegen.translation_info<LLVMGPUVectorize workgroup_size = [16, 32, 1] subgroup_size = 64, {}>} {
-    %c0 = arith.constant 0 : index
-    %0 = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(0) alignment(64) offset(%c0) : !flow.dispatch.tensor<readonly:tensor<64x256xf32>>
-    %1 = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(1) alignment(64) offset(%c0) : !flow.dispatch.tensor<readonly:tensor<64x256xf32>>
-    %2 = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(2) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<64x256xf32>>
-    %3 = flow.dispatch.tensor.load %0, offsets = [%c0, %c0], sizes = [64, 256], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<64x256xf32>> -> tensor<64x256xf32>
-    %4 = flow.dispatch.tensor.load %1, offsets = [%c0, %c0], sizes = [64, 256], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<64x256xf32>> -> tensor<64x256xf32>
-    %5 = flow.dispatch.tensor.load %2, offsets = [%c0, %c0], sizes = [64, 256], strides = [1, 1] : !flow.dispatch.tensor<writeonly:tensor<64x256xf32>> -> tensor<64x256xf32>
+  func.func @inferred_add_tensor(%3: tensor<64x256xf32>, %4: tensor<64x256xf32>, %5: tensor<64x256xf32>) -> tensor<64x256xf32>
+      attributes {
+        translation_info = #iree_codegen.translation_info<LLVMGPUTileAndFuse workgroup_size = [16, 32, 1] subgroup_size = 64, {}>
+      } {
     %6 = linalg.generic {
       indexing_maps = [#map, #map, #map],
       iterator_types = ["parallel", "parallel"]
@@ -297,8 +195,7 @@
       %7 = arith.addf %in, %in_0 : f32
       linalg.yield %7 : f32
     } -> tensor<64x256xf32>
-    flow.dispatch.tensor.store %6, %2, offsets = [%c0, %c0], sizes = [64, 256], strides = [1, 1] : tensor<64x256xf32> -> !flow.dispatch.tensor<writeonly:tensor<64x256xf32>>
-    return
+    return %6 : tensor<64x256xf32>
   }
 }
 
@@ -314,6 +211,63 @@
 
 // -----
 
+#config = #iree_gpu.derived_thread_config
+#map = affine_map<(d0, d1) -> (d0, d1)>
+module {
+  func.func @inferred_dynamic(%3: tensor<?x?xf32>, %4: tensor<?x?xf32>, %5: tensor<?x?xf32>) -> tensor<?x?xf32>
+      attributes {
+        translation_info = #iree_codegen.translation_info<LLVMGPUTileAndFuse workgroup_size = [16, 32, 1] subgroup_size = 64, {}>
+      } {
+    %6 = linalg.generic {
+      indexing_maps = [#map, #map, #map],
+      iterator_types = ["parallel", "parallel"]
+      } ins(%3, %4 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%5 : tensor<?x?xf32>) attrs =  {lowering_config = #config} {
+    ^bb0(%in: f32, %in_0: f32, %out: f32):
+      %7 = arith.addf %in, %in_0 : f32
+      linalg.yield %7 : f32
+    } -> tensor<?x?xf32>
+    return %6 : tensor<?x?xf32>
+  }
+}
+
+// THREAD-LABEL: func.func @inferred_dynamic
+//  THREAD-SAME:   %[[A:[A-Za-z0-9]+]]: tensor<?x?xf32>
+//   THREAD-DAG:   %[[DIM0:.+]] = tensor.dim %[[A]], %c0 : tensor<?x?xf32>
+//   THREAD-DAG:   %[[DIM1:.+]] = tensor.dim %[[A]], %c1 : tensor<?x?xf32>
+//       THREAD:   scf.forall ({{.*}}) = (0, 0) to (%[[DIM0]], %[[DIM1]]) step (1, 4)
+//       THREAD:     linalg.generic {{.*}} ins(%{{.*}}: tensor<1x?xf32>, tensor<1x?xf32>)
+//       THREAD:     scf.forall.in_parallel
+//       THREAD:   mapping = [#gpu.thread<linear_dim_1>, #gpu.thread<linear_dim_0>]
+
+// -----
+
+#config = #iree_gpu.derived_thread_config
+#map = affine_map<(d0, d1) -> (d0, d1)>
+module {
+  func.func @inferred_small_inner_dim(%3: tensor<8x2xf32>, %4: tensor<8x2xf32>, %5: tensor<8x2xf32>) -> tensor<8x2xf32>
+      attributes {
+        translation_info = #iree_codegen.translation_info<LLVMGPUTileAndFuse workgroup_size = [16, 32, 1] subgroup_size = 64, {}>
+      } {
+    %6 = linalg.generic {
+      indexing_maps = [#map, #map, #map],
+      iterator_types = ["parallel", "parallel"]
+      } ins(%3, %4 : tensor<8x2xf32>, tensor<8x2xf32>) outs(%5 : tensor<8x2xf32>) attrs =  {lowering_config = #config} {
+    ^bb0(%in: f32, %in_0: f32, %out: f32):
+      %7 = arith.addf %in, %in_0 : f32
+      linalg.yield %7 : f32
+    } -> tensor<8x2xf32>
+    return %6 : tensor<8x2xf32>
+  }
+}
+
+// THREAD-LABEL: func.func @inferred_small_inner_dim
+//       THREAD:   scf.forall ({{.*}}) = (0, 0) to (8, 2) step (1, 2)
+//       THREAD:     linalg.generic {{.*}} ins(%{{.*}}: tensor<1x2xf32>, tensor<1x2xf32>)
+//       THREAD:     scf.forall.in_parallel
+//       THREAD:   mapping = [#gpu.thread<linear_dim_1>, #gpu.thread<linear_dim_0>]
+
+// -----
+
 #pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [
   #hal.descriptor_set.layout<0, bindings = [
     #hal.descriptor_set.binding<0, storage_buffer>,
@@ -323,16 +277,16 @@
 ]>
 #config = #iree_gpu.derived_thread_config
 module {
-  func.func @inferred_im2col()
-      attributes {translation_info = #iree_codegen.translation_info<LLVMGPUVectorize workgroup_size = [16, 32, 1] subgroup_size = 64, {}>} {
-    %c0 = arith.constant 0 : index
-    %0 = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(0) alignment(64) offset(%c0) : !flow.dispatch.tensor<readonly:tensor<2x34x34x128xf16>>
-    %1 = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(2) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<2x128x8xf16>>
-    %2 = flow.dispatch.tensor.load %0, offsets = [%c0, %c0, %c0, %c0], sizes = [2, 34, 34, 128], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<2x34x34x128xf16>> -> tensor<2x34x34x128xf16>
-    %3 = flow.dispatch.tensor.load %1, offsets = [%c0, %c0, %c0], sizes = [2, 128, 8], strides = [1, 1, 1] : !flow.dispatch.tensor<writeonly:tensor<2x128x8xf16>> -> tensor<2x128x8xf16>
-    %4 = iree_linalg_ext.im2col {lowering_config = #config} strides = [1, 1] dilations = [1, 1] kernel_size = [3, 3] m_offset = [0] k_offset = [0] batch_pos = [0] m_pos = [2, 3] k_pos = [1] ins(%2 : tensor<2x34x34x128xf16>) outs(%3 : tensor<2x128x8xf16>) -> tensor<2x128x8xf16>
-    flow.dispatch.tensor.store %4, %1, offsets = [%c0, %c0, %c0], sizes = [2, 128, 8], strides = [1, 1, 1] : tensor<2x128x8xf16> -> !flow.dispatch.tensor<writeonly:tensor<2x128x8xf16>>
-    return
+  func.func @inferred_im2col(%2: tensor<2x34x34x128xf16>, %3: tensor<2x128x8xf16>) -> tensor<2x128x8xf16>
+      attributes {
+        translation_info = #iree_codegen.translation_info<LLVMGPUTileAndFuse workgroup_size = [16, 32, 1] subgroup_size = 64, {}>
+      } {
+    %4 = iree_linalg_ext.im2col {lowering_config = #config}
+      strides = [1, 1] dilations = [1, 1] kernel_size = [3, 3]
+      m_offset = [0] k_offset = [0] batch_pos = [0] m_pos = [2, 3] k_pos = [1]
+      ins(%2 : tensor<2x34x34x128xf16>)
+      outs(%3 : tensor<2x128x8xf16>) -> tensor<2x128x8xf16>
+    return %4 : tensor<2x128x8xf16>
   }
 }
 
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/DerivedConfigUtils.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/DerivedConfigUtils.cpp
index 64af9dd..72dd92c 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/DerivedConfigUtils.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/DerivedConfigUtils.cpp
@@ -11,28 +11,41 @@
 #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
 #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
 #include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
 #include "mlir/IR/TypeUtilities.h"
 
 namespace mlir::iree_compiler::IREE::GPU {
 
 static constexpr int64_t kPreferredCopyNumBits = 128;
 
-SmallVector<int64_t>
+// Helper to construct a list of tile sizes that simply uses the given vector
+// size or the innerDimSize as the inner most tile size, whichever is smaller.
+// All other dims are tiled to 1.
+static SmallVector<int64_t>
+getVectorSizeTileSizes(int64_t rank, int64_t innerDimSize, int64_t vectorSize) {
+  SmallVector<int64_t> tileSizes(rank, 1);
+  if (ShapedType::isDynamic(innerDimSize) || innerDimSize >= vectorSize) {
+    tileSizes.back() = vectorSize;
+  } else {
+    tileSizes.back() = innerDimSize;
+  }
+  return tileSizes;
+}
+
+static SmallVector<int64_t>
 getThreadTileSizesFromLoopRanges(SmallVector<int64_t> loopRanges,
                                  int64_t numThreads, int64_t vectorSize) {
-  // TODO: We shouldn't need this check, however loop fusion currently requires
-  // loop trip counts to be identical, meaning we need to use a num_threads
-  // variant of tiling. Remove this and simply return the preferred vector size
-  // once loop fusion can resolve the forall properly.
   if (llvm::any_of(loopRanges,
                    [](int64_t s) { return ShapedType::isDynamic(s); })) {
-    return {};
+    return getVectorSizeTileSizes(loopRanges.size(), loopRanges.back(),
+                                  vectorSize);
   }
 
   int64_t flatNumTrips = std::accumulate(loopRanges.begin(), loopRanges.end(),
                                          1, std::multiplies<int64_t>());
   if (flatNumTrips % numThreads != 0) {
-    return {};
+    return getVectorSizeTileSizes(loopRanges.size(), loopRanges.back(),
+                                  vectorSize);
   }
   int64_t maxVectorSize = flatNumTrips / numThreads;
 
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_tile_and_fuse.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_tile_and_fuse.mlir
index 9914509..a9f320c 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_tile_and_fuse.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_tile_and_fuse.mlir
@@ -531,9 +531,9 @@
         %cst = arith.constant 0.000000e+00 : f32
         %cst_0 = arith.constant dense<1.0> : tensor<1x64xf32>
         %c0 = arith.constant 0 : index
-        %0 = hal.interface.binding.subspan layout(<push_constants = 0, sets = [<0, bindings = [<0, storage_buffer, "ReadOnly|Indirect">, <1, storage_buffer, ReadOnly>, <2, storage_buffer, Indirect>], flags = Indirect>]>) set(0) binding(0) alignment(64) offset(%c0) flags("ReadOnly|Indirect") : !flow.dispatch.tensor<readonly:tensor<1x64x58x58xf32>>
-        %1 = hal.interface.binding.subspan layout(<push_constants = 0, sets = [<0, bindings = [<0, storage_buffer, "ReadOnly|Indirect">, <1, storage_buffer, ReadOnly>, <2, storage_buffer, Indirect>], flags = Indirect>]>) set(0) binding(1) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<64x64x3x3xf32>>
-        %2 = hal.interface.binding.subspan layout(<push_constants = 0, sets = [<0, bindings = [<0, storage_buffer, "ReadOnly|Indirect">, <1, storage_buffer, ReadOnly>, <2, storage_buffer, Indirect>], flags = Indirect>]>) set(0) binding(2) alignment(64) offset(%c0) flags(Indirect) : !flow.dispatch.tensor<writeonly:tensor<1x64x56x56xf32>>
+        %0 = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(0) alignment(64) offset(%c0) flags("ReadOnly|Indirect") : !flow.dispatch.tensor<readonly:tensor<1x64x58x58xf32>>
+        %1 = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(1) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<64x64x3x3xf32>>
+        %2 = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(2) alignment(64) offset(%c0) flags(Indirect) : !flow.dispatch.tensor<writeonly:tensor<1x64x56x56xf32>>
         %3 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0, 0], sizes = [1, 64, 58, 58], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<1x64x58x58xf32>> -> tensor<1x64x58x58xf32>
         %4 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0, 0], sizes = [64, 64, 3, 3], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<64x64x3x3xf32>> -> tensor<64x64x3x3xf32>
         %5 = tensor.empty() : tensor<1x64x56x56xf32>
@@ -562,3 +562,77 @@
 //       CHECK:   arith.addf
 //       CHECK:   arith.cmpf
 //       CHECK:   arith.select
+
+// -----
+
+#lowering_config = #iree_gpu.lowering_config<{
+  reduction = [0 : index, 0 : index, 4 : index],
+  thread = [1 : index, 4 : index, 0 : index],
+  workgroup = [4 : index, 32 : index, 0 : index]
+}>
+
+#translation_info = #iree_codegen.translation_info<LLVMGPUTileAndFuse workgroup_size = [8, 4, 1] subgroup_size = 32>
+
+#pipeline_layout = #hal.pipeline.layout<
+  push_constants = 0,
+  sets = [
+    <0, bindings = [
+      <0, storage_buffer, ReadOnly>,
+      <1, storage_buffer, "ReadOnly|Indirect">,
+      <2, storage_buffer, Indirect>
+    ], flags = Indirect>
+  ]>
+
+hal.executable public @main {
+  hal.executable.variant public @rocm_hsaco_fb target(<"rocm", "rocm-hsaco-fb">) {
+    hal.executable.export public @skinny_matmul_config ordinal(0) layout(#pipeline_layout) attributes {hal.interface.bindings = [#hal.interface.binding<0, 0>, #hal.interface.binding<0, 1>, #hal.interface.binding<0, 2>]} {
+    ^bb0(%arg0: !hal.device):
+      %x, %y, %z = flow.dispatch.workgroup_count_from_slice
+      hal.return %x, %y, %z : index, index, index
+    }
+    builtin.module {
+      func.func @skinny_matmul_config() attributes {translation_info = #translation_info} {
+        %cst = arith.constant 0.000000e+00 : f32
+        %c102227904 = arith.constant 102227904 : index
+        %c111444672 = arith.constant 111444672 : index
+        %c4014080 = arith.constant 4014080 : index
+        %c0 = arith.constant 0 : index
+        %0 = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(0) alignment(64) offset(%c102227904) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<128x256xf32>>
+        %1 = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(1) alignment(64) offset(%c4014080) flags("ReadOnly|Indirect") : !flow.dispatch.tensor<readonly:tensor<256x3136xf32>>
+        %2 = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(0) alignment(64) offset(%c111444672) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<128xf32>>
+        %3 = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(2) alignment(64) offset(%c0) flags(Indirect) : !flow.dispatch.tensor<writeonly:tensor<128x3136xf32>>
+        %4 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [128, 256], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<128x256xf32>> -> tensor<128x256xf32>
+        %5 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [256, 3136], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<256x3136xf32>> -> tensor<256x3136xf32>
+        %6 = flow.dispatch.tensor.load %2, offsets = [0], sizes = [128], strides = [1] : !flow.dispatch.tensor<readonly:tensor<128xf32>> -> tensor<128xf32>
+        %7 = tensor.empty() : tensor<128x3136xf32>
+        %8 = linalg.fill ins(%cst : f32) outs(%7 : tensor<128x3136xf32>) -> tensor<128x3136xf32>
+        %9 = linalg.matmul {lowering_config = #lowering_config} ins(%4, %5 : tensor<128x256xf32>, tensor<256x3136xf32>) outs(%8 : tensor<128x3136xf32>) -> tensor<128x3136xf32>
+        %10 = linalg.generic {
+          indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>],
+          iterator_types = ["parallel", "parallel"]}
+          ins(%9, %6 : tensor<128x3136xf32>, tensor<128xf32>) outs(%7 : tensor<128x3136xf32>) {
+        ^bb0(%in: f32, %in_0: f32, %out: f32):
+          %11 = arith.addf %in, %in_0 : f32
+          %12 = arith.cmpf ugt, %11, %cst : f32
+          %13 = arith.select %12, %11, %cst : f32
+          linalg.yield %13 : f32
+        } -> tensor<128x3136xf32>
+        flow.dispatch.tensor.store %10, %3, offsets = [0, 0], sizes = [128, 3136], strides = [1, 1] : tensor<128x3136xf32> -> !flow.dispatch.tensor<writeonly:tensor<128x3136xf32>>
+        return
+      }
+    }
+  }
+}
+
+// CHECK: #[[$MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1 * 8)>
+
+// CHECK-LABEL: func @skinny_matmul_config
+
+//   CHECK-DAG:   %[[IDX:.+]] = gpu.thread_id  x
+//   CHECK-DAG:   %[[IDY:.+]] = gpu.thread_id  y
+//       CHECK:   %[[LINID:.+]] = affine.apply #[[$MAP]]()[%[[IDX]], %[[IDY]]]
+//       CHECK:   scf.for %{{.*}} = %c0 to %c256 step %c4 {{.*}} -> (vector<1x4xf32>)
+//       CHECK:     scf.for %{{.*}} = %[[LINID]] to %c4 step %c32
+//       CHECK:       %[[READ:.+]] = vector.transfer_read {{.*}} : memref<128x256xf32, {{.*}}storage_buffer>>, vector<4xf32>
+//       CHECK:       vector.transfer_write %[[READ]], %{{.*}} : vector<4xf32>, memref<4x6xf32, #gpu.address_space<workgroup>>
+//       CHECK:     vector.contract