blob: 30169fd264e2dd72a7d21c53361dd714bf234493 [file] [log] [blame]
// Copyright 2020 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "iree/compiler/Dialect/Shape/IR/ShapeDialect.h"
#include "iree/compiler/Dialect/Shape/IR/ShapeOps.h"
#include "iree/compiler/Dialect/Shape/IR/ShapeTypes.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Traits.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassRegistry.h"
#include "mlir/Transforms/DialectConversion.h"
namespace mlir {
namespace iree_compiler {
namespace Shape {
namespace {
// This conversion is currently quite limited, such as not handling multiple
// basic blocks in general, due to doing a type conversion that the MLIR core
// conversion infra doesn't handle well.
//
// In particular, we convert `!shape.shape` to `!shapex.ranked_shape<...>`, but
// the contents of the `...` are context-dependent. Thus, one could say that
// this pass does a context-dependent type conversion.
//
// The current MLIR conversion infra doesn't handle context-dependent type
// conversions.
//
// I can see two solutions:
//
// 1. Extend the MLIR conversion infra to better support context-dependent type
// conversions. One way to do this would be for the conversion infra to convert
// blocks in RPO and use the type of the converted successor operand in a
// dominating predecessor as the type for the block argument when converting a
// block. A similar thing could be done with an RPO traversal of the callgraph.
// This algorithm wouldn't work in the presence of recursively dead cycles. And
// of course linkage boundaries cannot have a context-dependent type conversion
// (by definition).
//
// 2. Avoid needing to convert to !shapex.ranked_shape in the first place. This
// could be accomplished by generalizing !shape.shape to be able to support the
// use case of !shapex.ranked_shape. One important requirement here is that
// !shapex.ranked_shape models a partially-specified shape (hardcoded for the
// ranked case). !shape.shape could be extended to capture partially-specified
// shapes in the type, such as allowing `!shape.shape<*>` to model an unranked
// shape (which is the default; no information), `!shape.shape<?x?x5x?>` to
// model a rank-4 shape with dimension 2 being of extent 5, etc.
//
// Once we have this, we could do this lowering from generic !shape.shape to
// statically-known ranked shapes more progressively and treat it more like a
// type refinement algorithm.
//
// The main risk is that we are trying to shove too much stuff into the
// !shape.shape type. There's a risk that "progressive lowering" becomes "no
// clear boundaries" and we end up with code deep into the compiler continuously
// needing to doublecheck that the !shape.shape's at this point are in fact
// statically known to be ranked, or silently making that assumption and
// triggering assertions on verifier-valid IR. Pipelines and legalization
// targets could make these assertions not fire in practice, but it would
// be a maintenance burden.
class ConvertConstShapeOp : public OpConversionPattern<shape::ConstShapeOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
shape::ConstShapeOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
SmallVector<int64_t, 4> extents;
for (APInt extent : op.shape()) {
extents.push_back(extent.getZExtValue());
}
auto rsType = RankedShapeType::get(extents, rewriter.getContext());
rewriter.replaceOpWithNewOp<ConstRankedShapeOp>(op, rsType);
return success();
}
};
class ConvertShapeOfOp : public OpConversionPattern<shape::ShapeOfOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
shape::ShapeOfOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto tensorType = operands[0].getType().dyn_cast<RankedTensorType>();
if (!tensorType) {
return failure();
}
auto resultType =
RankedShapeType::get(tensorType.getShape(), rewriter.getContext());
// TODO(jpienaar): The following needs to be re-evaluated once the patch
// train from 2020/07/23 integrates properly. This is required to make
// it forward and backwards compatible. Also, tests need to be added once
// upstream integrates (and this can be tested).
// rewriter.replaceOpWithNewOp<Shape::GetRankedShapeOp>(op, resultType,
// operands[0]);
auto getRanked = rewriter.create<Shape::GetRankedShapeOp>(
op.getLoc(), resultType, operands[0]);
// For FromExtentTensorOp users, just forward the result from GetRanked.
SmallPtrSet<Operation *, 2> toDelete;
for (auto use : op.getOperation()->getUsers()) {
if (isa<FromExtentTensorOp>(use)) {
use->replaceAllUsesWith(getRanked);
toDelete.insert(use);
}
}
for (Operation *use : toDelete) {
rewriter.eraseOp(use);
}
rewriter.replaceOp(op.getOperation(), getRanked.getResult());
return success();
}
};
class ConvertTensorExtract : public OpConversionPattern<tensor::ExtractOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
tensor::ExtractOp op, ArrayRef<Value> rawOperands,
ConversionPatternRewriter &rewriter) const override {
tensor::ExtractOpAdaptor operands(rawOperands);
if (!operands.tensor().getType().isa<RankedShapeType>()) {
return rewriter.notifyMatchFailure(op, "not acting on a ranked shape");
}
auto dim = operands.indices().front();
auto dimConstOp = dyn_cast_or_null<ConstantIndexOp>(dim.getDefiningOp());
if (!dimConstOp) {
return rewriter.notifyMatchFailure(op, "extract index not constant");
}
rewriter.replaceOpWithNewOp<Shape::RankedDimOp>(
op, rewriter.getIndexType(), operands.tensor(),
dimConstOp.value().cast<IntegerAttr>());
return success();
}
};
class ConvertGetExtent : public OpConversionPattern<shape::GetExtentOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
shape::GetExtentOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<Shape::RankedDimOp>(
op, rewriter.getIndexType(), operands[0],
op.getConstantDim().getValue());
return success();
}
};
class ConvertFromExtents : public OpConversionPattern<shape::FromExtentsOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
shape::FromExtentsOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
SmallVector<Value, 4> dynOperands;
SmallVector<int64_t, 4> extents;
for (auto operand : operands) {
IntegerAttr indexAttr;
if (matchPattern(operand, m_Constant(&indexAttr))) {
extents.push_back(indexAttr.getValue().getSExtValue());
continue;
}
dynOperands.push_back(operand);
extents.push_back(-1);
}
auto resultType = RankedShapeType::get(extents, rewriter.getContext());
auto make = rewriter.create<Shape::MakeRankedShapeOp>(
op.getLoc(), resultType, dynOperands);
rewriter.replaceOp(op, make.getResult());
return success();
}
};
class ConvertSplitAtOp : public OpConversionPattern<shape::SplitAtOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
shape::SplitAtOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
IntegerAttr indexAttr;
if (!matchPattern(op.index(), m_Constant(&indexAttr))) {
return rewriter.notifyMatchFailure(op, "requires constant `index`");
}
auto rank = operands[0].getType().cast<RankedShapeType>().getRank();
int64_t index = indexAttr.getInt();
if (index < 0) {
index += rank;
}
auto head_indices = llvm::to_vector<4>(llvm::seq<int64_t>(0, index));
auto tail_indices = llvm::to_vector<4>(llvm::seq<int64_t>(index, rank));
Value head = rewriter.create<GatherExtentsOp>(
op.getLoc(), operands[0], rewriter.getI64TensorAttr(head_indices));
Value tail = rewriter.create<GatherExtentsOp>(
op.getLoc(), operands[0], rewriter.getI64TensorAttr(tail_indices));
rewriter.replaceOp(op, {head, tail});
return success();
}
};
class ConvertBroadcastOp : public OpConversionPattern<shape::BroadcastOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
shape::BroadcastOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
Value lhs = operands[0];
Value rhs = operands[1];
auto lhsType = lhs.getType().dyn_cast<RankedShapeType>();
auto rhsType = rhs.getType().dyn_cast<RankedShapeType>();
if (!lhsType || !rhsType) {
return failure();
}
// Establish invariant that rank(lhs) <= rank(rhs)
if (lhsType.getRank() > rhsType.getRank()) {
std::swap(lhsType, rhsType);
std::swap(lhs, rhs);
}
SmallVector<int64_t, 6> resultShape;
OpTrait::util::getBroadcastedShape(lhsType.getAllDims(),
rhsType.getAllDims(), resultShape);
auto resultType = RankedShapeType::get(resultShape, rewriter.getContext());
auto iota = llvm::to_vector<4>(llvm::seq<int64_t>(0, rhsType.getRank()));
Value broadcasted = rewriter.replaceOpWithNewOp<RankedBroadcastShapeOp>(
op, resultType, lhs, rhs,
/*lhs_broadcast_dimensions=*/
rewriter.getI64TensorAttr(makeArrayRef(iota).drop_front(
rhsType.getRank() - lhsType.getRank())),
/*rhs_broadcast_dimensions=*/
rewriter.getI64TensorAttr(iota));
// For FromExtentTensorOp users, just forward the RankedShapeType result.
for (Operation *user : op.getResult().getUsers()) {
if (isa<Shape::FromExtentTensorOp>(user)) {
rewriter.replaceOp(user, ValueRange{broadcasted});
}
}
return success();
}
};
class ConvertConcatOp : public OpConversionPattern<shape::ConcatOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
shape::ConcatOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto resultRank = operands[0].getType().cast<RankedShapeType>().getRank() +
operands[1].getType().cast<RankedShapeType>().getRank();
auto indices = llvm::to_vector<4>(llvm::seq<int64_t>(0, resultRank));
rewriter.replaceOpWithNewOp<Shape::GatherExtentsOp>(
op, ValueRange({operands[0], operands[1]}),
rewriter.getI64TensorAttr(indices));
return success();
}
};
class ConvertToExtentTensorOp
: public OpConversionPattern<shape::ToExtentTensorOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
shape::ToExtentTensorOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<Shape::ToExtentTensorOp>(op, op.getType(),
operands[0]);
return success();
}
};
class ConvertFromExtentTensorOp
: public OpConversionPattern<shape::FromExtentTensorOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
shape::FromExtentTensorOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
if (operands.front().getType().isa<RankedTensorType>()) {
rewriter.replaceOpWithNewOp<Shape::FromExtentTensorOp>(op,
operands.front());
return success();
}
if (operands.front().getType().isa<RankedShapeType>()) {
rewriter.replaceOp(op, operands.front());
return success();
}
return failure();
}
};
// Currently, upstream shape lowering can use tensor<?xindex> to represent a
// shape, and will insert tensor_cast ops to convert to specific extent tensor
// types. However, not all tensor_cast ops are shape-related.
class ConvertTensorCastOp : public OpConversionPattern<tensor::CastOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
tensor::CastOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
if (!operands[0].getType().isa<RankedShapeType>())
return rewriter.notifyMatchFailure(op, "not a shape-related tensor_cast");
rewriter.replaceOpWithNewOp<Shape::ToExtentTensorOp>(op, op.getType(),
operands[0]);
return success();
}
};
class ConvertShapeToShapex
: public PassWrapper<ConvertShapeToShapex, OperationPass<ModuleOp>> {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<iree_compiler::ShapeDialect>();
}
void runOnOperation() override {
ModuleOp module = getOperation();
MLIRContext *context = &getContext();
// Conversion target definition.
ConversionTarget conversionTarget(*context);
conversionTarget.addIllegalDialect<shape::ShapeDialect>();
conversionTarget.addLegalDialect<iree_compiler::ShapeDialect>();
// Patterns.
OwningRewritePatternList patterns(&getContext());
patterns.insert<ConvertConstShapeOp>(context);
patterns.insert<ConvertShapeOfOp>(context);
patterns.insert<ConvertTensorExtract>(context);
patterns.insert<ConvertGetExtent>(context);
patterns.insert<ConvertFromExtents>(context);
patterns.insert<ConvertFromExtentTensorOp>(context);
patterns.insert<ConvertSplitAtOp>(context);
patterns.insert<ConvertBroadcastOp>(context);
patterns.insert<ConvertConcatOp>(context);
patterns.insert<ConvertToExtentTensorOp>(context);
patterns.insert<ConvertTensorCastOp>(context);
if (failed(applyPartialConversion(module, conversionTarget,
std::move(patterns)))) {
return signalPassFailure();
}
}
};
} // namespace
std::unique_ptr<OperationPass<ModuleOp>> createConvertShapeToShapexPass() {
return std::make_unique<ConvertShapeToShapex>();
}
static PassRegistration<ConvertShapeToShapex> registration(
"convert-shape-to-shapex", "Convert `shape` dialect to `shapex` dialect");
} // namespace Shape
} // namespace iree_compiler
} // namespace mlir