| // Copyright 2023 The IREE Authors |
| // |
| // Licensed under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| |
| #include "compiler/plugins/input/TOSA/InputConversion/PassDetail.h" |
| #include "compiler/plugins/input/TOSA/InputConversion/Passes.h" |
| #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h" |
| #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h" |
| #include "mlir/Dialect/Arith/IR/Arith.h" |
| #include "mlir/Dialect/Linalg/IR/Linalg.h" |
| #include "mlir/Dialect/Tensor/IR/Tensor.h" |
| #include "mlir/Dialect/Tosa/IR/TosaOps.h" |
| #include "mlir/Dialect/Tosa/Transforms/Passes.h" |
| #include "mlir/Interfaces/FunctionInterfaces.h" |
| #include "mlir/Transforms/DialectConversion.h" |
| |
| using namespace mlir; |
| using namespace mlir::tosa; |
| |
| namespace mlir::iree_compiler { |
| |
| // Converts tosa.scatter to the iree_linalg_ext.scatter operation. As the |
| // LinalgExt version is not batched therefore we materialize the batch index |
| // for each update. |
| class ScatterConversion : public OpRewritePattern<tosa::ScatterOp> { |
| public: |
| using OpRewritePattern<tosa::ScatterOp>::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(tosa::ScatterOp op, |
| PatternRewriter &rewriter) const final { |
| auto values = op.getValuesIn(); |
| auto indices = llvm::cast<Value>(op.getIndices()); |
| auto updates = llvm::cast<Value>(op.getInput()); |
| auto valuesTy = llvm::dyn_cast<RankedTensorType>(values.getType()); |
| auto indicesTy = llvm::dyn_cast<RankedTensorType>(indices.getType()); |
| auto updatesTy = llvm::dyn_cast<RankedTensorType>(updates.getType()); |
| ImplicitLocOpBuilder builder(op.getLoc(), rewriter); |
| |
| if (!valuesTy || !indicesTy || !updatesTy) |
| return rewriter.notifyMatchFailure(op, |
| "tosa.gather has unknown input rank"); |
| |
| // TOSA's scatter does not include a index dimension, instead it implicitly |
| // supports an index depth of one. We materialize that implicit index of |
| // one as follows: [batch, updates] -> [batch, updates, index_depth=1] With |
| // a indexing map of [[0], [1, 2]]. |
| llvm::SmallVector<int64_t> expandIndShape{indicesTy.getDimSize(0), |
| indicesTy.getDimSize(1), 1}; |
| SmallVector<ReassociationExprs> expandIndMap; |
| expandIndMap.push_back({ |
| builder.getAffineDimExpr(0), |
| }); |
| expandIndMap.push_back({ |
| builder.getAffineDimExpr(1), |
| builder.getAffineDimExpr(2), |
| }); |
| |
| indices = builder.create<tensor::ExpandShapeOp>( |
| indicesTy.clone(expandIndShape), indices, expandIndMap); |
| indicesTy = llvm::dyn_cast<RankedTensorType>(indices.getType()); |
| |
| // Materialize the batch indice as LinalgExt scatter is not batched. |
| { |
| llvm::SmallVector<Value> dynDims; |
| for (int i = 0, s = indicesTy.getRank(); i < s; ++i) |
| if (indicesTy.isDynamicDim(i)) |
| dynDims.push_back(builder.create<tensor::DimOp>(indices, i)); |
| |
| Value empty = builder.create<tensor::EmptyOp>( |
| indicesTy.getShape(), indicesTy.getElementType(), dynDims); |
| |
| Value batchIdx = nullptr; |
| |
| if (indicesTy.getDimSize(0) == 1) { |
| Value zero = builder.create<arith::ConstantOp>( |
| rewriter.getZeroAttr(indicesTy.getElementType())); |
| batchIdx = builder.create<linalg::FillOp>(zero, empty).getResult(0); |
| } else { |
| SmallVector<utils::IteratorType> iterators( |
| indicesTy.getRank(), utils::IteratorType::parallel); |
| SmallVector<AffineMap, 3> indexingMaps( |
| 2, builder.getMultiDimIdentityMap(indicesTy.getRank())); |
| |
| auto blockBuilder = [&](OpBuilder &nestedBuilder, Location nestedLoc, |
| ValueRange blockArgs) { |
| ImplicitLocOpBuilder b(op.getLoc(), nestedBuilder); |
| auto index = b.create<linalg::IndexOp>(0); |
| auto cast = |
| b.create<arith::IndexCastOp>(indicesTy.getElementType(), index); |
| b.create<linalg::YieldOp>(cast.getResult()); |
| }; |
| batchIdx = builder |
| .create<linalg::GenericOp>(indicesTy, indices, empty, |
| indexingMaps, iterators, |
| blockBuilder) |
| .getResult(0); |
| } |
| |
| indicesTy = llvm::cast<RankedTensorType>(indicesTy.clone( |
| {indicesTy.getDimSize(0), indicesTy.getDimSize(1), 2})); |
| indices = builder.create<tosa::ConcatOp>(indicesTy, |
| ValueRange{batchIdx, indices}, |
| rewriter.getI32IntegerAttr(2)); |
| } |
| |
| auto collapseBatch = [](Value value, ImplicitLocOpBuilder &b) -> Value { |
| auto valueTy = llvm::cast<ShapedType>(value.getType()); |
| llvm::SmallVector<int64_t> collapseShape(valueTy.getShape().drop_front()); |
| llvm::SmallVector<ReassociationExprs> collapseMap(valueTy.getRank() - 1); |
| collapseMap.front().push_back(b.getAffineDimExpr(0)); |
| for (int i = 0, s = collapseMap.size(); i < s; ++i) { |
| collapseMap[i].push_back(b.getAffineDimExpr(i + 1)); |
| } |
| |
| int64_t batch = valueTy.getShape().front(); |
| int64_t rows = collapseShape.front(); |
| bool batchDyn = ShapedType::isDynamic(batch); |
| bool rowsDyn = ShapedType::isDynamic(rows); |
| collapseShape[0] = |
| (batchDyn || rowsDyn) ? ShapedType::kDynamic : batch * rows; |
| |
| return b.create<tensor::CollapseShapeOp>(valueTy.clone(collapseShape), |
| value, collapseMap); |
| }; |
| |
| indices = collapseBatch(indices, builder); |
| updates = collapseBatch(updates, builder); |
| |
| // Create the LinalgExt scatter operation. |
| auto scatter = builder.create<IREE::LinalgExt::ScatterOp>( |
| TypeRange{values.getType()}, ValueRange{updates, indices}, |
| ValueRange{values}, builder.getDenseI64ArrayAttr({0, 1}), |
| builder.getBoolAttr(true)); |
| |
| llvm::SmallVector<Type> args(2, valuesTy.getElementType()); |
| Block *scatterBody = |
| builder.createBlock(&scatter.getRegion(), {}, args, |
| llvm::SmallVector<Location>(2, op.getLoc())); |
| builder.setInsertionPointToStart(scatterBody); |
| builder.create<IREE::LinalgExt::YieldOp>(scatterBody->getArgument(0)); |
| rewriter.replaceOp(op, scatter.getResult(0)); |
| return success(); |
| } |
| }; |
| |
| struct TosaToLinalgExtPass : public TosaToLinalgExtBase<TosaToLinalgExtPass> { |
| void runOnOperation() override { |
| RewritePatternSet patterns(&getContext()); |
| ConversionTarget target(getContext()); |
| target.addIllegalOp<tosa::ScatterOp>(); |
| target.markUnknownOpDynamicallyLegal([](Operation *) { return true; }); |
| |
| FunctionOpInterface func = getOperation(); |
| mlir::iree_compiler::populateTosaToLinalgExtPatterns(&patterns); |
| if (failed(applyFullConversion(func, target, std::move(patterns)))) |
| signalPassFailure(); |
| } |
| }; |
| |
| void populateTosaToLinalgExtPatterns(RewritePatternSet *patterns) { |
| patterns->add<ScatterConversion>(patterns->getContext()); |
| } |
| |
| std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>> |
| createTosaToLinalgExt() { |
| return std::make_unique<TosaToLinalgExtPass>(); |
| } |
| |
| } // namespace mlir::iree_compiler |