blob: 5626d6cc2197890c0a6010ec492067412c1840ae [file] [log] [blame]
// Copyright 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "compiler/IR/Dialect.h"
#include "compiler/IR/Ops.h"
#include "compiler/IR/Sequencer/HLDialect.h"
#include "compiler/IR/Sequencer/HLOps.h"
#include "compiler/IR/StructureOps.h"
#include "compiler/Transforms/ConversionUtils.h"
#include "compiler/Utils/MemRefUtils.h"
#include "compiler/Utils/OpCreationUtils.h"
#include "compiler/Utils/TypeConversionUtils.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/Value.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
namespace mlir {
namespace iree_compiler {
namespace {
// TODO(suderman): tablegen this? or something a bit more flexible.
#define UNARY_OP_LOWERING(XlaOpType, IREEOpType) \
struct XlaOpType##Lowering \
: public UnaryOpLowering<xla_hlo::XlaOpType, IREEOpType> { \
using UnaryOpLowering::UnaryOpLowering; \
};
#define TERNARY_OP_LOWERING(XlaOpType, IREEOpType) \
struct XlaOpType##Lowering \
: public TernaryOpLowering<xla_hlo::XlaOpType, IREEOpType> { \
using TernaryOpLowering::TernaryOpLowering; \
};
UNARY_OP_LOWERING(CopyOp, IREESeq::HL::CloneOp);
#undef UNARY_OP_LOWERING
#undef TERNARY_OP_LOWERING
template <typename T>
static Operation *createShapeTargetingOp(ConversionPatternRewriter &rewriter,
Location loc, Value *input,
MemRefType targetType) {
auto shapeOp = createArrayConstant(rewriter, loc, targetType.getShape());
return rewriter.create<T>(loc, targetType, input, shapeOp);
}
static Value *inputAsMemref(ConversionPatternRewriter &rewriter, Operation *op,
Value *tensor) {
return wrapAsMemRef(loadAccessValue(op->getLoc(), tensor, rewriter), op,
rewriter);
}
template <typename SrcOp>
class XlaOpLowering : public OpConversionPattern<SrcOp> {
public:
using OpConversionPattern<SrcOp>::OpConversionPattern;
PatternMatchResult matchAndRewrite(
SrcOp op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
auto srcOp = cast<SrcOp>(op);
SmallVector<Value *, 4> memrefOperands;
for (auto operand : operands) {
memrefOperands.push_back(inputAsMemref(rewriter, op, operand));
}
auto dstOp = rewriteInternal(&srcOp, memrefOperands, rewriter);
rewriter.replaceOp(op, wrapAsTensor(dstOp->getResult(0), srcOp, rewriter));
return this->matchSuccess();
}
protected:
virtual Operation *rewriteInternal(
SrcOp *op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const {
llvm_unreachable("unimplemented rewrite, did you mean rewriteTerminator?");
}
};
struct ConcatOpLowering : public XlaOpLowering<xla_hlo::ConcatenateOp> {
using XlaOpLowering::XlaOpLowering;
Operation *rewriteInternal(
xla_hlo::ConcatenateOp *op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
auto finalType = convertTypeToMemRef(*op);
return rewriter.create<IREESeq::HL::ConcatOp>(
op->getLoc(), finalType, operands,
rewriter.getI32IntegerAttr(op->dimension().getZExtValue()));
}
};
struct DynamicUpdateSliceLowering
: public XlaOpLowering<xla_hlo::DynamicUpdateSliceOp> {
using XlaOpLowering::XlaOpLowering;
Operation *rewriteInternal(
xla_hlo::DynamicUpdateSliceOp *op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
auto operand = operands[0];
auto update = operands[1];
auto updateType = update->getType().cast<ShapedType>();
Value *lengthConstant =
createArrayConstant(rewriter, op->getLoc(), updateType.getShape());
auto startIndices = makeArrayRef(operands).drop_front(2);
const int rank = startIndices.size();
llvm::SmallVector<Value *, 4> valuesToConcat;
valuesToConcat.reserve(startIndices.size());
auto type = getElementTypeOrSelf(startIndices.front());
// To generate the offset matrix we need to convert the variadic tensors
// into a reshaped and concated value.
for (auto index : startIndices) {
auto reshapedIndex = rewriter.create<IREESeq::HL::ReshapeOp>(
op->getLoc(), MemRefType::get({1}, type), index,
createArrayConstant(rewriter, op->getLoc(), {1}));
valuesToConcat.push_back(reshapedIndex);
}
auto dstOffset = rewriter
.create<IREESeq::HL::ConcatOp>(
op->getLoc(), MemRefType::get({rank}, type),
valuesToConcat, rewriter.getI32IntegerAttr(0))
.getResult();
llvm::SmallVector<int64_t, 4> zero_offset;
zero_offset.resize(updateType.getRank(), 0);
auto srcOffset = createArrayConstant(rewriter, op->getLoc(), zero_offset);
auto copiedOperand = rewriter.create<IREESeq::HL::CloneOp>(
op->getLoc(), operand->getType(), operand);
rewriter
.create<IREESeq::HL::CopyOp>(op->getLoc(), update, srcOffset,
copiedOperand, dstOffset, lengthConstant)
.getOperation();
return copiedOperand;
}
};
struct SliceLowering : public XlaOpLowering<xla_hlo::SliceOp> {
using XlaOpLowering::XlaOpLowering;
Operation *rewriteInternal(
xla_hlo::SliceOp *op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
// XLA slice has value semantics, whereas the IREE slice creates a view. We
// lower it to a copy if all strides are one which may be transformed to a
// slice by later optimizations.
auto isNotOne = [](APInt stride) { return stride != 1; };
if (llvm::any_of(op->strides(), isNotOne)) {
op->emitRemark() << "Could not lower slice op with non-singular strides";
return nullptr;
}
auto finalType = convertTypeToMemRef(*op);
auto src = operands[0];
std::vector<Value *> dim_pieces;
auto dst = rewriter.create<IREESeq::HL::AllocHeapOp>(op->getLoc(),
finalType, dim_pieces);
auto srcIndices =
rewriter.create<IREE::ConstantOp>(op->getLoc(), op->start_indices());
auto lengths =
createArrayConstant(rewriter, op->getLoc(), finalType.getShape());
llvm::SmallVector<int64_t, 4> zero_offset;
zero_offset.resize(finalType.getRank(), 0);
auto dstIndices = createArrayConstant(rewriter, op->getLoc(), zero_offset);
rewriter.create<IREESeq::HL::CopyOp>(op->getLoc(), src, srcIndices, dst,
dstIndices, lengths);
return dst;
}
};
struct ReshapeOpLowering : public XlaOpLowering<xla_hlo::ReshapeOp> {
using XlaOpLowering::XlaOpLowering;
Operation *rewriteInternal(
xla_hlo::ReshapeOp *op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
return createShapeTargetingOp<IREESeq::HL::ReshapeOp>(
rewriter, op->getLoc(), operands[0], convertTypeToMemRef(*op));
}
};
} // namespace
void populateLowerXlaToSequencerPatterns(OwningRewritePatternList &patterns,
MLIRContext *ctx) {
patterns.insert<ConcatOpLowering, CopyOpLowering, DynamicUpdateSliceLowering,
ReshapeOpLowering, SliceLowering>(ctx);
}
} // namespace iree_compiler
} // namespace mlir