blob: cd3eda8e4db253868ec5027258dfa58b2c39b8f4 [file] [log] [blame]
// Copyright 2019 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
#include "iree/compiler/Dialect/Util/IR/UtilTypes.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/StringExtras.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Support/LogicalResult.h"
namespace mlir {
namespace iree_compiler {
namespace IREE {
namespace HAL {
//===----------------------------------------------------------------------===//
// hal.tensor.import/export
//===----------------------------------------------------------------------===//
OpFoldResult TensorImportOp::fold(ArrayRef<Attribute> operands) {
if (auto exportOp = source().getDefiningOp<TensorExportOp>()) {
if (exportOp.source().getType() == target().getType()) {
return exportOp.source();
}
}
return {};
}
OpFoldResult TensorExportOp::fold(ArrayRef<Attribute> operands) {
if (auto importOp = source().getDefiningOp<TensorImportOp>()) {
if (importOp.source().getType() == target().getType()) {
return importOp.source();
}
}
return {};
}
//===----------------------------------------------------------------------===//
// hal.buffer.*
//===----------------------------------------------------------------------===//
namespace {
/// Skips a hal.buffer.allocator accessor when the buffer view was created in
/// the same scope and we know the origin buffer.
struct SkipBufferAllocatorOp : public OpRewritePattern<BufferAllocatorOp> {
using OpRewritePattern<BufferAllocatorOp>::OpRewritePattern;
LogicalResult matchAndRewrite(BufferAllocatorOp op,
PatternRewriter &rewriter) const override {
if (auto allocateOp = dyn_cast_or_null<AllocatorAllocateOp>(
op.buffer().getDefiningOp())) {
rewriter.replaceOp(op, allocateOp.allocator());
return success();
} else if (auto subspanOp = dyn_cast_or_null<BufferSubspanOp>(
op.buffer().getDefiningOp())) {
rewriter.replaceOpWithNewOp<BufferAllocatorOp>(op, op.result().getType(),
subspanOp.source_buffer());
return success();
}
return failure();
}
};
} // namespace
void BufferAllocatorOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
results.insert<SkipBufferAllocatorOp>(context);
}
//===----------------------------------------------------------------------===//
// hal.buffer_view.*
//===----------------------------------------------------------------------===//
namespace {
/// Skips a hal.buffer_view.buffer accessor when the buffer view was created in
/// the same scope and we know the origin buffer.
struct SkipBufferViewBufferOp : public OpRewritePattern<BufferViewBufferOp> {
using OpRewritePattern<BufferViewBufferOp>::OpRewritePattern;
LogicalResult matchAndRewrite(BufferViewBufferOp op,
PatternRewriter &rewriter) const override {
if (auto createOp = dyn_cast_or_null<BufferViewCreateOp>(
op.buffer_view().getDefiningOp())) {
rewriter.replaceOp(op, createOp.buffer());
return success();
}
return failure();
}
};
} // namespace
void BufferViewBufferOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
results.insert<SkipBufferViewBufferOp>(context);
}
namespace {
/// Expands a hal.buffer_view.dims op into individual ops for each dimension.
struct ExpandBufferViewDimsOp : public OpRewritePattern<BufferViewDimsOp> {
using OpRewritePattern<BufferViewDimsOp>::OpRewritePattern;
LogicalResult matchAndRewrite(BufferViewDimsOp op,
PatternRewriter &rewriter) const override {
SmallVector<Value, 4> newDimValues;
for (unsigned i = 0; i < op.getNumResults(); ++i) {
newDimValues.push_back(rewriter.createOrFold<BufferViewDimOp>(
op.getLoc(), rewriter.getIndexType(), op.buffer_view(),
rewriter.getIndexAttr(i)));
}
rewriter.replaceOp(op, {newDimValues});
return success();
}
};
} // namespace
void BufferViewDimsOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
results.insert<ExpandBufferViewDimsOp>(context);
}
//===----------------------------------------------------------------------===//
// hal.command_buffer.*
//===----------------------------------------------------------------------===//
namespace {
/// Skips a hal.command_buffer.device accessor when the device was created in
/// the same scope.
struct SkipCommandBufferDeviceOp
: public OpRewritePattern<CommandBufferDeviceOp> {
using OpRewritePattern<CommandBufferDeviceOp>::OpRewritePattern;
LogicalResult matchAndRewrite(CommandBufferDeviceOp op,
PatternRewriter &rewriter) const override {
if (auto createOp = dyn_cast_or_null<CommandBufferCreateOp>(
op.command_buffer().getDefiningOp())) {
rewriter.replaceOp(op, createOp.device());
return success();
}
return failure();
}
};
} // namespace
void CommandBufferDeviceOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
results.insert<SkipCommandBufferDeviceOp>(context);
}
namespace {
/// Folds hal.buffer.subspans into buffer fill offsets.
struct FoldCommandBufferFillBufferSubspans
: public OpRewritePattern<CommandBufferFillBufferOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(CommandBufferFillBufferOp op,
PatternRewriter &rewriter) const override {
auto ip = rewriter.saveInsertionPoint();
rewriter.setInsertionPoint(op);
bool needsUpdate = false;
auto newTargetBuffer = op.target_buffer();
auto newTargetOffset = op.target_offset();
if (auto subspanOp = dyn_cast_or_null<BufferSubspanOp>(
op.target_buffer().getDefiningOp())) {
newTargetBuffer = subspanOp.source_buffer();
newTargetOffset = rewriter.createOrFold<mlir::arith::AddIOp>(
subspanOp.getLoc(), subspanOp.source_offset(), op.target_offset());
needsUpdate = true;
}
rewriter.restoreInsertionPoint(ip);
if (!needsUpdate) return failure();
rewriter.updateRootInPlace(op, [&]() {
op.target_bufferMutable().assign(newTargetBuffer);
op.target_offsetMutable().assign(newTargetOffset);
});
return success();
}
};
} // namespace
void CommandBufferFillBufferOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
results.insert<FoldCommandBufferFillBufferSubspans>(context);
}
namespace {
/// Folds hal.buffer.subspans into buffer copy offsets.
struct FoldCommandBufferCopyBufferSubspans
: public OpRewritePattern<CommandBufferCopyBufferOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(CommandBufferCopyBufferOp op,
PatternRewriter &rewriter) const override {
auto ip = rewriter.saveInsertionPoint();
rewriter.setInsertionPoint(op);
bool needsUpdate = false;
auto newSourceBuffer = op.source_buffer();
auto newSourceOffset = op.source_offset();
if (auto subspanOp = dyn_cast_or_null<BufferSubspanOp>(
op.source_buffer().getDefiningOp())) {
newSourceBuffer = subspanOp.source_buffer();
newSourceOffset = rewriter.createOrFold<mlir::arith::AddIOp>(
subspanOp.getLoc(), subspanOp.source_offset(), op.source_offset());
needsUpdate = true;
}
auto newTargetBuffer = op.target_buffer();
auto newTargetOffset = op.target_offset();
if (auto subspanOp = dyn_cast_or_null<BufferSubspanOp>(
op.target_buffer().getDefiningOp())) {
newTargetBuffer = subspanOp.source_buffer();
newTargetOffset = rewriter.createOrFold<mlir::arith::AddIOp>(
subspanOp.getLoc(), subspanOp.source_offset(), op.target_offset());
needsUpdate = true;
}
rewriter.restoreInsertionPoint(ip);
if (!needsUpdate) return failure();
rewriter.updateRootInPlace(op, [&]() {
op.source_bufferMutable().assign(newSourceBuffer);
op.source_offsetMutable().assign(newSourceOffset);
op.target_bufferMutable().assign(newTargetBuffer);
op.target_offsetMutable().assign(newTargetOffset);
});
return success();
}
};
} // namespace
void CommandBufferCopyBufferOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
results.insert<FoldCommandBufferCopyBufferSubspans>(context);
}
namespace {
/// Folds hal.buffer.subspans into push descriptor bindings.
/// The binding range is always equal to or a subset of the subspan.
struct FoldCommandBufferPushDescriptorSetBufferSubspan
: public OpRewritePattern<CommandBufferPushDescriptorSetOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(CommandBufferPushDescriptorSetOp op,
PatternRewriter &rewriter) const override {
auto ip = rewriter.saveInsertionPoint();
rewriter.setInsertionPoint(op);
bool needsUpdate = false;
auto bindingBuffers = llvm::to_vector<4>(op.binding_buffers());
auto bindingOffsets = llvm::to_vector<4>(op.binding_offsets());
for (size_t i = 0; i < bindingBuffers.size(); ++i) {
auto *definingOp = bindingBuffers[i].getDefiningOp();
if (!definingOp) continue;
if (auto subspanOp = dyn_cast<BufferSubspanOp>(definingOp)) {
needsUpdate = true;
bindingBuffers[i] = subspanOp.source_buffer();
bindingOffsets[i] = rewriter.createOrFold<mlir::arith::AddIOp>(
subspanOp.getLoc(), subspanOp.source_offset(), bindingOffsets[i]);
}
}
rewriter.restoreInsertionPoint(ip);
if (!needsUpdate) return failure();
rewriter.updateRootInPlace(op, [&]() {
auto mutableBindingBuffers = op.binding_buffersMutable();
mutableBindingBuffers.clear();
mutableBindingBuffers.append(bindingBuffers);
auto mutableBindingOffsets = op.binding_offsetsMutable();
mutableBindingOffsets.clear();
mutableBindingOffsets.append(bindingOffsets);
});
return success();
}
};
} // namespace
void CommandBufferPushDescriptorSetOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
results.insert<FoldCommandBufferPushDescriptorSetBufferSubspan>(context);
}
//===----------------------------------------------------------------------===//
// hal.device.switch
//===----------------------------------------------------------------------===//
// TODO(benvanik): fold conditions with the same IR tree.
// TODO(benvanik): remove duplicate conditions.
// TODO(benvanik): fold condition expressions (any(always, ...) -> always, etc).
// TODO(benvanik): completely replace switches with just one always block.
// TODO(benvanik): remove conditions with no side-effects.
//===----------------------------------------------------------------------===//
// hal.device.match.id
//===----------------------------------------------------------------------===//
// TODO(benvanik): fold matches that are known true based on device config.
} // namespace HAL
} // namespace IREE
} // namespace iree_compiler
} // namespace mlir