[spirv] Migrate tests for LinalgTileAndDistributePass (#5610)

diff --git a/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.cpp b/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.cpp
index c2f5d19..5355724 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.cpp
@@ -480,14 +480,18 @@
     // memory available (maybe). For now, just hard-wire it.
     tileSizeK = 32;
   }
-  assert(tileSizes.empty());
-  int64_t M = op.inputs()[0].getType().cast<ShapedType>().getShape()[0];
-  int64_t N = op.inputs()[1].getType().cast<ShapedType>().getShape()[1];
-  int64_t K = op.inputs()[0].getType().cast<ShapedType>().getShape()[1];
+
+  SmallVector<ShapedType> inputTypes;
+  std::tie(inputTypes, std::ignore) = getInputOutputTypes(op);
+  int64_t M = inputTypes[0].getShape()[0];
+  int64_t N = inputTypes[1].getShape()[1];
+  int64_t K = inputTypes[0].getShape()[1];
+
   SmallVector<int64_t, 4> ts = {
       getMinIfShapeStatic(M, nRowsPerWorkitem * config.workgroupSize[1]),
       getMinIfShapeStatic(N, nColsPerWorkitem * config.workgroupSize[0]),
       getMinIfShapeStatic(K, tileSizeK)};
+  assert(tileSizes.empty());
   tileSizes.emplace_back(std::move(ts));
   return success();
 }
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/BUILD b/iree/compiler/Conversion/LinalgToSPIRV/test/BUILD
index 2b3f9a6..f6c84b2 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/test/BUILD
+++ b/iree/compiler/Conversion/LinalgToSPIRV/test/BUILD
@@ -27,28 +27,27 @@
     name = "lit",
     srcs = enforce_glob(
         [
-            "batch_matmul_vectorization.mlir",
             "concretize_tile_among_workgroups.mlir",
             "concretize_tile_among_workgroups_dynamic.mlir",
             "convert_to_gpu.mlir",
             "convert_to_spirv.mlir",
             "dead_alloc.mlir",
-            "elementwise_vectorization.mlir",
             "fold-gpu-procid-uses.mlir",
             "forop_canonicalization.mlir",
             "materialize_launch_configuration.mlir",
             "materialize_launch_configuration2.mlir",
-            "matmul_fused_vectorization.mlir",
-            "matmul_vectorization.mlir",
             "matmul_vectorization_licm.mlir",
-            "vectorize_memref_load_store.mlir",
             "pipeline_matmul_vectorization.mlir",
             "pipeline_test_cooperative_mat.mlir",
             "promote_workgroup_memory.mlir",
             "split_dispatch_function.mlir",
+            "tile_and_vectorize_batch_matmul.mlir",
             "tile_and_vectorize_conv.mlir",
             "tile_and_vectorize_matmul.mlir",
             "vector_to_gpu.mlir",
+            "vectorize_elementwise_ops.mlir",
+            "vectorize_matmul.mlir",
+            "vectorize_memref_load_store.mlir",
         ],
         include = ["*.mlir"],
     ),
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/CMakeLists.txt b/iree/compiler/Conversion/LinalgToSPIRV/test/CMakeLists.txt
index a00bf0a..45673f5 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/test/CMakeLists.txt
+++ b/iree/compiler/Conversion/LinalgToSPIRV/test/CMakeLists.txt
@@ -14,27 +14,26 @@
   NAME
     lit
   SRCS
-    "batch_matmul_vectorization.mlir"
     "concretize_tile_among_workgroups.mlir"
     "concretize_tile_among_workgroups_dynamic.mlir"
     "convert_to_gpu.mlir"
     "convert_to_spirv.mlir"
     "dead_alloc.mlir"
-    "elementwise_vectorization.mlir"
     "fold-gpu-procid-uses.mlir"
     "forop_canonicalization.mlir"
     "materialize_launch_configuration.mlir"
     "materialize_launch_configuration2.mlir"
-    "matmul_fused_vectorization.mlir"
-    "matmul_vectorization.mlir"
     "matmul_vectorization_licm.mlir"
     "pipeline_matmul_vectorization.mlir"
     "pipeline_test_cooperative_mat.mlir"
     "promote_workgroup_memory.mlir"
     "split_dispatch_function.mlir"
+    "tile_and_vectorize_batch_matmul.mlir"
     "tile_and_vectorize_conv.mlir"
     "tile_and_vectorize_matmul.mlir"
     "vector_to_gpu.mlir"
