blob: 0bab4e0c8a204079deac7343e1c8ed5059a53d7a [file] [log] [blame]
// Copyright 2025 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "compiler/plugins/input/Torch/InputConversion/Passes.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h"
#include "iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
#include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h"
namespace mlir::iree_compiler::TorchInput {
#define GEN_PASS_DEF_CONVERTTORCHUNSTRUCTUREDTOLINALGEXTPASS
#include "compiler/plugins/input/Torch/InputConversion/Passes.h.inc"
namespace {
struct FftRfftOpConversion
: public OpRewritePattern<torch::Torch::AtenFftRfftOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(torch::Torch::AtenFftRfftOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value self = op.getSelf();
int64_t dim;
Value dimVal = op.getDim();
if (isa<torch::Torch::NoneType>(dimVal.getType())) {
dim = -1;
} else if (!matchPattern(dimVal, torch::Torch::m_TorchConstantInt(&dim))) {
return rewriter.notifyMatchFailure(
op, "unimplemented: requires dim to be constant");
}
if (!isa<torch::Torch::NoneType>(op.getN().getType())) {
return rewriter.notifyMatchFailure(op, "unimplemented: parameter n");
}
if (!isa<torch::Torch::NoneType>(op.getNorm().getType())) {
return rewriter.notifyMatchFailure(op, "unimplemented: parameter norm");
}
auto inputTensorType = cast<torch::Torch::ValueTensorType>(self.getType());
if (!inputTensorType || !inputTensorType.hasSizes()) {
return rewriter.notifyMatchFailure(op,
"expected input type having sizes");
}
ArrayRef<int64_t> inputShape = inputTensorType.getSizes();
dim += dim < 0 ? inputShape.size() : 0;
int64_t fftLength = inputShape[dim];
if (fftLength == torch::Torch::kUnknownSize) {
return rewriter.notifyMatchFailure(op,
"expected known FFT dimension size");
}
if (!llvm::isPowerOf2_64(fftLength)) {
return rewriter.notifyMatchFailure(
op, "expected FFT length to be a power of two");
}
// Transpose if FFT dimension is not the last one
SmallVector<int64_t> preFftShape(inputShape);
const int64_t lastDim = inputShape.size() - 1;
const bool needTranspose = dim != lastDim;
if (needTranspose) {
Value cstLastDim = torch::Torch::ConstantIntOp::create(
rewriter, loc, rewriter.getI64IntegerAttr(lastDim));
Value cstFftDim = torch::Torch::ConstantIntOp::create(
rewriter, loc, rewriter.getI64IntegerAttr(dim));
std::swap(preFftShape[dim], preFftShape[lastDim]);
self = torch::Torch::AtenTransposeIntOp::create(
rewriter, loc,
inputTensorType.getWithSizesAndDtype(preFftShape,
inputTensorType.getDtype()),
self, cstFftDim, cstLastDim);
}
// Cast to the builtin tensor type.
Value builtinCast = torch::TorchConversion::ToBuiltinTensorOp::create(
rewriter, loc,
cast<torch::Torch::ValueTensorType>(self.getType()).toBuiltinTensor(),
self);
auto rewriteRes =
IREE::LinalgExt::rewriteFft(op, builtinCast, fftLength, rewriter);
if (failed(rewriteRes)) {
return failure();
}
auto [real, imag] = rewriteRes.value();
// Cast back
SmallVector<int64_t> postFftShape(preFftShape);
postFftShape.back() = fftLength / 2 + 1;
Type postFftType = inputTensorType.getWithSizesAndDtype(
postFftShape, inputTensorType.getDtype());
Value torchReal = torch::TorchConversion::FromBuiltinTensorOp::create(
rewriter, loc, postFftType, real);
Value torchImag = torch::TorchConversion::FromBuiltinTensorOp::create(
rewriter, loc, postFftType, imag);
// Unsqueeze a 1 dimension at the end
SmallVector<int64_t> unsqueezedTensorSizes(postFftShape);
unsqueezedTensorSizes.push_back(1);
Type unsqueezedTensorType = inputTensorType.getWithSizesAndDtype(
unsqueezedTensorSizes, inputTensorType.getDtype());
Value axisUnsqueeze = torch::Torch::ConstantIntOp::create(
rewriter, loc, rewriter.getI64IntegerAttr(-1));
Value unsqueezedReal = torch::Torch::AtenUnsqueezeOp::create(
rewriter, loc, unsqueezedTensorType, torchReal, axisUnsqueeze);
Value unsqueezedImag = torch::Torch::AtenUnsqueezeOp::create(
rewriter, loc, unsqueezedTensorType, torchImag, axisUnsqueeze);
// Concatenate real and imag
Type listType = torch::Torch::ListType::get(unsqueezedTensorType);
Value slices = torch::Torch::PrimListConstructOp::create(
rewriter, loc, listType,
llvm::ArrayRef<Value>{unsqueezedReal, unsqueezedImag});
SmallVector<int64_t> concatenatedTensorSizes(unsqueezedTensorSizes);
concatenatedTensorSizes.back() = 2;
Type concatenatedTensorType = inputTensorType.getWithSizesAndDtype(
concatenatedTensorSizes, inputTensorType.getDtype());
Value concatenated = torch::Torch::AtenCatOp::create(
rewriter, loc, concatenatedTensorType, slices, axisUnsqueeze);
// View as complex (and transpose back)
SmallVector<int64_t> complexResultSizes(concatenatedTensorSizes);
complexResultSizes.pop_back();
torch::Torch::ValueTensorType complexResultType =
cast<torch::Torch::ValueTensorType>(
inputTensorType.getWithSizesAndDtype(
complexResultSizes,
mlir::ComplexType::get(inputTensorType.getDtype())));
if (needTranspose) {
Value complex = torch::Torch::AtenViewAsComplexOp::create(
rewriter, loc, complexResultType, concatenated);
Value cstLastDim = torch::Torch::ConstantIntOp::create(
rewriter, loc, rewriter.getI64IntegerAttr(lastDim));
Value cstFftDim = torch::Torch::ConstantIntOp::create(
rewriter, loc, rewriter.getI64IntegerAttr(dim));
std::swap(complexResultSizes[dim], complexResultSizes[lastDim]);
rewriter.replaceOpWithNewOp<torch::Torch::AtenTransposeIntOp>(
op,
complexResultType.getWithSizesAndDtype(complexResultSizes,
complexResultType.getDtype()),
complex, cstFftDim, cstLastDim);
} else {
rewriter.replaceOpWithNewOp<torch::Torch::AtenViewAsComplexOp>(
op, complexResultType, concatenated);
}
return success();
}
};
class ConvertTorchUnstructuredToLinalgExtPass final
: public impl::ConvertTorchUnstructuredToLinalgExtPassBase<
ConvertTorchUnstructuredToLinalgExtPass> {
public:
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<IREE::LinalgExt::IREELinalgExtDialect,
torch::Torch::TorchDialect, tensor::TensorDialect,
linalg::LinalgDialect, arith::ArithDialect,
torch::TorchConversion::TorchConversionDialect>();
}
void runOnOperation() override {
MLIRContext *context = &getContext();
RewritePatternSet patterns(context);
patterns.add<FftRfftOpConversion>(context);
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
signalPassFailure();
}
}
};
} // namespace
} // namespace mlir::iree_compiler::TorchInput