[Flow] Allow CollapseDimensions pass to fold reduction dimensions as well (#14656)
This makes the CollapseReductionDims pass redundant and can be dropped.
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Flow/Transforms/BUILD.bazel
index a3ad95f..515bb3b 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/BUILD.bazel
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/BUILD.bazel
@@ -35,7 +35,6 @@
"CleanupTensorShapes.cpp",
"CloneProducersIntoDispatchRegions.cpp",
"CollapseDimensions.cpp",
- "CollapseReductionDims.cpp",
"Convert1X1FilterConv2DToMatmul.cpp",
"ConvertRegionToWorkgroups.cpp",
"ConvertToFlow.cpp",
@@ -100,6 +99,7 @@
"//llvm-external-projects/iree-dialects:IREELinalgTransformDialect",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:AffineDialect",
+ "@llvm-project//mlir:AffineUtils",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:ArithUtils",
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt
index c3b86cc..bab9e87 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt
@@ -34,7 +34,6 @@
"CleanupTensorShapes.cpp"
"CloneProducersIntoDispatchRegions.cpp"
"CollapseDimensions.cpp"
- "CollapseReductionDims.cpp"
"Convert1X1FilterConv2DToMatmul.cpp"
"ConvertRegionToWorkgroups.cpp"
"ConvertToFlow.cpp"
@@ -82,6 +81,7 @@
IREELinalgTransformDialect
LLVMSupport
MLIRAffineDialect
+ MLIRAffineUtils
MLIRAnalysis
MLIRArithDialect
MLIRArithUtils
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/CollapseDimensions.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/CollapseDimensions.cpp
index 6b0445e..19a0a43 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/CollapseDimensions.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/CollapseDimensions.cpp
@@ -12,6 +12,8 @@
#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
+#include "mlir/Analysis/SliceAnalysis.h"
+#include "mlir/Dialect/Affine/Utils.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
@@ -27,6 +29,8 @@
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include <deque>
+
#define DEBUG_TYPE "iree-flow-collapse-dimensions"
namespace mlir {
@@ -48,47 +52,88 @@
/// Searches the same sequence in all the affine maps and collapses these
/// dimensions. It only applies these to "parallel" loops without mixing them
-/// with "reduction" types.
+/// with "reduction" types. It is expected that the `genericOp` has projected
+/// permutations only as indexing maps. (Checked using `isEligibleForCollapse`).
static SmallVector<ReassociationIndices>
getCollapsibleLoops(linalg::GenericOp genericOp) {
SmallVector<ReassociationIndices> contiguousLoops;
- SmallVector<unsigned> pDims;
+ SmallVector<unsigned> pDims, rDims;
genericOp.getParallelDims(pDims);
- if (pDims.size() < 2)
- return contiguousLoops;
-
- llvm::SmallDenseSet<unsigned> pLoops(pDims.begin(), pDims.end());
+ genericOp.getReductionDims(rDims);
+ llvm::SmallDenseSet<unsigned> pDimsSet, rDimsSet;
+ pDimsSet.insert(pDims.begin(), pDims.end());
+ rDimsSet.insert(rDims.begin(), rDims.end());
auto hasAllMapsSameSequence = [&](AffineExpr preExpr, AffineExpr nextExpr) {
+ // Check that all indexing maps of the `genericOp`
+ // - Either both `preExpr` and `nextExpr` contiguous, or
+ // - are missing in
+ // Then `preExpr` and `nextExpr` can be collapsed.
for (AffineMap map : genericOp.getIndexingMapsArray()) {
- bool foundSeq = false;
+ // If map has no results, no need to check.
+ if (map.getNumResults() == 0) {
+ continue;
+ }
for (auto [index, resultExpr] : llvm::enumerate(map.getResults())) {
+ // If we find the preExpr, we should find the nextExpr.
+ if (resultExpr == preExpr) {
+ if (index == map.getNumResults() - 1) {
+ // Reached end of list. Return false;
+ return false;
+ }
+ if (map.getResult(index + 1) != nextExpr) {
+ return false;
+ }
+ }
+ // If we find nextExpr the previous one should be `prevExpr`.
+ // This is redundant check for the most part, but is cheap enough, so
+ // #YOLO
if (resultExpr == nextExpr) {
- foundSeq = (index > 0 && preExpr == map.getResult(index - 1));
- break;
+ if (index == 0) {
+ // match at beginning of the list. Return false;
+ return false;
+ }
+ if (map.getResult(index - 1) != preExpr) {
+ return false;
+ }
}
}
- if (!foundSeq)
- return false;
}
return true;
};
+ auto hasSameIteratorType = [&](AffineExpr preExpr, AffineExpr nextExpr) {
+ unsigned prePos = preExpr.cast<AffineDimExpr>().getPosition();
+ unsigned nextPos = nextExpr.cast<AffineDimExpr>().getPosition();
+ return (pDimsSet.count(prePos) && pDimsSet.count(nextPos)) ||
+ (rDimsSet.count(prePos) && rDimsSet.count(nextPos));
+ };
ReassociationIndices range;
AffineExpr preExpr;
+ // Find the largest sequence of dimensions that are
+ // - Either preserved in all maps, or
+ // - are completely absent
+ // This sequence can be collapsed. To find the sequence,
+ // 1) Take the result expressions of one of the indexing maps
+ // 2) Find a sequence of 2 that is found in all maps
+ // 3) Then take last element of this sequence and the next
+ // result expression, and check if this sequence of 2 is
+ // found in all maps. If so, add to sequence (to get a sequence of 3)
+ // and repeat till the last element of sequence and the next result
+ // expression is not found as a sequence in all maps.
for (auto nextExpr : genericOp.getIndexingMapsArray().front().getResults()) {
- unsigned pos = nextExpr.cast<AffineDimExpr>().getPosition();
if (!range.empty()) {
- if (!hasAllMapsSameSequence(preExpr, nextExpr) || !pLoops.count(pos)) {
- if (range.size() > 1)
+ if (!hasAllMapsSameSequence(preExpr, nextExpr) ||
+ !hasSameIteratorType(preExpr, nextExpr)) {
+ if (range.size() > 1) {
contiguousLoops.push_back({range.begin(), range.end()});
+ }
range.clear();
}
}
+ range.push_back(nextExpr.cast<AffineDimExpr>().getPosition());
preExpr = nextExpr;
- if (pLoops.count(pos))
- range.push_back(pos);
}
if (range.size() > 1)
contiguousLoops.push_back(range);
@@ -107,22 +152,6 @@
return contiguousLoops;
}
-/// Collapse possible dimension of the given linalg.generic
-static FailureOr<SmallVector<Value>>
-collapseLinalgGeneric(IRRewriter &rewriter, linalg::GenericOp genericOp,
- SmallVector<ReassociationIndices> &collapseIndices) {
- rewriter.setInsertionPoint(genericOp->getParentOp());
- FailureOr<SmallVector<Value>> replacements =
- mlir::linalg::collapseGenericOpIterationDims(genericOp, collapseIndices,
- rewriter);
- if (failed(replacements) || replacements->empty()) {
- return rewriter.notifyMatchFailure(genericOp,
- "failed to collapse dimensions");
- }
-
- return replacements;
-}
-
/// Returns true if the given op is collapsable.
static bool isEligibleForCollapse(linalg::GenericOp genericOp) {
// TODO(guray) There is no mechanism to tell the collapsed indexes to
@@ -154,101 +183,298 @@
/// without any producers.
static FailureOr<linalg::GenericOp>
findRootGenericOp(DispatchRegionOp regionOp) {
- SmallVector<Operation *> computeOps;
- auto &ops = regionOp.getBody().front().getOperations();
- for (Operation &op : ops) {
- if (isa<TilingInterface>(op))
- computeOps.push_back(&op);
+ if (!llvm::hasSingleElement(regionOp.getBody())) {
+ return failure();
}
- // Looking for root without producer
- if (computeOps.size() != 1 || ops.size() != 2)
+
+ // Check the yielded value is from a single `linalg.generic`.
+ auto returnOp =
+ cast<Flow::ReturnOp>(regionOp.getBody().front().getTerminator());
+ auto collapsibleOp = dyn_cast_or_null<linalg::GenericOp>(
+ returnOp->getOperand(0).getDefiningOp());
+ if (!collapsibleOp) {
return failure();
- auto genericOp = llvm::dyn_cast<linalg::GenericOp>(computeOps.front());
- if (!genericOp)
+ }
+ for (auto returnVal : returnOp->getOperands().drop_front()) {
+ if (returnVal.getDefiningOp() != collapsibleOp.getOperation()) {
+ return failure();
+ }
+ }
+
+ // Check that the operands of the generic op are defined outside the dispatch.
+ for (OpOperand *inputOperands : collapsibleOp.getDpsInputOperands()) {
+ Operation *definingOp = inputOperands->get().getDefiningOp();
+ if (definingOp &&
+ definingOp->getParentOfType<DispatchRegionOp>() == regionOp) {
+ return failure();
+ }
+ }
+
+ // Check that the output is either a `tensor.empty` or a `linalg.fill` op by
+ // traversing the operations that define the `init` operands of the
+ // `collapsibleOp`.
+ std::deque<Operation *> worklist;
+ llvm::SmallDenseSet<Operation *> visited;
+ auto addDefiningOpToWorklist = [&](Value v) {
+ Operation *definingOp = v.getDefiningOp();
+ if (definingOp &&
+ definingOp->getParentOfType<DispatchRegionOp>() == regionOp &&
+ !visited.count(definingOp)) {
+ worklist.push_back(definingOp);
+ visited.insert(definingOp);
+ }
+ };
+ for (Value initOperand : collapsibleOp.getDpsInits()) {
+ addDefiningOpToWorklist(initOperand);
+ }
+
+ while (!worklist.empty()) {
+ Operation *op = worklist.front();
+ worklist.pop_front();
+ if (auto fillOp = dyn_cast<linalg::FillOp>(op)) {
+ addDefiningOpToWorklist(fillOp.getDpsInitOperand(0)->get());
+ continue;
+ }
+ if (isa<tensor::EmptyOp>(op)) {
+ continue;
+ }
return failure();
- return genericOp;
+ }
+ return collapsibleOp;
}
-/// Generate a new dispatch.region and workload according with the collapsed
-/// linalg Generic Op
-static LogicalResult
-generateNewDispatchRegion(IRRewriter &rewriter, DispatchRegionOp regionOp,
- SmallVector<Value> collapseResults,
- linalg::GenericOp newGenericOp) {
- OpBuilder::InsertionGuard g(rewriter);
- rewriter.setInsertionPoint(regionOp->getParentOp());
-
- auto maybeRegionOp = Flow::wrapOpInDispatchRegion(rewriter, newGenericOp);
- if (failed(maybeRegionOp))
+/// Hoist `tensor.collapse_shape` ops at the beginning of the `dispatchOp`
+/// and `tensor.expand_shape` ops at the end of the `dispatchOp`, out of the
+/// dispatch.
+static FailureOr<DispatchRegionOp>
+hoistTensorReshapesOutOfDispatchRegion(RewriterBase &rewriter,
+ DispatchRegionOp dispatchOp) {
+ // Only do this for `dispatchOp` with a single operation.
+ if (!llvm::hasSingleElement(dispatchOp.getBody())) {
return failure();
+ }
+ Block &body = dispatchOp.getBody().front();
+ auto returnOp = cast<Flow::ReturnOp>(body.getTerminator());
- // Replace old regionOp with the result of collapse
- rewriter.replaceOp(regionOp, collapseResults);
+ // 1. Get the slice of operations within `dispatchOp` that produce the yielded
+ // value.
+ BackwardSliceOptions sliceOptions;
+ sliceOptions.filter = [&](Operation *op) {
+ return op->getParentOfType<DispatchRegionOp>();
+ };
+ SetVector<Operation *> slice;
+ getBackwardSlice(returnOp, &slice, sliceOptions);
- return success();
+ // 2. Get the leaf operations that are tensor.collapse_shape ops.
+ SmallVector<tensor::CollapseShapeOp> leafs;
+ for (Operation *op : slice) {
+ auto collapseShapeOp = dyn_cast<tensor::CollapseShapeOp>(op);
+ if (!collapseShapeOp) {
+ continue;
+ }
+ if (llvm::all_of(op->getOperands(), [&](Value operand) {
+ Operation *definingOp = operand.getDefiningOp();
+ return !definingOp || slice.count(definingOp) == 0;
+ })) {
+ leafs.push_back(collapseShapeOp);
+ }
+ }
+
+ // 3. Clone the leaf `tensor.collapse_shape` ops outside the dispatch.
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(dispatchOp);
+ for (auto reshapeOp : leafs) {
+ Operation *clonedOp = rewriter.clone(*reshapeOp.getOperation());
+ rewriter.replaceOp(reshapeOp, clonedOp->getResults());
+ }
+
+ // 4. From the yielded values find any that are produced by
+ // `tensor.expand_shape` operation and move them out of the dispatch. For
+ // this a new `DispatchRegionOp` is needed. For values that are yielded and
+ // produced from `tensor.expand_shape`, the type of the result changes. The
+ // dynamic dimensions of the result type also need to be updated.
+ SmallVector<Type> newReturnTypes;
+ SmallVector<Value> newDynamicDims;
+ SmallVector<Value> newYieldVals;
+ SmallVector<SmallVector<ReassociationIndices>> allReassociationIndices;
+ ValueRange dynamicDimsList = dispatchOp.getResultDims();
+ Location loc = dispatchOp.getLoc();
+ for (Value yieldedValue : returnOp->getOperands()) {
+ auto expandShapeOp = yieldedValue.getDefiningOp<tensor::ExpandShapeOp>();
+ if (!expandShapeOp) {
+ // 4a. Keep the same yield value if the producer is not a
+ // `tensor.expand_shape` op.
+ newReturnTypes.push_back(yieldedValue.getType());
+ newYieldVals.push_back(yieldedValue);
+ continue;
+ }
+
+ // 4b. The return type is same as the type of the source of the
+ // `tensor.expand_shape`.
+ RankedTensorType collapsedShapeType = expandShapeOp.getSrcType();
+ newReturnTypes.push_back(collapsedShapeType);
+ newYieldVals.push_back(expandShapeOp.getSrc());
+ SmallVector<ReassociationIndices> reassociation =
+ expandShapeOp.getReassociationIndices();
+ ArrayRef<int64_t> expandedShape = expandShapeOp.getResultType().getShape();
+
+ // 4c. Dynamic dims of the result shape is obtained by taking the static
+ // shape + dynamic dims and collapsing them using the same reassociation
+ // map as the `tensor.expand_shape`.
+ for (auto [index, shape] : llvm::enumerate(collapsedShapeType.getShape())) {
+ int64_t staticCollapsedShape = 1;
+ SmallVector<OpFoldResult> dynamicCollapsedDims;
+ for (auto collapsedDim : reassociation[index]) {
+ if (expandedShape[collapsedDim] == ShapedType::kDynamic) {
+ dynamicCollapsedDims.push_back(dynamicDimsList.front());
+ dynamicDimsList = dynamicDimsList.drop_front();
+ } else {
+ staticCollapsedShape *= expandedShape[collapsedDim];
+ }
+ }
+
+ if (dynamicCollapsedDims.empty()) {
+ // If there are no dynamic dims, there is nothing to do.
+ continue;
+ }
+ SmallVector<AffineExpr> exprs(dynamicCollapsedDims.size());
+ bindSymbolsList(rewriter.getContext(),
+ MutableArrayRef<AffineExpr>(exprs));
+ AffineExpr multiplyAll = exprs.front();
+ for (auto expr : ArrayRef<AffineExpr>(exprs).drop_front()) {
+ multiplyAll = multiplyAll * expr;
+ }
+ if (staticCollapsedShape != 1) {
+ multiplyAll = multiplyAll * staticCollapsedShape;
+ }
+ OpFoldResult collapsedShape = affine::makeComposedFoldedAffineApply(
+ rewriter, loc, multiplyAll, dynamicCollapsedDims);
+ newDynamicDims.push_back(
+ getValueOrCreateConstantIndexOp(rewriter, loc, collapsedShape));
+ }
+ allReassociationIndices.emplace_back(std::move(reassociation));
+ }
+
+ // 5. Create the new dispatch op.
+ auto newDispatchOp = rewriter.create<DispatchRegionOp>(
+ loc, newReturnTypes, newDynamicDims, dispatchOp.getWorkload());
+
+ // 5a. Move the body over, but replace the `flow.return` to use the new yield
+ // values.
+ Region &newBody = newDispatchOp.getBody();
+ rewriter.inlineRegionBefore(dispatchOp.getBody(), newBody, newBody.begin());
+ {
+ Operation *terminator = newBody.front().getTerminator();
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(terminator);
+ rewriter.replaceOpWithNewOp<Flow::ReturnOp>(terminator, newYieldVals);
+ }
+
+ // 5b. Move the workgroup count region over.
+ Region &workgroupCountRegion = dispatchOp.getWorkgroupCount();
+ if (!workgroupCountRegion.empty()) {
+ Region &newWorkgroupCountRegion = newDispatchOp.getWorkgroupCount();
+ rewriter.inlineRegionBefore(workgroupCountRegion, newWorkgroupCountRegion,
+ newWorkgroupCountRegion.begin());
+ }
+
+ // 6. Map the modified result values back to their original shape using
+ // `tensor.expand_shape` operations.
+ ArrayRef<SmallVector<ReassociationIndices>> allReassociationIndicesRef(
+ allReassociationIndices);
+ for (auto [index, returnValue] :
+ llvm::enumerate(newDispatchOp.getResults())) {
+ Value origResult = dispatchOp->getResult(index);
+ if (returnValue.getType() == origResult.getType()) {
+ rewriter.replaceAllUsesWith(origResult, returnValue);
+ continue;
+ }
+ auto newExpandShapeOp = rewriter.create<tensor::ExpandShapeOp>(
+ loc, origResult.getType(), returnValue,
+ allReassociationIndicesRef.front());
+ allReassociationIndicesRef = allReassociationIndicesRef.drop_front();
+ rewriter.replaceAllUsesWith(origResult, newExpandShapeOp.getResult());
+ }
+ rewriter.eraseOp(dispatchOp);
+ return newDispatchOp;
}
/// Traverses DispatchRegionOps to find linalg genericOps that has no
/// producers and tries to collapse its dimensions.
-static LogicalResult collapseDimensions(IRRewriter &rewriter,
- DispatchRegionOp ®ionOp) {
+static bool collapseDimensions(IRRewriter &rewriter,
+ DispatchRegionOp ®ionOp) {
// Step 1. Find the root linalg.generic Op with no producer
std::optional<linalg::GenericOp> genericOp = findRootGenericOp(regionOp);
if (!genericOp.has_value())
- return success();
+ return false;
// Step 2. Check whether it is possible to collapse
if (!isEligibleForCollapse(genericOp.value()))
- return success();
+ return false;
SmallVector<ReassociationIndices> collapseIndices;
collapseIndices = getCollapsibleLoops(genericOp.value());
if (collapseIndices.empty())
- return success();
+ return false;
// Step 3. Collapse dimensions
- auto maybeReplacements =
- collapseLinalgGeneric(rewriter, genericOp.value(), collapseIndices);
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(genericOp.value());
+
+ FailureOr<SmallVector<Value>> maybeReplacements =
+ mlir::linalg::collapseGenericOpIterationDims(genericOp.value(),
+ collapseIndices, rewriter);
if (failed(maybeReplacements))
- return failure();
- auto expandshapeOp =
- maybeReplacements->front().getDefiningOp<tensor::ExpandShapeOp>();
- if (!expandshapeOp)
- return failure();
- auto newGenericOp =
- expandshapeOp.getOperand().getDefiningOp<linalg::GenericOp>();
- if (!newGenericOp)
- return failure();
-
- // Step 4. Generate new dispatch region and replace old one users
- if (failed(generateNewDispatchRegion(rewriter, regionOp, *maybeReplacements,
- newGenericOp)))
- return failure();
-
- return success();
+ return false;
+ rewriter.replaceOp(genericOp.value(), maybeReplacements.value());
+ return true;
}
void CollapseDimensionsPass::runOnOperation() {
mlir::FunctionOpInterface funcOp = getOperation();
- IRRewriter rewriter(funcOp->getContext());
+ MLIRContext *context = funcOp->getContext();
+ IRRewriter rewriter(context);
- auto walkResult = funcOp->walk([&](DispatchRegionOp regionOp) {
- if (failed(collapseDimensions(rewriter, regionOp)))
- return WalkResult::interrupt();
- return WalkResult::advance();
+ SmallVector<DispatchRegionOp> modifiedDispatchOps;
+ funcOp->walk([&](DispatchRegionOp dispatchOp) {
+ if (collapseDimensions(rewriter, dispatchOp)) {
+ modifiedDispatchOps.push_back(dispatchOp);
+ }
});
- if (walkResult.wasInterrupted()) {
- funcOp->emitOpError("failed in collapsing dimensions pass");
- return signalPassFailure();
- }
- RewritePatternSet canonicalizationPatterns(&getContext());
- memref::populateResolveRankedShapedTypeResultDimsPatterns(
- canonicalizationPatterns);
- tensor::populateFoldTensorEmptyPatterns(canonicalizationPatterns);
- if (failed(applyPatternsAndFoldGreedily(
- funcOp, std::move(canonicalizationPatterns)))) {
- funcOp->emitOpError("failed to apply cleanup patterns");
- return signalPassFailure();
+ LLVM_DEBUG({
+ llvm::dbgs() << "[CollapseDims] : After collapsing generic ops: \n";
+ funcOp.print(llvm::dbgs());
+ llvm::dbgs() << "\n";
+ });
+
+ // Move all the `tensor.collapse_shape` leafs and `tensor.expand_shape` roots
+ // of the modified dispatches out of the dispatch.
+ for (auto dispatchOp : modifiedDispatchOps) {
+ Region &body = dispatchOp.getBody();
+ assert(llvm::hasSingleElement(body) && "expected op with a single body");
+ Block &block = body.front();
+ RewritePatternSet moveReshapeOps(&getContext());
+ linalg::FillOp::getCanonicalizationPatterns(moveReshapeOps, context);
+ memref::populateResolveRankedShapedTypeResultDimsPatterns(moveReshapeOps);
+ tensor::populateFoldTensorEmptyPatterns(moveReshapeOps);
+ SmallVector<Operation *> candidateOps;
+ block.walk([&](Operation *op) {
+ if (isa<tensor::CollapseShapeOp>(op)) {
+ candidateOps.push_back(op);
+ }
+ });
+ if (failed(
+ applyOpPatternsAndFold(candidateOps, std::move(moveReshapeOps)))) {
+ funcOp.emitOpError(
+ "failed to propagate reshape ops introduced during collapse");
+ return signalPassFailure();
+ }
+
+ if (failed(hoistTensorReshapesOutOfDispatchRegion(
+ rewriter, cast<DispatchRegionOp>(dispatchOp)))) {
+ dispatchOp->emitOpError("failed to hoist reshapes out of dispatch");
+ return signalPassFailure();
+ }
}
}
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/CollapseReductionDims.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/CollapseReductionDims.cpp
deleted file mode 100644
index 407777b..0000000
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/CollapseReductionDims.cpp
+++ /dev/null
@@ -1,95 +0,0 @@
-// Copyright 2022 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#include "iree/compiler/Dialect/Flow/Transforms/PassDetail.h"
-#include "iree/compiler/Dialect/Flow/Transforms/Passes.h"
-#include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h"
-#include "mlir/Dialect/Linalg/IR/Linalg.h"
-#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-
-namespace mlir {
-namespace iree_compiler {
-namespace IREE {
-namespace Flow {
-
-namespace {
-
-/// Check whether the given dimensions are contiguous in the result map.
-/// If non of the dimension are present in the map return true as well.
-static bool hasContiguousDims(AffineMap map, ArrayRef<unsigned> dims) {
- if (!map.isProjectedPermutation())
- return false;
- llvm::SmallDenseSet<unsigned> existingDims(dims.begin(), dims.end());
- for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
- if (map.getDimPosition(i) != dims[0]) {
- if (existingDims.count(map.getDimPosition(i))) {
- return false;
- }
- continue;
- }
- // Check that the following dimensions are match the order of `dims`
- for (unsigned j = 1, numDims = dims.size(); j < numDims; j++) {
- unsigned pos = i + j;
- if (pos >= map.getNumResults() || map.getDimPosition(pos) != dims[j]) {
- return false;
- }
- }
- break;
- }
- return true;
-}
-
-static SmallVector<ReassociationIndices>
-collapseDimensions(linalg::GenericOp genericOp) {
- SmallVector<ReassociationIndices> collapseIndices;
-
- if (!isNonNullAndOutsideDispatch(genericOp)) {
- return collapseIndices;
- }
-
- SmallVector<unsigned> reductionDims;
- genericOp.getReductionDims(reductionDims);
- if (reductionDims.size() < 2)
- return collapseIndices;
-
- for (AffineMap map : genericOp.getIndexingMapsArray()) {
- if (!hasContiguousDims(map, reductionDims))
- return collapseIndices;
- }
- ReassociationIndices indices;
- for (unsigned dim : reductionDims) {
- indices.push_back(int64_t(dim));
- }
- collapseIndices.push_back(indices);
- return collapseIndices;
-}
-
-struct CollapseDimsPass : public CollapseDimsBase<CollapseDimsPass> {
- void getDependentDialects(DialectRegistry ®istry) const override {
- registry.insert<linalg::LinalgDialect>();
- }
-
- void runOnOperation() override {
- RewritePatternSet patterns(&getContext());
- linalg::populateCollapseDimensions(patterns, collapseDimensions);
- if (failed(applyPatternsAndFoldGreedily(getOperation(),
- std::move(patterns)))) {
- return signalPassFailure();
- }
- }
-};
-
-} // namespace
-
-std::unique_ptr<Pass> createCollapseDimsPass() {
- return std::make_unique<CollapseDimsPass>();
-}
-
-} // namespace Flow
-} // namespace IREE
-} // namespace iree_compiler
-} // namespace mlir
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FusionOfTensorOps.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FusionOfTensorOps.cpp
index 20746e7..a3f1f67 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FusionOfTensorOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FusionOfTensorOps.cpp
@@ -114,10 +114,16 @@
// broadcast this ends up redundantly computing operations without more
// parallelism.
if (auto linalgConsumerOp = dyn_cast<linalg::LinalgOp>(consumerOp)) {
- return linalgConsumerOp.getNumParallelLoops() ==
- linalgConsumerOp.getNumLoops() ||
- linalgConsumerOp.getMatchingIndexingMap(fusedOperand)
- .isPermutation();
+ if (linalgConsumerOp.getNumParallelLoops() ==
+ linalgConsumerOp.getNumLoops()) {
+ return true;
+ }
+ if (linalgConsumerOp.getNumReductionLoops() != 1 ||
+ !linalgConsumerOp.getMatchingIndexingMap(fusedOperand)
+ .isPermutation()) {
+ return false;
+ }
+ return true;
}
// All other cases dont fuse.
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.cpp
index beece70..524977e 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.cpp
@@ -129,7 +129,6 @@
// Preprocess the input to a form more amenable for fusion
.addPass(createRaiseSpecialOps)
.addPass(createInterchangeGenericOpsPass)
- .addPass(createCollapseDimsPass)
.addPass(memref::createResolveShapedTypeResultDimsPass)
.addPass(mlir::createCanonicalizerPass)
.addPass(mlir::createCSEPass)
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.h b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.h
index b8db1df..16f4c1e 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.h
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.h
@@ -239,9 +239,6 @@
// Create a pass to split reduction dimension.
std::unique_ptr<Pass> createSplitReductionPass();
-// Create a pass to collapse reduction dimensions
-std::unique_ptr<Pass> createCollapseDimsPass();
-
//===----------------------------------------------------------------------===//
// Module Analysis and Finalization
//===----------------------------------------------------------------------===//
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.td b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.td
index 2aa8ef3..52fa8fb 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.td
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.td
@@ -26,12 +26,6 @@
let constructor = "mlir::iree_compiler::IREE::Flow::createCleanupNumericNarrowingPass()";
}
-def CollapseDims :
- Pass<"iree-flow-collapse-dims", ""> {
- let summary = "Collapse reduction dimensions when possible.";
- let constructor = "mlir::iree_compiler::IREE::Flow::createCollapseDimsPass()";
-}
-
def Convert1X1FilterConv2DToMatmul:
Pass<"iree-flow-convert-1x1-filter-conv2d-to-matmul", ""> {
let summary = "Convert linalg convolution ops with 1x1 kernels into linalg matrix multiplication ops.";
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/BUILD.bazel
index e506ec9..b8cd4bc 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/BUILD.bazel
@@ -20,7 +20,6 @@
"cleanup_numeric_narrowing.mlir",
"cleanup_tensor_shapes.mlir",
"clone_producers_into_dispatch_regions.mlir",
- "collapse_reduction.mlir",
"conv1x1_to_matmul.mlir",
"convert_region_to_workgroups.mlir",
"deduplicate_executables.mlir",
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt
index 9c72420..2bf5af2 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt
@@ -19,7 +19,6 @@
"cleanup_tensor_shapes.mlir"
"clone_producers_into_dispatch_regions.mlir"
"collapse_linalg_generic_on_tensors.mlir"
- "collapse_reduction.mlir"
"conv1x1_to_matmul.mlir"
"convert_region_to_workgroups.mlir"
"deduplicate_executables.mlir"
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/collapse_linalg_generic_on_tensors.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/collapse_linalg_generic_on_tensors.mlir
index 21b6fe9..5e4ec9b 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/collapse_linalg_generic_on_tensors.mlir
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/collapse_linalg_generic_on_tensors.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(func.func(iree-flow-form-dispatch-regions{fuse-multi-use=true}, iree-flow-collapse-dimensions))" %s | FileCheck %s
+// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(func.func(iree-flow-form-dispatch-regions{fuse-multi-use=true}, iree-flow-clone-producers-into-dispatch-regions, iree-flow-collapse-dimensions, cse))" %s | FileCheck %s
!type = tensor<2x4x8x16x32x64xf32>
util.global private @"__transpose_10_input" {noinline} = dense<1.0> : !type
@@ -23,14 +23,14 @@
}
-// CHECK: #[[$MAP:.+]] = affine_map<(d0) -> (d0)>
-// CHECK-LABEL: func.func @collapse1
-// CHECK: %[[IN:.+]] = tensor.collapse_shape %[[INPUT:.+]] {{\[}}[0, 1, 2, 3, 4, 5]] : tensor<2x4x8x16x32x64xf32> into tensor<2097152xf32>
-// CHECK: %[[OUT:.+]] = tensor.empty() : tensor<2097152xf32>
-// CHECK: %[[RES:.+]] = flow.dispatch.region
-// CHECK: linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP]]], iterator_types = ["parallel"]}
-// CHECK: ins(%[[IN]] : tensor<2097152xf32>) outs(%[[OUT]] : tensor<2097152xf32>)
-// CHECK: tensor.expand_shape %[[RES]] {{\[}}[0, 1, 2, 3, 4, 5]] : tensor<2097152xf32> into tensor<2x4x8x16x32x64xf32>
+// CHECK: #[[$MAP:.+]] = affine_map<(d0) -> (d0)>
+// CHECK-LABEL: func.func @collapse1
+// CHECK: %[[IN:.+]] = tensor.collapse_shape %[[INPUT:.+]] {{\[}}[0, 1, 2, 3, 4, 5]] : tensor<2x4x8x16x32x64xf32> into tensor<2097152xf32>
+// CHECK: %[[RES:.+]] = flow.dispatch.region
+// CHECK: %[[OUT:.+]] = tensor.empty() : tensor<2097152xf32>
+// CHECK: linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP]]], iterator_types = ["parallel"]}
+// CHECK-SAME: ins(%[[IN]] : tensor<2097152xf32>) outs(%[[OUT]] : tensor<2097152xf32>)
+// CHECK: tensor.expand_shape %[[RES]] {{\[}}[0, 1, 2, 3, 4, 5]] : tensor<2097152xf32> into tensor<2x4x8x16x32x64xf32>
// -----
@@ -58,15 +58,15 @@
}
-// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d2, d4)>
-// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
-// CHECK-LABEL: func.func @collapse2
-// CHECK: %[[IN:.+]] = tensor.collapse_shape %[[INPUT:.+]] {{\[}}[0, 1], [2], [3], [4], [5, 6]] : tensor<2x4x8x32x32x64x128xf32> into tensor<8x8x32x32x8192xf32>
-// CHECK: %[[OUT:.+]] = tensor.empty() : tensor<8x8x32x32x8192xf32>
-// CHECK: %[[RES:.+]] = flow.dispatch.region
-// CHECK: linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP1]]], iterator_types = ["parallel", "reduction", "parallel", "parallel", "parallel"]}
-// CHECK: ins(%[[IN]] : tensor<8x8x32x32x8192xf32>) outs(%[[OUT]] : tensor<8x8x32x32x8192xf32>)
-// CHECK: tensor.expand_shape %[[RES]] {{\[}}[0, 1], [2], [3], [4], [5, 6]] : tensor<8x8x32x32x8192xf32> into tensor<2x4x8x32x32x64x128xf32>
+// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d2, d4)>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
+// CHECK-LABEL: func.func @collapse2
+// CHECK: %[[IN:.+]] = tensor.collapse_shape %[[INPUT:.+]] {{\[}}[0, 1], [2], [3], [4], [5, 6]] : tensor<2x4x8x32x32x64x128xf32> into tensor<8x8x32x32x8192xf32>
+// CHECK: %[[RES:.+]] = flow.dispatch.region
+// CHECK: %[[OUT:.+]] = tensor.empty() : tensor<8x8x32x32x8192xf32>
+// CHECK: linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP1]]], iterator_types = ["parallel", "reduction", "parallel", "parallel", "parallel"]}
+// CHECK-SAME: ins(%[[IN]] : tensor<8x8x32x32x8192xf32>) outs(%[[OUT]] : tensor<8x8x32x32x8192xf32>)
+// CHECK: tensor.expand_shape %[[RES]] {{\[}}[0, 1], [2], [3], [4], [5, 6]] : tensor<8x8x32x32x8192xf32> into tensor<2x4x8x32x32x64x128xf32>
// -----
!type = tensor<2x4x8x16x32x64x128x256xf32>
@@ -93,14 +93,14 @@
}
-// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
-// CHECK-LABEL: func.func @collapse3
-// CHECK: %[[IN:.+]] = tensor.collapse_shape %[[INPUT:.+]] {{\[}}[0, 1], [2], [3, 4, 5, 6, 7]] : tensor<2x4x8x16x32x64x128x256xf32> into tensor<8x8x1073741824xf32>
-// CHECK: %[[OUT:.+]] = tensor.empty() : tensor<8x8x1073741824xf32>
-// CHECK: %[[RES:.+]] = flow.dispatch.region
-// CHECK: linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP]]], iterator_types = ["parallel", "reduction", "parallel"]}
-// CHECK: ins(%[[IN]] : tensor<8x8x1073741824xf32>) outs(%[[OUT]] : tensor<8x8x1073741824xf32>)
-// CHECK: tensor.expand_shape %[[RES]] {{\[}}[0, 1], [2], [3, 4, 5, 6, 7]] : tensor<8x8x1073741824xf32> into tensor<2x4x8x16x32x64x128x256xf32>
+// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK-LABEL: func.func @collapse3
+// CHECK: %[[IN:.+]] = tensor.collapse_shape %[[INPUT:.+]] {{\[}}[0, 1], [2], [3, 4, 5, 6, 7]] : tensor<2x4x8x16x32x64x128x256xf32> into tensor<8x8x1073741824xf32>
+// CHECK: %[[RES:.+]] = flow.dispatch.region
+// CHECK: %[[OUT:.+]] = tensor.empty() : tensor<8x8x1073741824xf32>
+// CHECK: linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP]]], iterator_types = ["parallel", "reduction", "parallel"]}
+// CHECK-SAME: ins(%[[IN]] : tensor<8x8x1073741824xf32>) outs(%[[OUT]] : tensor<8x8x1073741824xf32>)
+// CHECK: tensor.expand_shape %[[RES]] {{\[}}[0, 1], [2], [3, 4, 5, 6, 7]] : tensor<8x8x1073741824xf32> into tensor<2x4x8x16x32x64x128x256xf32>
// -----
@@ -127,15 +127,15 @@
}
-// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>
-// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d4, d3, d5)>
-// CHECK-LABEL: func.func @collapse4
-// CHECK: %[[IN:.+]] = tensor.collapse_shape %[[INPUT:.+]] {{\[}}[0, 1], [2], [3], [4], [5], [6, 7]] : tensor<2x4x8x16x64x64x128x256xf32> into tensor<8x8x16x64x64x32768xf32>
-// CHECK: %[[OUT:.+]] = tensor.empty() : tensor<8x8x16x64x64x32768xf32>
-// CHECK: %[[RES:.+]] = flow.dispatch.region
-// CHECK: linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP2]]], iterator_types = ["parallel", "reduction", "parallel", "parallel", "parallel", "parallel"]}
-// CHECK: ins(%[[IN]] : tensor<8x8x16x64x64x32768xf32>) outs(%[[OUT]] : tensor<8x8x16x64x64x32768xf32>)
-// CHECK: tensor.expand_shape %[[RES]] {{\[}}[0, 1], [2], [3], [4], [5], [6, 7]] : tensor<8x8x16x64x64x32768xf32> into tensor<2x4x8x16x64x64x128x256xf32>
+// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>
+// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d4, d3, d5)>
+// CHECK-LABEL: func.func @collapse4
+// CHECK: %[[IN:.+]] = tensor.collapse_shape %[[INPUT:.+]] {{\[}}[0, 1], [2], [3], [4], [5], [6, 7]] : tensor<2x4x8x16x64x64x128x256xf32> into tensor<8x8x16x64x64x32768xf32>
+// CHECK: %[[RES:.+]] = flow.dispatch.region
+// CHECK: %[[OUT:.+]] = tensor.empty() : tensor<8x8x16x64x64x32768xf32>
+// CHECK: linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP2]]], iterator_types = ["parallel", "reduction", "parallel", "parallel", "parallel", "parallel"]}
+// CHECK-SAME: ins(%[[IN]] : tensor<8x8x16x64x64x32768xf32>) outs(%[[OUT]] : tensor<8x8x16x64x64x32768xf32>)
+// CHECK: tensor.expand_shape %[[RES]] {{\[}}[0, 1], [2], [3], [4], [5], [6, 7]] : tensor<8x8x16x64x64x32768xf32> into tensor<2x4x8x16x64x64x128x256xf32>
// -----
@@ -167,18 +167,18 @@
}
-// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>
-// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d2, d4, d5)>
-// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d3, d2, d1, d4, d5)>
-// CHECK-LABEL: func.func @collapse5
-// CHECK: %[[IN:.+]] = tensor.collapse_shape %[[INPUT:.+]] {{\[}}[0, 1], [2], [3], [4], [5], [6, 7]] : tensor<2x4x32x32x32x64x128x256xf32> into tensor<8x32x32x32x64x32768xf32>
-// CHECK: %[[IN1:.+]] = tensor.collapse_shape %[[INPUT1:.+]] {{\[}}[0, 1], [2], [3], [4], [5], [6, 7]] : tensor<2x4x32x32x32x64x128x256xf32> into tensor<8x32x32x32x64x32768xf32>
-// CHECK: %[[IN2:.+]] = tensor.collapse_shape %[[INPUT2:.+]] {{\[}}[0, 1], [2], [3], [4], [5], [6, 7]] : tensor<2x4x32x32x32x64x128x256xf32> into tensor<8x32x32x32x64x32768xf32>
-// CHECK: %[[OUT:.+]] = tensor.empty() : tensor<8x32x32x32x64x32768xf32>
-// CHECK: %[[RES:.+]] = flow.dispatch.region
-// CHECK: linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]], #[[$MAP]]], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "parallel"]}
-// CHECK: ins(%[[IN]], %[[IN1]], %[[IN2]] : tensor<8x32x32x32x64x32768xf32>, tensor<8x32x32x32x64x32768xf32>, tensor<8x32x32x32x64x32768xf32>) outs(%[[OUT]] : tensor<8x32x32x32x64x32768xf32>)
-// CHECK: tensor.expand_shape %[[RES]] {{\[}}[0, 1], [2], [3], [4], [5], [6, 7]] : tensor<8x32x32x32x64x32768xf32> into tensor<2x4x32x32x32x64x128x256xf32>
+// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d2, d4, d5)>
+// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d3, d2, d1, d4, d5)>
+// CHECK-LABEL: func.func @collapse5
+// CHECK: %[[IN:.+]] = tensor.collapse_shape %[[INPUT:.+]] {{\[}}[0, 1], [2], [3], [4], [5], [6, 7]] : tensor<2x4x32x32x32x64x128x256xf32> into tensor<8x32x32x32x64x32768xf32>
+// CHECK: %[[IN1:.+]] = tensor.collapse_shape %[[INPUT1:.+]] {{\[}}[0, 1], [2], [3], [4], [5], [6, 7]] : tensor<2x4x32x32x32x64x128x256xf32> into tensor<8x32x32x32x64x32768xf32>
+// CHECK: %[[IN2:.+]] = tensor.collapse_shape %[[INPUT2:.+]] {{\[}}[0, 1], [2], [3], [4], [5], [6, 7]] : tensor<2x4x32x32x32x64x128x256xf32> into tensor<8x32x32x32x64x32768xf32>
+// CHECK: %[[RES:.+]] = flow.dispatch.region
+// CHECK: %[[OUT:.+]] = tensor.empty() : tensor<8x32x32x32x64x32768xf32>
+// CHECK: linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]], #[[$MAP]]], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "parallel"]}
+// CHECK-SAME: ins(%[[IN]], %[[IN1]], %[[IN2]] : tensor<8x32x32x32x64x32768xf32>, tensor<8x32x32x32x64x32768xf32>, tensor<8x32x32x32x64x32768xf32>) outs(%[[OUT]] : tensor<8x32x32x32x64x32768xf32>)
+// CHECK: tensor.expand_shape %[[RES]] {{\[}}[0, 1], [2], [3], [4], [5], [6, 7]] : tensor<8x32x32x32x64x32768xf32> into tensor<2x4x32x32x32x64x128x256xf32>
// -----
@@ -205,15 +205,15 @@
}
-// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>
-// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d4, d3, d5)>
-// CHECK-LABEL: func.func @collapse6
-// CHECK: %[[IN:.+]] = tensor.collapse_shape %[[INPUT:.+]] {{\[}}[0], [1], [2, 3], [4], [5], [6, 7]] : tensor<32x2x4x8x16x16x64x128xf32> into tensor<32x2x32x16x16x8192xf32>
-// CHECK: %[[OUT:.+]] = tensor.empty() : tensor<32x2x32x16x16x8192xf32>
-// CHECK: %[[RES:.+]] = flow.dispatch.region
-// CHECK: linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP2]]], iterator_types = ["parallel", "reduction", "parallel", "parallel", "parallel", "parallel"]}
-// CHECK: ins(%[[IN]] : tensor<32x2x32x16x16x8192xf32>) outs(%[[OUT]] : tensor<32x2x32x16x16x8192xf32>)
-// CHECK: tensor.expand_shape %[[RES]] {{\[}}[0], [1], [2, 3], [4], [5], [6, 7]] : tensor<32x2x32x16x16x8192xf32> into tensor<32x2x4x8x16x16x64x128xf32>
+// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>
+// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d4, d3, d5)>
+// CHECK-LABEL: func.func @collapse6
+// CHECK: %[[IN:.+]] = tensor.collapse_shape %[[INPUT:.+]] {{\[}}[0], [1], [2, 3], [4], [5], [6, 7]] : tensor<32x2x4x8x16x16x64x128xf32> into tensor<32x2x32x16x16x8192xf32>
+// CHECK: %[[RES:.+]] = flow.dispatch.region
+// CHECK: %[[OUT:.+]] = tensor.empty() : tensor<32x2x32x16x16x8192xf32>
+// CHECK: linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP2]]], iterator_types = ["parallel", "reduction", "parallel", "parallel", "parallel", "parallel"]}
+// CHECK-SAME: ins(%[[IN]] : tensor<32x2x32x16x16x8192xf32>) outs(%[[OUT]] : tensor<32x2x32x16x16x8192xf32>)
+// CHECK: tensor.expand_shape %[[RES]] {{\[}}[0], [1], [2, 3], [4], [5], [6, 7]] : tensor<32x2x32x16x16x8192xf32> into tensor<32x2x4x8x16x16x64x128xf32>
// -----
@@ -239,24 +239,23 @@
return %result: !type_out
}
-// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1) -> (d1)>
-// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1) -> (d1, d0)>
+// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1) -> (d1)>
+// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1) -> (d1, d0)>
// CHECK-LABEL: func.func @collapse7
-// CHECK: %[[IN:.+]] = tensor.collapse_shape %[[INPUT:.+]] {{\[}}[0, 1, 2]] : tensor<2x4x8xf32> into tensor<64xf32>
-// CHECK: %[[OUT:.+]] = tensor.empty() : tensor<64x16xf32>
-// CHECK: %[[RES:.+]] = flow.dispatch.region
-// CHECK: linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP2]]], iterator_types = ["parallel", "parallel"]}
-// CHECK: ins(%[[IN]] : tensor<64xf32>) outs(%[[OUT]] : tensor<64x16xf32>)
-// CHECK: tensor.expand_shape %[[RES]] {{\[}}[0, 1, 2], [3]] : tensor<64x16xf32> into tensor<2x4x8x16xf32>
+// CHECK: %[[IN:.+]] = tensor.collapse_shape %[[INPUT:.+]] {{\[}}[0, 1, 2]] : tensor<2x4x8xf32> into tensor<64xf32>
+// CHECK: %[[RES:.+]] = flow.dispatch.region
+// CHECK: %[[OUT:.+]] = tensor.empty() : tensor<64x16xf32>
+// CHECK: linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP2]]], iterator_types = ["parallel", "parallel"]}
+// CHECK-SAME: ins(%[[IN]] : tensor<64xf32>) outs(%[[OUT]] : tensor<64x16xf32>)
+// CHECK: tensor.expand_shape %[[RES]] {{\[}}[0, 1, 2], [3]] : tensor<64x16xf32> into tensor<2x4x8x16xf32>
// -----
!type_in = tensor<16x4x32x2xf32>
!type_out = tensor<8x16x4x32x8x2xf32>
-func.func @collapse8() -> !type_out {
+func.func @collapse8(%input : !type_in) -> !type_out {
%cst = arith.constant 0.000000e+00 : f32
%c0 = arith.constant 0 : index
- %input = tensor.empty() : !type_in
%output = tensor.empty() : !type_out
// Can collapse (d3, d0, d1)
%6 = linalg.generic { indexing_maps = [
@@ -272,15 +271,16 @@
return %6: !type_out
}
-// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d3)>
-// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d3)>
+// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
// CHECK-LABEL: func.func @collapse8
-// CHECK: %[[IN:.+]] = tensor.empty() : tensor<2048x2xf32>
-// CHECK: %[[OUT:.+]] = tensor.empty() : tensor<8x2048x8x2xf32>
-// CHECK: %[[RES:.+]] = flow.dispatch.region
-// CHECK: linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
-// CHECK: ins(%[[IN]] : tensor<2048x2xf32>) outs(%[[OUT]] : tensor<8x2048x8x2xf32
-// CHECK: tensor.expand_shape %[[RES]] {{\[}}[0], [1, 2, 3], [4], [5]] : tensor<8x2048x8x2xf32> into tensor<8x16x4x32x8x2xf32>
+// CHECK-SAME: (%[[IN:.+]]: tensor<16x4x32x2xf32>)
+// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[IN]] {{\[}}[0, 1, 2], [3]{{\]}}
+// CHECK: %[[RES:.+]] = flow.dispatch.region
+// CHECK: %[[OUT:.+]] = tensor.empty() : tensor<8x2048x8x2xf32>
+// CHECK: linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
+// CHECK-SAME: ins(%[[COLLAPSE]] : tensor<2048x2xf32>) outs(%[[OUT]] : tensor<8x2048x8x2xf32
+// CHECK: tensor.expand_shape %[[RES]] {{\[}}[0], [1, 2, 3], [4], [5]] : tensor<8x2048x8x2xf32> into tensor<8x16x4x32x8x2xf32>
// -----
@@ -304,7 +304,7 @@
return %6: !type_out
}
// CHECK-LABEL: func.func @dont_collapse
-// CHECK: linalg.generic {indexing_maps = [#[[$MAP:.+]], #[[$MAP2:.+]]], iterator_types = ["parallel", "parallel", "parallel"]}
+// CHECK: linalg.generic {indexing_maps = [#[[$MAP:.+]], #[[$MAP2:.+]]], iterator_types = ["parallel", "parallel", "parallel"]}
// -----
@@ -333,11 +333,11 @@
}
-// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>
-// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d4, d3, d5)>
+// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>
+// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d4, d3, d5)>
// CHECK-LABEL: func.func @collapse9
-// CHECK: %[[RES:.+]] = flow.dispatch.region
-// CHECK: linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP2]]], iterator_types = ["parallel", "reduction", "parallel", "parallel", "parallel", "parallel"]}
+// CHECK: %[[RES:.+]] = flow.dispatch.region
+// CHECK: linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP2]]], iterator_types = ["parallel", "reduction", "parallel", "parallel", "parallel", "parallel"]}
// -----
@@ -345,10 +345,9 @@
!type_in = tensor<10x10x30xf32>
!type_out = tensor<20x10x10x30x20xf32>
-func.func @collapse10() -> !type_out {
+func.func @collapse10(%input : !type_in) -> !type_out {
%cst = arith.constant 0.000000e+00 : f32
%c0 = arith.constant 0 : index
- %input = tensor.empty() : !type_in
%output = tensor.empty() : !type_out
// Can collapse (d1, d3, d0)
@@ -364,21 +363,18 @@
return %result: !type_out
}
-// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2) -> (d0)>
-// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1, d2) -> (d1, d0, d2)>
// CHECK-LABEL: func.func @collapse10
-// CHECK: %[[RES:.+]] = flow.dispatch.region
-// CHECK: linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel"]}
+// CHECK: %[[RES:.+]] = flow.dispatch.region
+// CHECK: linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel"]}
// -----
!type_in = tensor<10x20xf32>
!type_out = tensor<10x20xf32>
-func.func @collapse11() -> !type_out {
+func.func @collapse11(%input : !type_in) -> !type_out {
%cst = arith.constant 0.000000e+00 : f32
%c0 = arith.constant 0 : index
- %input = tensor.empty() : !type_in
%output = tensor.empty() : !type_out
// Can collapse (d1, d0)
@@ -394,10 +390,10 @@
}
-// CHECK: #[[$MAP:.+]] = affine_map<(d0) -> (d0)>
+// CHECK: #[[$MAP:.+]] = affine_map<(d0) -> (d0)>
// CHECK-LABEL: func.func @collapse11
-// CHECK: %[[RES:.+]] = flow.dispatch.region
-// CHECK: linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP]]], iterator_types = ["parallel"]}
+// CHECK: %[[RES:.+]] = flow.dispatch.region
+// CHECK: linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP]]], iterator_types = ["parallel"]}
// -----
@@ -420,7 +416,7 @@
}
// CHECK-LABEL: func.func @dont_collapse
-// CHECK: linalg.generic {indexing_maps = [#[[$MAP:.+]]], iterator_types = ["parallel", "parallel"]}
+// CHECK: linalg.generic {indexing_maps = [#[[$MAP:.+]]], iterator_types = ["parallel", "parallel"]}
// -----
@@ -456,8 +452,146 @@
return %6, %7, %8, %9 : !type,!type,!type,!type
}
-// CHECK: #[[$MAP:.+]] = affine_map<(d0) -> (d0)>
+// CHECK: #[[$MAP:.+]] = affine_map<(d0) -> (d0)>
// CHECK-LABEL: func.func @collapse12
-// CHECK: %[[RES:.+]] = flow.dispatch.region
-// CHECK: linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP]], #[[$MAP]], #[[$MAP]], #[[$MAP]]], iterator_types = ["parallel"]}
+// CHECK: %[[RES:.+]] = flow.dispatch.region
+// CHECK: linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP]], #[[$MAP]], #[[$MAP]], #[[$MAP]]], iterator_types = ["parallel"]}
+// -----
+
+func.func @multi_reduce_dim(%arg0: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub} {
+ %cst = arith.constant -0.000000e+00 : f32
+ %0 = hal.tensor.import %arg0 : !hal.buffer_view -> tensor<2x32x10x4096xf32>
+ %1 = tensor.empty() : tensor<2x32xf32>
+ %2 = linalg.fill ins(%cst : f32) outs(%1 : tensor<2x32xf32>) -> tensor<2x32xf32>
+ %3 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction", "reduction"]} ins(%0 : tensor<2x32x10x4096xf32>) outs(%2 : tensor<2x32xf32>) {
+ ^bb0(%arg1: f32, %arg2: f32):
+ %6 = arith.addf %arg1, %arg2 : f32
+ linalg.yield %6 : f32
+ } -> tensor<2x32xf32>
+ %4 = tensor.expand_shape %3 [[0], [1, 2, 3]] : tensor<2x32xf32> into tensor<2x32x1x1xf32>
+ %5 = hal.tensor.export %4 : tensor<2x32x1x1xf32> -> !hal.buffer_view
+ return %5 : !hal.buffer_view
+}
+
+// Check that we collapse dimensions.
+// CHECK-LABEL: @multi_reduce_dim(
+// CHECK-DAG: %[[ARG0:.+]] = hal.tensor.import
+// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1], [2, 3]{{\]}}
+// CHECK: %[[DISPATCH:.+]] = flow.dispatch.region
+// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<64xf32>
+// CHECK: %[[FILL:.+]] = linalg.fill
+// CHECK-SAME: outs(%[[EMPTY]] :
+// CHECK: %[[GENERIC:.+]] = linalg.generic
+// CHECK-SAME: ins(%[[COLLAPSE]] :
+// CHECK-SAME: outs(%[[FILL]] :
+// CHECK: flow.return %[[GENERIC]]
+// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[DISPATCH]] {{\[}}[0, 1]{{\]}}
+
+// -----
+
+// Collapsing is not supported when an input is broadcasted; we can't collapse
+// the input from tensor<4xf32> to tensor<32xf32> for example.
+
+func.func @input_broadcast(%arg0: tensor<4x8xf32>, %arg1: tensor<4xf32>) -> tensor<f32> {
+ %empty = tensor.empty() : tensor<f32>
+ %reduce = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> ()>], iterator_types = ["reduction", "reduction"]} ins(%arg0, %arg1 : tensor<4x8xf32>, tensor<4xf32>) outs(%empty : tensor<f32>) {
+ ^bb0(%arg2: f32, %arg3: f32, %out: f32):
+ %div = arith.divf %arg2, %arg3 : f32
+ %add = arith.addf %out, %div : f32
+ linalg.yield %add : f32
+ } -> tensor<f32>
+ return %reduce : tensor<f32>
+}
+
+// CHECK: @input_broadcast
+// CHECK-NOT: tensor.collapse_shape
+
+// -----
+
+// Do nothing if the dispatch is not a single elementwise op (with tensor.empty/linalg.fill producers)
+
+#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d0, d1)>
+#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)>
+#map3 = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d4)>
+#map4 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
+module {
+ func.func @quantized_matmul(%arg0: tensor<4096x32x128xi8>, %arg1: tensor<1x1x32x128xf32>) -> tensor<1x1x4096xf32> {
+ %cst = arith.constant dense_resource<__elided__> : tensor<4096x32xf32>
+ %cst_0 = arith.constant dense_resource<__elided__> : tensor<4096x32xf32>
+ %0 = flow.dispatch.region -> (tensor<1x1x4096xf32>) {
+ %cst_1 = arith.constant 0.000000e+00 : f32
+ %1 = tensor.empty() : tensor<1x1x4096xf32>
+ %2 = tensor.empty() : tensor<4096x32x128xf32>
+ %3 = linalg.fill ins(%cst_1 : f32) outs(%1 : tensor<1x1x4096xf32>) -> tensor<1x1x4096xf32>
+ %4 = linalg.generic {indexing_maps = [#map, #map1, #map1, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0, %cst, %cst_0 : tensor<4096x32x128xi8>, tensor<4096x32xf32>, tensor<4096x32xf32>) outs(%2 : tensor<4096x32x128xf32>) {
+ ^bb0(%in: i8, %in_2: f32, %in_3: f32, %out: f32):
+ %6 = arith.extui %in : i8 to i32
+ %7 = arith.uitofp %6 : i32 to f32
+ %8 = arith.subf %7, %in_3 : f32
+ %9 = arith.mulf %8, %in_2 : f32
+ linalg.yield %9 : f32
+ } -> tensor<4096x32x128xf32>
+ %5 = linalg.generic {indexing_maps = [#map2, #map3, #map4], iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%arg1, %4 : tensor<1x1x32x128xf32>, tensor<4096x32x128xf32>) outs(%3 : tensor<1x1x4096xf32>) {
+ ^bb0(%in: f32, %in_2: f32, %out: f32):
+ %6 = arith.mulf %in, %in_2 : f32
+ %7 = arith.addf %6, %out : f32
+ linalg.yield %7 : f32
+ } -> tensor<1x1x4096xf32>
+ flow.return %5 : tensor<1x1x4096xf32>
+ }
+ return %0 : tensor<1x1x4096xf32>
+ }
+}
+
+// CHECK-LABEL: func.func @quantized_matmul
+// CHECK: %[[DISPATCH:.+]] = flow.dispatch.region
+// CHECK: linalg.generic
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]
+// CHECK: linalg.generic
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]
+// CHECK: flow.return
+// CHECK: return %[[DISPATCH]]
+
+// -----
+
+module {
+ func.func @batchnorm_failure_repro(%arg0 : tensor<2x4xf32>, %arg1 : tensor<4xf32>) -> tensor<2x4xf32> {
+ %0 = tensor.empty() : tensor<2x4xf32>
+ %1 = linalg.generic {
+ indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%arg0, %arg1 : tensor<2x4xf32>, tensor<4xf32>) outs(%0 : tensor<2x4xf32>) {
+ ^bb0(%b0 : f32, %b1 : f32, %b2 : f32):
+ %2 = arith.addf %b0, %b1 : f32
+ linalg.yield %2 : f32
+ } -> tensor<2x4xf32>
+ return %1 : tensor<2x4xf32>
+ }
+}
+// CHECK-LABEL: func @batchnorm_failure_repro
+// CHECK: %[[DISPATCH:.+]] = flow.dispatch.region
+// CHECK: %[[GENERIC:.+]] = linalg.generic
+// CHECK-SAME: iterator_types = ["parallel", "parallel"]
+// CHECK: flow.return %[[GENERIC]]
+// CHECK: return %[[DISPATCH]]
+
+// -----
+
+module {
+ func.func @catch_invalid_collapse(%arg0 : tensor<10x20x30xf32>) -> tensor<10x30x40xf32> {
+ %0 = tensor.empty() : tensor<10x30x40xf32>
+ %1 = linalg.generic {
+ indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>],
+ iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
+ ins(%arg0 : tensor<10x20x30xf32>) outs(%0 : tensor<10x30x40xf32>) {
+ ^bb0(%b0 : f32, %b1 : f32):
+ linalg.yield %b0 : f32
+ } -> tensor<10x30x40xf32>
+ return %1 : tensor<10x30x40xf32>
+ }
+}
+// CHECK-LABEL: func @catch_invalid_collapse
+// CHECK: linalg.generic
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"]
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/collapse_reduction.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/collapse_reduction.mlir
deleted file mode 100644
index 5409dfb..0000000
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/collapse_reduction.mlir
+++ /dev/null
@@ -1,64 +0,0 @@
-// RUN: iree-opt --split-input-file -iree-flow-collapse-dims %s | FileCheck %s
-
-func.func @multi_reduce_dim(%arg0: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub} {
- %cst = arith.constant -0.000000e+00 : f32
- %0 = hal.tensor.import %arg0 : !hal.buffer_view -> tensor<2x32x10x4096xf32>
- %1 = tensor.empty() : tensor<2x32xf32>
- %2 = linalg.fill ins(%cst : f32) outs(%1 : tensor<2x32xf32>) -> tensor<2x32xf32>
- %3 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction", "reduction"]} ins(%0 : tensor<2x32x10x4096xf32>) outs(%2 : tensor<2x32xf32>) {
- ^bb0(%arg1: f32, %arg2: f32):
- %6 = arith.addf %arg1, %arg2 : f32
- linalg.yield %6 : f32
- } -> tensor<2x32xf32>
- %4 = tensor.expand_shape %3 [[0], [1, 2, 3]] : tensor<2x32xf32> into tensor<2x32x1x1xf32>
- %5 = hal.tensor.export %4 : tensor<2x32x1x1xf32> -> !hal.buffer_view
- return %5 : !hal.buffer_view
-}
-
-// Check that we collapse dimensions.
-// CHECK: @multi_reduce_dim
-// CHECK: linalg.generic {{.*}} iterator_types = ["parallel", "parallel", "reduction"]
-
-// -----
-
-// Collapsing is not supported when an input is broadcasted; we can't collapse
-// the input from tensor<4xf32> to tensor<32xf32> for example.
-
-func.func @input_broadcast(%arg0: tensor<4x8xf32>, %arg1: tensor<4xf32>) -> tensor<f32> {
- %empty = tensor.empty() : tensor<f32>
- %reduce = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> ()>], iterator_types = ["reduction", "reduction"]} ins(%arg0, %arg1 : tensor<4x8xf32>, tensor<4xf32>) outs(%empty : tensor<f32>) {
- ^bb0(%arg2: f32, %arg3: f32, %out: f32):
- %div = arith.divf %arg2, %arg3 : f32
- %add = arith.addf %out, %div : f32
- linalg.yield %add : f32
- } -> tensor<f32>
- return %reduce : tensor<f32>
-}
-
-// CHECK: @input_broadcast
-// CHECK-NOT: tensor.collapse_shape
-
-// -----
-
-// Collapsing should not happen to ops in flow.dispatch.region or flow.dispatch.workgroups
-
-func.func @multi_reduce_dim_dispatch(%arg0: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub} {
- %cst = arith.constant -0.000000e+00 : f32
- %0 = hal.tensor.import %arg0 : !hal.buffer_view -> tensor<2x32x10x4096xf32>
- %1 = tensor.empty() : tensor<2x32xf32>
- %2 = linalg.fill ins(%cst : f32) outs(%1 : tensor<2x32xf32>) -> tensor<2x32xf32>
- %3 = flow.dispatch.region -> (tensor<2x32xf32>) {
- %6 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction", "reduction"]} ins(%0 : tensor<2x32x10x4096xf32>) outs(%2 : tensor<2x32xf32>) {
- ^bb0(%arg1: f32, %arg2: f32):
- %7 = arith.addf %arg1, %arg2 : f32
- linalg.yield %7 : f32
- } -> tensor<2x32xf32>
- flow.return %6 : tensor<2x32xf32>
- }
- %4 = tensor.expand_shape %3 [[0], [1, 2, 3]] : tensor<2x32xf32> into tensor<2x32x1x1xf32>
- %5 = hal.tensor.export %4 : tensor<2x32x1x1xf32> -> !hal.buffer_view
- return %5 : !hal.buffer_view
-}
-
-// CHECK: @multi_reduce_dim_dispatch
-// CHECK: linalg.generic {{.*}} iterator_types = ["parallel", "parallel", "reduction", "reduction"]