+    "vectorize_elementwise_ops.mlir"
+    "vectorize_matmul.mlir"
     "vectorize_memref_load_store.mlir"
   DATA
     iree::tools::IreeFileCheck
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/matmul_fused_vectorization.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/matmul_fused_vectorization.mlir
deleted file mode 100644
index 458e135..0000000
--- a/iree/compiler/Conversion/LinalgToSPIRV/test/matmul_fused_vectorization.mlir
+++ /dev/null
@@ -1,63 +0,0 @@
-// RUN: iree-opt -split-input-file -iree-spirv-workgroup-tile-size=8,64,4 -iree-spirv-invocation-tile-size=8,4,4 -iree-spirv-workgroup-size=16,1,1 -pass-pipeline="hal.executable(hal.executable.target(iree-codegen-spirv-linalg-tile-and-distribute,iree-spirv-tile-and-vectorize-in-one-workgroup,canonicalize,cse))" -iree-spirv-enable-vectorization %s | IreeFileCheck %s
-
-hal.executable @matmul_static_shape attributes {sym_visibility = "private"} {
-  hal.interface @io {
-    hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
-    hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
-    hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
-  }
-  hal.executable.target @vulkan, filter="dylib*" {
-    hal.executable.entry_point @matmul_static_shape attributes {
-      interface = @io, ordinal = 0 : index,
-      signature = (!flow.dispatch.tensor<readonly:?x?xf32>, !flow.dispatch.tensor<readonly:?x?xf32>,
-        !flow.dispatch.tensor<writeonly:?x?xf32>) -> ()}
-    module attributes {
-      spv.target_env =
-        #spv.target_env<#spv.vce<v1.3,
-          [Shader, Float64, Float16, Int64, Int16, Int8, StorageBuffer16BitAccess,
-           StorageUniform16, StoragePushConstant16, StorageBuffer8BitAccess,
-           UniformAndStorageBuffer8BitAccess, StoragePushConstant8, GroupNonUniform,
-           GroupNonUniformVote, GroupNonUniformArithmetic, GroupNonUniformBallot,
-           GroupNonUniformShuffle, GroupNonUniformShuffleRelative, VariablePointers,
-           VariablePointersStorageBuffer],
-          [SPV_KHR_16bit_storage, SPV_KHR_8bit_storage,
-           SPV_KHR_storage_buffer_storage_class, SPV_KHR_variable_pointers]>,
-          ARM:IntegratedGPU,
-          {max_compute_shared_memory_size = 32768 : i32,
-           max_compute_workgroup_invocations = 512 : i32,
-           max_compute_workgroup_size = dense<512> : vector<3xi32>,
-           subgroup_size = 16 : i32}>} {
-      func @matmul_static_shape()
-        attributes {vkspv.num_workgroups_fn = @matmul_static_shape__num_workgroups__} {
-        %arg0 = iree.placeholder for "interface buffer"
-          {binding = @io::@arg0, operand_result_num = 0 : i32} : memref<4096x4096xf32>
-        %arg1 = iree.placeholder for "interface buffer"
-          {binding = @io::@arg1, operand_result_num = 1 : i32} : memref<4096x4096xf32>
-        %ret0 = iree.placeholder for "interface buffer"
-          {binding = @io::@ret0, operand_result_num = 2 : i32} : memref<4096x4096xf32>
-        %cst = constant 0.000000e+00 : f32
-        linalg.fill(%ret0, %cst) : memref<4096x4096xf32>, f32
-        linalg.matmul ins(%arg0, %arg1 : memref<4096x4096xf32>, memref<4096x4096xf32>)
-                     outs(%ret0 : memref<4096x4096xf32>)
-        return
-      }
-      func private @matmul_static_shape__num_workgroups__
-        (!shapex.ranked_shape<[4096, 4096]>, !shapex.ranked_shape<[4096, 4096]>,
-         !shapex.ranked_shape<[4096, 4096]>) -> (index, index, index)
-      hal.interface @io attributes {sym_visibility = "private"} {
-        hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
-        hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
-        hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
-      }
-    }
-  }
-}
-//    CHECK-LABEL: func @matmul_static_shape
-//  CHECK-COUNT-8:   vector.transfer_write
-//  CHECK-COUNT-8:   vector.transfer_read
-//          CHECK:   %[[FOR_RES:.+]]:8 = scf.for
-// CHECK-COUNT-12:     vector.transfer_read
-// CHECK-COUNT-32:     vector.contract
-//      CHECK:         scf.yield
-//  CHECK-COUNT-8:    vector.transfer_write %[[FOR_RES]]
-//          CHECK:    return
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/batch_matmul_vectorization.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/tile_and_vectorize_batch_matmul.mlir
similarity index 62%
rename from iree/compiler/Conversion/LinalgToSPIRV/test/batch_matmul_vectorization.mlir
rename to iree/compiler/Conversion/LinalgToSPIRV/test/tile_and_vectorize_batch_matmul.mlir
index da4dc0a..4cdaeeb 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/test/batch_matmul_vectorization.mlir
+++ b/iree/compiler/Conversion/LinalgToSPIRV/test/tile_and_vectorize_batch_matmul.mlir
@@ -1,46 +1,53 @@
-// RUN: iree-opt -split-input-file -iree-spirv-workgroup-tile-size=1,8,64,4 -iree-spirv-invocation-tile-size=1,8,4,4 -iree-spirv-workgroup-size=16,1,1 -pass-pipeline="hal.executable(hal.executable.target(iree-codegen-spirv-linalg-tile-and-distribute,iree-spirv-tile-and-vectorize-in-one-workgroup,canonicalize,cse))" -iree-spirv-enable-vectorization %s | IreeFileCheck %s
+// RUN: iree-opt -split-input-file -pass-pipeline="hal.executable(hal.executable.target(iree-spirv-concretize-tile-among-workgroups,iree-spirv-tile-and-vectorize-in-one-workgroup))" -canonicalize -cse -iree-spirv-workgroup-tile-size=1,8,64,4 -iree-spirv-invocation-tile-size=1,8,4,4 -iree-spirv-workgroup-size=16,1,1 %s | IreeFileCheck %s
 
 hal.executable @batch_matmul_static_shape attributes {sym_visibility = "private"} {
-  hal.interface @io {
+  hal.interface @io attributes {sym_visibility = "private"} {
     hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
     hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
     hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
   }
-  hal.executable.target @vulkan, filter="dylib*" {
+  hal.executable.target @vulkan_spirv, filter="vulkan*" {
     hal.executable.entry_point @batch_matmul_static_shape attributes {
       interface = @io, ordinal = 0 : index,
-      signature = (!flow.dispatch.tensor<readonly:?x?xf32>, !flow.dispatch.tensor<readonly:?x?xf32>,
-        !flow.dispatch.tensor<writeonly:?x?xf32>) -> ()}
-    module attributes {
-      spv.target_env =
-        #spv.target_env<#spv.vce<v1.3,
-          [Shader, Float64, Float16, Int64, Int16, Int8, StorageBuffer16BitAccess,
-           StorageUniform16, StoragePushConstant16, StorageBuffer8BitAccess,
-           UniformAndStorageBuffer8BitAccess, StoragePushConstant8, GroupNonUniform,
-           GroupNonUniformVote, GroupNonUniformArithmetic, GroupNonUniformBallot,
-           GroupNonUniformShuffle, GroupNonUniformShuffleRelative, VariablePointers,
-           VariablePointersStorageBuffer],
-          [SPV_KHR_16bit_storage, SPV_KHR_8bit_storage,
-           SPV_KHR_storage_buffer_storage_class, SPV_KHR_variable_pointers]>,
-          ARM:IntegratedGPU,
-          {max_compute_shared_memory_size = 32768 : i32,
-           max_compute_workgroup_invocations = 512 : i32,
-           max_compute_workgroup_size = dense<512> : vector<3xi32>,
-           subgroup_size = 16 : i32}>} {
-      func @batch_matmul_static_shape()
-        attributes {vkspv.num_workgroups_fn = @matmul_static_shape__num_workgroups__} {
-        %arg0 = iree.placeholder for "interface buffer"
-          {binding = @io::@arg0, operand_result_num = 0 : i32} : memref<4x1024x1024xf32>
-        %arg1 = iree.placeholder for "interface buffer"
-          {binding = @io::@arg1, operand_result_num = 1 : i32} : memref<4x1024x1024xf32>
-        %ret0 = iree.placeholder for "interface buffer"
-          {binding = @io::@ret0, operand_result_num = 2 : i32} : memref<4x1024x1024xf32>
-        linalg.batch_matmul ins(%arg0, %arg1 : memref<4x1024x1024xf32>, memref<4x1024x1024xf32>) outs(%ret0 : memref<4x1024x1024xf32>)
+      signature = (!flow.dispatch.tensor<readonly:4x1024x1024xf32>, !flow.dispatch.tensor<readonly:4x1024x1024xf32>, !flow.dispatch.tensor<writeonly:4x1024x1024xf32>) -> ()}
+    module attributes {spv.target_env = #spv.target_env<#spv.vce<v1.3, [Shader], [SPV_KHR_storage_buffer_storage_class]>, ARM:IntegratedGPU, {}>}  {
+      func @batch_matmul_static_shape() {
+        %c0 = constant 0 : index
+        %c4 = constant 4 : index
+        %c1024 = constant 1024 : index
+        %0 = hal.interface.binding.subspan @io::@s0b0_ro_external[%c0] : memref<4x1024x1024xf32>
+        %1 = hal.interface.binding.subspan @io::@s0b1_ro_external[%c0] : memref<4x1024x1024xf32>
+        %2 = hal.interface.binding.subspan @io::@s0b2_xw_external[%c0] : memref<4x1024x1024xf32>
+        %workgroup_size_x = hal.interface.workgroup.size[0] : index
+        %workgroup_size_y = hal.interface.workgroup.size[1] : index
+        %workgroup_size_z = hal.interface.workgroup.size[2] : index
+        %workgroup_id_x = hal.interface.workgroup.id[0] : index
+        %workgroup_count_x = hal.interface.workgroup.count[0] : index
+        %workgroup_id_y = hal.interface.workgroup.id[1] : index
+        %workgroup_count_y = hal.interface.workgroup.count[1] : index
+        %workgroup_id_z = hal.interface.workgroup.id[2] : index
+        %workgroup_count_z = hal.interface.workgroup.count[2] : index
+        %3 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_z, %workgroup_size_z]
+        %4 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_z, %workgroup_size_z]
+        scf.for %arg0 = %3 to %c4 step %4 {
+          %5 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_y, %workgroup_size_y]
+          %6 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_y, %workgroup_size_y]
+          scf.for %arg1 = %5 to %c1024 step %6 {
+            %7 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_x, %workgroup_size_x]
+            %8 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_x, %workgroup_size_x]
+            scf.for %arg2 = %7 to %c1024 step %8 {
+              %9 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 4)>(%arg0)[%workgroup_size_z]
+              %10 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 1024)>(%arg1)[%workgroup_size_y]
+              %11 = memref.subview %0[%arg0, %arg1, 0] [%9, %10, 1024] [1, 1, 1] : memref<4x1024x1024xf32> to memref<?x?x1024xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 1048576 + s0 + d1 * 1024 + d2)>>
+              %12 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 1024)>(%arg2)[%workgroup_size_x]
+              %13 = memref.subview %1[%arg0, 0, %arg2] [%9, 1024, %12] [1, 1, 1] : memref<4x1024x1024xf32> to memref<?x1024x?xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 1048576 + s0 + d1 * 1024 + d2)>>
+              %14 = memref.subview %2[%arg0, %arg1, %arg2] [%9, %10, %12] [1, 1, 1] : memref<4x1024x1024xf32> to memref<?x?x?xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 1048576 + s0 + d1 * 1024 + d2)>>
+              linalg.batch_matmul {__internal_linalg_transform__ = "workgroup"} ins(%11, %13 : memref<?x?x1024xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 1048576 + s0 + d1 * 1024 + d2)>>, memref<?x1024x?xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 1048576 + s0 + d1 * 1024 + d2)>>) outs(%14 : memref<?x?x?xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 1048576 + s0 + d1 * 1024 + d2)>>)
+            }
+          }
+        }
         return
       }
