blob: 15dfe3273aa1c6192a8e184f97fe24e75b223655 [file] [log] [blame]
// Copyright 2021 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree/compiler/Dialect/Stream/IR/StreamDialect.h"
#include "iree/compiler/Dialect/Stream/IR/StreamOps.h"
#include "iree/compiler/Dialect/Stream/IR/StreamTypes.h"
#include "iree/compiler/Dialect/Util/IR/UtilDialect.h"
#include "llvm/Support/SourceMgr.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/Parser/Parser.h"
#include "mlir/Transforms/FoldUtils.h"
#include "mlir/Transforms/InliningUtils.h"
namespace mlir::iree_compiler::IREE::Stream {
namespace {
// Used to control inlining behavior.
struct StreamInlinerInterface : public DialectInlinerInterface {
using DialectInlinerInterface::DialectInlinerInterface;
bool isLegalToInline(Operation *call, Operation *callable,
bool wouldBeCloned) const final {
// Sure!
return true;
}
bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
IRMapping &valueMapping) const final {
// Sure!
return true;
}
bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned,
IRMapping &valueMapping) const final {
// Sure!
return true;
}
};
struct StreamFolderInterface : public DialectFoldInterface {
using DialectFoldInterface::DialectFoldInterface;
bool shouldMaterializeInto(Region *region) const override {
// TODO(benvanik): redirect constants to the region scope when small.
return false;
}
};
// Tries to fold away unrealized_conversion_cast ops if the downstream consumers
// don't need the extra information. These are inserted during conversion or
// transforms that may interop with external dialects.
//
// Specifically matches:
// %0 = builtin.unrealized_conversion_cast %arg0, %arg1 :
// !stream.resource<transient>, index to !stream.resource<transient>
struct StripResourceConversionCastPattern
: public OpRewritePattern<UnrealizedConversionCastOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(UnrealizedConversionCastOp castOp,
PatternRewriter &rewriter) const override {
auto result = castOp.getResult(0);
if (!llvm::isa<IREE::Stream::ResourceType>(result.getType()))
return failure();
assert(castOp.getNumOperands() == 2 &&
"expect resource, index -> resource");
auto resourceValue = castOp.getOperand(0);
auto sizeValue = castOp.getOperand(1);
for (auto &use : llvm::make_early_inc_range(result.getUses())) {
if (auto sizeOp =
dyn_cast<IREE::Stream::ResourceSizeOp>(use.getOwner())) {
rewriter.replaceOp(sizeOp, sizeValue);
} else {
rewriter.updateRootInPlace(use.getOwner(),
[&]() { use.set(resourceValue); });
}
}
rewriter.eraseOp(castOp);
return success();
}
};
} // namespace
StreamDialect::StreamDialect(MLIRContext *context)
: Dialect(getDialectNamespace(), context, TypeID::get<StreamDialect>()) {
context->loadDialect<IREE::Util::UtilDialect>();
context->loadDialect<mlir::complex::ComplexDialect>();
registerAttributes();
registerTypes();
#define GET_OP_LIST
addOperations<
#include "iree/compiler/Dialect/Stream/IR/StreamOps.cpp.inc"
>();
addInterfaces<StreamInlinerInterface, StreamFolderInterface>();
}
void StreamDialect::getCanonicalizationPatterns(
RewritePatternSet &results) const {
results.insert<StripResourceConversionCastPattern>(getContext());
}
Operation *StreamDialect::materializeConstant(OpBuilder &builder,
Attribute value, Type type,
Location loc) {
if (mlir::func::ConstantOp::isBuildableWith(value, type)) {
return builder.create<mlir::func::ConstantOp>(
loc, type, llvm::cast<FlatSymbolRefAttr>(value));
} else if (arith::ConstantOp::isBuildableWith(value, type)) {
return builder.create<arith::ConstantOp>(loc, type, cast<TypedAttr>(value));
} else if (llvm::isa<IREE::Stream::TimepointAttr>(value)) {
return builder.create<IREE::Stream::TimepointImmediateOp>(loc);
}
return nullptr;
}
} // namespace mlir::iree_compiler::IREE::Stream