blob: d472592533c03b9f97700e02145c0d13c66446c9 [file] [log] [blame]
// Copyright 2020 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "iree/compiler/Dialect/Shape/IR/Builders.h"
#include "iree/compiler/Dialect/Shape/IR/ShapeInterface.h"
#include "iree/compiler/Dialect/Shape/IR/ShapeOps.h"
#include "iree/compiler/Dialect/Shape/IR/ShapeTypes.h"
#include "iree/compiler/Dialect/Shape/Plugins/XLA/XlaHloShapeBuilder.h"
#include "iree/compiler/Dialect/Shape/Transforms/Passes.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/Pass/PassRegistry.h"
namespace mlir {
namespace iree_compiler {
namespace Shape {
namespace {
Value rewriteShapexRankedBroadcastShape(RankedBroadcastShapeOp bcastOp,
OpBuilder &builder) {
auto lhsRs = bcastOp.lhs().getType().cast<RankedShapeType>();
auto rhsRs = bcastOp.rhs().getType().cast<RankedShapeType>();
auto loc = bcastOp.getLoc();
auto resultRs = bcastOp.getResult().getType().cast<RankedShapeType>();
auto dimType = resultRs.getDimType();
// Pairs of the shape dim and corresponding value if dynamic.
SmallVector<std::pair<Optional<int>, Value>, 4> lhsDims;
SmallVector<std::pair<Optional<int>, Value>, 4> rhsDims;
lhsDims.resize(resultRs.getRank());
rhsDims.resize(resultRs.getRank());
// Populate the lhs dims.
for (auto dimMap : llvm::enumerate(bcastOp.lhs_broadcast_dimensions())) {
auto inputDimIndex = dimMap.index();
auto outputDimIndex = dimMap.value().getZExtValue();
assert(outputDimIndex < lhsDims.size());
if (!resultRs.isDimDynamic(outputDimIndex)) {
// No need to populate fully static dimensions.
continue;
}
if (lhsRs.isDimDynamic(inputDimIndex)) {
lhsDims[outputDimIndex] =
std::make_pair(-1, builder.create<RankedDimOp>(
loc, dimType, bcastOp.lhs(),
builder.getI64IntegerAttr(inputDimIndex)));
} else {
lhsDims[outputDimIndex] = std::make_pair(inputDimIndex, nullptr);
}
}
// Populate the rhs dims.
for (auto dimMap : llvm::enumerate(bcastOp.rhs_broadcast_dimensions())) {
auto inputDimIndex = dimMap.index();
auto outputDimIndex = dimMap.value().getZExtValue();
assert(outputDimIndex < rhsDims.size());
if (!resultRs.isDimDynamic(outputDimIndex)) {
// No need to populate fully static dimensions.
continue;
}
if (rhsRs.isDimDynamic(inputDimIndex)) {
rhsDims[outputDimIndex] =
std::make_pair(-1, builder.create<RankedDimOp>(
loc, dimType, bcastOp.rhs(),
builder.getI64IntegerAttr(inputDimIndex)));
} else {
rhsDims[outputDimIndex] = std::make_pair(inputDimIndex, nullptr);
}
}
// Now compute dynamic dims for each output dim.
SmallVector<Value, 4> dynamicDims;
for (int i = 0; i < lhsDims.size(); ++i) {
if (!resultRs.isDimDynamic(i)) continue;
auto lhsDimInfo = lhsDims[i];
auto lhsDimSize = lhsDimInfo.first ? *lhsDimInfo.first : -1;
auto rhsDimInfo = rhsDims[i];
auto rhsDimSize = rhsDimInfo.first ? *rhsDimInfo.first : -1;
if (lhsDimSize > 1) {
// Non-degenerate static.
bcastOp.emitRemark(
"broadcast of non-degenerate lhs static dim not implemented");
return nullptr;
} else if (rhsDimSize > 1) {
// Non-degenerate static.
bcastOp.emitRemark(
"broadcast of non-degenerate rhs static dim not implemented");
return nullptr;
} else if (lhsDimSize == 1) {
// Degenerate static.
bcastOp.emitRemark(
"broadcast of degenerate lhs static dim not implemented");
return nullptr;
} else if (rhsDimSize == 1) {
// Degenerate static.
bcastOp.emitRemark(
"broadcast of degenerate rhs static dim not implemented");
return nullptr;
} else {
// Dynamic.
// TODO: Generate code to assert.
if (lhsDimInfo.second) {
dynamicDims.push_back(lhsDimInfo.second);
} else if (rhsDimInfo.second) {
dynamicDims.push_back(rhsDimInfo.second);
} else {
return nullptr;
}
}
}
// And make the result shape.
return builder.create<MakeRankedShapeOp>(loc, resultRs, dynamicDims);
}
class ExpandRankedBroadcastShapePattern
: public OpRewritePattern<RankedBroadcastShapeOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(RankedBroadcastShapeOp bcastOp,
PatternRewriter &rewriter) const override {
auto newValue = rewriteShapexRankedBroadcastShape(bcastOp, rewriter);
if (!newValue) return failure();
rewriter.replaceOp(bcastOp, newValue);
return success();
}
};
class MaterializeRunTimeRankedShapePattern
: public OpRewritePattern<Shape::GetRankedShapeOp> {
public:
MaterializeRunTimeRankedShapePattern(MLIRContext *context)
: OpRewritePattern(context, 1) {}
LogicalResult matchAndRewrite(GetRankedShapeOp getShapeOp,
PatternRewriter &rewriter) const override {
auto shapeType = getShapeOp.shape().getType().dyn_cast<RankedShapeType>();
SmallVector<Value, 4> dynamicDims;
for (int64_t i = 0, e = shapeType.getRank(); i < e; ++i) {
if (!shapeType.isDimDynamic(i)) continue;
dynamicDims.push_back(
rewriter.create<DimOp>(getShapeOp.getLoc(), getShapeOp.operand(), i));
}
// TODO(laurenzo): Remove once further along (it is fine to be unsupported
// as it will fall back to generic), but in these early phases, it is
// extremely useful to be chatty about this fallback.
auto inputOperation = getShapeOp.operand().getDefiningOp();
if (inputOperation) {
inputOperation->emitRemark()
<< "unable to materialize shape calculation (unsupported op '"
<< inputOperation->getName() << "'?): falling back to runtime "
<< "resolution";
}
rewriter.replaceOpWithNewOp<MakeRankedShapeOp>(getShapeOp, shapeType,
dynamicDims);
return success();
}
};
class MaterializeCompileTimeRankedShapePattern
: public OpRewritePattern<Shape::GetRankedShapeOp> {
public:
MaterializeCompileTimeRankedShapePattern(
const CustomOpShapeBuilderList *customOpShapeBuilder,
MLIRContext *context)
: OpRewritePattern(context, 10),
customOpShapeBuilder(customOpShapeBuilder) {}
LogicalResult matchAndRewrite(GetRankedShapeOp getShapeOp,
PatternRewriter &rewriter) const override {
// Check for static shape and elide.
auto operandType =
getShapeOp.operand().getType().dyn_cast<RankedTensorType>();
auto shapeType = getShapeOp.shape().getType().dyn_cast<RankedShapeType>();
if (operandType && shapeType && operandType.hasStaticShape()) {
rewriter.replaceOpWithNewOp<ConstRankedShapeOp>(getShapeOp, shapeType);
return success();
}
// Check for input operation (unless if a small set of shape ops).
Operation *inputOperation = getShapeOp.operand().getDefiningOp();
if (inputOperation && !llvm::isa<TieShapeOp>(inputOperation)) {
return rewriteInputOp(getShapeOp, inputOperation, rewriter);
}
return failure();
}
private:
// Matches the case where the input to a GetRankedShapeOp is another
// operation. This is the primary supported case as other rewrites should
// have isolated function/block boundaries with TieShape ops.
LogicalResult rewriteInputOp(GetRankedShapeOp getShapeOp,
Operation *inputOperation,
PatternRewriter &rewriter) const {
// SameOperandsAndResultShape trait.
if (inputOperation->hasTrait<OpTrait::SameOperandsAndResultShape>() ||
inputOperation->hasTrait<OpTrait::SameOperandsAndResultType>()) {
return rewriteSameOperandsAndResultShape(getShapeOp, inputOperation,
rewriter);
}
// Custom shapes.
if (customOpShapeBuilder) {
auto resultShape = getShapeOp.getRankedShape();
for (auto &shapeBuilder : *customOpShapeBuilder) {
Value customShape = shapeBuilder->buildRankedShape(
resultShape, inputOperation, rewriter);
if (customShape) {
rewriter.replaceOp(getShapeOp, customShape);
return success();
}
}
}
return failure();
}
LogicalResult rewriteSameOperandsAndResultShape(
GetRankedShapeOp getShapeOp, Operation *inputOperation,
PatternRewriter &rewriter) const {
llvm::SmallVector<Value, 4> inputOperands(inputOperation->getOperands());
auto combinedShapeOp = buildCastInputsToResultShape(
inputOperation->getLoc(), getShapeOp.getRankedShape(), inputOperands,
rewriter);
if (!combinedShapeOp) return failure();
rewriter.replaceOp(getShapeOp, {combinedShapeOp});
return success();
}
const CustomOpShapeBuilderList *customOpShapeBuilder;
};
class MaterializeShapeCalculationsPass
: public FunctionPass<MaterializeShapeCalculationsPass> {
public:
// Gets a CustomOpShapeBuilderList for expanding shapes of custom ops.
// By default, returns nullptr, which will not handle custom op shapes.
// TODO(laurenzo): Since it isn't clear yet whether we need this facility
// long term (i.e. this should come from the ops themselves), we are just
// hard-linking it here at the expense of a dependency problem. Decouple this
// if the facility persists.
const CustomOpShapeBuilderList *getCustomOpShapeBuilder() {
static CustomOpShapeBuilderList globalBuilders = ([]() {
CustomOpShapeBuilderList builders;
xla_hlo::populateXlaHloCustomOpShapeBuilder(builders);
return builders;
})();
return &globalBuilders;
}
void runOnFunction() override {
OwningRewritePatternList patterns;
// Always include certain canonicalizations that interop.
CastCompatibleShapeOp::getCanonicalizationPatterns(patterns, &getContext());
GetRankedShapeOp::getCanonicalizationPatterns(patterns, &getContext());
MakeRankedShapeOp::getCanonicalizationPatterns(patterns, &getContext());
RankedDimOp::getCanonicalizationPatterns(patterns, &getContext());
TieShapeOp::getCanonicalizationPatterns(patterns, &getContext());
patterns.insert<ExpandRankedBroadcastShapePattern>(&getContext());
patterns.insert<MaterializeCompileTimeRankedShapePattern>(
getCustomOpShapeBuilder(), &getContext());
patterns.insert<MaterializeRunTimeRankedShapePattern>(&getContext());
applyPatternsGreedily(getFunction(), patterns);
OwningRewritePatternList fallbackPatterns;
}
};
} // namespace
} // namespace Shape
// For any function which contains dynamic dims in its inputs or results,
// rewrites it so that the dynamic dims are passed in/out.
std::unique_ptr<OpPassBase<FuncOp>> createMaterializeShapeCalculationsPass() {
return std::make_unique<Shape::MaterializeShapeCalculationsPass>();
}
static PassRegistration<Shape::MaterializeShapeCalculationsPass> pass(
"iree-shape-materialize-calculations", "Materializes shape calculations.");
} // namespace iree_compiler
} // namespace mlir