blob: 28d6e03b1ccb67d31c87e8578997b700de6d05dd [file] [log] [blame]
// Copyright 2020 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//===- XLAToLinalgOnTensors.cpp - Pass to convert XLA to Linalg on tensors-===//
//
// Pass to convert from XLA to linalg on tensers. Uses the patterns from
// tensorflow/compiler/mlir/xla/transforms/legalize_to_linalg.cc along with
// some IREE specific patterns.
//
//===----------------------------------------------------------------------===//
#include <memory>
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
#include "iree/compiler/Dialect/Shape/IR/ShapeDialect.h"
#include "iree/compiler/Dialect/Shape/IR/ShapeOps.h"
#include "iree/compiler/InputConversion/MHLO/ConvertMHLOToFlow.h"
#include "iree/compiler/InputConversion/MHLO/PassDetail.h"
#include "iree/compiler/InputConversion/MHLO/Passes.h"
#include "iree/compiler/InputConversion/MHLO/Rewriters.h"
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/Passes.h"
namespace mlir {
namespace iree_compiler {
namespace {
//===----------------------------------------------------------------------===//
// mhlo.concatenate conversion patterns.
//===----------------------------------------------------------------------===//
namespace {
/// Converts mhlo.concatenate operation to extract_slice ops + insert_slice ops.
struct ConcatenateOpConversion
: public OpConversionPattern<mhlo::ConcatenateOp> {
using OpConversionPattern<mhlo::ConcatenateOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
mhlo::ConcatenateOp op, ArrayRef<Value> args,
ConversionPatternRewriter &rewriter) const override {
auto resultType = this->typeConverter->convertType(op.getResult().getType())
.dyn_cast<RankedTensorType>();
if (!resultType || !resultType.hasStaticShape()) {
return rewriter.notifyMatchFailure(op,
"expected static shape for output");
}
Location loc = op.getLoc();
int dim = op.dimension();
int rank = resultType.getRank();
SmallVector<Value, 3> offsets, sizes, strides;
for (int i = 0; i < rank; ++i) {
offsets.push_back(rewriter.create<ConstantIndexOp>(loc, 0));
sizes.push_back(rewriter.create<tensor::DimOp>(loc, args[0], i));
strides.push_back(rewriter.create<ConstantIndexOp>(loc, 1));
}
Value resultDimSize = rewriter.create<ConstantIndexOp>(loc, 0);
for (auto arg : args) {
auto size = rewriter.create<tensor::DimOp>(loc, arg, dim);
resultDimSize = rewriter.create<AddIOp>(loc, resultDimSize, size);
}
sizes[dim] = resultDimSize;
auto initTensor = rewriter.create<linalg::InitTensorOp>(
loc, resultType.getShape(), resultType.getElementType());
auto zeroAttr = rewriter.getZeroAttr(resultType.getElementType());
Value zero = rewriter.create<ConstantOp>(loc, zeroAttr);
Value result =
rewriter.create<linalg::FillOp>(loc, zero, initTensor).getResult(0);
Value accBound = rewriter.create<ConstantIndexOp>(loc, 0);
for (auto arg : args) {
offsets[dim] = accBound;
sizes[dim] = rewriter.create<tensor::DimOp>(loc, arg, dim);
result = rewriter.create<tensor::InsertSliceOp>(loc, arg, result, offsets,
sizes, strides);
accBound = rewriter.create<AddIOp>(loc, accBound, sizes[dim]);
}
rewriter.replaceOp(op, result);
return success();
}
};
} // namespace
//===----------------------------------------------------------------------===//
// mhlo.fft conversion patterns.
//===----------------------------------------------------------------------===//
namespace {
/// Creats coefficients based on DFT definition, see
/// https://en.wikipedia.org/wiki/Discrete_Fourier_transform
Value getDFTMatmulCoeff(OpBuilder b, Location loc, RankedTensorType matrixType,
bool isRealPart) {
// scale = 2 * pi / N
double scale = 2 * M_PI / matrixType.getDimSize(0);
SmallVector<Attribute> values;
assert(matrixType.getRank() == 2 && "expected 2D matrix");
for (auto i : llvm::seq<unsigned>(0, matrixType.getDimSize(0))) {
for (auto j : llvm::seq<unsigned>(0, matrixType.getDimSize(1))) {
double v = scale * i * j;
if (isRealPart) {
v = cos(v);
} else {
v = -sin(v);
}
values.push_back(b.getF32FloatAttr(v));
}
}
return b.create<ConstantOp>(loc, matrixType,
DenseFPElementsAttr::get(matrixType, values));
}
Value createLinalgMatmulOnTensors(OpBuilder b, Location loc,
RankedTensorType resultType, Value lhs,
Value rhs) {
Value zero =
b.create<ConstantOp>(loc, b.getZeroAttr(resultType.getElementType()));
auto initTensor = b.create<linalg::InitTensorOp>(
loc, /*dyn_size=*/ValueRange{}, resultType.getShape(),
resultType.getElementType());
Value zeroTensor =
b.create<linalg::FillOp>(loc, zero, initTensor).getResult(0);
switch (lhs.getType().cast<RankedTensorType>().getRank()) {
case 1:
return b
.create<linalg::VecmatOp>(loc, TypeRange{resultType},
ValueRange{lhs, rhs},
ValueRange{zeroTensor})
.getResult(0);
case 2:
return b
.create<linalg::MatmulOp>(loc, TypeRange{resultType},
ValueRange{lhs, rhs},
ValueRange{zeroTensor})
.getResult(0);
default:
llvm_unreachable("unhandled matmul type");
}
}
/// Converts mhlo.fft operation to Linalg ops.
struct FftOpConversion : public OpConversionPattern<mhlo::FftOp> {
using OpConversionPattern<mhlo::FftOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
mhlo::FftOp op, ArrayRef<Value> args,
ConversionPatternRewriter &rewriter) const override {
if (op.fft_type() != "RFFT") {
return rewriter.notifyMatchFailure(op,
"non RFFT types are supported yet");
}
mhlo::FftOpAdaptor adaptor(args);
auto inputType = adaptor.operand().getType().dyn_cast<RankedTensorType>();
if (!inputType || !inputType.hasStaticShape() || inputType.getRank() > 2) {
return rewriter.notifyMatchFailure(op, "only static 1D or 2D dft ops");
}
int rank = inputType.getRank();
int n = inputType.getDimSize(rank - 1);
int fftLength =
op.fft_length().getSplatValue().cast<IntegerAttr>().getInt() / 2 + 1;
Location loc = op.getLoc();
auto matrixType =
RankedTensorType::get({n, fftLength}, inputType.getElementType());
auto resultType =
RankedTensorType::get(op.getType().cast<RankedTensorType>().getShape(),
inputType.getElementType());
auto realMatrix =
getDFTMatmulCoeff(rewriter, loc, matrixType, /*isRealPart=*/true);
auto real = createLinalgMatmulOnTensors(rewriter, loc, resultType,
adaptor.operand(), realMatrix);
auto imagMatrix =
getDFTMatmulCoeff(rewriter, loc, matrixType, /*isRealPart=*/false);
auto imag = createLinalgMatmulOnTensors(rewriter, loc, resultType,
adaptor.operand(), imagMatrix);
// Pack the results back to mhlo::ComplexOp.
rewriter.replaceOpWithNewOp<mhlo::ComplexOp>(op, op.getType(), real, imag);
return success();
}
};
} // namespace
struct ConvertMHLOToLinalgOnTensorsPass
: public ConvertMHLOToLinalgOnTensorsBase<
ConvertMHLOToLinalgOnTensorsPass> {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<IREE::Flow::FlowDialect, linalg::LinalgDialect,
mhlo::MhloDialect, ShapeDialect, math::MathDialect,
memref::MemRefDialect, complex::ComplexDialect>();
}
void runOnOperation() override {
OwningRewritePatternList patterns(&getContext());
MLIRContext *context = &getContext();
auto typeConverter = mhlo::createHloToLinalgSignedIntegerConverter();
// NOTE: not using corresponding setupMHLOToFlowPatterns because the entire
// MHLO dialects are marked illegal by this pass.
// TODO: Collapse/rework all of these patterns once the consolidation
// lands. There is little reason to have these so spread out.
populateMHLOToFlowPatterns(context, patterns);
chlo::PopulateDecomposeChloPatterns(context, &patterns);
populateMHLOBroadcastingToLinalgPatterns(context, *typeConverter, patterns);
populateMHLOToLinalgOnTensorsConversionPatterns(context, *typeConverter,
patterns);
populateMHLOComplexToRealPatterns(context, *typeConverter, patterns);
ConversionTarget target(getContext());
target.addIllegalDialect<chlo::HloClientDialect>();
target.addIllegalDialect<mhlo::MhloDialect>();
// Let the rest fall through.
target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns)))) {
return signalPassFailure();
}
}
};
/// Convert mhlo.constant op into std.const.
struct ConstOpConversion : public OpConversionPattern<mhlo::ConstOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
mhlo::ConstOp op, ArrayRef<Value> /*operands*/,
ConversionPatternRewriter &rewriter) const override {
auto valueAttr = op.value();
Type oldElType = valueAttr.getType().getElementType();
Type newElType = this->typeConverter->convertType(oldElType);
ElementsAttr newValueAttr = valueAttr;
if (newElType != oldElType) {
// Values don't change, just their reported type.
newValueAttr = valueAttr.mapValues(
newElType, [](const APInt &oldEl) { return oldEl; });
}
rewriter.replaceOpWithNewOp<ConstantOp>(op, newValueAttr);
return success();
}
};
} // namespace
void populateMHLOToLinalgOnTensorsConversionPatterns(
MLIRContext *context, TypeConverter &typeConverter,
OwningRewritePatternList &patterns) {
mhlo::populateHLOToLinalgConversionPattern(context, typeConverter, &patterns);
// TODO(#5809): Drop ConcatenateOp lowering in favor of the upstream version
// then remove the PatternBenefit here
patterns.insert<ConstOpConversion, ConcatenateOpConversion, FftOpConversion>(
typeConverter, context, PatternBenefit(1000));
}
std::unique_ptr<OperationPass<FuncOp>> createMHLOToLinalgOnTensorsPass() {
return std::make_unique<ConvertMHLOToLinalgOnTensorsPass>();
}
} // namespace iree_compiler
} // namespace mlir