blob: 77c212516f6ac30854aa9ba4fdd7ae649b8dd6c6 [file] [log] [blame]
// Copyright 2019 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
#include "iree/compiler/Dialect/HAL/IR/HALTypes.h"
#include "iree/compiler/Dialect/Util/IR/UtilTypes.h"
#include "llvm/ADT/STLExtras.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/FunctionImplementation.h"
namespace mlir {
namespace iree_compiler {
namespace IREE {
namespace HAL {
//===----------------------------------------------------------------------===//
// custom<DescriptorType>($descriptor_type)
//===----------------------------------------------------------------------===//
// Custom parser/printer to omit the wrapping `<` and `>` unlike autogenerated
// attribute parser/printer.
static ParseResult parseDescriptorType(OpAsmParser &parser,
DescriptorTypeAttr &dtAttr) {
StringRef enumKeyword;
if (failed(parser.parseKeyword(&enumKeyword)))
return failure();
std::optional<DescriptorType> maybeEnum =
symbolizeDescriptorType(enumKeyword);
if (!maybeEnum)
return failure();
dtAttr = DescriptorTypeAttr::get(parser.getContext(), *maybeEnum);
return success();
}
static void printDescriptorType(OpAsmPrinter &p, Operation *,
DescriptorTypeAttr dtAttr) {
p << stringifyDescriptorType(dtAttr.getValue());
}
//===----------------------------------------------------------------------===//
// custom<DescriptorSetBindings>($binding_ordinals,
// $binding_buffers,
// type($binding_buffers),
// $binding_offsets,
// $binding_lengths)
//===----------------------------------------------------------------------===//
static ParseResult parseDescriptorSetBindings(
OpAsmParser &parser,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &ordinals,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &buffers,
SmallVectorImpl<Type> &bufferTypes,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &bufferOffsets,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &bufferLengths) {
do {
OpAsmParser::UnresolvedOperand ordinal;
OpAsmParser::UnresolvedOperand buffer;
Type bufferType;
OpAsmParser::UnresolvedOperand bufferOffset;
OpAsmParser::UnresolvedOperand bufferLength;
if (failed(parser.parseOperand(ordinal)) || failed(parser.parseEqual()) ||
failed(parser.parseLParen()) || failed(parser.parseOperand(buffer)) ||
failed(parser.parseColonType(bufferType)) ||
failed(parser.parseRParen()) || failed(parser.parseLSquare()) ||
failed(parser.parseOperand(bufferOffset)) ||
failed(parser.parseComma()) ||
failed(parser.parseOperand(bufferLength)) ||
failed(parser.parseRSquare())) {
return failure();
}
ordinals.push_back(ordinal);
buffers.push_back(buffer);
bufferTypes.push_back(bufferType);
bufferOffsets.push_back(bufferOffset);
bufferLengths.push_back(bufferLength);
} while (succeeded(parser.parseOptionalComma()));
return success();
}
static void printDescriptorSetBindings(OpAsmPrinter &p, Operation *op,
ValueRange ordinals, ValueRange buffers,
TypeRange bufferTypes,
ValueRange bufferOffsets,
ValueRange bufferLengths) {
llvm::interleaveComma(llvm::zip_equal(ordinals, buffers, bufferTypes,
bufferOffsets, bufferLengths),
p,
[&](std::tuple<Value, Value, Type, Value, Value> it) {
p.printNewline();
p << " ";
p.printOperand(std::get<0>(it));
p << " = (";
p.printOperand(std::get<1>(it));
p << " : ";
p.printType(std::get<2>(it));
p << ")[";
p.printOperand(std::get<3>(it));
p << ", ";
p.printOperand(std::get<4>(it));
p << "]";
});
p.printNewline();
}
//===----------------------------------------------------------------------===//
// custom<TargetConditionRegion>($body)
//===----------------------------------------------------------------------===//
static FunctionType getTargetConditionRegionType(MLIRContext *context) {
return FunctionType::get(context,
{
IREE::HAL::DeviceType::get(context),
},
{
IntegerType::get(context, 1),
});
}
static LogicalResult verifyTargetConditionRegion(Operation *op,
Region &region) {
// Ignore if empty.
if (region.empty())
return success();
// Verify region takes a !hal.device.
if (region.getNumArguments() != 1 ||
!isa<IREE::HAL::DeviceType>(region.getArgumentTypes().front())) {
return op->emitOpError()
<< "target condition region must take a !hal.device";
}
// Verify i1 return.
for (auto returnOp : region.getOps<IREE::HAL::ReturnOp>()) {
if (returnOp.getNumOperands() != 1) {
return returnOp.emitOpError()
<< "target condition region must return a single i1 result";
}
for (auto returnType : returnOp.getOperandTypes()) {
if (!returnType.isInteger(1)) {
return returnOp.emitOpError()
<< "target condition region must return a single i1 result";
}
}
}
return success();
}
static ParseResult parseTargetConditionRegion(OpAsmParser &parser,
Region &body) {
SmallVector<OpAsmParser::Argument> args;
if (failed(parser.parseArgumentList(args, AsmParser::Delimiter::Paren,
/*allowType=*/true,
/*allowAttrs=*/true))) {
return failure();
}
SmallVector<Type> returnTypes;
if (failed(parser.parseArrowTypeList(returnTypes))) {
return failure();
}
if (returnTypes.size() != 1 ||
!llvm::all_of(returnTypes, [](Type type) { return type.isInteger(1); })) {
return parser.emitError(parser.getCurrentLocation())
<< "target condition region must return one i1";
}
return parser.parseRegion(body, args, /*enableNameShadowing=*/false);
}
static void printTargetConditionRegion(OpAsmPrinter &p, Operation *op,
Region &body) {
if (body.empty())
return;
p << "(";
llvm::interleaveComma(body.getArguments(), p,
[&](BlockArgument arg) { p.printRegionArgument(arg); });
p << ")";
p.printArrowTypeList(TypeRange{IntegerType::get(body.getContext(), 1)});
p << " ";
p.printRegion(body, /*printEntryBlockArgs=*/false,
/*printBlockTerminators=*/true);
}
//===----------------------------------------------------------------------===//
// custom<WorkgroupCountRegion>($body)
//===----------------------------------------------------------------------===//
static ParseResult parseWorkgroupCountRegion(OpAsmParser &parser,
Region &body) {
SmallVector<OpAsmParser::Argument> args;
if (failed(parser.parseArgumentList(args, AsmParser::Delimiter::Paren,
/*allowType=*/true,
/*allowAttrs=*/true))) {
return failure();
}
// Return types must be 3 dimensions (workgroup count XYZ).
SmallVector<Type> returnTypes;
if (failed(parser.parseArrowTypeList(returnTypes))) {
return failure();
}
if (returnTypes.size() != 3 ||
!llvm::all_of(returnTypes, [](Type type) { return type.isIndex(); })) {
return parser.emitError(parser.getCurrentLocation())
<< "workgroup count region must return the XYZ dimension counts";
}
// Parse region contents.
if (failed(parser.parseRegion(body, args, /*enableNameShadowing=*/false))) {
return failure();
}
// Verify the return types match.
for (auto returnOp : body.getOps<IREE::HAL::ReturnOp>()) {
for (auto [resultType, returnType] :
llvm::zip_equal(returnTypes, returnOp.getOperandTypes())) {
if (resultType != returnType) {
return returnOp.emitOpError()
<< "operands do not match expected region return types";
}
}
}
return success();
}
static void printWorkgroupCountRegion(OpAsmPrinter &p, Operation *op,
Region &body) {
if (body.empty())
return;
p << "(";
llvm::interleaveComma(body.getArguments(), p,
[&](BlockArgument arg) { p.printRegionArgument(arg); });
p << ")";
Type indexType = IndexType::get(body.getContext());
p.printArrowTypeList(TypeRange{indexType, indexType, indexType});
p << " ";
p.printRegion(body, /*printEntryBlockArgs=*/false,
/*printBlockTerminators=*/true);
}
//===----------------------------------------------------------------------===//
// hal.ex.*
//===----------------------------------------------------------------------===//
void ExSharedDeviceOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
setNameFn(getResult(), "device");
}
void ExFileFromMemoryOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
setNameFn(getResult(), "memory_file");
}
//===----------------------------------------------------------------------===//
// hal.return
//===----------------------------------------------------------------------===//
LogicalResult ReturnOp::verify() {
ReturnOp op = *this;
auto parentFuncOp = dyn_cast_or_null<FunctionOpInterface>(op->getParentOp());
if (parentFuncOp) {
auto expectedTypes = parentFuncOp.getResultTypes();
if (op.getNumOperands() != expectedTypes.size()) {
return op.emitOpError() << "return must have the same number of operands "
"as the parent result signature (have "
<< op.getNumOperands() << ", expected "
<< expectedTypes.size() << ")";
}
for (auto &&[index, values] :
llvm::enumerate(llvm::zip_equal(op.getOperands(), expectedTypes))) {
auto [operand, expectedType] = values;
if (operand.getType() != expectedType) {
return op.emitOpError()
<< "parent expected result " << index << " to be "
<< expectedType << " but returning " << operand.getType();
}
}
}
return success();
}
//===----------------------------------------------------------------------===//
// hal.tensor.import/export
//===----------------------------------------------------------------------===//
void TensorImportOp::build(OpBuilder &builder, OperationState &result,
Type resultType, Value source,
TypeAttr targetEncoding, StringAttr name) {
build(builder, result, resultType, source, targetEncoding,
/*waitFence=*/Value{}, name);
}
void TensorImportOp::build(OpBuilder &builder, OperationState &result,
Type resultType, Value source,
TypeAttr targetEncoding, Value waitFence,
StringAttr name) {
auto shapedType = llvm::cast<ShapedType>(resultType);
assert((source.getType().isa<IREE::HAL::BufferViewType>() ||
shapedType.hasStaticShape()) &&
"can only use this constructor for buffer views when shape "
"information is required");
SmallVector<Value> dynamicDims;
for (int64_t i = 0; i < shapedType.getRank(); ++i) {
if (!shapedType.isDynamicDim(i))
continue;
dynamicDims.push_back(builder.createOrFold<IREE::HAL::BufferViewDimOp>(
result.location, builder.getIndexType(), source,
builder.getIndexAttr(i)));
}
build(builder, result, resultType, source, targetEncoding, dynamicDims,
waitFence, name);
}
Value TensorImportOp::getTiedResult(unsigned resultIndex) {
return IREE::Util::TiedOpInterface::findTiedBaseValue(getSource());
}
::std::optional<unsigned>
TensorImportOp::getTiedResultOperandIndex(unsigned resultIndex) {
return {0}; // source
}
SmallVector<int64_t> TensorImportOp::getTiedResultOperandIndices() {
return {0}; // source
}
static LogicalResult verifyTypeStorageCompatibility(Operation *op,
Type encodingType,
Type storageType) {
if (encodingType == storageType)
return success();
auto encodingShapedType = llvm::dyn_cast<ShapedType>(encodingType);
auto storageShapedType = llvm::dyn_cast<ShapedType>(storageType);
if (!encodingShapedType || !storageShapedType)
return success();
if (IREE::Util::getRoundedElementByteWidth(
encodingShapedType.getElementType()) !=
IREE::Util::getRoundedElementByteWidth(
storageShapedType.getElementType())) {
// TODO(benvanik): more sophisticated logic here. There are a lot of valid
// cases that are difficult to account for here statically; for example,
// packing 8xi1 into 1xi8 or complex<f32> into 2xf32. We could try to guess
// the element count (at least the static part of it) and ensure the scaling
// matches but that wouldn't account for user variance. Really with this op
// we are letting the _user_ control the bitcasting and type reflection and
// purposefully don't want to mess with it (users should be able to put
// custom types here, etc).
//
// NOTE: we round to bytes first as the base type (such as i1) may not be
// representable in an external form.
// return op->emitOpError() << "encoding and storage types must be "
// "bitcastable; adjusted encoding bit width "
// "of "
// << encodingShapedType.getElementTypeBitWidth()
// << " != adjusted storage bit width of "
// << storageShapedType.getElementTypeBitWidth();
}
if (encodingShapedType.getNumDynamicDims() !=
storageShapedType.getNumDynamicDims()) {
// NOTE: we implicitly require that the dimensions are equivalent but
// dont actually care about their order. For example, tensor<?x1xf32> is
// compatible with tensor<?xf32>.
return op->emitOpError()
<< "encoding and storage types must have the same "
"dynamic dimension values; encoding shape "
<< encodingShapedType << " incompatible with storage shape "
<< storageShapedType;
}
return success();
}
LogicalResult TensorImportOp::verify() {
TensorImportOp op = *this;
auto targetType = llvm::cast<TensorType>(op.getTarget().getType());
if (targetType.getNumDynamicDims() != op.getTargetDims().size()) {
return op->emitOpError() << "number of target_dims must match number of "
"dynamic dims in target type";
}
return verifyTypeStorageCompatibility(op, op.getTargetEncoding(), targetType);
}
void TensorExportOp::build(OpBuilder &builder, OperationState &result,
Type resultType, Value source,
TypeAttr sourceEncoding, StringAttr name) {
auto dynamicDims =
IREE::Util::buildDynamicDimsForValue(result.location, source, builder);
build(builder, result, resultType, source, sourceEncoding, dynamicDims,
/*target_storage=*/nullptr, name);
}
Value TensorExportOp::getTiedResult(unsigned resultIndex) {
return IREE::Util::TiedOpInterface::findTiedBaseValue(getSource());
}
::std::optional<unsigned>
TensorExportOp::getTiedResultOperandIndex(unsigned resultIndex) {
return {0}; // source
}
SmallVector<int64_t> TensorExportOp::getTiedResultOperandIndices() {
return {0}; // source
}
LogicalResult TensorExportOp::verify() {
TensorExportOp op = *this;
auto sourceType = llvm::cast<TensorType>(op.getSource().getType());
if (sourceType.getNumDynamicDims() != op.getSourceDims().size()) {
return op->emitOpError() << "number of source_dims must match number of "
"dynamic dims in source type";
}
return verifyTypeStorageCompatibility(op, op.getSourceEncoding(),
op.getSource().getType());
}
//===----------------------------------------------------------------------===//
// hal.tensor.barrier
//===----------------------------------------------------------------------===//
void TensorBarrierOp::build(OpBuilder &builder, OperationState &result,
ValueRange sources, Value signalFence) {
auto resultTypes = llvm::map_to_vector(
sources, [](Value source) { return source.getType(); });
build(builder, result, resultTypes, sources, signalFence);
}
Value TensorBarrierOp::getTiedResult(unsigned resultIndex) {
return IREE::Util::TiedOpInterface::findTiedBaseValue(
getSources()[resultIndex]);
}
::std::optional<unsigned>
TensorBarrierOp::getTiedResultOperandIndex(unsigned resultIndex) {
return {resultIndex}; // sources[i]
}
SmallVector<int64_t> TensorBarrierOp::getTiedResultOperandIndices() {
size_t numSources = getSources().size();
return llvm::to_vector(llvm::seq<int64_t>(0, numSources));
}
//===----------------------------------------------------------------------===//
// hal.dispatch.extern
//===----------------------------------------------------------------------===//
void DispatchExternOp::build(OpBuilder &builder, OperationState &state,
ValueRange workload, TypeRange resultTypes,
ValueRange resultDims, ValueRange arguments,
ValueRange argumentDims,
ArrayRef<int64_t> tiedOperands,
ArrayRef<NamedAttribute> attributes) {
state.addTypes(resultTypes);
state.addOperands(workload);
state.addOperands(arguments);
state.addOperands(argumentDims);
state.addOperands(resultDims);
state.addAttributes(attributes);
state.attributes.erase(IREE::Util::TiedOpInterface::getStorageAttrName());
state.addAttribute(IREE::Util::TiedOpInterface::getStorageAttrName(),
builder.getIndexArrayAttr(tiedOperands));
state.attributes.erase(getOperandSegmentSizeAttr());
state.addAttribute(getOperandSegmentSizeAttr(),
builder.getDenseI32ArrayAttr({
static_cast<int32_t>(workload.size()),
static_cast<int32_t>(arguments.size()),
static_cast<int32_t>(argumentDims.size()),
static_cast<int32_t>(resultDims.size()),
}));
// NOTE: workgroup count region is empty; callers are expected to populate it.
state.addRegion();
}
// Verifies that |dynamicDims| contains the appropriate number of dims for all
// of the dynamic dimensions in |values|.
static LogicalResult verifyOpDynamicDims(Operation *op, ValueRange values,
ValueRange dynamicDims) {
unsigned requiredCount = 0;
for (auto value : values) {
if (auto shapedType = llvm::dyn_cast<ShapedType>(value.getType())) {
requiredCount += shapedType.getNumDynamicDims();
}
}
if (dynamicDims.size() != requiredCount) {
return op->emitOpError()
<< "value set has " << requiredCount
<< " dynamic dimensions but only " << dynamicDims.size()
<< " dimension values are attached";
}
return success();
}
static LogicalResult
verifyWorkgroupCountRegion(Operation *op, ValueRange workload, Region &region) {
// Verify the workload operands match the expected capture args.
auto regionArguments =
llvm::make_filter_range(region.getArgumentTypes(), [](Type type) {
return !type.isa<IREE::HAL::DeviceType>();
});
if (workload.size() != llvm::range_size(regionArguments)) {
return op->emitOpError()
<< "workload operands and workgroup count args mismatch ("
<< workload.size() << " vs " << llvm::range_size(regionArguments)
<< ")";
}
for (auto [index, values] :
llvm::enumerate(llvm::zip_equal(workload, regionArguments))) {
auto [workloadValue, capturedType] = values;
if (workloadValue.getType() != capturedType) {
return op->emitOpError()
<< "workload value " << index << " type mismatch; operand is "
<< workloadValue.getType() << " but region captures "
<< capturedType;
}
}
return success();
}
LogicalResult DispatchExternOp::verify() {
Operation *op = getOperation();
if (failed(verifyOpDynamicDims(getOperation(), getArguments(),
getArgumentDims())) ||
failed(
verifyOpDynamicDims(getOperation(), getResults(), getResultDims()))) {
return failure();
}
auto verifyIOType = [&](Type type) -> LogicalResult {
if (auto shapedType = llvm::dyn_cast<ShapedType>(type)) {
if (shapedType.getElementType().isIndex()) {
return op->emitOpError() << "I/O type " << type
<< " is invalid: index types must not cross "
"the dispatch boundary";
}
}
return success();
};
for (auto type : getOperandTypes()) {
if (failed(verifyIOType(type)))
return failure();
}
for (auto type : getResultTypes()) {
if (failed(verifyIOType(type)))
return failure();
}
if (failed(
verifyWorkgroupCountRegion(op, getWorkload(), getWorkgroupCount()))) {
return failure();
}
return success();
}
std::pair<unsigned, unsigned>
DispatchExternOp::getTiedOperandsIndexAndLength() {
return getODSOperandIndexAndLength(1);
}
//===----------------------------------------------------------------------===//
// hal.allocator.allocate
//===----------------------------------------------------------------------===//
void AllocatorAllocateOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
setNameFn(getResult(), "buffer");
}
Value AllocatorAllocateOp::getOperandSize(unsigned idx) { return {}; }
Value AllocatorAllocateOp::getResultSize(unsigned idx) {
return getResultSize();
}
//===----------------------------------------------------------------------===//
// hal.allocator.import
//===----------------------------------------------------------------------===//
void AllocatorImportOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
setNameFn(getDidImport(), "did_import");
setNameFn(getResult(), "mapped");
}
Value AllocatorImportOp::getOperandSize(unsigned idx) { return {}; }
Value AllocatorImportOp::getResultSize(unsigned idx) { return getLength(); }
//===----------------------------------------------------------------------===//
// hal.buffer.subspan
//===----------------------------------------------------------------------===//
void BufferSubspanOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
setNameFn(getResult(), "buffer");
}
Value BufferSubspanOp::getOperandSize(unsigned idx) { return getLength(); }
Value BufferSubspanOp::getResultSize(unsigned idx) { return getLength(); }
//===----------------------------------------------------------------------===//
// hal.buffer.byte_length
//===----------------------------------------------------------------------===//
void BufferLengthOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
setNameFn(getResult(), "len");
}
//===----------------------------------------------------------------------===//
// hal.buffer_view.create
//===----------------------------------------------------------------------===//
void BufferViewCreateOp::build(OpBuilder &builder, OperationState &state,
Value sourceBuffer, Value sourceOffset,
Value sourceLength, int32_t elementType,
int32_t encodingType, ValueRange shape) {
build(builder, state, sourceBuffer, sourceOffset, sourceLength,
builder.createOrFold<arith::ConstantIntOp>(state.location, elementType,
32),
builder.createOrFold<arith::ConstantIntOp>(state.location, encodingType,
32),
shape);
}
void BufferViewCreateOp::build(OpBuilder &builder, OperationState &state,
Value sourceBuffer, Value sourceOffset,
Value sourceLength, Value elementType,
Value encodingType, ValueRange shape) {
state.addOperands(
{sourceBuffer, sourceOffset, sourceLength, elementType, encodingType});
state.addOperands(shape);
state.addTypes({BufferViewType::get(builder.getContext())});
}
void BufferViewCreateOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
setNameFn(getResult(), "view");
}
//===----------------------------------------------------------------------===//
// hal.buffer_view.buffer
//===----------------------------------------------------------------------===//
void BufferViewBufferOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
setNameFn(getResult(), "buffer");
}
//===----------------------------------------------------------------------===//
// hal.channel.create
//===----------------------------------------------------------------------===//
void ChannelCreateOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
setNameFn(getResult(), "channel");
}
//===----------------------------------------------------------------------===//
// hal.channel.split
//===----------------------------------------------------------------------===//
void ChannelSplitOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
setNameFn(getResult(), "channel");
}
//===----------------------------------------------------------------------===//
// hal.channel.rank_and_count
//===----------------------------------------------------------------------===//
void ChannelRankAndCountOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
setNameFn(getRank(), "ccl_rank");
setNameFn(getCount(), "ccl_count");
}
//===----------------------------------------------------------------------===//
// hal.command_buffer.create
//===----------------------------------------------------------------------===//
void CommandBufferCreateOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
setNameFn(getResult(), "cmd");
}
//===----------------------------------------------------------------------===//
// hal.command_buffer.push_descriptor_set
//===----------------------------------------------------------------------===//
void CommandBufferPushDescriptorSetOp::build(
OpBuilder &builder, OperationState &state, Value commandBuffer,
Value pipelineLayout, int64_t set,
ArrayRef<DescriptorSetBindingValue> bindings) {
build(builder, state, commandBuffer, pipelineLayout,
builder.createOrFold<arith::ConstantIndexOp>(state.location, set),
bindings);
}
void CommandBufferPushDescriptorSetOp::build(
OpBuilder &builder, OperationState &state, Value commandBuffer,
Value pipelineLayout, Value set,
ArrayRef<DescriptorSetBindingValue> bindings) {
state.addOperands({commandBuffer, pipelineLayout, set});
SmallVector<Value> bindingOrdinals;
SmallVector<Value> bindingBuffers;
SmallVector<Value> bindingOffsets;
SmallVector<Value> bindingLengths;
for (auto binding : bindings) {
bindingOrdinals.push_back(binding.ordinal);
bindingBuffers.push_back(binding.buffer);
bindingOffsets.push_back(binding.byteOffset);
bindingLengths.push_back(binding.byteLength);
}
state.addOperands(bindingOrdinals);
state.addOperands(bindingBuffers);
state.addOperands(bindingOffsets);
state.addOperands(bindingLengths);
}
//===----------------------------------------------------------------------===//
// hal.descriptor_set_layout.create
//===----------------------------------------------------------------------===//
void DescriptorSetLayoutCreateOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
setNameFn(getResult(), "descriptor_set_layout");
}
//===----------------------------------------------------------------------===//
// hal.descriptor_set_layout.lookup
//===----------------------------------------------------------------------===//
void DescriptorSetLayoutLookupOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
setNameFn(getResult(), "descriptor_set_layout");
}
//===----------------------------------------------------------------------===//
// hal.device.allocator
//===----------------------------------------------------------------------===//
void DeviceAllocatorOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
setNameFn(getResult(), "allocator");
}
//===----------------------------------------------------------------------===//
// hal.device.query
//===----------------------------------------------------------------------===//
LogicalResult DeviceQueryOp::verify() {
DeviceQueryOp op = *this;
if (op.getDefaultValue().has_value()) {
if (op.getDefaultValue()->getType() != op.getValue().getType()) {
return op.emitOpError()
<< "type mismatch between result and default value";
}
}
return success();
}
//===----------------------------------------------------------------------===//
// hal.device.queue.*
//===----------------------------------------------------------------------===//
void DeviceQueueAllocaOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
setNameFn(getResult(), "transient_buffer");
}
Value DeviceQueueAllocaOp::getOperandSize(unsigned idx) { return {}; }
Value DeviceQueueAllocaOp::getResultSize(unsigned idx) {
return getResultSize();
}
static LogicalResult verifyDeviceQueueFences(Operation *queueOp,
Value waitFence,
Value signalFence) {
if (waitFence == signalFence &&
!isa<IREE::Util::NullOp>(waitFence.getDefiningOp())) {
return queueOp->emitOpError() << "device queue operations cannot wait and "
"signal on the same fence.";
}
return success();
}
LogicalResult DeviceQueueAllocaOp::verify() {
return verifyDeviceQueueFences(*this, getWaitFence(), getSignalFence());
}
LogicalResult DeviceQueueDeallocaOp::verify() {
return verifyDeviceQueueFences(*this, getWaitFence(), getSignalFence());
}
LogicalResult DeviceQueueReadOp::verify() {
return verifyDeviceQueueFences(*this, getWaitFence(), getSignalFence());
}
LogicalResult DeviceQueueWriteOp::verify() {
return verifyDeviceQueueFences(*this, getWaitFence(), getSignalFence());
}
LogicalResult DeviceQueueExecuteOp::verify() {
return verifyDeviceQueueFences(*this, getWaitFence(), getSignalFence());
}
//===----------------------------------------------------------------------===//
// hal.executable.source
//===----------------------------------------------------------------------===//
LogicalResult ExecutableSourceOp::verify() {
ExecutableSourceOp op = *this;
auto conditionOps = getOps<IREE::HAL::ExecutableConditionOp>();
if (llvm::range_size(conditionOps) > 1)
return op.emitOpError()
<< "only one condition op is allowed in an executable";
return success();
}
//===----------------------------------------------------------------------===//
// hal.executable
//===----------------------------------------------------------------------===//
void ExecutableOp::build(OpBuilder &builder, OperationState &state,
StringRef name) {
ensureTerminator(*state.addRegion(), builder, state.location);
state.addAttribute(mlir::SymbolTable::getSymbolAttrName(),
builder.getStringAttr(name));
}
LogicalResult ExecutableOp::verify() {
// TODO(benvanik): check export name conflicts.
return success();
}
//===----------------------------------------------------------------------===//
// hal.executable.export
//===----------------------------------------------------------------------===//
ParseResult ExecutableExportOp::parse(OpAsmParser &parser,
OperationState &result) {
StringAttr visibilityAttr;
if (failed(parseSymbolVisibility(parser, visibilityAttr))) {
return failure();
}
StringAttr nameAttr;
IREE::HAL::PipelineLayoutAttr layoutAttr;
if (failed(parser.parseSymbolName(nameAttr,
mlir::SymbolTable::getSymbolAttrName(),
result.attributes))) {
return failure();
}
if (succeeded(parser.parseOptionalKeyword("ordinal"))) {
IntegerAttr ordinalAttr;
if (failed(parser.parseLParen()) ||
failed(parser.parseAttribute(ordinalAttr,
parser.getBuilder().getIndexType())) ||
failed(parser.parseRParen())) {
return failure();
}
result.addAttribute("ordinal", ordinalAttr);
}
if (failed(parser.parseKeyword("layout")) || failed(parser.parseLParen()) ||
failed(parser.parseAttribute(layoutAttr)) ||
failed(parser.parseRParen()) ||
failed(parser.parseOptionalAttrDictWithKeyword(result.attributes))) {
return failure();
}
result.addAttribute("layout", layoutAttr);
std::unique_ptr<Region> region;
SmallVector<OpAsmParser::Argument> regionOperands;
// A missing optional region is materialized as an empty region.
(void)parser.parseOptionalRegion(region, regionOperands);
result.addRegion(std::move(region));
return success();
}
void ExecutableExportOp::print(OpAsmPrinter &p) {
Operation *op = getOperation();
p << ' ';
printSymbolVisibility(p, op, op->getAttrOfType<StringAttr>("sym_visibility"));
p << ' ';
p.printSymbolName(getSymName());
if (getOrdinalAttr()) {
p << " ordinal(";
p.printAttributeWithoutType(getOrdinalAttr());
p << ")";
}
p << " layout(";
p.printAttribute(getLayout());
p << ")";
p.printOptionalAttrDictWithKeyword(
op->getAttrs(),
/*elidedAttrs=*/{"sym_name", "layout", "ordinal"});
if (getWorkgroupCount().empty())
return;
p << " ";
p.printRegion(getWorkgroupCount());
}
LogicalResult ExecutableExportOp::verify() {
ExecutableExportOp op = *this;
Block *body = getWorkgroupCountBody();
// When there is no body, nothing to verify.
if (!body)
return success();
if (!llvm::hasSingleElement(getWorkgroupCount())) {
return op.emitOpError() << "expected a single region block";
}
bool validArguments = true;
if (body->getNumArguments() == 0) {
// Need at least a !hal.device.
validArguments = false;
} else if (!llvm::isa<IREE::HAL::DeviceType>(
body->getArgument(0).getType())) {
// !hal.device must come first.
validArguments = false;
} else {
// All remaining arguments need to be of type index (today).
for (BlockArgument &blockArg : body->getArguments().drop_front(1)) {
if (!llvm::isa<IndexType>(blockArg.getType())) {
validArguments = false;
break;
}
}
}
if (!validArguments) {
return op.emitOpError(
"expected workgroup_count to take (%device: !hal.device, "
"%workload_0: index, %workload_1: index, ...");
}
// Check that the last statement in the block is `hal.return` operation.
// TODO(ravishankarm): The SingleBlockImplicitTerminator<"HAL::ReturnOp">
// should generate this check, but it doesnt.
auto returnOp = dyn_cast<ReturnOp>(body->getTerminator());
if (!returnOp || returnOp.getOperands().size() != getNumWorkgroupDims()) {
return op.emitOpError("expected operation to yield ")
<< getNumWorkgroupDims() << " values";
}
return success();
}
// Calculates the workgroup count (x, y, z) given the total N-dimensional
// |workload| and specific |workgroupSize|.
static std::array<Value, 3>
calculateWorkloadWorkgroupCount(Location loc, ValueRange workload,
const std::array<Value, 3> &workgroupSize,
OpBuilder &builder) {
std::array<Value, 3> result;
auto constantOne = builder.createOrFold<arith::ConstantIndexOp>(loc, 1);
if (workload.size() <= 3) {
// 1-D to 3-D are easy (pad 2 to 0 dimensions) and divide by workgroup
// size.
for (int i = 0; i < 3; ++i) {
// Round up: (workload[i] + workgroup_size - 1) / workgroup_size;
Value workloadI = i < workload.size() ? workload[i] : constantOne;
workloadI = builder.createOrFold<arith::SubIOp>(
loc,
builder.createOrFold<arith::AddIOp>(loc, workloadI, workgroupSize[i]),
constantOne);
result[i] = builder.createOrFold<arith::DivUIOp>(loc, workloadI,
workgroupSize[i]);
}
} else {
// TODO(#4140): remapping of N-D to 3-D: this is not how you do this!
Value flatWorkload = constantOne;
for (auto workloadI : workload) {
flatWorkload =
builder.createOrFold<arith::MulIOp>(loc, flatWorkload, workloadI);
}
for (int i = 0; i < 3; ++i) {
// Round up: (workload[i] + workgroup_size - 1) / workgroup_size;
auto rounded = builder.createOrFold<arith::SubIOp>(
loc,
builder.createOrFold<arith::AddIOp>(loc, flatWorkload,
workgroupSize[i]),
constantOne);
auto workgroupCountI =
builder.createOrFold<arith::DivUIOp>(loc, rounded, workgroupSize[i]);
result[i] = workgroupCountI;
// Multiply back out and subtract from invocations.
flatWorkload = builder.createOrFold<arith::SubIOp>(
loc, flatWorkload,
builder.createOrFold<arith::MulIOp>(loc, workgroupCountI, rounded));
}
}
return result;
}
static std::array<Value, 3>
calculateWorkgroupCountFromRegion(Location loc, Block *body, Value device,
ValueRange workload, OpBuilder &builder) {
// TODO(benvanik): replace with region inlining util.
IRMapping bvm;
bvm.map(body->getArgument(0), device);
// For now use the number of args to minimum of number of args used by
// the body, and number of workload entries. When there is a more explicit
// propagation of number of workload entries to the `hal.executable.variant`
// this will be the same by construction.
unsigned numArgs =
std::min<unsigned>(body->getNumArguments() - 1, workload.size());
for (unsigned argNum : llvm::seq<unsigned>(0, numArgs)) {
bvm.map(body->getArgument(/*device*/ 1 + argNum), workload[argNum]);
}
for (Operation &op : body->without_terminator()) {
builder.clone(op, bvm);
}
auto returnOp = cast<IREE::HAL::ReturnOp>(body->getTerminator());
assert(returnOp.getNumOperands() == 3 && "must return xyz");
return {
bvm.lookup(returnOp.getOperands()[0]),
bvm.lookup(returnOp.getOperands()[1]),
bvm.lookup(returnOp.getOperands()[2]),
};
}
// Calculates the workgroup count (x, y, z) for dispatching to the entry point.
// The provided N-dimensional |workload| is the total number of invocations
// required as calculated by the generic workload logic (basically, number of
// output elements in tensors).
std::array<Value, 3> ExecutableExportOp::calculateWorkgroupCount(
Location loc, Value device, ValueRange workload, OpBuilder &builder) {
Block *body = getWorkgroupCountBody();
if (body) {
return calculateWorkgroupCountFromRegion(loc, body, device, workload,
builder);
}
auto workgroupSize = calculateWorkgroupSize(loc, device, workload, builder);
return calculateWorkloadWorkgroupCount(loc, workload, workgroupSize, builder);
}
// Calculates the workgroup size (x, y, z). These are the dimension numbers
// for a single workgroup.
std::array<Value, 3> ExecutableExportOp::calculateWorkgroupSize(
Location loc, Value device, ValueRange workload, OpBuilder &builder) {
// When no workgroup size is specified we just assume [1,1,1].
// This yields a workgroup count that models the extents of the workload.
return {
builder.createOrFold<arith::ConstantIndexOp>(loc, 1),
builder.createOrFold<arith::ConstantIndexOp>(loc, 1),
builder.createOrFold<arith::ConstantIndexOp>(loc, 1),
};
}
//===----------------------------------------------------------------------===//
// hal.executable.variant
//===----------------------------------------------------------------------===//
void ExecutableVariantOp::build(OpBuilder &builder, OperationState &state,
StringRef symName,
IREE::HAL::ExecutableTargetAttr target) {
ensureTerminator(*state.addRegion(), builder, state.location);
state.addAttribute(mlir::SymbolTable::getSymbolAttrName(),
builder.getStringAttr(symName));
state.addAttribute("target", target);
}
LogicalResult ExecutableVariantOp::verify() {
ExecutableVariantOp op = *this;
auto conditionOps = getOps<IREE::HAL::ExecutableConditionOp>();
if (llvm::range_size(conditionOps) > 1)
return op.emitOpError() << "only one condition op is allowed in a variant";
return success();
}
DenseMap<Attribute, int> ExecutableVariantOp::gatherConstantOrdinals() {
DenseMap<Attribute, int> map;
for (auto blockOp : getConstantBlockOps()) {
int baseCount = map.size();
for (auto [i, keyAttr] : llvm::enumerate(blockOp.getKeys())) {
map.try_emplace(keyAttr, baseCount + i);
}
}
return map;
}
Value ExecutableVariantOp::buildCondition(Value device, OpBuilder &builder) {
// Base case dependent on target information.
auto matchAttr =
cast<IREE::HAL::MatchAttrInterface>(getTarget().getMatchExpression());
auto selected = matchAttr.buildConditionExpression(getLoc(), device, builder);
// Factor in variant condition region, if any.
auto conditionOp = getConditionOp();
if (conditionOp) {
auto regionOp = builder.create<scf::ExecuteRegionOp>(conditionOp.getLoc(),
builder.getI1Type());
IRMapping mapper;
mapper.map(conditionOp.getRegion().getArgument(0), device);
conditionOp.getRegion().cloneInto(&regionOp.getRegion(), mapper);
for (auto returnOp :
llvm::make_early_inc_range(regionOp.getOps<IREE::HAL::ReturnOp>())) {
OpBuilder(returnOp).create<scf::YieldOp>(returnOp.getLoc(),
returnOp.getOperands());
returnOp.erase();
}
selected = builder.create<arith::AndIOp>(getLoc(), selected,
regionOp.getResult(0));
}
return selected;
}
//===----------------------------------------------------------------------===//
// hal.executable.condition
//===----------------------------------------------------------------------===//
LogicalResult ExecutableConditionOp::verify() {
ExecutableConditionOp op = *this;
return verifyTargetConditionRegion(op, op.getBody());
}
void ExecutableConditionOp::build(OpBuilder &builder, OperationState &result,
ArrayRef<NamedAttribute> attrs) {
result.addAttribute(
"function_type",
TypeAttr::get(getTargetConditionRegionType(builder.getContext())));
result.addRegion();
result.attributes.append(attrs.begin(), attrs.end());
}
ParseResult ExecutableConditionOp::parse(OpAsmParser &parser,
OperationState &result) {
if (parseTargetConditionRegion(parser, *result.addRegion()))
return failure();
result.addAttribute(
"function_type",
TypeAttr::get(getTargetConditionRegionType(parser.getContext())));
if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
return failure();
return success();
}
void ExecutableConditionOp::print(OpAsmPrinter &p) {
Operation *op = getOperation();
printTargetConditionRegion(p, op, getBody());
p.printOptionalAttrDictWithKeyword(op->getAttrs(),
/*elidedAttrs=*/{"function_type"});
}
Block *ExecutableConditionOp::addEntryBlock() {
assert(empty() && "function already has an entry block");
auto *entry = new Block();
auto argTypes = getArgumentTypes();
SmallVector<Location> argLocs(argTypes.size(), getLoc());
entry->addArguments(argTypes, argLocs);
push_back(entry);
return entry;
}
Block *ExecutableConditionOp::addBlock() {
assert(!empty() && "function should at least have an entry block");
push_back(new Block());
return &back();
}
//===----------------------------------------------------------------------===//
// hal.executable.constant.block
//===----------------------------------------------------------------------===//
ParseResult ExecutableConstantBlockOp::parse(OpAsmParser &parser,
OperationState &result) {
auto &builder = parser.getBuilder();
// Parse the function signature.
SmallVector<OpAsmParser::Argument> entryArgs;
bool isVariadic = false;
SmallVector<DictionaryAttr> resultAttrs;
SmallVector<Type> resultTypes;
if (mlir::function_interface_impl::parseFunctionSignature(
parser, /*allowVariadic=*/false, entryArgs, isVariadic, resultTypes,
resultAttrs)) {
return failure();
}
SmallVector<Type> argTypes;
for (auto &arg : entryArgs)
argTypes.push_back(arg.type);
auto fnType = builder.getFunctionType(argTypes, resultTypes);
result.addAttribute(getFunctionTypeAttrName(result.name),
TypeAttr::get(fnType));
// Parse the keys used for each yielded constant value.
// There must be one key per result. Note that we support omitted parens when
// only one result is present.
SmallVector<Attribute> keyAttrs;
if (failed(parser.parseKeyword("as")))
return failure();
if (resultTypes.size() == 1) {
std::string key;
if (failed(parser.parseString(&key)))
return failure();
keyAttrs.push_back(builder.getStringAttr(key));
} else {
if (failed(parser.parseCommaSeparatedList(
AsmParser::Delimiter::OptionalParen,
[&]() {
std::string key;
if (failed(parser.parseString(&key)))
return failure();
keyAttrs.push_back(builder.getStringAttr(key));
return success();
},
"containing a 1:1 list of keys per yielded value"))) {
return failure();
}
}
result.addAttribute("keys", builder.getArrayAttr(keyAttrs));
// If function attributes are present, parse them.
NamedAttrList parsedAttributes;
if (parser.parseOptionalAttrDictWithKeyword(parsedAttributes)) {
return failure();
}
result.attributes.append(parsedAttributes);
// Add the attributes to the function arguments.
assert(resultAttrs.size() == resultTypes.size());
mlir::function_interface_impl::addArgAndResultAttrs(
builder, result, entryArgs, resultAttrs, getArgAttrsAttrName(result.name),
getResAttrsAttrName(result.name));
// Parse the optional function body. The printer will not print the body if
// its empty, so disallow parsing of empty body in the parser.
auto *body = result.addRegion();
SMLoc loc = parser.getCurrentLocation();
if (failed(parser.parseRegion(*body, entryArgs,
/*enableNameShadowing=*/false))) {
return failure();
}
// Function body was parsed, make sure its not empty.
if (body->empty()) {
return parser.emitError(loc, "expected non-empty function body");
}
return success();
}
void ExecutableConstantBlockOp::print(OpAsmPrinter &p) {
Operation *op = getOperation();
ArrayRef<Type> argTypes = getArgumentTypes();
ArrayRef<Type> resultTypes = getResultTypes();
mlir::function_interface_impl::printFunctionSignature(
p, cast<FunctionOpInterface>(op), argTypes, /*isVariadic=*/false,
resultTypes);
p << " as ";
if (resultTypes.size() != 1)
p << '(';
llvm::interleaveComma(getKeys().getValue(), p,
[&](Attribute attr) { p << attr; });
if (resultTypes.size() != 1)
p << ')';
mlir::function_interface_impl::printFunctionAttributes(
p, op, {getFunctionTypeAttrName(), getKeysAttrName()});
p << " ";
p.printRegion(getBody(), /*printEntryBlockArgs=*/false,
/*printBlockTerminators=*/true);
}
LogicalResult ExecutableConstantBlockOp::verify() {
ExecutableConstantBlockOp op = *this;
// Verify the function takes either nothing or a device.
auto argTypes = op.getArgumentTypes();
if (!argTypes.empty() &&
(argTypes.size() > 1 || !llvm::isa<IREE::HAL::DeviceType>(argTypes[0]))) {
return op->emitOpError()
<< "initializer must take a !hal.device or nothing";
}
// Verify the return types are all i32 (today).
for (auto resultType : llvm::enumerate(op.getResultTypes())) {
if (!resultType.value().isInteger(32)) {
return op->emitOpError()
<< "initializer must return only i32 values (result "
<< resultType.index() << " is " << resultType.value() << ")";
}
}
// Verify there's a key for every result.
if (op.getNumResults() != op.getKeys().size()) {
return op->emitOpError() << "must have one key for every result";
}
return success();
}
//===----------------------------------------------------------------------===//
// hal.executable.binary
//===----------------------------------------------------------------------===//
void ExecutableBinaryOp::build(OpBuilder &builder, OperationState &state,
StringRef symName, StringRef format,
std::vector<uint8_t> data) {
state.addAttribute(mlir::SymbolTable::getSymbolAttrName(),
builder.getStringAttr(symName));
state.addAttribute("format", builder.getStringAttr(format));
state.addAttribute("data",
DenseIntElementsAttr::get(
VectorType::get({static_cast<int64_t>(data.size())},
builder.getIntegerType(8)),
data));
}
void ExecutableBinaryOp::build(OpBuilder &builder, OperationState &state,
StringRef symName, StringAttr format,
DenseIntElementsAttr data) {
state.addAttribute(mlir::SymbolTable::getSymbolAttrName(),
builder.getStringAttr(symName));
state.addAttribute("format", format);
state.addAttribute("data", data);
}
//===----------------------------------------------------------------------===//
// hal.executable.create
//===----------------------------------------------------------------------===//
void ExecutableCreateOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
setNameFn(getResult(), StringRef("exe"));
}
//===----------------------------------------------------------------------===//
// hal.executable.lookup
//===----------------------------------------------------------------------===//
void ExecutableLookupOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
setNameFn(getResult(), "exe");
}
//===----------------------------------------------------------------------===//
// hal.interface.binding.subspan
//===----------------------------------------------------------------------===//
void InterfaceBindingSubspanOp::build(
OpBuilder &builder, OperationState &result, Type resultType, APInt set,
APInt binding, IREE::HAL::DescriptorType descriptor_type, Value byte_offset,
ValueRange dynamic_dims, IntegerAttr alignment,
std::optional<DescriptorFlags> flags) {
IREE::HAL::DescriptorFlagsAttr descriptorAttr;
if (flags.has_value()) {
descriptorAttr = IREE::HAL::DescriptorFlagsAttr::get(builder.getContext(),
flags.value());
}
build(builder, result, resultType, set, binding, descriptor_type, byte_offset,
dynamic_dims, alignment, descriptorAttr);
}
LogicalResult InterfaceBindingSubspanOp::verify() {
InterfaceBindingSubspanOp op = *this;
if (ShapedType shapedType = llvm::dyn_cast<ShapedType>(op.getType())) {
if (shapedType.getNumDynamicDims() != op.getDynamicDims().size()) {
return op.emitOpError("result type ")
<< op.getType() << " has " << shapedType.getNumDynamicDims()
<< " dynamic dimensions but " << op.getDynamicDims().size()
<< " associated dimension SSA values";
}
}
return success();
}
llvm::MaybeAlign InterfaceBindingSubspanOp::getBaseAlignment() {
if (auto baseAlignmentInt = getAlignment()) {
return llvm::MaybeAlign(baseAlignmentInt.value().getZExtValue());
}
return std::nullopt;
}
llvm::Align InterfaceBindingSubspanOp::calculateAlignment() {
// If we can't calculate an alignment we fall back to the natural alignment of
// the element type (for example, a memref<?xi32> is known to be at least
// 4-byte aligned).
llvm::Align naturalAlignment(1);
auto resultType = getType();
if (auto shapedType = llvm::dyn_cast<ShapedType>(resultType)) {
naturalAlignment = llvm::Align(
IREE::Util::getRoundedElementByteWidth(shapedType.getElementType()));
}
// If the binding has no assigned alignment we fall back to natural alignment.
auto baseAlignment = getBaseAlignment();
if (!baseAlignment)
return naturalAlignment;
// If there's no offset specified then we can use the binding alignment
// directly.
if (!getByteOffset())
return baseAlignment.value();
// Try to get the alignment of the byte offset. If it's a constant then we can
// find a common alignment between it and the base and otherwise we need to
// try to infer the alignment from the IR - otherwise we fall back.
auto offsetOrAlignment = lookupOffsetOrAlignment(getByteOffset());
if (!offsetOrAlignment.has_value())
return naturalAlignment;
// Compute the common alignment between that of the binding base and that of
// the byte offset.
return llvm::commonAlignment(baseAlignment.value(),
offsetOrAlignment.value());
}
//===----------------------------------------------------------------------===//
// hal.interface.workgroup.*
//===----------------------------------------------------------------------===//
static void getAsmResultNamesForInterfaceWorkgroupOp(
StringRef prefix, const APInt &dimension, Value result,
function_ref<void(Value, StringRef)> setNameFn) {
switch (dimension.getZExtValue()) {
case 0:
setNameFn(result, (prefix + "x").str());
return;
case 1:
setNameFn(result, (prefix + "y").str());
return;
case 2:
setNameFn(result, (prefix + "z").str());
return;
}
}
void InterfaceWorkgroupIDOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
getAsmResultNamesForInterfaceWorkgroupOp("workgroup_id_", getDimension(),
getResult(), setNameFn);
}
void InterfaceWorkgroupCountOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
getAsmResultNamesForInterfaceWorkgroupOp("workgroup_count_", getDimension(),
getResult(), setNameFn);
}
void InterfaceWorkgroupSizeOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
getAsmResultNamesForInterfaceWorkgroupOp("workgroup_size_", getDimension(),
getResult(), setNameFn);
}
//===----------------------------------------------------------------------===//
// hal.pipeline_layout.create
//===----------------------------------------------------------------------===//
void PipelineLayoutCreateOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
setNameFn(getResult(), "pipeline_layout");
}
//===----------------------------------------------------------------------===//
// hal.pipeline_layout.lookup
//===----------------------------------------------------------------------===//
void PipelineLayoutLookupOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
setNameFn(getResult(), "pipeline_layout");
}
//===----------------------------------------------------------------------===//
// hal.fence.*
//===----------------------------------------------------------------------===//
void FenceCreateOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
setNameFn(getResult(), "fence");
}
void FenceJoinOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
setNameFn(getResult(), "fence");
}
void FenceAwaitOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
setNameFn(getStatus(), "status");
}
} // namespace HAL
} // namespace IREE
} // namespace iree_compiler
} // namespace mlir
//===----------------------------------------------------------------------===//
// TableGen definitions (intentionally last)
//===----------------------------------------------------------------------===//
#define GET_OP_CLASSES
#include "iree/compiler/Dialect/HAL/IR/HALOps.cpp.inc"