blob: 863b089f421e8baddba7e39114b39b895ac99cab [file] [log] [blame]
// Copyright 2019 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include <algorithm>
#include <numeric>
#include "iree/compiler/Dialect/Flow/IR/FlowDialect.h"
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
#include "iree/compiler/Dialect/Shape/IR/ShapeOps.h"
#include "iree/compiler/Dialect/Util/IR/ClosureOpUtils.h"
#include "iree/compiler/Dialect/Util/IR/UtilTypes.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/Optional.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/StringExtras.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/StandardOps/Utils/Utils.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/Value.h"
#include "mlir/Support/LogicalResult.h"
namespace mlir {
namespace iree_compiler {
namespace IREE {
namespace Flow {
//===----------------------------------------------------------------------===//
// Folding utilities
//===----------------------------------------------------------------------===//
// Returns a new set of dynamic dimensions for a shape carrying op when a type
// is being changed. This attempts to reuse the existing dimension values if
// they are available and will drop/insert new ones as required.
static SmallVector<Value, 4> refreshDimsOnTypeChange(
Operation *op, Type oldType, Type newType, ValueRange oldDims,
PatternRewriter &rewriter) {
if (oldType == newType) return llvm::to_vector<4>(oldDims);
// Build an expanded list of all the dims - constants will be nullptr.
// This lets us map back the new types without worrying about whether some
// subset become static or dynamic.
auto oldShapedType = oldType.cast<ShapedType>();
SmallVector<Value, 4> allOldDims(oldShapedType.getRank());
for (unsigned i = 0; i < oldShapedType.getRank(); ++i) {
if (oldShapedType.isDynamicDim(i)) {
allOldDims[i] = oldDims.front();
oldDims = oldDims.drop_front();
}
}
auto newShapedType = newType.cast<ShapedType>();
SmallVector<Value, 4> newDims;
for (unsigned i = 0; i < newShapedType.getRank(); ++i) {
if (newShapedType.isDynamicDim(i)) {
auto oldValue = allOldDims[i];
if (oldValue) {
// Old value valid; reuse.
newDims.push_back(oldValue);
} else {
// Dimension has changed to be dynamic; insert a constant to use.
// This sometimes happens during folding of casts and usually is cleaned
// up pretty quickly.
newDims.push_back(rewriter.createOrFold<arith::ConstantIndexOp>(
op->getLoc(), oldShapedType.getDimSize(i)));
}
}
}
return newDims;
}
//===----------------------------------------------------------------------===//
// Streams
//===----------------------------------------------------------------------===//
namespace {
// Returns true if the given |value| is used again after |updateOp| consumes it.
static bool hasUsersInStreamAfterUpdate(Value value, Operation *updateOp) {
for (auto user : value.getUsers()) {
if (user == updateOp) continue;
if (user->getBlock() != updateOp->getBlock() ||
user->isBeforeInBlock(updateOp)) {
// From a dominating block or earlier in the block, cannot be a consumer.
continue;
}
return true;
}
return false;
}
// Returns true if the given |operand| is a constant tied to a result of
// |updateOp| and the |updateOp| has inplace update semantics.
static bool updatesConstantInStream(Value operand, Operation *updateOp) {
// Only two ops have inplace update semantics thus far. (TensorReshapeOp,
// which also implements TiedOpInterface, is fine.) Checking the explicit
// op list is not good; we should have an op interface.
if (!isa<DispatchOp, TensorUpdateOp>(updateOp)) return false;
// For loaded variables, check whether it's mutable. Immutable variables will
// be aggregated into one read-only buffer.
if (auto loadOp = operand.getDefiningOp<IREE::Util::GlobalLoadOp>()) {
return loadOp.isGlobalImmutable();
}
return false;
}
/// Inserts clones into the stream as required by tied results.
/// This is required to preserve the immutable tensor semantics required by the
/// SSA use-def chain.
///
/// Example:
/// %0 = flow.dispatch
/// // %0 will be updated in-place and renamed %1:
/// %1 = flow.dispatch %0 -> %0
/// // The original value of %0 (aka %1) is required but is not valid!
/// %2 = flow.dispatch %0
/// ->
/// %0 = flow.dispatch
/// // Capture the value of %0 before it is modified:
/// %clone = flow.tensor.clone %0
/// // Update %0 in-place and rename to %1, safe as %0 now has one use:
/// %1 = flow.dispatch %0 -> %0
/// // Use the cloned %0 value:
/// %2 = flow.dispatch %clone
struct InsertImmutabilityPreservingStreamClones
: public OpRewritePattern<ExStreamFragmentOp> {
using OpRewritePattern<ExStreamFragmentOp>::OpRewritePattern;
LogicalResult matchAndRewrite(ExStreamFragmentOp op,
PatternRewriter &rewriter) const override {
bool didClone = insertTiedClones(
cast<IREE::Util::TiedOpInterface>(op.getOperation()), rewriter);
for (auto &block : op.getClosureBodyRegion()) {
for (auto &innerOp : block) {
if (auto tiedOp = dyn_cast<IREE::Util::TiedOpInterface>(innerOp)) {
didClone |= insertTiedClones(tiedOp, rewriter);
}
}
}
return success(didClone);
}
bool insertTiedClones(IREE::Util::TiedOpInterface tiedOp,
PatternRewriter &rewriter) const {
bool didClone = false;
for (unsigned resultIndex = 0; resultIndex < tiedOp->getNumResults();
++resultIndex) {
auto tiedOperandIndex = tiedOp.getTiedResultOperandIndex(resultIndex);
if (!tiedOperandIndex.hasValue()) continue;
auto tiedOperand = tiedOp->getOperand(tiedOperandIndex.getValue());
if (hasUsersInStreamAfterUpdate(tiedOperand, tiedOp)) {
rewriter.setInsertionPoint(tiedOp);
auto clonedOperand = rewriter.createOrFold<TensorCloneOp>(
tiedOperand.getLoc(), tiedOperand);
SmallPtrSet<Operation *, 1> excludedOps;
excludedOps.insert(tiedOp.getOperation());
excludedOps.insert(clonedOperand.getDefiningOp());
tiedOperand.replaceUsesWithIf(clonedOperand, [&](OpOperand &use) {
Operation *user = use.getOwner();
return !excludedOps.count(user) &&
user->getBlock() ==
clonedOperand.getDefiningOp()->getBlock() &&
clonedOperand.getDefiningOp()->isBeforeInBlock(user);
});
didClone = true;
}
// TODO(#5492): This is a temporary solution to address the issue where we
// aggreate constants in a read-only buffer but still see inplace updates
// to them. Force clones for such constants for now.
if (updatesConstantInStream(tiedOperand, tiedOp)) {
rewriter.setInsertionPoint(tiedOp);
auto clonedOperand = rewriter.createOrFold<TensorCloneOp>(
tiedOperand.getLoc(), tiedOperand);
tiedOperand.replaceUsesWithIf(clonedOperand, [&](OpOperand &use) {
return use.getOwner() == tiedOp.getOperation();
});
didClone = true;
}
}
return didClone;
}
};
/// Ties the results of streams to their operands when the stream operations are
/// tied throughout the entire body.
struct TieStreamResults : public OpRewritePattern<ExStreamFragmentOp> {
using OpRewritePattern<ExStreamFragmentOp>::OpRewritePattern;
LogicalResult matchAndRewrite(ExStreamFragmentOp op,
PatternRewriter &rewriter) const override {
assert(op.getRegion().getBlocks().size() == 1 &&
"only one stream block supported");
bool didModify = false;
op.walk([&](IREE::Flow::ReturnOp returnOp) {
for (auto result : llvm::enumerate(returnOp.getOperands())) {
if (op.getTiedResultOperandIndex(result.index()).hasValue()) {
continue; // Already tied.
}
auto baseValue =
IREE::Util::TiedOpInterface::findTiedBaseValue(result.value());
if (auto blockArg = baseValue.dyn_cast<BlockArgument>()) {
unsigned operandIndex = blockArg.getArgNumber();
op.setTiedResultOperandIndex(result.index(), operandIndex);
didModify = true;
}
}
});
return didModify ? success() : failure();
}
};
} // namespace
void ExStreamFragmentOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
results.insert<IREE::Util::ClosureOptimizationPattern<ExStreamFragmentOp>>(
context);
results.insert<InsertImmutabilityPreservingStreamClones>(context);
// TODO(#6420): fix HAL lowering of this (or wait until streams are gone).
// results.insert<TieStreamResults>(context);
}
//===----------------------------------------------------------------------===//
// Dispatch ops
//===----------------------------------------------------------------------===//
void DispatchWorkgroupsOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
results.insert<IREE::Util::ClosureOptimizationPattern<DispatchWorkgroupsOp>>(
context);
}
//===----------------------------------------------------------------------===//
// flow.dispatch.tensor.load
//===----------------------------------------------------------------------===//
namespace {
// Some linalg patterns, due to being upstream, tend to introduce `dim` ops.
// These generally fold with upstream patterns when tensors are involved, but
// when DispatchTensorLoadOp's are involved (with dispatch tensor types),
// then this starts to break down, which causes the `dim` ops to survive
// arbitrarily late into the pipeline. Often, they keep alive
// DispatchTensorLoadOp's that would otherwise be dead!
//
// To fix this:
// (1) In the case of loading full tensor we convert the `std.dim` ops to
// `flow.dispatch.shape` ops.
// ```
// dim(flow.dispatch.tensor.load(%x), %const)
// ->
// shapex.ranked_dim(flow.dispatch.shape(%x), %const)
// ``
// (2) When we are loading a tile we get replace dim with the size from sizes.
struct ConvertDimOfDispatchInputLoadToDispatchShape
: public OpRewritePattern<tensor::DimOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(tensor::DimOp op,
PatternRewriter &rewriter) const override {
auto loadOp = op.source().getDefiningOp<DispatchTensorLoadOp>();
if (!loadOp) return failure();
Optional<int64_t> constantIndex = op.getConstantIndex();
if (!constantIndex.hasValue()) return failure();
// Full tensor:
if (loadOp.sizes().empty()) {
auto rankedShape =
rewriter.create<DispatchShapeOp>(op.getLoc(), loadOp.source());
rewriter.replaceOpWithNewOp<Shape::RankedDimOp>(op, rankedShape,
*constantIndex);
} else { // Tensor tile :
if (loadOp.getMixedSizes()[*constantIndex].is<Attribute>()) {
rewriter.replaceOpWithNewOp<arith::ConstantOp>(
op, loadOp.getMixedSizes()[*constantIndex]
.get<Attribute>()
.dyn_cast<IntegerAttr>());
} else {
rewriter.replaceOp(
op, {loadOp.getMixedSizes()[*constantIndex].get<Value>()});
}
}
return success();
}
};
// Inlining producers of an input to the dispatch region results in the
// `flow.dispatch.input.load` having a `tensor` type as input. This fails
// verification. Since inlining happens during canonicalization, add a pattern
// to convert
//
// flow.dispatch.input.load %v, offsets .., sizes .., strides..
// : tensor<...> -> tensor<..>
//
// to
//
// subtensor %v[..] [..] [..]
struct ConvertDispatchInputLoadOfTensorToSubTensor
: public OpRewritePattern<DispatchTensorLoadOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(DispatchTensorLoadOp loadOp,
PatternRewriter &rewriter) const override {
if (!loadOp.source().getType().isa<RankedTensorType>()) {
return failure();
}
// If the offsets are empty rely on folding to take care of it.
if (loadOp.offsets().empty() && loadOp.sizes().empty() &&
loadOp.strides().empty()) {
return failure();
}
rewriter.replaceOpWithNewOp<tensor::ExtractSliceOp>(
loadOp, loadOp.source(), loadOp.getMixedOffsets(),
loadOp.getMixedSizes(), loadOp.getMixedStrides());
return success();
}
};
/// Returns the canonical type of the result of the load op.
struct DispatchTensorLoadReturnTypeCanonicalizer {
RankedTensorType operator()(DispatchTensorLoadOp loadOp,
ArrayRef<OpFoldResult> mixedOffsets,
ArrayRef<OpFoldResult> mixedSizes,
ArrayRef<OpFoldResult> mixedStrides) {
return DispatchTensorLoadOp::inferResultType(
loadOp.source().getType().cast<DispatchTensorType>(), mixedSizes);
}
};
/// A canonicalizer wrapper to replace DispatchTensorLoadOps.
struct DispatchTensorLoadOpCanonicalizer {
void operator()(PatternRewriter &rewriter, DispatchTensorLoadOp op,
DispatchTensorLoadOp newOp) {
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getResult().getType(),
newOp.getResult());
}
};
} // namespace
void DispatchTensorLoadOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
results.insert<
ConvertDimOfDispatchInputLoadToDispatchShape,
ConvertDispatchInputLoadOfTensorToSubTensor,
OpWithOffsetSizesAndStridesConstantArgumentFolder<
DispatchTensorLoadOp, DispatchTensorLoadReturnTypeCanonicalizer,
DispatchTensorLoadOpCanonicalizer>>(context);
}
// Inlining producers of an input to the dispatch region results in the
// `flow.dispatch.input.load` having a `tensor` type as input. This fails
// verification. Fold such uses of the offsets, size and strides are emtpy.
// i.e, flow.dispatch.input.load %v -> %v
OpFoldResult DispatchTensorLoadOp::fold(ArrayRef<Attribute> operands) {
if (source().getType() && source().getType().isa<RankedTensorType>() &&
getMixedOffsets().empty() && getMixedSizes().empty() &&
getMixedStrides().empty()) {
return source();
}
return {};
}
//===----------------------------------------------------------------------===//
// flow.dispatch.tensor.store
//===----------------------------------------------------------------------===//
namespace {
struct FoldCastOpIntoDispatchStoreOp
: public OpRewritePattern<DispatchTensorStoreOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(DispatchTensorStoreOp storeOp,
PatternRewriter &rewriter) const override {
if (!storeOp.value().getDefiningOp<tensor::CastOp>()) return failure();
auto parentOp = storeOp.value().getDefiningOp<tensor::CastOp>();
rewriter.replaceOpWithNewOp<DispatchTensorStoreOp>(
storeOp, parentOp.source(), storeOp.target(), storeOp.offsets(),
storeOp.sizes(), storeOp.strides(), storeOp.static_offsets(),
storeOp.static_sizes(), storeOp.static_strides());
return success();
}
};
} // namespace
void DispatchTensorStoreOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
results.insert<FoldCastOpIntoDispatchStoreOp>(context);
}
//===----------------------------------------------------------------------===//
// flow.dispatch.workgroup.*
//===----------------------------------------------------------------------===//
OpFoldResult DispatchWorkgroupRankOp::fold(ArrayRef<Attribute> operands) {
if (auto dispatchOp = (*this)->getParentOfType<DispatchWorkgroupsOp>()) {
return IntegerAttr::get(IndexType::get(getContext()),
APInt(64, dispatchOp.workgroup_count().size()));
}
return {};
}
//===----------------------------------------------------------------------===//
// flow.dispatch.shape
//===----------------------------------------------------------------------===//
namespace {
struct FoldConstantDispatchShape : public OpRewritePattern<DispatchShapeOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(DispatchShapeOp op,
PatternRewriter &rewriter) const override {
auto sourceType = op.source().getType().cast<DispatchTensorType>();
if (!sourceType.hasStaticShape()) return failure();
auto shapeType = Shape::RankedShapeType::get(sourceType.getShape(),
rewriter.getContext());
rewriter.replaceOpWithNewOp<Shape::ConstRankedShapeOp>(op, shapeType);
return success();
}
};
struct PropagateTiedDispatchShapeQuery
: public OpRewritePattern<DispatchShapeOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(DispatchShapeOp op,
PatternRewriter &rewriter) const override {
if (auto tieOp =
dyn_cast_or_null<DispatchTieShapeOp>(op.source().getDefiningOp())) {
rewriter.replaceOp(op, {tieOp.shape()});
return success();
}
return failure();
}
};
} // namespace
void DispatchShapeOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
results.insert<FoldConstantDispatchShape, PropagateTiedDispatchShapeQuery>(
context);
}
//===----------------------------------------------------------------------===//
// flow.dispatch.tie_shape
//===----------------------------------------------------------------------===//
namespace {
struct FoldConstantDispatchTieShape
: public OpRewritePattern<DispatchTieShapeOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(DispatchTieShapeOp op,
PatternRewriter &rewriter) const override {
auto shapeType = op.shape().getType().cast<Shape::RankedShapeType>();
if (!shapeType.isFullyStatic()) return failure();
rewriter.replaceOp(op, op.operand());
return success();
}
};
/// Elides the tie_shape if its operand already carries shapes.
struct ElideShapeCarryingOperandTieShape
: public OpRewritePattern<DispatchTieShapeOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(DispatchTieShapeOp op,
PatternRewriter &rewriter) const override {
auto definingOp = op.operand().getDefiningOp();
if (!definingOp) return failure();
if (!isa<ShapeCarryingInterface>(definingOp)) return failure();
rewriter.replaceOp(op, op.operand());
return success();
}
};
} // namespace
void DispatchTieShapeOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
results
.insert<ElideShapeCarryingOperandTieShape, FoldConstantDispatchTieShape>(
context);
}
//===----------------------------------------------------------------------===//
// Tensor ops
//===----------------------------------------------------------------------===//
/// Reduces the provided multidimensional index into a flattended 1D row-major
/// index. The |type| is expected to be statically shaped (as all constants
/// are).
static uint64_t getFlattenedIndex(ShapedType type, ArrayRef<uint64_t> index) {
assert(type.hasStaticShape() && "for use on statically shaped types only");
auto rank = type.getRank();
auto shape = type.getShape();
uint64_t valueIndex = 0;
uint64_t dimMultiplier = 1;
for (int i = rank - 1; i >= 0; --i) {
valueIndex += index[i] * dimMultiplier;
dimMultiplier *= shape[i];
}
return valueIndex;
}
static bool compareShapesEqual(ShapedType lhsType, ValueRange lhsDynamicDims,
ShapedType rhsType, ValueRange rhsDynamicDims) {
if (lhsType.hasStaticShape() && rhsType.hasStaticShape() &&
lhsType.getNumElements() == rhsType.getNumElements()) {
// Static shape equivalence means we can fast-path the check.
return true;
}
if (lhsType.getRank() != rhsType.getRank()) {
return false;
}
unsigned dynamicDimIndex = 0;
for (unsigned i = 0; i < lhsType.getRank(); ++i) {
if (lhsType.isDynamicDim(i) != rhsType.isDynamicDim(i)) {
// Static/dynamic dimension mismatch - definitely differ.
return false;
} else if (lhsType.isDynamicDim(i)) {
unsigned j = dynamicDimIndex++;
if (lhsDynamicDims[j] != rhsDynamicDims[j]) {
// Dynamic dimensions with different SSA values - probably differ.
return false;
}
} else {
if (lhsType.getDimSize(i) != rhsType.getDimSize(i)) {
// Static dimensions differ.
return false;
}
}
}
return true;
}
OpFoldResult TensorReshapeOp::fold(ArrayRef<Attribute> operands) {
auto sourceType = source().getType().cast<ShapedType>();
auto resultType = result().getType().cast<ShapedType>();
if (compareShapesEqual(sourceType, source_dims(), resultType,
result_dims())) {
// Shapes match and this is a no-op so just fold to the source.
return source();
}
return {};
}
namespace {
// Flatten a chain of reshapes (reshape feeding into reshape) such that a
// reshape only ever pulls from a non-reshape source. This prevents big useless
// chains and makes it easier to track the original storage for the tensor.
struct FlattenTensorReshapeChain : public OpRewritePattern<TensorReshapeOp> {
using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
PatternRewriter &rewriter) const override {
auto sourceOp =
dyn_cast_or_null<TensorReshapeOp>(reshapeOp.source().getDefiningOp());
if (!sourceOp) return failure();
// We want the same result value/shape but to source from the ancestor. We
// need to pull any dynamic dims from that as we don't care about the
// intermediate reshapes.
rewriter.replaceOpWithNewOp<TensorReshapeOp>(
reshapeOp, reshapeOp.result().getType(), sourceOp.source(),
sourceOp.source_dims(), reshapeOp.result_dims());
return success();
}
};
// Replace `flow.tensor.splat`-`flow.tensor.load` op-pairs by the input
// primitive value for the splat op.
struct FoldSplatLoadIntoPrimitive : public OpRewritePattern<TensorLoadOp> {
using OpRewritePattern<TensorLoadOp>::OpRewritePattern;
LogicalResult matchAndRewrite(TensorLoadOp loadOp,
PatternRewriter &rewriter) const override {
auto sourceOp =
dyn_cast_or_null<TensorSplatOp>(loadOp.source().getDefiningOp());
if (!sourceOp) return failure();
rewriter.replaceOp(loadOp, sourceOp.value());
return success();
}
};
struct FoldSplatReshapeIntoSplat : public OpRewritePattern<TensorSplatOp> {
using OpRewritePattern<TensorSplatOp>::OpRewritePattern;
LogicalResult matchAndRewrite(TensorSplatOp splatOp,
PatternRewriter &rewriter) const override {
if (!splatOp.result().hasOneUse()) return failure();
auto reshapeOp = dyn_cast_or_null<TensorReshapeOp>(
splatOp.result().use_begin()->getOwner());
if (!reshapeOp) return failure();
rewriter.replaceOpWithNewOp<TensorSplatOp>(
reshapeOp, reshapeOp.result().getType(), splatOp.value(),
reshapeOp.result_dims());
rewriter.eraseOp(splatOp);
return success();
}
};
} // namespace
void TensorReshapeOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
results.insert<FlattenTensorReshapeChain>(context);
}
void TensorLoadOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
results.insert<FoldSplatLoadIntoPrimitive>(context);
}
OpFoldResult TensorLoadOp::fold(ArrayRef<Attribute> operands) {
if (auto source = operands[0].dyn_cast_or_null<ElementsAttr>()) {
// Load directly from the constant source tensor.
auto indices = operands.drop_front();
if (llvm::count(indices, nullptr) == 0) {
return source.getValue(
llvm::to_vector<4>(llvm::map_range(indices, [](Attribute value) {
return value.cast<IntegerAttr>().getValue().getZExtValue();
})));
}
}
return {};
}
OpFoldResult TensorStoreOp::fold(ArrayRef<Attribute> operands) {
if (!operands[0]) return {};
auto &value = operands[0];
if (auto target = operands[1].dyn_cast_or_null<ElementsAttr>()) {
// Store into the constant target tensor.
if (target.getType().getRank() == 0) {
return DenseElementsAttr::get(target.getType(), {value});
}
auto indices = operands.drop_front(2);
if (llvm::count(indices, nullptr) == 0) {
uint64_t offset = getFlattenedIndex(
target.getType(),
llvm::to_vector<4>(llvm::map_range(indices, [](Attribute value) {
return value.cast<IntegerAttr>().getValue().getZExtValue();
})));
SmallVector<Attribute, 16> newContents(target.getValues<Attribute>());
newContents[offset] = value;
return DenseElementsAttr::get(target.getType(), newContents);
}
}
return {};
}
void TensorSplatOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
// TODO(benvanik): canonicalize splat+slice to smaller splat.
results.insert<FoldSplatReshapeIntoSplat>(context);
}
OpFoldResult TensorSplatOp::fold(ArrayRef<Attribute> operands) {
if (operands.size() == 1 && operands.front()) {
// Splat value is constant and we can fold the operation.
return SplatElementsAttr::get(result().getType().cast<ShapedType>(),
operands[0]);
}
return {};
}
OpFoldResult TensorCloneOp::fold(ArrayRef<Attribute> operands) {
if (operands[0]) {
// Constants always fold.
return operands[0];
}
// TODO(benvanik): elide clones when safe to do so. Right now clone is
// load-bearing to work around our lack of cross-stream scheduling. Clones are
// inserted to avoid mutating function arguments and any logic we perform here
// (without *also* checking all the conditions that may insert a clone) will
// just fight.
//
// Once the clones are not load-bearing we can remove them in all the normal
// cases (one user, no intervening uses between clone and consumers of
// operands, etc).
return {};
}
// Slices tensor from start to (start + length) exclusively at dim.
static ElementsAttr tensorSlice(ElementsAttr tensor, uint64_t dim,
uint64_t start, uint64_t length) {
auto shape = llvm::to_vector<4>(tensor.getType().getShape());
if (length == shape[dim]) {
// No need to slice.
return tensor;
}
auto outputShape = shape;
outputShape[dim] = length;
auto outputType =
RankedTensorType::get(outputShape, getElementTypeOrSelf(tensor));
llvm::SmallVector<Attribute, 4> newContents;
newContents.reserve(outputType.getNumElements());
auto valuesBegin = tensor.getValues<Attribute>().begin();
int64_t step =
std::accumulate(shape.rbegin(), shape.rbegin() + shape.size() - dim,
/*init=*/1, /*op=*/std::multiplies<int64_t>());
int64_t num = length * step / shape[dim];
for (int64_t offset = step / shape[dim] * start,
numElements = tensor.getType().getNumElements();
offset < numElements; offset += step) {
newContents.append(valuesBegin + offset, valuesBegin + offset + num);
}
return DenseElementsAttr::get(outputType, newContents);
}
OpFoldResult TensorSliceOp::fold(ArrayRef<Attribute> operands) {
if (llvm::count(operands, nullptr) == 0) {
// Fully constant arguments so we can perform the slice here.
auto tensor = operands[0].cast<ElementsAttr>();
int64_t rank = source().getType().cast<ShapedType>().getRank();
// start = operands[1:1+rank), and length = operands[1+rank:].
auto start = llvm::to_vector<4>(llvm::map_range(
operands.drop_front(1).drop_back(rank), [](Attribute value) {
return value.cast<IntegerAttr>().getValue().getZExtValue();
}));
auto length = llvm::to_vector<4>(
llvm::map_range(operands.drop_front(1 + rank), [](Attribute value) {
return value.cast<IntegerAttr>().getValue().getZExtValue();
}));
for (int64_t dim = 0; dim < rank; ++dim) {
tensor = tensorSlice(tensor, dim, start[dim], length[dim]);
}
return tensor;
}
return {};
}
static ElementsAttr tensorUpdate(ElementsAttr update, ElementsAttr target,
ArrayRef<Attribute> startIndicesAttrs) {
auto updateType = update.getType().cast<ShapedType>();
auto targetType = target.getType().cast<ShapedType>();
// If either target or update has zero element, then no update happens.
if (updateType.getNumElements() == 0 || targetType.getNumElements() == 0) {
return target;
}
int64_t rank = targetType.getRank();
// If target is scalar, update is also scalar and is the new content.
if (rank == 0) {
return update;
}
auto startIndex = llvm::to_vector<4>(
llvm::map_range(startIndicesAttrs, [](Attribute value) {
return value.cast<IntegerAttr>().getValue().getZExtValue();
}));
auto targetValues = llvm::to_vector<4>(target.getValues<Attribute>());
// target indices start from startIndicesAttrs and update indices start from
// all zeros.
llvm::SmallVector<uint64_t, 4> targetIndex(startIndex);
llvm::SmallVector<uint64_t, 4> updateIndex(rank, 0);
int64_t numElements = updateType.getNumElements();
while (numElements--) {
targetValues[getFlattenedIndex(targetType, targetIndex)] =
update.getValue<Attribute>(updateIndex);
// Increment the index at last dim.
++updateIndex.back();
++targetIndex.back();
// If the index in dim j exceeds dim size, reset dim j and
// increment dim (j-1).
for (int64_t j = rank - 1;
j >= 0 && updateIndex[j] >= updateType.getDimSize(j); --j) {
updateIndex[j] = 0;
targetIndex[j] = startIndex[j];
if (j - 1 >= 0) {
++updateIndex[j - 1];
++targetIndex[j - 1];
}
}
}
return DenseElementsAttr::get(targetType, targetValues);
}
OpFoldResult TensorUpdateOp::fold(ArrayRef<Attribute> operands) {
auto targetIndex = getODSOperandIndexAndLength(0).first;
auto startIndices = getODSOperandIndexAndLength(2);
auto updateIndex = getODSOperandIndexAndLength(3).first;
auto indices = operands.slice(startIndices.first, startIndices.second);
bool allIndicesConstant = llvm::count(indices, nullptr) == 0;
if (operands[updateIndex] && operands[targetIndex] && allIndicesConstant) {
// Fully constant arguments so we can perform the update here.
return tensorUpdate(operands[updateIndex].cast<ElementsAttr>(),
operands[targetIndex].cast<ElementsAttr>(), indices);
} else {
// Replace the entire tensor when the sizes match.
auto updateType = update().getType().cast<ShapedType>();
auto targetType = target().getType().cast<ShapedType>();
if (updateType.hasStaticShape() && targetType.hasStaticShape() &&
updateType == targetType) {
return update();
}
}
return {};
}
namespace {
// When the target tensor is a result of a tensor.cast operation, the op needs
// to be updated to use the source of the cast as the target tensor.
struct FoldTensorUpdateOpWithCasts : public OpRewritePattern<TensorUpdateOp> {
using OpRewritePattern<TensorUpdateOp>::OpRewritePattern;
LogicalResult matchAndRewrite(TensorUpdateOp updateOp,
PatternRewriter &rewriter) const override {
auto targetCastOp = updateOp.target().getDefiningOp<tensor::CastOp>();
auto updateCastOp = updateOp.update().getDefiningOp<tensor::CastOp>();
if (!targetCastOp && !updateCastOp) return failure();
auto target = (targetCastOp ? targetCastOp.source() : updateOp.target());
auto update = (updateCastOp ? updateCastOp.source() : updateOp.update());
auto newOp = rewriter.create<TensorUpdateOp>(
updateOp.getLoc(), target.getType(), target,
refreshDimsOnTypeChange(updateOp, updateOp.target().getType(),
target.getType(), updateOp.target_dims(),
rewriter),
updateOp.start_indices(), update,
refreshDimsOnTypeChange(updateOp, updateOp.update().getType(),
update.getType(), updateOp.update_dims(),
rewriter),
updateOp.tied_operandsAttr());
rewriter.replaceOpWithNewOp<tensor::CastOp>(
updateOp, updateOp.getResult().getType(), newOp.getResult());
return success();
}
};
} // namespace
void TensorUpdateOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
results.insert<FoldTensorUpdateOpWithCasts>(context);
}
} // namespace Flow
} // namespace IREE
} // namespace iree_compiler
} // namespace mlir