blob: 780774746532c9a26f74566d95fe3bae0730b430 [file] [log] [blame]
// Copyright 2019 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
#include "iree/compiler/Dialect/Util/IR/UtilTypes.h"
#include "iree/compiler/Dialect/VM/Conversion/ImportUtils.h"
#include "iree/compiler/Dialect/VM/IR/VMOps.h"
#include "mlir/Transforms/DialectConversion.h"
namespace mlir {
namespace iree_compiler {
namespace {
class CommandBufferFillBufferOpConversion
: public OpConversionPattern<IREE::HAL::CommandBufferFillBufferOp> {
public:
CommandBufferFillBufferOpConversion(MLIRContext *context,
SymbolTable &importSymbols,
TypeConverter &typeConverter,
StringRef importName)
: OpConversionPattern(typeConverter, context) {
importOp = importSymbols.lookup<IREE::VM::ImportOp>(importName);
assert(importOp);
}
LogicalResult matchAndRewrite(
IREE::HAL::CommandBufferFillBufferOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto importType = importOp.getFunctionType();
SmallVector<Value, 8> callOperands = {
adaptor.command_buffer(),
adaptor.target_buffer(),
adaptor.target_offset(),
adaptor.length(),
};
// Record the original pattern length then extend it to a 32 bit integer.
auto originalPatternType = op.pattern().getType();
auto patternBitWidth = originalPatternType.getIntOrFloatBitWidth();
// The pattern length (in bytes) will be used at runtime to issue the fill
// command. While the pattern itself will be stored in a 32 bit integer,
// the fill operation will use this length to slice a potentially smaller
// range of bits from the full pattern.
auto patternLengthBytes =
IREE::Util::getRoundedElementByteWidth(originalPatternType);
auto patternLengthConst = rewriter.createOrFold<mlir::arith::ConstantIntOp>(
op.getLoc(), patternLengthBytes, 32);
Value pattern = op.pattern();
if (patternBitWidth < 32) {
pattern = rewriter.createOrFold<arith::ExtUIOp>(
op.getLoc(), rewriter.getIntegerType(32), pattern);
}
callOperands.push_back(pattern);
callOperands.push_back(patternLengthConst);
auto callOp = rewriter.replaceOpWithNewOp<IREE::VM::CallOp>(
op, SymbolRefAttr::get(importOp), importType.getResults(),
callOperands);
copyImportAttrs(importOp, callOp);
return success();
}
private:
mutable IREE::VM::ImportOp importOp;
};
class CommandBufferPushDescriptorSetOpConversion
: public OpConversionPattern<IREE::HAL::CommandBufferPushDescriptorSetOp> {
public:
CommandBufferPushDescriptorSetOpConversion(MLIRContext *context,
SymbolTable &importSymbols,
TypeConverter &typeConverter,
StringRef importName)
: OpConversionPattern(context) {
importOp = importSymbols.lookup<IREE::VM::ImportOp>(importName);
assert(importOp);
}
LogicalResult matchAndRewrite(
IREE::HAL::CommandBufferPushDescriptorSetOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto importType = importOp.getFunctionType();
SmallVector<Value, 8> callOperands = {
adaptor.command_buffer(),
adaptor.executable_layout(),
adaptor.set(),
};
SmallVector<int16_t, 5> segmentSizes = {
/*command_buffer=*/-1,
/*executable_layout=*/-1,
/*set=*/-1,
/*bindings=*/
static_cast<int16_t>(adaptor.binding_ordinals().size()),
};
for (size_t i = 0; i < adaptor.binding_ordinals().size(); ++i) {
callOperands.push_back(adaptor.binding_ordinals()[i]);
callOperands.push_back(adaptor.binding_buffers()[i]);
callOperands.push_back(adaptor.binding_offsets()[i]);
callOperands.push_back(adaptor.binding_lengths()[i]);
}
auto callOp = rewriter.replaceOpWithNewOp<IREE::VM::CallVariadicOp>(
op, SymbolRefAttr::get(importOp), importType.getResults(), segmentSizes,
importType.getInputs(), callOperands);
copyImportAttrs(importOp, callOp);
return success();
}
private:
mutable IREE::VM::ImportOp importOp;
};
} // namespace
void populateHALCommandBufferToVMPatterns(MLIRContext *context,
SymbolTable &importSymbols,
TypeConverter &typeConverter,
RewritePatternSet &patterns) {
patterns.insert<VMImportOpConversion<IREE::HAL::CommandBufferCreateOp>>(
context, importSymbols, typeConverter, "hal.command_buffer.create");
patterns.insert<VMImportOpConversion<IREE::HAL::CommandBufferBeginOp>>(
context, importSymbols, typeConverter, "hal.command_buffer.begin");
patterns.insert<VMImportOpConversion<IREE::HAL::CommandBufferEndOp>>(
context, importSymbols, typeConverter, "hal.command_buffer.end");
patterns
.insert<VMImportOpConversion<IREE::HAL::CommandBufferBeginDebugGroupOp>>(
context, importSymbols, typeConverter,
"hal.command_buffer.begin_debug_group");
patterns
.insert<VMImportOpConversion<IREE::HAL::CommandBufferEndDebugGroupOp>>(
context, importSymbols, typeConverter,
"hal.command_buffer.end_debug_group");
patterns
.insert<VMImportOpConversion<IREE::HAL::CommandBufferExecutionBarrierOp>>(
context, importSymbols, typeConverter,
"hal.command_buffer.execution_barrier");
patterns.insert<CommandBufferFillBufferOpConversion>(
context, importSymbols, typeConverter, "hal.command_buffer.fill_buffer");
patterns.insert<VMImportOpConversion<IREE::HAL::CommandBufferCopyBufferOp>>(
context, importSymbols, typeConverter, "hal.command_buffer.copy_buffer");
patterns
.insert<VMImportOpConversion<IREE::HAL::CommandBufferPushConstantsOp>>(
context, importSymbols, typeConverter,
"hal.command_buffer.push_constants");
patterns.insert<CommandBufferPushDescriptorSetOpConversion>(
context, importSymbols, typeConverter,
"hal.command_buffer.push_descriptor_set");
patterns.insert<
VMImportOpConversion<IREE::HAL::CommandBufferBindDescriptorSetOp>>(
context, importSymbols, typeConverter,
"hal.command_buffer.bind_descriptor_set");
patterns.insert<VMImportOpConversion<IREE::HAL::CommandBufferDispatchOp>>(
context, importSymbols, typeConverter, "hal.command_buffer.dispatch");
patterns
.insert<VMImportOpConversion<IREE::HAL::CommandBufferDispatchIndirectOp>>(
context, importSymbols, typeConverter,
"hal.command_buffer.dispatch.indirect");
}
} // namespace iree_compiler
} // namespace mlir