blob: c428091f6cf811909eb1a04a1eb4f85d1abbe4aa [file]
// Copyright 2024 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//===--- FusionUtils.h --- Implementation of fusion utility functions -----===//
//===----------------------------------------------------------------------===//
#include "compiler/src/iree/compiler/DispatchCreation/FusionUtils.h"
#include "compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h"
#include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
namespace mlir::iree_compiler::DispatchCreation {
bool areFusableAsElementwiseOps(MLIRContext *context, OpOperand *fusedOperand,
bool fuseMultiReduction) {
Operation *producerOp = fusedOperand->get().getDefiningOp();
Operation *consumerOp = fusedOperand->getOwner();
if (!producerOp)
return false;
// Check for i1 return types, if so aggressively fuse to avoid `i1` buffers.
if (llvm::all_of(producerOp->getResultTypes(), [](Type t) {
if (t.isInteger(1))
return true;
if (auto shapedType = llvm::dyn_cast<ShapedType>(t)) {
if (shapedType.getElementType().isInteger(1))
return true;
}
return false;
})) {
return true;
}
// Don't fuse if all of the consumer maps aren't projected permutations.
if (auto linalgConsumerOp = dyn_cast<linalg::LinalgOp>(consumerOp)) {
if (!llvm::all_of(
linalgConsumerOp.getIndexingMapsArray(),
[](AffineMap map) { return map.isProjectedPermutation(); })) {
return false;
}
}
// If the generic op is "just" copy, then fuse always.
Block &body = producerOp->getRegion(0).front();
if (std::begin(body)->hasTrait<OpTrait::IsTerminator>())
return true;
if (llvm::all_of(body.getArguments(),
[](BlockArgument arg) { return arg.use_empty(); })) {
// The operands aren't used, its just an `linalg.index` op.
return true;
}
// If producer does not have a single user, dont fuse.
if (!producerOp->hasOneUse()) {
return false;
}
// Do no fuse bitextend-like operations with producers. Such ops are cloned
// into all their use dispatches. So fusing producer with consumer here would
// then result in producer also getting cloned into many dispatches which is
// against the thumb rule of fusion to not introduce additional computation
// (except for bit-extend ops). If the consumer has only one use, then this
// fusion is fine since cloning wont result in redundant computation of the
// producer. (Also note that the producer is always an elementwise operation).
if (IREE::LinalgExt::isBitExtendOp(consumerOp) && !consumerOp->hasOneUse()) {
return false;
}
auto linalgConsumerOp = dyn_cast<linalg::LinalgOp>(consumerOp);
if (!linalgConsumerOp) {
return false;
}
// If the producer has a single use (this op), only fuse if
// - 1) The consumer op is all parallel loops. The parallelism of the consumer
// can be used as a way to amortize cost of redundant computation
// - 2) If consumer op is a reduction, only fuse if the indexing map in the
// consumer for the producer result is a permutation. If it is a
// broadcast this ends up redundantly computing operations without more
// parallelism.
if (linalgConsumerOp.getNumParallelLoops() !=
linalgConsumerOp.getNumLoops()) {
if (!linalgConsumerOp.getMatchingIndexingMap(fusedOperand)
.isPermutation()) {
return false;
}
if (!fuseMultiReduction && linalgConsumerOp.getNumReductionLoops() != 1) {
return false;
}
if (linalg::isaContractionOpInterface(linalgConsumerOp) ||
linalg::isaConvolutionOpInterface(linalgConsumerOp)) {
return false;
}
}
return true;
}
} // namespace mlir::iree_compiler::DispatchCreation