[Dispatch][GlobalOpt] Improve transpose fusion for conv (#21778)

Generalizes all Linalg conv operations & allows elementwise fusion with
conv ops (non-projected permutation indexing maps). This is important
when handling pytorch input e.g. `transpose(NHWC->NCHW)->conv->
transpose back` so that the conv + transposes end up as a single
dispatch.

Also, adds workaround to `swapCollapseShapeWithSlice` that prevents
the rewrite pattern from not converging.

---------

Signed-off-by: Ian Wood <ianwood@u.northwestern.edu>
diff --git a/compiler/src/iree/compiler/Codegen/Common/Transforms.cpp b/compiler/src/iree/compiler/Codegen/Common/Transforms.cpp
index c0ba960..b3adfa9 100644
--- a/compiler/src/iree/compiler/Codegen/Common/Transforms.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/Transforms.cpp
@@ -7,6 +7,7 @@
 #include "iree/compiler/Codegen/Common/Transforms.h"
 #include "iree/compiler/Codegen/Common/CombineLayoutTransformation.h"
 #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
+#include "llvm/ADT/ScopeExit.h"
 #include "mlir/Analysis/SliceAnalysis.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
@@ -441,6 +442,17 @@
 swapCollapseShapeWithSlice(RewriterBase &rewriter,
                            tensor::CollapseShapeOp collapseShapeOp,
                            tensor::ExtractSliceOp sliceOp) {
+  // FIXME: this is a workaround for the fact that this inf loops due to the
+  // creation of `affine::AffineDelinearizeIndexOp` but still returns failure.
+  SmallVector<Operation *> createdOps;
+  auto scope = llvm::make_scope_exit([&]() {
+    for (Operation *op : createdOps) {
+      if (op->use_empty()) {
+        rewriter.eraseOp(op);
+      }
+    }
+  });
+
   // The tensor.extract_slice before applying the pattern works on the result
   // of the tensor.collapse_shape, so variables (i.e. inputs for
   // ExtractSliceOp) referring to the state before applying the pattern are
@@ -532,6 +544,7 @@
         }
         auto delinearizeOp = rewriter.create<affine::AffineDelinearizeIndexOp>(
             sliceOp.getLoc(), cast<Value>(collapsedOffset), expandedBasis);
+        createdOps.push_back(delinearizeOp);
         ValueRange offsets = delinearizeOp.getResults();
         expandedOffsets.append(offsets.begin(), offsets.end());
 
diff --git a/compiler/src/iree/compiler/DispatchCreation/FusionUtils.cpp b/compiler/src/iree/compiler/DispatchCreation/FusionUtils.cpp
index aa18221..1a321e4 100644
--- a/compiler/src/iree/compiler/DispatchCreation/FusionUtils.cpp
+++ b/compiler/src/iree/compiler/DispatchCreation/FusionUtils.cpp
@@ -36,15 +36,6 @@
     return true;
   }
 
-  // Don't fuse if all of the consumer maps aren't projected permutations.
-  if (auto linalgConsumerOp = dyn_cast<linalg::LinalgOp>(consumerOp)) {
-    if (!llvm::all_of(
-            linalgConsumerOp.getIndexingMapsArray(),
-            [](AffineMap map) { return map.isProjectedPermutation(); })) {
-      return false;
-    }
-  }
-
   // If the generic op is "just" copy, then fuse always.
   Block &body = producerOp->getRegion(0).front();
   if (std::begin(body)->hasTrait<OpTrait::IsTerminator>())
diff --git a/compiler/src/iree/compiler/DispatchCreation/test/elementwise_op_fusion.mlir b/compiler/src/iree/compiler/DispatchCreation/test/elementwise_op_fusion.mlir
index 5b23b96..4351630 100644
--- a/compiler/src/iree/compiler/DispatchCreation/test/elementwise_op_fusion.mlir
+++ b/compiler/src/iree/compiler/DispatchCreation/test/elementwise_op_fusion.mlir
@@ -536,3 +536,49 @@
 //  CHECK-SAME:     ins(%[[ARG1]] : tensor<100xindex>
 //       CHECK:     arith.index_cast %[[ARG2]]
 //       CHECK:   return %[[GATHER]]
