blob: 9b6cdffa8a64f2ae5c95cf1d912d7915afd9900f [file] [log] [blame] [edit]
// Copyright 2022 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//===- DetachElementwiseFromNamedOps.cpp ----------------------------------===//
//
// Detaches elementwise ops from Linalg named ops in preparation for following
// fusion and bufferization.
//
//===----------------------------------------------------------------------===//
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.h"
#include "iree/compiler/GlobalOptimization/Passes.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Utils/Utils.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
namespace mlir::iree_compiler::GlobalOptimization {
#define GEN_PASS_DEF_DETACHELEMENTWISEFROMNAMEDOPSPASS
#include "iree/compiler/GlobalOptimization/Passes.h.inc"
namespace {
struct DetachElementwisePattern
: public OpInterfaceRewritePattern<linalg::LinalgOp> {
using OpInterfaceRewritePattern::OpInterfaceRewritePattern;
LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp,
PatternRewriter &rewriter) const override {
if (!linalg::isaContractionOpInterface(linalgOp) &&
!isa<linalg::ConvolutionOpInterface>(*linalgOp)) {
return failure();
}
if (!linalgOp.hasPureTensorSemantics())
return failure();
// Nothing to do if the output tensor operand is already a fill op.
SmallVector<OpOperand *> outputOperands;
if (!linalgOp.hasPureBufferSemantics()) {
outputOperands = llvm::to_vector(
llvm::map_range(linalgOp.getDpsInitsMutable(),
[](OpOperand &opOperand) { return &opOperand; }));
}
// Right now all the cases we see have one output. This can be relaxed once
// we see multiple output ops.
if (outputOperands.size() != 1)
return failure();
Value outputOperand = outputOperands.front()->get();
auto outsDefiningOp = outputOperand.getDefiningOp<linalg::LinalgOp>();
if (!outsDefiningOp || isa<linalg::FillOp>(outsDefiningOp.getOperation())) {
// If not linalg op, or is a fill op, do nothing.
return failure();
}
auto outputType = cast<RankedTensorType>(outputOperand.getType());
if (!outputType.getElementType().isIntOrFloat())
return failure();
auto elementType = outputType.getElementType();
Location loc = linalgOp.getLoc();
// Check if the output tensor access is a projected permutation
if (!linalgOp.getMatchingIndexingMap(outputOperands.front())
.isProjectedPermutation()) {
return rewriter.notifyMatchFailure(
linalgOp, "Output indexing map must be a projected permutation.");
}
int64_t outputRank = outputType.getRank();
SmallVector<utils::IteratorType> iterators(outputRank,
utils::IteratorType::parallel);
SmallVector<AffineMap> maps(3, rewriter.getMultiDimIdentityMap(outputRank));
// Create a zero tensor as the new output tensor operand to the Linalg
// contraction op.
SmallVector<OpFoldResult> mixedSizes =
tensor::getMixedSizes(rewriter, loc, outputOperand);
auto initOp =
tensor::EmptyOp::create(rewriter, loc, mixedSizes, elementType);
Value zero = arith::ConstantOp::create(rewriter, loc,
rewriter.getZeroAttr(elementType));
Value fill = linalg::FillOp::create(rewriter, loc, zero, initOp.getResult())
.result();
// Update the contraction op to use the new zero tensor as output operand.
rewriter.modifyOpInPlace(linalgOp,
[&]() { linalgOp.setDpsInitOperand(0, fill); });
// Create a generic op to add back the original output tensor operand.
rewriter.setInsertionPointAfter(linalgOp);
auto genericOp = linalg::GenericOp::create(
rewriter, loc, outputType,
ValueRange{linalgOp->getResult(0), outputOperand}, fill, maps,
iterators, [&](OpBuilder &b, Location nestedLoc, ValueRange args) {
Value result;
if (isa<FloatType>(elementType)) {
result = arith::AddFOp::create(b, nestedLoc, args[0], args[1]);
} else {
result = arith::AddIOp::create(b, nestedLoc, args[0], args[1]);
}
linalg::YieldOp::create(b, nestedLoc, result);
});
linalgOp->getResult(0).replaceAllUsesExcept(genericOp->getResult(0),
genericOp);
return success();
}
};
/// Replace uses of splat constants as `outs` operands of `LinalgExt`
/// operations. More canonical representation is to use a `empty -> fill
/// -> outs` operand sequence. Splat constants pulled in this way causes issues
/// with allocations. Using `fill` will allow for fusing with the op just like
/// fill -> linalg ops are fused. If not as a fallback they would be converted
/// to a splat, but both without stack allocations.
template <typename InterfaceOp>
struct DetachSplatConstantOutsOperands
: public OpInterfaceRewritePattern<InterfaceOp> {
using OpInterfaceRewritePattern<InterfaceOp>::OpInterfaceRewritePattern;
LogicalResult matchAndRewrite(InterfaceOp interfaceOp,
PatternRewriter &rewriter) const {
SmallVector<Value> newOutsOperands;
auto dpsInterfaceOp =
dyn_cast<DestinationStyleOpInterface>(interfaceOp.getOperation());
if (!dpsInterfaceOp) {
return rewriter.notifyMatchFailure(
interfaceOp, "expected op to implement DPS interface");
}
bool madeChanges = false;
for (auto outOperand : llvm::enumerate(dpsInterfaceOp.getDpsInits())) {
auto constOp =
outOperand.value().template getDefiningOp<arith::ConstantOp>();
if (!constOp)
continue;
auto resultType =
dyn_cast<RankedTensorType>(constOp.getResult().getType());
if (!resultType || !resultType.getElementType().isIntOrFloat())
continue;
auto attr = dyn_cast<ElementsAttr>(constOp.getValue());
if (!attr || !attr.isSplat())
continue;
Location loc = constOp.getLoc();
Type elementType = resultType.getElementType();
Value emptyTensorOp = tensor::EmptyOp::create(
rewriter, loc, resultType.getShape(), elementType);
TypedAttr constValue;
if (isa<IntegerType>(elementType)) {
constValue = rewriter.getIntegerAttr(
elementType, attr.template getSplatValue<APInt>());
} else {
constValue = rewriter.getFloatAttr(
elementType, attr.template getSplatValue<APFloat>());
}
Value scalarConstantOp =
arith::ConstantOp::create(rewriter, loc, elementType, constValue);
Value fillOp = linalg::FillOp::create(rewriter, loc, resultType,
scalarConstantOp, emptyTensorOp)
.getResult(0);
rewriter.modifyOpInPlace(dpsInterfaceOp, [&]() {
dpsInterfaceOp.setDpsInitOperand(outOperand.index(), fillOp);
});
madeChanges = true;
}
return success(madeChanges);
};
};
struct DetachElementwiseFromNamedOpsPass
: public impl::DetachElementwiseFromNamedOpsPassBase<
DetachElementwiseFromNamedOpsPass> {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<arith::ArithDialect, linalg::LinalgDialect,
tensor::TensorDialect>();
}
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
patterns.add<DetachElementwisePattern,
DetachSplatConstantOutsOperands<IREE::LinalgExt::LinalgExtOp>,
DetachSplatConstantOutsOperands<linalg::LinalgOp>>(
&getContext());
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
return signalPassFailure();
}
}
};
} // namespace
} // namespace mlir::iree_compiler::GlobalOptimization