Sanitize LLVM/AOT temporary file name components.
* Fixes #4439
diff --git a/bindings/python/tests/compiler_core_test.py b/bindings/python/tests/compiler_core_test.py
index 54c15bd..6ce6ea7 100644
--- a/bindings/python/tests/compiler_core_test.py
+++ b/bindings/python/tests/compiler_core_test.py
@@ -44,6 +44,15 @@
logging.info("Flatbuffer size = %d", len(binary))
self.assertTrue(binary)
+ # Compiling the string form means that the compiler does not have a valid
+ # source file name, which can cause issues on the AOT side. Verify
+ # specifically. See: https://github.com/google/iree/issues/4439
+ def testCompileStrLLVMAOT(self):
+ binary = compiler.compile_str(SIMPLE_MUL_ASM,
+ target_backends=["dylib-llvm-aot"])
+ logging.info("Flatbuffer size = %d", len(binary))
+ self.assertTrue(binary)
+
def testCompileInputFile(self):
with tempfile.NamedTemporaryFile("wt", delete=False) as f:
try:
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/LinkerTool.cpp b/iree/compiler/Dialect/HAL/Target/LLVM/LinkerTool.cpp
index bda6d8a..4d18971 100644
--- a/iree/compiler/Dialect/HAL/Target/LLVM/LinkerTool.cpp
+++ b/iree/compiler/Dialect/HAL/Target/LLVM/LinkerTool.cpp
@@ -21,11 +21,30 @@
namespace IREE {
namespace HAL {
+// Sanitizes potentially user provided portions of a file name by replacing
+// all but a small set of alpha numeric and safe punctuation characters with
+// '_'. This is intended for components of temporary files that are uniqued
+// independently, where the input is meant to aid debugability but does not
+// need to be retained verbatim.
+static void sanitizeFilePart(llvm::SmallVectorImpl<char> &part) {
+ for (char &c : part) {
+ if ((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') ||
+ (c >= '0' && c <= '9') || c == '_' || c == '-' || c == '.')
+ continue;
+ c = '_';
+ }
+}
+
// static
Artifact Artifact::createTemporary(StringRef prefix, StringRef suffix) {
+ llvm::SmallString<8> prefixCopy(prefix);
+ llvm::SmallString<8> suffixCopy(suffix);
+ sanitizeFilePart(prefixCopy);
+ sanitizeFilePart(suffixCopy);
+
llvm::SmallString<32> filePath;
- if (std::error_code error =
- llvm::sys::fs::createTemporaryFile(prefix, suffix, filePath)) {
+ if (std::error_code error = llvm::sys::fs::createTemporaryFile(
+ prefixCopy, suffixCopy, filePath)) {
llvm::errs() << "failed to generate temporary file: " << error.message();
return {};
}