blob: 0476621aafe27966fdf94518fb8bf2290e82307c [file] [log] [blame]
// Copyright 2019 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Implements passes to convert complex operations to equivalent real value
// operations. This does not include removing complex values from function
// argument or return types.
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "stablehlo-iree/Conversion/Preprocessing/Passes.h"
#include "stablehlo-iree/Conversion/Preprocessing/Rewriters.h"
#include "stablehlo/dialect/StablehloOps.h"
namespace mlir::iree_compiler::stablehlo {
#define GEN_PASS_DEF_LOWERCOMPLEX
#include "stablehlo-iree/Conversion/Preprocessing/Passes.h.inc"
namespace {
struct ConvertComplexDot final : OpRewritePattern<mlir::stablehlo::DotOp> {
using OpRewritePattern::OpRewritePattern;
// Will decompose stablehlo::DotOp with complex parameters down to
// four Dot operations in the following fashion:
// result.real = lhs.real <DOT> rhs.real - lhs.imag <DOT> rhs.imag
// result.imag = lhs.imag <DOT> rhs.real + lhs.real <DOT> rhs.imag
// result = complex(result.real, result.imag)
LogicalResult matchAndRewrite(mlir::stablehlo::DotOp dot,
PatternRewriter &rewriter) const override {
ArrayAttr precision = dot.getPrecisionConfigAttr();
TypedValue<TensorType> lhs = dot.getLhs();
TypedValue<TensorType> rhs = dot.getRhs();
ShapedType lhsType = lhs.getType();
ShapedType rhsType = rhs.getType();
if (!isa<ComplexType>(lhsType.getElementType()) ||
!isa<ComplexType>(rhsType.getElementType())) {
return rewriter.notifyMatchFailure(dot, "lhs/rhs types are not complex");
}
Location loc = dot.getLoc();
Value lhsReal = rewriter.create<mlir::stablehlo::RealOp>(loc, lhs);
Value lhsImag = rewriter.create<mlir::stablehlo::ImagOp>(loc, lhs);
Value rhsReal = rewriter.create<mlir::stablehlo::RealOp>(loc, rhs);
Value rhsImag = rewriter.create<mlir::stablehlo::ImagOp>(loc, rhs);
TensorType resultType = dot.getType();
Type newType = mlir::hlo::createRealType(resultType);
Value realComponent = rewriter.create<mlir::stablehlo::SubtractOp>(
loc,
rewriter.create<mlir::stablehlo::DotOp>(loc, newType, lhsReal, rhsReal,
precision),
rewriter.create<mlir::stablehlo::DotOp>(loc, newType, lhsImag, rhsImag,
precision));
Value imagComponent = rewriter.create<mlir::stablehlo::AddOp>(
loc,
rewriter.create<mlir::stablehlo::DotOp>(loc, newType, lhsReal, rhsImag,
precision),
rewriter.create<mlir::stablehlo::DotOp>(loc, newType, lhsImag, rhsReal,
precision));
Value result = rewriter.create<mlir::stablehlo::ComplexOp>(
loc, resultType, realComponent, imagComponent);
rewriter.replaceOp(dot, result);
return success();
}
};
struct LowerComplex final : impl::LowerComplexBase<LowerComplex> {
void runOnOperation() override {
MLIRContext *ctx = &getContext();
RewritePatternSet patterns(ctx);
populatePreprocessingComplexPatterns(ctx, &patterns);
populateCanonicalizationPatterns(ctx, &patterns);
if (failed(applyPatternsAndFoldGreedily(getOperation(),
std::move(patterns)))) {
signalPassFailure();
}
}
};
// Get a constant splat for the given value of type. Requires value to be of
// type static shaped RankedTensorType.
template <typename T>
ElementsAttr getSplat(Builder *b, RankedTensorType ty, T constant) {
Type elementTy = getElementTypeOrSelf(ty);
if (elementTy.isSignlessInteger()) {
return DenseElementsAttr::get(ty, b->getIntegerAttr(elementTy, constant));
}
if (isa<FloatType>(elementTy)) {
return DenseElementsAttr::get(ty, b->getFloatAttr(elementTy, constant));
}
if (auto complexTy = dyn_cast<ComplexType>(elementTy)) {
auto complexElementTy = complexTy.getElementType();
if (complexElementTy.isF32())
return DenseElementsAttr::get(ty,
static_cast<std::complex<float>>(constant));
if (complexElementTy.isF64())
return DenseElementsAttr::get(
ty, static_cast<std::complex<double>>(constant));
}
llvm_unreachable("unhandled element type");
}
template <typename T>
ElementsAttr getSplat(Builder *b, Value val, T constant) {
return getSplat(b, cast<RankedTensorType>(val.getType()), constant);
}
} // end anonymous namespace
namespace {
#include "stablehlo-iree/Conversion/Preprocessing/ComplexLoweringPatterns.h.inc"
} // end anonymous namespace
void populatePreprocessingComplexPatterns(MLIRContext *context,
RewritePatternSet *patterns) {
patterns->add<ConvertComplexDot>(context);
populateWithGenerated(*patterns);
}
} // namespace mlir::iree_compiler::stablehlo