| // Copyright 2019 The IREE Authors |
| // |
| // Licensed under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| |
| #include "llvm/ADT/STLExtras.h" |
| #include "llvm/ADT/SmallVector.h" |
| #include "llvm/ADT/StringExtras.h" |
| #include "llvm/Support/FormatVariadic.h" |
| #include "llvm/Support/raw_ostream.h" |
| #include "llvm/TableGen/Error.h" |
| #include "llvm/TableGen/Record.h" |
| #include "llvm/TableGen/TableGenBackend.h" |
| #include "mlir/TableGen/Attribute.h" |
| #include "mlir/TableGen/Format.h" |
| #include "mlir/TableGen/GenInfo.h" |
| #include "mlir/TableGen/Operator.h" |
| |
| namespace mlir { |
| namespace tblgen { |
| namespace iree_compiler { |
| namespace { |
| |
| using llvm::formatv; |
| using llvm::raw_ostream; |
| using llvm::Record; |
| using llvm::RecordKeeper; |
| using llvm::StringRef; |
| |
| class StructFieldAttr { |
| public: |
| explicit StructFieldAttr(const llvm::Record *record) : def(record) { |
| assert(def->isSubClassOf("Util_StructFieldAttr") && |
| "must be subclass of TableGen 'Util_StructFieldAttr' class"); |
| } |
| explicit StructFieldAttr(const llvm::Record &record) |
| : StructFieldAttr(&record) {} |
| explicit StructFieldAttr(const llvm::DefInit *init) |
| : StructFieldAttr(init->getDef()) {} |
| |
| StringRef getName() const { return def->getValueAsString("name"); } |
| Attribute getType() const { |
| auto init = def->getValueInit("type"); |
| return tblgen::Attribute(cast<llvm::DefInit>(init)); |
| } |
| |
| private: |
| const llvm::Record *def; |
| }; |
| |
| class StructAttr : public Attribute { |
| public: |
| explicit StructAttr(const llvm::Record *record) : Attribute(record) { |
| assert(isSubClassOf("Util_StructAttr") && |
| "must be subclass of TableGen 'Util_StructAttr' class"); |
| } |
| explicit StructAttr(const llvm::Record &record) : StructAttr(&record) {} |
| explicit StructAttr(const llvm::DefInit *init) : StructAttr(init->getDef()) {} |
| |
| StringRef getStructKind() const { return def->getValueAsString("kind"); } |
| StringRef getStructClassName() const { |
| return def->getValueAsString("className"); |
| } |
| StringRef getCppNamespace() const { |
| if (def->isValueUnset("cppNamespace")) { |
| Dialect dialect(def->getValueAsDef("structDialect")); |
| return dialect.getCppNamespace(); |
| } else { |
| return def->getValueAsString("cppNamespace"); |
| } |
| } |
| |
| std::vector<StructFieldAttr> getAllFields() const { |
| std::vector<StructFieldAttr> attributes; |
| const auto *inits = def->getValueAsListInit("fields"); |
| attributes.reserve(inits->size()); |
| for (const llvm::Init *init : *inits) { |
| attributes.emplace_back(cast<llvm::DefInit>(init)); |
| } |
| return attributes; |
| } |
| }; |
| |
| static void emitStructClass(const StructAttr &structAttr, raw_ostream &os) { |
| if (!structAttr.getAllFields().empty()) { |
| os << formatv(R"( |
| namespace detail { |
| struct {0}Storage; |
| } // namespace detail |
| )", |
| structAttr.getStructClassName()); |
| } |
| os << formatv(R"( |
| // {0} |
| class {1} : public mlir::Attribute::AttrBase<{1}, mlir::Attribute, {3}Storage> { |
| public: |
| using Base::Base; |
| |
| static StringRef getKindName() { return "{2}"; } |
| |
| )", |
| structAttr.getSummary(), structAttr.getStructClassName(), |
| structAttr.getStructKind(), |
| structAttr.getAllFields().empty() |
| ? "Attribute" |
| : "detail::" + structAttr.getStructClassName()); |
| |
| if (!structAttr.getAllFields().empty()) { |
| os << " static LogicalResult verify(\n"; |
| os << " function_ref<InFlightDiagnostic()> emitError,\n"; |
| interleave( |
| structAttr.getAllFields(), os, |
| [&](StructFieldAttr field) { |
| auto type = field.getType(); |
| os << formatv(" {0} {1}", type.getStorageType(), |
| field.getName()); |
| }, |
| ",\n"); |
| os << ");\n\n"; |
| } |
| |
| // Attribute storage type constructor (IntegerAttr, etc). |
| os << formatv(" static {0} get(", structAttr.getStructClassName()); |
| if (structAttr.getAllFields().empty()) { |
| os << "mlir::MLIRContext* context"; |
| } else { |
| interleaveComma(structAttr.getAllFields(), os, [&](StructFieldAttr field) { |
| auto type = field.getType(); |
| os << formatv("\n {0} {1}", type.getStorageType(), field.getName()); |
| }); |
| } |
| os << ");\n\n"; |
| |
| // Attribute return type constructor (APInt, etc). |
| if (!structAttr.getAllFields().empty()) { |
| os << formatv(" static {0} get(\n", structAttr.getStructClassName()); |
| for (auto field : structAttr.getAllFields()) { |
| auto type = field.getType(); |
| os << formatv(" {0} {1},\n", type.getReturnType(), field.getName()); |
| } |
| os << " mlir::MLIRContext* context);\n"; |
| } |
| |
| os << R"( |
| static Attribute parse(DialectAsmParser &p); |
| void print(DialectAsmPrinter &p) const; |
| |
| )"; |
| |
| for (auto field : structAttr.getAllFields()) { |
| auto type = field.getType(); |
| // Attribute storage type accessors (IntegerAttr, etc). |
| os << formatv(" {0} {1}Attr() const;\n", type.getStorageType(), |
| field.getName()); |
| // Attribute return type accessors (APInt, etc). |
| os << formatv(" {0} {1}() const;\n", type.getReturnType(), |
| field.getName()); |
| } |
| |
| os << " void walkStorage(const llvm::function_ref<void(mlir::Attribute " |
| "elementAttr)> &fn) const;\n"; |
| |
| os << "};\n\n"; |
| } |
| |
| static void emitStructDecl(const Record &structDef, raw_ostream &os) { |
| StructAttr structAttr(&structDef); |
| |
| // Forward declarations (to make including easier). |
| os << R"(namespace mlir { |
| class DialectAsmParser; |
| class DialectAsmPrinter; |
| } // namespace mlir |
| |
| )"; |
| |
| // Wrap in the appropriate namespace. |
| llvm::SmallVector<StringRef, 2> namespaces; |
| llvm::SplitString(structAttr.getCppNamespace(), namespaces, "::"); |
| |
| for (auto ns : namespaces) { |
| os << "namespace " << ns << " {\n"; |
| } |
| |
| // Emit the struct class definition |
| emitStructClass(structAttr, os); |
| |
| // Close the declared namespace. |
| for (auto ns : namespaces) { |
| os << "} // namespace " << ns << "\n"; |
| } |
| } |
| |
| static bool emitStructDecls(const RecordKeeper &recordKeeper, raw_ostream &os) { |
| llvm::emitSourceFileHeader("Struct Attr Declarations", os); |
| auto defs = recordKeeper.getAllDerivedDefinitions("Util_StructAttr"); |
| for (const auto *def : defs) { |
| emitStructDecl(*def, os); |
| } |
| return false; |
| } |
| |
| static void emitStorageDef(const StructAttr &structAttr, raw_ostream &os) { |
| os << "namespace detail {\n"; |
| os << formatv("struct {0}Storage : public mlir::AttributeStorage {{\n", |
| structAttr.getStructClassName()); |
| |
| os << " using KeyTy = std::tuple<"; |
| interleaveComma(structAttr.getAllFields(), os, [&](StructFieldAttr field) { |
| auto type = field.getType(); |
| os << type.getStorageType(); |
| }); |
| os << ">;\n\n"; |
| |
| os << formatv(" {0}Storage(", structAttr.getStructClassName()); |
| interleaveComma(structAttr.getAllFields(), os, [&](StructFieldAttr field) { |
| auto type = field.getType(); |
| os << formatv("{0} {1}", type.getStorageType(), field.getName()); |
| }); |
| os << ") : "; |
| interleaveComma(structAttr.getAllFields(), os, [&](StructFieldAttr field) { |
| os << formatv("{0}({0})", field.getName()); |
| }); |
| os << " {}\n\n"; |
| |
| os << " bool operator==(const KeyTy &key) const {\n"; |
| os << " return "; |
| int i = 0; |
| interleave( |
| structAttr.getAllFields(), os, |
| [&](StructFieldAttr field) { |
| os << formatv("std::get<{0}>(key) == {1}", i++, field.getName()); |
| }, |
| " && "); |
| os << ";\n }\n\n"; |
| |
| os << " static llvm::hash_code hashKey(const KeyTy &key) {\n"; |
| os << " return llvm::hash_combine("; |
| i = 0; |
| interleaveComma(structAttr.getAllFields(), os, [&](StructFieldAttr field) { |
| os << formatv("std::get<{0}>(key)", i++, field.getName()); |
| }); |
| os << ");\n"; |
| os << "}\n\n"; |
| |
| os << formatv( |
| " static {0}Storage *construct(AttributeStorageAllocator &allocator, " |
| "const KeyTy &key) {{\n", |
| structAttr.getStructClassName()); |
| os << formatv( |
| " return new (allocator.allocate<{0}Storage>()) {0}Storage(\n", |
| structAttr.getStructClassName()); |
| i = 0; |
| os << " "; |
| interleaveComma(structAttr.getAllFields(), os, [&](StructFieldAttr field) { |
| os << formatv("std::get<{0}>(key)", i++, field.getName()); |
| }); |
| os << ");\n"; |
| os << " }\n\n"; |
| |
| for (auto field : structAttr.getAllFields()) { |
| auto type = field.getType(); |
| os << formatv(" {0} {1};\n", type.getStorageType(), field.getName()); |
| } |
| |
| os << "};\n"; |
| os << "} // namespace detail\n\n"; |
| } |
| |
| static void emitVerifierDef(const StructAttr &structAttr, raw_ostream &os) { |
| os << "// static\n"; |
| os << formatv("LogicalResult {0}::verify(\n", |
| structAttr.getStructClassName()); |
| os << " function_ref<InFlightDiagnostic()> emitError,\n"; |
| interleave( |
| structAttr.getAllFields(), os, |
| [&](StructFieldAttr field) { |
| auto type = field.getType(); |
| os << formatv(" {0} {1}", type.getStorageType(), field.getName()); |
| }, |
| ",\n"); |
| os << ") {\n"; |
| |
| for (auto field : structAttr.getAllFields()) { |
| FmtContext fmt; |
| auto type = field.getType(); |
| os << formatv(R"( |
| if (!{0}) {{ |
| return emitError() << "'{1}' must be {2} but got " << {1}.getType(); |
| } |
| )", |
| tgfmt(type.getConditionTemplate(), |
| &fmt.withSelf(field.getName()), field.getName()), |
| field.getName(), type.getSummary()); |
| } |
| |
| os << " return success();\n"; |
| os << "}\n\n"; |
| } |
| |
| static void emitAttrFactoryDef(const StructAttr &structAttr, raw_ostream &os) { |
| os << "// static\n"; |
| os << formatv("{0} {0}::get(", structAttr.getStructClassName()); |
| if (structAttr.getAllFields().empty()) { |
| os << "mlir::MLIRContext* context"; |
| } else { |
| interleaveComma(structAttr.getAllFields(), os, [&](StructFieldAttr field) { |
| auto type = field.getType(); |
| os << formatv("\n {0} {1}", type.getStorageType(), field.getName()); |
| }); |
| } |
| os << ") {\n"; |
| |
| for (auto field : structAttr.getAllFields()) { |
| if (!field.getType().isOptional()) { |
| os << formatv(" assert({0} && \"{0} is required\");\n", field.getName()); |
| } |
| } |
| |
| if (!structAttr.getAllFields().empty()) { |
| os << formatv(" auto *context = {0}.getContext();\n", |
| structAttr.getAllFields().front().getName()); |
| } |
| |
| os << formatv(" return Base::get(context"); |
| if (!structAttr.getAllFields().empty()) { |
| os << ",\n "; |
| interleaveComma(structAttr.getAllFields(), os, |
| [&](StructFieldAttr field) { os << field.getName(); }); |
| } |
| os << ");\n"; |
| |
| os << "}\n\n"; |
| } |
| |
| // Replaces all occurrences of `match` in `str` with `substitute`. |
| static std::string replaceAllSubstrs(std::string str, const std::string &match, |
| const std::string &substitute) { |
| std::string::size_type scanLoc = 0, matchLoc = std::string::npos; |
| while ((matchLoc = str.find(match, scanLoc)) != std::string::npos) { |
| str = str.replace(matchLoc, match.size(), substitute); |
| scanLoc = matchLoc + substitute.size(); |
| } |
| return str; |
| } |
| |
| static void emitTypedFactoryDef(const StructAttr &structAttr, raw_ostream &os) { |
| os << "// static\n"; |
| os << formatv("{0} {0}::get(", structAttr.getStructClassName()); |
| for (auto field : structAttr.getAllFields()) { |
| auto type = field.getType(); |
| os << formatv("\n {0} {1},", type.getReturnType(), field.getName()); |
| } |
| os << "\n mlir::MLIRContext* context) {\n"; |
| os << " mlir::Builder b(context);\n"; |
| |
| FmtContext ctx; |
| ctx.withBuilder("b"); |
| for (auto field : structAttr.getAllFields()) { |
| auto type = field.getType(); |
| |
| // For StringAttr, its constant builder call will wrap the input in |
| // quotes, which is correct for normal string literals, but incorrect |
| // here given we use function arguments. So we need to strip the |
| // wrapping quotes. |
| std::string builderTemplate = type.getConstBuilderTemplate().str(); |
| if (StringRef(builderTemplate).contains("\"$0\"")) { |
| builderTemplate = replaceAllSubstrs(builderTemplate, "\"$0\"", "$0"); |
| } |
| |
| os << formatv(" auto {0}Attr = {1};\n", field.getName(), |
| tgfmt(builderTemplate, &ctx, field.getName())); |
| } |
| |
| os << " return get("; |
| if (structAttr.getAllFields().empty()) { |
| os << "context"; |
| } else { |
| interleaveComma(structAttr.getAllFields(), os, [&](StructFieldAttr attr) { |
| os << attr.getName() << "Attr"; |
| }); |
| } |
| os << ");\n"; |
| |
| os << "}\n"; |
| } |
| |
| static void emitAccessorDefs(const StructAttr &structAttr, |
| const StructFieldAttr &field, raw_ostream &os) { |
| auto type = field.getType(); |
| |
| // Attribute storage type accessors (IntegerAttr, etc). |
| os << formatv(R"( |
| {1} {0}::{2}Attr() const {{ |
| return getImpl()->{2}; |
| } |
| )", |
| structAttr.getStructClassName(), type.getStorageType(), |
| field.getName()); |
| |
| // Attribute return type accessors (APInt, etc). |
| FmtContext ctx; |
| os << formatv( |
| R"( |
| {1} {0}::{2}() const {{ |
| return {3}; |
| } |
| )", |
| structAttr.getStructClassName(), type.getReturnType(), field.getName(), |
| tgfmt(type.getConvertFromStorageCall(), |
| &ctx.withSelf(field.getName() + "Attr()"))); |
| } |
| |
| static void emitWalkStorageDef(const StructAttr &structAttr, raw_ostream &os) { |
| os << formatv( |
| "void {0}::walkStorage(const llvm::function_ref<void(mlir::Attribute " |
| "elementAttr)> &fn) const {{\n", |
| structAttr.getStructClassName()); |
| for (auto field : structAttr.getAllFields()) { |
| os << formatv(" fn({0}Attr());\n", field.getName()); |
| } |
| os << "}\n"; |
| } |
| |
| static void emitStructDef(const Record &structDef, raw_ostream &os) { |
| StructAttr structAttr(&structDef); |
| StringRef cppNamespace = structAttr.getCppNamespace(); |
| |
| llvm::SmallVector<StringRef, 2> namespaces; |
| llvm::SplitString(cppNamespace, namespaces, "::"); |
| |
| for (auto ns : namespaces) { |
| os << "namespace " << ns << " {\n"; |
| } |
| os << "\n"; |
| |
| if (!structAttr.getAllFields().empty()) { |
| emitStorageDef(structAttr, os); |
| emitVerifierDef(structAttr, os); |
| } |
| emitAttrFactoryDef(structAttr, os); |
| if (!structAttr.getAllFields().empty()) { |
| emitTypedFactoryDef(structAttr, os); |
| for (auto field : structAttr.getAllFields()) { |
| emitAccessorDefs(structAttr, field, os); |
| } |
| } |
| emitWalkStorageDef(structAttr, os); |
| |
| os << "\n"; |
| for (auto ns : llvm::reverse(namespaces)) { |
| os << "} // namespace " << ns << "\n"; |
| } |
| } |
| |
| static bool emitStructDefs(const RecordKeeper &recordKeeper, raw_ostream &os) { |
| llvm::emitSourceFileHeader("Struct Attr Definitions", os); |
| auto defs = recordKeeper.getAllDerivedDefinitions("Util_StructAttr"); |
| for (const auto *def : defs) { |
| emitStructDef(*def, os); |
| } |
| return false; |
| } |
| |
| // Registers the struct utility generator to mlir-tblgen. |
| static GenRegistration genStructDecls("gen-iree-struct-attr-decls", |
| "Generate struct attr declarations", |
| [](const RecordKeeper &records, |
| raw_ostream &os) { |
| return emitStructDecls(records, os); |
| }); |
| |
| // Registers the struct utility generator to mlir-tblgen. |
| static GenRegistration genStructDefs("gen-iree-struct-attr-defs", |
| "Generate struct attr definitions", |
| [](const RecordKeeper &records, |
| raw_ostream &os) { |
| return emitStructDefs(records, os); |
| }); |
| |
| } // namespace |
| } // namespace iree_compiler |
| } // namespace tblgen |
| } // namespace mlir |