+
+// -----
+
+util.func public @fuse_transpose_with_conv(%arg0 : tensor<100x32x32x3xbf16>, %arg1 : tensor<32x3x3x3xbf16>) -> tensor<100x30x30x32xbf16> {
+  %cst = arith.constant 0.000000e+00 : f32
+  %2 = tensor.empty() : tensor<100x3x32x32xbf16>
+  %3 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d3, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : tensor<100x32x32x3xbf16>) outs(%2 : tensor<100x3x32x32xbf16>) {
+  ^bb0(%in: bf16, %out: bf16):
+    linalg.yield %in : bf16
+  } -> tensor<100x3x32x32xbf16>
+  %4 = tensor.empty() : tensor<32x3x3x3xbf16>
+  %5 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d3, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg1 : tensor<32x3x3x3xbf16>) outs(%4 : tensor<32x3x3x3xbf16>) {
+  ^bb0(%in: bf16, %out: bf16):
+    linalg.yield %in : bf16
+  } -> tensor<32x3x3x3xbf16>
+  %6 = tensor.empty() : tensor<100x32x30x30xf32>
+  %7 = linalg.fill ins(%cst : f32) outs(%6 : tensor<100x32x30x30xf32>) -> tensor<100x32x30x30xf32>
+  %8 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d4, d2 + d5, d3 + d6)>, affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d4, d5, d6)>, affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%3, %5 : tensor<100x3x32x32xbf16>, tensor<32x3x3x3xbf16>) outs(%7 : tensor<100x32x30x30xf32>) {
+  ^bb0(%in: bf16, %in_0: bf16, %out: f32):
+    %15 = arith.extf %in : bf16 to f32
+    %16 = arith.extf %in_0 : bf16 to f32
+    %17 = arith.mulf %15, %16 : f32
+    %18 = arith.addf %out, %17 : f32
+    linalg.yield %18 : f32
+  } -> tensor<100x32x30x30xf32>
+  %9 = tensor.empty() : tensor<100x32x30x30xbf16>
+  %10 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%8 : tensor<100x32x30x30xf32>) outs(%9 : tensor<100x32x30x30xbf16>) {
+  ^bb0(%in: f32, %out: bf16):
+    %15 = arith.truncf %in : f32 to bf16
+    linalg.yield %15 : bf16
+  } -> tensor<100x32x30x30xbf16>
+  %11 = tensor.empty() : tensor<100x30x30x32xbf16>
+  %12 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d3, d1)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%10 : tensor<100x32x30x30xbf16>) outs(%11 : tensor<100x30x30x32xbf16>) {
+  ^bb0(%in: bf16, %out: bf16):
+    linalg.yield %in : bf16
+  } -> tensor<100x30x30x32xbf16>
+  util.return %12 : tensor<100x30x30x32xbf16>
+}
+// CHECK-LABEL: util.func public @fuse_transpose_with_conv(
+//  CHECK-SAME:   %[[ARG0:[A-Za-z0-9]+]]: tensor
+//  CHECK-SAME:   %[[ARG1:[A-Za-z0-9]+]]: tensor
+//       CHECK:   %[[CONV:.+]] = linalg.generic
+//  CHECK-SAME:     ins(%[[ARG0]], %[[ARG1]]
+//       CHECK:   %[[TRUNCF:.+]] = linalg.generic
+//  CHECK-SAME:     ins(%[[CONV]]
+//       CHECK:   return %[[TRUNCF]]
diff --git a/compiler/src/iree/compiler/GlobalOptimization/GeneralizeLinalgNamedOps.cpp b/compiler/src/iree/compiler/GlobalOptimization/GeneralizeLinalgNamedOps.cpp
index f51fb1a..24b37e7 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/GeneralizeLinalgNamedOps.cpp
+++ b/compiler/src/iree/compiler/GlobalOptimization/GeneralizeLinalgNamedOps.cpp
@@ -35,50 +35,6 @@
 };
 } // namespace
 
-/// Returns true if `linalgOp` can be simplified to a basic GEMM.
-static bool isConvFoldableToContraction(linalg::LinalgOp linalgOp) {
-  auto convDimsOrFailure = linalg::inferConvolutionDims(linalgOp);
-  if (failed(convDimsOrFailure)) {
-    return false;
-  }
-  auto &convDims = *convDimsOrFailure;
-
-  if (!llvm::all_of(convDims.strides,
-                    [](int64_t element) { return element == 1; })) {
-    LDBG("conv not foldable: non-unit strides");
-    return false;
-  }
-
-  // Dont generalize pooling operations or depthwise convolutions. For pooling
-  // ops, the input/output channel size will be categorized as the additional
-  // batch dimension.
-  if (convDims.outputChannel.empty() || convDims.inputChannel.empty()) {
-    LDBG("conv not foldable: missing input or output channel dims");
-    return false;
-  }
-
-  // Check if all filter dimensions are size 1.
-  const int64_t kFilterInputIdx = 1;
-  auto filterShapeType = llvm::dyn_cast<RankedTensorType>(
-      linalgOp.getDpsInputOperand(kFilterInputIdx)->get().getType());
-  if (!filterShapeType) {
-    LDBG("conv not foldable: filter shape not ranked tensor");
-    return false;
-  }
-  auto filterShape = filterShapeType.getShape();
-  AffineMap filterMap = linalgOp.getIndexingMapsArray()[kFilterInputIdx];
-  for (auto filterLoop : convDims.filterLoop) {
-    std::optional<int64_t> maybeDim = filterMap.getResultPosition(
-        getAffineDimExpr(filterLoop, filterMap.getContext()));
-    if (!maybeDim || filterShape[*maybeDim] != 1) {
-      LDBG("conv not foldable: non-unit filter dim");
-      return false;
-    }
-  }
-
-  return true;
-}
-
 void GeneralizeLinalgNamedOpsPass::runOnOperation() {
   auto funcOp = getOperation();
   SmallVector<linalg::LinalgOp> namedOpCandidates;
@@ -98,7 +54,7 @@
                         linalg::MulOp, linalg::NegFOp, linalg::ReduceOp,
                         linalg::SubOp, linalg::TransposeOp>(
             linalgOp.getOperation()) ||
-        isConvFoldableToContraction(linalgOp)) {
+        linalg::isaConvolutionOpInterface(linalgOp)) {
       namedOpCandidates.push_back(linalgOp);
     }
   });
