blob: 937233e7af0867f24b9ee1431331979d2b54efec [file] [log] [blame]
// Copyright 2019 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
#include "iree/compiler/Dialect/HAL/IR/HALTypes.h"
#include "iree/compiler/Dialect/IREE/IR/IREETypes.h"
#include "iree/compiler/Dialect/Shape/IR/Builders.h"
#include "llvm/ADT/Hashing.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/SMLoc.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/IR/TypeUtilities.h"
namespace mlir {
namespace iree_compiler {
namespace IREE {
namespace HAL {
template <typename T>
static LogicalResult parseEnumAttr(OpAsmParser &parser, StringRef attrName,
NamedAttrList &attrs) {
Attribute genericAttr;
NamedAttrList attrList;
auto loc = parser.getCurrentLocation();
if (failed(parser.parseAttribute(genericAttr,
parser.getBuilder().getNoneType(), attrName,
attrList))) {
return parser.emitError(loc) << "failed to parse enum string value";
}
auto stringAttr = genericAttr.dyn_cast<StringAttr>();
if (!stringAttr) {
return parser.emitError(loc)
<< "expected " << attrName << " attribute specified as string";
}
auto symbolized = symbolizeEnum<T>(stringAttr.getValue());
if (!symbolized.hasValue()) {
return parser.emitError(loc) << "failed to parse enum value";
}
attrs.push_back(parser.getBuilder().getNamedAttr(
attrName, parser.getBuilder().getI32IntegerAttr(
static_cast<int32_t>(symbolized.getValue()))));
return success();
}
//===----------------------------------------------------------------------===//
// custom<SizeAwareType>
//===----------------------------------------------------------------------===//
// type{%size}
static ParseResult parseSizeAwareType(OpAsmParser &parser, Type &type,
OpAsmParser::OperandType &size) {
if (failed(parser.parseType(type)) || failed(parser.parseLBrace()) ||
failed(parser.parseOperand(size)) || failed(parser.parseRBrace())) {
return failure();
}
return success();
}
static void printSizeAwareType(OpAsmPrinter &p, Operation *op, Type type,
Value size) {
p.printType(type);
p << "{";
p.printOperand(size);
p << "}";
}
//===----------------------------------------------------------------------===//
// custom<SizeAwareTypeList>
//===----------------------------------------------------------------------===//
// (type{%size0}, type, type{%size1})
static ParseResult parseSizeAwareTypeList(
OpAsmParser &parser, SmallVectorImpl<Type> &types,
SmallVectorImpl<OpAsmParser::OperandType> &sizes) {
do {
Type type;
if (failed(parser.parseType(type))) return failure();
if (type.isa<SizeAwareTypeInterface>()) {
OpAsmParser::OperandType size;
if (failed(parser.parseLBrace()) || failed(parser.parseOperand(size)) ||
failed(parser.parseRBrace())) {
return failure();
}
sizes.push_back(size);
}
types.push_back(type);
} while (succeeded(parser.parseOptionalComma()));
return success();
}
static void printSizeAwareTypeList(OpAsmPrinter &p, Operation *op,
TypeRange types, OperandRange sizes) {
int sizeIndex = 0;
llvm::interleaveComma(types, p, [&](Type type) {
p.printType(type);
if (type.isa<SizeAwareTypeInterface>()) {
p << "{";
p.printOperand(sizes[sizeIndex++]);
p << "}";
}
});
}
//===----------------------------------------------------------------------===//
// custom<DescriptorSetBindings>($binding_ordinals,
// $binding_buffers,
// type($binding_buffers),
// $binding_offsets,
// $binding_lengths)
//===----------------------------------------------------------------------===//
static ParseResult parseDescriptorSetBindings(
OpAsmParser &parser, SmallVectorImpl<OpAsmParser::OperandType> &ordinals,
SmallVectorImpl<OpAsmParser::OperandType> &buffers,
SmallVectorImpl<Type> &bufferTypes,
SmallVectorImpl<OpAsmParser::OperandType> &bufferOffsets,
SmallVectorImpl<OpAsmParser::OperandType> &bufferLengths) {
do {
OpAsmParser::OperandType ordinal;
OpAsmParser::OperandType buffer;
Type bufferType;
OpAsmParser::OperandType bufferOffset;
OpAsmParser::OperandType bufferLength;
if (failed(parser.parseOperand(ordinal)) || failed(parser.parseEqual()) ||
failed(parser.parseLParen()) || failed(parser.parseOperand(buffer)) ||
failed(parser.parseColonType(bufferType)) ||
failed(parser.parseRParen()) || failed(parser.parseLSquare()) ||
failed(parser.parseOperand(bufferOffset)) ||
failed(parser.parseComma()) ||
failed(parser.parseOperand(bufferLength)) ||
failed(parser.parseRSquare())) {
return failure();
}
ordinals.push_back(ordinal);
buffers.push_back(buffer);
bufferTypes.push_back(bufferType);
bufferOffsets.push_back(bufferOffset);
bufferLengths.push_back(bufferLength);
} while (succeeded(parser.parseOptionalComma()));
return success();
}
static void printDescriptorSetBindings(OpAsmPrinter &p, Operation *op,
ValueRange ordinals, ValueRange buffers,
TypeRange bufferTypes,
ValueRange bufferOffsets,
ValueRange bufferLengths) {
llvm::interleaveComma(
llvm::zip(ordinals, buffers, bufferTypes, bufferOffsets, bufferLengths),
p, [&](std::tuple<Value, Value, Type, Value, Value> it) {
p.printNewline();
p << " ";
p.printOperand(std::get<0>(it));
p << " = (";
p.printOperand(std::get<1>(it));
p << " : ";
p.printType(std::get<2>(it));
p << ")[";
p.printOperand(std::get<3>(it));
p << ", ";
p.printOperand(std::get<4>(it));
p << "]";
});
p.printNewline();
}
//===----------------------------------------------------------------------===//
// custom<PackSliceRanges>($lifetime_intervals,
// $dynamic_slice_sizes,
// type($packed_offsets))
//===----------------------------------------------------------------------===//
static ParseResult parsePackSliceRanges(
OpAsmParser &parser, ArrayAttr &lifetimeIntervals,
SmallVectorImpl<OpAsmParser::OperandType> &dynamicSliceSizes,
SmallVectorImpl<Type> &packedOffsetTypes) {
auto indexType = parser.getBuilder().getIndexType();
SmallVector<Attribute> lifetimeRangeValues;
do {
if (failed(parser.parseOptionalLSquare())) break;
IntegerAttr lifetimeStart;
IntegerAttr lifetimeEnd;
OpAsmParser::OperandType dynamicSliceSize;
if (failed(parser.parseAttribute(lifetimeStart, indexType)) ||
failed(parser.parseComma()) ||
failed(parser.parseAttribute(lifetimeEnd, indexType)) ||
failed(parser.parseRSquare()) || failed(parser.parseEqual()) ||
failed(parser.parseOperand(dynamicSliceSize))) {
return failure();
}
lifetimeRangeValues.push_back(lifetimeStart);
lifetimeRangeValues.push_back(lifetimeEnd);
dynamicSliceSizes.push_back(dynamicSliceSize);
packedOffsetTypes.push_back(indexType);
} while (succeeded(parser.parseOptionalComma()));
lifetimeIntervals = parser.getBuilder().getArrayAttr(lifetimeRangeValues);
return success();
}
static void printPackSliceRanges(OpAsmPrinter &p, Operation *op,
ArrayAttr lifetimeIntervals,
ValueRange dynamicSliceSizes,
TypeRange packedOffsetTypes) {
if (packedOffsetTypes.empty()) return;
for (unsigned i = 0; i < packedOffsetTypes.size(); ++i) {
auto lifetimeStart = lifetimeIntervals[i * 2];
auto lifetimeEnd = lifetimeIntervals[i * 2 + 1];
auto sliceSize = dynamicSliceSizes[i];
p.printNewline();
p << " [";
p.printAttributeWithoutType(lifetimeStart);
p << ", ";
p.printAttributeWithoutType(lifetimeEnd);
p << "] = ";
p.printOperand(sliceSize);
if (i < packedOffsetTypes.size() - 1) p << ",";
}
p.printNewline();
}
//===----------------------------------------------------------------------===//
// hal.ex.shared_device
//===----------------------------------------------------------------------===//
void ExSharedDeviceOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
setNameFn(result(), "device");
}
//===----------------------------------------------------------------------===//
// hal.tensor.cast
//===----------------------------------------------------------------------===//
void TensorCastOp::build(OpBuilder &builder, OperationState &result,
Type resultType, Value source,
ArrayRef<NamedAttribute> attrs) {
SmallVector<Value> dynamicDims;
if (source.getType().isa<IREE::HAL::BufferViewType>()) {
auto shapedType = resultType.cast<ShapedType>();
for (int64_t i = 0; i < shapedType.getRank(); ++i) {
if (!shapedType.isDynamicDim(i)) continue;
dynamicDims.push_back(builder.createOrFold<IREE::HAL::BufferViewDimOp>(
result.location, builder.getIndexType(), source,
builder.getIndexAttr(i)));
}
} else {
dynamicDims =
Shape::buildOrFindDynamicDimsForValue(result.location, source, builder);
}
build(builder, result, resultType, source, dynamicDims, attrs);
}
void TensorCastOp::build(OpBuilder &builder, OperationState &result,
Type resultType, Value source, ValueRange dynamicDims,
ArrayRef<NamedAttribute> attrs) {
result.addTypes({resultType});
result.addOperands({source});
result.addOperands({dynamicDims});
result.addAttributes(attrs);
result.addAttribute(
"operand_segment_sizes",
builder.getI32VectorAttr({
static_cast<int32_t>(1),
static_cast<int32_t>(
source.getType().isa<TensorType>() ? dynamicDims.size() : 0),
static_cast<int32_t>(resultType.isa<TensorType>() ? dynamicDims.size()
: 0),
}));
}
Value TensorCastOp::buildOperandRankedShape(unsigned idx, OpBuilder &builder) {
if (source().getType().isa<TensorType>()) {
return Shape::buildRankedShapeForValue(getLoc(), source(), source_dims(),
builder);
} else {
return buildResultRankedShape(idx, builder);
}
}
Value TensorCastOp::buildResultRankedShape(unsigned idx, OpBuilder &builder) {
if (target().getType().isa<TensorType>()) {
return Shape::buildRankedShapeForValue(getLoc(), target(), target_dims(),
builder);
} else {
return buildOperandRankedShape(idx, builder);
}
}
Value TensorCastOp::getTiedResult(unsigned resultIndex) {
return IREE::TiedOpInterface::findTiedBaseValue(source());
}
::llvm::Optional<unsigned> TensorCastOp::getTiedResultOperandIndex(
unsigned resultIndex) {
return {0}; // source
}
SmallVector<int64_t, 4> TensorCastOp::getTiedResultOperandIndices() {
return {0}; // source
}
//===----------------------------------------------------------------------===//
// hal.variable
//===----------------------------------------------------------------------===//
// Returns true if the given |accessType| is compatible with the |variableType|.
// For example, this will return true if the variable type is a tensor<?xf32>
// and the access is tensor<4xf32>.
static bool isVariableTypeCompatible(Type variableType, Type accessType) {
// If one is a shaped type, then they both must be and have compatible
// shapes.
if (variableType.isa<ShapedType>() || accessType.isa<ShapedType>()) {
return succeeded(mlir::verifyCompatibleShape(variableType, accessType));
}
// Otherwise, the types must be the same.
return variableType == accessType;
}
static ParseResult parseVariableOp(OpAsmParser &parser,
OperationState *result) {
StringAttr nameAttr;
if (failed(parser.parseSymbolName(nameAttr,
mlir::SymbolTable::getSymbolAttrName(),
result->attributes))) {
return failure();
}
if (succeeded(parser.parseOptionalKeyword("mutable"))) {
result->addAttribute("is_mutable", UnitAttr::get(result->getContext()));
}
if (succeeded(parser.parseOptionalKeyword("init"))) {
FlatSymbolRefAttr initializerAttr;
if (failed(parser.parseLParen()) ||
failed(parser.parseAttribute(initializerAttr, "initializer",
result->attributes)) ||
failed(parser.parseRParen())) {
return failure();
}
}
if (failed(parser.parseOptionalAttrDictWithKeyword(result->attributes))) {
return failure();
}
Type type;
if (succeeded(parser.parseOptionalEqual())) {
// @foo = 4 : i32
Attribute initialValueAttr;
if (failed(parser.parseAttribute(initialValueAttr, "initial_value",
result->attributes))) {
return failure();
}
type = initialValueAttr.getType();
} else {
// @foo : index = 4 : i32
if (failed(parser.parseColonType(type)) ||
failed(parser.parseOptionalAttrDictWithKeyword(result->attributes))) {
return failure();
}
if (succeeded(parser.parseOptionalEqual())) {
Attribute initialValueAttr;
if (failed(parser.parseAttribute(initialValueAttr, "initial_value",
result->attributes))) {
return failure();
}
}
}
result->addAttribute("type", TypeAttr::get(type));
return success();
}
static void printVariableOp(OpAsmPrinter &p, VariableOp op) {
p << op.getOperationName() << ' ';
p.printSymbolName(op.sym_name());
if (op.is_mutable()) {
p << " mutable";
}
if (op.initializer().hasValue()) {
p << " init(";
p.printSymbolName(op.initializer().getValue());
p << ')';
}
if (op.initial_value().hasValue() &&
op.type() == op.initial_value().getValue().getType()) {
// @foo = 4 : i32
} else {
// @foo : index = 4 : i32
p << " : ";
p.printType(op.type());
}
p.printOptionalAttrDictWithKeyword(op->getAttrs(), /*elidedAttrs=*/{
"sym_name",
"type",
"is_mutable",
"initializer",
"initial_value",
});
if (op.initial_value().hasValue()) {
p << " = ";
p.printAttribute(op.initial_value().getValue());
}
}
static LogicalResult verifyVariableOp(VariableOp op) {
if (op.initializer().hasValue() && op.initial_value().hasValue()) {
return op.emitOpError()
<< "variables can have either an initializer or an initial value";
} else if (op.initializer().hasValue()) {
// Ensure initializer returns the same type as the variable.
auto *symbolOp =
SymbolTable::lookupNearestSymbolFrom(op, op.initializer().getValue());
if (!symbolOp) {
return op.emitOpError() << "initializer function "
<< op.initializer().getValue() << " not found";
}
auto initializerOp = dyn_cast<FuncOp>(symbolOp);
if (initializerOp.getNumArguments() != 0 ||
initializerOp.getNumResults() != 1 ||
initializerOp.getType().getResult(0) != op.type()) {
return op.emitOpError()
<< "initializer type mismatch; variable " << op.sym_name()
<< " is " << op.type() << " but initializer function "
<< initializerOp.getName() << " is " << initializerOp.getType();
}
}
return success();
}
void VariableOp::build(OpBuilder &builder, OperationState &result,
StringRef name, bool isMutable, Type type,
Optional<StringRef> initializer,
Optional<Attribute> initialValue,
ArrayRef<NamedAttribute> attrs) {
result.addAttribute(SymbolTable::getSymbolAttrName(),
builder.getStringAttr(name));
if (isMutable) {
result.addAttribute("is_mutable", builder.getUnitAttr());
}
if (initializer.hasValue()) {
result.addAttribute("initializer",
builder.getSymbolRefAttr(initializer.getValue()));
} else if (initialValue.hasValue()) {
result.addAttribute("initial_value", initialValue.getValue());
}
result.addAttribute("type", TypeAttr::get(type));
result.attributes.append(attrs.begin(), attrs.end());
}
void VariableOp::build(OpBuilder &builder, OperationState &result,
StringRef name, bool isMutable, mlir::FuncOp initializer,
ArrayRef<NamedAttribute> attrs) {
build(builder, result, name, isMutable, initializer.getType().getResult(0),
initializer.getName(), llvm::None, attrs);
}
void VariableOp::build(OpBuilder &builder, OperationState &result,
StringRef name, bool isMutable, Type type,
Attribute initialValue, ArrayRef<NamedAttribute> attrs) {
build(builder, result, name, isMutable, type, llvm::None, initialValue,
attrs);
}
void VariableOp::build(OpBuilder &builder, OperationState &result,
StringRef name, bool isMutable, Type type,
ArrayRef<NamedAttribute> attrs) {
build(builder, result, name, isMutable, type, llvm::None, llvm::None, attrs);
}
//===----------------------------------------------------------------------===//
// hal.variable.load
//===----------------------------------------------------------------------===//
void VariableLoadOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
// HACK: works around the lack of symbol side effects in mlir by only saying
// we have a side-effect if the variable we are loading is mutable.
auto *symbolOp = SymbolTable::lookupNearestSymbolFrom(*this, variable());
assert(symbolOp);
auto variableOp = dyn_cast<VariableOp>(symbolOp);
if (variableOp.is_mutable()) {
effects.emplace_back(MemoryEffects::Read::get());
}
}
static LogicalResult verifyVariableLoadOp(VariableLoadOp &op) {
auto *symbolOp = SymbolTable::lookupNearestSymbolFrom(op, op.variable());
if (!symbolOp) {
return op.emitOpError() << "undefined variable: " << op.variable();
}
auto variableOp = dyn_cast<VariableOp>(symbolOp);
auto loadType = op.result().getType();
if (!isVariableTypeCompatible(variableOp.type(), loadType)) {
return op.emitOpError()
<< "variable type mismatch; variable " << op.variable() << " is "
<< variableOp.type() << " but load is " << loadType;
}
return success();
}
//===----------------------------------------------------------------------===//
// hal.variable.load.indirect
//===----------------------------------------------------------------------===//
static LogicalResult verifyVariableLoadIndirectOp(VariableLoadIndirectOp &op) {
auto variableType =
op.variable().getType().cast<IREE::PtrType>().getTargetType();
auto loadType = op.result().getType();
if (!isVariableTypeCompatible(variableType, loadType)) {
return op.emitOpError() << "variable type mismatch; variable pointer is "
<< variableType << " but load is " << loadType;
}
return success();
}
//===----------------------------------------------------------------------===//
// hal.variable.store
//===----------------------------------------------------------------------===//
static LogicalResult verifyVariableStoreOp(VariableStoreOp &op) {
auto *symbolOp = SymbolTable::lookupNearestSymbolFrom(op, op.variable());
if (!symbolOp) {
return op.emitOpError() << "undefined variable: " << op.variable();
}
auto variableOp = dyn_cast<VariableOp>(symbolOp);
auto storeType = op.value().getType();
if (!isVariableTypeCompatible(variableOp.type(), storeType)) {
return op.emitOpError()
<< "variable type mismatch; variable " << op.variable() << " is "
<< variableOp.type() << " but store is " << storeType;
}
if (!variableOp.is_mutable()) {
return op.emitOpError() << "variable " << op.variable()
<< " is not mutable and cannot be stored to";
}
return success();
}
//===----------------------------------------------------------------------===//
// hal.variable.store.indirect
//===----------------------------------------------------------------------===//
static LogicalResult verifyVariableStoreIndirectOp(
VariableStoreIndirectOp &op) {
auto variableType =
op.variable().getType().cast<IREE::PtrType>().getTargetType();
auto storeType = op.value().getType();
if (!isVariableTypeCompatible(variableType, storeType)) {
return op.emitOpError() << "variable type mismatch; variable pointer is "
<< variableType << " but store is " << storeType;
}
return success();
}
//===----------------------------------------------------------------------===//
// hal.allocator.compute_size
//===----------------------------------------------------------------------===//
void AllocatorComputeSizeOp::build(OpBuilder &builder, OperationState &state,
Value allocator, ValueRange shape,
int32_t elementType) {
build(builder, state, allocator, shape,
builder.createOrFold<ConstantIntOp>(state.location, elementType, 32));
}
void AllocatorComputeSizeOp::build(OpBuilder &builder, OperationState &state,
Value allocator, ValueRange shape,
Value elementType) {
state.addOperands({allocator});
state.addOperands(shape);
state.addOperands(elementType);
state.addTypes({builder.getIndexType()});
}
void AllocatorComputeSizeOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
setNameFn(result(), "sz");
}
//===----------------------------------------------------------------------===//
// hal.allocator.compute_offset
//===----------------------------------------------------------------------===//
void AllocatorComputeOffsetOp::build(OpBuilder &builder, OperationState &state,
Value allocator, ValueRange shape,
int32_t elementType, ValueRange indices) {
build(builder, state, allocator, shape,
builder.createOrFold<ConstantIntOp>(state.location, elementType, 32),
indices);
}
void AllocatorComputeOffsetOp::build(OpBuilder &builder, OperationState &state,
Value allocator, ValueRange shape,
Value elementType, ValueRange indices) {
state.addOperands({allocator});
state.addOperands(shape);
state.addOperands(elementType);
state.addOperands(indices);
state.addTypes({builder.getIndexType()});
}
void AllocatorComputeOffsetOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
setNameFn(offset(), "off");
}
//===----------------------------------------------------------------------===//
// hal.allocator.compute_range
//===----------------------------------------------------------------------===//
void AllocatorComputeRangeOp::build(OpBuilder &builder, OperationState &state,
Value allocator, ValueRange shape,
int32_t elementType, ValueRange indices,
ValueRange lengths) {
build(builder, state, allocator, shape,
builder.createOrFold<ConstantIntOp>(state.location, elementType, 32),
indices, lengths);
}
void AllocatorComputeRangeOp::build(OpBuilder &builder, OperationState &state,
Value allocator, ValueRange shape,
Value elementType, ValueRange indices,
ValueRange lengths) {
state.addOperands({allocator});
state.addOperands(shape);
state.addOperands(elementType);
state.addOperands(indices);
state.addOperands(lengths);
state.addTypes({builder.getIndexType(), builder.getIndexType()});
}
void AllocatorComputeRangeOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
setNameFn(offset(), "off");
setNameFn(length(), "len");
}
//===----------------------------------------------------------------------===//
// hal.allocator.allocate
//===----------------------------------------------------------------------===//
void AllocatorAllocateOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
setNameFn(result(), "buffer");
}
Value AllocatorAllocateOp::getOperandSize(unsigned idx) { return {}; }
Value AllocatorAllocateOp::getResultSize(unsigned idx) { return result_size(); }
//===----------------------------------------------------------------------===//
// hal.allocator.constant
//===----------------------------------------------------------------------===//
void AllocatorConstantOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
setNameFn(result(), "cbuffer");
}
//===----------------------------------------------------------------------===//
// hal.allocator.map
//===----------------------------------------------------------------------===//
void AllocatorMapOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
setNameFn(result(), "mapped");
}
Value AllocatorMapOp::getOperandSize(unsigned idx) { return {}; }
Value AllocatorMapOp::getResultSize(unsigned idx) { return length(); }
//===----------------------------------------------------------------------===//
// hal.allocator.pack
//===----------------------------------------------------------------------===//
void AllocatorPackOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
// TODO(benvanik): figure out if we can get the names to coalesce when there
// are multiple results. Ideally we'd have `%total_length, %offsets:123` but
// unfortunately all get splatted out and create 10k+ char lines that are a
// pain to read.
// setNameFn(total_length(), "total_length");
// for (auto packedOffset : llvm::enumerate(packed_offsets())) {
// setNameFn(packedOffset.value(),
// "offset" + std::to_string(packedOffset.index()));
// }
}
static LogicalResult verifyAllocatorPackOp(AllocatorPackOp op) {
size_t sliceCount = op.packed_offsets().size();
if (op.lifetime_intervals().size() != sliceCount * 2) {
return op.emitOpError() << "requires a [start, end] range for each slice";
}
if (op.dynamic_slice_sizes().size() != sliceCount) {
return op.emitOpError() << "requires a size for each slice";
}
return success();
}
SmallVector<AllocatorPackOp::Slice> AllocatorPackOp::getSlices() {
auto intervalPairs = lifetime_intervals().getValue();
auto sizes = dynamic_slice_sizes();
auto offsets = packed_offsets();
SmallVector<AllocatorPackOp::Slice> slices(offsets.size());
for (size_t i = 0; i < offsets.size(); ++i) {
int64_t start = intervalPairs[i * 2 + 0].cast<IntegerAttr>().getInt();
int64_t end = intervalPairs[i * 2 + 1].cast<IntegerAttr>().getInt();
slices[i] = {start, end, sizes[i], offsets[i]};
}
return slices;
}
//===----------------------------------------------------------------------===//
// hal.buffer.allocator
//===----------------------------------------------------------------------===//
void BufferAllocatorOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
setNameFn(result(), "allocator");
}
//===----------------------------------------------------------------------===//
// hal.buffer.subspan
//===----------------------------------------------------------------------===//
void BufferSubspanOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
setNameFn(result(), "buffer");
}
Value BufferSubspanOp::getOperandSize(unsigned idx) { return length(); }
Value BufferSubspanOp::getResultSize(unsigned idx) { return length(); }
//===----------------------------------------------------------------------===//
// hal.buffer.byte_length
//===----------------------------------------------------------------------===//
void BufferLengthOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
setNameFn(result(), "len");
}
//===----------------------------------------------------------------------===//
// hal.buffer_view.create
//===----------------------------------------------------------------------===//
void BufferViewCreateOp::build(OpBuilder &builder, OperationState &state,
Value buffer, int32_t elementType,
ValueRange shape) {
build(builder, state, buffer,
builder.createOrFold<ConstantIntOp>(state.location, elementType, 32),
shape);
}
void BufferViewCreateOp::build(OpBuilder &builder, OperationState &state,
Value buffer, Value elementType,
ValueRange shape) {
state.addOperands({buffer, elementType});
state.addOperands(shape);
state.addTypes({BufferViewType::get(builder.getContext())});
}
void BufferViewCreateOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
setNameFn(result(), "view");
}
//===----------------------------------------------------------------------===//
// hal.buffer_view.subview
//===----------------------------------------------------------------------===//
void BufferViewSubviewOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
setNameFn(result(), "view");
}
//===----------------------------------------------------------------------===//
// hal.buffer_view.buffer
//===----------------------------------------------------------------------===//
void BufferViewBufferOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
setNameFn(result(), "buffer");
}
//===----------------------------------------------------------------------===//
// hal.buffer_view.byte_length
//===----------------------------------------------------------------------===//
void BufferViewByteLengthOp::build(OpBuilder &builder, OperationState &state,
Value bufferView) {
state.addOperands({bufferView});
state.addTypes({builder.getIndexType()});
}
void BufferViewByteLengthOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
setNameFn(result(), "len");
}
//===----------------------------------------------------------------------===//
// hal.buffer_view.compute_offset
//===----------------------------------------------------------------------===//
void BufferViewComputeOffsetOp::build(OpBuilder &builder, OperationState &state,
Value bufferView, ValueRange indices) {
state.addOperands({bufferView});
state.addOperands(indices);
state.addTypes({builder.getIndexType()});
}
void BufferViewComputeOffsetOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
setNameFn(offset(), "off");
}
//===----------------------------------------------------------------------===//
// hal.buffer_view.compute_range
//===----------------------------------------------------------------------===//
void BufferViewComputeRangeOp::build(OpBuilder &builder, OperationState &state,
Value bufferView, ValueRange indices,
ValueRange lengths) {
state.addOperands({bufferView});
state.addOperands(indices);
state.addOperands(lengths);
state.addTypes({builder.getIndexType(), builder.getIndexType()});
}
void BufferViewComputeRangeOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
setNameFn(offset(), "off");
setNameFn(length(), "len");
}
//===----------------------------------------------------------------------===//
// hal.command_buffer.create
//===----------------------------------------------------------------------===//
void CommandBufferCreateOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
setNameFn(result(), "cmd");
}
//===----------------------------------------------------------------------===//
// hal.command_buffer.execution_barrier
//===----------------------------------------------------------------------===//
static ParseResult parseCommandBufferExecutionBarrierOp(
OpAsmParser &parser, OperationState *result) {
OpAsmParser::OperandType commandBuffer;
if (failed(parser.parseOperand(commandBuffer)) ||
failed(parser.resolveOperand(commandBuffer,
CommandBufferType::get(result->getContext()),
result->operands)) ||
failed(parser.parseComma()) ||
failed(parseEnumAttr<ExecutionStageBitfield>(parser, "source_stage_mask",
result->attributes)) ||
failed(parser.parseComma()) ||
failed(parseEnumAttr<ExecutionStageBitfield>(parser, "target_stage_mask",
result->attributes)) ||
failed(parser.parseComma()) ||
failed(parseEnumAttr<ExecutionBarrierFlagBitfield>(parser, "flags",
result->attributes))) {
return failure();
}
if (failed(parser.parseOptionalAttrDictWithKeyword(result->attributes))) {
return failure();
}
return success();
}
static void printCommandBufferExecutionBarrierOp(
OpAsmPrinter &p, CommandBufferExecutionBarrierOp op) {
p << op.getOperationName() << ' ';
p.printOperand(op.command_buffer());
p << ", \"";
p << stringifyExecutionStageBitfield(op.source_stage_mask());
p << "\", \"";
p << stringifyExecutionStageBitfield(op.target_stage_mask());
p << "\", \"";
p << stringifyExecutionBarrierFlagBitfield(op.flags());
p << "\"";
p.printOptionalAttrDictWithKeyword(op->getAttrs(),
/*elidedAttrs=*/{
"source_stage_mask",
"target_stage_mask",
"flags",
});
}
//===----------------------------------------------------------------------===//
// hal.command_buffer.push_descriptor_set
//===----------------------------------------------------------------------===//
void CommandBufferPushDescriptorSetOp::build(
OpBuilder &builder, OperationState &state, Value commandBuffer,
Value executableLayout, int64_t set,
ArrayRef<DescriptorSetBindingValue> bindings) {
build(builder, state, commandBuffer, executableLayout,
builder.createOrFold<ConstantIndexOp>(state.location, set), bindings);
}
void CommandBufferPushDescriptorSetOp::build(
OpBuilder &builder, OperationState &state, Value commandBuffer,
Value executableLayout, Value set,
ArrayRef<DescriptorSetBindingValue> bindings) {
state.addOperands({commandBuffer, executableLayout, set});
SmallVector<Value, 4> bindingOrdinals;
SmallVector<Value, 4> bindingBuffers;
SmallVector<Value, 4> bindingOffsets;
SmallVector<Value, 4> bindingLengths;
for (auto binding : bindings) {
bindingOrdinals.push_back(std::get<0>(binding));
bindingBuffers.push_back(std::get<1>(binding));
bindingOffsets.push_back(std::get<2>(binding));
bindingLengths.push_back(std::get<3>(binding));
}
state.addOperands(bindingOrdinals);
state.addOperands(bindingBuffers);
state.addOperands(bindingOffsets);
state.addOperands(bindingLengths);
}
//===----------------------------------------------------------------------===//
// hal.constant_pool
//===----------------------------------------------------------------------===//
void ConstantPoolOp::build(OpBuilder &builder, OperationState &state,
StringRef name,
BufferConstraintsAttr bufferConstraints) {
ensureTerminator(*state.addRegion(), builder, state.location);
state.addAttribute(mlir::SymbolTable::getSymbolAttrName(),
builder.getStringAttr(name));
state.addAttribute("buffer_constraints", bufferConstraints);
}
static ParseResult parseConstantPoolOp(OpAsmParser &parser,
OperationState *result) {
StringAttr nameAttr;
if (failed(parser.parseSymbolName(nameAttr,
mlir::SymbolTable::getSymbolAttrName(),
result->attributes)) ||
failed(parser.parseOptionalAttrDictWithKeyword(result->attributes))) {
return failure();
}
// Parse the module body.
auto *body = result->addRegion();
if (failed(parser.parseRegion(*body, llvm::None, llvm::None))) {
return failure();
}
// Ensure that this module has a valid terminator.
ConstantPoolOp::ensureTerminator(*body, parser.getBuilder(),
result->location);
return success();
}
static void printConstantPoolOp(OpAsmPrinter &p, ConstantPoolOp op) {
p << op.getOperationName() << ' ';
p.printSymbolName(op.sym_name());
p.printOptionalAttrDictWithKeyword(
op->getAttrs(),
/*elidedAttrs=*/{mlir::SymbolTable::getSymbolAttrName()});
p.printRegion(op.body(), /*printEntryBlockArgs=*/false,
/*printBlockTerminators=*/false);
}
//===----------------------------------------------------------------------===//
// hal.constant_pool.load
//===----------------------------------------------------------------------===//
void ConstantPoolLoadOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
setNameFn(result(), "const");
}
//===----------------------------------------------------------------------===//
// hal.constant_storage.lookup
//===----------------------------------------------------------------------===//
void ConstantStorageLookupOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
setNameFn(result(), "storage");
}
//===----------------------------------------------------------------------===//
// hal.constant.subspan
//===----------------------------------------------------------------------===//
void ConstantSubspanOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
setNameFn(result(), "const_span");
}
//===----------------------------------------------------------------------===//
// hal.descriptor_set.create
//===----------------------------------------------------------------------===//
void DescriptorSetCreateOp::build(
OpBuilder &builder, OperationState &state, Value device, Value setLayout,
ArrayRef<DescriptorSetBindingValue> bindings) {
state.addOperands({device, setLayout});
SmallVector<Value, 4> bindingOrdinals;
SmallVector<Value, 4> bindingBuffers;
SmallVector<Value, 4> bindingOffsets;
SmallVector<Value, 4> bindingLengths;
for (auto binding : bindings) {
bindingOrdinals.push_back(std::get<0>(binding));
bindingBuffers.push_back(std::get<1>(binding));
bindingOffsets.push_back(std::get<2>(binding));
bindingLengths.push_back(std::get<3>(binding));
}
state.addOperands(bindingOrdinals);
state.addOperands(bindingBuffers);
state.addOperands(bindingOffsets);
state.addOperands(bindingLengths);
}
void DescriptorSetCreateOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
setNameFn(result(), "descriptor_set");
}
//===----------------------------------------------------------------------===//
// hal.descriptor_set_layout.create
//===----------------------------------------------------------------------===//
void DescriptorSetLayoutCreateOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
setNameFn(result(), "descriptor_set_layout");
}
//===----------------------------------------------------------------------===//
// hal.descriptor_set_layout.lookup
//===----------------------------------------------------------------------===//
void DescriptorSetLayoutLookupOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
setNameFn(result(), "descriptor_set_layout");
}
//===----------------------------------------------------------------------===//
// hal.device.allocator
//===----------------------------------------------------------------------===//
void DeviceAllocatorOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
setNameFn(result(), "allocator");
}
//===----------------------------------------------------------------------===//
// hal.device.query
//===----------------------------------------------------------------------===//
static LogicalResult verifyDeviceQueryOp(DeviceQueryOp op) {
if (op.default_value().hasValue()) {
if (op.default_value()->getType() != op.value().getType()) {
return op.emitOpError()
<< "type mismatch between result and default value";
}
}
return success();
}
//===----------------------------------------------------------------------===//
// hal.device.switch
//===----------------------------------------------------------------------===//
void DeviceSwitchOp::build(OpBuilder &builder, OperationState &state,
TypeRange resultTypes, Value device,
ArrayRef<Attribute> conditions,
ArrayRef<NamedAttribute> attributes) {
state.addOperands({device});
state.addAttribute("conditions", builder.getArrayAttr(conditions));
for (size_t i = 0; i < conditions.size(); ++i) {
state.addRegion();
}
state.addTypes(resultTypes);
state.addAttributes(attributes);
}
static ParseResult parseDeviceSwitchOp(OpAsmParser &parser,
OperationState *result) {
OpAsmParser::OperandType device;
Type deviceType;
if (failed(parser.parseLess()) || failed(parser.parseOperand(device)) ||
failed(parser.parseColonType(deviceType)) ||
failed(parser.resolveOperand(device, deviceType, result->operands)) ||
failed(parser.parseGreater()) ||
failed(parser.parseOptionalArrowTypeList(result->types))) {
return failure();
}
// Parses each switch condition attribute and region, like:
// #hal.device.match.id<"vulkan-v1.?-*"> {
// hal.return %c1 : i32
// }, ...
SmallVector<Attribute, 4> conditionAttrs;
do {
Attribute conditionAttr;
NamedAttrList dummyAttrs;
if (failed(parser.parseAttribute(conditionAttr, "condition", dummyAttrs))) {
return failure();
}
conditionAttrs.push_back(conditionAttr);
SmallVector<OpAsmParser::OperandType> regionArgs;
SmallVector<Type> regionArgTypes;
auto *regionBody = result->addRegion();
if (failed(parser.parseRegion(*regionBody, regionArgs, regionArgTypes))) {
return failure();
}
} while (succeeded(parser.parseOptionalComma()));
result->addAttribute("conditions",
ArrayAttr::get(result->getContext(), conditionAttrs));
if (failed(parser.parseOptionalAttrDictWithKeyword(result->attributes))) {
return failure();
}
return success();
}
static void printDeviceSwitchOp(OpAsmPrinter &p, DeviceSwitchOp op) {
p << op.getOperationName() << "<";
p.printOperand(op.device());
p << " : ";
p.printType(op.device().getType());
p << ">";
p.printOptionalArrowTypeList(op.getResultTypes());
p << "\n";
p.getStream().indent(4);
interleave(
llvm::zip(op.conditions(), op.condition_regions()),
[&](std::tuple<Attribute, Region &> it) {
auto &conditionAttr = std::get<0>(it);
auto &conditionRegion = std::get<1>(it);
p.printAttribute(conditionAttr);
p.printRegion(conditionRegion,
/*printEntryBlockArgs=*/false,
/*printBlockTerminators=*/true);
},
[&]() {
p << ",\n";
p.getStream().indent(4);
});
p.printOptionalAttrDictWithKeyword(op->getAttrs(),
/*elidedAttrs=*/{"conditions"});
}
static LogicalResult verifyDeviceSwitchOp(DeviceSwitchOp op) {
if (op.conditions().size() != op.condition_regions().size()) {
return op.emitOpError() << "requires conditions and regions be matched 1:1";
} else if (op.condition_regions().empty()) {
return op.emitOpError() << "requires at least one condition";
}
for (auto &region : op.condition_regions()) {
for (auto &block : region) {
if (auto returnOp =
dyn_cast_or_null<IREE::HAL::ReturnOp>(block.getTerminator())) {
if (!std::equal(returnOp.getOperandTypes().begin(),
returnOp.getOperandTypes().end(),
op.getResultTypes().begin())) {
return op.emitOpError()
<< "requires all regions return the same types";
}
}
}
}
return success();
}
//===----------------------------------------------------------------------===//
// hal.executable
//===----------------------------------------------------------------------===//
void ExecutableOp::build(OpBuilder &builder, OperationState &state,
StringRef name) {
ensureTerminator(*state.addRegion(), builder, state.location);
state.addAttribute(mlir::SymbolTable::getSymbolAttrName(),
builder.getStringAttr(name));
}
static ParseResult parseExecutableOp(OpAsmParser &parser,
OperationState *result) {
StringAttr nameAttr;
if (failed(parser.parseSymbolName(nameAttr,
mlir::SymbolTable::getSymbolAttrName(),
result->attributes)) ||
failed(parser.parseOptionalAttrDictWithKeyword(result->attributes))) {
return failure();
}
// Parse the module body.
auto *body = result->addRegion();
if (failed(parser.parseRegion(*body, llvm::None, llvm::None))) {
return failure();
}
// Ensure that this module has a valid terminator.
ExecutableOp::ensureTerminator(*body, parser.getBuilder(), result->location);
return success();
}
static void printExecutableOp(OpAsmPrinter &p, ExecutableOp op) {
p << op.getOperationName() << ' ';
p.printSymbolName(op.sym_name());
p.printOptionalAttrDictWithKeyword(
op->getAttrs(),
/*elidedAttrs=*/{mlir::SymbolTable::getSymbolAttrName()});
p.printRegion(op.body(), /*printEntryBlockArgs=*/false,
/*printBlockTerminators=*/false);
}
static LogicalResult verifyExecutableOp(ExecutableOp op) {
// TODO(benvanik): check export name conflicts.
return success();
}
//===----------------------------------------------------------------------===//
// hal.executable.entry_point
//===----------------------------------------------------------------------===//
static ParseResult parseExecutableEntryPointOp(OpAsmParser &parser,
OperationState *result) {
StringAttr nameAttr;
if (failed(parser.parseSymbolName(nameAttr,
mlir::SymbolTable::getSymbolAttrName(),
result->attributes)) ||
failed(parser.parseOptionalAttrDictWithKeyword(result->attributes))) {
return failure();
}
// For now assume that the workload is at max 3D. So arguments to the region
// are workload along x, y and z.
std::unique_ptr<Region> region;
SmallVector<OpAsmParser::OperandType, 4> regionOperands;
SmallVector<Type, 4> regionTypes;
OptionalParseResult parseResult =
parser.parseOptionalRegion(region, regionOperands, regionTypes);
if (!parseResult.hasValue()) return success();
if (failed(*parseResult)) return failure();
result->addRegion(std::move(region));
return success();
}
static void printExecutableEntryPointOp(OpAsmPrinter &p,
ExecutableEntryPointOp op) {
p << op.getOperationName() << ' ';
p.printSymbolName(op.sym_name());
p.printOptionalAttrDictWithKeyword(op->getAttrs(),
/*elidedAttrs=*/{"sym_name"});
if (op.workgroup_count_region().empty()) return;
p.printRegion(op.workgroup_count_region().front());
}
static LogicalResult verifyExecutableEntryPointOp(ExecutableEntryPointOp op) {
Region *region = op.getBody();
// When there is no region, nothing to verify.
if (!region) return success();
if (!llvm::hasSingleElement(*region)) {
return op.emitOpError() << "expected a single region";
}
if (region->getNumArguments() != 3) {
return op.emitOpError(
"expected three arguments for workgroup_count_region for workload "
"along "
"x, y, and z");
}
for (BlockArgument &blockArg : region->getArguments()) {
if (!blockArg.getType().isa<IndexType>()) {
return op.emitOpError(
"expected arguments to workgroup_count_region be index type");
}
}
// Check that the last statement in the block is `hal.yield` operation.
// TODO(ravishankarm): The SingleBlockImplicitTerminator<"HAL::ReturnOp">
// should generate this check, but it doesnt.
auto returnOp = dyn_cast<ReturnOp>(region->front().getTerminator());
if (!returnOp || returnOp.operands().size() != 3) {
return op.emitOpError("expected operation to yield 3 values");
}
return success();
}
//===----------------------------------------------------------------------===//
// hal.executable.variant
//===----------------------------------------------------------------------===//
void ExecutableVariantOp::build(OpBuilder &builder, OperationState &state,
StringRef symName,
IREE::HAL::ExecutableTargetAttr target) {
ensureTerminator(*state.addRegion(), builder, state.location);
state.addAttribute(mlir::SymbolTable::getSymbolAttrName(),
builder.getStringAttr(symName));
state.addAttribute("target", target);
}
static ParseResult parseExecutableVariantOp(OpAsmParser &parser,
OperationState *result) {
auto *body = result->addRegion();
StringAttr nameAttr;
IREE::HAL::ExecutableTargetAttr targetAttr;
if (failed(parser.parseSymbolName(nameAttr,
mlir::SymbolTable::getSymbolAttrName(),
result->attributes)) ||
failed(parser.parseComma()) || failed(parser.parseKeyword("target")) ||
failed(parser.parseEqual()) ||
failed(parser.parseAttribute(targetAttr, "target", result->attributes)) ||
failed(parser.parseOptionalAttrDictWithKeyword(result->attributes))) {
return failure();
}
OptionalParseResult parseResult = parser.parseOptionalRegion(*body);
if (parseResult.hasValue() && failed(*parseResult)) {
return failure();
}
// Ensure that this module has a valid terminator.
ExecutableVariantOp::ensureTerminator(*body, parser.getBuilder(),
result->location);
return success();
}
static void printExecutableVariantOp(OpAsmPrinter &p, ExecutableVariantOp op) {
p << op.getOperationName() << ' ';
p.printSymbolName(op.sym_name());
p << ", target = " << op.target();
p.printOptionalAttrDictWithKeyword(
op->getAttrs(),
/*elidedAttrs=*/{mlir::SymbolTable::getSymbolAttrName(), "target"});
if (!op.body().empty()) {
p.printRegion(op.body(), /*printEntryBlockArgs=*/false,
/*printBlockTerminators=*/false);
}
}
//===----------------------------------------------------------------------===//
// hal.executable.binary
//===----------------------------------------------------------------------===//
void ExecutableBinaryOp::build(OpBuilder &builder, OperationState &state,
StringRef symName, StringRef format,
std::vector<uint8_t> data) {
state.addAttribute(mlir::SymbolTable::getSymbolAttrName(),
builder.getStringAttr(symName));
state.addAttribute("format", builder.getStringAttr(format));
state.addAttribute("data",
DenseIntElementsAttr::get(
VectorType::get({static_cast<int64_t>(data.size())},
builder.getIntegerType(8)),
data));
}
void ExecutableBinaryOp::build(OpBuilder &builder, OperationState &state,
StringRef symName, StringAttr format,
DenseIntElementsAttr data) {
state.addAttribute(mlir::SymbolTable::getSymbolAttrName(),
builder.getStringAttr(symName));
state.addAttribute("format", format);
state.addAttribute("data", data);
}
static ParseResult parseExecutableBinaryOp(OpAsmParser &parser,
OperationState *result) {
StringAttr nameAttr;
if (failed(parser.parseSymbolName(nameAttr,
mlir::SymbolTable::getSymbolAttrName(),
result->attributes)) ||
failed(parser.parseOptionalAttrDictWithKeyword(result->attributes))) {
return failure();
}
return success();
}
static void printExecutableBinaryOp(OpAsmPrinter &p, ExecutableBinaryOp op) {
p << op.getOperationName() << ' ';
p.printSymbolName(op.sym_name());
p.printOptionalAttrDictWithKeyword(
op->getAttrs(),
/*elidedAttrs=*/{mlir::SymbolTable::getSymbolAttrName()});
}
//===----------------------------------------------------------------------===//
// hal.executable.create
//===----------------------------------------------------------------------===//
void ExecutableCreateOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
setNameFn(result(), StringRef("exe"));
}
//===----------------------------------------------------------------------===//
// hal.executable.lookup
//===----------------------------------------------------------------------===//
void ExecutableLookupOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
setNameFn(result(), "exe");
}
//===----------------------------------------------------------------------===//
// hal.interface
//===----------------------------------------------------------------------===//
void InterfaceOp::build(OpBuilder &builder, OperationState &state,
StringRef name, IntegerAttr pushConstants) {
ensureTerminator(*state.addRegion(), builder, state.location);
state.addAttribute(mlir::SymbolTable::getSymbolAttrName(),
builder.getStringAttr(name));
if (pushConstants) {
state.addAttribute("push_constants", pushConstants);
}
}
static ParseResult parseInterfaceOp(OpAsmParser &parser,
OperationState *result) {
StringAttr nameAttr;
if (failed(parser.parseSymbolName(nameAttr,
mlir::SymbolTable::getSymbolAttrName(),
result->attributes)) ||
failed(parser.parseOptionalAttrDictWithKeyword(result->attributes))) {
return failure();
}
// Parse the module body.
auto *body = result->addRegion();
if (failed(parser.parseRegion(*body, llvm::None, llvm::None))) {
return failure();
}
// Ensure that this module has a valid terminator.
InterfaceOp::ensureTerminator(*body, parser.getBuilder(), result->location);
return success();
}
static void printInterfaceOp(OpAsmPrinter &p, InterfaceOp op) {
p << op.getOperationName() << ' ';
p.printSymbolName(op.sym_name());
p.printOptionalAttrDictWithKeyword(
op->getAttrs(),
/*elidedAttrs=*/{mlir::SymbolTable::getSymbolAttrName()});
p.printRegion(op.body(), /*printEntryBlockArgs=*/false,
/*printBlockTerminators=*/false);
}
ArrayAttr InterfaceOp::getExecutableSetLayoutsAttr() {
Builder builder(getContext());
SmallVector<SmallVector<Attribute, 4>, 4> setAttrs;
for (auto bindingOp : getBlock().getOps<InterfaceBindingOp>()) {
unsigned set = bindingOp.set().getZExtValue();
unsigned binding = bindingOp.binding().getZExtValue();
if (set >= setAttrs.size()) setAttrs.resize(set + 1);
auto &bindingAttrs = setAttrs[set];
if (binding >= bindingAttrs.size()) bindingAttrs.resize(binding + 1);
bindingAttrs[binding] = DescriptorSetLayoutBindingAttr::get(
bindingOp.bindingAttr(), bindingOp.typeAttr(), bindingOp.accessAttr());
}
return builder.getArrayAttr(llvm::to_vector<4>(
llvm::map_range(setAttrs, [&](ArrayRef<Attribute> bindingsArray) {
return builder.getArrayAttr(bindingsArray).cast<Attribute>();
})));
}
bool InterfaceOp::isEquivalentTo(InterfaceOp other) {
auto bindings = llvm::to_vector<4>(getBlock().getOps<InterfaceBindingOp>());
auto otherBindings =
llvm::to_vector<4>(other.getBlock().getOps<InterfaceBindingOp>());
return push_constantsAttr() == other.push_constantsAttr() &&
bindings.size() == otherBindings.size() &&
llvm::all_of(llvm::zip(bindings, otherBindings), [](auto bindings) {
return OperationEquivalence::isEquivalentTo(
std::get<0>(bindings), std::get<1>(bindings),
OperationEquivalence::exactValueMatch,
OperationEquivalence::exactValueMatch,
OperationEquivalence::Flags::IgnoreLocations);
});
}
llvm::hash_code InterfaceOp::getInterfaceHash() {
auto range = llvm::map_range(getBlock().getOps<InterfaceBindingOp>(),
[](InterfaceBindingOp bindingOp) {
return bindingOp.getDescriptorHash();
});
return llvm::hash_combine(
push_constants(), llvm::hash_combine_range(range.begin(), range.end()));
}
//===----------------------------------------------------------------------===//
// hal.interface.binding
//===----------------------------------------------------------------------===//
static ParseResult parseInterfaceBindingOp(OpAsmParser &parser,
OperationState *result) {
StringAttr nameAttr;
IntegerAttr setAttr;
IntegerAttr bindingAttr;
if (failed(parser.parseSymbolName(nameAttr,
mlir::SymbolTable::getSymbolAttrName(),
result->attributes)) ||
failed(parser.parseComma()) || failed(parser.parseKeyword("set")) ||
failed(parser.parseEqual()) ||
failed(parser.parseAttribute(setAttr, parser.getBuilder().getIndexType(),
"set", result->attributes)) ||
failed(parser.parseComma()) || failed(parser.parseKeyword("binding")) ||
failed(parser.parseEqual()) ||
failed(parser.parseAttribute(bindingAttr,
parser.getBuilder().getIndexType(),
"binding", result->attributes)) ||
failed(parser.parseComma()) || failed(parser.parseKeyword("type")) ||
failed(parser.parseEqual()) ||
failed(
parseEnumAttr<DescriptorType>(parser, "type", result->attributes)) ||
failed(parser.parseComma()) || failed(parser.parseKeyword("access")) ||
failed(parser.parseEqual()) ||
failed(parseEnumAttr<MemoryAccessBitfield>(parser, "access",
result->attributes)) ||
failed(parser.parseOptionalAttrDictWithKeyword(result->attributes))) {
return failure();
}
return success();
}
static void printInterfaceBindingOp(OpAsmPrinter &p, InterfaceBindingOp op) {
p << op.getOperationName() << ' ';
p.printSymbolName(op.sym_name());
p << ", set=" << op.set();
p << ", binding=" << op.binding();
p << ", type=\"" << stringifyDescriptorType(op.type()) << "\"";
p << ", access=\"" << stringifyMemoryAccessBitfield(op.access()) << "\"";
p.printOptionalAttrDictWithKeyword(op->getAttrs(),
/*elidedAttrs=*/{
mlir::SymbolTable::getSymbolAttrName(),
"set",
"binding",
"type",
"access",
});
}
llvm::hash_code InterfaceBindingOp::getDescriptorHash() {
// Use the unwrapped attribute accessors so that we can have determinstic
// hashes. Hashing against the wrapped attributes are hashing against pointer
// values, which change per run.
return llvm::hash_combine(set(), binding(), type(), access());
}
//===----------------------------------------------------------------------===//
// hal.interface.binding.subspan
//===----------------------------------------------------------------------===//
InterfaceBindingOp InterfaceBindingSubspanOp::queryBindingOp() {
return dyn_cast_or_null<InterfaceBindingOp>(
SymbolTable::lookupNearestSymbolFrom(getOperation(), binding()));
}
//===----------------------------------------------------------------------===//
// hal.interface.workgroup.*
//===----------------------------------------------------------------------===//
static void getAsmResultNamesForInterfaceWorkgroupOp(
StringRef prefix, const APInt &dimension, Value result,
function_ref<void(Value, StringRef)> setNameFn) {
switch (dimension.getZExtValue()) {
case 0:
setNameFn(result, (prefix + "x").str());
return;
case 1:
setNameFn(result, (prefix + "y").str());
return;
case 2:
setNameFn(result, (prefix + "z").str());
return;
}
}
void InterfaceWorkgroupIDOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
getAsmResultNamesForInterfaceWorkgroupOp("workgroup_id_", dimension(),
result(), setNameFn);
}
void InterfaceWorkgroupCountOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
getAsmResultNamesForInterfaceWorkgroupOp("workgroup_count_", dimension(),
result(), setNameFn);
}
void InterfaceWorkgroupSizeOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
getAsmResultNamesForInterfaceWorkgroupOp("workgroup_size_", dimension(),
result(), setNameFn);
}
//===----------------------------------------------------------------------===//
// hal.executable_layout.create
//===----------------------------------------------------------------------===//
void ExecutableLayoutCreateOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
setNameFn(result(), "executable_layout");
}
//===----------------------------------------------------------------------===//
// hal.executable_layout.lookup
//===----------------------------------------------------------------------===//
void ExecutableLayoutLookupOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
setNameFn(result(), "executable_layout");
}
//===----------------------------------------------------------------------===//
// hal.semaphore.create
//===----------------------------------------------------------------------===//
void SemaphoreCreateOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
setNameFn(result(), "semaphore");
}
} // namespace HAL
} // namespace IREE
} // namespace iree_compiler
} // namespace mlir
//===----------------------------------------------------------------------===//
// TableGen definitions (intentionally last)
//===----------------------------------------------------------------------===//
#define GET_OP_CLASSES
#include "iree/compiler/Dialect/HAL/IR/HALOps.cpp.inc"