[Dispatch][GlobalOpt] Improve transpose fusion for conv (#21778)
Generalizes all Linalg conv operations & allows elementwise fusion with
conv ops (non-projected permutation indexing maps). This is important
when handling pytorch input e.g. `transpose(NHWC->NCHW)->conv->
transpose back` so that the conv + transposes end up as a single
dispatch.
Also, adds workaround to `swapCollapseShapeWithSlice` that prevents
the rewrite pattern from not converging.
---------
Signed-off-by: Ian Wood <ianwood@u.northwestern.edu>
diff --git a/compiler/src/iree/compiler/Codegen/Common/Transforms.cpp b/compiler/src/iree/compiler/Codegen/Common/Transforms.cpp
index c0ba960..b3adfa9 100644
--- a/compiler/src/iree/compiler/Codegen/Common/Transforms.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/Transforms.cpp
@@ -7,6 +7,7 @@
#include "iree/compiler/Codegen/Common/Transforms.h"
#include "iree/compiler/Codegen/Common/CombineLayoutTransformation.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
+#include "llvm/ADT/ScopeExit.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
@@ -441,6 +442,17 @@
swapCollapseShapeWithSlice(RewriterBase &rewriter,
tensor::CollapseShapeOp collapseShapeOp,
tensor::ExtractSliceOp sliceOp) {
+ // FIXME: this is a workaround for the fact that this inf loops due to the
+ // creation of `affine::AffineDelinearizeIndexOp` but still returns failure.
+ SmallVector<Operation *> createdOps;
+ auto scope = llvm::make_scope_exit([&]() {
+ for (Operation *op : createdOps) {
+ if (op->use_empty()) {
+ rewriter.eraseOp(op);
+ }
+ }
+ });
+
// The tensor.extract_slice before applying the pattern works on the result
// of the tensor.collapse_shape, so variables (i.e. inputs for
// ExtractSliceOp) referring to the state before applying the pattern are
@@ -532,6 +544,7 @@
}
auto delinearizeOp = rewriter.create<affine::AffineDelinearizeIndexOp>(
sliceOp.getLoc(), cast<Value>(collapsedOffset), expandedBasis);
+ createdOps.push_back(delinearizeOp);
ValueRange offsets = delinearizeOp.getResults();
expandedOffsets.append(offsets.begin(), offsets.end());
diff --git a/compiler/src/iree/compiler/DispatchCreation/FusionUtils.cpp b/compiler/src/iree/compiler/DispatchCreation/FusionUtils.cpp
index aa18221..1a321e4 100644
--- a/compiler/src/iree/compiler/DispatchCreation/FusionUtils.cpp
+++ b/compiler/src/iree/compiler/DispatchCreation/FusionUtils.cpp
@@ -36,15 +36,6 @@
return true;
}
- // Don't fuse if all of the consumer maps aren't projected permutations.
- if (auto linalgConsumerOp = dyn_cast<linalg::LinalgOp>(consumerOp)) {
- if (!llvm::all_of(
- linalgConsumerOp.getIndexingMapsArray(),
- [](AffineMap map) { return map.isProjectedPermutation(); })) {
- return false;
- }
- }
-
// If the generic op is "just" copy, then fuse always.
Block &body = producerOp->getRegion(0).front();
if (std::begin(body)->hasTrait<OpTrait::IsTerminator>())
diff --git a/compiler/src/iree/compiler/DispatchCreation/test/elementwise_op_fusion.mlir b/compiler/src/iree/compiler/DispatchCreation/test/elementwise_op_fusion.mlir
index 5b23b96..4351630 100644
--- a/compiler/src/iree/compiler/DispatchCreation/test/elementwise_op_fusion.mlir
+++ b/compiler/src/iree/compiler/DispatchCreation/test/elementwise_op_fusion.mlir
@@ -536,3 +536,49 @@
// CHECK-SAME: ins(%[[ARG1]] : tensor<100xindex>
// CHECK: arith.index_cast %[[ARG2]]
// CHECK: return %[[GATHER]]
+
+// -----
+
+util.func public @fuse_transpose_with_conv(%arg0 : tensor<100x32x32x3xbf16>, %arg1 : tensor<32x3x3x3xbf16>) -> tensor<100x30x30x32xbf16> {
+ %cst = arith.constant 0.000000e+00 : f32
+ %2 = tensor.empty() : tensor<100x3x32x32xbf16>
+ %3 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d3, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : tensor<100x32x32x3xbf16>) outs(%2 : tensor<100x3x32x32xbf16>) {
+ ^bb0(%in: bf16, %out: bf16):
+ linalg.yield %in : bf16
+ } -> tensor<100x3x32x32xbf16>
+ %4 = tensor.empty() : tensor<32x3x3x3xbf16>
+ %5 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d3, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg1 : tensor<32x3x3x3xbf16>) outs(%4 : tensor<32x3x3x3xbf16>) {
+ ^bb0(%in: bf16, %out: bf16):
+ linalg.yield %in : bf16
+ } -> tensor<32x3x3x3xbf16>
+ %6 = tensor.empty() : tensor<100x32x30x30xf32>
+ %7 = linalg.fill ins(%cst : f32) outs(%6 : tensor<100x32x30x30xf32>) -> tensor<100x32x30x30xf32>
+ %8 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d4, d2 + d5, d3 + d6)>, affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d4, d5, d6)>, affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%3, %5 : tensor<100x3x32x32xbf16>, tensor<32x3x3x3xbf16>) outs(%7 : tensor<100x32x30x30xf32>) {
+ ^bb0(%in: bf16, %in_0: bf16, %out: f32):
+ %15 = arith.extf %in : bf16 to f32
+ %16 = arith.extf %in_0 : bf16 to f32
+ %17 = arith.mulf %15, %16 : f32
+ %18 = arith.addf %out, %17 : f32
+ linalg.yield %18 : f32
+ } -> tensor<100x32x30x30xf32>
+ %9 = tensor.empty() : tensor<100x32x30x30xbf16>
+ %10 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%8 : tensor<100x32x30x30xf32>) outs(%9 : tensor<100x32x30x30xbf16>) {
+ ^bb0(%in: f32, %out: bf16):
+ %15 = arith.truncf %in : f32 to bf16
+ linalg.yield %15 : bf16
+ } -> tensor<100x32x30x30xbf16>
+ %11 = tensor.empty() : tensor<100x30x30x32xbf16>
+ %12 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d3, d1)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%10 : tensor<100x32x30x30xbf16>) outs(%11 : tensor<100x30x30x32xbf16>) {
+ ^bb0(%in: bf16, %out: bf16):
+ linalg.yield %in : bf16
+ } -> tensor<100x30x30x32xbf16>
+ util.return %12 : tensor<100x30x30x32xbf16>
+}
+// CHECK-LABEL: util.func public @fuse_transpose_with_conv(
+// CHECK-SAME: %[[ARG0:[A-Za-z0-9]+]]: tensor
+// CHECK-SAME: %[[ARG1:[A-Za-z0-9]+]]: tensor
+// CHECK: %[[CONV:.+]] = linalg.generic
+// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]]
+// CHECK: %[[TRUNCF:.+]] = linalg.generic
+// CHECK-SAME: ins(%[[CONV]]
+// CHECK: return %[[TRUNCF]]
diff --git a/compiler/src/iree/compiler/GlobalOptimization/GeneralizeLinalgNamedOps.cpp b/compiler/src/iree/compiler/GlobalOptimization/GeneralizeLinalgNamedOps.cpp
index f51fb1a..24b37e7 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/GeneralizeLinalgNamedOps.cpp
+++ b/compiler/src/iree/compiler/GlobalOptimization/GeneralizeLinalgNamedOps.cpp
@@ -35,50 +35,6 @@
};
} // namespace
-/// Returns true if `linalgOp` can be simplified to a basic GEMM.
-static bool isConvFoldableToContraction(linalg::LinalgOp linalgOp) {
- auto convDimsOrFailure = linalg::inferConvolutionDims(linalgOp);
- if (failed(convDimsOrFailure)) {
- return false;
- }
- auto &convDims = *convDimsOrFailure;
-
- if (!llvm::all_of(convDims.strides,
- [](int64_t element) { return element == 1; })) {
- LDBG("conv not foldable: non-unit strides");
- return false;
- }
-
- // Dont generalize pooling operations or depthwise convolutions. For pooling
- // ops, the input/output channel size will be categorized as the additional
- // batch dimension.
- if (convDims.outputChannel.empty() || convDims.inputChannel.empty()) {
- LDBG("conv not foldable: missing input or output channel dims");
- return false;
- }
-
- // Check if all filter dimensions are size 1.
- const int64_t kFilterInputIdx = 1;
- auto filterShapeType = llvm::dyn_cast<RankedTensorType>(
- linalgOp.getDpsInputOperand(kFilterInputIdx)->get().getType());
- if (!filterShapeType) {
- LDBG("conv not foldable: filter shape not ranked tensor");
- return false;
- }
- auto filterShape = filterShapeType.getShape();
- AffineMap filterMap = linalgOp.getIndexingMapsArray()[kFilterInputIdx];
- for (auto filterLoop : convDims.filterLoop) {
- std::optional<int64_t> maybeDim = filterMap.getResultPosition(
- getAffineDimExpr(filterLoop, filterMap.getContext()));
- if (!maybeDim || filterShape[*maybeDim] != 1) {
- LDBG("conv not foldable: non-unit filter dim");
- return false;
- }
- }
-
- return true;
-}
-
void GeneralizeLinalgNamedOpsPass::runOnOperation() {
auto funcOp = getOperation();
SmallVector<linalg::LinalgOp> namedOpCandidates;
@@ -98,7 +54,7 @@
linalg::MulOp, linalg::NegFOp, linalg::ReduceOp,
linalg::SubOp, linalg::TransposeOp>(
linalgOp.getOperation()) ||
- isConvFoldableToContraction(linalgOp)) {
+ linalg::isaConvolutionOpInterface(linalgOp)) {
namedOpCandidates.push_back(linalgOp);
}
});
diff --git a/compiler/src/iree/compiler/GlobalOptimization/test/generalize_named_ops.mlir b/compiler/src/iree/compiler/GlobalOptimization/test/generalize_named_ops.mlir
index 63d1ee0..5f5c799 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/test/generalize_named_ops.mlir
+++ b/compiler/src/iree/compiler/GlobalOptimization/test/generalize_named_ops.mlir
@@ -126,7 +126,7 @@
}
// CHECK-LABEL: @no_generalize_1x1_conv_2d_strides
-// CHECK-NOT: linalg.generic
+// CHECK: linalg.generic
// CHECK: util.return
// -----
@@ -141,5 +141,5 @@
}
// CHECK-LABEL: @no_generalize_1x1_depthwise_conv
-// CHECK-NOT: linalg.generic
+// CHECK: linalg.generic
// CHECK: util.return
diff --git a/compiler/src/iree/compiler/Preprocessing/Common/test/attr_based_pipeline.mlir b/compiler/src/iree/compiler/Preprocessing/Common/test/attr_based_pipeline.mlir
index 406051c..27bb61f 100644
--- a/compiler/src/iree/compiler/Preprocessing/Common/test/attr_based_pipeline.mlir
+++ b/compiler/src/iree/compiler/Preprocessing/Common/test/attr_based_pipeline.mlir
@@ -12,10 +12,10 @@
// CHECK-LABEL: @single_dispatch_dropunitdims
// CHECK-SAME: %[[ARG0:[A-Za-z0-9]+]]: tensor<1x26x18x288xbf16>
// CHECK: %[[DISPATCH:.+]] = flow.dispatch.region
-// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[ARG0]]
-// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[EXPAND]]
+// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[ARG0]]
// CHECK: %[[CONV:.+]] = linalg.generic {{.*}} ins(%[[COLLAPSE]]
-// CHECK: flow.return %[[CONV]]
+// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[CONV]]
+// CHECK: flow.return %[[EXPAND]]
// CHECK: return %[[DISPATCH]]
// -----
diff --git a/compiler/src/iree/compiler/Preprocessing/Passes.cpp b/compiler/src/iree/compiler/Preprocessing/Passes.cpp
index 53e6799..cdd3f4b 100644
--- a/compiler/src/iree/compiler/Preprocessing/Passes.cpp
+++ b/compiler/src/iree/compiler/Preprocessing/Passes.cpp
@@ -147,6 +147,7 @@
// Generalize transposes and any other remaining named linalg ops that can
// now be represented as generics.
passManager.addPass(GlobalOptimization::createGeneralizeLinalgNamedOpsPass());
+ passManager.addPass(DispatchCreation::createFoldUnitExtentDimsForFuncPass());
passManager.addPass(
GlobalOptimization::createConvertStridedContractionToContractionPass());
passManager.addPass(DispatchCreation::createFusionPreprocessingPass());
diff --git a/tests/external/iree-test-suites/onnx_ops/onnx_ops_gpu_vulkan.json b/tests/external/iree-test-suites/onnx_ops/onnx_ops_gpu_vulkan.json
index db0f378..bb2377e 100644
--- a/tests/external/iree-test-suites/onnx_ops/onnx_ops_gpu_vulkan.json
+++ b/tests/external/iree-test-suites/onnx_ops/onnx_ops_gpu_vulkan.json
@@ -41,8 +41,6 @@
"onnx/node/generated/test_argmin_negative_axis_keepdims_random_select_last_index",
"onnx/node/generated/test_argmin_no_keepdims_example_select_last_index",
"onnx/node/generated/test_argmin_no_keepdims_random_select_last_index",
- "onnx/node/generated/test_averagepool_2d_same_lower",
- "onnx/node/generated/test_averagepool_2d_same_upper",
"onnx/node/generated/test_basic_deform_conv_with_padding",
"onnx/node/generated/test_basic_deform_conv_without_padding",
"onnx/node/generated/test_bernoulli_seed",