blob: 1531b4bc0f44eacdc055db5a5673e63aa9275e24 [file] [log] [blame]
// Copyright 2020 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <numeric>
#include "iree/compiler/Dialect/Shape/IR/ShapeDialect.h"
#include "iree/compiler/Dialect/Shape/IR/ShapeOps.h"
#include "iree/compiler/Dialect/Shape/IR/ShapeTypes.h"
#include "iree/compiler/Dialect/Shape/Transforms/Passes.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/Pass/PassRegistry.h"
#include "mlir/Transforms/DialectConversion.h"
#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
namespace mlir {
namespace iree_compiler {
namespace Shape {
namespace {
// Returns a 1-d i64 elements attribute populated with numbers from start to
// end, excluding.
static DenseIntElementsAttr getI64ElementsAttrForSeq(int start, int end,
Builder &builder) {
int size = end - start;
SmallVector<int64_t, 4> vals;
vals.resize(size);
std::iota(vals.begin(), vals.end(), start);
TensorType ty = RankedTensorType::get({size}, builder.getIntegerType(64));
return DenseIntElementsAttr::get(ty, vals);
}
// Returns true if a given HLO elementwise op does not broadcast.
template <typename HloOpTy>
bool IsSameRankedTypeBinaryElementwiseOp(HloOpTy op) {
if (op.broadcast_dimensions()) {
// Has intra-operand broadcast.
return false;
}
auto lhsType = op.lhs().getType().template dyn_cast<RankedTensorType>();
auto rhsType = op.rhs().getType().template dyn_cast<RankedTensorType>();
if (!lhsType || !rhsType) return false;
return lhsType == rhsType;
}
// Converts a broadcasted binary elementwise HLO with dynamic shapes
// to have explicit broadcasting.
template <typename HloOpTy>
class BroadcastedRankedBinaryElementwiseConversion
: public OpConversionPattern<HloOpTy> {
using OpConversionPattern<HloOpTy>::OpConversionPattern;
LogicalResult matchAndRewrite(
HloOpTy op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto lhs = operands[0];
auto rhs = operands[1];
auto lhsType = lhs.getType().dyn_cast<RankedTensorType>();
auto rhsType = rhs.getType().dyn_cast<RankedTensorType>();
auto resultType = op.getOperation()
->getResultTypes()[0]
.template dyn_cast<RankedTensorType>();
if (!lhsType || !rhsType || !resultType) {
// This conversion only supports ranked.
return failure();
}
if (lhsType.hasStaticShape() && rhsType.hasStaticShape() &&
resultType.hasStaticShape()) {
// This temporary pass is only used for dynamically shaped elementwise
// ops.
return failure();
}
// Get the shapes of the operands. Note that we assume that a prior shape
// inference pass has appropriately specialized the shapes and we use them
// as-is versus recomputing the broadcast.
auto lhsShape = rewriter.create<GetRankedShapeOp>(op.getLoc(), lhs);
auto rhsShape = rewriter.create<GetRankedShapeOp>(op.getLoc(), rhs);
auto resultShapeType =
RankedShapeType::get(resultType.getShape(), rewriter.getContext());
auto resultShapeDims = resultShapeType.getAllDims();
// Rank broadcast as appropriate.
Value broadcastedLhs = lhs;
Value broadcastedRhs = rhs;
DenseIntElementsAttr lhsBroadcastDims;
DenseIntElementsAttr rhsBroadcastDims;
if (op.broadcast_dimensions()) {
auto lhsRank = lhsType.getRank();
auto rhsRank = rhsType.getRank();
auto higherRankBroadcastDims =
getI64ElementsAttrForSeq(0, std::max(lhsRank, rhsRank), rewriter);
if (lhsRank > rhsRank) {
lhsBroadcastDims = higherRankBroadcastDims;
rhsBroadcastDims = *op.broadcast_dimensions();
} else if (rhsRank > lhsRank) {
lhsBroadcastDims = *op.broadcast_dimensions();
rhsBroadcastDims = higherRankBroadcastDims;
} else {
op.emitOpError() << "broadcast_dimensions implies rank broadcast "
<< "but operands are of the same rank";
return failure();
}
} else if (lhsType != rhsType) {
op.emitError() << "degenerate broadcast of same-rank operands "
<< "not yet implemented";
return failure();
}
auto resultShape = rewriter.create<RankedBroadcastShapeOp>(
op.getLoc(), resultShapeType, lhsShape, rhsShape, lhsBroadcastDims,
rhsBroadcastDims);
broadcastedLhs = rewriter.create<RankedBroadcastInDimOp>(
op.getLoc(),
RankedTensorType::get(resultShapeDims, lhsType.getElementType()),
broadcastedLhs, resultShape, lhsBroadcastDims);
broadcastedRhs = rewriter.create<RankedBroadcastInDimOp>(
op.getLoc(),
RankedTensorType::get(resultShapeDims, rhsType.getElementType()),
broadcastedRhs, resultShape, rhsBroadcastDims);
auto newOp = rewriter.create<HloOpTy>(
op.getLoc(), resultType, broadcastedLhs, broadcastedRhs, nullptr);
rewriter.replaceOp(op, {newOp});
return success();
}
};
class ConvertDynamicBroadcastInDim
: public OpConversionPattern<xla_hlo::DynamicBroadcastInDimOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
xla_hlo::DynamicBroadcastInDimOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
xla_hlo::DynamicBroadcastInDimOpOperandAdaptor adapter(operands);
Value rankedShape = rewriter.create<Shape::FromExtentTensorOp>(
op.getLoc(), adapter.output_dimensions());
rewriter.replaceOpWithNewOp<Shape::RankedBroadcastInDimOp>(
op, op.getType(), adapter.operand(), rankedShape,
op.broadcast_dimensions());
return success();
}
};
class ConvertHLOToShapePass
: public PassWrapper<ConvertHLOToShapePass, FunctionPass> {
void runOnFunction() override {
ConversionTarget conversionTarget(getContext());
OwningRewritePatternList conversionPatterns;
conversionTarget.addLegalDialect<ShapeDialect>();
conversionTarget.addLegalDialect<StandardOpsDialect>();
conversionTarget.addLegalDialect<xla_hlo::XlaHloDialect>();
conversionTarget.addIllegalOp<xla_hlo::DynamicBroadcastInDimOp>();
conversionPatterns.insert<ConvertDynamicBroadcastInDim>(&getContext());
#define CONVERT_BINARY_ELEMENTWISE_OP(HloOpTy) \
conversionTarget.addDynamicallyLegalOp<HloOpTy>( \
[](HloOpTy op) { return IsSameRankedTypeBinaryElementwiseOp(op); }); \
conversionPatterns \
.insert<BroadcastedRankedBinaryElementwiseConversion<HloOpTy>>( \
&getContext());
CONVERT_BINARY_ELEMENTWISE_OP(xla_hlo::AddOp);
CONVERT_BINARY_ELEMENTWISE_OP(xla_hlo::Atan2Op);
CONVERT_BINARY_ELEMENTWISE_OP(xla_hlo::DivOp);
CONVERT_BINARY_ELEMENTWISE_OP(xla_hlo::MaxOp);
CONVERT_BINARY_ELEMENTWISE_OP(xla_hlo::MinOp);
CONVERT_BINARY_ELEMENTWISE_OP(xla_hlo::MulOp);
CONVERT_BINARY_ELEMENTWISE_OP(xla_hlo::PowOp);
CONVERT_BINARY_ELEMENTWISE_OP(xla_hlo::RemOp);
CONVERT_BINARY_ELEMENTWISE_OP(xla_hlo::ShiftLeftOp);
CONVERT_BINARY_ELEMENTWISE_OP(xla_hlo::ShiftRightArithmeticOp);
CONVERT_BINARY_ELEMENTWISE_OP(xla_hlo::ShiftRightLogicalOp);
CONVERT_BINARY_ELEMENTWISE_OP(xla_hlo::SubOp);
if (failed(applyPartialConversion(getFunction(), conversionTarget,
conversionPatterns))) {
return signalPassFailure();
}
}
};
} // namespace
// Converts shape-sensitive HLOs to be based on facilities in the shape
// dialect.
std::unique_ptr<OperationPass<FuncOp>> createConvertHLOToShapePass() {
return std::make_unique<Shape::ConvertHLOToShapePass>();
}
static PassRegistration<Shape::ConvertHLOToShapePass> pass(
"iree-shape-convert-hlo",
"Converts dynamic shape dependent HLO ops to shaped variants.");
} // namespace Shape
} // namespace iree_compiler
} // namespace mlir