-      func private @matmul_static_shape__num_workgroups__
-        (!shapex.ranked_shape<[4096, 4096]>, !shapex.ranked_shape<[4096, 4096]>,
-         !shapex.ranked_shape<[4096, 4096]>) -> (index, index, index)
       hal.interface @io attributes {sym_visibility = "private"} {
         hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
         hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
@@ -49,15 +56,15 @@
     }
   }
 }
+
 //  CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 * 8)>
 //  CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 64)>
 //  CHECK-DAG: #[[MAP2:.+]] = affine_map<()[s0] -> (s0 * 4)>
 //      CHECK: func @batch_matmul_static_shape
-//  CHECK-DAG:  %[[ARG0:.+]] = iree.placeholder {{.*}} {binding = @io::@arg0
-//  CHECK-DAG:  %[[ARG1:.+]] = iree.placeholder {{.*}} {binding = @io::@arg1
-//  CHECK-DAG:  %[[RET0:.+]] = iree.placeholder {{.*}} {binding = @io::@ret0
+//  CHECK-DAG:  %[[ARG0:.+]] = hal.interface.binding.subspan @io::@s0b0_ro_external[%c0]
+//  CHECK-DAG:  %[[ARG1:.+]] = hal.interface.binding.subspan @io::@s0b1_ro_external[%c0]
+//  CHECK-DAG:  %[[RET0:.+]] = hal.interface.binding.subspan @io::@s0b2_xw_external[%c0]
 //  CHECK-DAG:  %[[C0:.+]] = constant 0 : index
-//  CHECK-DAG:  %[[CST:.+]] = constant 0.0
 //  CHECK-DAG:  %[[C1:.+]] = constant 1 : index
 //  CHECK-DAG:  %[[C2:.+]] = constant 2 : index
 //  CHECK-DAG:  %[[C3:.+]] = constant 3 : index
@@ -65,11 +72,15 @@
 //  CHECK-DAG:  %[[C5:.+]] = constant 5 : index
 //  CHECK-DAG:  %[[C6:.+]] = constant 6 : index
 //  CHECK-DAG:  %[[C7:.+]] = constant 7 : index
-//      CHECK:  %[[BIDX:.+]] = "gpu.block_id"() {dimension = "x"}
-//      CHECK:  %[[BIDY:.+]] = "gpu.block_id"() {dimension = "y"}
-//      CHECK:  %[[BIDZ:.+]] = "gpu.block_id"() {dimension = "z"}
+//      CHECK:  %[[BIDX:.+]] = hal.interface.workgroup.id[0]
+//      CHECK:  %[[BIDY:.+]] = hal.interface.workgroup.id[1]
+//      CHECK:  %[[BIDZ:.+]] = hal.interface.workgroup.id[2]
 //  CHECK-DAG:  %[[BOFFSET_Y:.+]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
 //  CHECK-DAG:  %[[BOFFSET_X:.+]] = affine.apply #[[MAP1]]()[%[[BIDX]]]
+//      CHECK:  %[[SUBVIEW_ARG0:.+]] = memref.subview %[[ARG0]]
+// CHECK-SAME:      [%[[BIDZ]], %[[BOFFSET_Y]], 0] [1, 8, 1024]
+//      CHECK:  %[[SUBVIEW_ARG1:.+]] = memref.subview %[[ARG1]]
+// CHECK-SAME:      [%[[BIDZ]], 0, %[[BOFFSET_X]]] [1, 1024, 64]
 //      CHECK:  %[[SUBVIEW_RESULT:.+]] = memref.subview %[[RET0]]
 // CHECK-SAME:      [%[[BIDZ]], %[[BOFFSET_Y]], %[[BOFFSET_X]]] [1, 8, 64]
 //      CHECK:  %[[IIDX:.+]] = "gpu.thread_id"() {dimension = "x"}
@@ -105,12 +116,10 @@
 // CHECK-SAME:  %[[ACC_5:.+]] = %[[READ_INIT_5]],
 // CHECK-SAME:  %[[ACC_6:.+]] = %[[READ_INIT_6]],
 // CHECK-SAME:  %[[ACC_7:.+]] = %[[READ_INIT_7]])
-//  CHECK-DAG:    %[[SUBVIEW_LHS:.+]] = memref.subview %[[ARG0]]
-// CHECK-SAME:      [%[[BIDZ]], %[[BOFFSET_Y]], %[[IV0]]] [1, 8, 4]
-//  CHECK-DAG:    %[[SUBVIEW_RHS:.+]] = memref.subview %[[ARG1]]
-// CHECK-SAME:      [%[[BIDZ]], %[[IV0]], %[[BOFFSET_X]]] [1, 4, 64]
-//  CHECK-DAG:    %[[SUBVIEW_RHS_2:.+]] = memref.subview %[[SUBVIEW_RHS]]
-// CHECK-SAME:      [%[[IIDZ]], 0, %[[IOFFSET_X]]] [1, 4, 4] [1, 1, 1]
+//  CHECK-DAG:    %[[SUBVIEW_LHS:.+]] = memref.subview %[[SUBVIEW_ARG0]]
+// CHECK-SAME:      [%[[IIDZ]], %[[IOFFSET_Y]], %[[IV0]]] [1, 8, 4]
+//  CHECK-DAG:    %[[SUBVIEW_RHS:.+]] = memref.subview %[[SUBVIEW_ARG1]]
+// CHECK-SAME:      [%[[IIDZ]], %[[IV0]], %[[IOFFSET_X]]] [1, 4, 4] [1, 1, 1]
 
 //  CHECK-DAG:    %[[READ_LHS_0:.+]] = vector.transfer_read
 // CHECK-SAME:      %[[SUBVIEW_LHS]][%[[C0]], %[[C0]], %[[C0]]]
@@ -130,13 +139,13 @@
 // CHECK-SAME:      %[[SUBVIEW_LHS]][%[[C0]], %[[C7]], %[[C0]]]
 
 //  CHECK-DAG:    %[[READ_RHS_0:.+]] = vector.transfer_read
