| // Copyright 2019 The IREE Authors |
| // |
| // Licensed under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| |
| #include <algorithm> |
| #include <numeric> |
| |
| #include "iree/compiler/Dialect/Flow/IR/FlowDialect.h" |
| #include "iree/compiler/Dialect/Flow/IR/FlowOps.h" |
| #include "iree/compiler/Dialect/Util/IR/ClosureOpUtils.h" |
| #include "iree/compiler/Dialect/Util/IR/UtilOps.h" |
| #include "iree/compiler/Dialect/Util/IR/UtilTypes.h" |
| #include "llvm/ADT/MapVector.h" |
| #include "llvm/ADT/Optional.h" |
| #include "llvm/ADT/STLExtras.h" |
| #include "llvm/ADT/SmallPtrSet.h" |
| #include "llvm/ADT/StringExtras.h" |
| #include "mlir/Dialect/Arith/Utils/Utils.h" |
| #include "mlir/Dialect/Func/IR/FuncOps.h" |
| #include "mlir/Dialect/MemRef/IR/MemRef.h" |
| #include "mlir/Dialect/Tensor/IR/Tensor.h" |
| #include "mlir/IR/Attributes.h" |
| #include "mlir/IR/Builders.h" |
| #include "mlir/IR/Matchers.h" |
| #include "mlir/IR/OpDefinition.h" |
| #include "mlir/IR/OpImplementation.h" |
| #include "mlir/IR/PatternMatch.h" |
| #include "mlir/IR/SymbolTable.h" |
| #include "mlir/IR/TypeUtilities.h" |
| #include "mlir/IR/Value.h" |
| #include "mlir/Support/LogicalResult.h" |
| |
| namespace mlir { |
| namespace iree_compiler { |
| namespace IREE { |
| namespace Flow { |
| |
| //===----------------------------------------------------------------------===// |
| // Folding utilities |
| //===----------------------------------------------------------------------===// |
| |
| // Returns true if |value| is definitely empty at runtime. |
| static bool isTensorEmpty(Value value) { |
| auto type = value.getType().dyn_cast<ShapedType>(); |
| if (!type) return false; |
| // Any static dimension being zero is definitely empty. |
| for (int64_t i = 0; i < type.getRank(); ++i) { |
| int64_t dim = type.getDimSize(i); |
| if (dim == 0) return true; |
| } |
| return false; // may still be dynamically empty |
| } |
| |
| // Returns true if |value| is definitely empty at runtime. |
| // Returns false if the value is definitely not empty or may be empty at runtime |
| // (one or more dynamic dimensions). |
| static bool isTensorOperandEmpty(Value value) { |
| // Any value produced by an empty sentinel op is empty. |
| auto baseValue = IREE::Util::TiedOpInterface::findTiedBaseValue(value); |
| if (isa_and_nonnull<IREE::Flow::TensorEmptyOp>(baseValue.getDefiningOp())) { |
| return true; |
| } |
| return isTensorEmpty(value); |
| } |
| |
| // Returns true if |value| is definitely empty at runtime. |
| // Returns false if the value is definitely not empty or may be empty at runtime |
| // (one or more dynamic dimensions). |
| static bool isTensorResultEmpty(Value value) { return isTensorEmpty(value); } |
| |
| template <typename Op, int OperandIdx, int ResultIdx = 0> |
| struct ReplaceOpIfTensorOperandEmpty : public OpRewritePattern<Op> { |
| using OpRewritePattern<Op>::OpRewritePattern; |
| LogicalResult matchAndRewrite(Op op, |
| PatternRewriter &rewriter) const override { |
| auto operand = op->getOperand(OperandIdx); |
| if (!isTensorOperandEmpty(operand)) return failure(); |
| auto result = op->getResult(ResultIdx); |
| auto dynamicDims = op.getResultDynamicDims(result.getResultNumber()); |
| rewriter.replaceOpWithNewOp<IREE::Flow::TensorEmptyOp>(op, result.getType(), |
| dynamicDims); |
| return success(); |
| } |
| }; |
| |
| template <typename Op, int ResultIdx> |
| struct ReplaceOpIfTensorResultEmpty : public OpRewritePattern<Op> { |
| using OpRewritePattern<Op>::OpRewritePattern; |
| LogicalResult matchAndRewrite(Op op, |
| PatternRewriter &rewriter) const override { |
| auto result = op->getResult(ResultIdx); |
| if (!isTensorResultEmpty(result)) return failure(); |
| auto dynamicDims = op.getResultDynamicDims(result.getResultNumber()); |
| rewriter.replaceOpWithNewOp<IREE::Flow::TensorEmptyOp>(op, result.getType(), |
| dynamicDims); |
| return success(); |
| } |
| }; |
| |
| // Turns a tensor type that may have one or more dynamic dimensions into a |
| // static type with dynamic dimensions replaced with 0. |
| // Example: tensor<?x0x1xf32> -> tensor<0x0x1xf32> |
| static Type makeEmptyStaticTensorType(Type type) { |
| auto tensorType = type.cast<RankedTensorType>(); |
| if (tensorType.hasStaticShape()) return type; |
| SmallVector<int64_t> dims; |
| dims.resize(tensorType.getRank()); |
| for (int64_t i = 0; i < tensorType.getRank(); ++i) { |
| int64_t dim = tensorType.getDimSize(i); |
| dims[i] = dim == ShapedType::kDynamicSize ? 0 : dim; |
| } |
| return RankedTensorType::get(dims, tensorType.getElementType(), |
| tensorType.getEncoding()); |
| } |
| |
| // Returns a new set of dynamic dimensions for a shape carrying op when a type |
| // is being changed. This attempts to reuse the existing dimension values if |
| // they are available and will drop/insert new ones as required. |
| static SmallVector<Value, 4> refreshDimsOnTypeChange( |
| Operation *op, Type oldType, Type newType, ValueRange oldDims, |
| PatternRewriter &rewriter) { |
| if (oldType == newType) return llvm::to_vector<4>(oldDims); |
| |
| // Build an expanded list of all the dims - constants will be nullptr. |
| // This lets us map back the new types without worrying about whether some |
| // subset become static or dynamic. |
| auto oldShapedType = oldType.cast<ShapedType>(); |
| SmallVector<Value, 4> allOldDims(oldShapedType.getRank()); |
| for (unsigned i = 0; i < oldShapedType.getRank(); ++i) { |
| if (oldShapedType.isDynamicDim(i)) { |
| allOldDims[i] = oldDims.front(); |
| oldDims = oldDims.drop_front(); |
| } |
| } |
| |
| auto newShapedType = newType.cast<ShapedType>(); |
| SmallVector<Value, 4> newDims; |
| for (unsigned i = 0; i < newShapedType.getRank(); ++i) { |
| if (newShapedType.isDynamicDim(i)) { |
| auto oldValue = allOldDims[i]; |
| if (oldValue) { |
| // Old value valid; reuse. |
| newDims.push_back(oldValue); |
| } else { |
| // Dimension has changed to be dynamic; insert a constant to use. |
| // This sometimes happens during folding of casts and usually is cleaned |
| // up pretty quickly. |
| newDims.push_back(rewriter.createOrFold<arith::ConstantIndexOp>( |
| op->getLoc(), oldShapedType.getDimSize(i))); |
| } |
| } |
| } |
| return newDims; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // flow.dispatch.workgroups |
| //===----------------------------------------------------------------------===// |
| |
| struct ReplaceDispatchResultIfEmpty |
| : public OpRewritePattern<DispatchWorkgroupsOp> { |
| using OpRewritePattern::OpRewritePattern; |
| LogicalResult matchAndRewrite(DispatchWorkgroupsOp op, |
| PatternRewriter &rewriter) const override { |
| // NOTE: we only look at used results; if unused then closure optimization |
| // will drop it. |
| bool didReplaceAny = false; |
| for (auto result : op.getResults()) { |
| if (result.use_empty()) continue; |
| if (isTensorResultEmpty(result)) { |
| auto dynamicDims = op.getResultDynamicDims(result.getResultNumber()); |
| auto emptyOp = rewriter.create<IREE::Flow::TensorEmptyOp>( |
| result.getLoc(), result.getType(), dynamicDims); |
| result.replaceAllUsesWith(emptyOp); |
| didReplaceAny = true; |
| } |
| } |
| return didReplaceAny ? success() : failure(); |
| } |
| }; |
| |
| void DispatchWorkgroupsOp::getCanonicalizationPatterns( |
| RewritePatternSet &results, MLIRContext *context) { |
| results.insert<IREE::Util::ClosureOptimizationPattern<DispatchWorkgroupsOp>>( |
| context); |
| results.insert<ReplaceDispatchResultIfEmpty>(context); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // flow.dispatch.tie_shape |
| //===----------------------------------------------------------------------===// |
| |
| OpFoldResult DispatchTieShapeOp::fold(ArrayRef<Attribute> operands) { |
| if (getDynamicDims().empty()) { |
| return getOperand(); |
| } |
| return {}; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // flow.dispatch.tensor.load |
| //===----------------------------------------------------------------------===// |
| |
| namespace { |
| |
| // Updates the |dimValues| of |tensorValue| with dimensions inferred from IR. |
| // The dimension values may be derived values that are redundant with captured |
| // dimensions and by redirecting to the captured values we can simplify things. |
| // Returns true if the dims were changed. |
| static bool updateTensorOpDims(Operation *op, Value tensorValue, |
| MutableOperandRange mutableDimValues) { |
| auto dynamicDimsOr = IREE::Util::findDynamicDims(tensorValue, op->getBlock(), |
| Block::iterator(op)); |
| if (!dynamicDimsOr.has_value()) return false; |
| auto dynamicDims = dynamicDimsOr.value(); |
| bool anyChanged = false; |
| OperandRange oldValueRange = mutableDimValues; |
| auto oldValues = llvm::to_vector<4>(oldValueRange); |
| for (unsigned i = 0; i < dynamicDims.size(); ++i) { |
| if (oldValues[i] != dynamicDims[i]) { |
| mutableDimValues.slice(i, 1).assign(dynamicDims[i]); |
| anyChanged = true; |
| } |
| } |
| return anyChanged; |
| } |
| |
| struct ReuseDispatchTensorLoadShapeDims |
| : public OpRewritePattern<DispatchTensorLoadOp> { |
| using OpRewritePattern::OpRewritePattern; |
| LogicalResult matchAndRewrite(DispatchTensorLoadOp loadOp, |
| PatternRewriter &rewriter) const override { |
| return success(updateTensorOpDims(loadOp, loadOp.getSource(), |
| loadOp.getSourceDimsMutable())); |
| } |
| }; |
| |
| // Inlining producers of an input to the dispatch region results in the |
| // `flow.dispatch.input.load` having a `tensor` type as input. This fails |
| // verification. Since inlining happens during canonicalization, add a pattern |
| // to convert |
| // |
| // flow.dispatch.input.load %v, offsets .., sizes .., strides.. |
| // : tensor<...> -> tensor<..> |
| // |
| // to |
| // |
| // subtensor %v[..] [..] [..] |
| struct ConvertDispatchInputLoadOfTensorToSubTensor |
| : public OpRewritePattern<DispatchTensorLoadOp> { |
| using OpRewritePattern::OpRewritePattern; |
| LogicalResult matchAndRewrite(DispatchTensorLoadOp loadOp, |
| PatternRewriter &rewriter) const override { |
| if (!loadOp.getSource().getType().isa<RankedTensorType>()) { |
| return failure(); |
| } |
| // If the offsets are empty rely on folding to take care of it. |
| if (loadOp.offsets().empty() && loadOp.sizes().empty() && |
| loadOp.strides().empty()) { |
| return failure(); |
| } |
| rewriter.replaceOpWithNewOp<tensor::ExtractSliceOp>( |
| loadOp, loadOp.getSource(), loadOp.getMixedOffsets(), |
| loadOp.getMixedSizes(), loadOp.getMixedStrides()); |
| return success(); |
| } |
| }; |
| |
| /// For `op` that implements the `OffsetsStridesAndSizesInterface`, canonicalize |
| /// the `offsets`, `sizes` and `strides` by replacing aby value operand that is |
| /// defined by a constant with the integer value directly. The type of the slice |
| /// (result type for `flow.dispatch.tensor.load` and `value` type for |
| /// `flow.dispatch.tensor.store`) is also passed in. The type of the slice to |
| /// use in the canonicalized op is returned. |
| template <typename OpTy> |
| static FailureOr<RankedTensorType> canonicalizeSubViewParts( |
| OpTy op, RankedTensorType sliceType, |
| SmallVector<OpFoldResult> &mixedOffsets, |
| SmallVector<OpFoldResult> &mixedSizes, |
| SmallVector<OpFoldResult> &mixedStrides) { |
| // If there are no constant operands then we return early before the more |
| // expensive work below. |
| if (llvm::none_of(op.offsets(), |
| [](Value operand) { |
| return matchPattern(operand, matchConstantIndex()); |
| }) && |
| llvm::none_of(op.sizes(), |
| [](Value operand) { |
| return matchPattern(operand, matchConstantIndex()); |
| }) && |
| llvm::none_of(op.strides(), [](Value operand) { |
| return matchPattern(operand, matchConstantIndex()); |
| })) { |
| return failure(); |
| } |
| |
| // At least one of offsets/sizes/strides is a new constant. |
| // Form the new list of operands and constant attributes from the existing. |
| mixedOffsets.assign(op.getMixedOffsets()); |
| mixedSizes.assign(op.getMixedSizes()); |
| mixedStrides.assign(op.getMixedStrides()); |
| canonicalizeSubViewPart(mixedOffsets, ShapedType::isDynamicStrideOrOffset); |
| canonicalizeSubViewPart(mixedSizes, ShapedType::isDynamic); |
| canonicalizeSubViewPart(mixedStrides, ShapedType::isDynamicStrideOrOffset); |
| |
| // Drop out the same dimensions form before. |
| llvm::SmallVector<int64_t> newShape; |
| llvm::SmallBitVector droppedDims = op.getDroppedDims(); |
| for (auto size : llvm::enumerate(mixedSizes)) { |
| if (droppedDims.test(size.index())) continue; |
| Optional<int64_t> staticSize = getConstantIntValue(size.value()); |
| newShape.push_back(staticSize ? staticSize.value() |
| : ShapedType::kDynamicSize); |
| } |
| |
| auto newSliceType = |
| RankedTensorType::get(newShape, sliceType.getElementType()); |
| return newSliceType; |
| } |
| |
| /// Pattern to rewrite a subview op with constant arguments. |
| struct DispatchTensorLoadOpWithOffsetSizesAndStridesConstantArgumentFolder final |
| : public OpRewritePattern<DispatchTensorLoadOp> { |
| using OpRewritePattern::OpRewritePattern; |
| LogicalResult matchAndRewrite(DispatchTensorLoadOp loadOp, |
| PatternRewriter &rewriter) const override { |
| SmallVector<OpFoldResult> mixedOffsets, mixedSizes, mixedStrides; |
| RankedTensorType resultType = loadOp.getType(); |
| auto newResultType = canonicalizeSubViewParts( |
| loadOp, resultType, mixedOffsets, mixedSizes, mixedStrides); |
| if (failed(newResultType)) return failure(); |
| |
| // We need to resolve the new inferred type with the specified type. |
| Location loc = loadOp.getLoc(); |
| Value replacement = rewriter.create<DispatchTensorLoadOp>( |
| loc, newResultType.value(), loadOp.getSource(), loadOp.getSourceDims(), |
| mixedOffsets, mixedSizes, mixedStrides); |
| if (newResultType.value() != resultType) { |
| replacement = |
| rewriter.create<tensor::CastOp>(loc, resultType, replacement); |
| } |
| rewriter.replaceOp(loadOp, replacement); |
| return success(); |
| } |
| }; |
| |
| } // namespace |
| |
| void DispatchTensorLoadOp::getCanonicalizationPatterns( |
| RewritePatternSet &results, MLIRContext *context) { |
| results.insert<ReuseDispatchTensorLoadShapeDims>(context); |
| results.insert<ConvertDispatchInputLoadOfTensorToSubTensor>(context); |
| results.insert< |
| DispatchTensorLoadOpWithOffsetSizesAndStridesConstantArgumentFolder>( |
| context); |
| } |
| |
| // Inlining producers of an input to the dispatch region results in the |
| // `flow.dispatch.input.load` having a `tensor` type as input. This fails |
| // verification. Fold such uses of the offsets, size and strides are emtpy. |
| // i.e, flow.dispatch.input.load %v -> %v |
| OpFoldResult DispatchTensorLoadOp::fold(ArrayRef<Attribute> operands) { |
| if (getSource().getType() && getSource().getType().isa<RankedTensorType>() && |
| getMixedOffsets().empty() && getMixedSizes().empty() && |
| getMixedStrides().empty()) { |
| return getSource(); |
| } |
| return {}; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // flow.dispatch.tensor.store |
| //===----------------------------------------------------------------------===// |
| |
| namespace { |
| |
| struct ReuseDispatchTensorStoreShapeDims |
| : public OpRewritePattern<DispatchTensorStoreOp> { |
| using OpRewritePattern::OpRewritePattern; |
| LogicalResult matchAndRewrite(DispatchTensorStoreOp storeOp, |
| PatternRewriter &rewriter) const override { |
| return success(updateTensorOpDims(storeOp, storeOp.getTarget(), |
| storeOp.getTargetDimsMutable())); |
| } |
| }; |
| |
| struct FoldCastOpIntoDispatchStoreOp |
| : public OpRewritePattern<DispatchTensorStoreOp> { |
| using OpRewritePattern::OpRewritePattern; |
| LogicalResult matchAndRewrite(DispatchTensorStoreOp storeOp, |
| PatternRewriter &rewriter) const override { |
| auto parentOp = storeOp.getValue().getDefiningOp<tensor::CastOp>(); |
| if (!parentOp || !tensor::canFoldIntoConsumerOp(parentOp)) return failure(); |
| |
| rewriter.replaceOpWithNewOp<DispatchTensorStoreOp>( |
| storeOp, parentOp.getSource(), storeOp.getTarget(), |
| storeOp.getTargetDims(), storeOp.offsets(), storeOp.sizes(), |
| storeOp.strides(), storeOp.static_offsets(), storeOp.static_sizes(), |
| storeOp.static_strides()); |
| return success(); |
| } |
| }; |
| |
| struct DispatchTensorStoreOpWithOffsetSizesAndStridesConstantArgumentFolder |
| final : public OpRewritePattern<DispatchTensorStoreOp> { |
| using OpRewritePattern::OpRewritePattern; |
| LogicalResult matchAndRewrite(DispatchTensorStoreOp storeOp, |
| PatternRewriter &rewriter) const override { |
| SmallVector<OpFoldResult> mixedOffsets, mixedSizes, mixedStrides; |
| RankedTensorType valueType = storeOp.getValueType(); |
| auto newValueType = canonicalizeSubViewParts( |
| storeOp, valueType, mixedOffsets, mixedSizes, mixedStrides); |
| if (failed(newValueType)) return failure(); |
| |
| Value value = storeOp.getValue(); |
| Location loc = storeOp.getLoc(); |
| if (newValueType.value() != valueType) { |
| value = rewriter.create<tensor::CastOp>(loc, newValueType.value(), value); |
| } |
| rewriter.replaceOpWithNewOp<DispatchTensorStoreOp>( |
| storeOp, value, storeOp.getTarget(), storeOp.getTargetDims(), |
| mixedOffsets, mixedSizes, mixedStrides); |
| return success(); |
| } |
| }; |
| |
| } // namespace |
| |
| void DispatchTensorStoreOp::getCanonicalizationPatterns( |
| RewritePatternSet &results, MLIRContext *context) { |
| results.insert< |
| DispatchTensorStoreOpWithOffsetSizesAndStridesConstantArgumentFolder, |
| FoldCastOpIntoDispatchStoreOp, ReuseDispatchTensorStoreShapeDims>( |
| context); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Tensor ops |
| //===----------------------------------------------------------------------===// |
| |
| /// Reduces the provided multidimensional index into a flattended 1D row-major |
| /// index. The |type| is expected to be statically shaped (as all constants |
| /// are). |
| static uint64_t getFlattenedIndex(ShapedType type, ArrayRef<uint64_t> index) { |
| assert(type.hasStaticShape() && "for use on statically shaped types only"); |
| auto rank = type.getRank(); |
| auto shape = type.getShape(); |
| uint64_t valueIndex = 0; |
| uint64_t dimMultiplier = 1; |
| for (int i = rank - 1; i >= 0; --i) { |
| valueIndex += index[i] * dimMultiplier; |
| dimMultiplier *= shape[i]; |
| } |
| return valueIndex; |
| } |
| |
| static bool compareShapesEqual(ShapedType lhsType, ValueRange lhsDynamicDims, |
| ShapedType rhsType, ValueRange rhsDynamicDims) { |
| if (lhsType.hasStaticShape() && rhsType.hasStaticShape() && |
| lhsType == rhsType) { |
| // Static shape equivalence means we can fast-path the check. |
| return true; |
| } |
| if (lhsType.getRank() != rhsType.getRank()) { |
| return false; |
| } |
| unsigned dynamicDimIndex = 0; |
| for (unsigned i = 0; i < lhsType.getRank(); ++i) { |
| if (lhsType.isDynamicDim(i) != rhsType.isDynamicDim(i)) { |
| // Static/dynamic dimension mismatch - definitely differ. |
| return false; |
| } else if (lhsType.isDynamicDim(i)) { |
| unsigned j = dynamicDimIndex++; |
| if (lhsDynamicDims[j] != rhsDynamicDims[j]) { |
| // Dynamic dimensions with different SSA values - probably differ. |
| return false; |
| } |
| } else { |
| if (lhsType.getDimSize(i) != rhsType.getDimSize(i)) { |
| // Static dimensions differ. |
| return false; |
| } |
| } |
| } |
| return true; |
| } |
| |
| OpFoldResult TensorConstantOp::fold(ArrayRef<Attribute> operands) { |
| auto dynamicType = getType(); |
| if (dynamicType.getNumDynamicDims() == 0) { |
| return getValue(); |
| } |
| return {}; |
| } |
| |
| namespace { |
| |
| struct ExpandDynamicShapeConstant : public OpRewritePattern<TensorConstantOp> { |
| using OpRewritePattern<TensorConstantOp>::OpRewritePattern; |
| LogicalResult matchAndRewrite(TensorConstantOp op, |
| PatternRewriter &rewriter) const override { |
| auto constantOp = |
| rewriter.create<arith::ConstantOp>(op.getLoc(), op.getValue()); |
| auto dynamicType = op.getType(); |
| auto staticType = constantOp.getType().cast<ShapedType>(); |
| SmallVector<Value> dynamicDims; |
| for (int64_t i = 0; i < dynamicType.getNumDynamicDims(); ++i) { |
| auto dimValue = rewriter |
| .create<arith::ConstantIndexOp>( |
| op.getLoc(), staticType.getDimSize(i)) |
| .getResult(); |
| dynamicDims.push_back( |
| rewriter.create<IREE::Util::DoNotOptimizeOp>(op.getLoc(), dimValue) |
| .getResult(0)); |
| } |
| rewriter.replaceOpWithNewOp<IREE::Flow::TensorReshapeOp>( |
| op, dynamicType, constantOp.getResult(), dynamicDims); |
| return success(); |
| } |
| }; |
| |
| } // namespace |
| |
| void TensorConstantOp::getCanonicalizationPatterns(RewritePatternSet &results, |
| MLIRContext *context) { |
| results.insert<ExpandDynamicShapeConstant>(context); |
| } |
| |
| OpFoldResult TensorTieShapeOp::fold(ArrayRef<Attribute> operands) { |
| if (getDynamicDims().empty()) { |
| return getOperand(); |
| } |
| return {}; |
| } |
| |
| void TensorTieShapeOp::getCanonicalizationPatterns(RewritePatternSet &results, |
| MLIRContext *context) { |
| results.insert<ReplaceOpIfTensorOperandEmpty<TensorTieShapeOp, 0>>(context); |
| } |
| |
| OpFoldResult TensorReshapeOp::fold(ArrayRef<Attribute> operands) { |
| auto sourceType = getSource().getType().cast<ShapedType>(); |
| auto resultType = getResult().getType().cast<ShapedType>(); |
| if (compareShapesEqual(sourceType, getSourceDims(), resultType, |
| getResultDims())) { |
| // Shapes match and this is a no-op so just fold to the source. |
| return getSource(); |
| } |
| return {}; |
| } |
| |
| namespace { |
| |
| // Flatten a chain of reshapes (reshape feeding into reshape) such that a |
| // reshape only ever pulls from a non-reshape source. This prevents big useless |
| // chains and makes it easier to track the original storage for the tensor. |
| struct FlattenTensorReshapeChain : public OpRewritePattern<TensorReshapeOp> { |
| using OpRewritePattern<TensorReshapeOp>::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp, |
| PatternRewriter &rewriter) const override { |
| auto sourceOp = dyn_cast_or_null<TensorReshapeOp>( |
| reshapeOp.getSource().getDefiningOp()); |
| if (!sourceOp) return failure(); |
| |
| // We want the same result value/shape but to source from the ancestor. We |
| // need to pull any dynamic dims from that as we don't care about the |
| // intermediate reshapes. |
| rewriter.replaceOpWithNewOp<TensorReshapeOp>( |
| reshapeOp, reshapeOp.getResult().getType(), sourceOp.getSource(), |
| sourceOp.getSourceDims(), reshapeOp.getResultDims()); |
| return success(); |
| } |
| }; |
| |
| // Replace `flow.tensor.splat`-`flow.tensor.load` op-pairs by the input |
| // primitive value for the splat op. |
| struct FoldSplatLoadIntoPrimitive : public OpRewritePattern<TensorLoadOp> { |
| using OpRewritePattern<TensorLoadOp>::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(TensorLoadOp loadOp, |
| PatternRewriter &rewriter) const override { |
| auto sourceOp = |
| dyn_cast_or_null<TensorSplatOp>(loadOp.getSource().getDefiningOp()); |
| |
| if (!sourceOp) return failure(); |
| |
| rewriter.replaceOp(loadOp, sourceOp.getValue()); |
| return success(); |
| } |
| }; |
| |
| struct FoldSplatReshapeIntoSplat : public OpRewritePattern<TensorSplatOp> { |
| using OpRewritePattern<TensorSplatOp>::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(TensorSplatOp splatOp, |
| PatternRewriter &rewriter) const override { |
| if (!splatOp.getResult().hasOneUse()) return failure(); |
| |
| auto reshapeOp = dyn_cast_or_null<TensorReshapeOp>( |
| splatOp.getResult().use_begin()->getOwner()); |
| if (!reshapeOp) return failure(); |
| |
| rewriter.replaceOpWithNewOp<TensorSplatOp>( |
| reshapeOp, reshapeOp.getResult().getType(), splatOp.getValue(), |
| reshapeOp.getResultDims()); |
| rewriter.eraseOp(splatOp); |
| |
| return success(); |
| } |
| }; |
| |
| struct ResolveShapedRank : public OpRewritePattern<tensor::RankOp> { |
| using OpRewritePattern<tensor::RankOp>::OpRewritePattern; |
| LogicalResult matchAndRewrite(tensor::RankOp op, |
| PatternRewriter &rewriter) const override { |
| auto shapedType = op.getTensor().getType().cast<ShapedType>(); |
| rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op, |
| shapedType.getRank()); |
| return success(); |
| } |
| }; |
| |
| struct ResolveShapedDim : public OpRewritePattern<tensor::DimOp> { |
| using OpRewritePattern<tensor::DimOp>::OpRewritePattern; |
| LogicalResult matchAndRewrite(tensor::DimOp op, |
| PatternRewriter &rewriter) const override { |
| if (!op.getConstantIndex().has_value()) { |
| return rewriter.notifyMatchFailure( |
| op, "non-constant index dim ops are unsupported"); |
| } |
| auto idx = op.getConstantIndex().value(); |
| |
| auto shapedType = op.getSource().getType().cast<ShapedType>(); |
| if (!shapedType.isDynamicDim(idx)) { |
| rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>( |
| op, shapedType.getDimSize(idx)); |
| return success(); |
| } |
| |
| auto dynamicDims = IREE::Util::findDynamicDims( |
| op.getSource(), op->getBlock(), Block::iterator(op.getOperation())); |
| if (!dynamicDims.has_value()) { |
| return rewriter.notifyMatchFailure(op, "no dynamic dims found/usable"); |
| } |
| unsigned dimOffset = 0; |
| for (unsigned i = 0; i < idx; ++i) { |
| if (shapedType.isDynamicDim(i)) ++dimOffset; |
| } |
| rewriter.replaceOp(op, dynamicDims.value()[dimOffset]); |
| |
| return success(); |
| } |
| }; |
| |
| } // namespace |
| |
| void TensorReshapeOp::getCanonicalizationPatterns(RewritePatternSet &results, |
| MLIRContext *context) { |
| results.insert<ReplaceOpIfTensorOperandEmpty<TensorReshapeOp, 0>>(context); |
| results.insert<ReplaceOpIfTensorResultEmpty<TensorReshapeOp, 0>>(context); |
| results.insert<FlattenTensorReshapeChain>(context); |
| results.insert<ResolveShapedRank>(context); |
| results.insert<ResolveShapedDim>(context); |
| } |
| |
| void TensorLoadOp::getCanonicalizationPatterns(RewritePatternSet &results, |
| MLIRContext *context) { |
| results.insert<FoldSplatLoadIntoPrimitive>(context); |
| } |
| |
| OpFoldResult TensorLoadOp::fold(ArrayRef<Attribute> operands) { |
| if (auto source = operands[0].dyn_cast_or_null<ElementsAttr>()) { |
| // Load directly from the constant source tensor. |
| auto indices = operands.drop_front(); |
| if (llvm::count(indices, nullptr) == 0) { |
| return source.getValues<Attribute>()[llvm::to_vector<4>( |
| llvm::map_range(indices, [](Attribute value) { |
| return value.cast<IntegerAttr>().getValue().getZExtValue(); |
| }))]; |
| } |
| } |
| return {}; |
| } |
| |
| OpFoldResult TensorStoreOp::fold(ArrayRef<Attribute> operands) { |
| if (!operands[0]) return {}; |
| auto &value = operands[0]; |
| if (auto target = operands[1].dyn_cast_or_null<ElementsAttr>()) { |
| // Store into the constant target tensor. |
| if (target.getType().getRank() == 0) { |
| return DenseElementsAttr::get(target.getType(), {value}); |
| } |
| auto indices = operands.drop_front(2); |
| if (llvm::count(indices, nullptr) == 0) { |
| uint64_t offset = getFlattenedIndex( |
| target.getType(), |
| llvm::to_vector<4>(llvm::map_range(indices, [](Attribute value) { |
| return value.cast<IntegerAttr>().getValue().getZExtValue(); |
| }))); |
| SmallVector<Attribute, 16> newContents(target.getValues<Attribute>()); |
| newContents[offset] = value; |
| return DenseElementsAttr::get(target.getType(), newContents); |
| } |
| } |
| return {}; |
| } |
| |
| void TensorEmptyOp::getCanonicalizationPatterns(RewritePatternSet &results, |
| MLIRContext *context) { |
| // TODO(benvanik): fold static shapes into dims. |
| } |
| |
| void TensorSplatOp::getCanonicalizationPatterns(RewritePatternSet &results, |
| MLIRContext *context) { |
| // TODO(benvanik): canonicalize splat+slice to smaller splat. |
| results.insert<ReplaceOpIfTensorResultEmpty<TensorSplatOp, 0>>(context); |
| results.insert<FoldSplatReshapeIntoSplat>(context); |
| } |
| |
| OpFoldResult TensorCloneOp::fold(ArrayRef<Attribute> operands) { |
| if (operands[0]) { |
| // Constants always fold. |
| return operands[0]; |
| } |
| |
| // TODO(benvanik): elide clones when safe to do so. Right now clone is |
| // load-bearing to work around our lack of cross-stream scheduling. Clones are |
| // inserted to avoid mutating function arguments and any logic we perform here |
| // (without *also* checking all the conditions that may insert a clone) will |
| // just fight. |
| // |
| // Once the clones are not load-bearing we can remove them in all the normal |
| // cases (one user, no intervening uses between clone and consumers of |
| // operands, etc). |
| |
| return {}; |
| } |
| |
| void TensorCloneOp::getCanonicalizationPatterns(RewritePatternSet &results, |
| MLIRContext *context) { |
| results.insert<ReplaceOpIfTensorOperandEmpty<TensorCloneOp, 0>>(context); |
| } |
| |
| // Slices tensor from start to (start + length) exclusively at dim. |
| static ElementsAttr tensorSlice(ElementsAttr tensor, uint64_t dim, |
| uint64_t start, uint64_t length) { |
| auto shape = llvm::to_vector<4>(tensor.getType().getShape()); |
| if (length == shape[dim]) { |
| // No need to slice. |
| return tensor; |
| } |
| auto outputShape = shape; |
| outputShape[dim] = length; |
| auto outputType = |
| RankedTensorType::get(outputShape, getElementTypeOrSelf(tensor)); |
| llvm::SmallVector<Attribute, 4> newContents; |
| newContents.reserve(outputType.getNumElements()); |
| auto valuesBegin = tensor.getValues<Attribute>().begin(); |
| int64_t step = |
| std::accumulate(shape.rbegin(), shape.rbegin() + shape.size() - dim, |
| /*init=*/1, /*op=*/std::multiplies<int64_t>()); |
| int64_t num = length * step / shape[dim]; |
| for (int64_t offset = step / shape[dim] * start, |
| numElements = tensor.getType().getNumElements(); |
| offset < numElements; offset += step) { |
| newContents.append(valuesBegin + offset, valuesBegin + offset + num); |
| } |
| return DenseElementsAttr::get(outputType, newContents); |
| } |
| |
| OpFoldResult TensorSliceOp::fold(ArrayRef<Attribute> operands) { |
| if (llvm::count(operands, nullptr) == 0) { |
| // Fully constant arguments so we can perform the slice here. |
| auto tensor = operands[0].cast<ElementsAttr>(); |
| int64_t rank = getSource().getType().cast<ShapedType>().getRank(); |
| // start = operands[1:1+rank), and length = operands[1+rank:]. |
| auto start = llvm::to_vector<4>(llvm::map_range( |
| operands.drop_front(1).drop_back(rank), [](Attribute value) { |
| return value.cast<IntegerAttr>().getValue().getZExtValue(); |
| })); |
| auto length = llvm::to_vector<4>( |
| llvm::map_range(operands.drop_front(1 + rank), [](Attribute value) { |
| return value.cast<IntegerAttr>().getValue().getZExtValue(); |
| })); |
| for (int64_t dim = 0; dim < rank; ++dim) { |
| tensor = tensorSlice(tensor, dim, start[dim], length[dim]); |
| } |
| return tensor; |
| } |
| return {}; |
| } |
| |
| void TensorSliceOp::getCanonicalizationPatterns(RewritePatternSet &results, |
| MLIRContext *context) { |
| // TODO(benvanik): canonicalize multiple slices (traverse upward through ssa). |
| results.insert<ReplaceOpIfTensorOperandEmpty<TensorSliceOp, 0>>(context); |
| results.insert<ReplaceOpIfTensorResultEmpty<TensorSliceOp, 0>>(context); |
| } |
| |
| static ElementsAttr tensorUpdate(ElementsAttr update, ElementsAttr target, |
| ArrayRef<Attribute> startIndicesAttrs) { |
| auto updateType = update.getType().cast<ShapedType>(); |
| auto targetType = target.getType().cast<ShapedType>(); |
| // If either target or update has zero element, then no update happens. |
| if (updateType.getNumElements() == 0 || targetType.getNumElements() == 0) { |
| return target; |
| } |
| |
| int64_t rank = targetType.getRank(); |
| // If target is scalar, update is also scalar and is the new content. |
| if (rank == 0) { |
| return update; |
| } |
| |
| auto startIndex = llvm::to_vector<4>( |
| llvm::map_range(startIndicesAttrs, [](Attribute value) { |
| return value.cast<IntegerAttr>().getValue().getZExtValue(); |
| })); |
| auto targetValues = llvm::to_vector<4>(target.getValues<Attribute>()); |
| // target indices start from startIndicesAttrs and update indices start from |
| // all zeros. |
| llvm::SmallVector<uint64_t, 4> targetIndex(startIndex); |
| llvm::SmallVector<uint64_t, 4> updateIndex(rank, 0); |
| int64_t numElements = updateType.getNumElements(); |
| while (numElements--) { |
| targetValues[getFlattenedIndex(targetType, targetIndex)] = |
| update.getValues<Attribute>()[updateIndex]; |
| // Increment the index at last dim. |
| ++updateIndex.back(); |
| ++targetIndex.back(); |
| // If the index in dim j exceeds dim size, reset dim j and |
| // increment dim (j-1). |
| for (int64_t j = rank - 1; |
| j >= 0 && updateIndex[j] >= updateType.getDimSize(j); --j) { |
| updateIndex[j] = 0; |
| targetIndex[j] = startIndex[j]; |
| if (j - 1 >= 0) { |
| ++updateIndex[j - 1]; |
| ++targetIndex[j - 1]; |
| } |
| } |
| } |
| return DenseElementsAttr::get(targetType, targetValues); |
| } |
| |
| OpFoldResult TensorUpdateOp::fold(ArrayRef<Attribute> operands) { |
| auto targetIndex = getODSOperandIndexAndLength(0).first; |
| auto startIndices = getODSOperandIndexAndLength(2); |
| auto updateIndex = getODSOperandIndexAndLength(3).first; |
| auto indices = operands.slice(startIndices.first, startIndices.second); |
| bool allIndicesConstant = llvm::count(indices, nullptr) == 0; |
| if (operands[updateIndex] && operands[targetIndex] && allIndicesConstant) { |
| // Fully constant arguments so we can perform the update here. |
| return tensorUpdate(operands[updateIndex].cast<ElementsAttr>(), |
| operands[targetIndex].cast<ElementsAttr>(), indices); |
| } else { |
| // Replace the entire tensor when the sizes match. |
| auto updateType = getUpdate().getType().cast<ShapedType>(); |
| auto targetType = getTarget().getType().cast<ShapedType>(); |
| if (updateType.hasStaticShape() && targetType.hasStaticShape() && |
| updateType == targetType) { |
| return getUpdate(); |
| } |
| } |
| return {}; |
| } |
| |
| namespace { |
| |
| // When the target tensor is a result of a tensor.cast operation, the op needs |
| // to be updated to use the source of the cast as the target tensor. |
| struct FoldTensorUpdateOpWithCasts : public OpRewritePattern<TensorUpdateOp> { |
| using OpRewritePattern<TensorUpdateOp>::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(TensorUpdateOp updateOp, |
| PatternRewriter &rewriter) const override { |
| auto targetCastOp = updateOp.getTarget().getDefiningOp<tensor::CastOp>(); |
| auto updateCastOp = updateOp.getUpdate().getDefiningOp<tensor::CastOp>(); |
| if (!targetCastOp && !updateCastOp) return failure(); |
| auto target = |
| (targetCastOp ? targetCastOp.getSource() : updateOp.getTarget()); |
| auto update = |
| (updateCastOp ? updateCastOp.getSource() : updateOp.getUpdate()); |
| auto newOp = rewriter.create<TensorUpdateOp>( |
| updateOp.getLoc(), target.getType(), target, |
| refreshDimsOnTypeChange(updateOp, updateOp.getTarget().getType(), |
| target.getType(), updateOp.getTargetDims(), |
| rewriter), |
| updateOp.getStartIndices(), update, |
| refreshDimsOnTypeChange(updateOp, updateOp.getUpdate().getType(), |
| update.getType(), updateOp.getUpdateDims(), |
| rewriter), |
| updateOp.getTiedOperandsAttr()); |
| rewriter.replaceOpWithNewOp<tensor::CastOp>( |
| updateOp, updateOp.getResult().getType(), newOp.getResult()); |
| return success(); |
| } |
| }; |
| |
| struct ReplaceOpIfTensorUpdateOperandEmpty |
| : public OpRewritePattern<TensorUpdateOp> { |
| using OpRewritePattern<TensorUpdateOp>::OpRewritePattern; |
| LogicalResult matchAndRewrite(TensorUpdateOp op, |
| PatternRewriter &rewriter) const override { |
| auto operand = op.getUpdate(); |
| if (!isTensorOperandEmpty(operand)) return failure(); |
| rewriter.replaceOp(op, op.getTarget()); |
| return success(); |
| } |
| }; |
| |
| } // namespace |
| |
| void TensorUpdateOp::getCanonicalizationPatterns(RewritePatternSet &results, |
| MLIRContext *context) { |
| results.insert<FoldTensorUpdateOpWithCasts>(context); |
| // target: |
| results.insert<ReplaceOpIfTensorOperandEmpty<TensorUpdateOp, 0>>(context); |
| // update: |
| results.insert<ReplaceOpIfTensorUpdateOperandEmpty>(context); |
| } |
| |
| } // namespace Flow |
| } // namespace IREE |
| } // namespace iree_compiler |
| } // namespace mlir |