Add a CLI `iree-ir-tool` with a command to strip data. (#14636)
Been meaning to do this for a while in order to have a place to stash
more power-user style things that core developers typically use iree-opt
for.
This one adds a `strip-data` sub-command which uses the passes from
#14627 to systematically replace tensor constants with synthetic values.
With an ASAN/asserts build, this was able to strip a 7GiB int4 vicuna
MLIR file in ~5s and a 23GiB h2ogpt model in about 30s (the latter has
some other characteristics which make it more expensive to load as well
as being bigger). Results were a 2.6MiB and 1.4MiB MLIR file
respectively, consisting just of program IR and annotations for
synthetic data.
Getting the opt pipeline right for arbitrary input is a bit tricky, so I
decided we should just armor this into a tool
From installed packages, this can be used as:
```
iree-ir-tool strip-data input.mlir -o output.mlir
```
From a build tree with Python setup:
```
python -m iree.compiler.tools.ir_tool strip-data input.mlir -o output.mlir
```
Required adding some additional compiler APIs:
* `ireeCompilerInvocationRunPassPipeline` to run an arbitrary textual
pass pipeline on an invocation.
* `ireeCompilerInvocationOutputIRBytecode` to emit bytecode from an
invocation.
diff --git a/compiler/bindings/c/iree/compiler/embedding_api.h b/compiler/bindings/c/iree/compiler/embedding_api.h
index d56dbcf..2cb5edd 100644
--- a/compiler/bindings/c/iree/compiler/embedding_api.h
+++ b/compiler/bindings/c/iree/compiler/embedding_api.h
@@ -253,11 +253,26 @@
ireeCompilerInvocationPipeline(iree_compiler_invocation_t *inv,
enum iree_compiler_pipeline_t pipeline);
+// Runs an arbitrary pass pipeline.
+// Returns false and emits diagnostics on failure.
+// Available since: 1.4
+IREE_EMBED_EXPORTED bool
+ireeCompilerInvocationRunPassPipeline(iree_compiler_invocation_t *inv,
+ const char *textPassPipeline);
+
// Outputs the current compiler state as textual IR to the output.
IREE_EMBED_EXPORTED iree_compiler_error_t *
ireeCompilerInvocationOutputIR(iree_compiler_invocation_t *inv,
iree_compiler_output_t *output);
+// Outputs the current compiler state as bytecode IR to the output.
+// Emits as the given bytecode version or most recent if -1.
+// Available since: 1.4
+IREE_EMBED_EXPORTED iree_compiler_error_t *
+ireeCompilerInvocationOutputIRBytecode(iree_compiler_invocation_t *inv,
+ iree_compiler_output_t *output,
+ int bytecodeVersion);
+
// Assuming that the compiler has produced VM IR, converts it to bytecode
// and outputs it. This is a valid next step after running the
// IREE_COMPILER_PIPELINE_STD pipeline.
diff --git a/compiler/bindings/c/iree/compiler/loader/handle_symbols.inc b/compiler/bindings/c/iree/compiler/loader/handle_symbols.inc
index 60ff4ce..665a031 100644
--- a/compiler/bindings/c/iree/compiler/loader/handle_symbols.inc
+++ b/compiler/bindings/c/iree/compiler/loader/handle_symbols.inc
@@ -24,7 +24,9 @@
HANDLE_SYMBOL(ireeCompilerInvocationSetCompileToPhase)
HANDLE_SYMBOL(ireeCompilerInvocationSetVerifyIR)
HANDLE_SYMBOL(ireeCompilerInvocationPipeline)
+HANDLE_VERSIONED_SYMBOL(ireeCompilerInvocationRunPassPipeline, 1, 4)
HANDLE_SYMBOL(ireeCompilerInvocationOutputIR)
+HANDLE_VERSIONED_SYMBOL(ireeCompilerInvocationOutputIRBytecode, 1, 4)
HANDLE_SYMBOL(ireeCompilerInvocationOutputVMBytecode)
HANDLE_SYMBOL(ireeCompilerInvocationOutputVMCSource)
HANDLE_SYMBOL(ireeCompilerInvocationOutputHALExecutable)
diff --git a/compiler/bindings/c/iree/compiler/loader/loader.cpp b/compiler/bindings/c/iree/compiler/loader/loader.cpp
index b4fd614..0469b2f 100644
--- a/compiler/bindings/c/iree/compiler/loader/loader.cpp
+++ b/compiler/bindings/c/iree/compiler/loader/loader.cpp
@@ -275,6 +275,11 @@
return __ireeCompilerInvocationPipeline(run, pipeline);
}
+bool ireeCompilerInvocationRunPassPipeline(iree_compiler_invocation_t *inv,
+ const char *textPassPipeline) {
+ return __ireeCompilerInvocationRunPassPipeline(inv, textPassPipeline);
+}
+
iree_compiler_error_t *
ireeCompilerInvocationOutputIR(iree_compiler_invocation_t *run,
iree_compiler_output_t *output) {
@@ -282,6 +287,13 @@
}
iree_compiler_error_t *
+ireeCompilerInvocationOutputIRBytecode(iree_compiler_invocation_t *inv,
+ iree_compiler_output_t *output,
+ int bytecodeVersion) {
+ return __ireeCompilerInvocationOutputIRBytecode(inv, output, bytecodeVersion);
+}
+
+iree_compiler_error_t *
ireeCompilerInvocationOutputVMBytecode(iree_compiler_invocation_t *run,
iree_compiler_output_t *output) {
return __ireeCompilerInvocationOutputVMBytecode(run, output);
diff --git a/compiler/bindings/python/CMakeLists.txt b/compiler/bindings/python/CMakeLists.txt
index b3566bb..206c085 100644
--- a/compiler/bindings/python/CMakeLists.txt
+++ b/compiler/bindings/python/CMakeLists.txt
@@ -60,6 +60,7 @@
SOURCES_GLOB
api/*.py
tools/*.py
+ tools/ir_tool/*.py
)
################################################################################
diff --git a/compiler/bindings/python/iree/compiler/api/ctypes_dl.py b/compiler/bindings/python/iree/compiler/api/ctypes_dl.py
index ee14ee1..40be17a 100644
--- a/compiler/bindings/python/iree/compiler/api/ctypes_dl.py
+++ b/compiler/bindings/python/iree/compiler/api/ctypes_dl.py
@@ -54,8 +54,14 @@
_setsig(_dylib.ireeCompilerInvocationEnableConsoleDiagnostics, None, [c_void_p])
_setsig(_dylib.ireeCompilerInvocationParseSource, c_bool, [c_void_p, c_void_p])
_setsig(_dylib.ireeCompilerInvocationPipeline, c_bool, [c_void_p, c_int])
+ _setsig(_dylib.ireeCompilerInvocationRunPassPipeline, c_bool, [c_void_p, c_char_p])
_setsig(_dylib.ireeCompilerInvocationOutputIR, c_void_p, [c_void_p, c_void_p])
_setsig(
+ _dylib.ireeCompilerInvocationOutputIRBytecode,
+ c_void_p,
+ [c_void_p, c_void_p, c_int],
+ )
+ _setsig(
_dylib.ireeCompilerInvocationOutputVMBytecode, c_void_p, [c_void_p, c_void_p]
)
@@ -328,6 +334,10 @@
# Invocation.
self._retained_module_op = None
+ @property
+ def session(self) -> Session:
+ return self._session
+
def __del__(self):
self.close()
@@ -365,11 +375,23 @@
) -> bool:
return _dylib.ireeCompilerInvocationPipeline(self._inv_p, pipeline)
+ def execute_text_pass_pipeline(self, text_pipeline_spec: str) -> bool:
+ return _dylib.ireeCompilerInvocationRunPassPipeline(
+ self._inv_p, text_pipeline_spec.encode()
+ )
+
def output_ir(self, output: Output):
_handle_error(
_dylib.ireeCompilerInvocationOutputIR(self._inv_p, output._output_p)
)
+ def output_ir_bytecode(self, output: Output, bytecode_version: int = -1):
+ _handle_error(
+ _dylib.ireeCompilerInvocationOutputIRBytecode(
+ self._inv_p, output._output_p, bytecode_version
+ )
+ )
+
def output_vm_bytecode(self, output: Output):
_handle_error(
_dylib.ireeCompilerInvocationOutputVMBytecode(self._inv_p, output._output_p)
diff --git a/compiler/bindings/python/iree/compiler/tools/ir_tool/__main__.py b/compiler/bindings/python/iree/compiler/tools/ir_tool/__main__.py
new file mode 100644
index 0000000..4448c94
--- /dev/null
+++ b/compiler/bindings/python/iree/compiler/tools/ir_tool/__main__.py
@@ -0,0 +1,106 @@
+# Copyright 2023 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+import argparse
+import logging
+import sys
+
+from ...api import Invocation, Session, Source, Output
+
+
+def load_source(inv: Invocation, input_file: str) -> Source:
+ source = Source.open_file(inv.session, input_file)
+ if not inv.parse_source(source):
+ raise RuntimeError(f"Error parsing source file {input_file}")
+ return source
+
+
+def write_output(inv: Invocation, output: Output, args, keep: bool = True):
+ if args.emit_bytecode:
+ inv.output_ir_bytecode(output, args.bytecode_version)
+ else:
+ inv.output_ir(output)
+ if keep:
+ output.keep()
+
+
+###############################################################################
+# CLI handling
+###############################################################################
+
+
+def parse_arguments(argv=None):
+ parser = argparse.ArgumentParser(description="IREE IR Tool")
+ subparsers = parser.add_subparsers(
+ help="sub-command help", required=True, dest="sub_command"
+ )
+
+ def add_ouptut_options(subparser):
+ subparser.add_argument(
+ "--emit-bytecode", action="store_true", help="Emit bytecode"
+ )
+ subparser.add_argument(
+ "--bytecode-version",
+ default=-1,
+ type=int,
+ help="Bytecode version to emit or -1 for latest",
+ )
+
+ # strip-data command.
+ strip_data_parser = subparsers.add_parser(
+ "strip-data",
+ help="Strip large constants and values, "
+ "replacing them with pseudo data suitable for interactive "
+ "debugging of IR",
+ )
+ add_ouptut_options(strip_data_parser)
+ strip_data_parser.add_argument(
+ "--no-import",
+ action="store_true",
+ help="Disable import of public dialects to internal",
+ )
+ strip_data_parser.add_argument("input_file", help="File to process")
+ strip_data_parser.add_argument(
+ "-o", required=True, dest="output_file", help="Output file"
+ )
+ args = parser.parse_args(argv)
+ return args
+
+
+def main(args) -> int:
+ if args.sub_command == "strip-data":
+ return do_strip_data(args)
+ else:
+ print("error: Unrecognized sub-command {args.sub_command}", file=sys.stderr)
+ return 1
+ return 0
+
+
+def do_strip_data(args) -> int:
+ session = Session()
+ output = Output.open_file(args.output_file)
+ inv = session.invocation()
+ inv.enable_console_diagnostics()
+ load_source(inv, args.input_file)
+ if not args.no_import:
+ if not inv.execute_text_pass_pipeline(
+ "iree-import-public, iree-import-ml-program"
+ ):
+ return 1
+ if not inv.execute_text_pass_pipeline(
+ "iree-util-outline-constants, iree-util-strip-and-splat-constants"
+ ):
+ return 2
+ write_output(inv, output, args)
+ return 0
+
+
+def _cli_main():
+ sys.exit(main(parse_arguments()))
+
+
+if __name__ == "__main__":
+ _cli_main()
diff --git a/compiler/bindings/python/test/tools/CMakeLists.txt b/compiler/bindings/python/test/tools/CMakeLists.txt
index 427116b..01456f9 100644
--- a/compiler/bindings/python/test/tools/CMakeLists.txt
+++ b/compiler/bindings/python/test/tools/CMakeLists.txt
@@ -17,6 +17,13 @@
iree_py_test(
NAME
+ ir_tool_test
+ SRCS
+ "ir_tool_test.py"
+)
+
+iree_py_test(
+ NAME
compiler_tf_test
SRCS
"compiler_tf_test.py"
diff --git a/compiler/bindings/python/test/tools/ir_tool_test.py b/compiler/bindings/python/test/tools/ir_tool_test.py
new file mode 100644
index 0000000..2a7a538
--- /dev/null
+++ b/compiler/bindings/python/test/tools/ir_tool_test.py
@@ -0,0 +1,114 @@
+# Copyright 2023 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+from iree.compiler.tools.ir_tool import __main__
+
+import os
+import tempfile
+import unittest
+
+
+def run_tool(*argv: str):
+ try:
+ args = __main__.parse_arguments(list(argv))
+ __main__.main(args)
+ except SystemExit as e:
+ if e.code != 0:
+ raise RuntimeError(f"Tool exited with code {e.code}")
+
+
+class IrToolTest(unittest.TestCase):
+ def setUp(self):
+ with tempfile.NamedTemporaryFile(delete=False) as f:
+ self.inputPath = f.name
+ with tempfile.NamedTemporaryFile(delete=False) as f:
+ self.outputPath = f.name
+
+ def tearDown(self) -> None:
+ if os.path.exists(self.inputPath):
+ os.unlink(self.inputPath)
+ if os.path.exists(self.outputPath):
+ os.unlink(self.outputPath)
+
+ def saveInput(self, contents, text=True):
+ with open(self.inputPath, "wt" if text else "wb") as f:
+ f.write(contents)
+
+ def loadOutput(self, text=True):
+ with open(self.outputPath, "rt" if text else "rb") as f:
+ return f.read()
+
+ def testStripDataWithImport(self):
+ self.saveInput(
+ r"""
+ builtin.module {
+ func.func @main() -> tensor<4xf32> {
+ %0 = arith.constant dense<[0.1, 0.2, 0.3, 0.4]> : tensor<4xf32>
+ func.return %0 : tensor<4xf32>
+ }
+ }
+ """
+ )
+ run_tool("strip-data", self.inputPath, "-o", self.outputPath)
+ output = self.loadOutput()
+ print("Output:", output)
+ self.assertIn("#util.byte_pattern", output)
+
+ def testStripDataNoImport(self):
+ # Without import, ml_program.global is not recognized.
+ self.saveInput(
+ r"""
+ builtin.module {
+ ml_program.global public @foobar(dense<[0.1, 0.2, 0.3, 0.4]> : tensor<4xf32>) : tensor<4xf32>
+ }
+ """
+ )
+ run_tool("strip-data", "--no-import", self.inputPath, "-o", self.outputPath)
+ output = self.loadOutput()
+ print("Output:", output)
+ self.assertNotIn("#util.byte_pattern", output)
+
+ def testStripDataParseError(self):
+ self.saveInput(
+ r"""
+ FOOBAR
+ """
+ )
+ with self.assertRaisesRegex(RuntimeError, "Error parsing source file"):
+ run_tool("strip-data", self.inputPath, "-o", self.outputPath)
+
+ def testStripDataEmitBytecode(self):
+ self.saveInput(
+ r"""
+ builtin.module {
+ }
+ """
+ )
+ run_tool("strip-data", "--emit-bytecode", self.inputPath, "-o", self.outputPath)
+ output = self.loadOutput(text=False)
+ self.assertIn(b"MLIR", output)
+
+ def testStripDataEmitBytecodeVersion(self):
+ self.saveInput(
+ r"""
+ builtin.module {
+ }
+ """
+ )
+ run_tool(
+ "strip-data",
+ "--emit-bytecode",
+ "--bytecode-version=0",
+ self.inputPath,
+ "-o",
+ self.outputPath,
+ )
+ output = self.loadOutput(text=False)
+ self.assertIn(b"MLIR", output)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/compiler/setup.py b/compiler/setup.py
index 37a3aa1..ea1859f 100644
--- a/compiler/setup.py
+++ b/compiler/setup.py
@@ -455,6 +455,7 @@
# TODO: We have renamed to iree-compile on 2022-03-18. Remove
# this alias once no longer needed.
"ireec = iree.compiler.tools.scripts.ireec.__main__:main",
+ "iree-ir-tool = iree.compiler.tools.ir_tool.__main__:_cli_main",
],
},
install_requires=[
diff --git a/compiler/src/iree/compiler/API/Internal/BUILD.bazel b/compiler/src/iree/compiler/API/Internal/BUILD.bazel
index f9cec5c..ab6a55f 100644
--- a/compiler/src/iree/compiler/API/Internal/BUILD.bazel
+++ b/compiler/src/iree/compiler/API/Internal/BUILD.bazel
@@ -38,6 +38,7 @@
"//compiler/src/iree/compiler/Utils",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:BuiltinToLLVMIRTranslation",
+ "@llvm-project//mlir:BytecodeWriter",
"@llvm-project//mlir:CAPIIR",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Parser",
diff --git a/compiler/src/iree/compiler/API/Internal/CMakeLists.txt b/compiler/src/iree/compiler/API/Internal/CMakeLists.txt
index aabf542..d14608f 100644
--- a/compiler/src/iree/compiler/API/Internal/CMakeLists.txt
+++ b/compiler/src/iree/compiler/API/Internal/CMakeLists.txt
@@ -21,6 +21,7 @@
DEPS
LLVMSupport
MLIRBuiltinToLLVMIRTranslation
+ MLIRBytecodeWriter
MLIRCAPIIR
MLIRIR
MLIRParser
diff --git a/compiler/src/iree/compiler/API/Internal/Embed.cpp b/compiler/src/iree/compiler/API/Internal/Embed.cpp
index 044cd69..36801a9 100644
--- a/compiler/src/iree/compiler/API/Internal/Embed.cpp
+++ b/compiler/src/iree/compiler/API/Internal/Embed.cpp
@@ -63,6 +63,7 @@
#include "llvm/Support/Signals.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/ToolOutputFile.h"
+#include "mlir/Bytecode/BytecodeWriter.h"
#include "mlir/CAPI/IR.h"
#include "mlir/CAPI/Wrap.h"
#include "mlir/IR/AsmState.h"
@@ -84,7 +85,7 @@
#endif
#define IREE_COMPILER_API_MAJOR 1
-#define IREE_COMPILER_API_MINOR 3
+#define IREE_COMPILER_API_MINOR 4
namespace mlir::iree_compiler::embed {
namespace {
@@ -529,19 +530,23 @@
// Invocation corresponds to iree_compiler_invocation_t
struct Invocation {
+ using PassManagerInitializer = std::function<void(PassManager &pm)>;
Invocation(Session &session);
~Invocation();
bool initializeInvocation();
+ std::unique_ptr<PassManager> createPassManager();
bool parseSource(Source &source);
bool importModule(Operation *inputModule, bool steal);
bool runPipeline(enum iree_compiler_pipeline_t pipeline);
+ bool runTextualPassPipeline(const char *textPassPipeline);
Error *outputIR(Output &output);
+ Error *outputIRBytecode(Output &output, int bytecodeVersion);
Error *outputVMBytecode(Output &output);
Error *outputVMCSource(Output &output);
Error *outputHALExecutable(Output &output);
Session &session;
- PassManager passManager;
+ llvm::SmallVector<PassManagerInitializer> passManagerInitializers;
IREEVMPipelineHooks pipelineHooks;
// Diagnostic handlers are instantiated upon parsing the source (when we
@@ -567,17 +572,7 @@
int diagnosticCallbackFlags = 0;
};
-Invocation::Invocation(Session &session)
- : session(session), passManager(&session.context) {
- if (session.globalInit.usesCommandLine) {
- if (failed(mlir::applyPassManagerCLOptions(passManager))) {
- emitError(UnknownLoc::get(&session.context))
- << "Failed to apply pass manager CL options";
- }
- mlir::applyDefaultTimingPassManagerCLOptions(passManager);
- }
- passManager.addInstrumentation(std::make_unique<PassTracing>());
-
+Invocation::Invocation(Session &session) : session(session) {
// Since the jitter invokes much of the top-level compiler recursively,
// it must be injected at the top-level here vs in the pass pipeline
// (or else the circular dependency cannot be resolved).
@@ -597,6 +592,23 @@
}
}
+std::unique_ptr<PassManager> Invocation::createPassManager() {
+ auto passManager = std::make_unique<PassManager>(&session.context);
+ if (session.globalInit.usesCommandLine) {
+ if (failed(mlir::applyPassManagerCLOptions(*passManager))) {
+ emitError(UnknownLoc::get(&session.context))
+ << "Failed to apply pass manager CL options";
+ }
+ mlir::applyDefaultTimingPassManagerCLOptions(*passManager);
+ }
+ passManager->addInstrumentation(std::make_unique<PassTracing>());
+ passManager->enableVerifier(enableVerifier);
+ for (auto &init : passManagerInitializers) {
+ init(*passManager);
+ }
+ return passManager;
+}
+
bool Invocation::initializeInvocation() {
// Initialize callback diagnostics.
if (diagnosticCallback && !callbackDiagnosticHandler) {
@@ -697,6 +709,7 @@
}
bool Invocation::runPipeline(enum iree_compiler_pipeline_t pipeline) {
+ auto passManager = createPassManager();
switch (pipeline) {
case IREE_COMPILER_PIPELINE_STD: {
// Parse the compile to phase name.
@@ -745,7 +758,7 @@
session.targetRegistry, session.bindingOptions, session.inputOptions,
session.preprocessingOptions, session.highLevelOptimizationOptions,
session.schedulingOptions, session.halTargetOptions,
- session.vmTargetOptions, pipelineHooks, passManager, *compileFromPhase,
+ session.vmTargetOptions, pipelineHooks, *passManager, *compileFromPhase,
*compileToPhase);
break;
}
@@ -764,7 +777,7 @@
return false;
}
IREE::HAL::buildHALTransformPassPipeline(
- passManager, session.targetRegistry, session.halTargetOptions);
+ *passManager, session.targetRegistry, session.halTargetOptions);
break;
}
default:
@@ -772,8 +785,18 @@
return false;
}
- passManager.enableVerifier(enableVerifier);
- if (failed(passManager.run(parsedModule))) {
+ if (failed(passManager->run(parsedModule))) {
+ return false;
+ }
+ return true;
+}
+
+bool Invocation::runTextualPassPipeline(const char *textPassPipeline) {
+ auto passManager = createPassManager();
+ if (failed(mlir::parsePassPipeline(textPassPipeline, *passManager,
+ llvm::errs())))
+ return false;
+ if (failed(passManager->run(parsedModule))) {
return false;
}
return true;
@@ -784,6 +807,17 @@
return output.getWriteError();
}
+Error *Invocation::outputIRBytecode(Output &output, int bytecodeVersion) {
+ mlir::BytecodeWriterConfig config;
+ if (bytecodeVersion >= 0)
+ config.setDesiredBytecodeVersion(bytecodeVersion);
+ if (failed(mlir::writeBytecodeToFile(parsedModule, *output.outputStream,
+ config))) {
+ return new Error("illegal bytecode version requested");
+ }
+ return output.getWriteError();
+}
+
Error *Invocation::outputVMBytecode(Output &output) {
auto vmModule = llvm::dyn_cast<IREE::VM::ModuleOp>(*parsedModule);
auto builtinModule = llvm::dyn_cast<mlir::ModuleOp>(*parsedModule);
@@ -1134,24 +1168,27 @@
iree_compiler_output_t *output;
};
- unwrap(inv)->passManager.enableCrashReproducerGeneration(
- [=](std::string &errorMessage)
- -> std::unique_ptr<mlir::PassManager::ReproducerStream> {
- iree_compiler_output_t *output = nullptr;
- auto error = onCrashCallback(&output, userData);
- if (error) {
- errorMessage = ireeCompilerErrorGetMessage(error);
- return nullptr;
- }
+ unwrap(inv)->passManagerInitializers.push_back(
+ [=](mlir::PassManager &passManager) {
+ passManager.enableCrashReproducerGeneration(
+ [=](std::string &errorMessage)
+ -> std::unique_ptr<mlir::PassManager::ReproducerStream> {
+ iree_compiler_output_t *output = nullptr;
+ auto error = onCrashCallback(&output, userData);
+ if (error) {
+ errorMessage = ireeCompilerErrorGetMessage(error);
+ return nullptr;
+ }
- if (!output) {
- errorMessage = "callback did not set output";
- return nullptr;
- }
+ if (!output) {
+ errorMessage = "callback did not set output";
+ return nullptr;
+ }
- return std::make_unique<StreamImpl>(output);
- },
- /*genLocalReproducer=*/genLocalReproducer);
+ return std::make_unique<StreamImpl>(output);
+ },
+ /*genLocalReproducer=*/genLocalReproducer);
+ });
}
bool ireeCompilerInvocationParseSource(iree_compiler_invocation_t *inv,
@@ -1179,6 +1216,11 @@
return unwrap(inv)->runPipeline(pipeline);
}
+bool ireeCompilerInvocationRunPassPipeline(iree_compiler_invocation_t *inv,
+ const char *textPassPipeline) {
+ return unwrap(inv)->runTextualPassPipeline(textPassPipeline);
+}
+
void ireeCompilerSourceDestroy(iree_compiler_source_t *source) {
delete unwrap(source);
}
@@ -1263,6 +1305,13 @@
}
iree_compiler_error_t *
+ireeCompilerInvocationOutputIRBytecode(iree_compiler_invocation_t *inv,
+ iree_compiler_output_t *output,
+ int bytecodeVersion) {
+ return wrap(unwrap(inv)->outputIRBytecode(*unwrap(output), bytecodeVersion));
+}
+
+iree_compiler_error_t *
ireeCompilerInvocationOutputVMBytecode(iree_compiler_invocation_t *inv,
iree_compiler_output_t *output) {
return wrap(unwrap(inv)->outputVMBytecode(*unwrap(output)));
diff --git a/compiler/src/iree/compiler/API/api_exports.c b/compiler/src/iree/compiler/API/api_exports.c
index 5a3cae7..114d7ca 100644
--- a/compiler/src/iree/compiler/API/api_exports.c
+++ b/compiler/src/iree/compiler/API/api_exports.c
@@ -25,10 +25,12 @@
extern void ireeCompilerInvocationImportStealModule();
extern void ireeCompilerInvocationOutputHALExecutable();
extern void ireeCompilerInvocationOutputIR();
+extern void ireeCompilerInvocationOutputIRBytecode();
extern void ireeCompilerInvocationOutputVMBytecode();
extern void ireeCompilerInvocationOutputVMCSource();
extern void ireeCompilerInvocationParseSource();
extern void ireeCompilerInvocationPipeline();
+extern void ireeCompilerInvocationRunPassPipeline();
extern void ireeCompilerInvocationSetCompileFromPhase();
extern void ireeCompilerInvocationSetCompileToPhase();
extern void ireeCompilerInvocationSetCrashHandler();
@@ -649,6 +651,7 @@
extern void mlirValuePrint();
extern void mlirValuePrintAsOperand();
extern void mlirValueReplaceAllUsesOfWith();
+extern void mlirValueSetType();
extern void mlirVectorTypeGet();
extern void mlirVectorTypeGetChecked();
extern void mlirVectorTypeGetTypeID();
@@ -672,10 +675,12 @@
x += (uintptr_t)&ireeCompilerInvocationImportStealModule;
x += (uintptr_t)&ireeCompilerInvocationOutputHALExecutable;
x += (uintptr_t)&ireeCompilerInvocationOutputIR;
+ x += (uintptr_t)&ireeCompilerInvocationOutputIRBytecode;
x += (uintptr_t)&ireeCompilerInvocationOutputVMBytecode;
x += (uintptr_t)&ireeCompilerInvocationOutputVMCSource;
x += (uintptr_t)&ireeCompilerInvocationParseSource;
x += (uintptr_t)&ireeCompilerInvocationPipeline;
+ x += (uintptr_t)&ireeCompilerInvocationRunPassPipeline;
x += (uintptr_t)&ireeCompilerInvocationSetCompileFromPhase;
x += (uintptr_t)&ireeCompilerInvocationSetCompileToPhase;
x += (uintptr_t)&ireeCompilerInvocationSetCrashHandler;
@@ -1296,6 +1301,7 @@
x += (uintptr_t)&mlirValuePrint;
x += (uintptr_t)&mlirValuePrintAsOperand;
x += (uintptr_t)&mlirValueReplaceAllUsesOfWith;
+ x += (uintptr_t)&mlirValueSetType;
x += (uintptr_t)&mlirVectorTypeGet;
x += (uintptr_t)&mlirVectorTypeGetChecked;
x += (uintptr_t)&mlirVectorTypeGetTypeID;
diff --git a/compiler/src/iree/compiler/API/api_exports.def b/compiler/src/iree/compiler/API/api_exports.def
index 986cb46..b07a8a4 100644
--- a/compiler/src/iree/compiler/API/api_exports.def
+++ b/compiler/src/iree/compiler/API/api_exports.def
@@ -17,10 +17,12 @@
ireeCompilerInvocationImportStealModule
ireeCompilerInvocationOutputHALExecutable
ireeCompilerInvocationOutputIR
+ ireeCompilerInvocationOutputIRBytecode
ireeCompilerInvocationOutputVMBytecode
ireeCompilerInvocationOutputVMCSource
ireeCompilerInvocationParseSource
ireeCompilerInvocationPipeline
+ ireeCompilerInvocationRunPassPipeline
ireeCompilerInvocationSetCompileFromPhase
ireeCompilerInvocationSetCompileToPhase
ireeCompilerInvocationSetCrashHandler
@@ -641,6 +643,7 @@
mlirValuePrint
mlirValuePrintAsOperand
mlirValueReplaceAllUsesOfWith
+ mlirValueSetType
mlirVectorTypeGet
mlirVectorTypeGetChecked
mlirVectorTypeGetTypeID
diff --git a/compiler/src/iree/compiler/API/api_exports.ld b/compiler/src/iree/compiler/API/api_exports.ld
index a252dcb..f02718f 100644
--- a/compiler/src/iree/compiler/API/api_exports.ld
+++ b/compiler/src/iree/compiler/API/api_exports.ld
@@ -18,10 +18,12 @@
ireeCompilerInvocationImportStealModule;
ireeCompilerInvocationOutputHALExecutable;
ireeCompilerInvocationOutputIR;
+ ireeCompilerInvocationOutputIRBytecode;
ireeCompilerInvocationOutputVMBytecode;
ireeCompilerInvocationOutputVMCSource;
ireeCompilerInvocationParseSource;
ireeCompilerInvocationPipeline;
+ ireeCompilerInvocationRunPassPipeline;
ireeCompilerInvocationSetCompileFromPhase;
ireeCompilerInvocationSetCompileToPhase;
ireeCompilerInvocationSetCrashHandler;
@@ -642,6 +644,7 @@
mlirValuePrint;
mlirValuePrintAsOperand;
mlirValueReplaceAllUsesOfWith;
+ mlirValueSetType;
mlirVectorTypeGet;
mlirVectorTypeGetChecked;
mlirVectorTypeGetTypeID;
diff --git a/compiler/src/iree/compiler/API/api_exports.macos.lst b/compiler/src/iree/compiler/API/api_exports.macos.lst
index 07fb753..14283da 100644
--- a/compiler/src/iree/compiler/API/api_exports.macos.lst
+++ b/compiler/src/iree/compiler/API/api_exports.macos.lst
@@ -16,10 +16,12 @@
_ireeCompilerInvocationImportStealModule
_ireeCompilerInvocationOutputHALExecutable
_ireeCompilerInvocationOutputIR
+_ireeCompilerInvocationOutputIRBytecode
_ireeCompilerInvocationOutputVMBytecode
_ireeCompilerInvocationOutputVMCSource
_ireeCompilerInvocationParseSource
_ireeCompilerInvocationPipeline
+_ireeCompilerInvocationRunPassPipeline
_ireeCompilerInvocationSetCompileFromPhase
_ireeCompilerInvocationSetCompileToPhase
_ireeCompilerInvocationSetCrashHandler
@@ -640,6 +642,7 @@
_mlirValuePrint
_mlirValuePrintAsOperand
_mlirValueReplaceAllUsesOfWith
+_mlirValueSetType
_mlirVectorTypeGet
_mlirVectorTypeGetChecked
_mlirVectorTypeGetTypeID