Re-land "[Dispatch] extend CollapseDimensionsPass to more…" (#19461)

Re-lands https://github.com/iree-org/iree/pull/19326 after adding a few
more matmul specs that were missed. I left `@match_mmt_8192x640x2560`
commented out because the config was causing a slight regression in perf
but the overall time spent was very low.

This reverts commit 7177c29.

---------

Signed-off-by: Ian Wood <ianwood2024@u.northwestern.edu>
diff --git a/.github/workflows/pkgci_regression_test.yml b/.github/workflows/pkgci_regression_test.yml
index 0491d2b..649f7e6 100644
--- a/.github/workflows/pkgci_regression_test.yml
+++ b/.github/workflows/pkgci_regression_test.yml
@@ -220,7 +220,7 @@
             --goldentime-rocm-unet-ms 419.0 \
             --goldentime-rocm-clip-ms 18.5 \
             --goldentime-rocm-vae-ms 337.0 \
-            --goldendispatch-rocm-unet 1602 \
+            --goldendispatch-rocm-unet 1598 \
             --goldendispatch-rocm-clip 1139 \
             --goldendispatch-rocm-vae 246 \
             --goldensize-rocm-unet-bytes 2280000  \
@@ -242,17 +242,17 @@
             --goldentime-rocm-unet-ms 80.0 \
             --goldentime-rocm-clip-ms 15.5 \
             --goldentime-rocm-vae-ms 80.0 \
-            --goldendispatch-rocm-unet 1602 \
+            --goldendispatch-rocm-unet 1598 \
             --goldendispatch-rocm-clip 1139 \
             --goldendispatch-rocm-vae 246 \
             --goldensize-rocm-unet-bytes 2270000 \
             --goldensize-rocm-clip-bytes 860000  \
             --goldensize-rocm-vae-bytes 840000 \
-            --goldentime-rocm-punet-int8-fp16-ms 53 \
-            --goldendispatch-rocm-punet-int8-fp16 1424 \
+            --goldentime-rocm-punet-int8-fp16-ms 51 \
+            --goldendispatch-rocm-punet-int8-fp16 1416 \
             --goldensize-rocm-punet-int8-fp16-bytes 2560000 \
-            --goldentime-rocm-punet-int8-fp8-ms 53 \
-            --goldendispatch-rocm-punet-int8-fp8 1704 \
+            --goldentime-rocm-punet-int8-fp8-ms 51 \
+            --goldendispatch-rocm-punet-int8-fp8 1696 \
             --goldensize-rocm-punet-int8-fp8-bytes 2800000 \
             --rocm-chip gfx942 \
             --log-cli-level=info \
diff --git a/build_tools/pkgci/external_test_suite/attention_and_matmul_spec_punet.mlir b/build_tools/pkgci/external_test_suite/attention_and_matmul_spec_punet.mlir
index 16049fa..969f31a 100644
--- a/build_tools/pkgci/external_test_suite/attention_and_matmul_spec_punet.mlir
+++ b/build_tools/pkgci/external_test_suite/attention_and_matmul_spec_punet.mlir
@@ -101,6 +101,142 @@
 // Matmul tuning
 //===----------------------------------------------------------------------===//
 
+transform.named_sequence @match_mmt_i8_i8_i32(%root: !transform.any_op {transform.readonly}) -> (!transform.any_op) {
+  transform.match.operation_name %root ["linalg.generic"] : !transform.any_op
+  // transform.print %root {name = "Generic"} : !transform.any_op
+  %ins, %outs = transform.iree.match.cast_compatible_dag_from_root %root {
+    ^bb0(%lhs: tensor<?x?xi8>, %rhs: tensor<?x?xi8>, %out: tensor<?x?xi32>):
+    %7 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>,
+                                          affine_map<(d0, d1, d2) -> (d1, d2)>,
+                                          affine_map<(d0, d1, d2) -> (d0, d1)>],
+                         iterator_types = ["parallel", "parallel", "reduction"]}
+        ins(%lhs, %rhs : tensor<?x?xi8>, tensor<?x?xi8>) outs(%out : tensor<?x?xi32>) {
+      ^bb0(%in: i8, %in_0: i8, %acc: i32):
+        %18 = arith.extsi %in : i8 to i32
+        %19 = arith.extsi %in_0 : i8 to i32
+        %20 = arith.muli %18, %19 : i32
+        %21 = arith.addi %acc, %20 : i32
+        linalg.yield %21 : i32
+      } -> tensor<?x?xi32>
+  } : (!transform.any_op) -> (!transform.any_value, !transform.any_value)
+  transform.yield %root : !transform.any_op
+}
+
+transform.named_sequence @match_mmt_2048x10240x1280(%matmul: !transform.any_op {transform.readonly}) -> (!transform.any_op, !transform.any_param) {
+  %mmt = transform.include @match_mmt_i8_i8_i32 failures(propagate) (%matmul) : (!transform.any_op) -> !transform.any_op
+  %lhs = transform.get_operand %matmul[0] : (!transform.any_op) -> !transform.any_value
+  %rhs = transform.get_operand %matmul[1] : (!transform.any_op) -> !transform.any_value
+  transform.iree.match.cast_compatible_type %lhs = tensor<2048x1280xi8> : !transform.any_value
+  transform.iree.match.cast_compatible_type %rhs = tensor<10240x1280xi8> : !transform.any_value
+  %config = transform.param.constant #iree_codegen.compilation_info<
+    lowering_config = #iree_gpu.lowering_config<{promote_operands = [0, 1],
+                                                 mma_kind = #iree_gpu.mma_layout<MFMA_I32_16x16x32_I8>,
+                                                 subgroup_m_count = 4, subgroup_n_count = 2,
+                                                 reduction = [0, 0, 128],
+                                                 workgroup = [128, 320, 0]}>,
+    translation_info = #iree_codegen.translation_info<pipeline = LLVMGPUVectorDistribute
+      workgroup_size = [128, 4, 1] subgroup_size = 64,
+      {gpu_pipeline_options = #iree_gpu.pipeline_options<prefetch_shared_memory = true>
+      }>> -> !transform.any_param
+  transform.yield %matmul, %config : !transform.any_op, !transform.any_param
+}
+
+transform.named_sequence @match_mmt_2048x1280x5120(%matmul: !transform.any_op {transform.readonly}) -> (!transform.any_op, !transform.any_param) {
+  %mmt = transform.include @match_mmt_i8_i8_i32 failures(propagate) (%matmul) : (!transform.any_op) -> !transform.any_op
+  %lhs = transform.get_operand %matmul[0] : (!transform.any_op) -> !transform.any_value
+  %rhs = transform.get_operand %matmul[1] : (!transform.any_op) -> !transform.any_value
+  transform.iree.match.cast_compatible_type %lhs = tensor<2048x5120xi8> : !transform.any_value
+  transform.iree.match.cast_compatible_type %rhs = tensor<1280x5120xi8> : !transform.any_value
+  %config = transform.param.constant #iree_codegen.compilation_info<
+    lowering_config = #iree_gpu.lowering_config<{promote_operands = [0, 1],
+                                                 mma_kind = #iree_gpu.mma_layout<MFMA_I32_16x16x32_I8>,
+                                                 subgroup_m_count = 4, subgroup_n_count = 1,
+                                                 reduction = [0, 0, 256],
+                                                 workgroup = [128, 80, 0]}>,
+    translation_info = #iree_codegen.translation_info<pipeline = LLVMGPUVectorDistribute
+      workgroup_size = [64, 4, 1] subgroup_size = 64,
+      {gpu_pipeline_options = #iree_gpu.pipeline_options<prefetch_shared_memory = true>
+      }>> -> !transform.any_param
+  transform.yield %matmul, %config : !transform.any_op, !transform.any_param
+}
+
+transform.named_sequence @match_mmt_2048x1280x1280(%matmul: !transform.any_op {transform.readonly}) -> (!transform.any_op, !transform.any_param) {
+  %mmt = transform.include @match_mmt_i8_i8_i32 failures(propagate) (%matmul) : (!transform.any_op) -> !transform.any_op
+  %lhs = transform.get_operand %matmul[0] : (!transform.any_op) -> !transform.any_value
+  %rhs = transform.get_operand %matmul[1] : (!transform.any_op) -> !transform.any_value
+  transform.iree.match.cast_compatible_type %lhs = tensor<2048x1280xi8> : !transform.any_value
+  transform.iree.match.cast_compatible_type %rhs = tensor<1280x1280xi8> : !transform.any_value
+  %config = transform.param.constant #iree_codegen.compilation_info<
+    lowering_config = #iree_gpu.lowering_config<{promote_operands = [0, 1],
+                                                 mma_kind = #iree_gpu.mma_layout<MFMA_I32_16x16x32_I8>,
+                                                 subgroup_m_count = 2, subgroup_n_count = 2,
+                                                 reduction = [0, 0, 128],
+                                                 workgroup = [64, 160, 0]}>,
+    translation_info = #iree_codegen.translation_info<pipeline = LLVMGPUVectorDistribute
+      workgroup_size = [256, 1, 1] subgroup_size = 64,
+      {gpu_pipeline_options = #iree_gpu.pipeline_options<prefetch_shared_memory = true,
+                                                         reorder_workgroups_strategy = <Transpose>>}>
+  > -> !transform.any_param
+  transform.yield %matmul, %config : !transform.any_op, !transform.any_param
+}
+
+transform.named_sequence @match_mmt_8192x640x640(%matmul: !transform.any_op {transform.readonly}) -> (!transform.any_op, !transform.any_param) {
+  %mmt = transform.include @match_mmt_i8_i8_i32 failures(propagate) (%matmul) : (!transform.any_op) -> !transform.any_op
+  %lhs = transform.get_operand %matmul[0] : (!transform.any_op) -> !transform.any_value
+  %rhs = transform.get_operand %matmul[1] : (!transform.any_op) -> !transform.any_value
+  transform.iree.match.cast_compatible_type %lhs = tensor<8192x640xi8> : !transform.any_value
+  transform.iree.match.cast_compatible_type %rhs = tensor<640x640xi8> : !transform.any_value
+  %config = transform.param.constant #iree_codegen.compilation_info<
+    lowering_config = #iree_gpu.lowering_config<{promote_operands = [0, 1],
+                                                 mma_kind = #iree_gpu.mma_layout<MFMA_I32_16x16x32_I8>,
+                                                 subgroup_m_count = 8, subgroup_n_count = 1,
+                                                 reduction = [0, 0, 64],
+                                                 workgroup = [256, 64, 0]}>,
+    translation_info = #iree_codegen.translation_info<pipeline = LLVMGPUVectorDistribute
+      workgroup_size = [512, 1, 1] subgroup_size = 64,
+      {gpu_pipeline_options = #iree_gpu.pipeline_options<prefetch_shared_memory = true>}>
+  > -> !transform.any_param
+  transform.yield %matmul, %config : !transform.any_op, !transform.any_param
+}
+
+transform.named_sequence @match_mmt_8192x5120x640(%matmul: !transform.any_op {transform.readonly}) -> (!transform.any_op, !transform.any_param) {
+  %mmt = transform.include @match_mmt_i8_i8_i32 failures(propagate) (%matmul) : (!transform.any_op) -> !transform.any_op
+  %lhs = transform.get_operand %matmul[0] : (!transform.any_op) -> !transform.any_value
+  %rhs = transform.get_operand %matmul[1] : (!transform.any_op) -> !transform.any_value
+  transform.iree.match.cast_compatible_type %lhs = tensor<8192x640xi8> : !transform.any_value
+  transform.iree.match.cast_compatible_type %rhs = tensor<5120x640xi8> : !transform.any_value
+  %config = transform.param.constant #iree_codegen.compilation_info<
+    lowering_config = #iree_gpu.lowering_config<{promote_operands = [0, 1],
+                                                 mma_kind = #iree_gpu.mma_layout<MFMA_I32_32x32x16_I8>,
+                                                 subgroup_m_count = 2, subgroup_n_count = 4,
+                                                 reduction = [0, 0, 64],
+                                                 workgroup = [256, 128, 0]}>,
+    translation_info = #iree_codegen.translation_info<pipeline = LLVMGPUVectorDistribute
+      workgroup_size = [512, 1, 1] subgroup_size = 64,
+      {gpu_pipeline_options = #iree_gpu.pipeline_options<prefetch_shared_memory = true>}>
+  > -> !transform.any_param
+  transform.yield %matmul, %config : !transform.any_op, !transform.any_param
+}
+
+transform.named_sequence @match_mmt_8192x640x2560 (%matmul: !transform.any_op {transform.readonly}) -> (!transform.any_op, !transform.any_param) {
+  %mmt = transform.include @match_mmt_i8_i8_i32 failures(propagate) (%matmul) : (!transform.any_op) -> !transform.any_op
+  %lhs = transform.get_operand %matmul[0] : (!transform.any_op) -> !transform.any_value
+  %rhs = transform.get_operand %matmul[1] : (!transform.any_op) -> !transform.any_value
+  transform.iree.match.cast_compatible_type %lhs = tensor<8192x2560xi8> : !transform.any_value
+  transform.iree.match.cast_compatible_type %rhs = tensor<640x2560xi8> : !transform.any_value
+  %config = transform.param.constant #iree_codegen.compilation_info<
+    lowering_config = #iree_gpu.lowering_config<{promote_operands = [0, 1],
+                                                 mma_kind = #iree_gpu.mma_layout<MFMA_I32_16x16x32_I8>,
+                                                 subgroup_m_count = 8, subgroup_n_count = 1,
+                                                 reduction = [0, 0, 64],
+                                                 workgroup = [256, 64, 0]}>,
+    translation_info = #iree_codegen.translation_info<pipeline = LLVMGPUVectorDistribute
+      workgroup_size = [512, 1, 1] subgroup_size = 64,
+      {gpu_pipeline_options = #iree_gpu.pipeline_options<prefetch_shared_memory = true>}>
+  > -> !transform.any_param
+  transform.yield %matmul, %config : !transform.any_op, !transform.any_param
+}
+
 //===----------------------------------------------------------------------===//
 // Convolution tuning
 //===----------------------------------------------------------------------===//
