Merge pull request #2981 from rsuderman:main-to-google
PiperOrigin-RevId: 328179386
diff --git a/SUBMODULE_VERSIONS b/SUBMODULE_VERSIONS
index 76af324..e27d647 100644
--- a/SUBMODULE_VERSIONS
+++ b/SUBMODULE_VERSIONS
@@ -4,7 +4,7 @@
a5d9d0f7d368054fd1691aedf1db4116efcc233e third_party/flatbuffers
4fb0ff7069bd88ee85902f4d0bb62794e5f6d021 third_party/flatcc
f2fb48c3b3d79a75a88a99fba6576b25d42ec528 third_party/googletest
-e75bc5c791e0e8dbe79f7453e55af9e8d03c9cc0 third_party/llvm-project
+bad7d6b3735d1d855ffb07f32a272049cff085e6 third_party/llvm-project
17b12a4481daa150e2d1ea3ada086b551b856707 third_party/marl
a3479bbf9161df8c8cac55a08205864e6f371491 third_party/mlir-emitc
d8c7ee00a687ac369e62e2032514a93a9b413502 third_party/pybind11
@@ -12,7 +12,7 @@
a1390ed39ec77ecfb574bc6fcd5bfc5e3adbdea9 third_party/sdl2
f8bf11a0253a32375c32cad92c841237b96696c0 third_party/spirv_headers
57eb48aed36160c4876bc8310d9ca84d42ee9e2a third_party/swiftshader
-0e65b3a903cb5f0457d4972cd9ab1b1b8fa98e4d third_party/tensorflow
+051ed1cbfdefac1404cbfe0c2b1dd6e13c4e8fbd third_party/tensorflow
864d86e8b6d21449474db5e9313dbff90aa9c24f third_party/tracy
9bd3f561bcee3f01d22912de10bb07ce4e23d378 third_party/vulkan_headers
909f36b714c9239ee0b112a321220213a474ba53 third_party/vulkan_memory_allocator
diff --git a/bindings/python/pyiree/compiler/compiler.cc b/bindings/python/pyiree/compiler/compiler.cc
index 78a2bd4..7910e4c 100644
--- a/bindings/python/pyiree/compiler/compiler.cc
+++ b/bindings/python/pyiree/compiler/compiler.cc
@@ -287,7 +287,9 @@
void DiagnosticCapture::ClearDiagnostics() { diagnostics_.clear(); }
CompilerContextBundle::CompilerContextBundle()
- : default_capture_(&mlir_context_, nullptr) {}
+ : default_capture_(&mlir_context_, nullptr) {
+ mlir_context_.loadAllGloballyRegisteredDialects();
+}
CompilerContextBundle::~CompilerContextBundle() = default;
CompilerModuleBundle CompilerContextBundle::ParseAsm(
diff --git a/bindings/python/pyiree/rt/system_api.py b/bindings/python/pyiree/rt/system_api.py
index 17ac317..5ffd62e 100644
--- a/bindings/python/pyiree/rt/system_api.py
+++ b/bindings/python/pyiree/rt/system_api.py
@@ -58,16 +58,18 @@
driver_exceptions = {}
for driver_name in driver_names:
if driver_name not in available_driver_names:
- print("Could not create driver %s (not registered)" % driver_name,
- file=sys.stderr)
+ print(
+ "Could not create driver %s (not registered)" % driver_name,
+ file=sys.stderr)
continue
try:
driver = _binding.HalDriver.create(driver_name)
# TODO(laurenzo): Remove these prints to stderr (for now, more information
# is better and there is no better way to report it yet).
except Exception as ex: # pylint: disable=broad-except
- print("Could not create default driver %s: %r" % (driver_name, ex),
- file=sys.stderr)
+ print(
+ "Could not create default driver %s: %r" % (driver_name, ex),
+ file=sys.stderr)
driver_exceptions[driver_name] = ex
continue
@@ -78,8 +80,9 @@
try:
device = driver.create_default_device()
except Exception as ex:
- print("Could not create default driver device %s: %r" % (driver_name, ex),
- file=sys.stderr)
+ print(
+ "Could not create default driver device %s: %r" % (driver_name, ex),
+ file=sys.stderr)
driver_exceptions[driver_name] = ex
continue
@@ -226,8 +229,8 @@
else:
init_modules = None
- self._vm_context = _binding.VmContext(instance=self._config.vm_instance,
- modules=init_modules)
+ self._vm_context = _binding.VmContext(
+ instance=self._config.vm_instance, modules=init_modules)
if self._is_dynamic:
self._vm_context.register_modules(self._config.default_modules)
diff --git a/build_tools/bazel/third_party_import/llvm-project/overlay/llvm/BUILD.bazel b/build_tools/bazel/third_party_import/llvm-project/overlay/llvm/BUILD.bazel
index 3d5717b..92d1535 100644
--- a/build_tools/bazel/third_party_import/llvm-project/overlay/llvm/BUILD.bazel
+++ b/build_tools/bazel/third_party_import/llvm-project/overlay/llvm/BUILD.bazel
@@ -1759,6 +1759,7 @@
"lib/CodeGen/*.c",
"lib/CodeGen/*.cpp",
"lib/CodeGen/*.inc",
+ "lib/CodeGen/LiveDebugValues/*.cpp",
"lib/CodeGen/*.h",
]),
hdrs = glob([
diff --git a/build_tools/bazel/third_party_import/llvm-project/overlay/mlir/BUILD.bazel b/build_tools/bazel/third_party_import/llvm-project/overlay/mlir/BUILD.bazel
index 60284cc..94129a2 100644
--- a/build_tools/bazel/third_party_import/llvm-project/overlay/mlir/BUILD.bazel
+++ b/build_tools/bazel/third_party_import/llvm-project/overlay/mlir/BUILD.bazel
@@ -121,11 +121,13 @@
srcs = [
"lib/CAPI/IR/AffineMap.cpp",
"lib/CAPI/IR/IR.cpp",
+ "lib/CAPI/IR/StandardAttributes.cpp",
"lib/CAPI/IR/StandardTypes.cpp",
],
hdrs = [
"include/mlir-c/AffineMap.h",
"include/mlir-c/IR.h",
+ "include/mlir-c/StandardAttributes.h",
"include/mlir-c/StandardTypes.h",
"include/mlir/CAPI/AffineMap.h",
"include/mlir/CAPI/IR.h",
@@ -1124,6 +1126,7 @@
":ControlFlowInterfaces",
":IR",
":LLVMOpsIncGen",
+ ":OpenMPDialect",
":SideEffectInterfaces",
":Support",
"@llvm-project//llvm:AsmParser",
@@ -1762,6 +1765,61 @@
],
)
+cc_library(
+ name = "PDLDialect",
+ srcs = glob([
+ "lib/Dialect/PDL/IR/*.cpp",
+ "lib/Dialect/PDL/IR/*.h",
+ ]),
+ hdrs = glob([
+ "include/mlir/Dialect/PDL/IR/*.h",
+ ]),
+ includes = ["include"],
+ deps = [
+ ":IR",
+ ":InferTypeOpInterface",
+ ":PDLOpsIncGen",
+ ":SideEffects",
+ ":Support",
+ "@llvm-project//llvm:Support",
+ ],
+)
+
+filegroup(
+ name = "PDLOpsTdFiles",
+ srcs = [
+ "include/mlir/Dialect/PDL/IR/PDLBase.td",
+ "include/mlir/Dialect/PDL/IR/PDLOps.td",
+ "include/mlir/IR/SymbolInterfaces.td",
+ "include/mlir/Interfaces/SideEffectInterfaces.td",
+ ":OpBaseTdFiles",
+ ],
+)
+
+gentbl(
+ name = "PDLOpsIncGen",
+ strip_include_prefix = "include",
+ tbl_outs = [
+ (
+ "-gen-op-decls",
+ "include/mlir/Dialect/PDL/IR/PDLOps.h.inc",
+ ),
+ (
+ "-gen-op-defs",
+ "include/mlir/Dialect/PDL/IR/PDLOps.cpp.inc",
+ ),
+ (
+ "-gen-dialect-decls",
+ "include/mlir/Dialect/PDL/IR/PDLOpsDialect.h.inc",
+ ),
+ ],
+ tblgen = ":mlir-tblgen",
+ td_file = "include/mlir/Dialect/PDL/IR/PDLOps.td",
+ td_srcs = [
+ ":PDLOpsTdFiles",
+ ],
+)
+
# TODO(gcmn): Update SPIRV dependencies so that they map better to cmake files.
filegroup(
name = "SPIRVOpsTdFiles",
@@ -2876,6 +2934,7 @@
":NVVMDialect",
":OpenACCDialect",
":OpenMPDialect",
+ ":PDLDialect",
":QuantOps",
":QuantPassIncGen",
":ROCDLDialect",
@@ -3542,6 +3601,7 @@
":LinalgOps",
":LinalgTransforms",
":Pass",
+ ":SCFDialect",
":SCFToStandard",
":StandardOps",
":StandardToLLVM",
@@ -3779,6 +3839,7 @@
":EDSC",
":IR",
":LLVMDialect",
+ ":LinalgTransforms",
":Pass",
":SCFDialect",
":StandardOps",
diff --git a/build_tools/bazel/third_party_import/llvm-project/overlay/mlir/test/BUILD.bazel b/build_tools/bazel/third_party_import/llvm-project/overlay/mlir/test/BUILD.bazel
index ac27bab..bea0710 100644
--- a/build_tools/bazel/third_party_import/llvm-project/overlay/mlir/test/BUILD.bazel
+++ b/build_tools/bazel/third_party_import/llvm-project/overlay/mlir/test/BUILD.bazel
@@ -186,6 +186,7 @@
"@llvm-project//mlir:LinalgTransforms",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:SCFDialect",
+ "@llvm-project//mlir:SPIRVDialect",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:StandardOpsTransforms",
"@llvm-project//mlir:Support",
diff --git a/experimental/ModelBuilder/ModelBuilder.cpp b/experimental/ModelBuilder/ModelBuilder.cpp
index 56dace1..55dc7e8 100644
--- a/experimental/ModelBuilder/ModelBuilder.cpp
+++ b/experimental/ModelBuilder/ModelBuilder.cpp
@@ -39,15 +39,7 @@
thread_local MLIRContext mlir::ModelBuilder::ctx;
void ModelBuilder::registerAllDialects() {
- registerDialect<AffineDialect>();
- registerDialect<gpu::GPUDialect>();
- registerDialect<LLVM::LLVMDialect>();
- registerDialect<linalg::LinalgDialect>();
- registerDialect<scf::SCFDialect>();
- registerDialect<omp::OpenMPDialect>();
- registerDialect<spirv::SPIRVDialect>();
- registerDialect<StandardOpsDialect>();
- registerDialect<vector::VectorDialect>();
+ // TODO: remove.
}
mlir::ModelBuilder::ModelBuilder()
@@ -57,7 +49,17 @@
loc(module->getLoc()),
i8(IntegerType::get(8, &ctx)),
f32(FloatType::getF32(&ctx)),
- f64(FloatType::getF64(&ctx)) {}
+ f64(FloatType::getF64(&ctx)) {
+ ctx.getOrLoadDialect<AffineDialect>();
+ ctx.getOrLoadDialect<gpu::GPUDialect>();
+ ctx.getOrLoadDialect<LLVM::LLVMDialect>();
+ ctx.getOrLoadDialect<linalg::LinalgDialect>();
+ ctx.getOrLoadDialect<scf::SCFDialect>();
+ ctx.getOrLoadDialect<omp::OpenMPDialect>();
+ ctx.getOrLoadDialect<spirv::SPIRVDialect>();
+ ctx.getOrLoadDialect<StandardOpsDialect>();
+ ctx.getOrLoadDialect<vector::VectorDialect>();
+}
Value mlir::ModelBuilder::constant_f32(float v) {
return std_constant_float(llvm::APFloat(v),
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py
index 8be5958..871346b 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py
@@ -462,7 +462,7 @@
flaglines.append(f"--entry_function={self.calls[0].method}")
with open(os.path.join(trace_dir, "flagfile"), "w") as f:
- f.writelines(line + '\n' for line in flaglines)
+ f.writelines(line + "\n" for line in flaglines)
@staticmethod
def load(trace_dir: str) -> "Trace":
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils_test.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils_test.py
index c6c27da..3093028 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils_test.py
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils_test.py
@@ -178,8 +178,8 @@
def trace_function(module):
module.increment()
module.increment_by(np.array([81.], dtype=np.float32))
- module.increment_by_max(np.array([81], dtype=np.float32),
- np.array([92], dtype=np.float32))
+ module.increment_by_max(
+ np.array([81], dtype=np.float32), np.array([92], dtype=np.float32))
module.get_count()
module = tf_utils.IreeCompiledModule(StatefulCountingModule,
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py
index d424aa3..3137d81 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py
@@ -304,8 +304,8 @@
self.compiled_path = _create_reinitialized_dict["compiled_path"]
# Holds all of the module's mutable state.
- self._context = rt.SystemContext(modules=[self._module],
- config=self._config)
+ self._context = rt.SystemContext(
+ modules=[self._module], config=self._config)
def create_reinitialized(self) -> "IreeCompiledModule":
"""Duplicates this module with its initial state without recompiling."""
@@ -358,9 +358,8 @@
# which is sad).
if not isinstance(results, tuple):
results = (results,)
- return tf.nest.map_structure(self._convert_to_numpy,
- *results,
- check_types=False)
+ return tf.nest.map_structure(
+ self._convert_to_numpy, *results, check_types=False)
def get_serialized_values(self) -> Tuple[Tuple[str], Tuple[str]]:
"""Dummy function to match _IreeFunctionWrapper's API."""
diff --git a/integrations/tensorflow/compiler/TFSavedModelLowerGlobalTensors.cpp b/integrations/tensorflow/compiler/TFSavedModelLowerGlobalTensors.cpp
index 03c878d..c97c52b 100644
--- a/integrations/tensorflow/compiler/TFSavedModelLowerGlobalTensors.cpp
+++ b/integrations/tensorflow/compiler/TFSavedModelLowerGlobalTensors.cpp
@@ -108,7 +108,7 @@
}
func.eraseArguments(argsToErase);
Dialect *ireeFlowDialect =
- func.getContext()->getRegisteredDialect<IREE::Flow::FlowDialect>();
+ func.getContext()->getLoadedDialect<IREE::Flow::FlowDialect>();
while (!typeConversionWorklist.empty()) {
Value v = typeConversionWorklist.pop_back_val();
Type desiredType = v.getType();
diff --git a/integrations/tensorflow/compiler/dialect/tf_tensorlist/conversion/convert_tf_to_tf_tensorlist.cc b/integrations/tensorflow/compiler/dialect/tf_tensorlist/conversion/convert_tf_to_tf_tensorlist.cc
index 5c687a9..da966a3 100644
--- a/integrations/tensorflow/compiler/dialect/tf_tensorlist/conversion/convert_tf_to_tf_tensorlist.cc
+++ b/integrations/tensorflow/compiler/dialect/tf_tensorlist/conversion/convert_tf_to_tf_tensorlist.cc
@@ -102,7 +102,7 @@
// want to blindly update all variant types to tensorlist. So here we do a
// targeted rewrite.
auto *tfTensorListDialect =
- func.getContext()->getRegisteredDialect<TfTensorListDialect>();
+ func.getContext()->getLoadedDialect<TfTensorListDialect>();
auto tensorListType = TensorListType::get(func.getContext());
SmallVector<Value, 8> typeConversionWorklist;
func.walk([&](Operation *op) {
diff --git a/integrations/tensorflow/e2e/README.md b/integrations/tensorflow/e2e/README.md
index 2a4c9eb..6b75f23 100644
--- a/integrations/tensorflow/e2e/README.md
+++ b/integrations/tensorflow/e2e/README.md
@@ -215,9 +215,9 @@
--input_file=/path/to/custom/compiled.vmfb
```
-Currently, this only supports benchmarking the first module call in a trace.
-We plan to extend this to support benchmarking all of the calls in the trace,
-and also plan to support verifying outputs during the warm-up phase of the
+Currently, this only supports benchmarking the first module call in a trace. We
+plan to extend this to support benchmarking all of the calls in the trace, and
+also plan to support verifying outputs during the warm-up phase of the
benchmark.
## Debugging Tests
diff --git a/integrations/tensorflow/e2e/bool_test.py b/integrations/tensorflow/e2e/bool_test.py
index 1d62216..2b29f2e 100644
--- a/integrations/tensorflow/e2e/bool_test.py
+++ b/integrations/tensorflow/e2e/bool_test.py
@@ -20,40 +20,40 @@
class MathModule(tf.Module):
+
@tf.function(input_signature=[tf.TensorSpec([4], tf.float32)])
def greater_than(self, x):
return x > 1.0
- @tf.function(input_signature=[tf.TensorSpec([4], tf.bool),
- tf.TensorSpec([4], tf.bool)])
+ @tf.function(input_signature=[
+ tf.TensorSpec([4], tf.bool),
+ tf.TensorSpec([4], tf.bool)
+ ])
def logical_and(self, x, y):
return tf.math.logical_and(x, y)
-
@tf_test_utils.compile_module(MathModule)
class BooleanTest(tf_test_utils.TracedModuleTestCase):
def test_greater_than(self):
+
def greater_than(module):
module.greater_than(np.array([0.0, 1.2, 1.5, 3.75], dtype=np.float32))
self.compare_backends(greater_than)
-
def test_logical_and(self):
def logical_and(module):
module.logical_and(
- np.array([True, True, False, False], dtype=np.bool),
- np.array([True, False, False, True], dtype=np.bool))
+ np.array([True, True, False, False], dtype=np.bool),
+ np.array([True, False, False, True], dtype=np.bool))
self.compare_backends(logical_and)
-
if __name__ == "__main__":
if hasattr(tf, "enable_v2_behavior"):
tf.enable_v2_behavior()
tf.test.main()
-
diff --git a/integrations/tensorflow/e2e/keras/train_vision_models_on_cifar.py b/integrations/tensorflow/e2e/keras/train_vision_models_on_cifar.py
index fabc112..92ce1b4 100644
--- a/integrations/tensorflow/e2e/keras/train_vision_models_on_cifar.py
+++ b/integrations/tensorflow/e2e/keras/train_vision_models_on_cifar.py
@@ -89,19 +89,20 @@
train_labels = train_labels[:4000]
# It is a toy model for debugging (not optimized for accuracy or speed).
- model = APP_MODELS[FLAGS.model](weights=None,
- include_top=FLAGS.include_top,
- input_shape=INPUT_SHAPE[1:])
+ model = APP_MODELS[FLAGS.model](
+ weights=None, include_top=FLAGS.include_top, input_shape=INPUT_SHAPE[1:])
model.summary()
- model.compile(optimizer='adam',
- loss='sparse_categorical_crossentropy',
- metrics=['accuracy'])
+ model.compile(
+ optimizer='adam',
+ loss='sparse_categorical_crossentropy',
+ metrics=['accuracy'])
# train model
- model.fit(train_images,
- train_labels,
- epochs=1,
- validation_data=(test_images, test_labels))
+ model.fit(
+ train_images,
+ train_labels,
+ epochs=1,
+ validation_data=(test_images, test_labels))
file_name = os.path.join(
FLAGS.model_path,
diff --git a/integrations/tensorflow/e2e/keras/vision_model_test.py b/integrations/tensorflow/e2e/keras/vision_model_test.py
index f9bb302..d607060 100644
--- a/integrations/tensorflow/e2e/keras/vision_model_test.py
+++ b/integrations/tensorflow/e2e/keras/vision_model_test.py
@@ -114,9 +114,8 @@
# an external tf.keras URL.
weights = 'imagenet' if FLAGS.data == 'imagenet' else None
- model = APP_MODELS[FLAGS.model](weights=weights,
- include_top=FLAGS.include_top,
- input_shape=input_shape)
+ model = APP_MODELS[FLAGS.model](
+ weights=weights, include_top=FLAGS.include_top, input_shape=input_shape)
if FLAGS.data == 'cifar10' and FLAGS.url:
model = load_cifar10_weights(model)
@@ -132,7 +131,8 @@
# TODO(b/142948097): Add support for dynamic shapes in SPIR-V lowering.
# Replace input_shape with m.input_shape to make the batch size dynamic.
self.predict = tf.function(
- input_signature=[tf.TensorSpec(get_input_shape())])(self.m.call)
+ input_signature=[tf.TensorSpec(get_input_shape())])(
+ self.m.call)
@tf_test_utils.compile_module(VisionModule, exported_names=['predict'])
diff --git a/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertFlowToHAL.cpp b/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertFlowToHAL.cpp
index b26f0f3..0e406eb 100644
--- a/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertFlowToHAL.cpp
+++ b/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertFlowToHAL.cpp
@@ -76,7 +76,7 @@
SmallVector<const HALConversionDialectInterface *, 4> conversionInterfaces;
// Gather all interfaces from registered dialects.
// These will perform the tensor->buffer mapping for their ops.
- for (auto *dialect : context->getRegisteredDialects()) {
+ for (auto *dialect : context->getLoadedDialects()) {
if (auto *conversionInterface =
dialect
->getRegisteredInterface<HALConversionDialectInterface>()) {
diff --git a/iree/samples/custom_modules/dialect/custom_opt.cc b/iree/samples/custom_modules/dialect/custom_opt.cc
index 3264c3e..7fa25c2 100644
--- a/iree/samples/custom_modules/dialect/custom_opt.cc
+++ b/iree/samples/custom_modules/dialect/custom_opt.cc
@@ -107,7 +107,7 @@
if (showDialects) {
llvm::outs() << "Registered Dialects:\n";
mlir::MLIRContext context;
- for (mlir::Dialect *dialect : context.getRegisteredDialects()) {
+ for (mlir::Dialect *dialect : context.getLoadedDialects()) {
llvm::outs() << dialect->getNamespace() << "\n";
}
return 0;
@@ -127,7 +127,10 @@
exit(1);
}
+ mlir::DialectRegistry registry;
+ mlir::getGlobalDialectRegistry().appendTo(registry);
return failed(mlir::MlirOptMain(output->os(), std::move(file), passPipeline,
- splitInputFile, verifyDiagnostics,
- verifyPasses, allowUnregisteredDialects));
+ registry, splitInputFile, verifyDiagnostics,
+ verifyPasses, allowUnregisteredDialects,
+ /*preloadDialectsInContext=*/true));
}
diff --git a/iree/samples/custom_modules/dialect/custom_translate.cc b/iree/samples/custom_modules/dialect/custom_translate.cc
index d8d6da6..1cff7e5 100644
--- a/iree/samples/custom_modules/dialect/custom_translate.cc
+++ b/iree/samples/custom_modules/dialect/custom_translate.cc
@@ -102,6 +102,7 @@
auto processBuffer = [&](std::unique_ptr<llvm::MemoryBuffer> ownedBuffer,
llvm::raw_ostream &os) {
mlir::MLIRContext context;
+ context.loadAllGloballyRegisteredDialects();
llvm::SourceMgr sourceMgr;
sourceMgr.AddNewSourceBuffer(std::move(ownedBuffer), llvm::SMLoc());
mlir::SourceMgrDiagnosticHandler diagHandler(sourceMgr, &context);
diff --git a/iree/tools/opt_main.cc b/iree/tools/opt_main.cc
index 260901c..6108587 100644
--- a/iree/tools/opt_main.cc
+++ b/iree/tools/opt_main.cc
@@ -110,7 +110,7 @@
if (showDialects) {
llvm::outs() << "Registered Dialects:\n";
mlir::MLIRContext context;
- for (mlir::Dialect *dialect : context.getRegisteredDialects()) {
+ for (mlir::Dialect *dialect : context.getLoadedDialects()) {
llvm::outs() << dialect->getNamespace() << "\n";
}
return 0;
@@ -130,7 +130,9 @@
exit(1);
}
+ mlir::DialectRegistry registry;
return failed(mlir::MlirOptMain(output->os(), std::move(file), passPipeline,
- splitInputFile, verifyDiagnostics,
- verifyPasses, allowUnregisteredDialects));
+ registry, splitInputFile, verifyDiagnostics,
+ verifyPasses, allowUnregisteredDialects,
+ /*preloadDialectsInContext=*/true));
}
diff --git a/iree/tools/run_mlir_main.cc b/iree/tools/run_mlir_main.cc
index 144e7c0..289ac5f 100644
--- a/iree/tools/run_mlir_main.cc
+++ b/iree/tools/run_mlir_main.cc
@@ -178,6 +178,7 @@
IREE_TRACE_SCOPE0("PrepareModule");
mlir::MLIRContext context;
+ context.loadAllGloballyRegisteredDialects();
// Parse input MLIR module.
llvm::SourceMgr source_mgr;
diff --git a/iree/tools/translate_main.cc b/iree/tools/translate_main.cc
index 50edac7..3ecd201 100644
--- a/iree/tools/translate_main.cc
+++ b/iree/tools/translate_main.cc
@@ -109,6 +109,7 @@
auto processBuffer = [&](std::unique_ptr<llvm::MemoryBuffer> ownedBuffer,
llvm::raw_ostream &os) {
mlir::MLIRContext context;
+ context.loadAllGloballyRegisteredDialects();
llvm::SourceMgr sourceMgr;
sourceMgr.AddNewSourceBuffer(std::move(ownedBuffer), llvm::SMLoc());
mlir::SourceMgrDiagnosticHandler diagHandler(sourceMgr, &context);
diff --git a/third_party/llvm-project b/third_party/llvm-project
index e75bc5c..bad7d6b 160000
--- a/third_party/llvm-project
+++ b/third_party/llvm-project
@@ -1 +1 @@
-Subproject commit e75bc5c791e0e8dbe79f7453e55af9e8d03c9cc0
+Subproject commit bad7d6b3735d1d855ffb07f32a272049cff085e6
diff --git a/third_party/tensorflow b/third_party/tensorflow
index 0e65b3a..051ed1c 160000
--- a/third_party/tensorflow
+++ b/third_party/tensorflow
@@ -1 +1 @@
-Subproject commit 0e65b3a903cb5f0457d4972cd9ab1b1b8fa98e4d
+Subproject commit 051ed1cbfdefac1404cbfe0c2b1dd6e13c4e8fbd