| // Copyright 2025 The IREE Authors |
| // |
| // Licensed under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| |
| #include "compiler/plugins/input/Torch/InputConversion/Passes.h" |
| #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h" |
| #include "iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h" |
| #include "mlir/Transforms/DialectConversion.h" |
| #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
| #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" |
| #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" |
| #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h" |
| #include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h" |
| |
| namespace mlir::iree_compiler::TorchInput { |
| |
| #define GEN_PASS_DEF_CONVERTTORCHUNSTRUCTUREDTOLINALGEXTPASS |
| #include "compiler/plugins/input/Torch/InputConversion/Passes.h.inc" |
| |
| namespace { |
| |
| struct FftRfftOpConversion |
| : public OpRewritePattern<torch::Torch::AtenFftRfftOp> { |
| using OpRewritePattern::OpRewritePattern; |
| LogicalResult matchAndRewrite(torch::Torch::AtenFftRfftOp op, |
| PatternRewriter &rewriter) const override { |
| |
| Location loc = op.getLoc(); |
| Value self = op.getSelf(); |
| |
| int64_t dim; |
| Value dimVal = op.getDim(); |
| if (isa<torch::Torch::NoneType>(dimVal.getType())) { |
| dim = -1; |
| } else if (!matchPattern(dimVal, torch::Torch::m_TorchConstantInt(&dim))) { |
| return rewriter.notifyMatchFailure( |
| op, "unimplemented: requires dim to be constant"); |
| } |
| |
| if (!isa<torch::Torch::NoneType>(op.getN().getType())) { |
| return rewriter.notifyMatchFailure(op, "unimplemented: parameter n"); |
| } |
| |
| if (!isa<torch::Torch::NoneType>(op.getNorm().getType())) { |
| return rewriter.notifyMatchFailure(op, "unimplemented: parameter norm"); |
| } |
| |
| auto inputTensorType = cast<torch::Torch::ValueTensorType>(self.getType()); |
| if (!inputTensorType || !inputTensorType.hasSizes()) { |
| return rewriter.notifyMatchFailure(op, |
| "expected input type having sizes"); |
| } |
| ArrayRef<int64_t> inputShape = inputTensorType.getSizes(); |
| dim += dim < 0 ? inputShape.size() : 0; |
| |
| int64_t fftLength = inputShape[dim]; |
| if (fftLength == torch::Torch::kUnknownSize) { |
| return rewriter.notifyMatchFailure(op, |
| "expected known FFT dimension size"); |
| } |
| if (!llvm::isPowerOf2_64(fftLength)) { |
| return rewriter.notifyMatchFailure( |
| op, "expected FFT length to be a power of two"); |
| } |
| |
| // Transpose if FFT dimension is not the last one |
| SmallVector<int64_t> preFftShape(inputShape); |
| const int64_t lastDim = inputShape.size() - 1; |
| const bool needTranspose = dim != lastDim; |
| if (needTranspose) { |
| Value cstLastDim = torch::Torch::ConstantIntOp::create( |
| rewriter, loc, rewriter.getI64IntegerAttr(lastDim)); |
| Value cstFftDim = torch::Torch::ConstantIntOp::create( |
| rewriter, loc, rewriter.getI64IntegerAttr(dim)); |
| std::swap(preFftShape[dim], preFftShape[lastDim]); |
| |
| self = torch::Torch::AtenTransposeIntOp::create( |
| rewriter, loc, |
| inputTensorType.getWithSizesAndDtype(preFftShape, |
| inputTensorType.getDtype()), |
| self, cstFftDim, cstLastDim); |
| } |
| |
| // Cast to the builtin tensor type. |
| Value builtinCast = torch::TorchConversion::ToBuiltinTensorOp::create( |
| rewriter, loc, |
| cast<torch::Torch::ValueTensorType>(self.getType()).toBuiltinTensor(), |
| self); |
| |
| auto rewriteRes = |
| IREE::LinalgExt::rewriteFft(op, builtinCast, fftLength, rewriter); |
| if (failed(rewriteRes)) { |
| return failure(); |
| } |
| |
| auto [real, imag] = rewriteRes.value(); |
| |
| // Cast back |
| SmallVector<int64_t> postFftShape(preFftShape); |
| postFftShape.back() = fftLength / 2 + 1; |
| Type postFftType = inputTensorType.getWithSizesAndDtype( |
| postFftShape, inputTensorType.getDtype()); |
| Value torchReal = torch::TorchConversion::FromBuiltinTensorOp::create( |
| rewriter, loc, postFftType, real); |
| Value torchImag = torch::TorchConversion::FromBuiltinTensorOp::create( |
| rewriter, loc, postFftType, imag); |
| |
| // Unsqueeze a 1 dimension at the end |
| SmallVector<int64_t> unsqueezedTensorSizes(postFftShape); |
| unsqueezedTensorSizes.push_back(1); |
| Type unsqueezedTensorType = inputTensorType.getWithSizesAndDtype( |
| unsqueezedTensorSizes, inputTensorType.getDtype()); |
| Value axisUnsqueeze = torch::Torch::ConstantIntOp::create( |
| rewriter, loc, rewriter.getI64IntegerAttr(-1)); |
| Value unsqueezedReal = torch::Torch::AtenUnsqueezeOp::create( |
| rewriter, loc, unsqueezedTensorType, torchReal, axisUnsqueeze); |
| Value unsqueezedImag = torch::Torch::AtenUnsqueezeOp::create( |
| rewriter, loc, unsqueezedTensorType, torchImag, axisUnsqueeze); |
| |
| // Concatenate real and imag |
| Type listType = torch::Torch::ListType::get(unsqueezedTensorType); |
| Value slices = torch::Torch::PrimListConstructOp::create( |
| rewriter, loc, listType, |
| llvm::ArrayRef<Value>{unsqueezedReal, unsqueezedImag}); |
| SmallVector<int64_t> concatenatedTensorSizes(unsqueezedTensorSizes); |
| concatenatedTensorSizes.back() = 2; |
| Type concatenatedTensorType = inputTensorType.getWithSizesAndDtype( |
| concatenatedTensorSizes, inputTensorType.getDtype()); |
| Value concatenated = torch::Torch::AtenCatOp::create( |
| rewriter, loc, concatenatedTensorType, slices, axisUnsqueeze); |
| |
| // View as complex (and transpose back) |
| SmallVector<int64_t> complexResultSizes(concatenatedTensorSizes); |
| complexResultSizes.pop_back(); |
| torch::Torch::ValueTensorType complexResultType = |
| cast<torch::Torch::ValueTensorType>( |
| inputTensorType.getWithSizesAndDtype( |
| complexResultSizes, |
| mlir::ComplexType::get(inputTensorType.getDtype()))); |
| if (needTranspose) { |
| Value complex = torch::Torch::AtenViewAsComplexOp::create( |
| rewriter, loc, complexResultType, concatenated); |
| |
| Value cstLastDim = torch::Torch::ConstantIntOp::create( |
| rewriter, loc, rewriter.getI64IntegerAttr(lastDim)); |
| Value cstFftDim = torch::Torch::ConstantIntOp::create( |
| rewriter, loc, rewriter.getI64IntegerAttr(dim)); |
| std::swap(complexResultSizes[dim], complexResultSizes[lastDim]); |
| |
| rewriter.replaceOpWithNewOp<torch::Torch::AtenTransposeIntOp>( |
| op, |
| complexResultType.getWithSizesAndDtype(complexResultSizes, |
| complexResultType.getDtype()), |
| complex, cstFftDim, cstLastDim); |
| } else { |
| rewriter.replaceOpWithNewOp<torch::Torch::AtenViewAsComplexOp>( |
| op, complexResultType, concatenated); |
| } |
| |
| return success(); |
| } |
| }; |
| |
| class ConvertTorchUnstructuredToLinalgExtPass final |
| : public impl::ConvertTorchUnstructuredToLinalgExtPassBase< |
| ConvertTorchUnstructuredToLinalgExtPass> { |
| public: |
| void getDependentDialects(DialectRegistry ®istry) const override { |
| registry.insert<IREE::LinalgExt::IREELinalgExtDialect, |
| torch::Torch::TorchDialect, tensor::TensorDialect, |
| linalg::LinalgDialect, arith::ArithDialect, |
| torch::TorchConversion::TorchConversionDialect>(); |
| } |
| void runOnOperation() override { |
| MLIRContext *context = &getContext(); |
| RewritePatternSet patterns(context); |
| |
| patterns.add<FftRfftOpConversion>(context); |
| |
| if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) { |
| signalPassFailure(); |
| } |
| } |
| }; |
| } // namespace |
| |
| } // namespace mlir::iree_compiler::TorchInput |