blob: 600b002b9adddca804a59e47256818239ac15c2f [file] [log] [blame]
// Copyright 2023 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "compiler/plugins/input/TOSA/InputConversion/PassDetail.h"
#include "compiler/plugins/input/TOSA/InputConversion/Passes.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Tosa/Transforms/Passes.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Transforms/DialectConversion.h"
using namespace mlir;
using namespace mlir::tosa;
namespace mlir::iree_compiler {
// Converts tosa.scatter to the iree_linalg_ext.scatter operation. As the
// LinalgExt version is not batched therefore we materialize the batch index
// for each update.
class ScatterConversion : public OpRewritePattern<tosa::ScatterOp> {
public:
using OpRewritePattern<tosa::ScatterOp>::OpRewritePattern;
LogicalResult matchAndRewrite(tosa::ScatterOp op,
PatternRewriter &rewriter) const final {
auto values = op.getValuesIn();
auto indices = llvm::cast<Value>(op.getIndices());
auto updates = llvm::cast<Value>(op.getInput());
auto valuesTy = llvm::dyn_cast<RankedTensorType>(values.getType());
auto indicesTy = llvm::dyn_cast<RankedTensorType>(indices.getType());
auto updatesTy = llvm::dyn_cast<RankedTensorType>(updates.getType());
ImplicitLocOpBuilder builder(op.getLoc(), rewriter);
if (!valuesTy || !indicesTy || !updatesTy)
return rewriter.notifyMatchFailure(op,
"tosa.gather has unknown input rank");
// TOSA's scatter does not include a index dimension, instead it implicitly
// supports an index depth of one. We materialize that implicit index of
// one as follows: [batch, updates] -> [batch, updates, index_depth=1] With
// a indexing map of [[0], [1, 2]].
llvm::SmallVector<int64_t> expandIndShape{indicesTy.getDimSize(0),
indicesTy.getDimSize(1), 1};
SmallVector<ReassociationExprs> expandIndMap;
expandIndMap.push_back({
builder.getAffineDimExpr(0),
});
expandIndMap.push_back({
builder.getAffineDimExpr(1),
builder.getAffineDimExpr(2),
});
indices = builder.create<tensor::ExpandShapeOp>(
indicesTy.clone(expandIndShape), indices, expandIndMap);
indicesTy = llvm::dyn_cast<RankedTensorType>(indices.getType());
// Materialize the batch indice as LinalgExt scatter is not batched.
{
llvm::SmallVector<Value> dynDims;
for (int i = 0, s = indicesTy.getRank(); i < s; ++i)
if (indicesTy.isDynamicDim(i))
dynDims.push_back(builder.create<tensor::DimOp>(indices, i));
Value empty = builder.create<tensor::EmptyOp>(
indicesTy.getShape(), indicesTy.getElementType(), dynDims);
Value batchIdx = nullptr;
if (indicesTy.getDimSize(0) == 1) {
Value zero = builder.create<arith::ConstantOp>(
rewriter.getZeroAttr(indicesTy.getElementType()));
batchIdx = builder.create<linalg::FillOp>(zero, empty).getResult(0);
} else {
SmallVector<utils::IteratorType> iterators(
indicesTy.getRank(), utils::IteratorType::parallel);
SmallVector<AffineMap, 3> indexingMaps(
2, builder.getMultiDimIdentityMap(indicesTy.getRank()));
auto blockBuilder = [&](OpBuilder &nestedBuilder, Location nestedLoc,
ValueRange blockArgs) {
ImplicitLocOpBuilder b(op.getLoc(), nestedBuilder);
auto index = b.create<linalg::IndexOp>(0);
auto cast =
b.create<arith::IndexCastOp>(indicesTy.getElementType(), index);
b.create<linalg::YieldOp>(cast.getResult());
};
batchIdx = builder
.create<linalg::GenericOp>(indicesTy, indices, empty,
indexingMaps, iterators,
blockBuilder)
.getResult(0);
}
indicesTy = llvm::cast<RankedTensorType>(indicesTy.clone(
{indicesTy.getDimSize(0), indicesTy.getDimSize(1), 2}));
indices = builder.create<tosa::ConcatOp>(indicesTy,
ValueRange{batchIdx, indices},
rewriter.getI32IntegerAttr(2));
}
auto collapseBatch = [](Value value, ImplicitLocOpBuilder &b) -> Value {
auto valueTy = llvm::cast<ShapedType>(value.getType());
llvm::SmallVector<int64_t> collapseShape(valueTy.getShape().drop_front());
llvm::SmallVector<ReassociationExprs> collapseMap(valueTy.getRank() - 1);
collapseMap.front().push_back(b.getAffineDimExpr(0));
for (int i = 0, s = collapseMap.size(); i < s; ++i) {
collapseMap[i].push_back(b.getAffineDimExpr(i + 1));
}
int64_t batch = valueTy.getShape().front();
int64_t rows = collapseShape.front();
bool batchDyn = ShapedType::isDynamic(batch);
bool rowsDyn = ShapedType::isDynamic(rows);
collapseShape[0] =
(batchDyn || rowsDyn) ? ShapedType::kDynamic : batch * rows;
return b.create<tensor::CollapseShapeOp>(valueTy.clone(collapseShape),
value, collapseMap);
};
indices = collapseBatch(indices, builder);
updates = collapseBatch(updates, builder);
// Create the LinalgExt scatter operation.
auto scatter = builder.create<IREE::LinalgExt::ScatterOp>(
TypeRange{values.getType()}, ValueRange{updates, indices},
ValueRange{values}, builder.getDenseI64ArrayAttr({0, 1}),
builder.getBoolAttr(true));
llvm::SmallVector<Type> args(2, valuesTy.getElementType());
Block *scatterBody =
builder.createBlock(&scatter.getRegion(), {}, args,
llvm::SmallVector<Location>(2, op.getLoc()));
builder.setInsertionPointToStart(scatterBody);
builder.create<IREE::LinalgExt::YieldOp>(scatterBody->getArgument(0));
rewriter.replaceOp(op, scatter.getResult(0));
return success();
}
};
struct TosaToLinalgExtPass : public TosaToLinalgExtBase<TosaToLinalgExtPass> {
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
ConversionTarget target(getContext());
target.addIllegalOp<tosa::ScatterOp>();
target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
FunctionOpInterface func = getOperation();
mlir::iree_compiler::populateTosaToLinalgExtPatterns(&patterns);
if (failed(applyFullConversion(func, target, std::move(patterns))))
signalPassFailure();
}
};
void populateTosaToLinalgExtPatterns(RewritePatternSet *patterns) {
patterns->add<ScatterConversion>(patterns->getContext());
}
std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
createTosaToLinalgExt() {
return std::make_unique<TosaToLinalgExtPass>();
}
} // namespace mlir::iree_compiler