blob: 8ecdf2c71c9dba1fb9e64ca686c3a744ac593e81 [file] [log] [blame]
// Copyright 2021 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include <utility>
#include "iree/compiler/Dialect/Stream/IR/StreamDialect.h"
#include "iree/compiler/Dialect/Stream/IR/StreamOps.h"
#include "iree/compiler/Dialect/Stream/IR/StreamTraits.h"
#include "iree/compiler/Dialect/Stream/Transforms/PassDetail.h"
#include "iree/compiler/Dialect/Stream/Transforms/Passes.h"
#include "iree/compiler/Dialect/Util/IR/UtilDialect.h"
#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
#include "llvm/ADT/TypeSwitch.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/Pass/Pass.h"
namespace mlir {
namespace iree_compiler {
namespace IREE {
namespace Stream {
//===----------------------------------------------------------------------===//
// Base pass utility
//===----------------------------------------------------------------------===//
class Verifier {
public:
enum class Legality {
LEGAL,
RECURSIVELY_LEGAL,
ILLEGAL,
};
using OpVerifierFn = std::function<Optional<Legality>(Operation *op)>;
using TypeVerifierFn = std::function<Legality(Type type)>;
void addIllegalDialect(StringRef dialectName) {
dialectLegality.insert({dialectName, Legality::ILLEGAL});
}
template <typename DialectT>
void addIllegalDialect() {
addIllegalDialect(DialectT::getDialectNamespace());
}
template <typename OpT>
void addLegalOp() {
opLegality.insert({OpT::getOperationName(), Legality::LEGAL});
}
template <typename OpT>
void addRecursivelyLegalOp() {
opLegality.insert({OpT::getOperationName(), Legality::RECURSIVELY_LEGAL});
}
template <typename OpT>
void addIllegalOp() {
opLegality.insert({OpT::getOperationName(), Legality::ILLEGAL});
}
void addOpVerifier(std::function<Optional<Legality>(Operation *)> fn) {
opVerifiers.push_back(fn);
}
template <typename OpT>
void addOpVerifier(std::function<Optional<Legality>(OpT)> fn) {
auto wrapperFn = [=](Operation *baseOp) -> Optional<Legality> {
if (auto op = dyn_cast<OpT>(baseOp)) {
return fn(op);
}
return llvm::None;
};
opVerifiers.push_back(wrapperFn);
}
template <typename TypeT>
void addIllegalType() {
typeLegality.insert({TypeID::get<TypeT>(), Legality::ILLEGAL});
}
template <typename TypeT>
void addTypeVerifier(std::function<Legality(TypeT)> fn) {
auto wrapperFn = [=](Type baseType) { return fn(baseType.cast<TypeT>()); };
if (typeVerifiers.insert({TypeID::get<TypeT>(), wrapperFn}).second ==
false) {
llvm_unreachable("already registered for this type");
}
}
LogicalResult run(Operation *rootOp) {
bool foundAnyIllegal = false;
rootOp->walk<WalkOrder::PreOrder>([&](Operation *op) {
auto walkResult = WalkResult::advance();
// Check for op legality - can skip the expensive work if known-illegal.
auto legality = getOpLegality(op);
switch (legality) {
case Legality::LEGAL:
// Op itself is legal but may not have valid operands/results.
break;
case Legality::RECURSIVELY_LEGAL:
// If the entire op w/ nested ops is legal then skip.
return WalkResult::skip();
default:
case Legality::ILLEGAL:
// Early-exit on illegal ops without recursing.
emitIllegalOpError(op);
foundAnyIllegal = true;
return WalkResult::skip();
}
// Check types for operands/results.
for (auto operandType : llvm::enumerate(op->getOperandTypes())) {
if (isTypeLegal(operandType.value())) continue;
emitIllegalTypeError(op, "operand", operandType.index(),
operandType.value());
foundAnyIllegal = true;
}
for (auto resultType : llvm::enumerate(op->getResultTypes())) {
if (isTypeLegal(resultType.value())) continue;
emitIllegalTypeError(op, "result", resultType.index(),
resultType.value());
foundAnyIllegal = true;
}
return walkResult;
});
return success(!foundAnyIllegal);
}
private:
Legality getOpLegality(Operation *op) {
auto opName = op->getName();
// Check specific ops first (we may override dialect settings).
{
auto legalityIt = opLegality.find(opName.getStringRef());
if (legalityIt != opLegality.end()) {
return legalityIt->second;
}
}
// Check all op verifiers (usually used for interface checks).
for (auto &opVerifier : opVerifiers) {
auto legalOr = opVerifier(op);
if (legalOr.hasValue()) {
return legalOr.getValue();
}
}
// If no op carveout is applied then check to see if the dialect is
// allowed at all.
{
auto legalityIt = dialectLegality.find(opName.getDialectNamespace());
if (legalityIt != dialectLegality.end()) {
return legalityIt->second;
}
}
// Assume legal by default.
return Legality::LEGAL;
}
bool isTypeLegal(Type type) {
// TODO(benvanik): subelements interface checks using recursive legality.
// Defer to verifiers first.
auto it = typeVerifiers.find(type.getTypeID());
if (it != typeVerifiers.end()) {
return it->second(type) != Legality::ILLEGAL;
}
// Check legality of the base type.
{
auto legalityIt = typeLegality.find(type.getTypeID());
if (legalityIt != typeLegality.end()) {
return legalityIt->second != Legality::ILLEGAL;
}
}
// Assume legal by default.
return true;
}
void emitIllegalOpError(Operation *op) {
op->emitOpError()
<< "illegal for this phase of lowering in the stream dialect; "
"expected to have been converted or removed";
}
void emitIllegalTypeError(Operation *op, StringRef location, unsigned idx,
Type type) {
op->emitOpError()
<< location << " " << idx << " type " << type
<< " illegal for this phase of lowering in the stream dialect";
}
DenseMap<StringRef, Legality> dialectLegality;
DenseMap<StringRef, Legality> opLegality;
SmallVector<OpVerifierFn> opVerifiers;
DenseMap<TypeID, Legality> typeLegality;
DenseMap<TypeID, TypeVerifierFn> typeVerifiers;
};
static void markStreamTensorOpsIllegal(Verifier &verifier) {
verifier.addOpVerifier([](Operation *op) -> Optional<Verifier::Legality> {
if (op->hasTrait<OpTrait::IREE::Stream::TensorPhaseOp>()) {
return Verifier::Legality::ILLEGAL;
}
return llvm::None;
});
}
static void markStreamAsyncOpsIllegal(Verifier &verifier) {
verifier.addOpVerifier([](Operation *op) -> Optional<Verifier::Legality> {
if (op->hasTrait<OpTrait::IREE::Stream::AsyncPhaseOp>()) {
return Verifier::Legality::ILLEGAL;
}
return llvm::None;
});
}
static void markStreamCmdOpsIllegal(Verifier &verifier) {
verifier.addOpVerifier([](Operation *op) -> Optional<Verifier::Legality> {
if (op->hasTrait<OpTrait::IREE::Stream::CmdPhaseOp>()) {
return Verifier::Legality::ILLEGAL;
}
return llvm::None;
});
}
//===----------------------------------------------------------------------===//
// -iree-stream-verify-input
//===----------------------------------------------------------------------===//
namespace {
class VerifyInputPass : public VerifyInputBase<VerifyInputPass> {
public:
VerifyInputPass() = default;
void runOnOperation() override {
Verifier verifier;
// TODO(#7432): add indirect global expansion support to streams.
verifier.addIllegalOp<IREE::Util::GlobalAddressOp>();
verifier.addIllegalOp<IREE::Util::GlobalLoadIndirectOp>();
verifier.addIllegalOp<IREE::Util::GlobalStoreIndirectOp>();
if (failed(verifier.run(getOperation()))) {
return signalPassFailure();
}
}
};
} // namespace
std::unique_ptr<OperationPass<mlir::ModuleOp>> createVerifyInputPass() {
return std::make_unique<VerifyInputPass>();
}
//===----------------------------------------------------------------------===//
// -iree-stream-verify-lowering-to-tensors
//===----------------------------------------------------------------------===//
static void markTensorInputsIllegal(Verifier &verifier) {
// Tensorish dialects should all be either converted or outlined into
// executables. Everything should be in resources now.
verifier.addIllegalDialect("tensor");
verifier.addIllegalDialect("linalg");
// We don't allow the flow dialect except for inside of executables for which
// we don't yet have a full mapping to in the stream dialect.
// TODO(#7277): remove this carveout once we switch over to streams fully.
verifier.addIllegalDialect("flow");
verifier.addRecursivelyLegalOp<IREE::Stream::ExecutableOp>();
}
namespace {
class VerifyLoweringToTensorsPass
: public VerifyLoweringToTensorsBase<VerifyLoweringToTensorsPass> {
public:
VerifyLoweringToTensorsPass() = default;
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<IREE::Stream::StreamDialect>();
registry.insert<IREE::Util::UtilDialect>();
}
void runOnOperation() override {
// We cannot have stream.cmd.* ops mixed with stream.tensor/async.* ops
// as they use different memory models.
Verifier verifier;
markTensorInputsIllegal(verifier);
markStreamCmdOpsIllegal(verifier);
if (failed(verifier.run(getOperation()))) {
return signalPassFailure();
}
}
};
} // namespace
std::unique_ptr<OperationPass<mlir::ModuleOp>>
createVerifyLoweringToTensorsPass() {
return std::make_unique<VerifyLoweringToTensorsPass>();
}
//===----------------------------------------------------------------------===//
// -iree-stream-verify-lowering-to-tensors
//===----------------------------------------------------------------------===//
namespace {
class VerifyLoweringToAsyncPass
: public VerifyLoweringToAsyncBase<VerifyLoweringToAsyncPass> {
public:
VerifyLoweringToAsyncPass() = default;
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<IREE::Stream::StreamDialect>();
registry.insert<IREE::Util::UtilDialect>();
}
void runOnOperation() override {
// We cannot have stream.cmd.* ops mixed with stream.tensor/async.* ops
// as they use different memory models.
Verifier verifier;
markTensorInputsIllegal(verifier);
markStreamTensorOpsIllegal(verifier);
markStreamCmdOpsIllegal(verifier);
// All resources should have had their usage assigned.
verifier.addTypeVerifier<IREE::Stream::ResourceType>([](auto type) {
if (type.getLifetime() == IREE::Stream::Lifetime::Unknown) {
return Verifier::Legality::ILLEGAL;
}
return Verifier::Legality::LEGAL;
});
// All streamable ops should be inside of execution regions.
verifier.addOpVerifier<IREE::Stream::StreamableOpInterface>(
[](auto op) -> Optional<Verifier::Legality> {
// Allow metadata ops outside of execution regions.
if (op.isMetadata()) return Verifier::Legality::LEGAL;
// TODO(benvanik): execution region interface to make this generic.
if (!op->template getParentOfType<IREE::Stream::AsyncExecuteOp>()) {
op->emitOpError()
<< ": streamable op expected to be in an execution region";
return Verifier::Legality::ILLEGAL;
}
return llvm::None;
});
if (failed(verifier.run(getOperation()))) {
return signalPassFailure();
}
}
};
} // namespace
std::unique_ptr<OperationPass<mlir::ModuleOp>>
createVerifyLoweringToAsyncPass() {
return std::make_unique<VerifyLoweringToAsyncPass>();
}
//===----------------------------------------------------------------------===//
// -iree-stream-verify-lowering-to-cmd
//===----------------------------------------------------------------------===//
namespace {
class VerifyLoweringToCmdPass
: public VerifyLoweringToCmdBase<VerifyLoweringToCmdPass> {
public:
VerifyLoweringToCmdPass() = default;
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<IREE::Stream::StreamDialect>();
registry.insert<IREE::Util::UtilDialect>();
}
void runOnOperation() override {
Verifier verifier;
markTensorInputsIllegal(verifier);
markStreamTensorOpsIllegal(verifier);
markStreamAsyncOpsIllegal(verifier);
// All resources should have had their usage assigned.
verifier.addTypeVerifier<IREE::Stream::ResourceType>([](auto type) {
if (type.getLifetime() == IREE::Stream::Lifetime::Unknown) {
return Verifier::Legality::ILLEGAL;
}
return Verifier::Legality::LEGAL;
});
if (failed(verifier.run(getOperation()))) {
return signalPassFailure();
}
}
};
} // namespace
std::unique_ptr<OperationPass<mlir::ModuleOp>> createVerifyLoweringToCmdPass() {
return std::make_unique<VerifyLoweringToCmdPass>();
}
} // namespace Stream
} // namespace IREE
} // namespace iree_compiler
} // namespace mlir