Removing the bespoke struct attr gen in favor of AttrDefs.
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/BUILD b/compiler/src/iree/compiler/Dialect/HAL/IR/BUILD index 6b77748..4cfb016 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/BUILD +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/BUILD
@@ -63,15 +63,12 @@ "HALOpInterfaces.h.inc", "HALOps.cpp.inc", "HALOps.h.inc", - "HALStructs.cpp.inc", - "HALStructs.h.inc", "HALTypeInterfaces.cpp.inc", "HALTypeInterfaces.h.inc", ], deps = [ ":HALInterfacesGen", ":HALOpsGen", - ":HALStructsGen", ":HALTypesGen", "//compiler/src/iree/compiler/Dialect/Util/IR", "@llvm-project//llvm:Support", @@ -158,27 +155,13 @@ ) iree_gentbl_cc_library( - name = "HALStructsGen", - tbl_outs = [ - ( - ["--gen-iree-struct-attr-decls"], - "HALStructs.h.inc", - ), - ( - ["--gen-iree-struct-attr-defs"], - "HALStructs.cpp.inc", - ), - ], - tblgen = "//tools:iree-tblgen", - td_file = "HALBase.td", - deps = [":td_files"], -) - -iree_gentbl_cc_library( name = "HALTypesGen", tbl_outs = [ ( - ["--gen-attrdef-decls", "--attrdefs-dialect=hal"], + [ + "--gen-attrdef-decls", + "--attrdefs-dialect=hal", + ], "HALAttrs.h.inc", ), (
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/HAL/IR/CMakeLists.txt index bdcddca..2ca1bef 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/CMakeLists.txt +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/CMakeLists.txt
@@ -29,8 +29,6 @@ "HALOpInterfaces.h.inc" "HALOps.cpp.inc" "HALOps.h.inc" - "HALStructs.cpp.inc" - "HALStructs.h.inc" "HALTypeInterfaces.cpp.inc" "HALTypeInterfaces.h.inc" SRCS @@ -40,7 +38,6 @@ DEPS ::HALInterfacesGen ::HALOpsGen - ::HALStructsGen ::HALTypesGen LLVMSupport MLIRArithmeticDialect @@ -105,18 +102,6 @@ iree_tablegen_library( NAME - HALStructsGen - TD_FILE - "HALBase.td" - OUTS - --gen-iree-struct-attr-decls HALStructs.h.inc - --gen-iree-struct-attr-defs HALStructs.cpp.inc - TBLGEN - IREE -) - -iree_tablegen_library( - NAME HALTypesGen TD_FILE "HALBase.td"
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALTypes.cpp b/compiler/src/iree/compiler/Dialect/HAL/IR/HALTypes.cpp index 2c8c139..c6bb755 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALTypes.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALTypes.cpp
@@ -17,9 +17,8 @@ // clang-format off: must be included after all LLVM/MLIR headers. #define GET_ATTRDEF_CLASSES -#include "iree/compiler/Dialect/HAL/IR/HALAttrs.cpp.inc" // IWYU pragma: keep -#include "iree/compiler/Dialect/HAL/IR/HALEnums.cpp.inc" // IWYU pragma: keep -#include "iree/compiler/Dialect/HAL/IR/HALStructs.cpp.inc" // IWYU pragma: keep +#include "iree/compiler/Dialect/HAL/IR/HALAttrs.cpp.inc" // IWYU pragma: keep +#include "iree/compiler/Dialect/HAL/IR/HALEnums.cpp.inc" // IWYU pragma: keep // clang-format on namespace mlir {
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALTypes.h b/compiler/src/iree/compiler/Dialect/HAL/IR/HALTypes.h index 77b566d..90612f2 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALTypes.h +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALTypes.h
@@ -26,8 +26,7 @@ #include "mlir/Support/LLVM.h" // clang-format off: must be included after all LLVM/MLIR headers. -#include "iree/compiler/Dialect/HAL/IR/HALEnums.h.inc" // IWYU pragma: keep -#include "iree/compiler/Dialect/HAL/IR/HALStructs.h.inc" // IWYU pragma: keep +#include "iree/compiler/Dialect/HAL/IR/HALEnums.h.inc" // IWYU pragma: keep // clang-format on namespace mlir {
diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/UtilBase.td b/compiler/src/iree/compiler/Dialect/Util/IR/UtilBase.td index c778dc4..8f5e335 100644 --- a/compiler/src/iree/compiler/Dialect/Util/IR/UtilBase.td +++ b/compiler/src/iree/compiler/Dialect/Util/IR/UtilBase.td
@@ -86,33 +86,6 @@ } //===----------------------------------------------------------------------===// -// Util_StructAttr -//===----------------------------------------------------------------------===// -// This has a custom tablegen generator in StructAttrGen.cpp to create the -// attribute and storage types. It differs from the core MLIR StructAttr -// by more closely matching what handwritten C++ would have (better typing -// and ergonomics and custom parser/printer). - -class Util_StructFieldAttr<string thisName, Attr thisType> { - string name = thisName; - Attr type = thisType; -} - -class Util_StructAttr<string thisKind, string name, Dialect dialect, - list<Util_StructFieldAttr> attributes> - : Attr<CPred<"$_self.isa<" # name # ">()">, - "structured attribute of " # name> { - string kind = thisKind; - string className = name; - string cppNamespace = ?; - let storageType = name; - let returnType = name; - let convertFromStorage = "$_self"; - Dialect structDialect = dialect; - list<Util_StructFieldAttr> fields = attributes; -} - -//===----------------------------------------------------------------------===// // Common traits //===----------------------------------------------------------------------===//
diff --git a/compiler/src/iree/compiler/Dialect/Util/Tools/BUILD b/compiler/src/iree/compiler/Dialect/Util/Tools/BUILD deleted file mode 100644 index 746c662..0000000 --- a/compiler/src/iree/compiler/Dialect/Util/Tools/BUILD +++ /dev/null
@@ -1,18 +0,0 @@ -# Copyright 2019 The IREE Authors -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -package( - default_visibility = ["//visibility:public"], - features = ["layering_check"], - licenses = ["notice"], # Apache 2.0 -) - -filegroup( - name = "GenSrcs", - srcs = [ - "StructAttrGen.cpp", - ], -)
diff --git a/compiler/src/iree/compiler/Dialect/Util/Tools/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Util/Tools/CMakeLists.txt deleted file mode 100644 index cd98a50..0000000 --- a/compiler/src/iree/compiler/Dialect/Util/Tools/CMakeLists.txt +++ /dev/null
@@ -1,7 +0,0 @@ -# Copyright 2019 The IREE Authors -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -iree_add_all_subdirs()
diff --git a/compiler/src/iree/compiler/Dialect/Util/Tools/StructAttrGen.cpp b/compiler/src/iree/compiler/Dialect/Util/Tools/StructAttrGen.cpp deleted file mode 100644 index d1be382..0000000 --- a/compiler/src/iree/compiler/Dialect/Util/Tools/StructAttrGen.cpp +++ /dev/null
@@ -1,476 +0,0 @@ -// Copyright 2019 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#include "iree/compiler/Utils/StringUtils.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/SmallVector.h" -#include "llvm/ADT/StringExtras.h" -#include "llvm/Support/FormatVariadic.h" -#include "llvm/Support/raw_ostream.h" -#include "llvm/TableGen/Error.h" -#include "llvm/TableGen/Record.h" -#include "llvm/TableGen/TableGenBackend.h" -#include "mlir/TableGen/Attribute.h" -#include "mlir/TableGen/Format.h" -#include "mlir/TableGen/GenInfo.h" -#include "mlir/TableGen/Operator.h" - -namespace mlir { -namespace tblgen { -namespace iree_compiler { -namespace { - -using llvm::formatv; -using llvm::raw_ostream; -using llvm::Record; -using llvm::RecordKeeper; -using llvm::StringRef; - -class StructFieldAttr { - public: - explicit StructFieldAttr(const llvm::Record *record) : def(record) { - assert(def->isSubClassOf("Util_StructFieldAttr") && - "must be subclass of TableGen 'Util_StructFieldAttr' class"); - } - explicit StructFieldAttr(const llvm::Record &record) - : StructFieldAttr(&record) {} - explicit StructFieldAttr(const llvm::DefInit *init) - : StructFieldAttr(init->getDef()) {} - - StringRef getName() const { return def->getValueAsString("name"); } - Attribute getType() const { - auto init = def->getValueInit("type"); - return tblgen::Attribute(cast<llvm::DefInit>(init)); - } - - private: - const llvm::Record *def; -}; - -class StructAttr : public Attribute { - public: - explicit StructAttr(const llvm::Record *record) : Attribute(record) { - assert(isSubClassOf("Util_StructAttr") && - "must be subclass of TableGen 'Util_StructAttr' class"); - } - explicit StructAttr(const llvm::Record &record) : StructAttr(&record) {} - explicit StructAttr(const llvm::DefInit *init) : StructAttr(init->getDef()) {} - - StringRef getStructKind() const { return def->getValueAsString("kind"); } - StringRef getStructClassName() const { - return def->getValueAsString("className"); - } - StringRef getCppNamespace() const { - if (def->isValueUnset("cppNamespace")) { - Dialect dialect(def->getValueAsDef("structDialect")); - return dialect.getCppNamespace(); - } else { - return def->getValueAsString("cppNamespace"); - } - } - - std::vector<StructFieldAttr> getAllFields() const { - std::vector<StructFieldAttr> attributes; - const auto *inits = def->getValueAsListInit("fields"); - attributes.reserve(inits->size()); - for (const llvm::Init *init : *inits) { - attributes.emplace_back(cast<llvm::DefInit>(init)); - } - return attributes; - } -}; - -static void emitStructClass(const StructAttr &structAttr, raw_ostream &os) { - if (!structAttr.getAllFields().empty()) { - os << formatv(R"( -namespace detail { -struct {0}Storage; -} // namespace detail -)", - structAttr.getStructClassName()); - } - os << formatv(R"( -// {0} -class {1} : public mlir::Attribute::AttrBase<{1}, mlir::Attribute, {3}Storage> { - public: - using Base::Base; - - static StringRef getKindName() { return "{2}"; } - -)", - structAttr.getSummary(), structAttr.getStructClassName(), - structAttr.getStructKind(), - structAttr.getAllFields().empty() - ? "Attribute" - : "detail::" + structAttr.getStructClassName()); - - if (!structAttr.getAllFields().empty()) { - os << " static LogicalResult verify(\n"; - os << " function_ref<InFlightDiagnostic()> emitError,\n"; - interleave( - structAttr.getAllFields(), os, - [&](StructFieldAttr field) { - auto type = field.getType(); - os << formatv(" {0} {1}", type.getStorageType(), - field.getName()); - }, - ",\n"); - os << ");\n\n"; - } - - // Attribute storage type constructor (IntegerAttr, etc). - os << formatv(" static {0} get(", structAttr.getStructClassName()); - if (structAttr.getAllFields().empty()) { - os << "mlir::MLIRContext* context"; - } else { - interleaveComma(structAttr.getAllFields(), os, [&](StructFieldAttr field) { - auto type = field.getType(); - os << formatv("\n {0} {1}", type.getStorageType(), field.getName()); - }); - } - os << ");\n\n"; - - // Attribute return type constructor (APInt, etc). - if (!structAttr.getAllFields().empty()) { - os << formatv(" static {0} get(\n", structAttr.getStructClassName()); - for (auto field : structAttr.getAllFields()) { - auto type = field.getType(); - os << formatv(" {0} {1},\n", type.getReturnType(), field.getName()); - } - os << " mlir::MLIRContext* context);\n"; - } - - os << R"( - static Attribute parse(AsmParser &p); - void print(AsmPrinter &p) const; - -)"; - - for (auto field : structAttr.getAllFields()) { - auto type = field.getType(); - // Attribute storage type accessors (IntegerAttr, etc). - os << formatv(" {0} {1}Attr() const;\n", type.getStorageType(), - field.getName()); - // Attribute return type accessors (APInt, etc). - os << formatv(" {0} {1}() const;\n", type.getReturnType(), - field.getName()); - } - - os << " void walkStorage(const llvm::function_ref<void(mlir::Attribute " - "elementAttr)> &fn) const;\n"; - - os << "};\n\n"; -} - -static void emitStructDecl(const Record &structDef, raw_ostream &os) { - StructAttr structAttr(&structDef); - - // Forward declarations (to make including easier). - os << R"(namespace mlir { -class DialectAsmParser; -class DialectAsmPrinter; -} // namespace mlir - -)"; - - // Wrap in the appropriate namespace. - llvm::SmallVector<StringRef, 2> namespaces; - llvm::SplitString(structAttr.getCppNamespace(), namespaces, "::"); - - for (auto ns : namespaces) { - os << "namespace " << ns << " {\n"; - } - - // Emit the struct class definition - emitStructClass(structAttr, os); - - // Close the declared namespace. - for (auto ns : namespaces) { - os << "} // namespace " << ns << "\n"; - } -} - -static bool emitStructDecls(const RecordKeeper &recordKeeper, raw_ostream &os) { - llvm::emitSourceFileHeader("Struct Attr Declarations", os); - auto defs = recordKeeper.getAllDerivedDefinitions("Util_StructAttr"); - for (const auto *def : defs) { - emitStructDecl(*def, os); - } - return false; -} - -static void emitStorageDef(const StructAttr &structAttr, raw_ostream &os) { - os << "namespace detail {\n"; - os << formatv("struct {0}Storage : public mlir::AttributeStorage {{\n", - structAttr.getStructClassName()); - - os << " using KeyTy = std::tuple<"; - interleaveComma(structAttr.getAllFields(), os, [&](StructFieldAttr field) { - auto type = field.getType(); - os << type.getStorageType(); - }); - os << ">;\n\n"; - - os << formatv(" {0}Storage(", structAttr.getStructClassName()); - interleaveComma(structAttr.getAllFields(), os, [&](StructFieldAttr field) { - auto type = field.getType(); - os << formatv("{0} {1}", type.getStorageType(), field.getName()); - }); - os << ") : "; - interleaveComma(structAttr.getAllFields(), os, [&](StructFieldAttr field) { - os << formatv("{0}({0})", field.getName()); - }); - os << " {}\n\n"; - - os << " bool operator==(const KeyTy &key) const {\n"; - os << " return "; - int i = 0; - interleave( - structAttr.getAllFields(), os, - [&](StructFieldAttr field) { - os << formatv("std::get<{0}>(key) == {1}", i++, field.getName()); - }, - " && "); - os << ";\n }\n\n"; - - os << " static llvm::hash_code hashKey(const KeyTy &key) {\n"; - os << " return llvm::hash_combine("; - i = 0; - interleaveComma(structAttr.getAllFields(), os, [&](StructFieldAttr field) { - os << formatv("std::get<{0}>(key)", i++, field.getName()); - }); - os << ");\n"; - os << "}\n\n"; - - os << formatv( - " static {0}Storage *construct(AttributeStorageAllocator &allocator, " - "const KeyTy &key) {{\n", - structAttr.getStructClassName()); - os << formatv( - " return new (allocator.allocate<{0}Storage>()) {0}Storage(\n", - structAttr.getStructClassName()); - i = 0; - os << " "; - interleaveComma(structAttr.getAllFields(), os, [&](StructFieldAttr field) { - os << formatv("std::get<{0}>(key)", i++, field.getName()); - }); - os << ");\n"; - os << " }\n\n"; - - for (auto field : structAttr.getAllFields()) { - auto type = field.getType(); - os << formatv(" {0} {1};\n", type.getStorageType(), field.getName()); - } - - os << "};\n"; - os << "} // namespace detail\n\n"; -} - -static void emitVerifierDef(const StructAttr &structAttr, raw_ostream &os) { - os << "// static\n"; - os << formatv("LogicalResult {0}::verify(\n", - structAttr.getStructClassName()); - os << " function_ref<InFlightDiagnostic()> emitError,\n"; - interleave( - structAttr.getAllFields(), os, - [&](StructFieldAttr field) { - auto type = field.getType(); - os << formatv(" {0} {1}", type.getStorageType(), field.getName()); - }, - ",\n"); - os << ") {\n"; - - for (auto field : structAttr.getAllFields()) { - FmtContext fmt; - auto type = field.getType(); - os << formatv(R"( - if (!{0}) {{ - return emitError() << "'{1}' must be {2} but got " << {1}.getType(); - } -)", - tgfmt(type.getConditionTemplate(), - &fmt.withSelf(field.getName()), field.getName()), - field.getName(), type.getSummary()); - } - - os << " return success();\n"; - os << "}\n\n"; -} - -static void emitAttrFactoryDef(const StructAttr &structAttr, raw_ostream &os) { - os << "// static\n"; - os << formatv("{0} {0}::get(", structAttr.getStructClassName()); - if (structAttr.getAllFields().empty()) { - os << "mlir::MLIRContext* context"; - } else { - interleaveComma(structAttr.getAllFields(), os, [&](StructFieldAttr field) { - auto type = field.getType(); - os << formatv("\n {0} {1}", type.getStorageType(), field.getName()); - }); - } - os << ") {\n"; - - for (auto field : structAttr.getAllFields()) { - if (!field.getType().isOptional()) { - os << formatv(" assert({0} && \"{0} is required\");\n", field.getName()); - } - } - - if (!structAttr.getAllFields().empty()) { - os << formatv(" auto *context = {0}.getContext();\n", - structAttr.getAllFields().front().getName()); - } - - os << formatv(" return Base::get(context"); - if (!structAttr.getAllFields().empty()) { - os << ",\n "; - interleaveComma(structAttr.getAllFields(), os, - [&](StructFieldAttr field) { os << field.getName(); }); - } - os << ");\n"; - - os << "}\n\n"; -} - -static void emitTypedFactoryDef(const StructAttr &structAttr, raw_ostream &os) { - os << "// static\n"; - os << formatv("{0} {0}::get(", structAttr.getStructClassName()); - for (auto field : structAttr.getAllFields()) { - auto type = field.getType(); - os << formatv("\n {0} {1},", type.getReturnType(), field.getName()); - } - os << "\n mlir::MLIRContext* context) {\n"; - os << " mlir::Builder b(context);\n"; - - FmtContext ctx; - ctx.withBuilder("b"); - for (auto field : structAttr.getAllFields()) { - auto type = field.getType(); - - // For StringAttr, its constant builder call will wrap the input in - // quotes, which is correct for normal string literals, but incorrect - // here given we use function arguments. So we need to strip the - // wrapping quotes. - std::string builderTemplate = type.getConstBuilderTemplate().str(); - if (StringRef(builderTemplate).contains("\"$0\"")) { - builderTemplate = mlir::iree_compiler::replaceAllSubstrs(builderTemplate, - "\"$0\"", "$0"); - } - - os << formatv(" auto {0}Attr = {1};\n", field.getName(), - tgfmt(builderTemplate, &ctx, field.getName())); - } - - os << " return get("; - if (structAttr.getAllFields().empty()) { - os << "context"; - } else { - interleaveComma(structAttr.getAllFields(), os, [&](StructFieldAttr attr) { - os << attr.getName() << "Attr"; - }); - } - os << ");\n"; - - os << "}\n"; -} - -static void emitAccessorDefs(const StructAttr &structAttr, - const StructFieldAttr &field, raw_ostream &os) { - auto type = field.getType(); - - // Attribute storage type accessors (IntegerAttr, etc). - os << formatv(R"( -{1} {0}::{2}Attr() const {{ - return getImpl()->{2}; -} -)", - structAttr.getStructClassName(), type.getStorageType(), - field.getName()); - - // Attribute return type accessors (APInt, etc). - FmtContext ctx; - os << formatv( - R"( -{1} {0}::{2}() const {{ - return {3}; -} -)", - structAttr.getStructClassName(), type.getReturnType(), field.getName(), - tgfmt(type.getConvertFromStorageCall(), - &ctx.withSelf(field.getName() + "Attr()"))); -} - -static void emitWalkStorageDef(const StructAttr &structAttr, raw_ostream &os) { - os << formatv( - "void {0}::walkStorage(const llvm::function_ref<void(mlir::Attribute " - "elementAttr)> &fn) const {{\n", - structAttr.getStructClassName()); - for (auto field : structAttr.getAllFields()) { - os << formatv(" fn({0}Attr());\n", field.getName()); - } - os << "}\n"; -} - -static void emitStructDef(const Record &structDef, raw_ostream &os) { - StructAttr structAttr(&structDef); - StringRef cppNamespace = structAttr.getCppNamespace(); - - llvm::SmallVector<StringRef, 2> namespaces; - llvm::SplitString(cppNamespace, namespaces, "::"); - - for (auto ns : namespaces) { - os << "namespace " << ns << " {\n"; - } - os << "\n"; - - if (!structAttr.getAllFields().empty()) { - emitStorageDef(structAttr, os); - emitVerifierDef(structAttr, os); - } - emitAttrFactoryDef(structAttr, os); - if (!structAttr.getAllFields().empty()) { - emitTypedFactoryDef(structAttr, os); - for (auto field : structAttr.getAllFields()) { - emitAccessorDefs(structAttr, field, os); - } - } - emitWalkStorageDef(structAttr, os); - - os << "\n"; - for (auto ns : llvm::reverse(namespaces)) { - os << "} // namespace " << ns << "\n"; - } -} - -static bool emitStructDefs(const RecordKeeper &recordKeeper, raw_ostream &os) { - llvm::emitSourceFileHeader("Struct Attr Definitions", os); - auto defs = recordKeeper.getAllDerivedDefinitions("Util_StructAttr"); - for (const auto *def : defs) { - emitStructDef(*def, os); - } - return false; -} - -// Registers the struct utility generator to mlir-tblgen. -static GenRegistration genStructDecls("gen-iree-struct-attr-decls", - "Generate struct attr declarations", - [](const RecordKeeper &records, - raw_ostream &os) { - return emitStructDecls(records, os); - }); - -// Registers the struct utility generator to mlir-tblgen. -static GenRegistration genStructDefs("gen-iree-struct-attr-defs", - "Generate struct attr definitions", - [](const RecordKeeper &records, - raw_ostream &os) { - return emitStructDefs(records, os); - }); - -} // namespace -} // namespace iree_compiler -} // namespace tblgen -} // namespace mlir
diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp b/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp index 283a8b3..5d8770f 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp +++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp
@@ -876,7 +876,7 @@ "must be run before."; } - const int numGlobalRefs = ordinal_counts.getValue().global_refs(); + const int numGlobalRefs = ordinal_counts.getValue().getGlobalRefs(); if (numGlobalRefs > 0) { auto refs = emitc_builders::structPtrMember( @@ -962,7 +962,7 @@ << "ordinal_counts attribute not found. The OrdinalAllocationPass " "must be run before."; } - const int numGlobalRefs = ordinal_counts.getValue().global_refs(); + const int numGlobalRefs = ordinal_counts.getValue().getGlobalRefs(); if (numGlobalRefs > 0) { auto refs = emitc_builders::structPtrMember(
diff --git a/compiler/src/iree/compiler/Dialect/VM/IR/BUILD b/compiler/src/iree/compiler/Dialect/VM/IR/BUILD index de05340..aba745e 100644 --- a/compiler/src/iree/compiler/Dialect/VM/IR/BUILD +++ b/compiler/src/iree/compiler/Dialect/VM/IR/BUILD
@@ -43,32 +43,34 @@ name = "IR", srcs = [ "VMDialect.cpp", - "VMEnums.cpp.inc", - "VMOpEncoder.cpp.inc", "VMOpFolders.cpp", - "VMOpInterface.cpp.inc", "VMOps.cpp", - "VMOps.cpp.inc", - "VMStructs.cpp.inc", "VMTypes.cpp", ], hdrs = [ "VMDialect.h", - "VMEnums.h.inc", "VMFuncEncoder.h", - "VMOpInterface.h.inc", "VMOps.h", - "VMOps.h.inc", - "VMStructs.h.inc", "VMTraits.h", "VMTypes.h", ], + textual_hdrs = [ + "VMAttrs.cpp.inc", + "VMAttrs.h.inc", + "VMEnums.cpp.inc", + "VMEnums.h.inc", + "VMOpEncoder.cpp.inc", + "VMOpInterfaces.cpp.inc", + "VMOpInterfaces.h.inc", + "VMOps.cpp.inc", + "VMOps.h.inc", + ], deps = [ + ":VMAttrsGen", ":VMEnumsGen", ":VMOpEncoderGen", - ":VMOpInterfaceGen", + ":VMOpInterfacesGen", ":VMOpsGen", - ":VMStructsGen", "//compiler/src/iree/compiler/Dialect/Util/IR", "@llvm-project//llvm:Support", "@llvm-project//mlir:ControlFlowInterfaces", @@ -82,6 +84,29 @@ ) iree_gentbl_cc_library( + name = "VMAttrsGen", + tbl_outs = [ + ( + [ + "--gen-attrdef-decls", + "--attrdefs-dialect=vm", + ], + "VMAttrs.h.inc", + ), + ( + [ + "--gen-attrdef-defs", + "--attrdefs-dialect=vm", + ], + "VMAttrs.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "VMBase.td", + deps = [":td_files"], +) + +iree_gentbl_cc_library( name = "VMEnumsGen", tbl_outs = [ ( @@ -129,15 +154,15 @@ ) iree_gentbl_cc_library( - name = "VMOpInterfaceGen", + name = "VMOpInterfacesGen", tbl_outs = [ ( ["--gen-op-interface-decls"], - "VMOpInterface.h.inc", + "VMOpInterfaces.h.inc", ), ( ["--gen-op-interface-defs"], - "VMOpInterface.cpp.inc", + "VMOpInterfaces.cpp.inc", ), ], tblgen = "@llvm-project//mlir:mlir-tblgen", @@ -145,23 +170,6 @@ deps = [":td_files"], ) -iree_gentbl_cc_library( - name = "VMStructsGen", - tbl_outs = [ - ( - ["--gen-iree-struct-attr-decls"], - "VMStructs.h.inc", - ), - ( - ["--gen-iree-struct-attr-defs"], - "VMStructs.cpp.inc", - ), - ], - tblgen = "//tools:iree-tblgen", - td_file = "VMBase.td", - deps = [":td_files"], -) - iree_tablegen_doc( name = "VMDialectDocGen", tbl_outs = [
diff --git a/compiler/src/iree/compiler/Dialect/VM/IR/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/VM/IR/CMakeLists.txt index ad738ad..48aabdb 100644 --- a/compiler/src/iree/compiler/Dialect/VM/IR/CMakeLists.txt +++ b/compiler/src/iree/compiler/Dialect/VM/IR/CMakeLists.txt
@@ -15,30 +15,31 @@ IR HDRS "VMDialect.h" - "VMEnums.h.inc" "VMFuncEncoder.h" - "VMOpInterface.h.inc" "VMOps.h" - "VMOps.h.inc" - "VMStructs.h.inc" "VMTraits.h" "VMTypes.h" + TEXTUAL_HDRS + "VMAttrs.cpp.inc" + "VMAttrs.h.inc" + "VMEnums.cpp.inc" + "VMEnums.h.inc" + "VMOpEncoder.cpp.inc" + "VMOpInterfaces.cpp.inc" + "VMOpInterfaces.h.inc" + "VMOps.cpp.inc" + "VMOps.h.inc" SRCS "VMDialect.cpp" - "VMEnums.cpp.inc" - "VMOpEncoder.cpp.inc" "VMOpFolders.cpp" - "VMOpInterface.cpp.inc" "VMOps.cpp" - "VMOps.cpp.inc" - "VMStructs.cpp.inc" "VMTypes.cpp" DEPS + ::VMAttrsGen ::VMEnumsGen ::VMOpEncoderGen - ::VMOpInterfaceGen + ::VMOpInterfacesGen ::VMOpsGen - ::VMStructsGen LLVMSupport MLIRControlFlowInterfaces MLIRFuncDialect @@ -53,6 +54,16 @@ iree_tablegen_library( NAME + VMAttrsGen + TD_FILE + "VMBase.td" + OUTS + --gen-attrdef-decls --attrdefs-dialect=vm VMAttrs.h.inc + --gen-attrdef-defs --attrdefs-dialect=vm VMAttrs.cpp.inc +) + +iree_tablegen_library( + NAME VMEnumsGen TD_FILE "VMBase.td" @@ -84,24 +95,12 @@ iree_tablegen_library( NAME - VMOpInterfaceGen + VMOpInterfacesGen TD_FILE "VMBase.td" OUTS - --gen-op-interface-decls VMOpInterface.h.inc - --gen-op-interface-defs VMOpInterface.cpp.inc -) - -iree_tablegen_library( - NAME - VMStructsGen - TD_FILE - "VMBase.td" - OUTS - --gen-iree-struct-attr-decls VMStructs.h.inc - --gen-iree-struct-attr-defs VMStructs.cpp.inc - TBLGEN - IREE + --gen-op-interface-decls VMOpInterfaces.h.inc + --gen-op-interface-defs VMOpInterfaces.cpp.inc ) iree_tablegen_doc(
diff --git a/compiler/src/iree/compiler/Dialect/VM/IR/VMBase.td b/compiler/src/iree/compiler/Dialect/VM/IR/VMBase.td index 4fffde2..475a324 100644 --- a/compiler/src/iree/compiler/Dialect/VM/IR/VMBase.td +++ b/compiler/src/iree/compiler/Dialect/VM/IR/VMBase.td
@@ -386,19 +386,18 @@ // VM structs //===----------------------------------------------------------------------===// -def VM_OrdinalCountsAttr : - Util_StructAttr<"ordinal_counts", - "OrdinalCountsAttr", - VM_Dialect, [ - Util_StructFieldAttr<"import_funcs", I32Attr>, - Util_StructFieldAttr<"export_funcs", I32Attr>, - Util_StructFieldAttr<"internal_funcs", I32Attr>, - Util_StructFieldAttr<"global_bytes", I32Attr>, - Util_StructFieldAttr<"global_refs", I32Attr>, - Util_StructFieldAttr<"rodatas", I32Attr>, - Util_StructFieldAttr<"rwdatas", I32Attr>, - ]> { - let cppNamespace = "mlir::iree_compiler::IREE::VM"; +def VM_OrdinalCountsAttr : AttrDef<VM_Dialect, "OrdinalCounts"> { + let mnemonic = "ordinal_counts"; + let parameters = (ins + AttrParameter<"int32_t", "">:$import_funcs, + AttrParameter<"int32_t", "">:$export_funcs, + AttrParameter<"int32_t", "">:$internal_funcs, + AttrParameter<"int32_t", "">:$global_bytes, + AttrParameter<"int32_t", "">:$global_refs, + AttrParameter<"int32_t", "">:$rodatas, + AttrParameter<"int32_t", "">:$rwdatas + ); + let assemblyFormat = "`<` struct(params) `>`"; } #endif // IREE_DIALECT_VM_BASE
diff --git a/compiler/src/iree/compiler/Dialect/VM/IR/VMDialect.cpp b/compiler/src/iree/compiler/Dialect/VM/IR/VMDialect.cpp index 250e59d..d9c4e74 100644 --- a/compiler/src/iree/compiler/Dialect/VM/IR/VMDialect.cpp +++ b/compiler/src/iree/compiler/Dialect/VM/IR/VMDialect.cpp
@@ -22,7 +22,7 @@ namespace IREE { namespace VM { -#include "iree/compiler/Dialect/VM/IR/VMOpInterface.cpp.inc" // IWYU pragma: keep +#include "iree/compiler/Dialect/VM/IR/VMOpInterfaces.cpp.inc" // IWYU pragma: keep // Fallback asm printer for ops that do not define their own. See op-specific // printers in the op implementations. @@ -169,30 +169,6 @@ } //===----------------------------------------------------------------------===// -// Attribute printing and parsing -//===----------------------------------------------------------------------===// - -Attribute VMDialect::parseAttribute(DialectAsmParser &parser, Type type) const { - StringRef attrKind; - if (failed(parser.parseKeyword(&attrKind))) return {}; - if (attrKind == OrdinalCountsAttr::getKindName()) { - return OrdinalCountsAttr::parse(parser); - } - parser.emitError(parser.getNameLoc()) << "unknown VM attribute: " << attrKind; - return {}; -} - -void VMDialect::printAttribute(Attribute attr, DialectAsmPrinter &p) const { - TypeSwitch<Attribute>(attr) - .Case<OrdinalCountsAttr>([&](auto typedAttr) { - p << typedAttr.getKindName(); - typedAttr.print(p); - }) - .Default( - [](Attribute) { assert(false && "unhandled VM attribute kind"); }); -} - -//===----------------------------------------------------------------------===// // Type printing and parsing //===----------------------------------------------------------------------===//
diff --git a/compiler/src/iree/compiler/Dialect/VM/IR/VMDialect.h b/compiler/src/iree/compiler/Dialect/VM/IR/VMDialect.h index 1c60296..a709444 100644 --- a/compiler/src/iree/compiler/Dialect/VM/IR/VMDialect.h +++ b/compiler/src/iree/compiler/Dialect/VM/IR/VMDialect.h
@@ -24,7 +24,7 @@ namespace IREE { namespace VM { -#include "iree/compiler/Dialect/VM/IR/VMOpInterface.h.inc" // IWYU pragma: export +#include "iree/compiler/Dialect/VM/IR/VMOpInterfaces.h.inc" // IWYU pragma: export class VMDialect : public Dialect { public:
diff --git a/compiler/src/iree/compiler/Dialect/VM/IR/VMTypes.cpp b/compiler/src/iree/compiler/Dialect/VM/IR/VMTypes.cpp index 26aeedf..08a4b21 100644 --- a/compiler/src/iree/compiler/Dialect/VM/IR/VMTypes.cpp +++ b/compiler/src/iree/compiler/Dialect/VM/IR/VMTypes.cpp
@@ -14,8 +14,9 @@ #include "mlir/IR/TypeSupport.h" // clang-format off: must be included after all LLVM/MLIR headers. -#include "iree/compiler/Dialect/VM/IR/VMEnums.cpp.inc" // IWYU pragma: keep -#include "iree/compiler/Dialect/VM/IR/VMStructs.cpp.inc" // IWYU pragma: keep +#define GET_ATTRDEF_CLASSES +#include "iree/compiler/Dialect/VM/IR/VMAttrs.cpp.inc" // IWYU pragma: keep +#include "iree/compiler/Dialect/VM/IR/VMEnums.cpp.inc" // IWYU pragma: keep // clang-format on namespace mlir { @@ -131,68 +132,44 @@ Type RefType::getObjectType() { return getImpl()->objectType; } //===----------------------------------------------------------------------===// -// Attribute printing and parsing -//===----------------------------------------------------------------------===// - -Attribute OrdinalCountsAttr::parse(AsmParser &p) { - Type i32 = p.getBuilder().getIntegerType(32); - IntegerAttr importFuncsAttr; - IntegerAttr exportFuncsAttr; - IntegerAttr internalFuncsAttr; - IntegerAttr globalBytesAttr; - IntegerAttr globalRefsAttr; - IntegerAttr rodatasAttr; - IntegerAttr rwdatasAttr; - if (failed(p.parseLess()) || failed(p.parseKeyword("import_funcs")) || - failed(p.parseEqual()) || - failed(p.parseAttribute(importFuncsAttr, i32)) || - failed(p.parseComma()) || failed(p.parseKeyword("export_funcs")) || - failed(p.parseEqual()) || - failed(p.parseAttribute(exportFuncsAttr, i32)) || - failed(p.parseComma()) || failed(p.parseKeyword("internal_funcs")) || - failed(p.parseEqual()) || - failed(p.parseAttribute(internalFuncsAttr, i32)) || - failed(p.parseComma()) || failed(p.parseKeyword("global_bytes")) || - failed(p.parseEqual()) || - failed(p.parseAttribute(globalBytesAttr, i32)) || - failed(p.parseComma()) || failed(p.parseKeyword("global_refs")) || - failed(p.parseEqual()) || failed(p.parseAttribute(globalRefsAttr, i32)) || - failed(p.parseComma()) || failed(p.parseKeyword("rodatas")) || - failed(p.parseEqual()) || failed(p.parseAttribute(rodatasAttr, i32)) || - failed(p.parseComma()) || failed(p.parseKeyword("rwdatas")) || - failed(p.parseEqual()) || failed(p.parseAttribute(rwdatasAttr, i32)) || - failed(p.parseGreater())) { - return {}; - } - return get(importFuncsAttr, exportFuncsAttr, internalFuncsAttr, - globalBytesAttr, globalRefsAttr, rodatasAttr, rwdatasAttr); -} - -void OrdinalCountsAttr::print(AsmPrinter &p) const { - auto &os = p.getStream(); - os << "<"; - os << "import_funcs = " << import_funcs() << ", "; - os << "export_funcs = " << export_funcs() << ", "; - os << "internal_funcs = " << internal_funcs() << ", "; - os << "global_bytes = " << global_bytes() << ", "; - os << "global_refs = " << global_refs() << ", "; - os << "rodatas = " << rodatas() << ", "; - os << "rwdatas = " << rwdatas(); - os << ">"; -} - -//===----------------------------------------------------------------------===// // VMDialect //===----------------------------------------------------------------------===// void VMDialect::registerAttributes() { - addAttributes<IREE::VM::OrdinalCountsAttr>(); + addAttributes< +#define GET_ATTRDEF_LIST +#include "iree/compiler/Dialect/VM/IR/VMAttrs.cpp.inc" // IWYU pragma: keep + >(); } void VMDialect::registerTypes() { addTypes<IREE::VM::BufferType, IREE::VM::ListType, IREE::VM::OpaqueType, IREE::VM::RefType>(); } +//===----------------------------------------------------------------------===// +// Attribute printing and parsing +//===----------------------------------------------------------------------===// + +Attribute VMDialect::parseAttribute(DialectAsmParser &parser, Type type) const { + StringRef mnemonic; + if (failed(parser.parseKeyword(&mnemonic))) return {}; + Attribute genAttr; + OptionalParseResult parseResult = + generatedAttributeParser(parser, mnemonic, type, genAttr); + if (parseResult.hasValue()) return genAttr; + parser.emitError(parser.getNameLoc()) + << "unknown HAL attribute: " << mnemonic; + return {}; +} + +void VMDialect::printAttribute(Attribute attr, DialectAsmPrinter &p) const { + TypeSwitch<Attribute>(attr).Default([&](Attribute) { + if (failed(generatedAttributePrinter(attr, p))) { + assert(false && "unhandled HAL attribute kind"); + } + }); +} + } // namespace VM } // namespace IREE } // namespace iree_compiler
diff --git a/compiler/src/iree/compiler/Dialect/VM/IR/VMTypes.h b/compiler/src/iree/compiler/Dialect/VM/IR/VMTypes.h index e006187..85417ee 100644 --- a/compiler/src/iree/compiler/Dialect/VM/IR/VMTypes.h +++ b/compiler/src/iree/compiler/Dialect/VM/IR/VMTypes.h
@@ -17,8 +17,9 @@ #include "mlir/Support/LLVM.h" // clang-format off: must be included after all LLVM/MLIR headers. -#include "iree/compiler/Dialect/VM/IR/VMEnums.h.inc" // IWYU pragma: keep -#include "iree/compiler/Dialect/VM/IR/VMStructs.h.inc" // IWYU pragma: export +#define GET_ATTRDEF_CLASSES +#include "iree/compiler/Dialect/VM/IR/VMAttrs.h.inc" // IWYU pragma: export +#include "iree/compiler/Dialect/VM/IR/VMEnums.h.inc" // IWYU pragma: keep // clang-format on namespace mlir {
diff --git a/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeModuleTarget.cpp b/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeModuleTarget.cpp index 29196ad..1924c77 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeModuleTarget.cpp +++ b/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeModuleTarget.cpp
@@ -337,9 +337,9 @@ std::vector<IREE::VM::ImportOp> importFuncOps; std::vector<IREE::VM::ExportOp> exportFuncOps; std::vector<IREE::VM::FuncOp> internalFuncOps; - importFuncOps.resize(ordinalCounts.import_funcs()); - exportFuncOps.resize(ordinalCounts.export_funcs()); - internalFuncOps.resize(ordinalCounts.internal_funcs()); + importFuncOps.resize(ordinalCounts.getImportFuncs()); + exportFuncOps.resize(ordinalCounts.getExportFuncs()); + internalFuncOps.resize(ordinalCounts.getInternalFuncs()); for (auto &op : moduleOp.getBlock().getOperations()) { if (auto funcOp = dyn_cast<IREE::VM::FuncOp>(op)) { @@ -487,8 +487,8 @@ auto importFuncsRef = fbb.createOffsetVecDestructive(importFuncRefs); auto typesRef = fbb.createOffsetVecDestructive(typeRefs); - int32_t globalRefs = ordinalCounts.global_refs(); - int32_t globalBytes = ordinalCounts.global_bytes(); + int32_t globalRefs = ordinalCounts.getGlobalRefs(); + int32_t globalBytes = ordinalCounts.getGlobalBytes(); iree_vm_ModuleStateDef_ref_t moduleStateDef = 0; if (globalBytes || globalRefs) { @@ -591,7 +591,7 @@ // it's small (like strings) we can avoid the extra seeks and keep it more // local by embedding it in the FlatBuffer. std::vector<IREE::VM::RodataOp> rodataOps; - rodataOps.resize(moduleOp.ordinal_counts().getValue().rodatas()); + rodataOps.resize(moduleOp.ordinal_counts().getValue().getRodatas()); for (auto rodataOp : moduleOp.getOps<IREE::VM::RodataOp>()) { rodataOps[rodataOp.ordinal().getValue().getLimitedValue()] = rodataOp; }
diff --git a/compiler/src/iree/compiler/Dialect/VM/Target/C/CModuleTarget.cpp b/compiler/src/iree/compiler/Dialect/VM/Target/C/CModuleTarget.cpp index dbb08b9..35eeefc 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Target/C/CModuleTarget.cpp +++ b/compiler/src/iree/compiler/Dialect/VM/Target/C/CModuleTarget.cpp
@@ -117,14 +117,14 @@ auto ordinalCounts = moduleOp.ordinal_counts().getValue(); output << "iree_allocator_t allocator;\n"; - output << "uint8_t rwdata[" << countOrEmpty(ordinalCounts.global_bytes()) + output << "uint8_t rwdata[" << countOrEmpty(ordinalCounts.getGlobalBytes()) << "];\n"; - output << "iree_vm_ref_t refs[" << countOrEmpty(ordinalCounts.global_refs()) + output << "iree_vm_ref_t refs[" << countOrEmpty(ordinalCounts.getGlobalRefs()) << "];\n"; output << "iree_vm_buffer_t rodata_buffers[" - << countOrEmpty(ordinalCounts.rodatas()) << "];\n"; + << countOrEmpty(ordinalCounts.getRodatas()) << "];\n"; output << "iree_vm_function_t imports[" - << countOrEmpty(ordinalCounts.import_funcs()) << "];\n"; + << countOrEmpty(ordinalCounts.getImportFuncs()) << "];\n"; output << "};\n"; output << "typedef struct " << moduleName << "_t " << moduleName << "_t;\n";
diff --git a/compiler/src/iree/compiler/Dialect/VM/Transforms/OrdinalAllocation.cpp b/compiler/src/iree/compiler/Dialect/VM/Transforms/OrdinalAllocation.cpp index e923543..c10a90a 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Transforms/OrdinalAllocation.cpp +++ b/compiler/src/iree/compiler/Dialect/VM/Transforms/OrdinalAllocation.cpp
@@ -95,8 +95,8 @@ // Assign ordinal counts to module op. getOperation().ordinal_countsAttr(OrdinalCountsAttr::get( - nextImportOrdinal, nextExportOrdinal, nextFuncOrdinal, globalBytes, - nextGlobalRefOrdinal, nextRodataOrdinal, 0, &getContext())); + &getContext(), nextImportOrdinal, nextExportOrdinal, nextFuncOrdinal, + globalBytes, nextGlobalRefOrdinal, nextRodataOrdinal, 0)); SymbolTable symbolTable(getOperation());
diff --git a/tools/BUILD b/tools/BUILD index 0cc77e0..f0f4831 100644 --- a/tools/BUILD +++ b/tools/BUILD
@@ -211,7 +211,6 @@ cc_binary( name = "iree-tblgen", srcs = [ - "//compiler/src/iree/compiler/Dialect/Util/Tools:GenSrcs", "//compiler/src/iree/compiler/Dialect/VM/Tools:GenSrcs", ], tags = ["hostonly"],
diff --git a/tools/CMakeLists.txt b/tools/CMakeLists.txt index 15cf5b1..069617d 100644 --- a/tools/CMakeLists.txt +++ b/tools/CMakeLists.txt
@@ -182,7 +182,6 @@ iree-tblgen SRCS "${IREE_ROOT_DIR}/third_party/llvm-project/mlir/tools/mlir-tblgen/mlir-tblgen.cpp" - "${IREE_SOURCE_DIR}/compiler/src/iree/compiler/Dialect/Util/Tools/StructAttrGen.cpp" "${IREE_SOURCE_DIR}/compiler/src/iree/compiler/Dialect/VM/Tools/VMOpEncoderGen.cpp" "${IREE_SOURCE_DIR}/compiler/src/iree/compiler/Dialect/VM/Tools/VMOpTableGen.cpp" DEPS