| // Copyright 2025 The IREE Authors |
| // |
| // Licensed under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| |
| #include "compiler/plugins/input/Torch/InputConversion/Passes.h" |
| #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h" |
| #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h" |
| #include "iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h" |
| #include "llvm/ADT/APFloat.h" |
| #include "mlir/Dialect/Arith/IR/Arith.h" |
| #include "mlir/Dialect/Arith/Utils/Utils.h" |
| #include "mlir/Dialect/Func/IR/FuncOps.h" |
| #include "mlir/Dialect/Linalg/IR/Linalg.h" |
| #include "mlir/Dialect/Math/IR/Math.h" |
| #include "mlir/Dialect/Tensor/IR/Tensor.h" |
| #include "mlir/IR/BuiltinTypes.h" |
| #include "mlir/IR/IRMapping.h" |
| #include "mlir/Transforms/DialectConversion.h" |
| #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
| #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" |
| #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" |
| #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h" |
| #include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h" |
| |
| namespace mlir::iree_compiler::TorchInput { |
| |
| #define GEN_PASS_DEF_CONVERTTORCHUNSTRUCTUREDTOLINALGEXTPASS |
| #include "compiler/plugins/input/Torch/InputConversion/Passes.h.inc" |
| |
| namespace { |
| |
| struct FftRfftOpConversion : OpRewritePattern<torch::Torch::AtenFftRfftOp> { |
| using Base::Base; |
| LogicalResult matchAndRewrite(torch::Torch::AtenFftRfftOp op, |
| PatternRewriter &rewriter) const override { |
| |
| Location loc = op.getLoc(); |
| Value self = op.getSelf(); |
| |
| int64_t dim; |
| Value dimVal = op.getDim(); |
| if (isa<torch::Torch::NoneType>(dimVal.getType())) { |
| dim = -1; |
| } else if (!matchPattern(dimVal, torch::Torch::m_TorchConstantInt(&dim))) { |
| return rewriter.notifyMatchFailure( |
| op, "unimplemented: requires dim to be constant"); |
| } |
| |
| if (!isa<torch::Torch::NoneType>(op.getN().getType())) { |
| return rewriter.notifyMatchFailure(op, "unimplemented: parameter n"); |
| } |
| |
| if (!isa<torch::Torch::NoneType>(op.getNorm().getType())) { |
| return rewriter.notifyMatchFailure(op, "unimplemented: parameter norm"); |
| } |
| |
| auto inputTensorType = cast<torch::Torch::ValueTensorType>(self.getType()); |
| if (!inputTensorType || !inputTensorType.hasSizes()) { |
| return rewriter.notifyMatchFailure(op, |
| "expected input type having sizes"); |
| } |
| ArrayRef<int64_t> inputShape = inputTensorType.getSizes(); |
| dim += dim < 0 ? inputShape.size() : 0; |
| |
| int64_t fftLength = inputShape[dim]; |
| if (fftLength == torch::Torch::kUnknownSize) { |
| return rewriter.notifyMatchFailure(op, |
| "expected known FFT dimension size"); |
| } |
| if (!llvm::isPowerOf2_64(fftLength)) { |
| return rewriter.notifyMatchFailure( |
| op, "expected FFT length to be a power of two"); |
| } |
| |
| // Transpose if FFT dimension is not the last one |
| SmallVector<int64_t> preFftShape(inputShape); |
| const int64_t lastDim = inputShape.size() - 1; |
| const bool needTranspose = dim != lastDim; |
| if (needTranspose) { |
| Value cstLastDim = torch::Torch::ConstantIntOp::create( |
| rewriter, loc, rewriter.getI64IntegerAttr(lastDim)); |
| Value cstFftDim = torch::Torch::ConstantIntOp::create( |
| rewriter, loc, rewriter.getI64IntegerAttr(dim)); |
| std::swap(preFftShape[dim], preFftShape[lastDim]); |
| |
| self = torch::Torch::AtenTransposeIntOp::create( |
| rewriter, loc, |
| inputTensorType.getWithSizesAndDtype(preFftShape, |
| inputTensorType.getDtype()), |
| self, cstFftDim, cstLastDim); |
| } |
| |
| // Cast to the builtin tensor type. |
| Value builtinCast = torch::TorchConversion::ToBuiltinTensorOp::create( |
| rewriter, loc, |
| cast<torch::Torch::ValueTensorType>(self.getType()).toBuiltinTensor(), |
| self); |
| |
| auto rewriteRes = |
| IREE::LinalgExt::rewriteRfft(op, builtinCast, fftLength, rewriter); |
| if (failed(rewriteRes)) { |
| return failure(); |
| } |
| |
| auto [real, imag] = rewriteRes.value(); |
| |
| // Cast back |
| SmallVector<int64_t> postFftShape(preFftShape); |
| postFftShape.back() = fftLength / 2 + 1; |
| Type postFftType = inputTensorType.getWithSizesAndDtype( |
| postFftShape, inputTensorType.getDtype()); |
| Value torchReal = torch::TorchConversion::FromBuiltinTensorOp::create( |
| rewriter, loc, postFftType, real); |
| Value torchImag = torch::TorchConversion::FromBuiltinTensorOp::create( |
| rewriter, loc, postFftType, imag); |
| |
| // Unsqueeze a 1 dimension at the end |
| SmallVector<int64_t> unsqueezedTensorSizes(postFftShape); |
| unsqueezedTensorSizes.push_back(1); |
| Type unsqueezedTensorType = inputTensorType.getWithSizesAndDtype( |
| unsqueezedTensorSizes, inputTensorType.getDtype()); |
| Value axisUnsqueeze = torch::Torch::ConstantIntOp::create( |
| rewriter, loc, rewriter.getI64IntegerAttr(-1)); |
| Value unsqueezedReal = torch::Torch::AtenUnsqueezeOp::create( |
| rewriter, loc, unsqueezedTensorType, torchReal, axisUnsqueeze); |
| Value unsqueezedImag = torch::Torch::AtenUnsqueezeOp::create( |
| rewriter, loc, unsqueezedTensorType, torchImag, axisUnsqueeze); |
| |
| // Concatenate real and imag |
| Type listType = torch::Torch::ListType::get(unsqueezedTensorType); |
| Value slices = torch::Torch::PrimListConstructOp::create( |
| rewriter, loc, listType, |
| llvm::ArrayRef<Value>{unsqueezedReal, unsqueezedImag}); |
| SmallVector<int64_t> concatenatedTensorSizes(unsqueezedTensorSizes); |
| concatenatedTensorSizes.back() = 2; |
| Type concatenatedTensorType = inputTensorType.getWithSizesAndDtype( |
| concatenatedTensorSizes, inputTensorType.getDtype()); |
| Value concatenated = torch::Torch::AtenCatOp::create( |
| rewriter, loc, concatenatedTensorType, slices, axisUnsqueeze); |
| |
| // View as complex (and transpose back) |
| SmallVector<int64_t> complexResultSizes(concatenatedTensorSizes); |
| complexResultSizes.pop_back(); |
| torch::Torch::ValueTensorType complexResultType = |
| cast<torch::Torch::ValueTensorType>( |
| inputTensorType.getWithSizesAndDtype( |
| complexResultSizes, |
| mlir::ComplexType::get(inputTensorType.getDtype()))); |
| if (needTranspose) { |
| Value complex = torch::Torch::AtenViewAsComplexOp::create( |
| rewriter, loc, complexResultType, concatenated); |
| |
| Value cstLastDim = torch::Torch::ConstantIntOp::create( |
| rewriter, loc, rewriter.getI64IntegerAttr(lastDim)); |
| Value cstFftDim = torch::Torch::ConstantIntOp::create( |
| rewriter, loc, rewriter.getI64IntegerAttr(dim)); |
| std::swap(complexResultSizes[dim], complexResultSizes[lastDim]); |
| |
| rewriter.replaceOpWithNewOp<torch::Torch::AtenTransposeIntOp>( |
| op, |
| complexResultType.getWithSizesAndDtype(complexResultSizes, |
| complexResultType.getDtype()), |
| complex, cstFftDim, cstLastDim); |
| } else { |
| rewriter.replaceOpWithNewOp<torch::Torch::AtenViewAsComplexOp>( |
| op, complexResultType, concatenated); |
| } |
| |
| return success(); |
| } |
| }; |
| |
| //===----------------------------------------------------------------------===// |
| // FlexAttention -> OnlineAttention conversion |
| //===----------------------------------------------------------------------===// |
| |
| static Value convertToBuiltinTensor(PatternRewriter &rewriter, Location loc, |
| Value torchTensor) { |
| auto tensorType = cast<torch::Torch::ValueTensorType>(torchTensor.getType()); |
| return torch::TorchConversion::ToBuiltinTensorOp::create( |
| rewriter, loc, tensorType.toBuiltinTensor(), torchTensor); |
| } |
| |
| /// Inline a single-block torch function's body at the current insertion point. |
| /// Falls back to func.call for multi-block or external functions. |
| static SmallVector<Value> inlineTorchFunction(PatternRewriter &rewriter, |
| Location loc, |
| FlatSymbolRefAttr funcSymbol, |
| ValueRange args, |
| Operation *contextOp) { |
| auto module = contextOp->getParentOfType<ModuleOp>(); |
| auto funcOp = module.lookupSymbol<func::FuncOp>(funcSymbol); |
| if (!funcOp || funcOp.isExternal() || !funcOp.getBody().hasOneBlock()) { |
| auto callOp = func::CallOp::create(rewriter, loc, funcSymbol, |
| funcOp.getResultTypes(), args); |
| return SmallVector<Value>(callOp->getResults()); |
| } |
| |
| Block &entryBlock = funcOp.getBody().front(); |
| IRMapping mapper; |
| for (auto [blockArg, callArg] : llvm::zip(entryBlock.getArguments(), args)) { |
| mapper.map(blockArg, callArg); |
| } |
| |
| for (Operation &op : entryBlock.without_terminator()) { |
| rewriter.clone(op, mapper); |
| } |
| |
| auto returnOp = cast<func::ReturnOp>(entryBlock.getTerminator()); |
| SmallVector<Value> results; |
| for (Value operand : returnOp.getOperands()) { |
| results.push_back(mapper.lookupOrDefault(operand)); |
| } |
| return results; |
| } |
| |
| /// Build the score modification region inside the OnlineAttention op. |
| /// Both mask_mod and score_mod are inlined into the region to avoid separate |
| /// mask materialization and to enable fusion during attention decomposition. |
| /// |
| /// The region computes: |
| /// 1. mask = mask_mod_fn(b, h, q_idx, kv_idx) [if mask_mod present] |
| /// 2. score = select(mask, score, -inf) [if mask_mod present] |
| /// 3. score = score_mod_fn(score, b, h, q, kv) [if score_mod present] |
| /// 4. yield score |
| static void |
| createScoreModificationRegion(PatternRewriter &rewriter, Location loc, |
| IREE::LinalgExt::OnlineAttentionOp onlineAttnOp, |
| FlatSymbolRefAttr scoreModSymbol, |
| FlatSymbolRefAttr maskModSymbol, |
| FloatType floatType, Operation *contextOp) { |
| Region ®ion = onlineAttnOp.getRegion(); |
| OpBuilder::InsertionGuard guard(rewriter); |
| Block *block = |
| rewriter.createBlock(®ion, region.end(), {floatType}, {loc}); |
| |
| Value score = block->getArgument(0); |
| bool needIndices = scoreModSymbol || maskModSymbol; |
| |
| // Build index torch tensors: b, h, q_idx, kv_idx (dims 0-3). |
| SmallVector<Value> torchIndices; |
| if (needIndices) { |
| auto signlessI32 = rewriter.getIntegerType(32); |
| auto signedI32 = |
| IntegerType::get(rewriter.getContext(), 32, IntegerType::Signed); |
| auto torchI32Scalar = torch::Torch::ValueTensorType::get( |
| rewriter.getContext(), ArrayRef<int64_t>{}, signedI32); |
| for (int dim : {0, 1, 2, 3}) { |
| Value idx = IREE::LinalgExt::IndexOp::create(rewriter, loc, dim); |
| Value idxI32 = |
| arith::IndexCastOp::create(rewriter, loc, signlessI32, idx); |
| Value idxTensor = tensor::FromElementsOp::create( |
| rewriter, loc, RankedTensorType::get({}, signlessI32), |
| ValueRange{idxI32}); |
| Value torchIdx = torch::TorchConversion::FromBuiltinTensorOp::create( |
| rewriter, loc, torchI32Scalar, idxTensor); |
| torchIndices.push_back(torchIdx); |
| } |
| } |
| |
| // Inline mask_mod: compute mask and apply select(mask, score, -inf). |
| if (maskModSymbol) { |
| Value maskResult = inlineTorchFunction(rewriter, loc, maskModSymbol, |
| torchIndices, contextOp)[0]; |
| |
| auto boolType = rewriter.getIntegerType(1); |
| Value builtinBool = torch::TorchConversion::ToBuiltinTensorOp::create( |
| rewriter, loc, RankedTensorType::get({}, boolType), maskResult); |
| Value boolScalar = |
| tensor::ExtractOp::create(rewriter, loc, builtinBool, ValueRange{}); |
| |
| Value negInf = arith::ConstantOp::create( |
| rewriter, loc, |
| rewriter.getFloatAttr( |
| floatType, |
| APFloat::getInf(floatType.getFloatSemantics(), /*Negative=*/true))); |
| score = arith::SelectOp::create(rewriter, loc, boolScalar, score, negInf); |
| } |
| |
| // Inline score_mod: transform the (possibly masked) score. |
| if (scoreModSymbol) { |
| auto f32ScalarTensor = RankedTensorType::get({}, floatType); |
| auto torchF32Scalar = torch::Torch::ValueTensorType::get( |
| rewriter.getContext(), ArrayRef<int64_t>{}, floatType); |
| Value scoreTensor = tensor::FromElementsOp::create( |
| rewriter, loc, f32ScalarTensor, ValueRange{score}); |
| Value torchScore = torch::TorchConversion::FromBuiltinTensorOp::create( |
| rewriter, loc, torchF32Scalar, scoreTensor); |
| |
| SmallVector<Value> scoreArgs = {torchScore}; |
| scoreArgs.append(torchIndices.begin(), torchIndices.end()); |
| Value torchResult = inlineTorchFunction(rewriter, loc, scoreModSymbol, |
| scoreArgs, contextOp)[0]; |
| |
| Value builtinResult = torch::TorchConversion::ToBuiltinTensorOp::create( |
| rewriter, loc, f32ScalarTensor, torchResult); |
| score = |
| tensor::ExtractOp::create(rewriter, loc, builtinResult, ValueRange{}); |
| } |
| |
| IREE::LinalgExt::YieldOp::create(rewriter, loc, score); |
| } |
| |
| struct FlexAttentionOpConversion |
| : OpRewritePattern<torch::Torch::HigherOrderFlexAttentionOp> { |
| using Base::Base; |
| |
| static constexpr int64_t kAttentionRank = 4; |
| |
| LogicalResult matchAndRewrite(torch::Torch::HigherOrderFlexAttentionOp op, |
| PatternRewriter &rewriter) const override { |
| Location loc = op.getLoc(); |
| |
| Value query = op.getQuery(); |
| Value key = op.getKey(); |
| Value value = op.getValue(); |
| Value scaleVal = op.getScale(); |
| |
| auto scoreModSymbol = op.getScoreModFnAttr(); |
| auto maskModSymbol = op.getMaskModFnAttr(); |
| |
| // Extract return_lse and return_max_scores. |
| bool returnLse, returnMaxScores; |
| if (!matchPattern(op.getReturnLse(), |
| torch::Torch::m_TorchConstantBool(&returnLse))) { |
| return rewriter.notifyMatchFailure( |
| op, "expected return_lse to be a constant bool"); |
| } |
| if (!matchPattern(op.getReturnMaxScores(), |
| torch::Torch::m_TorchConstantBool(&returnMaxScores))) { |
| return rewriter.notifyMatchFailure( |
| op, "expected return_max_scores to be a constant bool"); |
| } |
| |
| // Extract shapes from Q, K, V. |
| auto queryType = cast<torch::Torch::ValueTensorType>(query.getType()); |
| auto valueType = cast<torch::Torch::ValueTensorType>(value.getType()); |
| |
| ArrayRef<int64_t> queryShape = queryType.getSizes(); |
| ArrayRef<int64_t> valueShape = valueType.getSizes(); |
| |
| // Q: [B, H, M, K1], K: [B, H, N, K1], V: [B, H, N, K2] |
| int64_t batch = queryShape[0]; |
| int64_t numHeads = queryShape[1]; |
| int64_t seqLenQ = queryShape[2]; |
| int64_t valueDim = valueShape[3]; |
| |
| auto floatType = Float32Type::get(rewriter.getContext()); |
| |
| Value builtinQ = convertToBuiltinTensor(rewriter, loc, query); |
| Value builtinK = convertToBuiltinTensor(rewriter, loc, key); |
| Value builtinV = convertToBuiltinTensor(rewriter, loc, value); |
| |
| // Resolve scale: try constant float, else compute rsqrt(headDim). |
| Value scale; |
| double scaleDouble; |
| if (matchPattern(scaleVal, |
| torch::Torch::m_TorchConstantFloat(&scaleDouble))) { |
| scale = arith::ConstantOp::create( |
| rewriter, loc, rewriter.getFloatAttr(floatType, scaleDouble)); |
| } else { |
| int64_t queryRank = queryShape.size(); |
| Value dimIdx = |
| tensor::DimOp::create(rewriter, loc, builtinQ, queryRank - 1); |
| Value dimI64 = arith::IndexCastOp::create(rewriter, loc, |
| rewriter.getI64Type(), dimIdx); |
| Value dimF32 = arith::SIToFPOp::create(rewriter, loc, floatType, dimI64); |
| scale = math::RsqrtOp::create(rewriter, loc, dimF32); |
| } |
| |
| // 6D iteration space: (b, h, m, n, k1, k2) |
| AffineExpr b, h, m, n, k1, k2; |
| bindDims(rewriter.getContext(), b, h, m, n, k1, k2); |
| int64_t numDims = 6; |
| |
| auto getMap = [&](ArrayRef<AffineExpr> results) { |
| return AffineMap::get(numDims, 0, results, rewriter.getContext()); |
| }; |
| |
| AffineMap qMap = getMap({b, h, m, k1}); |
| AffineMap kMap = getMap({b, h, n, k1}); |
| AffineMap vMap = getMap({b, h, n, k2}); |
| AffineMap scaleMap = AffineMap::get(numDims, 0, rewriter.getContext()); |
| AffineMap outputMap = getMap({b, h, m, k2}); |
| AffineMap maxMap = getMap({b, h, m}); |
| AffineMap sumMap = getMap({b, h, m}); |
| |
| // No mask operand: mask_mod is inlined into the score region. |
| SmallVector<AffineMap> indexingMaps = {qMap, kMap, vMap, scaleMap}; |
| indexingMaps.push_back(outputMap); |
| indexingMaps.push_back(maxMap); |
| indexingMaps.push_back(sumMap); |
| |
| // Create output tensor. |
| auto outputShape = SmallVector<int64_t>{batch, numHeads, seqLenQ, valueDim}; |
| Value outputEmpty = |
| tensor::EmptyOp::create(rewriter, loc, outputShape, floatType); |
| |
| // Create and fill max/sum tensors. |
| auto rowRedShape = SmallVector<int64_t>{batch, numHeads, seqLenQ}; |
| Value rowRedEmpty = |
| tensor::EmptyOp::create(rewriter, loc, rowRedShape, floatType); |
| |
| Value zeroInit = arith::getIdentityValue(arith::AtomicRMWKind::addf, |
| floatType, rewriter, loc); |
| Value maxInit = arith::getIdentityValue(arith::AtomicRMWKind::maximumf, |
| floatType, rewriter, loc, |
| /*useOnlyFiniteValue=*/true); |
| |
| Value accFill = |
| linalg::FillOp::create(rewriter, loc, ValueRange{zeroInit}, outputEmpty) |
| .getResult(0); |
| Value maxFill = |
| linalg::FillOp::create(rewriter, loc, ValueRange{maxInit}, rowRedEmpty) |
| .getResult(0); |
| Value sumFill = |
| linalg::FillOp::create(rewriter, loc, ValueRange{zeroInit}, rowRedEmpty) |
| .getResult(0); |
| |
| // Create OnlineAttentionOp without a mask operand. |
| auto onlineAttnOp = IREE::LinalgExt::OnlineAttentionOp::create( |
| rewriter, loc, |
| TypeRange{accFill.getType(), maxFill.getType(), sumFill.getType()}, |
| builtinQ, builtinK, builtinV, scale, /*mask=*/Value(), accFill, maxFill, |
| sumFill, rewriter.getAffineMapArrayAttr(indexingMaps), |
| /*decomposition_config=*/DictionaryAttr::get(rewriter.getContext())); |
| |
| // Build score modification region with inlined mask_mod and score_mod. |
| createScoreModificationRegion(rewriter, loc, onlineAttnOp, scoreModSymbol, |
| maskModSymbol, floatType, op); |
| |
| Value attnResult = onlineAttnOp.getResult(0); |
| Value maxResult = onlineAttnOp.getResult(1); |
| Value sumResult = onlineAttnOp.getResult(2); |
| |
| // Post-process: output = (1/sum) * attnResult |
| SmallVector<AffineMap> postMaps = compressUnusedDims( |
| SmallVector<AffineMap>{sumMap, outputMap, outputMap}); |
| SmallVector<utils::IteratorType> postIterTypes( |
| postMaps[0].getNumDims(), utils::IteratorType::parallel); |
| |
| // Determine the output element type from the torch result type. |
| auto torchOutputType = |
| cast<torch::Torch::ValueTensorType>(op.getOutput().getType()); |
| auto builtinOutputType = torchOutputType.toBuiltinTensor(); |
| Type outputElemType = |
| cast<RankedTensorType>(builtinOutputType).getElementType(); |
| Value outputInit = |
| tensor::EmptyOp::create(rewriter, loc, outputShape, outputElemType); |
| |
| auto normalizeOp = linalg::GenericOp::create( |
| rewriter, loc, outputInit.getType(), ValueRange{sumResult, attnResult}, |
| ValueRange{outputInit}, postMaps, postIterTypes, |
| [&](OpBuilder &b, Location loc, ValueRange args) { |
| Value one = arith::ConstantOp::create( |
| b, loc, b.getFloatAttr(args[0].getType(), 1.0)); |
| Value reciprocal = arith::DivFOp::create(b, loc, one, args[0]); |
| Value result = arith::MulFOp::create(b, loc, reciprocal, args[1]); |
| result = convertScalarToDtype(b, loc, result, args[2].getType(), |
| /*isUnsignedCast=*/false); |
| linalg::YieldOp::create(b, loc, result); |
| }); |
| Value normalizedOutput = normalizeOp.getResult(0); |
| |
| // Convert output back to torch tensor. |
| Value torchOutput = torch::TorchConversion::FromBuiltinTensorOp::create( |
| rewriter, loc, torchOutputType, normalizedOutput); |
| |
| // Handle logsumexp: log(sum) + max, shape [B, H, M] |
| Value logsumexpResult; |
| if (returnLse) { |
| auto lseType = |
| cast<torch::Torch::ValueTensorType>(op.getLogsumexp().getType()); |
| Value lseInit = tensor::EmptyOp::create(rewriter, loc, rowRedShape, |
| lseType.getDtype()); |
| |
| auto identityMap3D = |
| AffineMap::getMultiDimIdentityMap(3, rewriter.getContext()); |
| SmallVector<AffineMap> lseMaps(3, identityMap3D); |
| SmallVector<utils::IteratorType> lseIterTypes( |
| 3, utils::IteratorType::parallel); |
| |
| auto lseOp = linalg::GenericOp::create( |
| rewriter, loc, lseInit.getType(), ValueRange{sumResult, maxResult}, |
| ValueRange{lseInit}, lseMaps, lseIterTypes, |
| [&](OpBuilder &b, Location loc, ValueRange args) { |
| Value logSum = math::LogOp::create(b, loc, args[0]); |
| Value lse = arith::AddFOp::create(b, loc, logSum, args[1]); |
| linalg::YieldOp::create(b, loc, lse); |
| }); |
| logsumexpResult = torch::TorchConversion::FromBuiltinTensorOp::create( |
| rewriter, loc, lseType, lseOp.getResult(0)); |
| } else { |
| logsumexpResult = torch::Torch::ConstantNoneOp::create(rewriter, loc); |
| } |
| |
| // Handle max_scores: directly from max result. |
| Value maxScoresResult; |
| if (returnMaxScores) { |
| auto maxType = |
| cast<torch::Torch::ValueTensorType>(op.getMaxScores().getType()); |
| maxScoresResult = torch::TorchConversion::FromBuiltinTensorOp::create( |
| rewriter, loc, maxType, maxResult); |
| } else { |
| maxScoresResult = torch::Torch::ConstantNoneOp::create(rewriter, loc); |
| } |
| |
| rewriter.replaceOp(op, {torchOutput, logsumexpResult, maxScoresResult}); |
| return success(); |
| } |
| }; |
| |
| class ConvertTorchUnstructuredToLinalgExtPass final |
| : public impl::ConvertTorchUnstructuredToLinalgExtPassBase< |
| ConvertTorchUnstructuredToLinalgExtPass> { |
| public: |
| void getDependentDialects(DialectRegistry ®istry) const override { |
| registry.insert<IREE::LinalgExt::IREELinalgExtDialect, |
| torch::Torch::TorchDialect, tensor::TensorDialect, |
| linalg::LinalgDialect, arith::ArithDialect, |
| math::MathDialect, func::FuncDialect, |
| torch::TorchConversion::TorchConversionDialect>(); |
| } |
| void runOnOperation() override { |
| MLIRContext *context = &getContext(); |
| RewritePatternSet patterns(context); |
| |
| patterns.add<FftRfftOpConversion>(context); |
| patterns.add<FlexAttentionOpConversion>(context); |
| |
| if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) { |
| signalPassFailure(); |
| } |
| } |
| }; |
| } // namespace |
| |
| } // namespace mlir::iree_compiler::TorchInput |