-// CHECK-SAME:      %[[SUBVIEW_RHS_2]][%[[C0]], %[[C0]], %[[C0]]]
+// CHECK-SAME:      %[[SUBVIEW_RHS]][%[[C0]], %[[C0]], %[[C0]]]
 //  CHECK-DAG:    %[[READ_RHS_1:.+]] = vector.transfer_read
-// CHECK-SAME:      %[[SUBVIEW_RHS_2]][%[[C0]], %[[C1]], %[[C0]]]
+// CHECK-SAME:      %[[SUBVIEW_RHS]][%[[C0]], %[[C1]], %[[C0]]]
 //  CHECK-DAG:    %[[READ_RHS_2:.+]] = vector.transfer_read
-// CHECK-SAME:      %[[SUBVIEW_RHS_2]][%[[C0]], %[[C2]], %[[C0]]]
+// CHECK-SAME:      %[[SUBVIEW_RHS]][%[[C0]], %[[C2]], %[[C0]]]
 //  CHECK-DAG:    %[[READ_RHS_3:.+]] = vector.transfer_read
-// CHECK-SAME:      %[[SUBVIEW_RHS_2]][%[[C0]], %[[C3]], %[[C0]]]
+// CHECK-SAME:      %[[SUBVIEW_RHS]][%[[C0]], %[[C3]], %[[C0]]]
 
 //  CHECK-DAG:    %[[READ_LHS_0_0:.+]] = vector.extract_strided_slice
 // CHECK-SAME:      %[[READ_LHS_0]] {offsets = [0, 0, 0]
@@ -290,49 +299,56 @@
 
 // -----
 
-hal.executable @batch_matmul_fused_fillop attributes {sym_visibility = "private"} {
-  hal.interface @io {
+hal.executable @fused_fill_batch_matmul attributes {sym_visibility = "private"} {
+  hal.interface @io attributes {sym_visibility = "private"} {
     hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
     hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
     hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
   }
-  hal.executable.target @vulkan, filter="dylib*" {
-    hal.executable.entry_point @batch_matmul_fused_fillop attributes {
+  hal.executable.target @vulkan_spirv, filter="vulkan*" {
+    hal.executable.entry_point @fused_fill_batch_matmul attributes {
       interface = @io, ordinal = 0 : index,
-      signature = (!flow.dispatch.tensor<readonly:?x?xf32>, !flow.dispatch.tensor<readonly:?x?xf32>,
-        !flow.dispatch.tensor<writeonly:?x?xf32>) -> ()}
-    module attributes {
-      spv.target_env =
-        #spv.target_env<#spv.vce<v1.3,
-          [Shader, Float64, Float16, Int64, Int16, Int8, StorageBuffer16BitAccess,
-           StorageUniform16, StoragePushConstant16, StorageBuffer8BitAccess,
-           UniformAndStorageBuffer8BitAccess, StoragePushConstant8, GroupNonUniform,
-           GroupNonUniformVote, GroupNonUniformArithmetic, GroupNonUniformBallot,
-           GroupNonUniformShuffle, GroupNonUniformShuffleRelative, VariablePointers,
-           VariablePointersStorageBuffer],
-          [SPV_KHR_16bit_storage, SPV_KHR_8bit_storage,
-           SPV_KHR_storage_buffer_storage_class, SPV_KHR_variable_pointers]>,
-          ARM:IntegratedGPU,
-          {max_compute_shared_memory_size = 32768 : i32,
-           max_compute_workgroup_invocations = 512 : i32,
-           max_compute_workgroup_size = dense<512> : vector<3xi32>,
-           subgroup_size = 16 : i32}>} {
-      func @batch_matmul_fused_fillop()
-        attributes {vkspv.num_workgroups_fn = @batch_matmul_fused_fillop__num_workgroups__} {
-        %cst = constant 0.000000e+00 : f32
-        %arg0 = iree.placeholder for "interface buffer"
-          {binding = @io::@arg0, operand_result_num = 0 : i32} : memref<4x1024x1024xf32>
-        %arg1 = iree.placeholder for "interface buffer"
-          {binding = @io::@arg1, operand_result_num = 1 : i32} : memref<4x1024x1024xf32>
-        %ret0 = iree.placeholder for "interface buffer"
-          {binding = @io::@ret0, operand_result_num = 2 : i32} : memref<4x1024x1024xf32>
-        linalg.fill(%ret0, %cst) : memref<4x1024x1024xf32>, f32
-        linalg.batch_matmul ins(%arg0, %arg1 : memref<4x1024x1024xf32>, memref<4x1024x1024xf32>) outs(%ret0 : memref<4x1024x1024xf32>)
+      signature = (!flow.dispatch.tensor<readonly:4x1024x1024xf32>, !flow.dispatch.tensor<readonly:4x1024x1024xf32>, !flow.dispatch.tensor<writeonly:4x1024x1024xf32>) -> ()}
+    module attributes {spv.target_env = #spv.target_env<#spv.vce<v1.3, [Shader], [SPV_KHR_storage_buffer_storage_class]>, ARM:IntegratedGPU, {}>}  {
+      func @fused_fill_batch_matmul() {
+        %zero = constant 0.0 : f32
+        %c0 = constant 0 : index
+        %c4 = constant 4 : index
+        %c1024 = constant 1024 : index
+        %0 = hal.interface.binding.subspan @io::@s0b0_ro_external[%c0] : memref<4x1024x1024xf32>
+        %1 = hal.interface.binding.subspan @io::@s0b1_ro_external[%c0] : memref<4x1024x1024xf32>
+        %2 = hal.interface.binding.subspan @io::@s0b2_xw_external[%c0] : memref<4x1024x1024xf32>
+        %workgroup_size_x = hal.interface.workgroup.size[0] : index
+        %workgroup_size_y = hal.interface.workgroup.size[1] : index
+        %workgroup_size_z = hal.interface.workgroup.size[2] : index
+        %workgroup_id_x = hal.interface.workgroup.id[0] : index
+        %workgroup_count_x = hal.interface.workgroup.count[0] : index
+        %workgroup_id_y = hal.interface.workgroup.id[1] : index
+        %workgroup_count_y = hal.interface.workgroup.count[1] : index
+        %workgroup_id_z = hal.interface.workgroup.id[2] : index
+        %workgroup_count_z = hal.interface.workgroup.count[2] : index
+        %3 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_z, %workgroup_size_z]
+        %4 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_z, %workgroup_size_z]
+        scf.for %arg0 = %3 to %c4 step %4 {
+          %5 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_y, %workgroup_size_y]
+          %6 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_y, %workgroup_size_y]
+          scf.for %arg1 = %5 to %c1024 step %6 {
+            %7 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_x, %workgroup_size_x]
+            %8 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_x, %workgroup_size_x]
+            scf.for %arg2 = %7 to %c1024 step %8 {
+              %9 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 4)>(%arg0)[%workgroup_size_z]
+              %10 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 1024)>(%arg1)[%workgroup_size_y]
+              %11 = memref.subview %0[%arg0, %arg1, 0] [%9, %10, 1024] [1, 1, 1] : memref<4x1024x1024xf32> to memref<?x?x1024xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 1048576 + s0 + d1 * 1024 + d2)>>
+              %12 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 1024)>(%arg2)[%workgroup_size_x]
+              %13 = memref.subview %1[%arg0, 0, %arg2] [%9, 1024, %12] [1, 1, 1] : memref<4x1024x1024xf32> to memref<?x1024x?xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 1048576 + s0 + d1 * 1024 + d2)>>
+              %14 = memref.subview %2[%arg0, %arg1, %arg2] [%9, %10, %12] [1, 1, 1] : memref<4x1024x1024xf32> to memref<?x?x?xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 1048576 + s0 + d1 * 1024 + d2)>>
+              linalg.fill(%14, %zero) : memref<?x?x?xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 1048576 + s0 + d1 * 1024 + d2)>>, f32
+              linalg.batch_matmul {__internal_linalg_transform__ = "workgroup"} ins(%11, %13 : memref<?x?x1024xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 1048576 + s0 + d1 * 1024 + d2)>>, memref<?x1024x?xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 1048576 + s0 + d1 * 1024 + d2)>>) outs(%14 : memref<?x?x?xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 1048576 + s0 + d1 * 1024 + d2)>>)
+            }
+          }
+        }
         return
       }