diff --git a/compiler/src/iree/compiler/GlobalOptimization/test/generalize_named_ops.mlir b/compiler/src/iree/compiler/GlobalOptimization/test/generalize_named_ops.mlir
index 63d1ee0..5f5c799 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/test/generalize_named_ops.mlir
+++ b/compiler/src/iree/compiler/GlobalOptimization/test/generalize_named_ops.mlir
@@ -126,7 +126,7 @@
 }
 
 // CHECK-LABEL: @no_generalize_1x1_conv_2d_strides
-//   CHECK-NOT:   linalg.generic
+//       CHECK:   linalg.generic
 //       CHECK:   util.return
 
 // -----
@@ -141,5 +141,5 @@
 }
 
 // CHECK-LABEL: @no_generalize_1x1_depthwise_conv
-//   CHECK-NOT:   linalg.generic
+//       CHECK:   linalg.generic
 //       CHECK:   util.return
diff --git a/compiler/src/iree/compiler/Preprocessing/Common/test/attr_based_pipeline.mlir b/compiler/src/iree/compiler/Preprocessing/Common/test/attr_based_pipeline.mlir
index 406051c..27bb61f 100644
--- a/compiler/src/iree/compiler/Preprocessing/Common/test/attr_based_pipeline.mlir
+++ b/compiler/src/iree/compiler/Preprocessing/Common/test/attr_based_pipeline.mlir
@@ -12,10 +12,10 @@
 // CHECK-LABEL: @single_dispatch_dropunitdims
 //  CHECK-SAME: %[[ARG0:[A-Za-z0-9]+]]: tensor<1x26x18x288xbf16>
 //       CHECK:   %[[DISPATCH:.+]] = flow.dispatch.region
-//       CHECK:     %[[EXPAND:.+]] = tensor.expand_shape %[[ARG0]]
-//       CHECK:     %[[COLLAPSE:.+]] = tensor.collapse_shape %[[EXPAND]]
+//       CHECK:     %[[COLLAPSE:.+]] = tensor.collapse_shape %[[ARG0]]
 //       CHECK:     %[[CONV:.+]] = linalg.generic {{.*}} ins(%[[COLLAPSE]]
-//       CHECK:     flow.return %[[CONV]]
+//       CHECK:     %[[EXPAND:.+]] = tensor.expand_shape %[[CONV]]
+//       CHECK:     flow.return %[[EXPAND]]
 //       CHECK:   return %[[DISPATCH]]
 
 // -----
diff --git a/compiler/src/iree/compiler/Preprocessing/Passes.cpp b/compiler/src/iree/compiler/Preprocessing/Passes.cpp
index 53e6799..cdd3f4b 100644
--- a/compiler/src/iree/compiler/Preprocessing/Passes.cpp
+++ b/compiler/src/iree/compiler/Preprocessing/Passes.cpp
@@ -147,6 +147,7 @@
   // Generalize transposes and any other remaining named linalg ops that can
   // now be represented as generics.
   passManager.addPass(GlobalOptimization::createGeneralizeLinalgNamedOpsPass());
+  passManager.addPass(DispatchCreation::createFoldUnitExtentDimsForFuncPass());
   passManager.addPass(
       GlobalOptimization::createConvertStridedContractionToContractionPass());
   passManager.addPass(DispatchCreation::createFusionPreprocessingPass());
diff --git a/tests/external/iree-test-suites/onnx_ops/onnx_ops_gpu_vulkan.json b/tests/external/iree-test-suites/onnx_ops/onnx_ops_gpu_vulkan.json
index db0f378..bb2377e 100644
--- a/tests/external/iree-test-suites/onnx_ops/onnx_ops_gpu_vulkan.json
+++ b/tests/external/iree-test-suites/onnx_ops/onnx_ops_gpu_vulkan.json
@@ -41,8 +41,6 @@
     "onnx/node/generated/test_argmin_negative_axis_keepdims_random_select_last_index",
     "onnx/node/generated/test_argmin_no_keepdims_example_select_last_index",
     "onnx/node/generated/test_argmin_no_keepdims_random_select_last_index",
-    "onnx/node/generated/test_averagepool_2d_same_lower",
-    "onnx/node/generated/test_averagepool_2d_same_upper",
     "onnx/node/generated/test_basic_deform_conv_with_padding",
     "onnx/node/generated/test_basic_deform_conv_without_padding",
     "onnx/node/generated/test_bernoulli_seed",