blob: 1543705ce0a3cb971029a8647fafd383ff2b2f42 [file] [log] [blame]
// Copyright 2019 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
#include "iree/compiler/Dialect/VM/Conversion/ImportUtils.h"
#include "iree/compiler/Dialect/VM/IR/VMOps.h"
#include "mlir/Transforms/DialectConversion.h"
namespace mlir {
namespace iree_compiler {
namespace {
class CommandBufferPushDescriptorSetOpConversion
: public OpConversionPattern<IREE::HAL::CommandBufferPushDescriptorSetOp> {
public:
CommandBufferPushDescriptorSetOpConversion(MLIRContext *context,
SymbolTable &importSymbols,
TypeConverter &typeConverter,
StringRef importName)
: OpConversionPattern(context) {
importOp = importSymbols.lookup<IREE::VM::ImportOp>(importName);
assert(importOp);
}
LogicalResult matchAndRewrite(
IREE::HAL::CommandBufferPushDescriptorSetOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto importType = importOp.getType();
SmallVector<Value, 8> callOperands = {
adaptor.command_buffer(),
adaptor.executable_layout(),
adaptor.set(),
};
SmallVector<int16_t, 5> segmentSizes = {
/*command_buffer=*/-1,
/*executable_layout=*/-1,
/*set=*/-1,
/*bindings=*/
static_cast<int16_t>(adaptor.binding_ordinals().size()),
};
for (size_t i = 0; i < adaptor.binding_ordinals().size(); ++i) {
callOperands.push_back(adaptor.binding_ordinals()[i]);
callOperands.push_back(adaptor.binding_buffers()[i]);
callOperands.push_back(adaptor.binding_offsets()[i]);
callOperands.push_back(adaptor.binding_lengths()[i]);
}
auto callOp = rewriter.replaceOpWithNewOp<IREE::VM::CallVariadicOp>(
op, SymbolRefAttr::get(importOp), importType.getResults(), segmentSizes,
importType.getInputs(), callOperands);
copyImportAttrs(importOp, callOp);
return success();
}
private:
mutable IREE::VM::ImportOp importOp;
};
} // namespace
void populateHALCommandBufferToVMPatterns(MLIRContext *context,
SymbolTable &importSymbols,
TypeConverter &typeConverter,
OwningRewritePatternList &patterns) {
patterns.insert<VMImportOpConversion<IREE::HAL::CommandBufferCreateOp>>(
context, importSymbols, typeConverter, "hal.command_buffer.create");
patterns.insert<VMImportOpConversion<IREE::HAL::CommandBufferBeginOp>>(
context, importSymbols, typeConverter, "hal.command_buffer.begin");
patterns.insert<VMImportOpConversion<IREE::HAL::CommandBufferEndOp>>(
context, importSymbols, typeConverter, "hal.command_buffer.end");
patterns
.insert<VMImportOpConversion<IREE::HAL::CommandBufferBeginDebugGroupOp>>(
context, importSymbols, typeConverter,
"hal.command_buffer.begin_debug_group");
patterns
.insert<VMImportOpConversion<IREE::HAL::CommandBufferEndDebugGroupOp>>(
context, importSymbols, typeConverter,
"hal.command_buffer.end_debug_group");
patterns
.insert<VMImportOpConversion<IREE::HAL::CommandBufferExecutionBarrierOp>>(
context, importSymbols, typeConverter,
"hal.command_buffer.execution_barrier");
patterns.insert<VMImportOpConversion<IREE::HAL::CommandBufferFillBufferOp>>(
context, importSymbols, typeConverter, "hal.command_buffer.fill_buffer");
patterns.insert<VMImportOpConversion<IREE::HAL::CommandBufferCopyBufferOp>>(
context, importSymbols, typeConverter, "hal.command_buffer.copy_buffer");
patterns
.insert<VMImportOpConversion<IREE::HAL::CommandBufferPushConstantsOp>>(
context, importSymbols, typeConverter,
"hal.command_buffer.push_constants");
patterns.insert<CommandBufferPushDescriptorSetOpConversion>(
context, importSymbols, typeConverter,
"hal.command_buffer.push_descriptor_set");
patterns.insert<
VMImportOpConversion<IREE::HAL::CommandBufferBindDescriptorSetOp>>(
context, importSymbols, typeConverter,
"hal.command_buffer.bind_descriptor_set");
patterns.insert<VMImportOpConversion<IREE::HAL::CommandBufferDispatchOp>>(
context, importSymbols, typeConverter, "hal.command_buffer.dispatch");
patterns
.insert<VMImportOpConversion<IREE::HAL::CommandBufferDispatchIndirectOp>>(
context, importSymbols, typeConverter,
"hal.command_buffer.dispatch.indirect");
}
} // namespace iree_compiler
} // namespace mlir