Collapse `linalg.generic` (#11295)
PR implements collapsing `linalg.generic`. It first identifies the
collapsible parallel dimensions and collapses them all. It does
collapsing on the shapes rather than loops. Therefore, it does not
introduce any arithmetic to calculate loop indices and etc. When there
are `reduction` it can still collapse `parallel` loops but it does not
mix them.
It finds the longest same sequence in each `AffineMap`. There can be
multiple. For example, for the following case it is `d1, d3, d0`. Here
the `d1, d3, d0` loops are not nested; there are other loop(s) in
between. But they're all parallel loops, so it's safe to interchange
them out. After interchanging, it is also safe to collapse them.
```
indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d1, d3, d0)>,
affine_map<(d0, d1, d2, d3, d4) -> (d2, d1, d3, d0, d4)>]
```
After collapsing, iree can parallelize more dimensions of the
`linalg.generic`, this yields significant performance improvement.
Current limitations:
1. Dynamic tensor shapes: Generating `tensor.expand_shape` with dynamic
tensors is ambiguous. It is a known limitation of MLIR. It is possible
to solve that. See RFC:
https://discourse.llvm.org/t/rfc-add-explicit-shape-inputs-to-tensor-expand-shape/65952
2. Non-contiguous loops. Current mechanism does collapsing on tensor
shapes, not on the loops. If there is non-contiguous loops (like
transpose), collapsing tensor would change behavior. #11385 tackles this
problem by linearizing the workgroup id. Alternative idea is to
linearize loops.
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.cpp
index be65059..22bdfc2 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.cpp
@@ -22,6 +22,7 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
@@ -691,12 +692,147 @@
return result;
}
+/// Searches the same sequence in all the affine maps and collapses these
+/// dimensions. It only applies these to "parallel" loops without mixing them
+/// with "reduction" types.
+static SmallVector<ReassociationIndices> getCollapsibleLoops(
+ linalg::GenericOp genericOp) {
+ SmallVector<ReassociationIndices> contiguousLoops;
+
+ SmallVector<unsigned> pDims;
+ genericOp.getParallelDims(pDims);
+ if (pDims.size() < 2) return contiguousLoops;
+
+ llvm::SmallDenseSet<unsigned> pLoops(pDims.begin(), pDims.end());
+
+ auto hasAllMapsSameSequence = [&](AffineExpr preExpr, AffineExpr nextExpr) {
+ for (AffineMap map : genericOp.getIndexingMapsArray()) {
+ bool foundSeq = false;
+ for (auto [index, resultExpr] : llvm::enumerate(map.getResults())) {
+ if (resultExpr == nextExpr) {
+ foundSeq = (index > 0 && preExpr == map.getResult(index - 1));
+ break;
+ }
+ }
+ if (!foundSeq) return false;
+ }
+ return true;
+ };
+
+ ReassociationIndices range;
+ AffineExpr preExpr;
+ for (auto nextExpr : genericOp.getIndexingMapsArray().front().getResults()) {
+ unsigned pos = nextExpr.cast<AffineDimExpr>().getPosition();
+ if (!range.empty()) {
+ if (!hasAllMapsSameSequence(preExpr, nextExpr) || !pLoops.count(pos)) {
+ if (range.size() > 1)
+ contiguousLoops.push_back({range.begin(), range.end()});
+ range.clear();
+ }
+ }
+ preExpr = nextExpr;
+ if (pLoops.count(pos)) range.push_back(pos);
+ }
+ if (range.size() > 1) contiguousLoops.push_back(range);
+
+ LLVM_DEBUG({
+ llvm::dbgs() << "Collapsing dimensions if possible: ";
+ for (auto indices : contiguousLoops) {
+ llvm::dbgs() << "[";
+ for (auto idx : indices) llvm::dbgs() << idx << ",";
+ llvm::dbgs() << "]\t";
+ }
+ llvm::dbgs() << "\n";
+ });
+
+ return contiguousLoops;
+}
+
+/// Collapse possible dimension of the given linalg.generic and return the
+/// new one
+static FailureOr<linalg::GenericOp> collapseLinalgGeneric(
+ TensorDimTrackingRewriter &rewriter, linalg::GenericOp genericOp) {
+ SmallVector<ReassociationIndices> collapseIndices =
+ getCollapsibleLoops(genericOp);
+
+ if (collapseIndices.empty()) return genericOp;
+
+ rewriter.setInsertionPoint(genericOp);
+ FailureOr<SmallVector<Value>> replacements =
+ mlir::linalg::collapseGenericOpIterationDims(genericOp, collapseIndices,
+ rewriter);
+ if (failed(replacements) || replacements->empty()) {
+ return rewriter.notifyMatchFailure(genericOp,
+ "failed to collapse dimensions");
+ }
+
+ // Find and return collapsed linalg.generic
+ auto expandshapeOp =
+ replacements->front().getDefiningOp<tensor::ExpandShapeOp>();
+ if (!expandshapeOp) return failure();
+
+ auto newGenericOp =
+ expandshapeOp.getOperand().getDefiningOp<linalg::GenericOp>();
+ if (!newGenericOp) return failure();
+
+ rewriter.replaceOp(genericOp, *replacements);
+ return newGenericOp;
+}
+
+/// Returns true if the given op is collapsable.
+static bool isEligibleForCollapse(Operation *op,
+ ArrayRef<Operation *> producers) {
+ if (!producers.empty()) return false;
+
+ auto genericOp = dyn_cast<linalg::GenericOp>(op);
+ if (!genericOp) return false;
+
+ // TODO(guray) There is no mechanism to tell the collapsed indexes to
+ // `tensor.expand_shape`. Once we have this support in MLIR, we can enable
+ // dynamic tensor shapes.
+ if (genericOp.hasDynamicShape()) return false;
+
+ // TODO(guray) Currently we can only collapse when result of all the
+ // AffineMaps are dimensions. Possible to collapse cases like
+ // affine_map<d0, d1+d2> with affine_map<d0, d1+d2>, however, this is not
+ // supported in collapsing mechanism in MLIR. Once we have this support,
+ // we can remove this if statement.
+ if (llvm::any_of(genericOp.getIndexingMapsArray(), [](AffineMap map) {
+ return !map.isProjectedPermutation();
+ })) {
+ return false;
+ }
+
+ // IndexOp allows accesing induction variables. Collapsing might cause
+ // performance regression, so we disable it.
+ if (genericOp.hasIndexSemantics()) return false;
+
+ return true;
+}
+
+/// Traverses all the ops in `roots`; collapse the ops if they are eligible
+/// ops.
+static LogicalResult collapseDimensions(
+ TensorDimTrackingRewriter &rewriter, SmallVectorImpl<Operation *> &roots,
+ DenseMap<unsigned, SmallVector<Operation *>> &producers) {
+ for (auto [index, op] : llvm::enumerate(roots)) {
+ if (!isEligibleForCollapse(op, producers[index])) continue;
+
+ if (auto genericOp = dyn_cast<linalg::GenericOp>(op)) {
+ auto maybeLinalgGeneric = collapseLinalgGeneric(rewriter, genericOp);
+ if (failed(maybeLinalgGeneric)) return failure();
+ roots[index] = *maybeLinalgGeneric;
+ }
+ }
+ return success();
+}
+
/// Create Flow::DispatchGroupsOps based on a fusion heuristic.
static LogicalResult createFusionGroups(TensorDimTrackingRewriter &rewriter,
FunctionOpInterface funcOp,
DominanceInfo const &dominanceInfo,
bool generateWorkloadRegion,
- bool aggressiveFusion) {
+ bool aggressiveFusion, bool collapse) {
// Step 1: Decide fusion groups (heuristic). This marks rootOps with an
// attribute
unsigned numRoots =
@@ -724,6 +860,17 @@
}
});
+ // TODO(guray): This can be extracted to a pass.
+ if (collapse) {
+ if (failed(collapseDimensions(rewriter, roots, producers)))
+ return failure();
+ LLVM_DEBUG({
+ llvm::dbgs() << "\n--- After Collapsing dimension ---\n";
+ funcOp->print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
+ llvm::dbgs() << "\n\n";
+ });
+ }
+
// Step 2. Create a DispatchRegionOp for every fusion group.
OpBuilder::InsertionGuard g(rewriter);
SmallVector<Flow::DispatchRegionOp> regionOps;
@@ -786,13 +933,15 @@
.insert<AffineDialect, IREE::Flow::FlowDialect, linalg::LinalgDialect,
scf::SCFDialect, tensor::TensorDialect>();
}
- FormDispatchRegionsPass(bool aggressiveFusion, bool generateWorkloadRegion) {
+ FormDispatchRegionsPass(bool aggressiveFusion, bool generateWorkloadRegion,
+ bool collapse) {
this->aggressiveFusion = aggressiveFusion;
this->generateWorkloadRegion = generateWorkloadRegion;
+ this->collapse = collapse;
}
FormDispatchRegionsPass(const FormDispatchRegionsPass &pass)
: FormDispatchRegionsPass(pass.aggressiveFusion,
- pass.generateWorkloadRegion) {}
+ pass.generateWorkloadRegion, pass.collapse) {}
void runOnOperation() override;
};
} // namespace
@@ -803,15 +952,16 @@
DominanceInfo const &dominanceInfo = getAnalysis<DominanceInfo>();
TensorDimTrackingRewriter rewriter(funcOp);
if (failed(createFusionGroups(rewriter, funcOp, dominanceInfo,
- generateWorkloadRegion, aggressiveFusion)))
+ generateWorkloadRegion, aggressiveFusion,
+ collapse)))
return signalPassFailure();
}
std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
createFormDispatchRegionsPass(bool aggressiveFusion,
- bool generateWorkloadRegion) {
- return std::make_unique<FormDispatchRegionsPass>(aggressiveFusion,
- generateWorkloadRegion);
+ bool generateWorkloadRegion, bool collapse) {
+ return std::make_unique<FormDispatchRegionsPass>(
+ aggressiveFusion, generateWorkloadRegion, collapse);
}
} // namespace Flow
} // namespace IREE
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.cpp
index 3199882..f0e9ac7 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.cpp
@@ -101,6 +101,10 @@
"iree-flow-dispatch-generate-workload-region",
llvm::cl::desc("Generate the workload region"), llvm::cl::init(true));
+static llvm::cl::opt<bool> clCollapseDimensions(
+ "iree-flow-form-dispatch-regions-collapse",
+ llvm::cl::desc("Collapse dimensions"), llvm::cl::init(true));
+
static llvm::cl::opt<bool> clEnableDataTiling(
"iree-flow-enable-data-tiling", llvm::cl::desc("Enable data tiling path"),
llvm::cl::init(false));
@@ -292,7 +296,8 @@
// the FormDispatchRegions handle the rest.
.addPass([&]() {
return createFormDispatchRegionsPass(clEnableAggressiveFusion,
- clDispatchGenerateWorkloadRegion);
+ clDispatchGenerateWorkloadRegion,
+ clCollapseDimensions);
})
// Form dispatch region into dispatch workgroups
.addPass([&]() {
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.h b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.h
index f8b2ab9..71cdde3 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.h
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.h
@@ -147,7 +147,8 @@
// is created for each tiled loop nest.
std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
createFormDispatchRegionsPass(bool aggressiveFusion = false,
- bool generateWorkloadRegion = true);
+ bool generateWorkloadRegion = true,
+ bool collapse = true);
//===----------------------------------------------------------------------===//
// Dispatches (flow.dispatch.workgroups)
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.td b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.td
index 4517a69..a18508c 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.td
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.td
@@ -77,6 +77,8 @@
/*default=*/"false", "Fuse with aggressive heuristics">,
Option<"generateWorkloadRegion", "genereate-workload-region", "bool",
/*default=*/"true", "Generate workload regions of WorkgroupOps">,
+ Option<"collapse", "collapse", "bool",
+ /*default=*/"true", "Collapse Op and tensors that are used">,
];
}
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/BUILD b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/BUILD
index de4ce24..cc95697 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/BUILD
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/BUILD
@@ -26,6 +26,7 @@
"deduplicate_executables.mlir",
"detach_elementwise_from_named_ops.mlir",
"dispatch_linalg_on_tensors.mlir",
+ "collapse_linalg_generic_on_tensors.mlir",
"dispatch_linalg_on_tensors_default.mlir",
"dispatch_linalg_on_tensors_fusion_with_transpose.mlir",
"dispatch_linalg_transform_dialect.mlir",
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt
index 02c63fd..6c8be53 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt
@@ -17,6 +17,7 @@
"capture_dispatch_dynamic_dims.mlir"
"cleanup_numeric_narrowing.mlir"
"cleanup_tensor_shapes.mlir"
+ "collapse_linalg_generic_on_tensors.mlir"
"collapse_reduction.mlir"
"conv1x1_to_matmul.mlir"
"conv2d_to_img2col.mlir"
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/collapse_linalg_generic_on_tensors.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/collapse_linalg_generic_on_tensors.mlir
new file mode 100644
index 0000000..9f69b60
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/collapse_linalg_generic_on_tensors.mlir
@@ -0,0 +1,463 @@
+// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(func.func(iree-flow-form-dispatch-regions{aggressive-fusion=true}))" %s | FileCheck %s
+!type = tensor<2x4x8x16x32x64xf32>
+util.global private @"__transpose_10_input" {noinline} = dense<1.0> : !type
+
+func.func @collapse1() -> !type {
+ %cst = arith.constant 0.000000e+00 : f32
+ %c0 = arith.constant 0 : index
+ %input_ptr = util.global.address @"__transpose_10_input" : !util.ptr<!type>
+ %input = util.global.load.indirect %input_ptr : !util.ptr<!type> -> !type
+ %output = tensor.empty() : !type
+
+ // Can collapse All (d0, d1, d2, d3, d4, d5)
+ %6 = linalg.generic { indexing_maps = [
+ affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>,
+ affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>],
+ iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]
+ }
+ ins(%input : !type) outs(%output : !type) {
+ ^bb0(%arg1: f32, %arg2: f32):
+ linalg.yield %arg1 : f32
+ } -> !type
+ return %6: !type
+
+}
+
+// CHECK: #[[$MAP:.+]] = affine_map<(d0) -> (d0)>
+// CHECK-LABEL: func.func @collapse1
+// CHECK: %[[IN:.+]] = tensor.collapse_shape %[[INPUT:.+]] {{\[}}[0, 1, 2, 3, 4, 5]] : tensor<2x4x8x16x32x64xf32> into tensor<2097152xf32>
+// CHECK: %[[OUT:.+]] = tensor.collapse_shape %[[OUTPUT:.+]] {{\[}}[0, 1, 2, 3, 4, 5]] : tensor<2x4x8x16x32x64xf32> into tensor<2097152xf32>
+// CHECK: %[[RES:.+]] = flow.dispatch.region
+// CHECK: linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP]]], iterator_types = ["parallel"]}
+// CHECK: ins(%[[IN]] : tensor<2097152xf32>) outs(%[[OUT]] : tensor<2097152xf32>)
+// CHECK: tensor.expand_shape %[[RES]] {{\[}}[0, 1, 2, 3, 4, 5]] : tensor<2097152xf32> into tensor<2x4x8x16x32x64xf32>
+
+// -----
+
+!type = tensor<2x4x8x32x32x64x128xf32>
+util.global private @"__transpose_10_input" {noinline} = dense<1.0> : !type
+
+func.func @collapse2() -> !type {
+ %cst = arith.constant 0.000000e+00 : f32
+ %c0 = arith.constant 0 : index
+ %input_ptr = util.global.address @"__transpose_10_input" : !util.ptr<!type>
+ %input = util.global.load.indirect %input_ptr : !util.ptr<!type> -> !type
+ %output = tensor.empty() : !type
+
+ // Can collapse (d0, d1) and (d5, d6)
+ %6 = linalg.generic { indexing_maps = [
+ affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d4, d3, d5, d6)>,
+ affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4, d5, d6)>],
+ iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "parallel", "parallel"]
+ }
+ ins(%input : !type) outs(%output : !type) {
+ ^bb0(%arg1: f32, %arg2: f32):
+ linalg.yield %arg1 : f32
+ } -> !type
+ return %6: !type
+
+}
+
+// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d2, d4)>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
+// CHECK-LABEL: func.func @collapse2
+// CHECK: %[[IN:.+]] = tensor.collapse_shape %[[INPUT:.+]] {{\[}}[0, 1], [2], [3], [4], [5, 6]] : tensor<2x4x8x32x32x64x128xf32> into tensor<8x8x32x32x8192xf32>
+// CHECK: %[[OUT:.+]] = tensor.collapse_shape %[[OUTPUT:.+]] {{\[}}[0, 1], [2], [3], [4], [5, 6]] : tensor<2x4x8x32x32x64x128xf32> into tensor<8x8x32x32x8192xf32>
+// CHECK: %[[RES:.+]] = flow.dispatch.region
+// CHECK: linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP1]]], iterator_types = ["parallel", "reduction", "parallel", "parallel", "parallel"]}
+// CHECK: ins(%[[IN]] : tensor<8x8x32x32x8192xf32>) outs(%[[OUT]] : tensor<8x8x32x32x8192xf32>)
+// CHECK: tensor.expand_shape %[[RES]] {{\[}}[0, 1], [2], [3], [4], [5, 6]] : tensor<8x8x32x32x8192xf32> into tensor<2x4x8x32x32x64x128xf32>
+
+// -----
+!type = tensor<2x4x8x16x32x64x128x256xf32>
+util.global private @"__transpose_10_input" {noinline} = dense<1.0> : !type
+
+func.func @collapse3() -> !type {
+ %cst = arith.constant 0.000000e+00 : f32
+ %c0 = arith.constant 0 : index
+ %input_ptr = util.global.address @"__transpose_10_input" : !util.ptr<!type>
+ %input = util.global.load.indirect %input_ptr : !util.ptr<!type> -> !type
+ %output = tensor.empty() : !type
+
+ // Can collapse (d0, d1) and (d3, d4, d5, d6, d7)
+ %result = linalg.generic { indexing_maps = [
+ affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4, d5, d6, d7)>,
+ affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4, d5, d6, d7)>],
+ iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "parallel", "parallel", "parallel"]
+ }
+ ins(%input : !type) outs(%output : !type) {
+ ^bb0(%arg1: f32, %arg2: f32):
+ linalg.yield %arg1 : f32
+ } -> !type
+ return %result: !type
+
+}
+
+// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK-LABEL: func.func @collapse3
+// CHECK: %[[IN:.+]] = tensor.collapse_shape %[[INPUT:.+]] {{\[}}[0, 1], [2], [3, 4, 5, 6, 7]] : tensor<2x4x8x16x32x64x128x256xf32> into tensor<8x8x1073741824xf32>
+// CHECK: %[[OUT:.+]] = tensor.collapse_shape %[[OUTPUT:.+]] {{\[}}[0, 1], [2], [3, 4, 5, 6, 7]] : tensor<2x4x8x16x32x64x128x256xf32> into tensor<8x8x1073741824xf32>
+// CHECK: %[[RES:.+]] = flow.dispatch.region
+// CHECK: linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP]]], iterator_types = ["parallel", "reduction", "parallel"]}
+// CHECK: ins(%[[IN]] : tensor<8x8x1073741824xf32>) outs(%[[OUT]] : tensor<8x8x1073741824xf32>)
+// CHECK: tensor.expand_shape %[[RES]] {{\[}}[0, 1], [2], [3, 4, 5, 6, 7]] : tensor<8x8x1073741824xf32> into tensor<2x4x8x16x32x64x128x256xf32>
+
+// -----
+
+!type = tensor<2x4x8x16x64x64x128x256xf32>
+util.global private @"__transpose_10_input" {noinline} = dense<1.0> : !type
+func.func @collapse4() -> !type {
+ %cst = arith.constant 0.000000e+00 : f32
+ %c0 = arith.constant 0 : index
+ %input_ptr = util.global.address @"__transpose_10_input" : !util.ptr<!type>
+ %input = util.global.load.indirect %input_ptr : !util.ptr<!type> -> !type
+ %output = tensor.empty() : !type
+
+ // Can collapse (d0, d1) and (d6, d7)
+ %result = linalg.generic { indexing_maps = [
+ affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4, d5, d6, d7)>,
+ affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d5, d4, d6, d7)>],
+ iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "parallel", "parallel", "parallel"]
+ }
+ ins(%input : !type) outs(%output : !type) {
+ ^bb0(%arg1: f32, %arg2: f32):
+ linalg.yield %arg1 : f32
+ } -> !type
+ return %result: !type
+
+}
+
+// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>
+// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d4, d3, d5)>
+// CHECK-LABEL: func.func @collapse4
+// CHECK: %[[IN:.+]] = tensor.collapse_shape %[[INPUT:.+]] {{\[}}[0, 1], [2], [3], [4], [5], [6, 7]] : tensor<2x4x8x16x64x64x128x256xf32> into tensor<8x8x16x64x64x32768xf32>
+// CHECK: %[[OUT:.+]] = tensor.collapse_shape %[[OUTPUT:.+]] {{\[}}[0, 1], [2], [3], [4], [5], [6, 7]] : tensor<2x4x8x16x64x64x128x256xf32> into tensor<8x8x16x64x64x32768xf32>
+// CHECK: %[[RES:.+]] = flow.dispatch.region
+// CHECK: linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP2]]], iterator_types = ["parallel", "reduction", "parallel", "parallel", "parallel", "parallel"]}
+// CHECK: ins(%[[IN]] : tensor<8x8x16x64x64x32768xf32>) outs(%[[OUT]] : tensor<8x8x16x64x64x32768xf32>)
+// CHECK: tensor.expand_shape %[[RES]] {{\[}}[0, 1], [2], [3], [4], [5], [6, 7]] : tensor<8x8x16x64x64x32768xf32> into tensor<2x4x8x16x64x64x128x256xf32>
+
+// -----
+
+!type = tensor<2x4x32x32x32x64x128x256xf32>
+util.global private @"__transpose_10_input" {noinline} = dense<1.0> : !type
+func.func @collapse5() -> !type {
+ %cst = arith.constant 0.000000e+00 : f32
+ %c0 = arith.constant 0 : index
+ %input_ptr = util.global.address @"__transpose_10_input" : !util.ptr<!type>
+ %input = util.global.load.indirect %input_ptr : !util.ptr<!type> -> !type
+ %input2 = util.global.load.indirect %input_ptr : !util.ptr<!type> -> !type
+ %input3 = util.global.load.indirect %input_ptr : !util.ptr<!type> -> !type
+ %output = tensor.empty() : !type
+
+ // Can collapse (d0, d1) and (d6, d7)
+ %result = linalg.generic { indexing_maps = [
+ affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4, d5, d6, d7)>,
+ affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d4, d3, d5, d6, d7)>,
+ affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d4, d3, d2, d5, d6, d7)>,
+ affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4, d5, d6, d7)>],
+ iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "reduction", "parallel", "parallel"]
+ }
+ ins(%input, %input2, %input3 : !type, !type, !type)
+ outs(%output : !type) {
+ ^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32):
+ linalg.yield %arg1 : f32
+ } -> !type
+ return %result: !type
+
+}
+
+// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d2, d4, d5)>
+// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d3, d2, d1, d4, d5)>
+// CHECK-LABEL: func.func @collapse5
+// CHECK: %[[IN:.+]] = tensor.collapse_shape %[[INPUT:.+]] {{\[}}[0, 1], [2], [3], [4], [5], [6, 7]] : tensor<2x4x32x32x32x64x128x256xf32> into tensor<8x32x32x32x64x32768xf32>
+// CHECK: %[[IN1:.+]] = tensor.collapse_shape %[[INPUT1:.+]] {{\[}}[0, 1], [2], [3], [4], [5], [6, 7]] : tensor<2x4x32x32x32x64x128x256xf32> into tensor<8x32x32x32x64x32768xf32>
+// CHECK: %[[IN2:.+]] = tensor.collapse_shape %[[INPUT2:.+]] {{\[}}[0, 1], [2], [3], [4], [5], [6, 7]] : tensor<2x4x32x32x32x64x128x256xf32> into tensor<8x32x32x32x64x32768xf32>
+// CHECK: %[[OUT:.+]] = tensor.collapse_shape %[[OUTPUT:.+]] {{\[}}[0, 1], [2], [3], [4], [5], [6, 7]] : tensor<2x4x32x32x32x64x128x256xf32> into tensor<8x32x32x32x64x32768xf32>
+// CHECK: %[[RES:.+]] = flow.dispatch.region
+// CHECK: linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]], #[[$MAP]]], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "parallel"]}
+// CHECK: ins(%[[IN]], %[[IN1]], %[[IN2]] : tensor<8x32x32x32x64x32768xf32>, tensor<8x32x32x32x64x32768xf32>, tensor<8x32x32x32x64x32768xf32>) outs(%[[OUT]] : tensor<8x32x32x32x64x32768xf32>)
+// CHECK: tensor.expand_shape %[[RES]] {{\[}}[0, 1], [2], [3], [4], [5], [6, 7]] : tensor<8x32x32x32x64x32768xf32> into tensor<2x4x32x32x32x64x128x256xf32>
+
+// -----
+
+!type = tensor<32x2x4x8x16x16x64x128xf32>
+util.global private @"__transpose_10_input" {noinline} = dense<1.0> : !type
+func.func @collapse6() -> !type {
+ %cst = arith.constant 0.000000e+00 : f32
+ %c0 = arith.constant 0 : index
+ %input_ptr = util.global.address @"__transpose_10_input" : !util.ptr<!type>
+ %input = util.global.load.indirect %input_ptr : !util.ptr<!type> -> !type
+ %output = tensor.empty() : !type
+
+ // Can collapse (d2, d3) and (d6, d7)
+ %result = linalg.generic { indexing_maps = [
+ affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4, d5, d6, d7)>,
+ affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d5, d4, d6, d7)>],
+ iterator_types = ["parallel", "reduction", "parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]
+ }
+ ins(%input : !type) outs(%output : !type) {
+ ^bb0(%arg1: f32, %arg2: f32):
+ linalg.yield %arg1 : f32
+ } -> !type
+ return %result: !type
+
+}
+
+// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>
+// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d4, d3, d5)>
+// CHECK-LABEL: func.func @collapse6
+// CHECK: %[[IN:.+]] = tensor.collapse_shape %[[INPUT:.+]] {{\[}}[0], [1], [2, 3], [4], [5], [6, 7]] : tensor<32x2x4x8x16x16x64x128xf32> into tensor<32x2x32x16x16x8192xf32>
+// CHECK: %[[OUT:.+]] = tensor.collapse_shape %[[OUTPUT:.+]] {{\[}}[0], [1], [2, 3], [4], [5], [6, 7]] : tensor<32x2x4x8x16x16x64x128xf32> into tensor<32x2x32x16x16x8192xf32>
+// CHECK: %[[RES:.+]] = flow.dispatch.region
+// CHECK: linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP2]]], iterator_types = ["parallel", "reduction", "parallel", "parallel", "parallel", "parallel"]}
+// CHECK: ins(%[[IN]] : tensor<32x2x32x16x16x8192xf32>) outs(%[[OUT]] : tensor<32x2x32x16x16x8192xf32>)
+// CHECK: tensor.expand_shape %[[RES]] {{\[}}[0], [1], [2, 3], [4], [5], [6, 7]] : tensor<32x2x32x16x16x8192xf32> into tensor<32x2x4x8x16x16x64x128xf32>
+
+// -----
+
+!type_out = tensor<2x4x8x16xf32>
+!type_in = tensor<2x4x8xf32>
+util.global private @"__transpose_10_input" {noinline} = dense<1.0> : !type_in
+func.func @collapse7() -> !type_out {
+ %cst = arith.constant 0.000000e+00 : f32
+ %c0 = arith.constant 0 : index
+ %input_ptr = util.global.address @"__transpose_10_input" : !util.ptr<!type_in>
+ %input = util.global.load.indirect %input_ptr : !util.ptr<!type_in> -> !type_in
+ %output = tensor.empty() : !type_out
+
+ %result = linalg.generic { indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d1, d2, d3, d0)>],
+ iterator_types = ["parallel", "parallel", "parallel", "parallel"]
+ }
+ ins(%input : !type_in) outs(%output : !type_out) {
+ ^bb0(%arg1: f32, %arg2: f32):
+ linalg.yield %arg1 : f32
+ } -> !type_out
+ return %result: !type_out
+}
+
+// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1) -> (d1)>
+// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1) -> (d1, d0)>
+// CHECK-LABEL: func.func @collapse7
+// CHECK: %[[IN:.+]] = tensor.collapse_shape %[[INPUT:.+]] {{\[}}[0, 1, 2]] : tensor<2x4x8xf32> into tensor<64xf32>
+// CHECK: %[[OUT:.+]] = tensor.collapse_shape %[[OUTPUT:.+]] {{\[}}[0, 1, 2], [3]] : tensor<2x4x8x16xf32> into tensor<64x16xf32>
+// CHECK: %[[RES:.+]] = flow.dispatch.region
+// CHECK: linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP2]]], iterator_types = ["parallel", "parallel"]}
+// CHECK: ins(%[[IN]] : tensor<64xf32>) outs(%[[OUT]] : tensor<64x16xf32>)
+// CHECK: tensor.expand_shape %[[RES]] {{\[}}[0, 1, 2], [3]] : tensor<64x16xf32> into tensor<2x4x8x16xf32>
+
+// -----
+
+!type_in = tensor<16x4x32x2xf32>
+!type_out = tensor<8x16x4x32x8x2xf32>
+func.func @collapse8() -> !type_out {
+ %cst = arith.constant 0.000000e+00 : f32
+ %c0 = arith.constant 0 : index
+ %input = tensor.empty() : !type_in
+ %output = tensor.empty() : !type_out
+ // Can collapse (d3, d0, d1)
+ %6 = linalg.generic { indexing_maps = [
+ affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d0, d1, d5)>,
+ affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3, d0, d1, d4, d5)>],
+ iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]
+ }
+ ins(%input : !type_in) outs(%output : !type_out) {
+ ^bb0(%arg1: f32, %arg2: f32):
+ %11 = arith.addf %arg1, %arg2 : f32
+ linalg.yield %11 : f32
+ } -> !type_out
+ return %6: !type_out
+}
+
+// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d3)>
+// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK-LABEL: func.func @collapse8
+// CHECK: %[[IN:.+]] = tensor.collapse_shape %[[INPUT:.+]] {{\[}}[0, 1, 2], [3]] : tensor<16x4x32x2xf32> into tensor<2048x2xf32>
+// CHECK: %[[OUT:.+]] = tensor.collapse_shape %[[OUTPUT:.+]] {{\[}}[0], [1, 2, 3], [4], [5]] : tensor<8x16x4x32x8x2xf32> into tensor<8x2048x8x2xf32>
+// CHECK: %[[RES:.+]] = flow.dispatch.region
+// CHECK: linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
+// CHECK: ins(%[[IN]] : tensor<2048x2xf32>) outs(%[[OUT]] : tensor<8x2048x8x2xf32
+// CHECK: tensor.expand_shape %[[RES]] {{\[}}[0], [1, 2, 3], [4], [5]] : tensor<8x2048x8x2xf32> into tensor<8x16x4x32x8x2xf32>
+
+// -----
+
+!type_in = tensor<16x4xf32>
+!type_out = tensor<16x32x4xf32>
+func.func @dont_collapse() -> !type_out {
+ %cst = arith.constant 0.000000e+00 : f32
+ %c0 = arith.constant 0 : index
+ %input = tensor.empty() : !type_in
+ %output = tensor.empty() : !type_out
+ %6 = linalg.generic { indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d0, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
+ iterator_types = ["parallel", "parallel", "parallel"]
+ }
+ ins(%input : !type_in) outs(%output : !type_out) {
+ ^bb0(%arg1: f32, %arg2: f32):
+ %11 = arith.addf %arg1, %arg2 : f32
+ linalg.yield %11 : f32
+ } -> !type_out
+ return %6: !type_out
+}
+// CHECK-LABEL: func.func @dont_collapse
+// CHECK: linalg.generic {indexing_maps = [#[[$MAP:.+]], #[[$MAP2:.+]]], iterator_types = ["parallel", "parallel", "parallel"]}
+
+// -----
+
+!type_in = tensor<2x4x8x16x32x64x128x256xf32>
+!type_out = tensor<2x4x16x64x32x128x256xf32>
+util.global private @"__transpose_10_input" {noinline} = dense<1.0> : !type_in
+
+func.func @collapse9() -> !type_out {
+ %cst = arith.constant 0.000000e+00 : f32
+ %c0 = arith.constant 0 : index
+ %input_ptr = util.global.address @"__transpose_10_input" : !util.ptr<!type_in>
+ %input = util.global.load.indirect %input_ptr : !util.ptr<!type_in> -> !type_in
+ %output = tensor.empty() : !type_out
+
+ // Can collapse (d0, d1) and (d6, d7)
+ %result = linalg.generic { indexing_maps = [
+ affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4, d5, d6, d7)>,
+ affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d3, d5, d4, d6, d7)>],
+ iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "parallel", "parallel", "parallel"]
+ }
+ ins(%input : !type_in) outs(%output : !type_out) {
+ ^bb0(%arg1: f32, %arg2: f32):
+ linalg.yield %arg1 : f32
+ } -> !type_out
+ return %result: !type_out
+}
+
+
+// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>
+// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d4, d3, d5)>
+// CHECK-LABEL: func.func @collapse9
+// CHECK: %[[RES:.+]] = flow.dispatch.region
+// CHECK: linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP2]]], iterator_types = ["parallel", "reduction", "parallel", "parallel", "parallel", "parallel"]}
+
+
+// -----
+
+!type_in = tensor<10x10x30xf32>
+!type_out = tensor<20x10x10x30x20xf32>
+
+func.func @collapse10() -> !type_out {
+ %cst = arith.constant 0.000000e+00 : f32
+ %c0 = arith.constant 0 : index
+ %input = tensor.empty() : !type_in
+ %output = tensor.empty() : !type_out
+
+ // Can collapse (d1, d3, d0)
+ %result = linalg.generic {
+ indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d1, d3, d0)>,
+ affine_map<(d0, d1, d2, d3, d4) -> (d2, d1, d3, d0, d4)>],
+ iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]}
+ ins(%input : !type_in) outs(%output : !type_out) {
+ ^bb0(%arg1: f32, %arg2: f32):
+ linalg.yield %arg1 : f32
+ } -> !type_out
+
+ return %result: !type_out
+}
+
+// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2) -> (d0)>
+// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1, d2) -> (d1, d0, d2)>
+// CHECK-LABEL: func.func @collapse10
+// CHECK: %[[RES:.+]] = flow.dispatch.region
+// CHECK: linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel"]}
+
+// -----
+
+!type_in = tensor<10x20xf32>
+!type_out = tensor<10x20xf32>
+
+func.func @collapse11() -> !type_out {
+ %cst = arith.constant 0.000000e+00 : f32
+ %c0 = arith.constant 0 : index
+ %input = tensor.empty() : !type_in
+ %output = tensor.empty() : !type_out
+
+ // Can collapse (d1, d0)
+ %result = linalg.generic {
+ indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d1, d0)>],
+ iterator_types = ["parallel", "parallel"] }
+ ins(%input : !type_in) outs(%output : !type_out) {
+ ^bb0(%arg1: f32, %arg2: f32):
+ linalg.yield %arg1 : f32
+ } -> !type_out
+
+ return %result: !type_out
+}
+
+
+// CHECK: #[[$MAP:.+]] = affine_map<(d0) -> (d0)>
+// CHECK-LABEL: func.func @collapse11
+// CHECK: %[[RES:.+]] = flow.dispatch.region
+// CHECK: linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP]]], iterator_types = ["parallel"]}
+
+// -----
+
+!type = tensor<16x32xi32>
+func.func @dont_collapse_dueto_index(%height : index, %width : index) -> !type {
+ %init_source = tensor.empty() : !type
+ %source = linalg.generic {
+ indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel"]}
+ outs(%init_source : !type) {
+ ^bb0(%b0 : i32):
+ %outer = linalg.index 0 : index
+ %inner = linalg.index 1 : index
+ %strided = arith.muli %outer, %width : index
+ %linearized = arith.addi %inner, %strided : index
+ %linearized_i32 = arith.index_cast %linearized : index to i32
+ linalg.yield %linearized_i32 : i32
+ } -> !type
+ return %source : !type
+}
+
+// CHECK-LABEL: func.func @dont_collapse
+// CHECK: linalg.generic {indexing_maps = [#[[$MAP:.+]]], iterator_types = ["parallel", "parallel"]}
+
+// -----
+
+!type = tensor<2x4x8x16x32x64xf32>
+util.global private @"__transpose_10_input" {noinline} = dense<1.0> : !type
+
+func.func @collapse12() -> (!type,!type,!type,!type) {
+ %cst = arith.constant 0.000000e+00 : f32
+ %c0 = arith.constant 0 : index
+ %input_ptr = util.global.address @"__transpose_10_input" : !util.ptr<!type>
+ %input = util.global.load.indirect %input_ptr : !util.ptr<!type> -> !type
+ %output = tensor.empty() : !type
+ %output1 = tensor.empty() : !type
+ %output2 = tensor.empty() : !type
+ %output3 = tensor.empty() : !type
+
+ %6, %7, %8, %9 = linalg.generic { indexing_maps = [
+ affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d4, d3, d5)>,
+ affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d4, d3, d5)>,
+ affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d4, d3, d5)>,
+ affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d4, d3, d5)>,
+ affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d4, d3, d5)>],
+ iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]
+ }
+ ins(%input : !type) outs(%output, %output1, %output2, %output3 : !type, !type, !type, !type) {
+ ^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32, %arg5: f32):
+ %0 = arith.addf %arg1, %arg2 : f32
+ %1 = arith.addf %0, %arg3 : f32
+ %2 = arith.addf %1, %arg4 : f32
+ %3 = arith.addf %2, %arg5 : f32
+ linalg.yield %0,%1,%2,%3 : f32, f32, f32, f32
+ } -> (!type,!type,!type,!type)
+ return %6, %7, %8, %9 : !type,!type,!type,!type
+}
+
+// CHECK: #[[$MAP:.+]] = affine_map<(d0) -> (d0)>
+// CHECK-LABEL: func.func @collapse12
+// CHECK: %[[RES:.+]] = flow.dispatch.region
+// CHECK: linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP]], #[[$MAP]], #[[$MAP]], #[[$MAP]]], iterator_types = ["parallel"]}
+
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir
index 680a9aa..7c0b8fc 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir
@@ -1212,7 +1212,7 @@
%cst = arith.constant dense<0.0> : tensor<3x3xf32>
%init = tensor.empty() : tensor<2x2xf32>
%0 = linalg.generic {
- indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
+ indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1, d0)>],
iterator_types = ["parallel", "parallel"]}
ins(%arg0 : tensor<2x2xf32>) outs(%init : tensor<2x2xf32>) {
^bb0(%b0 : f32, %b1 : f32) :
diff --git a/tests/transform_dialect/cuda/vecadd2d.mlir b/tests/transform_dialect/cuda/vecadd2d.mlir
index 405914f..938afa1 100644
--- a/tests/transform_dialect/cuda/vecadd2d.mlir
+++ b/tests/transform_dialect/cuda/vecadd2d.mlir
@@ -32,6 +32,9 @@
}
// RUN: iree-opt %s --iree-hal-target-backends=cuda \
+/// We must disable collapsing linalg.generic, because transform dialect maps
+/// dimensions explicitly and is not aware of collapsing
+// RUN: --iree-flow-form-dispatch-regions-collapse=false \
// RUN: --iree-abi-transformation-pipeline \
// RUN: --iree-flow-transformation-pipeline \
// RUN: --iree-stream-transformation-pipeline \
@@ -41,6 +44,7 @@
// RUN: FileCheck %s --check-prefix=CHECK
// RUN: iree-opt %s --iree-hal-target-backends=cuda \
+// RUN: --iree-flow-form-dispatch-regions-collapse=false \
// RUN: --iree-abi-transformation-pipeline \
// RUN: --iree-flow-transformation-pipeline \
// RUN: --iree-stream-transformation-pipeline \
@@ -50,6 +54,7 @@
// RUN: FileCheck %s --check-prefix=CHECK-PARTIAL-TILE
// RUN: iree-compile %s --iree-hal-target-backends=cuda \
+// RUN: --iree-flow-form-dispatch-regions-collapse=false \
// RUN: --iree-opt-const-expr-hoisting=false --iree-opt-const-eval=false \
/// Constant JIT'ing must be disabled because the transform-dialect debug
/// flags leak to the JIT session, which doesn't know what to do with them.