-      func private @batch_matmul_fused_fillop__num_workgroups__
-        (!shapex.ranked_shape<[4096, 4096]>, !shapex.ranked_shape<[4096, 4096]>,
-         !shapex.ranked_shape<[4096, 4096]>) -> (index, index, index)
       hal.interface @io attributes {sym_visibility = "private"} {
         hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
         hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
@@ -341,7 +357,8 @@
     }
   }
 }
-//    CHECK-LABEL: func @batch_matmul_fused_fillop
+
+//    CHECK-LABEL: func @fused_fill_batch_matmul
 //  CHECK-COUNT-8:   vector.transfer_write
 //  CHECK-COUNT-8:   vector.transfer_read
 //          CHECK:   %[[FOR_RES:.+]]:8 = scf.for
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/tile_and_vectorize_matmul.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/tile_and_vectorize_matmul.mlir
index 07e871d..89d2053 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/test/tile_and_vectorize_matmul.mlir
+++ b/iree/compiler/Conversion/LinalgToSPIRV/test/tile_and_vectorize_matmul.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt -split-input-file -pass-pipeline="hal.executable(hal.executable.target(iree-spirv-concretize-tile-among-workgroups,iree-spirv-tile-and-vectorize-in-one-workgroup))" -iree-spirv-enable-vectorization -iree-codegen-spirv-experimental-linalg-on-tensors -canonicalize -cse %s | IreeFileCheck %s
+// RUN: iree-opt -split-input-file -pass-pipeline="hal.executable(hal.executable.target(iree-spirv-concretize-tile-among-workgroups,iree-spirv-tile-and-vectorize-in-one-workgroup))" -canonicalize -cse -iree-spirv-workgroup-tile-size=8,64,4 -iree-spirv-invocation-tile-size=8,4,4 -iree-spirv-workgroup-size=16,1,1 %s | IreeFileCheck %s
 
 hal.executable @matmul_static_shape_f16 attributes {sym_visibility = "private"} {
   hal.interface @io attributes {sym_visibility = "private"} {
@@ -9,7 +9,7 @@
   hal.executable.target @vulkan_spirv, filter="vulkan*" {
     hal.executable.entry_point @matmul_static_shape_f16 attributes {
       interface = @io, ordinal = 0 : index,
-      signature = (!flow.dispatch.tensor<readonly:1x225x225x16xf32>, !flow.dispatch.tensor<readonly:3x3x16x32xf32>, !flow.dispatch.tensor<writeonly:1x112x112x32xf32>) -> ()}
+      signature = (!flow.dispatch.tensor<readonly:4096x4096xf16>, !flow.dispatch.tensor<readonly:4096x4096xf16>, !flow.dispatch.tensor<writeonly:4096x4096xf16>) -> ()}
     module attributes {spv.target_env = #spv.target_env<#spv.vce<v1.3, [Shader], [SPV_KHR_storage_buffer_storage_class]>, ARM:IntegratedGPU, {}>}  {
       func @matmul_static_shape_f16() {
         %cst = constant 0.000000e+00 : f16
@@ -51,11 +51,73 @@
 }
 
 //    CHECK-LABEL: func @matmul_static_shape_f16
-//  CHECK-COUNT-16:   vector.transfer_write
-//  CHECK-COUNT-16:   vector.transfer_read
-//          CHECK:   %[[FOR_RES:.+]]:16 = scf.for
-// CHECK-COUNT-16:     vector.transfer_read
-// CHECK-COUNT-64:     vector.contract
+//  CHECK-COUNT-8:   vector.transfer_write
+//  CHECK-COUNT-8:   vector.transfer_read
+//          CHECK:   %[[FOR_RES:.+]]:8 = scf.for
+// CHECK-COUNT-12:     vector.transfer_read
+// CHECK-COUNT-32:     vector.contract
 //      CHECK:         scf.yield
-//  CHECK-COUNT-16:    vector.transfer_write %[[FOR_RES]]
+//  CHECK-COUNT-8:    vector.transfer_write %[[FOR_RES]]
+//          CHECK:    return
+
+// -----
+
+hal.executable @matmul_static_shape_f32 attributes {sym_visibility = "private"} {
+  hal.interface @io attributes {sym_visibility = "private"} {
+    hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+    hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
+    hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
+  }
+  hal.executable.target @vulkan_spirv, filter="vulkan*" {
+    hal.executable.entry_point @matmul_static_shape_f32 attributes {
+      interface = @io, ordinal = 0 : index,
+      signature = (!flow.dispatch.tensor<readonly:4096x4096xf32>, !flow.dispatch.tensor<readonly:4096x4096xf32>, !flow.dispatch.tensor<writeonly:4096x4096xf32>) -> ()}
+    module attributes {spv.target_env = #spv.target_env<#spv.vce<v1.3, [Shader], [SPV_KHR_storage_buffer_storage_class]>, ARM:IntegratedGPU, {}>}  {
+      func @matmul_static_shape_f32() {
+        %c0 = constant 0 : index
+        %cst = constant 0.000000e+00 : f32
+        %c4096 = constant 4096 : index
+        %0 = hal.interface.binding.subspan @io::@s0b0_ro_external[%c0] : memref<4096x4096xf32>
+        %1 = hal.interface.binding.subspan @io::@s0b1_ro_external[%c0] : memref<4096x4096xf32>
+        %2 = hal.interface.binding.subspan @io::@s0b2_xw_external[%c0] : memref<4096x4096xf32>
+        %workgroup_size_x = hal.interface.workgroup.size[0] : index
+        %workgroup_size_y = hal.interface.workgroup.size[1] : index
+        %workgroup_id_x = hal.interface.workgroup.id[0] : index
+        %workgroup_count_x = hal.interface.workgroup.count[0] : index
+        %workgroup_id_y = hal.interface.workgroup.id[1] : index
+        %workgroup_count_y = hal.interface.workgroup.count[1] : index
+        %3 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_y, %workgroup_size_y]
+        %4 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_y, %workgroup_size_y]
+        scf.for %arg0 = %3 to %c4096 step %4 {
+          %5 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_x, %workgroup_size_x]
+          %6 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_x, %workgroup_size_x]
+          scf.for %arg1 = %5 to %c4096 step %6 {
+            %7 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 4096)>(%arg0)[%workgroup_size_y]
+            %8 = memref.subview %0[%arg0, 0] [%7, 4096] [1, 1] : memref<4096x4096xf32> to memref<?x4096xf32, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>
+            %9 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 4096)>(%arg1)[%workgroup_size_x]
+            %10 = memref.subview %1[0, %arg1] [4096, %9] [1, 1] : memref<4096x4096xf32> to memref<4096x?xf32, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>
+            %11 = memref.subview %2[%arg0, %arg1] [%7, %9] [1, 1] : memref<4096x4096xf32> to memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>
+            linalg.fill(%11, %cst) : memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>, f32
+            linalg.matmul {__internal_linalg_transform__ = "workgroup"} ins(%8, %10 : memref<?x4096xf32, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>, memref<4096x?xf32, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>) outs(%11 : memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>)
+          }
+        }
+        return
+      }
+      hal.interface @io attributes {sym_visibility = "private"} {
+        hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+        hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
+        hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
+      }
+    }
+  }
+}
+
+//    CHECK-LABEL: func @matmul_static_shape_f32
+//  CHECK-COUNT-8:   vector.transfer_write
+//  CHECK-COUNT-8:   vector.transfer_read
+//          CHECK:   %[[FOR_RES:.+]]:8 = scf.for
+// CHECK-COUNT-12:     vector.transfer_read
+// CHECK-COUNT-32:     vector.contract
+//      CHECK:         scf.yield
+//  CHECK-COUNT-8:    vector.transfer_write %[[FOR_RES]]
 //          CHECK:    return
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/elementwise_vectorization.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/vectorize_elementwise_ops.mlir
similarity index 61%
rename from iree/compiler/Conversion/LinalgToSPIRV/test/elementwise_vectorization.mlir
rename to iree/compiler/Conversion/LinalgToSPIRV/test/vectorize_elementwise_ops.mlir
index 983dd35..b0e3c2b 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/test/elementwise_vectorization.mlir
+++ b/iree/compiler/Conversion/LinalgToSPIRV/test/vectorize_elementwise_ops.mlir
@@ -1,8 +1,8 @@
-// RUN: iree-opt -split-input-file -pass-pipeline="hal.executable(hal.executable.target(iree-codegen-spirv-linalg-tile-and-distribute,iree-spirv-tile-and-vectorize-in-one-workgroup,canonicalize,cse))" -iree-spirv-enable-vectorization %s | IreeFileCheck %s
+// RUN: iree-opt -split-input-file -pass-pipeline="hal.executable(hal.executable.target(iree-spirv-tile-and-vectorize-in-one-workgroup,canonicalize,cse))" %s | IreeFileCheck %s
 
 // CHECK-LABEL: func @elementwise_static_shape
