blob: 63a72e9e159bb95f94b76a7d598c7c8b1fd64a3b [file] [log] [blame]
// Copyright 2020 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree/compiler/Dialect/Vulkan/IR/VulkanDialect.h"
#include "iree/compiler/Dialect/Vulkan/IR/VulkanAttributes.h"
#include "iree/compiler/Dialect/Vulkan/IR/VulkanTypes.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/SMLoc.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/DialectImplementation.h"
namespace mlir {
namespace iree_compiler {
namespace IREE {
namespace Vulkan {
VulkanDialect::VulkanDialect(MLIRContext *context)
: Dialect(getDialectNamespace(), context, TypeID::get<VulkanDialect>()) {
registerAttributes();
}
//===----------------------------------------------------------------------===//
// Attribute Parsing
//===----------------------------------------------------------------------===//
namespace {
/// Parses a comma-separated list of keywords, invokes `processKeyword` on each
/// of the parsed keyword, and returns failure if any error occurs.
ParseResult parseKeywordList(
DialectAsmParser &parser,
function_ref<LogicalResult(llvm::SMLoc, StringRef)> processKeyword) {
if (parser.parseLSquare()) return failure();
// Special case for empty list.
if (succeeded(parser.parseOptionalRSquare())) return success();
// Keep parsing the keyword and an optional comma following it. If the comma
// is successfully parsed, then we have more keywords to parse.
do {
auto loc = parser.getCurrentLocation();
StringRef keyword;
if (parser.parseKeyword(&keyword) || failed(processKeyword(loc, keyword)))
return failure();
} while (succeeded(parser.parseOptionalComma()));
if (parser.parseRSquare()) return failure();
return success();
}
/// Parses a TargetEnvAttr.
Attribute parseTargetAttr(DialectAsmParser &parser) {
if (parser.parseLess()) return {};
Builder &builder = parser.getBuilder();
IntegerAttr versionAttr;
{
auto loc = parser.getCurrentLocation();
StringRef version;
if (parser.parseKeyword(&version) || parser.parseComma()) return {};
if (auto versionSymbol = symbolizeVersion(version)) {
versionAttr =
builder.getI32IntegerAttr(static_cast<uint32_t>(*versionSymbol));
} else {
parser.emitError(loc, "unknown Vulkan version: ") << version;
return {};
}
}
IntegerAttr revisionAttr;
{
unsigned revision = 0;
// TODO(antiagainst): it would be nice to parse rN instad of r(N).
if (parser.parseKeyword("r") || parser.parseLParen() ||
parser.parseInteger(revision) || parser.parseRParen() ||
parser.parseComma())
return {};
revisionAttr = builder.getI32IntegerAttr(revision);
}
ArrayAttr extensionsAttr;
{
SmallVector<Attribute, 1> extensions;
llvm::SMLoc errorloc;
StringRef errorKeyword;
auto processExtension = [&](llvm::SMLoc loc, StringRef extension) {
if (symbolizeExtension(extension)) {
extensions.push_back(builder.getStringAttr(extension));
return success();
}
return errorloc = loc, errorKeyword = extension, failure();
};
if (parseKeywordList(parser, processExtension) || parser.parseComma()) {
if (!errorKeyword.empty())
parser.emitError(errorloc, "unknown Vulkan extension: ")
<< errorKeyword;
return {};
}
extensionsAttr = builder.getArrayAttr(extensions);
}
// Parse vendor:device-type[:device-id]
spirv::Vendor vendorID = spirv::Vendor::Unknown;
spirv::DeviceType deviceType = spirv::DeviceType::Unknown;
uint32_t deviceID = spirv::TargetEnvAttr::kUnknownDeviceID;
{
auto loc = parser.getCurrentLocation();
StringRef vendorStr;
if (parser.parseKeyword(&vendorStr)) return {};
if (auto vendorSymbol = spirv::symbolizeVendor(vendorStr)) {
vendorID = *vendorSymbol;
} else {
parser.emitError(loc, "unknown vendor: ") << vendorStr;
}
loc = parser.getCurrentLocation();
StringRef deviceTypeStr;
if (parser.parseColon() || parser.parseKeyword(&deviceTypeStr)) return {};
if (auto deviceTypeSymbol = spirv::symbolizeDeviceType(deviceTypeStr)) {
deviceType = *deviceTypeSymbol;
} else {
parser.emitError(loc, "unknown device type: ") << deviceTypeStr;
}
loc = parser.getCurrentLocation();
if (succeeded(parser.parseOptionalColon())) {
if (parser.parseInteger(deviceID)) return {};
}
if (parser.parseComma()) return {};
}
DictionaryAttr capabilities;
{
auto loc = parser.getCurrentLocation();
if (parser.parseAttribute(capabilities)) return {};
if (!capabilities.isa<CapabilitiesAttr>()) {
parser.emitError(loc,
"capabilities must be a vulkan::CapabilitiesAttr "
"dictionary attribute");
return {};
}
}
if (parser.parseGreater()) return {};
return TargetEnvAttr::get(versionAttr, revisionAttr, extensionsAttr, vendorID,
deviceType, deviceID, capabilities);
}
} // anonymous namespace
Attribute VulkanDialect::parseAttribute(DialectAsmParser &parser,
Type type) const {
// Vulkan attributes do not have type.
if (type) {
parser.emitError(parser.getNameLoc(), "unexpected type");
return {};
}
// Parse the kind keyword first.
StringRef attrKind;
if (parser.parseKeyword(&attrKind)) return {};
if (attrKind == TargetEnvAttr::getKindName()) return parseTargetAttr(parser);
parser.emitError(parser.getNameLoc(), "unknown Vulkan attriubte kind: ")
<< attrKind;
return {};
}
//===----------------------------------------------------------------------===//
// Attribute Printing
//===----------------------------------------------------------------------===//
namespace {
void print(TargetEnvAttr targetEnv, DialectAsmPrinter &printer) {
auto &os = printer.getStream();
printer << TargetEnvAttr::getKindName() << "<"
<< stringifyVersion(targetEnv.getVersion()) << ", r("
<< targetEnv.getRevision() << "), [";
interleaveComma(targetEnv.getExtensionsAttr(), os, [&](Attribute attr) {
os << attr.cast<StringAttr>().getValue();
});
printer << "], " << spirv::stringifyVendor(targetEnv.getVendorID());
printer << ":" << spirv::stringifyDeviceType(targetEnv.getDeviceType());
auto deviceID = targetEnv.getDeviceID();
if (deviceID != spirv::TargetEnvAttr::kUnknownDeviceID) {
printer << ":" << targetEnv.getDeviceID();
}
printer << ", " << targetEnv.getCapabilitiesAttr() << ">";
}
} // anonymous namespace
void VulkanDialect::printAttribute(Attribute attr,
DialectAsmPrinter &printer) const {
if (auto targetEnv = attr.dyn_cast<TargetEnvAttr>())
print(targetEnv, printer);
else
assert(false && "unhandled Vulkan attribute kind");
}
} // namespace Vulkan
} // namespace IREE
} // namespace iree_compiler
} // namespace mlir