@@ -152,6 +288,65 @@
     transform.yield %generic, %config : !transform.any_op, !transform.any_param
   }
 
+  transform.named_sequence @match_broadcast_rhs_mmt_Bx64x1280x2480(%generic: !transform.any_op {transform.readonly}) -> (!transform.any_op, !transform.any_param) {
+    %mmt = transform.include @match_broadcast_rhs_mmt_i8_i8_i32 failures(propagate) (%generic) : (!transform.any_op) -> !transform.any_op
+    %lhs = transform.get_operand %generic[0] : (!transform.any_op) -> !transform.any_value
+    %rhs = transform.get_operand %generic[1] : (!transform.any_op) -> !transform.any_value
+    transform.iree.match.cast_compatible_type %lhs = tensor<?x64x2480xi8> : !transform.any_value
+    transform.iree.match.cast_compatible_type %rhs = tensor<1280x2480xi8> : !transform.any_value
+    %config = transform.param.constant #iree_codegen.compilation_info<
+      lowering_config = #iree_gpu.lowering_config<{promote_operands = [0, 1],
+                                                   mma_kind = #iree_gpu.mma_layout<MFMA_I32_16x16x32_I8>,
+                                                   subgroup_m_count = 2, subgroup_n_count = 2,
+                                                   reduction = [0, 0, 0, 128],
+                                                   workgroup = [1, 64, 160, 0]}>,
+      translation_info = #iree_codegen.translation_info<pipeline = LLVMGPUVectorDistribute
+        workgroup_size = [256, 1, 1] subgroup_size = 64,
+        {gpu_pipeline_options = #iree_gpu.pipeline_options<prefetch_shared_memory = true,
+                                                           reorder_workgroups_strategy = <Transpose>>
+        }>
+    > -> !transform.any_param
+    transform.yield %generic, %config : !transform.any_op, !transform.any_param
+  }
+
+  transform.named_sequence @match_broadcast_rhs_mmt_Bx4960x640x640(%generic: !transform.any_op {transform.readonly}) -> (!transform.any_op, !transform.any_param) {
+    %mmt = transform.include @match_broadcast_rhs_mmt_i8_i8_i32 failures(propagate) (%generic) : (!transform.any_op) -> !transform.any_op
+    %lhs = transform.get_operand %generic[0] : (!transform.any_op) -> !transform.any_value
+    %rhs = transform.get_operand %generic[1] : (!transform.any_op) -> !transform.any_value
+    transform.iree.match.cast_compatible_type %lhs = tensor<?x4960x640xi8> : !transform.any_value
+    transform.iree.match.cast_compatible_type %rhs = tensor<640x640xi8> : !transform.any_value
+    %config = transform.param.constant #iree_codegen.compilation_info<
+      lowering_config = #iree_gpu.lowering_config<{promote_operands = [0, 1],
+                                                   mma_kind = #iree_gpu.mma_layout<MFMA_I32_16x16x32_I8>,
+                                                   subgroup_m_count = 8, subgroup_n_count = 1,
+                                                   reduction = [0, 0, 0, 64],
+                                                   workgroup = [1, 256, 64, 0]}>,
+      translation_info = #iree_codegen.translation_info<pipeline = LLVMGPUVectorDistribute
+        workgroup_size = [512, 1, 1] subgroup_size = 64,
+        {gpu_pipeline_options = #iree_gpu.pipeline_options<prefetch_shared_memory = true>}>
+    > -> !transform.any_param
+    transform.yield %generic, %config : !transform.any_op, !transform.any_param
+  }
+
+  transform.named_sequence @match_broadcast_rhs_mmt_Bx64x640x2480(%generic: !transform.any_op {transform.readonly}) -> (!transform.any_op, !transform.any_param) {
+    %mmt = transform.include @match_broadcast_rhs_mmt_i8_i8_i32 failures(propagate) (%generic) : (!transform.any_op) -> !transform.any_op
+    %lhs = transform.get_operand %generic[0] : (!transform.any_op) -> !transform.any_value
+    %rhs = transform.get_operand %generic[1] : (!transform.any_op) -> !transform.any_value
+    transform.iree.match.cast_compatible_type %lhs = tensor<?x64x2480xi8> : !transform.any_value
+    transform.iree.match.cast_compatible_type %rhs = tensor<640x2480xi8> : !transform.any_value
+    %config = transform.param.constant #iree_codegen.compilation_info<
+      lowering_config = #iree_gpu.lowering_config<{promote_operands = [0, 1],
+                                                   mma_kind = #iree_gpu.mma_layout<MFMA_I32_16x16x32_I8>,
+                                                   subgroup_m_count = 2, subgroup_n_count = 1,
+                                                   reduction = [0, 0, 0, 128],
+                                                   workgroup = [1, 32, 320, 0]}>,
+      translation_info = #iree_codegen.translation_info<pipeline = LLVMGPUVectorDistribute
+        workgroup_size = [128, 1, 1] subgroup_size = 64,
+        {gpu_pipeline_options = #iree_gpu.pipeline_options<prefetch_shared_memory = true>}>
+    > -> !transform.any_param
+    transform.yield %generic, %config : !transform.any_op, !transform.any_param
+  }
+
   transform.named_sequence @match_broadcast_rhs_mmt_Bx4096x5120x640(%generic: !transform.any_op {transform.readonly}) -> (!transform.any_op, !transform.any_param) {
     %mmt = transform.include @match_broadcast_rhs_mmt_i8_i8_i32 failures(propagate) (%generic) : (!transform.any_op) -> !transform.any_op
     %lhs = transform.get_operand %generic[0] : (!transform.any_op) -> !transform.any_value
@@ -352,6 +547,12 @@
         // TUNING_MATCH_BEGIN DO NOT REMOVE
 
         // Matmul.
+        , @match_mmt_2048x10240x1280 -> @apply_op_config
+        , @match_mmt_2048x1280x5120 -> @apply_op_config
+        , @match_mmt_2048x1280x1280 -> @apply_op_config
+        , @match_mmt_8192x640x640 -> @apply_op_config
+        , @match_mmt_8192x5120x640 -> @apply_op_config
+        //, @match_mmt_8192x640x2560 -> @apply_op_config
 
         // Convolution.
 
@@ -363,6 +564,10 @@
         // Carried over from SPX.
         , @match_broadcast_rhs_mmt_Bx1024x10240x1280 -> @apply_op_config
         , @match_broadcast_rhs_mmt_Bx1024x1280x1280 -> @apply_op_config
+        , @match_broadcast_rhs_mmt_Bx64x1280x2480 -> @apply_op_config
+        , @match_broadcast_rhs_mmt_Bx4960x640x640 -> @apply_op_config
+        //, @match_broadcast_rhs_mmt_Bx64x640x2480 -> @apply_op_config
+
 
         // Contration.
         , @match_matmul_like_Bx20x1024x64x1280_i8xi8xi32 -> @apply_op_config
diff --git a/compiler/src/iree/compiler/DispatchCreation/CollapseDimensions.cpp b/compiler/src/iree/compiler/DispatchCreation/CollapseDimensions.cpp
index ba79578..f1fe77b 100644
--- a/compiler/src/iree/compiler/DispatchCreation/CollapseDimensions.cpp
+++ b/compiler/src/iree/compiler/DispatchCreation/CollapseDimensions.cpp
@@ -12,6 +12,7 @@
 #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
 #include "iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h"
 #include "iree/compiler/DispatchCreation/Passes.h"
+#include "llvm/ADT/DenseSet.h"
 #include "llvm/Support/Debug.h"
 #include "mlir/Analysis/SliceAnalysis.h"
 #include "mlir/Dialect/Affine/Utils.h"
@@ -121,36 +122,45 @@
            (rDimsSet.count(prePos) && rDimsSet.count(nextPos));
   };
 