-//       CHECK:   vector.transfer_read %10[%c0], {{.*}} memref<4xf32, #map1>, vector<4xf32>
-//       CHECK:   vector.transfer_read %11[%c0], {{.*}} memref<4xf32, #map1>, vector<4xf32>
+//       CHECK:   vector.transfer_read %{{.+}}[%c0], {{.+}} memref<4xf32, #{{.+}}>, vector<4xf32>
+//       CHECK:   vector.transfer_read %{{.+}}[%c0], {{.+}} memref<4xf32, #{{.+}}>, vector<4xf32>
 //       CHECK:   addf %{{.*}}, %{{.*}} : vector<4xf32>
 //       CHECK:   vector.transfer_write {{.*}} : vector<4xf32>, memref<4xf32
 hal.executable @elementwise_static_shape attributes {sym_visibility = "private"} {
@@ -11,7 +11,7 @@
     hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
     hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
   }
-  hal.executable.target @vulkan, filter="dylib*" {
+  hal.executable.target @vulkan, filter="vulkan*" {
     hal.executable.entry_point @elementwise_static_shape attributes {
       interface = @io, ordinal = 0 : index,
       signature = (!flow.dispatch.tensor<readonly:?xf32>,
@@ -23,30 +23,25 @@
           [Shader],
           []>, NVIDIA:DiscreteGPU,
           {subgroup_size = 32 : i32}>} {
-      func @elementwise_static_shape()
-        attributes {vkspv.num_workgroups_fn = @elementwise_static_shape__num_workgroups__} {
-        %arg0 = iree.placeholder for "interface buffer"
-          {binding = @io::@arg0, operand_result_num = 0 : i32} : memref<128xf32>
-        %arg1 = iree.placeholder for "interface buffer"
-          {binding = @io::@arg1, operand_result_num = 1 : i32} : memref<128xf32>
-        %ret0 = iree.placeholder for "interface buffer"
-          {binding = @io::@ret0, operand_result_num = 2 : i32} : memref<128xf32>
+      func @elementwise_static_shape() {
+        %c0 = constant 0 : index
+        %arg0 = hal.interface.binding.subspan @io::@arg0[%c0] : memref<128xf32>
+        %arg1 = hal.interface.binding.subspan @io::@arg1[%c0] : memref<128xf32>
+        %ret0 = hal.interface.binding.subspan @io::@ret0[%c0] : memref<128xf32>
         linalg.generic {
+          __internal_linalg_transform__ = "workgroup",
           indexing_maps = [affine_map<(i) -> (i)>,
                            affine_map<(i) -> (i)>,
                            affine_map<(i) -> (i)>],
-           iterator_types = ["parallel"]
-          } ins(%arg0, %arg1 : memref<128xf32>, memref<128xf32>)
-            outs(%ret0 : memref<128xf32>) {
+          iterator_types = ["parallel"]
+        } ins(%arg0, %arg1 : memref<128xf32>, memref<128xf32>)
+          outs(%ret0 : memref<128xf32>) {
               ^bb0(%a : f32, %b : f32, %c : f32):
               %add = addf %a, %b : f32
               linalg.yield %add : f32
         }
         return
       }
-      func private @elementwise_static_shape__num_workgroups__
-        (!shapex.ranked_shape<[4096, 4096]>, !shapex.ranked_shape<[4096, 4096]>,
-         !shapex.ranked_shape<[4096, 4096]>) -> (index, index, index)
       hal.interface @io attributes {sym_visibility = "private"} {
         hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
         hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
@@ -81,30 +76,25 @@
           [Shader],
           []>, NVIDIA:DiscreteGPU,
           {subgroup_size = 32 : i32}>} {
-      func @elementwise_transpose()
-        attributes {vkspv.num_workgroups_fn = @elementwise_transpose__num_workgroups__} {
-        %arg0 = iree.placeholder for "interface buffer"
-          {binding = @io::@arg0, operand_result_num = 0 : i32} : memref<128x8xf32>
-        %arg1 = iree.placeholder for "interface buffer"
-          {binding = @io::@arg1, operand_result_num = 1 : i32} : memref<128xf32>
-        %ret0 = iree.placeholder for "interface buffer"
-          {binding = @io::@ret0, operand_result_num = 2 : i32} : memref<128x8xf32>
+      func @elementwise_transpose() {
+        %c0 = constant 0 : index
+        %arg0 = hal.interface.binding.subspan @io::@arg0[%c0] : memref<128x8xf32>
+        %arg1 = hal.interface.binding.subspan @io::@arg1[%c0] : memref<128xf32>
+        %ret0 = hal.interface.binding.subspan @io::@ret0[%c0] : memref<128x8xf32>
         linalg.generic {
+          __internal_linalg_transform__ = "workgroup",
           indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
                            affine_map<(d0, d1) -> (d0)>,
                            affine_map<(d0, d1) -> (d0, d1)>],
-           iterator_types = ["parallel", "parallel"]
-          } ins(%arg0, %arg1 : memref<128x8xf32>, memref<128xf32>)
-            outs(%ret0 : memref<128x8xf32>) {
+          iterator_types = ["parallel", "parallel"]
+        } ins(%arg0, %arg1 : memref<128x8xf32>, memref<128xf32>)
+          outs(%ret0 : memref<128x8xf32>) {
               ^bb0(%a : f32, %b : f32, %c : f32):
               %add = addf %a, %b : f32
               linalg.yield %add : f32
         }
         return
       }
-      func private @elementwise_transpose__num_workgroups__
-        (!shapex.ranked_shape<[4096, 4096]>, !shapex.ranked_shape<[4096, 4096]>,
-         !shapex.ranked_shape<[4096, 4096]>) -> (index, index, index)
       hal.interface @io attributes {sym_visibility = "private"} {
         hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
         hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/matmul_vectorization.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/vectorize_matmul.mlir
similarity index 79%
rename from iree/compiler/Conversion/LinalgToSPIRV/test/matmul_vectorization.mlir
rename to iree/compiler/Conversion/LinalgToSPIRV/test/vectorize_matmul.mlir
index 80ed9b1..b97424c 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/test/matmul_vectorization.mlir
+++ b/iree/compiler/Conversion/LinalgToSPIRV/test/vectorize_matmul.mlir
@@ -1,17 +1,16 @@
-// RUN: iree-opt -split-input-file -pass-pipeline="hal.executable(hal.executable.target(iree-codegen-spirv-linalg-tile-and-distribute,iree-spirv-tile-and-vectorize-in-one-workgroup,canonicalize,cse))" -iree-spirv-enable-vectorization %s | IreeFileCheck %s
-// RUN: iree-opt -split-input-file -pass-pipeline="hal.executable(hal.executable.target(iree-codegen-spirv-linalg-tile-and-distribute,iree-spirv-tile-and-vectorize-in-one-workgroup,canonicalize,cse))" -iree-spirv-enable-vectorization -iree-spirv-use-workgroup-memory %s | IreeFileCheck %s -check-prefix=PROMOTE
+// RUN: iree-opt -split-input-file -pass-pipeline="hal.executable(hal.executable.target(iree-spirv-tile-and-vectorize-in-one-workgroup,canonicalize,cse))" %s | IreeFileCheck %s
+// RUN: iree-opt -split-input-file -pass-pipeline="hal.executable(hal.executable.target(iree-spirv-tile-and-vectorize-in-one-workgroup,canonicalize,cse))" -iree-spirv-use-workgroup-memory %s | IreeFileCheck %s -check-prefix=PROMOTE
 
 hal.executable @matmul_static_shape attributes {sym_visibility = "private"} {
-  hal.interface @io {
+  hal.interface @io attributes {sym_visibility = "private"} {
     hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
     hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
     hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
   }
-  hal.executable.target @vulkan, filter="dylib*" {
+  hal.executable.target @vulkan_spirv, filter="vulkan*" {
     hal.executable.entry_point @matmul_static_shape attributes {
       interface = @io, ordinal = 0 : index,
-      signature = (!flow.dispatch.tensor<readonly:?x?xf32>, !flow.dispatch.tensor<readonly:?x?xf32>,
-        !flow.dispatch.tensor<writeonly:?x?xf32>) -> ()}
+      signature = (!flow.dispatch.tensor<readonly:4096x4096xf16>, !flow.dispatch.tensor<readonly:4096x4096xf16>, !flow.dispatch.tensor<writeonly:4096x4096xf16>) -> ()}
     module attributes {
       spv.target_env =
         #spv.target_env<#spv.vce<v1.5,
@@ -37,21 +36,27 @@
            max_compute_workgroup_invocations = 1024 : i32,
            max_compute_workgroup_size = dense<[2147483647, 65535, 65535]> : vector<3xi32>,
            subgroup_size = 32 : i32}>} {
-      func @matmul_static_shape()
-        attributes {vkspv.num_workgroups_fn = @matmul_static_shape__num_workgroups__} {
-        %arg0 = iree.placeholder for "interface buffer"
-          {binding = @io::@arg0, operand_result_num = 0 : i32} : memref<4096x4096xf16>
-        %arg1 = iree.placeholder for "interface buffer"
-          {binding = @io::@arg1, operand_result_num = 1 : i32} : memref<4096x4096xf16>
-        %ret0 = iree.placeholder for "interface buffer"
-          {binding = @io::@ret0, operand_result_num = 2 : i32} : memref<4096x4096xf16>
-        linalg.matmul ins(%arg0, %arg1 : memref<4096x4096xf16>, memref<4096x4096xf16>)
-                     outs(%ret0 : memref<4096x4096xf16>)
+      func @matmul_static_shape() {
+        %c32 = constant 32 : index
+        %c4096 = constant 4096 : index
+        %c0 = constant 0 : index
+        %0 = hal.interface.binding.subspan @io::@arg0[%c0] : memref<4096x4096xf16>
+        %1 = hal.interface.binding.subspan @io::@arg1[%c0] : memref<4096x4096xf16>
+        %2 = hal.interface.binding.subspan @io::@ret0[%c0] : memref<4096x4096xf16>
+        %3 = hal.interface.workgroup.size[0] : index
+        %4 = hal.interface.workgroup.size[1] : index
+        scf.for %arg0 = %c0 to %c4096 step %c32 {
+          %5 = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%4]
+          %6 = memref.subview %0[%5, %arg0] [64, 32] [1, 1] : memref<4096x4096xf16> to memref<64x32xf16, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>
+          %7 = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%3]
+          %8 = memref.subview %1[%arg0, %7] [32, 64] [1, 1] : memref<4096x4096xf16> to memref<32x64xf16, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>
+          %9 = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%4]
+          %10 = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%3]
+          %11 = memref.subview %2[%9, %10] [64, 64] [1, 1] : memref<4096x4096xf16> to memref<64x64xf16, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>
+          linalg.matmul {__internal_linalg_transform__ = "workgroup"} ins(%6, %8 : memref<64x32xf16, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>, memref<32x64xf16, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>) outs(%11 : memref<64x64xf16, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>)
+        }
         return
       }
-      func private @matmul_static_shape__num_workgroups__
-        (!shapex.ranked_shape<[4096, 4096]>, !shapex.ranked_shape<[4096, 4096]>,
-         !shapex.ranked_shape<[4096, 4096]>) -> (index, index, index)
       hal.interface @io attributes {sym_visibility = "private"} {
         hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
         hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
@@ -60,18 +65,19 @@
     }
   }
 }
+
 //  CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 * 64)>
 //      CHECK: func @matmul_static_shape
-//  CHECK-DAG:  %[[ARG0:.+]] = iree.placeholder {{.*}} {binding = @io::@arg0
-//  CHECK-DAG:  %[[ARG1:.+]] = iree.placeholder {{.*}} {binding = @io::@arg1
-//  CHECK-DAG:  %[[RET0:.+]] = iree.placeholder {{.*}} {binding = @io::@ret0
-//  CHECK-DAG:  %[[C0:.+]] = constant 0 : index
 //  CHECK-DAG:  %[[CST:.+]] = constant 0.0
+//  CHECK-DAG:  %[[C0:.+]] = constant 0 : index
 //  CHECK-DAG:  %[[C16:.+]] = constant 16 : index
 //  CHECK-DAG:  %[[C32:.+]] = constant 32 : index
 //  CHECK-DAG:  %[[C48:.+]] = constant 48 : index
-//      CHECK:  %[[BIDX:.+]] = "gpu.block_id"() {dimension = "x"}
-//      CHECK:  %[[BIDY:.+]] = "gpu.block_id"() {dimension = "y"}
+//  CHECK-DAG:  %[[ARG0:.+]] = hal.interface.binding.subspan @io::@arg0[%[[C0]]]
+//  CHECK-DAG:  %[[ARG1:.+]] = hal.interface.binding.subspan @io::@arg1[%[[C0]]]
+//  CHECK-DAG:  %[[RET0:.+]] = hal.interface.binding.subspan @io::@ret0[%[[C0]]]
+//      CHECK:  %[[BIDX:.+]] = hal.interface.workgroup.size[0]
+//      CHECK:  %[[BIDY:.+]] = hal.interface.workgroup.size[1]
 //      CHECK:  %[[BOFFSET_Y:.+]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
 //      CHECK:  %[[BOFFSET_X:.+]] = affine.apply #[[MAP0]]()[%[[BIDX]]]
 //      CHECK:    %[[SUBVIEW_RESULT:.+]] = memref.subview %[[RET0]]
@@ -268,12 +274,75 @@
 //  CHECK-DAG:  vector.transfer_write %[[FOR_RES]]#14, %[[SUBVIEW_RESULT]][%[[C48]], %[[C32]]]
 //  CHECK-DAG:  vector.transfer_write %[[FOR_RES]]#15, %[[SUBVIEW_RESULT]][%[[C48]], %[[C48]]]
 
+// -----
+
+hal.executable @matmul_static_shape attributes {sym_visibility = "private"} {
+  hal.interface @io attributes {sym_visibility = "private"} {
+    hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+    hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
+    hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
+  }
+  hal.executable.target @vulkan_spirv, filter="vulkan*" {
+    hal.executable.entry_point @matmul_static_shape attributes {
+      interface = @io, ordinal = 0 : index,
+      signature = (!flow.dispatch.tensor<readonly:4096x4096xf16>, !flow.dispatch.tensor<readonly:4096x4096xf16>, !flow.dispatch.tensor<writeonly:4096x4096xf16>) -> ()}
+    module attributes {
+      spv.target_env =
+        #spv.target_env<#spv.vce<v1.5,
+          [Shader, Float64, Float16, Int64, Int16, Int8, StorageBuffer16BitAccess,
+           StorageUniform16, StoragePushConstant16, StorageBuffer8BitAccess,
+           UniformAndStorageBuffer8BitAccess, StoragePushConstant8, GroupNonUniform,
+           GroupNonUniformVote, GroupNonUniformArithmetic, GroupNonUniformBallot,
+           GroupNonUniformShuffle, GroupNonUniformShuffleRelative, VariablePointers,
+           VariablePointersStorageBuffer, CooperativeMatrixNV],
+          [SPV_KHR_16bit_storage, SPV_KHR_8bit_storage,
+           SPV_KHR_storage_buffer_storage_class, SPV_KHR_variable_pointers,
+           SPV_NV_cooperative_matrix]>, NVIDIA:DiscreteGPU,
+          {cooperative_matrix_properties_nv = [
+            {a_type = i8, b_type = i8, c_type = i32, k_size = 32 : i32,
+             m_size = 8 : i32, n_size = 8 : i32, result_type = i32, scope = 3 : i32},
+            {a_type = f16, b_type = f16, c_type = f16, k_size = 16 : i32,
+             m_size = 16 : i32, n_size = 16 : i32, result_type = f16,
+             scope = 3 : i32},
+            {a_type = f16, b_type = f16, c_type = f32, k_size = 16 : i32,
+             m_size = 16 : i32, n_size = 16 : i32, result_type = f32,
+             scope = 3 : i32}],
+           max_compute_shared_memory_size = 49152 : i32,
+           max_compute_workgroup_invocations = 1024 : i32,
+           max_compute_workgroup_size = dense<[2147483647, 65535, 65535]> : vector<3xi32>,
+           subgroup_size = 32 : i32}>} {
+      func @matmul_static_shape() {
+        %c32 = constant 32 : index
+        %c4096 = constant 4096 : index
+        %c0 = constant 0 : index
+        %0 = hal.interface.binding.subspan @io::@arg0[%c0] : memref<4096x4096xf16>
+        %1 = hal.interface.binding.subspan @io::@arg1[%c0] : memref<4096x4096xf16>
+        %2 = hal.interface.binding.subspan @io::@ret0[%c0] : memref<4096x4096xf16>
+        %3 = hal.interface.workgroup.size[0] : index
+        %4 = hal.interface.workgroup.size[1] : index
+        scf.for %arg0 = %c0 to %c4096 step %c32 {
+          %5 = affine.apply affine_map<()[s0] -> (s0 * 128)>()[%4]
+          %6 = memref.subview %0[%5, %arg0] [128, 32] [1, 1] : memref<4096x4096xf16> to memref<128x32xf16, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>
+          %7 = affine.apply affine_map<()[s0] -> (s0 * 128)>()[%3]
+          %8 = memref.subview %1[%arg0, %7] [32, 128] [1, 1] : memref<4096x4096xf16> to memref<32x128xf16, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>
+          %9 = affine.apply affine_map<()[s0] -> (s0 * 128)>()[%4]
+          %10 = affine.apply affine_map<()[s0] -> (s0 * 128)>()[%3]
+          %11 = memref.subview %2[%9, %10] [128, 128] [1, 1] : memref<4096x4096xf16> to memref<128x128xf16, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>
+          linalg.matmul {__internal_linalg_transform__ = "workgroup", is_root_op, launch_info_key = "__op_num_0__"} ins(%6, %8 : memref<128x32xf16, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>, memref<32x128xf16, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>) outs(%11 : memref<128x128xf16, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>)
+        }
+        return
+      }
+      hal.interface @io attributes {sym_visibility = "private"} {
+        hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+        hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
+        hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
+      }
+    }
+  }
+}
 
 //  PROMOTE-DAG: #[[MAP4:.+]] = affine_map<()[s0] -> (s0 * 64 - (s0 floordiv 2) * 128)>
 //      PROMOTE: func @matmul_static_shape
-//  PROMOTE-DAG:  %[[ARG0:.+]] = iree.placeholder {{.*}} {binding = @io::@arg0
-//  PROMOTE-DAG:  %[[ARG1:.+]] = iree.placeholder {{.*}} {binding = @io::@arg1
-//  PROMOTE-DAG:  %[[RET0:.+]] = iree.placeholder {{.*}} {binding = @io::@ret0
 //  PROMOTE-DAG:  %[[C0:.+]] = constant 0 : index
 //  PROMOTE-DAG:  %[[C2:.+]] = constant 2
 //  PROMOTE-DAG:  %[[C16:.+]] = constant 16
@@ -281,6 +350,9 @@
 //  PROMOTE-DAG:  %[[C48:.+]] = constant 48
 //  PROMOTE-DAG:  %[[ALLOC1:.+]] = memref.alloc() : memref<128x32xf16, 3>
 //  PROMOTE-DAG:  %[[ALLOC2:.+]] = memref.alloc() : memref<32x128xf16, 3>
+//  PROMOTE-DAG:  %[[ARG0:.+]] = hal.interface.binding.subspan @io::@arg0[%[[C0]]]
+//  PROMOTE-DAG:  %[[ARG1:.+]] = hal.interface.binding.subspan @io::@arg1[%[[C0]]]
+//  PROMOTE-DAG:  %[[RET0:.+]] = hal.interface.binding.subspan @io::@ret0[%[[C0]]]
 
 //      PROMOTE:  %[[RESULT_SUBVIEW:.+]] = memref.subview %[[RET0]]
 //      PROMOTE:  %[[WGMEM_LHS_SUBVIEW:.+]] = memref.subview %[[ALLOC1]][0, 0] [128, 32] [1, 1]