blob: 68c60b6bd1c751a53f7a00a7134ea58a6c780e9f [file] [log] [blame]
// Copyright 2019 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree/compiler/Dialect/HAL/Conversion/ConversionTarget.h"
#include "iree/compiler/Dialect/HAL/Conversion/TypeConverter.h"
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
#include "iree/compiler/Dialect/HAL/Utils/TypeUtils.h"
#include "iree/compiler/Dialect/Shape/IR/ShapeOps.h"
#include "iree/compiler/Dialect/Util/IR/UtilTypes.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
namespace mlir {
namespace iree_compiler {
HALConversionTarget::HALConversionTarget(MLIRContext *context,
TypeConverter &typeConverter)
: ConversionTarget(*context) {
// The HAL dialect allows hal ops as input as we may be running on partially
// processed files or may have already lowered some constructs (like constant
// pools).
addLegalDialect("hal");
// We don't care about the contents of a HAL executable: it may have any kind
// of dialect and type usage.
addLegalOp<IREE::HAL::ExecutableOp>();
markOpRecursivelyLegal<IREE::HAL::ExecutableOp>();
// There are a variety of patterns which convert std.dim and std.rank ops
// to corresponding HAL ops. All should be eliminated.
addIllegalOp<memref::DimOp>();
addIllegalOp<mlir::RankOp>();
addIllegalOp<tensor::DimOp>();
// Metadata ops are dynamically legal if their types are legal.
addDynamicallyLegalOp<Shape::TieShapeOp>([&](Shape::TieShapeOp op) {
return typeConverter.isLegal(op.result().getType());
});
// Setup the fallback handler such that all ops without explicitly
// registered patterns will be checked to ensure that they don't use any
// illegal types.
markUnknownOpDynamicallyLegal([&](Operation *op) {
// Short-circuit test that bails on the first illegal type.
const auto isTypeIllegal = [&](Type type) {
return !typeConverter.isLegal(type);
};
return !(llvm::any_of(op->getOperandTypes(), isTypeIllegal) ||
llvm::any_of(op->getResultTypes(), isTypeIllegal));
});
}
// static
LogicalResult HALConversionTarget::applyDefaultBufferRewrite(
Operation *srcOp, ArrayRef<Value> operands, StringRef dstOpName,
TypeConverter &typeConverter, ConversionPatternRewriter &rewriter) {
OperationState state{srcOp->getLoc(), dstOpName};
state.addAttributes(srcOp->getAttrs());
for (auto srcDstOperand : llvm::zip(srcOp->getOperands(), operands)) {
auto srcOperand = std::get<0>(srcDstOperand);
auto dstOperand = std::get<1>(srcDstOperand);
if (HALTypeConverter::shouldConvertToBuffer(srcOperand.getType())) {
// Create the buffer view that we'll pass to the function.
// Note that we expect this to be CSE'd if there are multiple calls
// using the same buffer.
auto operand = IREE::HAL::TensorRewriteAdaptor::getChecked(
srcOp->getLoc(), srcOperand, dstOperand, rewriter);
if (!operand.hasValue()) {
return srcOp->emitOpError() << "unable to create adaptor for operand";
}
auto bufferView = operand->getBufferView();
if (!bufferView) {
return srcOp->emitOpError() << "unable to get buffer view for operand";
}
state.addOperands({bufferView});
} else {
// Normal pass-through operand.
state.addOperands({dstOperand});
}
}
for (auto resultType : srcOp->getResultTypes()) {
if (HALTypeConverter::shouldConvertToBuffer(resultType)) {
state.addTypes(IREE::HAL::BufferViewType::get(rewriter.getContext()));
} else {
// Normal pass-through result.
if (failed(typeConverter.convertType(resultType, state.types))) {
return failure();
}
}
}
auto *dstOp = rewriter.createOperation(state);
// Now unpack any of the buffer views we may have returned.
SmallVector<Value, 4> results;
for (auto resultTypeValue :
llvm::zip(srcOp->getResultTypes(), dstOp->getResults())) {
Type resultType;
Value resultValue;
std::tie(resultType, resultValue) = resultTypeValue;
if (HALTypeConverter::shouldConvertToBuffer(resultType)) {
results.push_back(rewriter.createOrFold<IREE::HAL::BufferViewBufferOp>(
srcOp->getLoc(), IREE::HAL::BufferType::get(rewriter.getContext()),
resultValue));
} else {
results.push_back(resultValue);
}
}
rewriter.replaceOp(srcOp, results);
return success();
}
} // namespace iree_compiler
} // namespace mlir