[Codegen][LLVMGPU] Replace TransposeSharedMem pipeline (#21661)

Replaces all uses of this pipeline with LLVMGPUTileAndFuse. Keeps using
the same configuration logic as it turns out it is critical for
performance.
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.td b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.td
index 1fcef3b..01797db 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.td
+++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.td
@@ -40,14 +40,12 @@
     : I32EnumAttrCase<"LLVMGPUDistribute", 102>;
 def LLVMGPU_Vectorize
     : I32EnumAttrCase<"LLVMGPUVectorize", 103>;
-def LLVMGPU_TransposeSharedMem
-    : I32EnumAttrCase<"LLVMGPUTransposeSharedMem", 104>;
 def LLVMGPU_VectorDistribute
-    : I32EnumAttrCase<"LLVMGPUVectorDistribute", 105>;
+    : I32EnumAttrCase<"LLVMGPUVectorDistribute", 104>;
 def LLVMGPU_WinogradVectorize
-    : I32EnumAttrCase<"LLVMGPUWinogradVectorize", 106>;
+    : I32EnumAttrCase<"LLVMGPUWinogradVectorize", 105>;
 def LLVMGPU_TileAndFuse
-    : I32EnumAttrCase<"LLVMGPUTileAndFuse", 107>;
+    : I32EnumAttrCase<"LLVMGPUTileAndFuse", 106>;
 
 def SPIRV_BaseLowering
     : I32EnumAttrCase<"SPIRVBaseLowering", 200>;
@@ -86,8 +84,8 @@
 
     // LLVMGPU CodeGen pipelines
     LLVMGPU_Default, LLVMGPU_BaseLowering, LLVMGPU_SimpleDistribute,
-    LLVMGPU_Vectorize, LLVMGPU_TransposeSharedMem,
-    LLVMGPU_VectorDistribute, LLVMGPU_WinogradVectorize, LLVMGPU_TileAndFuse,
+    LLVMGPU_Vectorize, LLVMGPU_VectorDistribute,
+    LLVMGPU_WinogradVectorize, LLVMGPU_TileAndFuse,
 
     // SPIR-V CodeGen pipelines
     SPIRV_BaseLowering, SPIRV_BaseDistribute, SPIRV_BaseVectorize,
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
index cfac82e..44259ce 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
@@ -1893,7 +1893,8 @@
 // Transpose Pipeline Configuration
 //====---------------------------------------------------------------------===//
 
-static LogicalResult setTransposeConfig(mlir::FunctionOpInterface entryPoint,
+static LogicalResult setTransposeConfig(IREE::GPU::TargetAttr target,
+                                        mlir::FunctionOpInterface entryPoint,
                                         linalg::LinalgOp linalgOp) {
   LinalgOpInfo opInfo(linalgOp, sharedMemTransposeFilter);
 
@@ -1922,12 +1923,16 @@
 
   int32_t tileM = 32;
   int32_t tileN = 32;
-  TileSizesListType tileSizes;
   // Set all tile sizes to 1 except for fastest moving dimensions.
-  SmallVector<int64_t> tileSizesTemp(linalgOp.getNumLoops(), 1);
-  tileSizesTemp[outputFastestDim] = 32;
-  tileSizesTemp[inputFastestDim] = 32;
-  tileSizes.push_back(tileSizesTemp);
+  SmallVector<int64_t> workgroupTileSizes(linalgOp.getNumLoops(), 1);
+  workgroupTileSizes[outputFastestDim] = 32;
+  workgroupTileSizes[inputFastestDim] = 32;
+
+  // Set the thread tile sizes to 1 for all dims except the fastest varying
+  // output dim which we set to 4. Because we promote the tranposed input
+  // operands, this gives both vectorized global reads and writes.
+  SmallVector<int64_t> threadTileSizes(linalgOp.getNumLoops(), 1);
+  threadTileSizes[outputFastestDim] = 4;
 
   // Check alignment with tile size for each transpose. Only the fastest moving
   // dims need to match the transpose tile.
@@ -1941,9 +1946,38 @@
   // moving dimension so each thread can execute a vectorized copy of 4
   // contiguous elements at a time from the 32 block.
   std::array<int64_t, 3> workgroupSize = {8, 32, 1};
+
+  MLIRContext *context = linalgOp.getContext();
+  Builder b(context);
+  SmallVector<NamedAttribute> attrs{
+      NamedAttribute("workgroup", b.getI64ArrayAttr(workgroupTileSizes)),
+      NamedAttribute("thread", b.getI64ArrayAttr(threadTileSizes))};
+  SmallVector<int64_t> promotedOperands;
+  for (OpOperand *operand : transposedOperands) {
+    promotedOperands.push_back(operand->getOperandNumber());
+  }
+  IREE::GPU::appendPromotedOperandsList(context, attrs, promotedOperands);
+  DictionaryAttr configDict = DictionaryAttr::get(context, attrs);
+  IREE::GPU::LoweringConfigAttr loweringConfig =
+      IREE::GPU::LoweringConfigAttr::get(context, configDict);
+
+  IREE::GPU::GPUPipelineOptionsAttr pipelineOptions =
+      IREE::GPU::GPUPipelineOptionsAttr::get(
+          context, /*prefetchSharedMemory=*/false,
+          /*no_reduce_shared_memory_bank_conflicts=*/false,
+          /*use_igemm_convolution=*/false,
+          /*reorder_workgroups_strategy=*/std::nullopt);
+  DictionaryAttr pipelineConfig = DictionaryAttr::get(
+      context,
+      {NamedAttribute(IREE::GPU::GPUPipelineOptionsAttr::getDictKeyName(),
+                      pipelineOptions)});
+  const int64_t targetSubgroupSize = target.getPreferredSubgroupSize();
+
+  // TODO(qedawkins): Use a shared pipeline identifier here.
   return setOpConfigAndEntryPointFnTranslation(
-      entryPoint, linalgOp, tileSizes,
-      CodeGenPipeline::LLVMGPUTransposeSharedMem, workgroupSize);
+      entryPoint, linalgOp, loweringConfig,
+      IREE::Codegen::DispatchLoweringPassPipeline::LLVMGPUTileAndFuse,
+      workgroupSize, targetSubgroupSize, pipelineConfig);
 }
 
 //====---------------------------------------------------------------------===//
@@ -2257,7 +2291,8 @@
     }
     auto genericOp = dyn_cast<linalg::GenericOp>(computeOp);
     if (genericOp) {
-      if (succeeded(setTransposeConfig(entryPointFn, genericOp))) {
+      if (genericOp &&
+          succeeded(setTransposeConfig(target, entryPointFn, genericOp))) {
         LDBG() << "Transpose Config";
         return success();
       } else if (ukernelConfig &&
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPULowerExecutableTarget.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPULowerExecutableTarget.cpp
index 6dcfb6a..6fa62b9 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPULowerExecutableTarget.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPULowerExecutableTarget.cpp
@@ -102,9 +102,6 @@
   case IREE::Codegen::DispatchLoweringPassPipeline::LLVMGPUWinogradVectorize:
     addGPUWinogradVectorizePassPipeline(pipeline);
     break;
-  case IREE::Codegen::DispatchLoweringPassPipeline::LLVMGPUTransposeSharedMem:
-    addGPUTransposePassPipeline(pipeline, pipelineOptions);
-    break;
   case IREE::Codegen::DispatchLoweringPassPipeline::LLVMGPUVectorDistribute:
     addGPUVectorDistributePassPipeline(pipeline, pipelineOptions, forROCDL);
     break;
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
index f988040..32edfce 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
@@ -650,49 +650,6 @@
 }
 
 //===---------------------------------------------------------------------===//
-// Transpose
-//===---------------------------------------------------------------------===//
-
-void addGPUTransposePassPipeline(OpPassManager &funcPassManager,
-                                 const GPUPipelineOptions &options) {
-  tileAndDistributeToWorkgroup(funcPassManager, /*useForall=*/true);
-
-  funcPassManager.addPass(createConfigTrackingCanonicalizerPass());
-  funcPassManager.addPass(createConfigTrackingCanonicalizerPass());
-  funcPassManager.addPass(createCSEPass());
-
-  funcPassManager.addPass(
-      createGPUTensorAlloc(GPUPromoteSharedMemPattern::TransposeOpPattern));
-  funcPassManager.addPass(createGPUTensorTilePass());
-
-  // Linalg -> vector
-  addGPUVectorizationPasses(funcPassManager);
-  funcPassManager.addPass(createOptimizeVectorTransferPass());
-  funcPassManager.addPass(createOptimizeTensorInsertExtractSlicesPass());
-
-  // tensor to memref
-  addBufferizePasses(funcPassManager);
-
-  // distribute foreach threads
-  funcPassManager.addPass(createGPUDistributePass());
-
-  funcPassManager.addPass(createMemrefCopyToLinalgPass());
-  funcPassManager.addPass(createGPUDistributeSharedMemoryCopyPass());
-  funcPassManager.addPass(createCanonicalizerPass());
-  funcPassManager.addPass(createCSEPass());
-
-  if (options.enableReduceSharedMemoryBankConflicts) {
-    // May or may not need to reduce shared mememory conflicts.
-    GPUReduceBankConflictsPassOptions options = {};
-    options.paddingBits = 32;
-    funcPassManager.addPass(createGPUReduceBankConflictsPass(options));
-  }
-
-  funcPassManager.addPass(createCanonicalizerPass());
-  funcPassManager.addPass(createCSEPass());
-}
-
-//===---------------------------------------------------------------------===//
 // Vector Distribution
 //===---------------------------------------------------------------------===//
 
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.h b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.h
index abc59bc..f9aa066 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.h
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.h
@@ -43,10 +43,6 @@
 void addGPUTransformDialectPasses(OpPassManager &funcPassManager,
                                   StringRef entryPoint);
 
-/// Lowering transpose using shared memory.
-void addGPUTransposePassPipeline(OpPassManager &funcPassManager,
-                                 const GPUPipelineOptions &options);
-
 /// Lowering calling vectorization patterns. Expects pass manager to be a
 /// module-level pass manager.
 void addGPUVectorizationPassPipeline(OpPassManager &funcPassManager);
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/nvvm_pipeline_test.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/nvvm_pipeline_test.mlir
index a444e64..83fed42 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/nvvm_pipeline_test.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/nvvm_pipeline_test.mlir
@@ -539,7 +539,8 @@
 #map1 = affine_map<(d0, d1) -> (d0, d1)>
 hal.executable private @shared_mem_transpose  {
   hal.executable.variant @cuda target(#executable_target_cuda_nvptx_fb) {
-    hal.executable.export public @shared_mem_transpose layout(#pipeline_layout) count(%arg0: !hal.device, %arg1: index, %arg2: index) -> (index, index, index) {
+    hal.executable.export @shared_mem_transpose layout(#pipeline_layout)
+      count(%arg0: !hal.device, %arg1: index, %arg2: index) -> (index, index, index) {
         %x, %y, %z = iree_tensor_ext.dispatch.workgroup_count_from_dag_root(%arg1, %arg2)
         hal.return %x, %y, %z : index, index, index
     }
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transpose_pipeline_test.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transpose_pipeline_test.mlir
index 5a0f2ba..d88dcd8 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transpose_pipeline_test.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transpose_pipeline_test.mlir
@@ -8,7 +8,8 @@
 #executable_target_cuda_nvptx_fb = #hal.executable.target<"cuda", "cuda-nvptx-fb">
 hal.executable @transpose_dispatch_0 {
   hal.executable.variant public @cuda_nvptx_fb target(#executable_target_cuda_nvptx_fb) {
-    hal.executable.export public @transpose_dispatch_0_generic_4096x4096 ordinal(0) layout(#pipeline_layout) count(%arg0: !hal.device, %arg1: index, %arg2: index) -> (index, index, index) {
+    hal.executable.export public @transpose_dispatch_0_generic_4096x4096 ordinal(0) layout(#pipeline_layout)
+    count(%arg0: !hal.device, %arg1: index, %arg2: index) -> (index, index, index) {
       %x, %y, %z = iree_tensor_ext.dispatch.workgroup_count_from_dag_root(%arg1, %arg2)
       hal.return %x, %y, %z : index, index, index
     }
@@ -30,28 +31,19 @@
   }
 }
 
-// CHECK-LABEL:  hal.executable public @transpose_dispatch_0
-//   CHECK-DAG:  %[[PV:.*]] = ub.poison : f32
-//   CHECK-DAG:  %[[C0:.*]] = arith.constant 0 : index
-//   CHECK-DAG:  %[[TX:.*]] = gpu.thread_id  x
-//   CHECK-DAG:  %[[TY:.*]] = gpu.thread_id  y
-//   CHECK-DAG:  %[[ALLOC:.*]] = memref.alloc() : memref<32x33xf32, #gpu.address_space<workgroup>>
-//       CHECK:  %[[D0_BINDING:.*]] = hal.interface.binding.subspan layout({{.+}}) binding(0) alignment(64) offset(%[[C0]]) : memref<4096x4096xf32, #hal.descriptor_type<storage_buffer>>
-//       CHECK:  %[[D0:.+]] = memref.assume_alignment %[[D0_BINDING]], 64 : memref<4096x4096xf32, #hal.descriptor_type<storage_buffer>>
-//       CHECK:  %[[D1_BINDING:.*]] = hal.interface.binding.subspan layout({{.+}}) binding(1) alignment(64) offset(%[[C0]]) : memref<4096x4096xf32, #hal.descriptor_type<storage_buffer>>
-//       CHECK:  %[[D1:.+]] = memref.assume_alignment %[[D1_BINDING]], 64 : memref<4096x4096xf32, #hal.descriptor_type<storage_buffer>>
-//       CHECK:  gpu.barrier
-//       CHECK:  %[[D2:.*]] = affine.apply #{{.*}}()[%{{.+}}, %[[TY]]]
-//       CHECK:  %[[D3:.*]] = affine.apply #{{.*}}()[%{{.+}}, %[[TX]]]
-//       CHECK:  %[[D4:.*]] = vector.transfer_read %[[D0]][%[[D2]], %[[D3]]], %[[PV]] {in_bounds = [true, true]} : memref<4096x4096xf32, #hal.descriptor_type<storage_buffer>>, vector<1x4xf32>
-//       CHECK:  %[[D5:.*]] = affine.apply #{{.*}}()[%[[TX]]]
-//       CHECK:  vector.transfer_write %[[D4]], %[[ALLOC]][%[[TY]], %[[D5]]] {in_bounds = [true, true]} : vector<1x4xf32>, memref<32x33xf32, #gpu.address_space<workgroup>>
-//       CHECK:  gpu.barrier
-//       CHECK:  %[[D6:.*]] = vector.transfer_read %[[ALLOC]][%[[D5]], %[[TY]]], %[[PV]] {in_bounds = [true, true]} : memref<32x33xf32, #gpu.address_space<workgroup>>, vector<4x1xf32>
-//       CHECK:  %[[D7:.*]] = vector.shape_cast %[[D6]] : vector<4x1xf32> to vector<4xf32>
-//       CHECK:  %[[D8:.*]] = affine.apply #{{.*}}()[%{{.+}}, %[[TY]]]
-//       CHECK:  %[[D9:.*]] = affine.apply #{{.*}}()[%{{.+}}, %[[TX]]]
-//       CHECK:  vector.transfer_write %[[D7]], %[[D1]][%[[D8]], %[[D9]]] {in_bounds = [true]} : vector<4xf32>, memref<4096x4096xf32, #hal.descriptor_type<storage_buffer>>
+// CHECK-LABEL:   func @transpose_dispatch_0
+//       CHECK:   %[[A:.*]] = memref.alloc() : memref<32x34xf32, #gpu.address_space<workgroup>>
+//   CHECK-DAG:   %[[B0:.*]] = hal.interface.binding.subspan layout({{.+}}) binding(0)
+//   CHECK-DAG:   %[[A0:.*]] = memref.assume_alignment %[[B0]]
+//   CHECK-DAG:   %[[B1:.*]] = hal.interface.binding.subspan layout({{.+}}) binding(1)
+//   CHECK-DAG:   %[[A1:.*]] = memref.assume_alignment %[[B1]]
+//       CHECK:   gpu.barrier
+//       CHECK:   %[[GR0:.*]] = vector.transfer_read %[[A0]]{{.*}} vector<4xf32>
+//       CHECK:   vector.transfer_write %[[GR0]], %[[A]]{{.*}} : vector<4xf32>
+//       CHECK:   gpu.barrier
+//       CHECK:   %[[SR:.*]] = vector.transfer_read %[[A]]{{.*}} vector<4x1xf32>
+//       CHECK:   %[[SC:.*]] = vector.shape_cast %[[SR]] : vector<4x1xf32> to vector<4xf32>
+//       CHECK:   vector.transfer_write %[[SC]], %[[A1]]{{.*}} : vector<4xf32>
 
 // -----
 
@@ -63,7 +55,8 @@
 #executable_target_cuda_nvptx_fb = #hal.executable.target<"cuda", "cuda-nvptx-fb">
 hal.executable @transpose_single_operand_dispatch_0_generic_768x2048 {
   hal.executable.variant public @cuda_nvptx_fb target(#executable_target_cuda_nvptx_fb) {
-    hal.executable.export public @transpose_single_operand_dispatch_0_generic_768x2048 ordinal(0) layout(#pipeline_layout) count(%arg0: !hal.device, %arg1: index, %arg2: index) -> (index, index, index) {
+    hal.executable.export public @transpose_single_operand_dispatch_0_generic_768x2048 ordinal(0) layout(#pipeline_layout)
+    count(%arg0: !hal.device, %arg1: index, %arg2: index) -> (index, index, index) {
       %x, %y, %z = iree_tensor_ext.dispatch.workgroup_count_from_dag_root(%arg1, %arg2)
       hal.return %x, %y, %z : index, index, index
     }
@@ -88,34 +81,23 @@
   }
 }
 
-// CHECK-LABEL:  hal.executable public @transpose_single_operand_dispatch_0_generic_768x2048
-//       CHECK:  %[[PV:.*]] = ub.poison : f32
-//       CHECK:  %[[C0:.*]] = arith.constant 0 : index
-//       CHECK:  %[[TX:.*]] = gpu.thread_id  x
-//       CHECK:  %[[TY:.*]] = gpu.thread_id  y
-//       CHECK:  %[[ALLOC:.*]] = memref.alloc() : memref<32x33xf32, #gpu.address_space<workgroup>>
-//       CHECK:  %[[D0_BINDING:.*]] = hal.interface.binding.subspan layout({{.+}}) binding(0) alignment(64) offset(%[[C0]]) : memref<2048x768xf32, #hal.descriptor_type<storage_buffer>>
-//       CHECK:  %[[D0:.+]] = memref.assume_alignment %[[D0_BINDING]], 64 : memref<2048x768xf32, #hal.descriptor_type<storage_buffer>>
-//       CHECK:  %[[D1_BINDING:.*]] = hal.interface.binding.subspan layout({{.+}}) binding(1) alignment(64) offset(%[[C0]]) : memref<768x2048xf32, #hal.descriptor_type<storage_buffer>>
-//       CHECK:  %[[D1:.+]] = memref.assume_alignment %[[D1_BINDING]], 64 : memref<768x2048xf32, #hal.descriptor_type<storage_buffer>>
-//       CHECK:  %[[D2_BINDING:.*]] = hal.interface.binding.subspan layout({{.+}}) binding(2) alignment(64) offset(%[[C0]]) : memref<768x2048xf32, #hal.descriptor_type<storage_buffer>>
-//       CHECK:  %[[D2:.+]] = memref.assume_alignment %[[D2_BINDING]], 64 : memref<768x2048xf32, #hal.descriptor_type<storage_buffer>>
-//       CHECK:  gpu.barrier
-//       CHECK:  %[[D3:.*]] = affine.apply #{{.*}}()[%[[TX]]]
-//       CHECK:  %[[D4:.*]] = affine.apply #{{.*}}()[%{{.*}}, %[[TY]]]
-//       CHECK:  %[[D5:.*]] = affine.apply #{{.*}}()[%{{.*}}, %[[TX]]]
-//       CHECK:  %[[D6:.*]] = vector.transfer_read %[[D0]][%[[D4]], %[[D5]]], %[[PV]] {in_bounds = [true, true]} : memref<2048x768xf32, #hal.descriptor_type<storage_buffer>>, vector<1x4xf32>
-//       CHECK:  vector.transfer_write %[[D6]], %[[ALLOC]][%[[TY]], %[[D3]]] {in_bounds = [true, true]} : vector<1x4xf32>, memref<32x33xf32, #gpu.address_space<workgroup>>
-//       CHECK:  gpu.barrier
-//       CHECK:  %[[D7:.*]] = vector.transfer_read %[[ALLOC]][%[[D3]], %[[TY]]], %[[PV]] {in_bounds = [true, true]} : memref<32x33xf32, #gpu.address_space<workgroup>>, vector<4x1xf32>
-//       CHECK:  %[[D8:.*]] = arith.addi %[[TY]], %{{.*}}
-//       CHECK:  %[[D9:.*]] = arith.addi %[[D3]], %{{.*}}
-//       CHECK:  %[[D10:.*]] = vector.transfer_read %[[D1]][%[[D8]], %[[D9]]], %[[PV]] {in_bounds = [true]} : memref<768x2048xf32, #hal.descriptor_type<storage_buffer>>, vector<4xf32>
-//       CHECK:  %[[D11:.*]] = vector.shape_cast %[[D7]] : vector<4x1xf32> to vector<4xf32>
-//       CHECK:  %[[D12:.*]] = arith.addf %[[D11]], %[[D10]] : vector<4xf32>
-//       CHECK:  %[[D13:.*]] = affine.apply #{{.*}}()[%{{.*}}, %[[TY]]]
-//       CHECK:  %[[D14:.*]] = affine.apply #{{.*}}()[%{{.*}}, %[[TX]]]
-//       CHECK:  vector.transfer_write %[[D12]], %[[D2]][%[[D13]], %[[D14]]] {in_bounds = [true]} : vector<4xf32>, memref<768x2048xf32, #hal.descriptor_type<storage_buffer>>
+// CHECK-LABEL:   func @transpose_single_operand_dispatch_0_generic_768x2048
+//       CHECK:   %[[A:.*]] = memref.alloc() : memref<32x34xf32, #gpu.address_space<workgroup>>
+//   CHECK-DAG:   %[[B0:.*]] = hal.interface.binding.subspan layout({{.+}}) binding(0)
+//   CHECK-DAG:   %[[A0:.*]] = memref.assume_alignment %[[B0]]
+//   CHECK-DAG:   %[[B1:.*]] = hal.interface.binding.subspan layout({{.+}}) binding(1)
+//   CHECK-DAG:   %[[A1:.*]] = memref.assume_alignment %[[B1]]
+//   CHECK-DAG:   %[[B2:.*]] = hal.interface.binding.subspan layout({{.+}}) binding(2)
+//   CHECK-DAG:   %[[A2:.*]] = memref.assume_alignment %[[B2]]
+//       CHECK:   gpu.barrier
+//       CHECK:   %[[GR0:.*]] = vector.transfer_read %[[A0]]{{.*}} vector<4xf32>
+//       CHECK:   vector.transfer_write %[[GR0]], %[[A]]{{.*}} : vector<4xf32>
+//       CHECK:   gpu.barrier
+//       CHECK:   %[[SR:.*]] = vector.transfer_read %[[A]]{{.*}} vector<4x1xf32>
+//       CHECK:   %[[GR1:.*]] = vector.transfer_read %[[A1]]{{.*}} vector<4xf32>
+//       CHECK:   %[[SC:.*]] = vector.shape_cast %[[SR]] : vector<4x1xf32> to vector<4xf32>
+//       CHECK:   %[[ADD:.*]] = arith.addf %[[SC]], %[[GR1]] : vector<4xf32>
+//       CHECK:   vector.transfer_write %[[ADD]], %[[A2]]{{.*}} : vector<4xf32>
 
 // -----
 
@@ -127,7 +109,8 @@
 #executable_target_cuda_nvptx_fb = #hal.executable.target<"cuda", "cuda-nvptx-fb">
 hal.executable @transpose_3d_no_dispatch_0_generic_768x2048x1024 {
   hal.executable.variant public @cuda_nvptx_fb target(#executable_target_cuda_nvptx_fb) {
-    hal.executable.export public @transpose_3d_no_dispatch_0_generic_768x2048x1024 ordinal(0) layout(#pipeline_layout) count(%arg0: !hal.device, %arg1: index, %arg2: index, %arg3: index) -> (index, index, index) {
+    hal.executable.export public @transpose_3d_no_dispatch_0_generic_768x2048x1024 ordinal(0) layout(#pipeline_layout)
+    count(%arg0: !hal.device, %arg1: index, %arg2: index, %arg3: index) -> (index, index, index) {
       %x, %y, %z = iree_tensor_ext.dispatch.workgroup_count_from_dag_root(%arg1, %arg2, %arg3)
       hal.return %x, %y, %z : index, index, index
     }
@@ -167,7 +150,8 @@
 #executable_target_cuda_nvptx_fb = #hal.executable.target<"cuda", "cuda-nvptx-fb">
 hal.executable @transpose_3d_yes_dispatch_0_generic_10x768x2048 {
   hal.executable.variant public @cuda_nvptx_fb target(#executable_target_cuda_nvptx_fb) {
-    hal.executable.export public @transpose_3d_yes_dispatch_0_generic_10x768x2048 ordinal(0) layout(#pipeline_layout) count(%arg0: !hal.device, %arg1: index, %arg2: index, %arg3: index) -> (index, index, index) {
+    hal.executable.export public @transpose_3d_yes_dispatch_0_generic_10x768x2048 ordinal(0) layout(#pipeline_layout)
+    count(%arg0: !hal.device, %arg1: index, %arg2: index, %arg3: index) -> (index, index, index) {
       %x, %y, %z = iree_tensor_ext.dispatch.workgroup_count_from_dag_root(%arg1, %arg2, %arg3)
       hal.return %x, %y, %z : index, index, index
     }
@@ -192,34 +176,23 @@
   }
 }
 
-// CHECK-LABEL:   hal.executable public @transpose_3d_yes_dispatch_0_generic_10x768x2048 {
-//   CHECK-DAG:   %[[PV:.*]] = ub.poison : f32
-//   CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
-//       CHECK:   %[[TX:.*]] = gpu.thread_id  x
-//       CHECK:   %[[TY:.*]] = gpu.thread_id  y
-//       CHECK:   %[[ALLOC:.*]] = memref.alloc() : memref<1x32x33xf32, #gpu.address_space<workgroup>>
-//       CHECK:   %[[D0_BINDING:.*]] = hal.interface.binding.subspan layout({{.+}}) binding(0) alignment(64) offset(%[[C0]]) : memref<10x2048x768xf32, #hal.descriptor_type<storage_buffer>>
-//       CHECK:   %[[D0:.+]] = memref.assume_alignment %[[D0_BINDING]], 64 : memref<10x2048x768xf32, #hal.descriptor_type<storage_buffer>>
-//       CHECK:   %[[D1_BINDING:.*]] = hal.interface.binding.subspan layout({{.+}}) binding(1) alignment(64) offset(%[[C0]]) : memref<10x768x2048xf32, #hal.descriptor_type<storage_buffer>>
-//       CHECK:   %[[D1:.+]] = memref.assume_alignment %[[D1_BINDING]], 64 : memref<10x768x2048xf32, #hal.descriptor_type<storage_buffer>>
-//       CHECK:   %[[D2_BINDING:.*]] = hal.interface.binding.subspan layout({{.+}}) binding(2) alignment(64) offset(%[[C0]]) : memref<10x768x2048xf32, #hal.descriptor_type<storage_buffer>>
-//       CHECK:   %[[D2:.+]] = memref.assume_alignment %[[D2_BINDING]], 64 : memref<10x768x2048xf32, #hal.descriptor_type<storage_buffer>>
+// CHECK-LABEL:   func @transpose_3d_yes_dispatch_0_generic_10x768x2048
+//       CHECK:   %[[A:.*]] = memref.alloc() : memref<1x32x34xf32, #gpu.address_space<workgroup>>
+//   CHECK-DAG:   %[[B0:.*]] = hal.interface.binding.subspan layout({{.+}}) binding(0)
+//   CHECK-DAG:   %[[A0:.*]] = memref.assume_alignment %[[B0]]
+//   CHECK-DAG:   %[[B1:.*]] = hal.interface.binding.subspan layout({{.+}}) binding(1)
+//   CHECK-DAG:   %[[A1:.*]] = memref.assume_alignment %[[B1]]
+//   CHECK-DAG:   %[[B2:.*]] = hal.interface.binding.subspan layout({{.+}}) binding(2)
+//   CHECK-DAG:   %[[A2:.*]] = memref.assume_alignment %[[B2]]
 //       CHECK:   gpu.barrier
-//       CHECK:   %[[D3:.*]] = affine.apply #{{.*}}()[%[[TX]]]
-//       CHECK:   %[[D4:.*]] = affine.apply #{{.*}}()[%{{.*}}, %[[TY]]]
-//       CHECK:   %[[D5:.*]] = affine.apply #{{.*}}()[%{{.*}}, %[[TX]]]
-//       CHECK:   %[[D6:.*]] = vector.transfer_read %[[D0]][%{{.*}}, %[[D4]], %[[D5]]], %[[PV]] {in_bounds = [true, true, true]} : memref<10x2048x768xf32, #hal.descriptor_type<storage_buffer>>, vector<1x1x4xf32>
-//       CHECK:   vector.transfer_write %[[D6]], %[[ALLOC]][%[[C0]], %[[TY]], %[[D3]]] {in_bounds = [true, true, true]} : vector<1x1x4xf32>, memref<1x32x33xf32, #gpu.address_space<workgroup>>
+//       CHECK:   %[[GR0:.*]] = vector.transfer_read %[[A0]]{{.*}} vector<4xf32>
+//       CHECK:   vector.transfer_write %[[GR0]], %[[A]]{{.*}} : vector<4xf32>
 //       CHECK:   gpu.barrier
-//       CHECK:   %[[D7:.*]] = vector.transfer_read %[[ALLOC]][%[[C0]], %[[D3]], %[[TY]]], %[[PV]] {in_bounds = [true, true]} : memref<1x32x33xf32, #gpu.address_space<workgroup>>, vector<4x1xf32>
-//       CHECK:   %[[D8:.*]] = arith.addi %[[TY]], %{{.*}}
-//       CHECK:   %[[D9:.*]] = arith.addi %[[D3]], %{{.*}}
-//       CHECK:   %[[D10:.*]] = vector.transfer_read %[[D1]][%{{.*}}, %[[D8]], %[[D9]]], %[[PV]] {in_bounds = [true]} : memref<10x768x2048xf32, #hal.descriptor_type<storage_buffer>>, vector<4xf32>
-//       CHECK:   %[[D11:.*]] = vector.shape_cast %[[D7]] : vector<4x1xf32> to vector<4xf32>
-//       CHECK:   %[[D12:.*]] = arith.addf %[[D11]], %[[D10]] : vector<4xf32>
-//       CHECK:   %[[D13:.*]] = affine.apply #{{.*}}()[%{{.*}}, %[[TY]]]
-//       CHECK:   %[[D14:.*]] = affine.apply #{{.*}}()[%{{.*}}, %[[TX]]]
-//       CHECK:   vector.transfer_write %[[D12]], %[[D2]][%{{.*}}, %[[D13]], %[[D14]]] {in_bounds = [true]} : vector<4xf32>, memref<10x768x2048xf32, #hal.descriptor_type<storage_buffer>>
+//       CHECK:   %[[SR:.*]] = vector.transfer_read %[[A]]{{.*}} vector<4x1xf32>
+//       CHECK:   %[[GR1:.*]] = vector.transfer_read %[[A1]]{{.*}} vector<4xf32>
+//       CHECK:   %[[SC:.*]] = vector.shape_cast %[[SR]] : vector<4x1xf32> to vector<4xf32>
+//       CHECK:   %[[ADD:.*]] = arith.addf %[[SC]], %[[GR1]] : vector<4xf32>
+//       CHECK:   vector.transfer_write %[[ADD]], %[[A2]]{{.*}} : vector<4xf32>
 
 // -----
 
@@ -231,7 +204,8 @@
 #executable_target_cuda_nvptx_fb = #hal.executable.target<"cuda", "cuda-nvptx-fb">
 hal.executable @transpose_3d_trans_out_dispatch_0_generic_10x2048x768 {
   hal.executable.variant public @cuda_nvptx_fb target(#executable_target_cuda_nvptx_fb) {
-    hal.executable.export public @transpose_3d_trans_out_dispatch_0_generic_10x2048x768 ordinal(0) layout(#pipeline_layout) count(%arg0: !hal.device, %arg1: index, %arg2: index, %arg3: index) -> (index, index, index) {
+    hal.executable.export public @transpose_3d_trans_out_dispatch_0_generic_10x2048x768 ordinal(0) layout(#pipeline_layout)
+    count(%arg0: !hal.device, %arg1: index, %arg2: index, %arg3: index) -> (index, index, index) {
       %x, %y, %z = iree_tensor_ext.dispatch.workgroup_count_from_dag_root(%arg1, %arg2, %arg3)
       hal.return %x, %y, %z : index, index, index
     }
@@ -256,35 +230,26 @@
   }
 }
 
-// CHECK-LABEL:   hal.executable public @transpose_3d_trans_out_dispatch_0_generic_10x2048x768 {
-//   CHECK-DAG:   %[[PV:.*]] = ub.poison
-//   CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
-//       CHECK:   %[[TX:.*]] = gpu.thread_id  x
-//       CHECK:   %[[TY:.*]] = gpu.thread_id  y
-//       CHECK:   %[[ALLOC:.*]] = memref.alloc() : memref<1x32x33xf32, #gpu.address_space<workgroup>>
-//       CHECK:   %[[ALLOC1:.*]] = memref.alloc() : memref<1x32x33xf32, #gpu.address_space<workgroup>>
-//       CHECK:   %[[D0_BINDING:.*]] = hal.interface.binding.subspan layout({{.+}}) binding(0) alignment(64) offset(%[[C0]]) : memref<10x768x2048xf32, #hal.descriptor_type<storage_buffer>>
-//       CHECK:   %[[D0:.+]] = memref.assume_alignment %[[D0_BINDING]], 64 : memref<10x768x2048xf32, #hal.descriptor_type<storage_buffer>>
-//       CHECK:   %[[D1_BINDING:.*]] = hal.interface.binding.subspan layout({{.+}}) binding(1) alignment(64) offset(%[[C0]]) : memref<10x768x2048xf32, #hal.descriptor_type<storage_buffer>>
-//       CHECK:   %[[D1:.+]] = memref.assume_alignment %[[D1_BINDING]], 64 : memref<10x768x2048xf32, #hal.descriptor_type<storage_buffer>>
-//       CHECK:   %[[D2_BINDING:.*]] = hal.interface.binding.subspan layout({{.+}}) binding(2) alignment(64) offset(%[[C0]]) : memref<10x2048x768xf32, #hal.descriptor_type<storage_buffer>>
-//       CHECK:   %[[D2:.+]] = memref.assume_alignment %[[D2_BINDING]], 64 : memref<10x2048x768xf32, #hal.descriptor_type<storage_buffer>>
+// CHECK-LABEL:   func @transpose_3d_trans_out_dispatch_0_generic_10x2048x768
+//       CHECK:   %[[A0:.*]] = memref.alloc() : memref<1x32x34xf32, #gpu.address_space<workgroup>>
+//       CHECK:   %[[A1:.*]] = memref.alloc() : memref<1x32x34xf32, #gpu.address_space<workgroup>>
+//   CHECK-DAG:   %[[B0:.*]] = hal.interface.binding.subspan layout({{.+}}) binding(0)
+//   CHECK-DAG:   %[[BA0:.*]] = memref.assume_alignment %[[B0]]
+//   CHECK-DAG:   %[[B1:.*]] = hal.interface.binding.subspan layout({{.+}}) binding(1)
+//   CHECK-DAG:   %[[BA1:.*]] = memref.assume_alignment %[[B1]]
+//   CHECK-DAG:   %[[B2:.*]] = hal.interface.binding.subspan layout({{.+}}) binding(2)
+//   CHECK-DAG:   %[[BA2:.*]] = memref.assume_alignment %[[B2]]
 //       CHECK:   gpu.barrier
-//       CHECK:   %[[D3:.*]] = affine.apply #{{.*}}()[%{{.*}}, %[[TY]]]
-//       CHECK:   %[[D4:.*]] = affine.apply #{{.*}}()[%{{.*}}, %[[TX]]]
-//       CHECK:   %[[D5:.*]] = vector.transfer_read %[[D0]][%{{.*}}, %[[D3]], %[[D4]]], %[[PV]] {in_bounds = [true, true, true]} : memref<10x768x2048xf32, #hal.descriptor_type<storage_buffer>>, vector<1x1x4xf32>
-//       CHECK:   %[[D6:.*]] = affine.apply #{{.*}}()[%[[TX]]]
-//       CHECK:   vector.transfer_write %[[D5]], %[[ALLOC1]][%[[C0]], %[[TY]], %[[D6]]] {in_bounds = [true, true, true]} : vector<1x1x4xf32>, memref<1x32x33xf32, #gpu.address_space<workgroup>>
-//       CHECK:   %[[D7:.*]] = vector.transfer_read %[[D1]][%{{.*}}, %[[D3]], %[[D4]]], %[[PV]] {in_bounds = [true, true, true]} : memref<10x768x2048xf32, #hal.descriptor_type<storage_buffer>>, vector<1x1x4xf32>
-//       CHECK:   vector.transfer_write %[[D7]], %[[ALLOC]][%[[C0]], %[[TY]], %[[D6]]] {in_bounds = [true, true, true]} : vector<1x1x4xf32>, memref<1x32x33xf32, #gpu.address_space<workgroup>>
+//       CHECK:   %[[GR0:.*]] = vector.transfer_read %[[BA0]]{{.*}} vector<4xf32>
+//       CHECK:   vector.transfer_write %[[GR0]], %[[A0]]{{.*}} : vector<4xf32>
+//       CHECK:   %[[GR1:.*]] = vector.transfer_read %[[BA1]]{{.*}} vector<4xf32>
+//       CHECK:   vector.transfer_write %[[GR1]], %[[A1]]{{.*}} : vector<4xf32>
 //       CHECK:   gpu.barrier
-//       CHECK:   %[[D8:.*]] = vector.transfer_read %[[ALLOC1]][%[[C0]], %[[D6]], %[[TY]]], %[[PV]] {in_bounds = [true, true]} : memref<1x32x33xf32, #gpu.address_space<workgroup>>, vector<4x1xf32>
-//       CHECK:   %[[D9:.*]] = vector.transfer_read %[[ALLOC]][%[[C0]], %[[D6]], %[[TY]]], %[[PV]] {in_bounds = [true, true]} : memref<1x32x33xf32, #gpu.address_space<workgroup>>, vector<4x1xf32>
-//       CHECK:   %[[D10:.*]] = arith.addf %[[D8]], %[[D9]] : vector<4x1xf32>
-//       CHECK:   %[[D11:.*]] = vector.shape_cast %[[D10]] : vector<4x1xf32> to vector<4xf32>
-//       CHECK:   %[[D12:.*]] = affine.apply #{{.*}}()[%{{.*}}, %[[TY]]]
-//       CHECK:   %[[D13:.*]] = affine.apply #{{.*}}()[%{{.*}}, %[[TX]]]
-//       CHECK:   vector.transfer_write %[[D11]], %[[D2]][%{{.*}}, %[[D12]], %[[D13]]] {in_bounds = [true]} : vector<4xf32>, memref<10x2048x768xf32, #hal.descriptor_type<storage_buffer>>
+//       CHECK:   %[[SR0:.*]] = vector.transfer_read %[[A0]]{{.*}} vector<4x1xf32>
+//       CHECK:   %[[SR1:.*]] = vector.transfer_read %[[A1]]{{.*}} vector<4x1xf32>
+//       CHECK:   %[[ADD:.*]] = arith.addf %[[SR0]], %[[SR1]] : vector<4x1xf32>
+//       CHECK:   %[[SC:.*]] = vector.shape_cast %[[ADD]] : vector<4x1xf32> to vector<4xf32>
+//       CHECK:   vector.transfer_write %[[SC]], %[[BA2]]{{.*}} : vector<4xf32>
 
 // -----
 
@@ -296,7 +261,8 @@
 #executable_target_cuda_nvptx_fb = #hal.executable.target<"cuda", "cuda-nvptx-fb">
 hal.executable @transpose_3d_diff_dispatch_0_generic_10x768x2048 {
   hal.executable.variant public @cuda_nvptx_fb target(#executable_target_cuda_nvptx_fb) {
-    hal.executable.export public @transpose_3d_diff_dispatch_0_generic_10x768x2048 ordinal(0) layout(#pipeline_layout) count(%arg0: !hal.device, %arg1: index, %arg2: index, %arg3: index) -> (index, index, index) {
+    hal.executable.export public @transpose_3d_diff_dispatch_0_generic_10x768x2048 ordinal(0) layout(#pipeline_layout)
+    count(%arg0: !hal.device, %arg1: index, %arg2: index, %arg3: index) -> (index, index, index) {
       %x, %y, %z = iree_tensor_ext.dispatch.workgroup_count_from_dag_root(%arg1, %arg2, %arg3)
       hal.return %x, %y, %z : index, index, index
     }