blob: 38186579b4db290a7c702bd58cdb619fc0430ed1 [file] [log] [blame]
// Copyright 2024 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree/compiler/ExternalInterfaces/LinalgExtExternalModels.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/BuiltinAttributes.h"
namespace mlir::iree_compiler {
namespace {
// Used to register the LinalgFusionOpInterface with the linalg ops.
template <typename ConcreteType>
struct LinalgFusionOpInterfaceAdapter
: public IREE::LinalgExt::LinalgFusionOpInterface::ExternalModel<
LinalgFusionOpInterfaceAdapter<ConcreteType>, ConcreteType> {
public:
SmallVector<AffineMap> getIndexingMapsForOperands(mlir::Operation *op) const {
auto maps = llvm::cast<ConcreteType>(op)
.getIndexingMaps()
.template getAsValueRange<AffineMapAttr>();
return {maps.begin(),
maps.end() - llvm::cast<ConcreteType>(op).getNumResults()};
}
SmallVector<AffineMap> getIndexingMapsForResults(mlir::Operation *op) const {
auto maps = llvm::cast<ConcreteType>(op)
.getIndexingMaps()
.template getAsValueRange<AffineMapAttr>();
return {maps.end() - llvm::cast<ConcreteType>(op).getNumResults(),
maps.end()};
}
// Forward all the interface methods to the corresponding linalg op.
unsigned getNumParallelLoops(mlir::Operation *op) const {
return (llvm::cast<ConcreteType>(op).getNumParallelLoops());
}
unsigned getNumLoops(mlir::Operation *op) const {
return (llvm::cast<ConcreteType>(op).getNumLoops());
}
FailureOr<SmallVector<int64_t>>
getStaticLoopRanges(mlir::Operation *op) const {
return SmallVector<int64_t>(
llvm::cast<ConcreteType>(op).getStaticLoopRanges());
}
AffineMap getIndexingMapMatchingResult(mlir::Operation *op,
OpResult result) const {
return (llvm::cast<ConcreteType>(op).getIndexingMapMatchingResult(result));
}
AffineMap getMatchingIndexingMap(mlir::Operation *op,
OpOperand *operand) const {
return (llvm::cast<ConcreteType>(op).getMatchingIndexingMap(operand));
}
SmallVector<AffineMap> getIndexingMapsArray(mlir::Operation *op) const {
auto inputMaps = getIndexingMapsForOperands(op);
llvm::append_range(inputMaps, getIndexingMapsForResults(op));
return inputMaps;
}
};
struct SoftmaxFusionOpInterfaceAdapter
: public IREE::LinalgExt::LinalgFusionOpInterface::ExternalModel<
SoftmaxFusionOpInterfaceAdapter, linalg::SoftmaxOp> {
public:
SmallVector<AffineMap> getIndexingMapsForOperands(mlir::Operation *op) const {
Builder b(op->getContext());
return llvm::to_vector(llvm::map_range(
llvm::cast<linalg::SoftmaxOp>(op).getDpsInputs(),
[&b](Value operand) -> AffineMap {
auto rank = cast<ShapedType>(operand.getType()).getRank();
return b.getMultiDimIdentityMap(rank);
}));
}
SmallVector<AffineMap> getIndexingMapsForResults(mlir::Operation *op) const {
Builder b(op->getContext());
return llvm::to_vector(llvm::map_range(
llvm::cast<linalg::SoftmaxOp>(op).getDpsInits(),
[&b](Value operand) -> AffineMap {
auto rank = cast<ShapedType>(operand.getType()).getRank();
return b.getMultiDimIdentityMap(rank);
}));
}
FailureOr<SmallVector<int64_t>> getStaticLoopRanges(Operation *op) const {
auto softmaxOp = cast<linalg::SoftmaxOp>(op);
// Softmax loop range is the input shape.
return SmallVector<int64_t>(softmaxOp.getInputOperandType().getShape());
}
AffineMap getIndexingMapMatchingResult(mlir::Operation *op,
OpResult result) const {
return getIndexingMapsForResults(op)[result.getResultNumber()];
}
AffineMap getMatchingIndexingMap(mlir::Operation *op,
OpOperand *operand) const {
return getIndexingMapsForOperands(op)[operand->getOperandNumber()];
}
SmallVector<AffineMap> getIndexingMapsArray(mlir::Operation *op) const {
auto inputMaps = getIndexingMapsForOperands(op);
llvm::append_range(inputMaps, getIndexingMapsForResults(op));
return inputMaps;
}
};
template <typename... Args>
void registerOpsWithLinalgExtOpInterface(mlir::MLIRContext *context) {
(Args::template attachInterface<LinalgFusionOpInterfaceAdapter<Args>>(
*context),
...);
}
} // namespace
void registerLinalgExtExternalModels(DialectRegistry &registry) {
registry.addExtension(+[](MLIRContext *ctx,
IREE::LinalgExt::IREELinalgExtDialect *dialect) {
ctx->loadDialect<mlir::linalg::LinalgDialect>();
#define GET_OP_LIST
registerOpsWithLinalgExtOpInterface<
#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
>(ctx);
linalg::SoftmaxOp::attachInterface<SoftmaxFusionOpInterfaceAdapter>(*ctx);
});
}
} // namespace mlir::iree_compiler