blob: daf4e09ffc5095ad0a952b9d000ac671b6d43e8b [file] [log] [blame]
// Copyright 2019 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/TableGen/Error.h"
#include "llvm/TableGen/Record.h"
#include "llvm/TableGen/TableGenBackend.h"
#include "mlir/TableGen/Attribute.h"
#include "mlir/TableGen/Format.h"
#include "mlir/TableGen/GenInfo.h"
#include "mlir/TableGen/Operator.h"
namespace mlir {
namespace tblgen {
namespace iree_compiler {
namespace {
using llvm::formatv;
using llvm::raw_ostream;
using llvm::Record;
using llvm::RecordKeeper;
using llvm::StringRef;
class StructFieldAttr {
public:
explicit StructFieldAttr(const llvm::Record *record) : def(record) {
assert(def->isSubClassOf("Util_StructFieldAttr") &&
"must be subclass of TableGen 'Util_StructFieldAttr' class");
}
explicit StructFieldAttr(const llvm::Record &record)
: StructFieldAttr(&record) {}
explicit StructFieldAttr(const llvm::DefInit *init)
: StructFieldAttr(init->getDef()) {}
StringRef getName() const { return def->getValueAsString("name"); }
Attribute getType() const {
auto init = def->getValueInit("type");
return tblgen::Attribute(cast<llvm::DefInit>(init));
}
private:
const llvm::Record *def;
};
class StructAttr : public Attribute {
public:
explicit StructAttr(const llvm::Record *record) : Attribute(record) {
assert(isSubClassOf("Util_StructAttr") &&
"must be subclass of TableGen 'Util_StructAttr' class");
}
explicit StructAttr(const llvm::Record &record) : StructAttr(&record) {}
explicit StructAttr(const llvm::DefInit *init) : StructAttr(init->getDef()) {}
StringRef getStructKind() const { return def->getValueAsString("kind"); }
StringRef getStructClassName() const {
return def->getValueAsString("className");
}
StringRef getCppNamespace() const {
if (def->isValueUnset("cppNamespace")) {
Dialect dialect(def->getValueAsDef("structDialect"));
return dialect.getCppNamespace();
} else {
return def->getValueAsString("cppNamespace");
}
}
std::vector<StructFieldAttr> getAllFields() const {
std::vector<StructFieldAttr> attributes;
const auto *inits = def->getValueAsListInit("fields");
attributes.reserve(inits->size());
for (const llvm::Init *init : *inits) {
attributes.emplace_back(cast<llvm::DefInit>(init));
}
return attributes;
}
};
static void emitStructClass(const StructAttr &structAttr, raw_ostream &os) {
if (!structAttr.getAllFields().empty()) {
os << formatv(R"(
namespace detail {
struct {0}Storage;
} // namespace detail
)",
structAttr.getStructClassName());
}
os << formatv(R"(
// {0}
class {1} : public mlir::Attribute::AttrBase<{1}, mlir::Attribute, {3}Storage> {
public:
using Base::Base;
static StringRef getKindName() { return "{2}"; }
)",
structAttr.getSummary(), structAttr.getStructClassName(),
structAttr.getStructKind(),
structAttr.getAllFields().empty()
? "Attribute"
: "detail::" + structAttr.getStructClassName());
if (!structAttr.getAllFields().empty()) {
os << " static LogicalResult verify(\n";
os << " function_ref<InFlightDiagnostic()> emitError,\n";
interleave(
structAttr.getAllFields(), os,
[&](StructFieldAttr field) {
auto type = field.getType();
os << formatv(" {0} {1}", type.getStorageType(),
field.getName());
},
",\n");
os << ");\n\n";
}
// Attribute storage type constructor (IntegerAttr, etc).
os << formatv(" static {0} get(", structAttr.getStructClassName());
if (structAttr.getAllFields().empty()) {
os << "mlir::MLIRContext* context";
} else {
interleaveComma(structAttr.getAllFields(), os, [&](StructFieldAttr field) {
auto type = field.getType();
os << formatv("\n {0} {1}", type.getStorageType(), field.getName());
});
}
os << ");\n\n";
// Attribute return type constructor (APInt, etc).
if (!structAttr.getAllFields().empty()) {
os << formatv(" static {0} get(\n", structAttr.getStructClassName());
for (auto field : structAttr.getAllFields()) {
auto type = field.getType();
os << formatv(" {0} {1},\n", type.getReturnType(), field.getName());
}
os << " mlir::MLIRContext* context);\n";
}
os << R"(
static Attribute parse(DialectAsmParser &p);
void print(DialectAsmPrinter &p) const;
)";
for (auto field : structAttr.getAllFields()) {
auto type = field.getType();
// Attribute storage type accessors (IntegerAttr, etc).
os << formatv(" {0} {1}Attr() const;\n", type.getStorageType(),
field.getName());
// Attribute return type accessors (APInt, etc).
os << formatv(" {0} {1}() const;\n", type.getReturnType(),
field.getName());
}
os << " void walkStorage(const llvm::function_ref<void(mlir::Attribute "
"elementAttr)> &fn) const;\n";
os << "};\n\n";
}
static void emitStructDecl(const Record &structDef, raw_ostream &os) {
StructAttr structAttr(&structDef);
// Forward declarations (to make including easier).
os << R"(namespace mlir {
class DialectAsmParser;
class DialectAsmPrinter;
} // namespace mlir
)";
// Wrap in the appropriate namespace.
llvm::SmallVector<StringRef, 2> namespaces;
llvm::SplitString(structAttr.getCppNamespace(), namespaces, "::");
for (auto ns : namespaces) {
os << "namespace " << ns << " {\n";
}
// Emit the struct class definition
emitStructClass(structAttr, os);
// Close the declared namespace.
for (auto ns : namespaces) {
os << "} // namespace " << ns << "\n";
}
}
static bool emitStructDecls(const RecordKeeper &recordKeeper, raw_ostream &os) {
llvm::emitSourceFileHeader("Struct Attr Declarations", os);
auto defs = recordKeeper.getAllDerivedDefinitions("Util_StructAttr");
for (const auto *def : defs) {
emitStructDecl(*def, os);
}
return false;
}
static void emitStorageDef(const StructAttr &structAttr, raw_ostream &os) {
os << "namespace detail {\n";
os << formatv("struct {0}Storage : public mlir::AttributeStorage {{\n",
structAttr.getStructClassName());
os << " using KeyTy = std::tuple<";
interleaveComma(structAttr.getAllFields(), os, [&](StructFieldAttr field) {
auto type = field.getType();
os << type.getStorageType();
});
os << ">;\n\n";
os << formatv(" {0}Storage(", structAttr.getStructClassName());
interleaveComma(structAttr.getAllFields(), os, [&](StructFieldAttr field) {
auto type = field.getType();
os << formatv("{0} {1}", type.getStorageType(), field.getName());
});
os << ") : ";
interleaveComma(structAttr.getAllFields(), os, [&](StructFieldAttr field) {
os << formatv("{0}({0})", field.getName());
});
os << " {}\n\n";
os << " bool operator==(const KeyTy &key) const {\n";
os << " return ";
int i = 0;
interleave(
structAttr.getAllFields(), os,
[&](StructFieldAttr field) {
os << formatv("std::get<{0}>(key) == {1}", i++, field.getName());
},
" && ");
os << ";\n }\n\n";
os << " static llvm::hash_code hashKey(const KeyTy &key) {\n";
os << " return llvm::hash_combine(";
i = 0;
interleaveComma(structAttr.getAllFields(), os, [&](StructFieldAttr field) {
os << formatv("std::get<{0}>(key)", i++, field.getName());
});
os << ");\n";
os << "}\n\n";
os << formatv(
" static {0}Storage *construct(AttributeStorageAllocator &allocator, "
"const KeyTy &key) {{\n",
structAttr.getStructClassName());
os << formatv(
" return new (allocator.allocate<{0}Storage>()) {0}Storage(\n",
structAttr.getStructClassName());
i = 0;
os << " ";
interleaveComma(structAttr.getAllFields(), os, [&](StructFieldAttr field) {
os << formatv("std::get<{0}>(key)", i++, field.getName());
});
os << ");\n";
os << " }\n\n";
for (auto field : structAttr.getAllFields()) {
auto type = field.getType();
os << formatv(" {0} {1};\n", type.getStorageType(), field.getName());
}
os << "};\n";
os << "} // namespace detail\n\n";
}
static void emitVerifierDef(const StructAttr &structAttr, raw_ostream &os) {
os << "// static\n";
os << formatv("LogicalResult {0}::verify(\n",
structAttr.getStructClassName());
os << " function_ref<InFlightDiagnostic()> emitError,\n";
interleave(
structAttr.getAllFields(), os,
[&](StructFieldAttr field) {
auto type = field.getType();
os << formatv(" {0} {1}", type.getStorageType(), field.getName());
},
",\n");
os << ") {\n";
for (auto field : structAttr.getAllFields()) {
FmtContext fmt;
auto type = field.getType();
os << formatv(R"(
if (!{0}) {{
return emitError() << "'{1}' must be {2} but got " << {1}.getType();
}
)",
tgfmt(type.getConditionTemplate(),
&fmt.withSelf(field.getName()), field.getName()),
field.getName(), type.getSummary());
}
os << " return success();\n";
os << "}\n\n";
}
static void emitAttrFactoryDef(const StructAttr &structAttr, raw_ostream &os) {
os << "// static\n";
os << formatv("{0} {0}::get(", structAttr.getStructClassName());
if (structAttr.getAllFields().empty()) {
os << "mlir::MLIRContext* context";
} else {
interleaveComma(structAttr.getAllFields(), os, [&](StructFieldAttr field) {
auto type = field.getType();
os << formatv("\n {0} {1}", type.getStorageType(), field.getName());
});
}
os << ") {\n";
for (auto field : structAttr.getAllFields()) {
if (!field.getType().isOptional()) {
os << formatv(" assert({0} && \"{0} is required\");\n", field.getName());
}
}
if (!structAttr.getAllFields().empty()) {
os << formatv(" auto *context = {0}.getContext();\n",
structAttr.getAllFields().front().getName());
}
os << formatv(" return Base::get(context");
if (!structAttr.getAllFields().empty()) {
os << ",\n ";
interleaveComma(structAttr.getAllFields(), os,
[&](StructFieldAttr field) { os << field.getName(); });
}
os << ");\n";
os << "}\n\n";
}
// Replaces all occurrences of `match` in `str` with `substitute`.
static std::string replaceAllSubstrs(std::string str, const std::string &match,
const std::string &substitute) {
std::string::size_type scanLoc = 0, matchLoc = std::string::npos;
while ((matchLoc = str.find(match, scanLoc)) != std::string::npos) {
str = str.replace(matchLoc, match.size(), substitute);
scanLoc = matchLoc + substitute.size();
}
return str;
}
static void emitTypedFactoryDef(const StructAttr &structAttr, raw_ostream &os) {
os << "// static\n";
os << formatv("{0} {0}::get(", structAttr.getStructClassName());
for (auto field : structAttr.getAllFields()) {
auto type = field.getType();
os << formatv("\n {0} {1},", type.getReturnType(), field.getName());
}
os << "\n mlir::MLIRContext* context) {\n";
os << " mlir::Builder b(context);\n";
FmtContext ctx;
ctx.withBuilder("b");
for (auto field : structAttr.getAllFields()) {
auto type = field.getType();
// For StringAttr, its constant builder call will wrap the input in
// quotes, which is correct for normal string literals, but incorrect
// here given we use function arguments. So we need to strip the
// wrapping quotes.
std::string builderTemplate = type.getConstBuilderTemplate().str();
if (StringRef(builderTemplate).contains("\"$0\"")) {
builderTemplate = replaceAllSubstrs(builderTemplate, "\"$0\"", "$0");
}
os << formatv(" auto {0}Attr = {1};\n", field.getName(),
tgfmt(builderTemplate, &ctx, field.getName()));
}
os << " return get(";
if (structAttr.getAllFields().empty()) {
os << "context";
} else {
interleaveComma(structAttr.getAllFields(), os, [&](StructFieldAttr attr) {
os << attr.getName() << "Attr";
});
}
os << ");\n";
os << "}\n";
}
static void emitAccessorDefs(const StructAttr &structAttr,
const StructFieldAttr &field, raw_ostream &os) {
auto type = field.getType();
// Attribute storage type accessors (IntegerAttr, etc).
os << formatv(R"(
{1} {0}::{2}Attr() const {{
return getImpl()->{2};
}
)",
structAttr.getStructClassName(), type.getStorageType(),
field.getName());
// Attribute return type accessors (APInt, etc).
FmtContext ctx;
os << formatv(
R"(
{1} {0}::{2}() const {{
return {3};
}
)",
structAttr.getStructClassName(), type.getReturnType(), field.getName(),
tgfmt(type.getConvertFromStorageCall(),
&ctx.withSelf(field.getName() + "Attr()")));
}
static void emitWalkStorageDef(const StructAttr &structAttr, raw_ostream &os) {
os << formatv(
"void {0}::walkStorage(const llvm::function_ref<void(mlir::Attribute "
"elementAttr)> &fn) const {{\n",
structAttr.getStructClassName());
for (auto field : structAttr.getAllFields()) {
os << formatv(" fn({0}Attr());\n", field.getName());
}
os << "}\n";
}
static void emitStructDef(const Record &structDef, raw_ostream &os) {
StructAttr structAttr(&structDef);
StringRef cppNamespace = structAttr.getCppNamespace();
llvm::SmallVector<StringRef, 2> namespaces;
llvm::SplitString(cppNamespace, namespaces, "::");
for (auto ns : namespaces) {
os << "namespace " << ns << " {\n";
}
os << "\n";
if (!structAttr.getAllFields().empty()) {
emitStorageDef(structAttr, os);
emitVerifierDef(structAttr, os);
}
emitAttrFactoryDef(structAttr, os);
if (!structAttr.getAllFields().empty()) {
emitTypedFactoryDef(structAttr, os);
for (auto field : structAttr.getAllFields()) {
emitAccessorDefs(structAttr, field, os);
}
}
emitWalkStorageDef(structAttr, os);
os << "\n";
for (auto ns : llvm::reverse(namespaces)) {
os << "} // namespace " << ns << "\n";
}
}
static bool emitStructDefs(const RecordKeeper &recordKeeper, raw_ostream &os) {
llvm::emitSourceFileHeader("Struct Attr Definitions", os);
auto defs = recordKeeper.getAllDerivedDefinitions("Util_StructAttr");
for (const auto *def : defs) {
emitStructDef(*def, os);
}
return false;
}
// Registers the struct utility generator to mlir-tblgen.
static GenRegistration genStructDecls("gen-iree-struct-attr-decls",
"Generate struct attr declarations",
[](const RecordKeeper &records,
raw_ostream &os) {
return emitStructDecls(records, os);
});
// Registers the struct utility generator to mlir-tblgen.
static GenRegistration genStructDefs("gen-iree-struct-attr-defs",
"Generate struct attr definitions",
[](const RecordKeeper &records,
raw_ostream &os) {
return emitStructDefs(records, os);
});
} // namespace
} // namespace iree_compiler
} // namespace tblgen
} // namespace mlir