blob: c7797d57d955fd0e33dd40dc12b2a0b5cba875a5 [file] [log] [blame]
// Copyright 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
#include "llvm/ADT/StringExtras.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Support/LogicalResult.h"
namespace mlir {
namespace iree_compiler {
namespace IREE {
namespace HAL {
//===----------------------------------------------------------------------===//
// Variables
//===----------------------------------------------------------------------===//
namespace {
/// Converts variable initializer functions that evaluate to a constant to a
/// specified initial value.
struct InlineConstVariableOpInitializer : public OpRewritePattern<VariableOp> {
using OpRewritePattern<VariableOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite(VariableOp op,
PatternRewriter &rewriter) const override {
if (!op.initializer()) return matchFailure();
auto *symbolOp =
SymbolTable::lookupNearestSymbolFrom(op, op.initializer().getValue());
auto initializer = cast<FuncOp>(symbolOp);
if (initializer.getBlocks().size() == 1 &&
initializer.getBlocks().front().getOperations().size() == 2 &&
isa<mlir::ReturnOp>(
initializer.getBlocks().front().getOperations().back())) {
auto &primaryOp = initializer.getBlocks().front().getOperations().front();
Attribute constResult;
if (matchPattern(primaryOp.getResult(0), m_Constant(&constResult))) {
rewriter.replaceOpWithNewOp<VariableOp>(
op, op.sym_name(), op.is_mutable(), op.type(), constResult);
return matchSuccess();
}
}
return matchFailure();
}
};
} // namespace
void VariableOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
results.insert<InlineConstVariableOpInitializer>(context);
}
namespace {
class PropagateVariableLoadAddress
: public OpRewritePattern<VariableLoadIndirectOp> {
using OpRewritePattern::OpRewritePattern;
public:
PatternMatchResult matchAndRewrite(VariableLoadIndirectOp op,
PatternRewriter &rewriter) const override {
if (auto addressOp = dyn_cast_or_null<VariableAddressOp>(
op.variable().getDefiningOp())) {
rewriter.replaceOpWithNewOp<VariableLoadOp>(op, op.result().getType(),
addressOp.variable());
return matchSuccess();
}
return matchFailure();
}
};
} // namespace
void VariableLoadIndirectOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
results.insert<PropagateVariableLoadAddress>(context);
}
namespace {
/// Erases hal.variable.store ops that are no-ops.
/// This can happen if there was a variable load, some DCE'd usage, and a
/// store back to the same variable: we want to be able to elide the entire load
/// and store.
struct EraseUnusedVariableStoreOp : public OpRewritePattern<VariableStoreOp> {
using OpRewritePattern<VariableStoreOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite(VariableStoreOp op,
PatternRewriter &rewriter) const override {
if (auto loadOp =
dyn_cast_or_null<VariableLoadOp>(op.value().getDefiningOp())) {
if (loadOp.variable() == op.variable()) {
rewriter.eraseOp(op);
return matchSuccess();
}
}
return matchFailure();
}
};
} // namespace
void VariableStoreOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
results.insert<EraseUnusedVariableStoreOp>(context);
}
namespace {
class PropagateVariableStoreAddress
: public OpRewritePattern<VariableStoreIndirectOp> {
using OpRewritePattern::OpRewritePattern;
public:
PatternMatchResult matchAndRewrite(VariableStoreIndirectOp op,
PatternRewriter &rewriter) const override {
if (auto addressOp = dyn_cast_or_null<VariableAddressOp>(
op.variable().getDefiningOp())) {
rewriter.replaceOpWithNewOp<VariableStoreOp>(op, op.value(),
addressOp.variable());
return matchSuccess();
}
return matchFailure();
}
};
} // namespace
void VariableStoreIndirectOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
results.insert<PropagateVariableStoreAddress>(context);
}
//===----------------------------------------------------------------------===//
// iree::hal::Buffer
//===----------------------------------------------------------------------===//
namespace {
/// Skips a hal.buffer_view.buffer accessor when the buffer view was created in
/// the same scope and we know the origin buffer.
struct SkipBufferAllocatorOp : public OpRewritePattern<BufferAllocatorOp> {
using OpRewritePattern<BufferAllocatorOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite(BufferAllocatorOp op,
PatternRewriter &rewriter) const override {
if (auto allocateOp = dyn_cast_or_null<AllocatorAllocateOp>(
op.buffer().getDefiningOp())) {
rewriter.replaceOp(op, allocateOp.allocator());
return matchSuccess();
} else if (auto allocateOp = dyn_cast_or_null<AllocatorAllocateConstOp>(
op.buffer().getDefiningOp())) {
rewriter.replaceOp(op, allocateOp.allocator());
return matchSuccess();
}
return matchFailure();
}
};
} // namespace
void BufferAllocatorOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
results.insert<SkipBufferAllocatorOp>(context);
}
//===----------------------------------------------------------------------===//
// iree::hal::BufferView
//===----------------------------------------------------------------------===//
namespace {
/// Expands hal.buffer_view.const to an allocation and buffer view wrapper.
struct ExpandBufferViewConstOp : public OpRewritePattern<BufferViewConstOp> {
using OpRewritePattern<BufferViewConstOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite(BufferViewConstOp op,
PatternRewriter &rewriter) const override {
auto shapedType = op.value().getType();
auto elementType = getElementTypeValue(shapedType.getElementType());
if (!elementType.hasValue()) {
return matchFailure();
}
auto buffer = rewriter.createOrFold<AllocatorAllocateConstOp>(
op.getLoc(), op.allocator(), op.memory_types(), op.buffer_usage(),
op.value());
SmallVector<Value, 4> shape;
if (shapedType.getRank() >= 1) {
for (auto dim : shapedType.getShape()) {
shape.push_back(rewriter.createOrFold<mlir::ConstantOp>(
op.getLoc(),
rewriter.getI32IntegerAttr(static_cast<int32_t>(dim))));
}
}
rewriter.replaceOpWithNewOp<BufferViewCreateOp>(op, buffer, shape,
elementType.getValue());
return matchSuccess();
}
};
} // namespace
void BufferViewConstOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
results.insert<ExpandBufferViewConstOp>(context);
}
namespace {
/// Skips a hal.buffer_view.buffer accessor when the buffer view was created in
/// the same scope and we know the origin buffer.
struct SkipBufferViewBufferOp : public OpRewritePattern<BufferViewBufferOp> {
using OpRewritePattern<BufferViewBufferOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite(BufferViewBufferOp op,
PatternRewriter &rewriter) const override {
if (auto createOp = dyn_cast_or_null<BufferViewCreateOp>(
op.buffer_view().getDefiningOp())) {
rewriter.replaceOp(op, createOp.buffer());
return matchSuccess();
}
return matchFailure();
}
};
} // namespace
void BufferViewBufferOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
results.insert<SkipBufferViewBufferOp>(context);
}
} // namespace HAL
} // namespace IREE
} // namespace iree_compiler
} // namespace mlir