-  ReassociationIndices range;
-  AffineExpr preExpr;
   // Find the largest sequence of dimensions that are
   // - Either preserved in all maps, or
   // - are completely absent
   // This sequence can be collapsed. To find the sequence,
-  // 1) Take the result expressions of one of the indexing maps
-  // 2) Find a sequence of 2 that is found in all maps
+  // 1) For each indexing map, take the result expressions
+  // 2) Find a sequence of 2 that is found in all maps (or absent)
   // 3) Then take last element of this sequence and the next
   //    result expression, and check if this sequence of 2 is
   //    found in all maps. If so, add to sequence (to get a sequence of 3)
   //    and repeat till the last element of sequence and the next result
   //    expression is not found as a sequence in all maps.
-  for (auto nextExpr :
-       fusionInterfaceOp.getIndexingMapsArray().front().getResults()) {
-    unsigned position = cast<AffineDimExpr>(nextExpr).getPosition();
-    if (!range.empty()) {
+
+  llvm::SmallSetVector<unsigned, 8> seenLoops;
+  for (auto map : fusionInterfaceOp.getIndexingMapsArray()) {
+    ReassociationIndices range;
+    AffineExpr preExpr;
+
+    auto appendAndClearRange = [&]() {
+      if (range.size() > 1) {
+        contiguousLoops.push_back(range);
+      }
+      range.clear();
+    };
+
+    for (auto nextExpr : map.getResults()) {
+      unsigned position = cast<AffineDimExpr>(nextExpr).getPosition();
+      if (seenLoops.contains(position)) {
+        appendAndClearRange();
+        continue;
+      }
       if (!hasAllMapsSameSequence(preExpr, nextExpr) ||
           !hasSameIteratorType(preExpr, nextExpr)) {
-        if (range.size() > 1) {
-          contiguousLoops.push_back({range.begin(), range.end()});
-        }
-        range.clear();
+        appendAndClearRange();
       }
+      range.push_back(position);
+      seenLoops.insert(position);
+      preExpr = nextExpr;
     }
-    range.push_back(position);
-    preExpr = nextExpr;
-  }
-  if (range.size() > 1) {
-    contiguousLoops.push_back(range);
+    appendAndClearRange();
   }
 
   return contiguousLoops;
@@ -192,21 +202,20 @@
   }
 
   // TODO(#17948) GPU codegen fails when we collapse the dimensions of softmax.
-  if (llvm::any_of(genericOp.getDpsInputOperands(),
-                   [&](OpOperand *operand) -> bool {
-                     auto genericOperand =
-                         operand->get().getDefiningOp<linalg::GenericOp>();
-                     if (!genericOperand) {
-                       return false;
-                     }
+  auto isPossiblySoftmax = [&](OpOperand *operand) -> bool {
+    auto genericOperand = operand->get().getDefiningOp<linalg::GenericOp>();
+    if (!genericOperand) {
+      return false;
+    }
 
-                     if (genericOperand.getNumReductionLoops() == 0) {
-                       return false;
-                     }
+    if (genericOperand.getNumReductionLoops() == 0) {
+      return false;
+    }
 
-                     return genericOp.getMatchingIndexingMap(operand)
-                         .isProjectedPermutation();
-                   })) {
+    auto map = genericOp.getMatchingIndexingMap(operand);
+    return !map.isPermutation() && map.isProjectedPermutation();
+  };
+  if (llvm::any_of(genericOp.getDpsInputOperands(), isPossiblySoftmax)) {
     return false;
   }
 
@@ -615,6 +624,7 @@
   // 1. Get the slice of operations within `dispatchOp` that produce the yielded
   // value.
   BackwardSliceOptions sliceOptions;
+  sliceOptions.omitBlockArguments = true;
   sliceOptions.filter = [&](Operation *op) {
     return op->getParentOfType<IREE::Flow::DispatchRegionOp>();
   };
@@ -868,6 +878,7 @@
   BackwardSliceOptions sliceOptions;
   sliceOptions.inclusive = true;
   sliceOptions.omitBlockArguments = true;
+  sliceOptions.omitUsesFromAbove = false;
   sliceOptions.filter = [&](Operation *op) -> bool {
     auto parentOp = op->getParentOfType<IREE::Flow::DispatchRegionOp>();
     return isEligibleForCollapse(op) && parentOp == regionOp;
diff --git a/compiler/src/iree/compiler/DispatchCreation/test/collapse_dimensions.mlir b/compiler/src/iree/compiler/DispatchCreation/test/collapse_dimensions.mlir
index ae4146f..880f7f9 100644
--- a/compiler/src/iree/compiler/DispatchCreation/test/collapse_dimensions.mlir
+++ b/compiler/src/iree/compiler/DispatchCreation/test/collapse_dimensions.mlir
@@ -619,3 +619,40 @@
 //       CHECK:   %[[TRUNC:.*]] = linalg.generic
 //  CHECK-SAME:      ins(%[[ATTN]] : tensor<20x4096x64xf32>
 //       CHECK:   flow.return %[[TRUNC]] : tensor<20x4096x64xf16>
+
+// -----
+
+util.func public @collapse(%10: tensor<2x32x32x1280xi8>, %11 : tensor<10240x1280xi8>, %12 : tensor<10240xi32>, %13 : tensor<10240xf32>) -> (tensor<2x32x32x10240xf16>) {
+  %c0_i32 = arith.constant 0 : i32
+  %c0 = arith.constant 0 : index
+  %14 = tensor.empty() : tensor<2x32x32x10240xf16>
+  %15 = tensor.empty() : tensor<2x32x32x10240xi32>
+  %16 = linalg.fill ins(%c0_i32 : i32) outs(%15 : tensor<2x32x32x10240xi32>) -> tensor<2x32x32x10240xi32>
+  %dispatch = flow.dispatch.region -> (tensor<2x32x32x10240xf16>) {
+    %17 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"]} ins(%10, %11 : tensor<2x32x32x1280xi8>, tensor<10240x1280xi8>) outs(%16 : tensor<2x32x32x10240xi32>) {
+    ^bb0(%in: i8, %in_0: i8, %out: i32):
+      %19 = arith.extsi %in : i8 to i32
+      %20 = arith.extsi %in_0 : i8 to i32
+      %21 = arith.muli %19, %20 : i32
+      %22 = arith.addi %out, %21 : i32
+      linalg.yield %22 : i32
+    } -> tensor<2x32x32x10240xi32>
+    %18 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d3)>, affine_map<(d0, d1, d2, d3) -> (d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%17, %12, %13 : tensor<2x32x32x10240xi32>, tensor<10240xi32>, tensor<10240xf32>) outs(%14 : tensor<2x32x32x10240xf16>) {
+    ^bb0(%in: i32, %in_0: i32, %in_1: f32, %out: f16):
+      %19 = arith.addi %in, %in_0 : i32
+      %20 = arith.sitofp %19 : i32 to f32
+      %21 = arith.mulf %20, %in_1 : f32
+      %22 = arith.truncf %21 : f32 to f16
+      linalg.yield %22 : f16
+    } -> tensor<2x32x32x10240xf16>
+    flow.return %18 : tensor<2x32x32x10240xf16>
+  }
+  util.return %dispatch  : tensor<2x32x32x10240xf16>
+}
+
+// CHECK-LABEL: util.func public @collapse
+//       CHECK:   %[[GEN0:.*]] = linalg.generic
+//  CHECK-SAME:      iterator_types = ["parallel", "parallel", "reduction"]
+//       CHECK:   %[[GEN1:.*]] = linalg.generic
+//  CHECK-SAME:      iterator_types = ["parallel", "parallel"]
+//       CHECK:   flow.return %[[GEN1]] : tensor<2048x10240xf16>