| // Copyright 2021 The IREE Authors |
| // |
| // Licensed under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| |
| #include "iree/compiler/Dialect/VMVX/Conversion/HALToVMVX/ConvertHALToVMVX.h" |
| |
| #include "iree/compiler/Dialect/HAL/IR/HALDialect.h" |
| #include "iree/compiler/Dialect/HAL/IR/HALOps.h" |
| #include "iree/compiler/Dialect/Util/IR/UtilOps.h" |
| #include "iree/compiler/Dialect/Util/IR/UtilTypes.h" |
| #include "iree/compiler/Dialect/VMVX/IR/VMVXDialect.h" |
| #include "iree/compiler/Dialect/VMVX/IR/VMVXOps.h" |
| #include "iree/compiler/Dialect/VMVX/IR/VMVXTypes.h" |
| #include "iree/compiler/Utils/IntegerSet.h" |
| #include "mlir/Dialect/Arith/IR/Arith.h" |
| #include "mlir/Dialect/MemRef/IR/MemRef.h" |
| #include "mlir/IR/Attributes.h" |
| #include "mlir/IR/Builders.h" |
| #include "mlir/IR/BuiltinOps.h" |
| #include "mlir/IR/BuiltinTypes.h" |
| #include "mlir/IR/Matchers.h" |
| #include "mlir/IR/SymbolTable.h" |
| #include "mlir/Interfaces/FunctionInterfaces.h" |
| #include "mlir/Transforms/DialectConversion.h" |
| |
| namespace mlir::iree_compiler { |
| |
| // Ordered indices of arguments to the entry point function. |
| // This is what the VM will receive at runtime from the HAL. |
| enum EntryArgOrdinals { |
| kEntryArgLocalMemory, |
| kEntryArgConstants, |
| kEntryArgBindings, |
| kEntryArgWorkgroupX, |
| kEntryArgWorkgroupY, |
| kEntryArgWorkgroupZ, |
| kEntryArgWorkgroupSizeX, |
| kEntryArgWorkgroupSizeY, |
| kEntryArgWorkgroupSizeZ, |
| kEntryArgWorkgroupCountX, |
| kEntryArgWorkgroupCountY, |
| kEntryArgWorkgroupCountZ, |
| }; |
| |
| /// Rewrites entry functions to have a vmvx.interface, local memory, and an XYZ |
| /// workgroup ID. The runtime will provide these values during invocation. |
| /// |
| /// Source: |
| /// func.func @entry() |
| /// |
| /// Target: |
| /// func.func @entry( |
| /// %local_memory: !vmvx.buffer, |
| /// %constants: !vmvx.buffer, |
| /// %bindings: !util.list<!vmvx.buffer>, |
| /// %workgroup_id_x: i32, |
| /// %workgroup_id_y: i32, |
| /// %workgroup_id_z: i32, |
| /// %workgroup_size_x: i32, |
| /// %workgroup_size_y: i32, |
| /// %workgroup_size_z: i32, |
| /// %workgroup_count_x: i32, |
| /// %workgroup_count_y: i32, |
| /// %workgroup_count_z: i32 |
| /// ) |
| LogicalResult updateHALToVMVXEntryFuncOp(mlir::FunctionOpInterface funcOp, |
| TypeConverter &typeConverter) { |
| if (funcOp.getNumArguments() != 0 || funcOp.getNumResults() != 0) { |
| return funcOp.emitError() << "exported functions must have no I/O"; |
| } |
| |
| auto bufferType = IREE::Util::BufferType::get(funcOp.getContext()); |
| auto bindingsType = IREE::Util::ListType::get(bufferType); // of i8 |
| auto i32Type = IntegerType::get(funcOp.getContext(), 32); |
| auto newType = FunctionType::get(funcOp.getContext(), |
| { |
| /*local_memory=*/bufferType, // of i8 |
| /*constants=*/bufferType, // of i32 |
| /*bindings=*/bindingsType, |
| /*workgroup_id_x=*/i32Type, |
| /*workgroup_id_y=*/i32Type, |
| /*workgroup_id_z=*/i32Type, |
| /*workgroup_size_x=*/i32Type, |
| /*workgroup_size_y=*/i32Type, |
| /*workgroup_size_z=*/i32Type, |
| /*workgroup_count_x=*/i32Type, |
| /*workgroup_count_y=*/i32Type, |
| /*workgroup_count_z=*/i32Type, |
| }, |
| {}); |
| |
| funcOp.setType(newType); |
| SmallVector<Location> locs(newType.getNumInputs(), funcOp.getLoc()); |
| funcOp.front().addArguments(newType.getInputs(), locs); |
| |
| return success(); |
| } |
| |
| namespace { |
| |
| /// Rewrites hal.interface.workgroup.id to use the arguments injected onto the |
| /// function. |
| struct ConvertHALInterfaceWorkgroupIDOp |
| : public OpConversionPattern<IREE::HAL::InterfaceWorkgroupIDOp> { |
| using Base::Base; |
| LogicalResult |
| matchAndRewrite(IREE::HAL::InterfaceWorkgroupIDOp op, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| uint64_t dim = op.getDimension().getZExtValue(); |
| if (dim >= 3) { |
| return op.emitOpError() << "out of bounds workgroup ID dimension"; |
| } |
| |
| // Get the argument to the function corresponding to the workgroup dim. |
| Value workgroupDimI32 = |
| op->getParentOfType<mlir::FunctionOpInterface>().getArgument( |
| kEntryArgWorkgroupX + dim); |
| Value workgroupDim = arith::IndexCastOp::create( |
| rewriter, op.getLoc(), rewriter.getIndexType(), workgroupDimI32); |
| rewriter.replaceOp(op, workgroupDim); |
| return success(); |
| } |
| }; |
| |
| /// Rewrites hal.interface.workgroup.size to use the arguments injected onto the |
| /// function. |
| struct ConvertHALInterfaceWorkgroupSizeOp |
| : public OpConversionPattern<IREE::HAL::InterfaceWorkgroupSizeOp> { |
| using Base::Base; |
| LogicalResult |
| matchAndRewrite(IREE::HAL::InterfaceWorkgroupSizeOp op, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| uint64_t dim = op.getDimension().getZExtValue(); |
| if (dim >= 3) { |
| return op.emitOpError() << "out of bounds workgroup size dimension"; |
| } |
| |
| // Get the argument to the function corresponding to the workgroup dim. |
| Value workgroupDimI32 = |
| op->getParentOfType<mlir::FunctionOpInterface>().getArgument( |
| kEntryArgWorkgroupSizeX + dim); |
| Value workgroupDim = arith::IndexCastOp::create( |
| rewriter, op.getLoc(), rewriter.getIndexType(), workgroupDimI32); |
| rewriter.replaceOp(op, workgroupDim); |
| return success(); |
| } |
| }; |
| |
| /// Rewrites hal.interface.workgroup.count to use the arguments injected onto |
| /// the function. |
| struct ConvertHALInterfaceWorkgroupCountOp |
| : public OpConversionPattern<IREE::HAL::InterfaceWorkgroupCountOp> { |
| using Base::Base; |
| LogicalResult |
| matchAndRewrite(IREE::HAL::InterfaceWorkgroupCountOp op, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| uint64_t dim = op.getDimension().getZExtValue(); |
| if (dim >= 3) { |
| return op.emitOpError() << "out of bounds workgroup count dimension"; |
| } |
| |
| // Get the argument to the function corresponding to the workgroup dim. |
| Value workgroupDimI32 = |
| op->getParentOfType<mlir::FunctionOpInterface>().getArgument( |
| kEntryArgWorkgroupCountX + dim); |
| Value workgroupDim = arith::IndexCastOp::create( |
| rewriter, op.getLoc(), rewriter.getIndexType(), workgroupDimI32); |
| rewriter.replaceOp(op, workgroupDim); |
| return success(); |
| } |
| }; |
| |
| /// Rewrites hal.interface.constant.load to ops loading from the ABI structs. |
| struct ConvertHALInterfaceConstantLoadOp |
| : public OpConversionPattern<IREE::HAL::InterfaceConstantLoadOp> { |
| using Base::Base; |
| LogicalResult |
| matchAndRewrite(IREE::HAL::InterfaceConstantLoadOp op, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| // Find the vmvx.interface argument to the function. |
| auto constantsArg = |
| op->getParentOfType<mlir::FunctionOpInterface>().getArgument( |
| kEntryArgConstants); |
| assert(constantsArg && "entry point not conforming to requirements"); |
| // HACK: we could find the total push constant count and avoid this size op |
| // but it'd require walking all the way up to the hal.executable export. |
| auto constantsSize = |
| IREE::Util::BufferSizeOp::create(rewriter, op.getLoc(), constantsArg); |
| auto resultType = getTypeConverter()->convertType(op.getResult().getType()); |
| |
| // Index -> byte offset. |
| auto constantIndex = rewriter.createOrFold<arith::ConstantIndexOp>( |
| op.getLoc(), op.getOrdinal().getZExtValue()); |
| auto elementSize = |
| rewriter.createOrFold<IREE::Util::SizeOfOp>(op.getLoc(), resultType); |
| auto byteOffset = rewriter.createOrFold<arith::MulIOp>( |
| op.getLoc(), elementSize, constantIndex); |
| rewriter.replaceOpWithNewOp<IREE::Util::BufferLoadOp>( |
| op, resultType, constantsArg, constantsSize, byteOffset, elementSize); |
| return success(); |
| } |
| }; |
| |
| struct ConvertGetRawInterfaceBindingBufferOp |
| : public OpConversionPattern<IREE::VMVX::GetRawInterfaceBindingBufferOp> { |
| using Base::Base; |
| LogicalResult |
| matchAndRewrite(IREE::VMVX::GetRawInterfaceBindingBufferOp op, |
| OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| // Find the vmvx.interface argument to the function. |
| auto bindingsArg = |
| op->getParentOfType<mlir::FunctionOpInterface>().getArgument( |
| kEntryArgBindings); |
| assert(bindingsArg && isa<IREE::Util::ListType>(bindingsArg.getType()) && |
| "entry point not conforming to requirements"); |
| |
| IndexSet indexSet(op.getLoc(), rewriter); |
| auto bindingType = |
| cast<IREE::Util::ListType>(bindingsArg.getType()).getElementType(); |
| rewriter |
| .replaceOpWithNewOp<IREE::Util::ListGetOp>( |
| op, bindingType, bindingsArg, |
| rewriter.createOrFold<arith::ConstantIndexOp>( |
| op.getLoc(), op.getBinding().getSExtValue())) |
| .getResult(); |
| return success(); |
| } |
| }; |
| |
| /// Rewrites hal.interface.binding.subspan to ops loading from the ABI structs. |
| struct ConvertHALInterfaceBindingSubspanOp |
| : public OpConversionPattern<IREE::HAL::InterfaceBindingSubspanOp> { |
| using Base::Base; |
| LogicalResult |
| matchAndRewrite(IREE::HAL::InterfaceBindingSubspanOp op, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| // Find the vmvx.interface argument to the function. |
| auto bindingsArg = |
| op->getParentOfType<mlir::FunctionOpInterface>().getArgument( |
| kEntryArgBindings); |
| assert(bindingsArg && isa<IREE::Util::ListType>(bindingsArg.getType()) && |
| "entry point not conforming to requirements"); |
| |
| IndexSet indexSet(op.getLoc(), rewriter); |
| auto bindingType = |
| cast<IREE::Util::ListType>(bindingsArg.getType()).getElementType(); |
| auto sourceBuffer = IREE::Util::ListGetOp::create( |
| rewriter, op.getLoc(), bindingType, bindingsArg, |
| rewriter.createOrFold<arith::ConstantIndexOp>( |
| op.getLoc(), op.getBinding().getSExtValue())) |
| .getResult(); |
| |
| if (op.getByteOffset() && !matchPattern(op.getByteOffset(), m_Zero())) { |
| // Offsetted binding: replace with a BufferSubspanOp. |
| Value sourceSize = rewriter.createOrFold<IREE::Util::BufferSizeOp>( |
| op.getLoc(), sourceBuffer); |
| |
| // Compute the dest size by multiplying the element size by all extents |
| // (static and dynamic). |
| auto memRefType = cast<MemRefType>(op.getResult().getType()); |
| Value destSize = rewriter.createOrFold<IREE::Util::SizeOfOp>( |
| op.getLoc(), memRefType.getElementType()); |
| auto dynamicExtentIt = adaptor.getDynamicDims().begin(); |
| for (int i = 0; i < memRefType.getRank(); ++i) { |
| Value extent; |
| if (memRefType.isDynamicDim(i)) { |
| extent = *dynamicExtentIt; |
| dynamicExtentIt++; |
| } else { |
| extent = indexSet.get(memRefType.getDimSize(i)); |
| } |
| destSize = |
| rewriter.createOrFold<arith::MulIOp>(op.getLoc(), destSize, extent); |
| } |
| |
| rewriter.replaceOpWithNewOp<IREE::Util::BufferSubspanOp>( |
| op, sourceBuffer, sourceSize, adaptor.getByteOffset(), destSize); |
| } else { |
| // Zero offset. Just return the source buffer. |
| rewriter.replaceOp(op, sourceBuffer); |
| } |
| return success(); |
| } |
| }; |
| |
| } // namespace |
| |
| void populateHALToVMVXPatterns(MLIRContext *context, |
| ConversionTarget &conversionTarget, |
| RewritePatternSet &patterns, |
| TypeConverter &typeConverter) { |
| conversionTarget.addIllegalDialect<IREE::HAL::HALDialect>(); |
| conversionTarget.addIllegalOp<IREE::VMVX::GetRawInterfaceBindingBufferOp>(); |
| |
| patterns.insert<ConvertGetRawInterfaceBindingBufferOp>(typeConverter, |
| context); |
| patterns.insert<ConvertHALInterfaceWorkgroupIDOp>(typeConverter, context); |
| patterns.insert<ConvertHALInterfaceWorkgroupSizeOp>(typeConverter, context); |
| patterns.insert<ConvertHALInterfaceWorkgroupCountOp>(typeConverter, context); |
| patterns.insert<ConvertHALInterfaceConstantLoadOp>(typeConverter, context); |
| patterns.insert<ConvertHALInterfaceBindingSubspanOp>(typeConverter, context); |
| } |
| |
| } // namespace mlir::iree_compiler |