| // Copyright 2019 The IREE Authors |
| // |
| // Licensed under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| |
| #include "iree/compiler/Dialect/HAL/IR/HALOps.h" |
| #include "iree/compiler/Dialect/VM/Conversion/ImportUtils.h" |
| #include "iree/compiler/Dialect/VM/IR/VMOps.h" |
| #include "mlir/Transforms/DialectConversion.h" |
| |
| namespace mlir { |
| namespace iree_compiler { |
| namespace { |
| |
| class CommandBufferPushDescriptorSetOpConversion |
| : public OpConversionPattern<IREE::HAL::CommandBufferPushDescriptorSetOp> { |
| public: |
| CommandBufferPushDescriptorSetOpConversion(MLIRContext *context, |
| SymbolTable &importSymbols, |
| TypeConverter &typeConverter, |
| StringRef importName) |
| : OpConversionPattern(context) { |
| importOp = importSymbols.lookup<IREE::VM::ImportOp>(importName); |
| assert(importOp); |
| } |
| |
| LogicalResult matchAndRewrite( |
| IREE::HAL::CommandBufferPushDescriptorSetOp op, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| auto importType = importOp.getType(); |
| |
| SmallVector<Value, 8> callOperands = { |
| adaptor.command_buffer(), |
| adaptor.executable_layout(), |
| adaptor.set(), |
| }; |
| SmallVector<int16_t, 5> segmentSizes = { |
| /*command_buffer=*/-1, |
| /*executable_layout=*/-1, |
| /*set=*/-1, |
| /*bindings=*/ |
| static_cast<int16_t>(adaptor.binding_ordinals().size()), |
| }; |
| for (size_t i = 0; i < adaptor.binding_ordinals().size(); ++i) { |
| callOperands.push_back(adaptor.binding_ordinals()[i]); |
| callOperands.push_back(adaptor.binding_buffers()[i]); |
| callOperands.push_back(adaptor.binding_offsets()[i]); |
| callOperands.push_back(adaptor.binding_lengths()[i]); |
| } |
| |
| auto callOp = rewriter.replaceOpWithNewOp<IREE::VM::CallVariadicOp>( |
| op, SymbolRefAttr::get(importOp), importType.getResults(), segmentSizes, |
| importType.getInputs(), callOperands); |
| copyImportAttrs(importOp, callOp); |
| return success(); |
| } |
| |
| private: |
| mutable IREE::VM::ImportOp importOp; |
| }; |
| |
| } // namespace |
| |
| void populateHALCommandBufferToVMPatterns(MLIRContext *context, |
| SymbolTable &importSymbols, |
| TypeConverter &typeConverter, |
| OwningRewritePatternList &patterns) { |
| patterns.insert<VMImportOpConversion<IREE::HAL::CommandBufferCreateOp>>( |
| context, importSymbols, typeConverter, "hal.command_buffer.create"); |
| patterns.insert<VMImportOpConversion<IREE::HAL::CommandBufferBeginOp>>( |
| context, importSymbols, typeConverter, "hal.command_buffer.begin"); |
| patterns.insert<VMImportOpConversion<IREE::HAL::CommandBufferEndOp>>( |
| context, importSymbols, typeConverter, "hal.command_buffer.end"); |
| patterns |
| .insert<VMImportOpConversion<IREE::HAL::CommandBufferBeginDebugGroupOp>>( |
| context, importSymbols, typeConverter, |
| "hal.command_buffer.begin_debug_group"); |
| patterns |
| .insert<VMImportOpConversion<IREE::HAL::CommandBufferEndDebugGroupOp>>( |
| context, importSymbols, typeConverter, |
| "hal.command_buffer.end_debug_group"); |
| patterns |
| .insert<VMImportOpConversion<IREE::HAL::CommandBufferExecutionBarrierOp>>( |
| context, importSymbols, typeConverter, |
| "hal.command_buffer.execution_barrier"); |
| patterns.insert<VMImportOpConversion<IREE::HAL::CommandBufferFillBufferOp>>( |
| context, importSymbols, typeConverter, "hal.command_buffer.fill_buffer"); |
| patterns.insert<VMImportOpConversion<IREE::HAL::CommandBufferCopyBufferOp>>( |
| context, importSymbols, typeConverter, "hal.command_buffer.copy_buffer"); |
| patterns |
| .insert<VMImportOpConversion<IREE::HAL::CommandBufferPushConstantsOp>>( |
| context, importSymbols, typeConverter, |
| "hal.command_buffer.push_constants"); |
| patterns.insert<CommandBufferPushDescriptorSetOpConversion>( |
| context, importSymbols, typeConverter, |
| "hal.command_buffer.push_descriptor_set"); |
| patterns.insert< |
| VMImportOpConversion<IREE::HAL::CommandBufferBindDescriptorSetOp>>( |
| context, importSymbols, typeConverter, |
| "hal.command_buffer.bind_descriptor_set"); |
| patterns.insert<VMImportOpConversion<IREE::HAL::CommandBufferDispatchOp>>( |
| context, importSymbols, typeConverter, "hal.command_buffer.dispatch"); |
| patterns |
| .insert<VMImportOpConversion<IREE::HAL::CommandBufferDispatchIndirectOp>>( |
| context, importSymbols, typeConverter, |
| "hal.command_buffer.dispatch.indirect"); |
| } |
| |
| } // namespace iree_compiler |
| } // namespace mlir |