Synchronize submodules with LLVM at llvm/llvm-project@6d8d1338629c Updates LLVM dependencies to match [6d8d1338629c](https://github.com/llvm/llvm-project/commit/6d8d1338629c). - llvm-bazel to [52812526859e](https://github.com/google/llvm-bazel/commit/52812526859e) - TensorFlow to [2a5146c8a03c](https://github.com/tensorflow/tensorflow/commit/2a5146c8a03c) - MLIR-HLO to [f4414fcd666b](https://github.com/tensorflow/mlir-hlo/commit/${MLIR_HLO_SHA?}) `./scripts/git/update_to_llvm_syncpoint.py` Automated submodule bump from .github/workflows/update_llvm_dependent_submodules.yml PiperOrigin-RevId: 371125346
diff --git a/bindings/python/iree/jax/README.md b/bindings/python/iree/jax/README.md index 557cfee..a4fac75 100644 --- a/bindings/python/iree/jax/README.md +++ b/bindings/python/iree/jax/README.md
@@ -95,8 +95,12 @@ adb push /tmp/mlp_apply.vmfb /data/local/tmp/ adb push ../iree-build-android/iree/tools/iree-run-module /data/local/tmp/ adb shell /data/local/tmp/iree-run-module \ - -module_file=/data/local/tmp/mlp_apply.vmfb \ - -function_inputs="128xf32,784x128xf32,10xf32,128x10xf32,1x28x28x1xf32" \ - -driver=dylib \ - -entry_function=main + --driver=dylib \ + --module_file=/data/local/tmp/mlp_apply.vmfb \ + --entry_function=main \ + --function_input=128xf32 \ + --function_input=784x128xf32 \ + --function_input=10xf32 \ + --function_input=128x10xf32 \ + --function_input=1x28x28x1xf32 ```
diff --git a/build_tools/cmake/iree_cc_binary.cmake b/build_tools/cmake/iree_cc_binary.cmake index d35a873..a20e550 100644 --- a/build_tools/cmake/iree_cc_binary.cmake +++ b/build_tools/cmake/iree_cc_binary.cmake
@@ -71,6 +71,11 @@ add_executable(${_NAME} "") add_executable(${_RULE_NAME} ALIAS ${_NAME}) + if(_RULE_OUT) + set_target_properties(${_NAME} PROPERTIES OUTPUT_NAME "${_RULE_OUT}") + else() + set_target_properties(${_NAME} PROPERTIES OUTPUT_NAME "${_RULE_NAME}") + endif() if(_RULE_SRCS) target_sources(${_NAME} PRIVATE @@ -84,11 +89,6 @@ ${_DUMMY_SRC} ) endif() - if(_RULE_OUT) - set_target_properties(${_NAME} PROPERTIES OUTPUT_NAME "${_RULE_OUT}") - else() - set_target_properties(${_NAME} PROPERTIES OUTPUT_NAME "${_RULE_NAME}") - endif() target_include_directories(${_NAME} SYSTEM PUBLIC "$<BUILD_INTERFACE:${IREE_SOURCE_DIR}>"
diff --git a/build_tools/cmake/iree_cc_test.cmake b/build_tools/cmake/iree_cc_test.cmake index 54456f1..72279dc 100644 --- a/build_tools/cmake/iree_cc_test.cmake +++ b/build_tools/cmake/iree_cc_test.cmake
@@ -72,6 +72,12 @@ set(_NAME "${_PACKAGE_NAME}_${_RULE_NAME}") add_executable(${_NAME} "") + add_executable(${_RULE_NAME} ALIAS ${_NAME}) + if(_RULE_OUT) + set_target_properties(${_NAME} PROPERTIES OUTPUT_NAME "${_RULE_OUT}") + else() + set_target_properties(${_NAME} PROPERTIES OUTPUT_NAME "${_RULE_NAME}") + endif() target_sources(${_NAME} PRIVATE ${_RULE_SRCS} @@ -131,7 +137,7 @@ ${_TEST_NAME} COMMAND "${CMAKE_SOURCE_DIR}/build_tools/cmake/run_android_test.${IREE_HOST_SCRIPT_EXT}" - "${_ANDROID_REL_DIR}/${_NAME}" + "${_ANDROID_REL_DIR}/$<TARGET_FILE_NAME:${_NAME}>" ) # Use environment variables to instruct the script to push artifacts # onto the Android device before running the test. This needs to match
diff --git a/build_tools/cmake/iree_lit_test.cmake b/build_tools/cmake/iree_lit_test.cmake index 6d18823..9d18495 100644 --- a/build_tools/cmake/iree_lit_test.cmake +++ b/build_tools/cmake/iree_lit_test.cmake
@@ -54,11 +54,13 @@ return() endif() + iree_package_ns(_PACKAGE_NS) iree_package_name(_PACKAGE_NAME) set(_NAME "${_PACKAGE_NAME}_${_RULE_NAME}") get_filename_component(_TEST_FILE_PATH ${_RULE_TEST_FILE} ABSOLUTE) + list(TRANSFORM _RULE_DATA REPLACE "^::" "${_PACKAGE_NS}::") set(_DATA_DEP_PATHS) foreach(_DATA_DEP ${_RULE_DATA}) string(REPLACE "::" "_" _DATA_DEP_NAME ${_DATA_DEP})
diff --git a/build_tools/mako/configuration.py b/build_tools/mako/configuration.py index 256848d..15d7cd3 100644 --- a/build_tools/mako/configuration.py +++ b/build_tools/mako/configuration.py
@@ -125,7 +125,7 @@ "--iree-flow-inline-constants-max-byte-length=2048", ], runtime_flags=[ - "--dylib_worker_count=3", + "--task_topology_group_count=3", ]), TargetInfo( driver="vulkan", @@ -178,7 +178,7 @@ "--iree-flow-inline-constants-max-byte-length=2048", ], runtime_flags=[ - "--dylib_worker_count=3", + "--task_topology_group_count=3", ]), TargetInfo( driver="vulkan",
diff --git a/docs/design_docs/cuda_backend.md b/docs/design_docs/cuda_backend.md index 5b84c25..eb0cac4 100644 --- a/docs/design_docs/cuda_backend.md +++ b/docs/design_docs/cuda_backend.md
@@ -79,15 +79,18 @@ # First translate into a VM bytecode module using linalg on tensors path. $ ../iree-build/iree/tools/iree-translate \ -iree-mlir-to-vm-bytecode-module \ - --iree-hal-target-backends=cuda \ + -iree-hal-target-backends=cuda \ -iree-flow-dispatch-linalg-on-tensors \ /tmp/add.mlir \ -o /tmp/mhlo-add.vmfb # Run the module through CUDA HAL backend. $ ../iree-build/iree/tools/iree-run-module \ --module_file=/tmp/mhlo-add.vmfb -driver=cuda -entry_function=add \ ---function_inputs='4xf32=[1 2 3 4], 4xf32=[2 2 2 2]' + --driver=cuda \ + --module_file=/tmp/mhlo-add.vmfb \ + --entry_function=add \ + --function_input="4xf32=[1 2 3 4]" \ + --function_input="4xf32=[2 2 2 2]" EXEC @add 4xf32=3 4 5 6
diff --git a/docs/developing_iree/benchmarking.md b/docs/developing_iree/benchmarking.md index b203c62..6956b67 100644 --- a/docs/developing_iree/benchmarking.md +++ b/docs/developing_iree/benchmarking.md
@@ -33,7 +33,7 @@ --module_file=/tmp/module.fb \ --driver=vmla \ --entry_function=abs \ - --function_inputs="i32=-2" + --function_input=i32=-2 ``` You'll see output like @@ -80,7 +80,7 @@ --module_file=/tmp/module.fb \ --driver=vmla \ --entry_function=abs \ - --function_inputs="i32=-2" + --function_input=i32=-2 ``` ```shell
diff --git a/docs/developing_iree/developer_overview.md b/docs/developing_iree/developer_overview.md index 13a3b69..e0caae5 100644 --- a/docs/developing_iree/developer_overview.md +++ b/docs/developing_iree/developer_overview.md
@@ -141,7 +141,7 @@ --module_file=/tmp/simple.vmfb \ --driver=vmla \ --entry_function=abs \ - --function_inputs="i32=-2" + --function_input=i32=-2 ``` ### iree-check-module
diff --git a/docs/developing_iree/profiling_vulkan_gpu.md b/docs/developing_iree/profiling_vulkan_gpu.md index b9421f0..8fb76b5 100644 --- a/docs/developing_iree/profiling_vulkan_gpu.md +++ b/docs/developing_iree/profiling_vulkan_gpu.md
@@ -45,10 +45,10 @@ # Then package the Android app $ /path/to/iree/source/iree/tools/android/run_module_app/build_apk.sh \ ./build-apk \ + --driver vulkan \ --module_file /tmp/mhlo-dot.vmfb \ --entry_function dot \ - --function_inputs_file /path/to/inputs/file \ - --driver vulkan + --function_input=... ``` Where `/path/to/input/file` is a file containing inputs to `dot`, for example:
diff --git a/docs/developing_iree/profiling_with_tracy.md b/docs/developing_iree/profiling_with_tracy.md index 4d5c58f..528d731 100644 --- a/docs/developing_iree/profiling_with_tracy.md +++ b/docs/developing_iree/profiling_with_tracy.md
@@ -137,9 +137,11 @@ ```shell TRACY_NO_EXIT=1 /data/local/tmp/iree-benchmark-module \ --driver=dylib \ - --function_inputs='1x384xi32,1x384xi32,1x384xi32' \ --module_file=/data/local/tmp/android_module.fbvm \ - --entry_function=serving_default + --entry_function=serving_default \ + --function_input=1x384xi32 \ + --function_input=1x384xi32 \ + --function_input=1x384xi32 ``` ## Running the Tracy profiler UI, connecting and visualizing
diff --git a/docs/get_started/getting_started_android_cmake.md b/docs/get_started/getting_started_android_cmake.md index 2420f83..8928f11 100644 --- a/docs/get_started/getting_started_android_cmake.md +++ b/docs/get_started/getting_started_android_cmake.md
@@ -164,10 +164,11 @@ $ adb shell android $ cd /data/local/tmp/ -android $ ./iree-run-module -driver=vmla \ - -module_file=simple-vmla.vmfb \ - -entry_function=abs \ - -function_inputs="i32=-5" +android $ ./iree-run-module \ + --driver=vmla \ + --module_file=simple-vmla.vmfb \ + --entry_function=abs \ + --function_input=i32=-5 EXEC @abs i32=5 @@ -202,10 +203,11 @@ $ adb shell android $ cd /data/local/tmp/ -android $ ./iree-run-module -driver=vulkan \ - -module_file=simple-vulkan.vmfb \ - -entry_function=abs \ - -function_inputs="i32=-5" +android $ ./iree-run-module \ + --driver=vulkan \ + --module_file=simple-vulkan.vmfb \ + --entry_function=abs \ + --function_input=i32=-5 EXEC @abs i32=5 @@ -279,10 +281,11 @@ $ adb shell android $ cd /data/local/tmp/ -android $ ./iree-run-module -driver=dylib \ - -module_file=simple-llvm_aot.vmfb \ - -entry_function=abs \ - -function_inputs="i32=-5" +android $ ./iree-run-module \ + --driver=dylib \ + --module_file=simple-llvm_aot.vmfb \ + --entry_function=abs \ + --function_input=i32=-5 EXEC @abs i32=5
diff --git a/docs/get_started/getting_started_linux_cmake.md b/docs/get_started/getting_started_linux_cmake.md index c0ec8d2..92a3dfd 100644 --- a/docs/get_started/getting_started_linux_cmake.md +++ b/docs/get_started/getting_started_linux_cmake.md
@@ -131,10 +131,11 @@ Then run the compiled module using the `dylib` HAL driver: ```shell -$ ../iree-build/iree/tools/iree-run-module -driver=dylib \ - -module_file=/tmp/simple-llvm_aot.vmfb \ - -entry_function=abs \ - -function_inputs="i32=-5" +$ ../iree-build/iree/tools/iree-run-module \ + --driver=dylib \ + --module_file=/tmp/simple-llvm_aot.vmfb \ + --entry_function=abs \ + --function_input=i32=-5 EXEC @abs i32=5
diff --git a/docs/get_started/getting_started_linux_vulkan.md b/docs/get_started/getting_started_linux_vulkan.md index c4ee13d..1264297 100644 --- a/docs/get_started/getting_started_linux_vulkan.md +++ b/docs/get_started/getting_started_linux_vulkan.md
@@ -134,10 +134,18 @@ ```shell # -- CMake -- $ cmake --build ../iree-build/ --target iree_tools_iree-run-module -$ ../iree-build/iree/tools/iree-run-module -module_file=/tmp/module.vmfb -driver=vulkan -entry_function=abs -function_inputs="i32=-2" +$ ../iree-build/iree/tools/iree-run-module \ + --driver=vulkan \ + --module_file=/tmp/module.vmfb \ + --entry_function=abs \ + --function_input=i32=-2 # -- Bazel -- -$ bazel run iree/tools:iree-run-module -- -module_file=/tmp/module.vmfb -driver=vulkan -entry_function=abs -function_inputs="i32=-2" +$ bazel run iree/tools:iree-run-module -- \ + --driver=vulkan \ + --module_file=/tmp/module.vmfb \ + --entry_function=abs \ + --function_input=i32=-2 ``` ## Running IREE's Vulkan Samples
diff --git a/docs/get_started/getting_started_riscv_cmake.md b/docs/get_started/getting_started_riscv_cmake.md index 276ca73..0e50fd0 100644 --- a/docs/get_started/getting_started_riscv_cmake.md +++ b/docs/get_started/getting_started_riscv_cmake.md
@@ -104,10 +104,11 @@ $ $HOME/riscv/qemu/linux/RISCV/bin/qemu-riscv64 \ -cpu rv64,x-v=true,x-k=true,vlen=256,elen=64,vext_spec=v1.0 \ -L $HOME/riscv/toolchain/clang/linux/RISCV/sysroot/ \ - ../iree-build-riscv/iree/tools/iree-run-module -driver=vmla \ - -module_file=/tmp/iree-run-module-vmla.vmfb \ - -entry_function=abs \ - -function_inputs="i32=-5" + ../iree-build-riscv/iree/tools/iree-run-module \ + --driver=vmla \ + --module_file=/tmp/iree-run-module-vmla.vmfb \ + --entry_function=abs \ + --function_input=i32=-5 ``` Output: @@ -147,10 +148,11 @@ $ $HOME/riscv/qemu/linux/RISCV/bin/qemu-riscv64 \ -cpu rv64 \ -L $HOME/riscv/toolchain/clang/linux/RISCV/sysroot/ \ - ../iree-build-riscv/iree/tools/iree-run-module -driver=dylib \ - -module_file=/tmp/iree-run-module-llvm_aot.vmfb \ - -entry_function=abs \ - -function_inputs="i32=-5" + ../iree-build-riscv/iree/tools/iree-run-module \ + --driver=dylib \ + --module_file=/tmp/iree-run-module-llvm_aot.vmfb \ + --entry_function=abs \ + --function_input=i32=-5 ``` Output:
diff --git a/docs/get_started/getting_started_windows_vulkan.md b/docs/get_started/getting_started_windows_vulkan.md index cb19e21..4b37faa 100644 --- a/docs/get_started/getting_started_windows_vulkan.md +++ b/docs/get_started/getting_started_windows_vulkan.md
@@ -132,10 +132,18 @@ ```powershell # -- CMake -- > cmake --build ..\iree-build\ --target iree_tools_iree-run-module -> ..\iree-build\iree\tools\iree-run-module.exe -module_file=.\build\module.vmfb -driver=vulkan -entry_function=abs -function_inputs="i32=-2" +> ..\iree-build\iree\tools\iree-run-module.exe \ + --driver=vulkan \ + --module_file=.\build\module.vmfb \ + --entry_function=abs \ + --function_input=i32=-2 # -- Bazel -- -> bazel run iree/tools:iree-run-module -- -module_file=.\build\module.vmfb -driver=vulkan -entry_function=abs -function_inputs="i32=-2" +> bazel run iree/tools:iree-run-module -- \ + --driver=vulkan \ + --module_file=.\build\module.vmfb \ + --entry_function=abs \ + --function_input=i32=-2 ``` ## Running IREE's Vulkan Samples
diff --git a/integrations/tensorflow/bindings/python/iree/tf/support/trace_utils.py b/integrations/tensorflow/bindings/python/iree/tf/support/trace_utils.py index f5ae8d3..5fa7eab 100644 --- a/integrations/tensorflow/bindings/python/iree/tf/support/trace_utils.py +++ b/integrations/tensorflow/bindings/python/iree/tf/support/trace_utils.py
@@ -291,13 +291,12 @@ compiled_path = self.compiled_paths[entry_function] if self.iree_serializable: - serialized_inputs = ", ".join(self.calls[0].serialized_inputs) + serialized_inputs = self.calls[0].serialized_inputs flagfile = [ f"--module_file={compiled_path}", f"--driver={self.backend_driver}", - f"--function_inputs={serialized_inputs}", f"--entry_function={entry_function}", - ] + ] + [f"--function_input={input}" for input in serialized_inputs] with open(os.path.join(trace_dir, "flagfile"), "w") as f: f.writelines(line + "\n" for line in flagfile) else:
diff --git a/integrations/tensorflow/e2e/slim_vision_models/BUILD b/integrations/tensorflow/e2e/slim_vision_models/BUILD index a9f7179..b91c95a 100644 --- a/integrations/tensorflow/e2e/slim_vision_models/BUILD +++ b/integrations/tensorflow/e2e/slim_vision_models/BUILD
@@ -141,6 +141,8 @@ "mobilenet_v2_100_224", "mobilenet_v2_130_224", "mobilenet_v2_140_224", + # MobileNetV3 + "mobilenet_v3_large_100_224", "nasnet_mobile", "nasnet_large", "pnasnet_large",
diff --git a/integrations/tensorflow/e2e/slim_vision_models/CMakeLists.txt b/integrations/tensorflow/e2e/slim_vision_models/CMakeLists.txt index b72c985..f1a6694 100644 --- a/integrations/tensorflow/e2e/slim_vision_models/CMakeLists.txt +++ b/integrations/tensorflow/e2e/slim_vision_models/CMakeLists.txt
@@ -23,7 +23,7 @@ "slim_vision_model_test.py" "tf" "https://tfhub.dev/google/imagenet/" - "amoebanet_a_n18_f448;inception_resnet_v2;inception_v1;inception_v2;inception_v3;mobilenet_v1_025_128;mobilenet_v1_025_160;mobilenet_v1_025_192;mobilenet_v1_025_224;mobilenet_v1_050_128;mobilenet_v1_050_160;mobilenet_v1_050_192;mobilenet_v1_050_224;mobilenet_v1_075_128;mobilenet_v1_075_160;mobilenet_v1_075_192;mobilenet_v1_075_224;mobilenet_v1_100_128;mobilenet_v1_100_160;mobilenet_v1_100_192;mobilenet_v1_100_224;mobilenet_v2_035_96;mobilenet_v2_035_128;mobilenet_v2_035_160;mobilenet_v2_035_192;mobilenet_v2_035_224;mobilenet_v2_050_96;mobilenet_v2_050_128;mobilenet_v2_050_160;mobilenet_v2_050_192;mobilenet_v2_050_224;mobilenet_v2_075_96;mobilenet_v2_075_128;mobilenet_v2_075_160;mobilenet_v2_075_192;mobilenet_v2_075_224;mobilenet_v2_100_96;mobilenet_v2_100_128;mobilenet_v2_100_160;mobilenet_v2_100_192;mobilenet_v2_100_224;mobilenet_v2_130_224;mobilenet_v2_140_224;nasnet_mobile;nasnet_large;pnasnet_large;resnet_v1_50;resnet_v1_101;resnet_v1_152;resnet_v2_50;resnet_v2_101;resnet_v2_152" + "amoebanet_a_n18_f448;inception_resnet_v2;inception_v1;inception_v2;inception_v3;mobilenet_v1_025_128;mobilenet_v1_025_160;mobilenet_v1_025_192;mobilenet_v1_025_224;mobilenet_v1_050_128;mobilenet_v1_050_160;mobilenet_v1_050_192;mobilenet_v1_050_224;mobilenet_v1_075_128;mobilenet_v1_075_160;mobilenet_v1_075_192;mobilenet_v1_075_224;mobilenet_v1_100_128;mobilenet_v1_100_160;mobilenet_v1_100_192;mobilenet_v1_100_224;mobilenet_v2_035_96;mobilenet_v2_035_128;mobilenet_v2_035_160;mobilenet_v2_035_192;mobilenet_v2_035_224;mobilenet_v2_050_96;mobilenet_v2_050_128;mobilenet_v2_050_160;mobilenet_v2_050_192;mobilenet_v2_050_224;mobilenet_v2_075_96;mobilenet_v2_075_128;mobilenet_v2_075_160;mobilenet_v2_075_192;mobilenet_v2_075_224;mobilenet_v2_100_96;mobilenet_v2_100_128;mobilenet_v2_100_160;mobilenet_v2_100_192;mobilenet_v2_100_224;mobilenet_v2_130_224;mobilenet_v2_140_224;mobilenet_v3_large_100_224;nasnet_mobile;nasnet_large;pnasnet_large;resnet_v1_50;resnet_v1_101;resnet_v1_152;resnet_v2_50;resnet_v2_101;resnet_v2_152" "tf;tflite;iree_vmla;iree_vulkan" FAILING_CONFIGURATIONS ",,,amoebanet_a_n18_f448,"
diff --git a/integrations/tensorflow/e2e/slim_vision_models/slim_vision_model_test.py b/integrations/tensorflow/e2e/slim_vision_models/slim_vision_model_test.py index f6d58c5..be762b5 100644 --- a/integrations/tensorflow/e2e/slim_vision_models/slim_vision_model_test.py +++ b/integrations/tensorflow/e2e/slim_vision_models/slim_vision_model_test.py
@@ -41,8 +41,6 @@ 'tf_hub_url', None, 'Base URL for the models to test. URL at the time of ' 'writing:\nhttps://tfhub.dev/google/imagenet/') -# Classification mode; 4 - is a format of the model (SavedModel TF v2). -MODE = 'classification/4' LARGE_MODELS = ['amoebanet_a_n18_f448', "nasnet_large", "pnasnet_large"] @@ -54,17 +52,28 @@ # from their TFHub name. size = int(FLAGS.model.split('_')[-1]) return (1, size, size, 3) + elif FLAGS.model.startswith('mobilenet_v3_large'): + size = int(FLAGS.model.split('_')[-1]) + return (1, size, size, 3) else: # Default input shape. return (1, 224, 224, 3) +def get_mode(model_name): + if model_name.startswith('mobilenet_v3'): + return 'classification/5' + # Classification mode; 4 - is a format of the model (SavedModel TF v2). + return 'classification/4' + + class SlimVisionModule(tf.Module): def __init__(self): super().__init__() tf_utils.set_random_seed() - model_path = posixpath.join(FLAGS.tf_hub_url, FLAGS.model, MODE) + model_path = posixpath.join(FLAGS.tf_hub_url, FLAGS.model, + get_mode(FLAGS.model)) hub_layer = hub.KerasLayer(model_path) self.m = tf.keras.Sequential([hub_layer]) input_shape = get_input_shape()
diff --git a/iree/base/BUILD b/iree/base/BUILD index 0f901c2..40423fc 100644 --- a/iree/base/BUILD +++ b/iree/base/BUILD
@@ -29,7 +29,13 @@ cc_library( name = "api", - srcs = ["api.c"], + srcs = [ + "allocator.c", + "api.c", + "status.c", + "string_view.c", + "time.c", + ], hdrs = ["api.h"], visibility = ["//visibility:public"], deps = [ @@ -93,7 +99,7 @@ deps = [ ":core_headers", ":tracing", - "@com_google_absl//absl/flags:flag", + "//iree/base/internal:flags", "@com_google_absl//absl/strings:str_format", ], ) @@ -137,6 +143,16 @@ ], ) +cc_test( + name = "string_view_test", + srcs = ["string_view_test.cc"], + deps = [ + ":api", + "//iree/testing:gtest", + "//iree/testing:gtest_main", + ], +) + cc_library( name = "synchronization", srcs = ["synchronization.c"],
diff --git a/iree/base/CMakeLists.txt b/iree/base/CMakeLists.txt index 380dae5..54ecc4b 100644 --- a/iree/base/CMakeLists.txt +++ b/iree/base/CMakeLists.txt
@@ -20,7 +20,11 @@ HDRS "api.h" SRCS + "allocator.c" "api.c" + "status.c" + "string_view.c" + "time.c" DEPS ::core_headers ::tracing @@ -80,8 +84,8 @@ DEPS ::core_headers ::tracing - absl::flags absl::str_format + iree::base::internal::flags PUBLIC ) @@ -131,6 +135,17 @@ iree::testing::gtest_main ) +iree_cc_test( + NAME + string_view_test + SRCS + "string_view_test.cc" + DEPS + ::api + iree::testing::gtest + iree::testing::gtest_main +) + iree_cc_library( NAME synchronization
diff --git a/iree/base/allocator.c b/iree/base/allocator.c new file mode 100644 index 0000000..3f52727 --- /dev/null +++ b/iree/base/allocator.c
@@ -0,0 +1,105 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "iree/base/api.h" +#include "iree/base/target_platform.h" +#include "iree/base/tracing.h" + +IREE_API_EXPORT iree_status_t IREE_API_CALL iree_allocator_malloc( + iree_allocator_t allocator, iree_host_size_t byte_length, void** out_ptr) { + if (!allocator.alloc) { + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "allocator has no alloc routine"); + } + return allocator.alloc(allocator.self, IREE_ALLOCATION_MODE_ZERO_CONTENTS, + byte_length, out_ptr); +} + +IREE_API_EXPORT iree_status_t IREE_API_CALL iree_allocator_realloc( + iree_allocator_t allocator, iree_host_size_t byte_length, void** out_ptr) { + if (!allocator.alloc) { + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "allocator has no alloc routine"); + } + return allocator.alloc(allocator.self, + IREE_ALLOCATION_MODE_TRY_REUSE_EXISTING, byte_length, + out_ptr); +} + +IREE_API_EXPORT iree_status_t IREE_API_CALL +iree_allocator_clone(iree_allocator_t allocator, + iree_const_byte_span_t source_bytes, void** out_ptr) { + IREE_RETURN_IF_ERROR( + iree_allocator_malloc(allocator, source_bytes.data_length, out_ptr)); + memcpy(*out_ptr, source_bytes.data, source_bytes.data_length); + return iree_ok_status(); +} + +IREE_API_EXPORT void IREE_API_CALL +iree_allocator_free(iree_allocator_t allocator, void* ptr) { + if (ptr && allocator.free) { + allocator.free(allocator.self, ptr); + } +} + +IREE_API_EXPORT iree_status_t IREE_API_CALL +iree_allocator_system_allocate(void* self, iree_allocation_mode_t mode, + iree_host_size_t byte_length, void** out_ptr) { + IREE_ASSERT_ARGUMENT(out_ptr); + if (byte_length == 0) { + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "allocations must be >0 bytes"); + } + + IREE_TRACE_ZONE_BEGIN(z0); + + void* existing_ptr = *out_ptr; + void* ptr = NULL; + if (existing_ptr && (mode & IREE_ALLOCATION_MODE_TRY_REUSE_EXISTING)) { + ptr = realloc(existing_ptr, byte_length); + if (ptr && (mode & IREE_ALLOCATION_MODE_ZERO_CONTENTS)) { + memset(ptr, 0, byte_length); + } + } else { + existing_ptr = NULL; + if (mode & IREE_ALLOCATION_MODE_ZERO_CONTENTS) { + ptr = calloc(1, byte_length); + } else { + ptr = malloc(byte_length); + } + } + if (!ptr) { + return iree_make_status(IREE_STATUS_RESOURCE_EXHAUSTED, + "system allocator failed the request"); + } + + if (existing_ptr) { + IREE_TRACE_FREE(existing_ptr); + } + IREE_TRACE_ALLOC(ptr, byte_length); + + *out_ptr = ptr; + IREE_TRACE_ZONE_END(z0); + return iree_ok_status(); +} + +IREE_API_EXPORT void IREE_API_CALL iree_allocator_system_free(void* self, + void* ptr) { + IREE_TRACE_ZONE_BEGIN(z0); + IREE_TRACE_FREE(ptr); + if (ptr) { + free(ptr); + } + IREE_TRACE_ZONE_END(z0); +}
diff --git a/iree/base/api.c b/iree/base/api.c index f883f8c..e2d06c9 100644 --- a/iree/base/api.c +++ b/iree/base/api.c
@@ -14,965 +14,6 @@ #include "iree/base/api.h" -#include <assert.h> -#include <errno.h> -#include <limits.h> -#include <stdarg.h> -#include <stdio.h> -#include <stdlib.h> -#include <string.h> -#include <time.h> - -#include "iree/base/target_platform.h" -#include "iree/base/tracing.h" - -static inline size_t iree_min_host_size(size_t a, size_t b) { - return a < b ? a : b; -} - -#if defined(IREE_PLATFORM_WINDOWS) -// https://docs.microsoft.com/en-us/cpp/c-runtime-library/reference/aligned-malloc -#define iree_aligned_alloc(alignment, size) _aligned_malloc(size, alignment) -#define iree_aligned_free(p) _aligned_free(p) -#elif defined(_ISOC11_SOURCE) -// https://en.cppreference.com/w/c/memory/aligned_alloc -#define iree_aligned_alloc(alignment, size) aligned_alloc(alignment, size) -#define iree_aligned_free(p) free(p) -#elif _POSIX_C_SOURCE >= 200112L -// https://pubs.opengroup.org/onlinepubs/9699919799/functions/posix_memalign.html -static inline void* iree_aligned_alloc(size_t alignment, size_t size) { - void* ptr = NULL; - return posix_memalign(&ptr, alignment, size) == 0 ? ptr : NULL; -} -#define iree_aligned_free(p) free(p) -#else -// Emulates alignment with normal malloc. We overallocate by at least the -// alignment + the size of a pointer, store the base pointer at p[-1], and -// return the aligned pointer. This lets us easily get the base pointer in free -// to pass back to the system. -static inline void* iree_aligned_alloc(size_t alignment, size_t size) { - void* base_ptr = malloc(size + alignment + sizeof(uintptr_t)); - if (!base_ptr) return NULL; - uintptr_t* aligned_ptr = (uintptr_t*)iree_math_align( - (uintptr_t)base_ptr + sizeof(uintptr_t), alignment); - aligned_ptr[-1] = (uintptr_t)base_ptr; - return aligned_ptr; -} -static inline void iree_aligned_free(void* p) { - if (IREE_UNLIKELY(!p)) return; - uintptr_t* aligned_ptr = (uintptr_t*)p; - void* base_ptr = (void*)aligned_ptr[-1]; - free(base_ptr); -} -#endif // IREE_PLATFORM_WINDOWS - -//===----------------------------------------------------------------------===// -// iree_string_view_t -//===----------------------------------------------------------------------===// - -IREE_API_EXPORT bool IREE_API_CALL -iree_string_view_equal(iree_string_view_t lhs, iree_string_view_t rhs) { - if (lhs.size != rhs.size) return false; - for (iree_host_size_t i = 0; i < lhs.size; ++i) { - if (lhs.data[i] != rhs.data[i]) return false; - } - return true; -} - -IREE_API_EXPORT int IREE_API_CALL -iree_string_view_compare(iree_string_view_t lhs, iree_string_view_t rhs) { - iree_host_size_t min_size = iree_min_host_size(lhs.size, rhs.size); - int cmp = strncmp(lhs.data, rhs.data, min_size); - if (cmp != 0) { - return cmp; - } else if (lhs.size == rhs.size) { - return 0; - } - return lhs.size < rhs.size ? -1 : 1; -} - -IREE_API_EXPORT bool IREE_API_CALL iree_string_view_starts_with( - iree_string_view_t value, iree_string_view_t prefix) { - if (!value.data || !prefix.data || prefix.size > value.size) { - return false; - } - return strncmp(value.data, prefix.data, prefix.size) == 0; -} - -IREE_API_EXPORT iree_host_size_t IREE_API_CALL iree_string_view_find_char( - iree_string_view_t value, char c, iree_host_size_t pos) { - if (iree_string_view_is_empty(value) || pos >= value.size) { - return IREE_STRING_VIEW_NPOS; - } - const char* result = - (const char*)(memchr(value.data + pos, c, value.size - pos)); - return result != NULL ? result - value.data : IREE_STRING_VIEW_NPOS; -} - -IREE_API_EXPORT iree_host_size_t IREE_API_CALL iree_string_view_find_first_of( - iree_string_view_t value, iree_string_view_t s, iree_host_size_t pos) { - if (iree_string_view_is_empty(value) || iree_string_view_is_empty(s)) { - return IREE_STRING_VIEW_NPOS; - } - if (s.size == 1) { - // Avoid the cost of the lookup table for a single-character search. - return iree_string_view_find_char(value, s.data[0], pos); - } - bool lookup_table[UCHAR_MAX + 1] = {0}; - for (iree_host_size_t i = 0; i < s.size; ++i) { - lookup_table[(uint8_t)s.data[i]] = true; - } - for (iree_host_size_t i = pos; i < value.size; ++i) { - if (lookup_table[(uint8_t)value.data[i]]) { - return i; - } - } - return IREE_STRING_VIEW_NPOS; -} - -IREE_API_EXPORT iree_host_size_t IREE_API_CALL iree_string_view_find_last_of( - iree_string_view_t value, iree_string_view_t s, iree_host_size_t pos) { - if (iree_string_view_is_empty(value) || iree_string_view_is_empty(s)) { - return IREE_STRING_VIEW_NPOS; - } - bool lookup_table[UCHAR_MAX + 1] = {0}; - for (iree_host_size_t i = 0; i < s.size; ++i) { - lookup_table[(uint8_t)s.data[i]] = true; - } - pos = iree_min(pos, value.size); - iree_host_size_t i = pos; - while (i != 0) { - --i; - if (lookup_table[(uint8_t)value.data[i]]) { - return i; - } - } - return IREE_STRING_VIEW_NPOS; -} - -IREE_API_EXPORT iree_string_view_t IREE_API_CALL -iree_string_view_remove_prefix(iree_string_view_t value, iree_host_size_t n) { - if (n >= value.size) { - return iree_string_view_empty(); - } - return iree_make_string_view(value.data + n, value.size - n); -} - -IREE_API_EXPORT iree_string_view_t IREE_API_CALL iree_string_view_substr( - iree_string_view_t value, iree_host_size_t pos, iree_host_size_t n) { - pos = iree_min_host_size(pos, value.size); - n = iree_min_host_size(n, value.size - pos); - return iree_make_string_view(value.data + pos, n); -} - -IREE_API_EXPORT intptr_t IREE_API_CALL iree_string_view_split( - iree_string_view_t value, char split_char, iree_string_view_t* out_lhs, - iree_string_view_t* out_rhs) { - *out_lhs = iree_string_view_empty(); - *out_rhs = iree_string_view_empty(); - if (!value.data || !value.size) { - return -1; - } - const void* first_ptr = memchr(value.data, split_char, value.size); - if (!first_ptr) { - *out_lhs = value; - return -1; - } - intptr_t offset = (intptr_t)((const char*)(first_ptr)-value.data); - if (out_lhs) { - out_lhs->data = value.data; - out_lhs->size = offset; - } - if (out_rhs) { - out_rhs->data = value.data + offset + 1; - out_rhs->size = value.size - offset - 1; - } - return offset; -} - -static bool iree_string_view_match_pattern_impl(iree_string_view_t value, - iree_string_view_t pattern) { - iree_host_size_t next_char_index = iree_string_view_find_first_of( - pattern, iree_make_cstring_view("*?"), /*pos=*/0); - if (next_char_index == IREE_STRING_VIEW_NPOS) { - return iree_string_view_equal(value, pattern); - } else if (next_char_index > 0) { - iree_string_view_t value_prefix = - iree_string_view_substr(value, 0, next_char_index); - iree_string_view_t pattern_prefix = - iree_string_view_substr(pattern, 0, next_char_index); - if (!iree_string_view_equal(value_prefix, pattern_prefix)) { - return false; - } - value = - iree_string_view_substr(value, next_char_index, IREE_STRING_VIEW_NPOS); - pattern = iree_string_view_substr(pattern, next_char_index, - IREE_STRING_VIEW_NPOS); - } - if (iree_string_view_is_empty(value) && iree_string_view_is_empty(pattern)) { - return true; - } - char pattern_char = pattern.data[0]; - if (pattern_char == '*' && pattern.size > 1 && - iree_string_view_is_empty(value)) { - return false; - } else if (pattern_char == '*' && pattern.size == 1) { - return true; - } else if (pattern_char == '?' || value.data[0] == pattern_char) { - return iree_string_view_match_pattern_impl( - iree_string_view_substr(value, 1, IREE_STRING_VIEW_NPOS), - iree_string_view_substr(pattern, 1, IREE_STRING_VIEW_NPOS)); - } else if (pattern_char == '*') { - return iree_string_view_match_pattern_impl( - value, - iree_string_view_substr(pattern, 1, IREE_STRING_VIEW_NPOS)) || - iree_string_view_match_pattern_impl( - iree_string_view_substr(value, 1, IREE_STRING_VIEW_NPOS), - pattern); - } - return false; -} - -IREE_API_EXPORT bool IREE_API_CALL iree_string_view_match_pattern( - iree_string_view_t value, iree_string_view_t pattern) { - return iree_string_view_match_pattern_impl(value, pattern); -} - -IREE_API_EXPORT iree_host_size_t IREE_API_CALL -iree_string_view_append_to_buffer(iree_string_view_t source_value, - iree_string_view_t* target_value, - char* buffer) { - memcpy(buffer, source_value.data, source_value.size); - target_value->data = buffer; - target_value->size = source_value.size; - return source_value.size; -} - -//===----------------------------------------------------------------------===// -// iree_status_t canonical errors -//===----------------------------------------------------------------------===// - -IREE_API_EXPORT iree_status_code_t -iree_status_code_from_errno(int error_number) { - switch (error_number) { - case 0: - return IREE_STATUS_OK; - case EINVAL: // Invalid argument - case ENAMETOOLONG: // Filename too long - case E2BIG: // Argument list too long - case EDESTADDRREQ: // Destination address required - case EDOM: // Mathematics argument out of domain of function - case EFAULT: // Bad address - case EILSEQ: // Illegal byte sequence - case ENOPROTOOPT: // Protocol not available - case ENOSTR: // Not a STREAM - case ENOTSOCK: // Not a socket - case ENOTTY: // Inappropriate I/O control operation - case EPROTOTYPE: // Protocol wrong type for socket - case ESPIPE: // Invalid seek - return IREE_STATUS_INVALID_ARGUMENT; - case ETIMEDOUT: // Connection timed out - case ETIME: // Timer expired - return IREE_STATUS_DEADLINE_EXCEEDED; - case ENODEV: // No such device - case ENOENT: // No such file or directory -#ifdef ENOMEDIUM - case ENOMEDIUM: // No medium found -#endif - case ENXIO: // No such device or address - case ESRCH: // No such process - return IREE_STATUS_NOT_FOUND; - case EEXIST: // File exists - case EADDRNOTAVAIL: // Address not available - case EALREADY: // Connection already in progress -#ifdef ENOTUNIQ - case ENOTUNIQ: // Name not unique on network -#endif - return IREE_STATUS_ALREADY_EXISTS; - case EPERM: // Operation not permitted - case EACCES: // Permission denied -#ifdef ENOKEY - case ENOKEY: // Required key not available -#endif - case EROFS: // Read only file system - return IREE_STATUS_PERMISSION_DENIED; - case ENOTEMPTY: // Directory not empty - case EISDIR: // Is a directory - case ENOTDIR: // Not a directory - case EADDRINUSE: // Address already in use - case EBADF: // Invalid file descriptor -#ifdef EBADFD - case EBADFD: // File descriptor in bad state -#endif - case EBUSY: // Device or resource busy - case ECHILD: // No child processes - case EISCONN: // Socket is connected -#ifdef EISNAM - case EISNAM: // Is a named type file -#endif -#ifdef ENOTBLK - case ENOTBLK: // Block device required -#endif - case ENOTCONN: // The socket is not connected - case EPIPE: // Broken pipe -#ifdef ESHUTDOWN - case ESHUTDOWN: // Cannot send after transport endpoint shutdown -#endif - case ETXTBSY: // Text file busy -#ifdef EUNATCH - case EUNATCH: // Protocol driver not attached -#endif - return IREE_STATUS_FAILED_PRECONDITION; - case ENOSPC: // No space left on device -#ifdef EDQUOT - case EDQUOT: // Disk quota exceeded -#endif - case EMFILE: // Too many open files - case EMLINK: // Too many links - case ENFILE: // Too many open files in system - case ENOBUFS: // No buffer space available - case ENODATA: // No message is available on the STREAM read queue - case ENOMEM: // Not enough space - case ENOSR: // No STREAM resources -#ifdef EUSERS - case EUSERS: // Too many users -#endif - return IREE_STATUS_RESOURCE_EXHAUSTED; -#ifdef ECHRNG - case ECHRNG: // Channel number out of range -#endif - case EFBIG: // File too large - case EOVERFLOW: // Value too large to be stored in data type - case ERANGE: // Result too large - return IREE_STATUS_OUT_OF_RANGE; -#ifdef ENOPKG - case ENOPKG: // Package not installed -#endif - case ENOSYS: // Function not implemented - case ENOTSUP: // Operation not supported - case EAFNOSUPPORT: // Address family not supported -#ifdef EPFNOSUPPORT - case EPFNOSUPPORT: // Protocol family not supported -#endif - case EPROTONOSUPPORT: // Protocol not supported -#ifdef ESOCKTNOSUPPORT - case ESOCKTNOSUPPORT: // Socket type not supported -#endif - case EXDEV: // Improper link - return IREE_STATUS_UNIMPLEMENTED; - case EAGAIN: // Resource temporarily unavailable -#ifdef ECOMM - case ECOMM: // Communication error on send -#endif - case ECONNREFUSED: // Connection refused - case ECONNABORTED: // Connection aborted - case ECONNRESET: // Connection reset - case EINTR: // Interrupted function call -#ifdef EHOSTDOWN - case EHOSTDOWN: // Host is down -#endif - case EHOSTUNREACH: // Host is unreachable - case ENETDOWN: // Network is down - case ENETRESET: // Connection aborted by network - case ENETUNREACH: // Network unreachable - case ENOLCK: // No locks available - case ENOLINK: // Link has been severed -#ifdef ENONET - case ENONET: // Machine is not on the network -#endif - return IREE_STATUS_UNAVAILABLE; - case EDEADLK: // Resource deadlock avoided -#ifdef ESTALE - case ESTALE: // Stale file handle -#endif - return IREE_STATUS_ABORTED; - case ECANCELED: // Operation cancelled - return IREE_STATUS_CANCELLED; - default: - return IREE_STATUS_UNKNOWN; - } -} - -#if defined(IREE_PLATFORM_WINDOWS) -IREE_API_EXPORT iree_status_code_t -iree_status_code_from_win32_error(uint32_t error) { - switch (error) { - case ERROR_SUCCESS: - return IREE_STATUS_OK; - case ERROR_FILE_NOT_FOUND: - case ERROR_PATH_NOT_FOUND: - return IREE_STATUS_NOT_FOUND; - case ERROR_TOO_MANY_OPEN_FILES: - case ERROR_OUTOFMEMORY: - case ERROR_HANDLE_DISK_FULL: - case ERROR_HANDLE_EOF: - return IREE_STATUS_RESOURCE_EXHAUSTED; - case ERROR_ACCESS_DENIED: - return IREE_STATUS_PERMISSION_DENIED; - case ERROR_INVALID_HANDLE: - return IREE_STATUS_INVALID_ARGUMENT; - case ERROR_NOT_READY: - case ERROR_READ_FAULT: - return IREE_STATUS_UNAVAILABLE; - case ERROR_WRITE_FAULT: - return IREE_STATUS_DATA_LOSS; - case ERROR_NOT_SUPPORTED: - return IREE_STATUS_UNIMPLEMENTED; - default: - return IREE_STATUS_UNKNOWN; - } -} -#endif // IREE_PLATFORM_WINDOWS - -//===----------------------------------------------------------------------===// -// iree_status_t -//===----------------------------------------------------------------------===// - -IREE_API_EXPORT const char* IREE_API_CALL -iree_status_code_string(iree_status_code_t code) { - switch (code) { - case IREE_STATUS_OK: - return "OK"; - case IREE_STATUS_CANCELLED: - return "CANCELLED"; - case IREE_STATUS_UNKNOWN: - return "UNKNOWN"; - case IREE_STATUS_INVALID_ARGUMENT: - return "INVALID_ARGUMENT"; - case IREE_STATUS_DEADLINE_EXCEEDED: - return "DEADLINE_EXCEEDED"; - case IREE_STATUS_NOT_FOUND: - return "NOT_FOUND"; - case IREE_STATUS_ALREADY_EXISTS: - return "ALREADY_EXISTS"; - case IREE_STATUS_PERMISSION_DENIED: - return "PERMISSION_DENIED"; - case IREE_STATUS_UNAUTHENTICATED: - return "UNAUTHENTICATED"; - case IREE_STATUS_RESOURCE_EXHAUSTED: - return "RESOURCE_EXHAUSTED"; - case IREE_STATUS_FAILED_PRECONDITION: - return "FAILED_PRECONDITION"; - case IREE_STATUS_ABORTED: - return "ABORTED"; - case IREE_STATUS_OUT_OF_RANGE: - return "OUT_OF_RANGE"; - case IREE_STATUS_UNIMPLEMENTED: - return "UNIMPLEMENTED"; - case IREE_STATUS_INTERNAL: - return "INTERNAL"; - case IREE_STATUS_UNAVAILABLE: - return "UNAVAILABLE"; - case IREE_STATUS_DATA_LOSS: - return "DATA_LOSS"; - default: - return ""; - } -} - -// TODO(#55): move payload methods/types to header when API is stabilized. - -// Defines the type of an iree_status_payload_t. -typedef enum { - // Opaque; payload may still be formatted by a formatter but is not possible - // to retrieve by the programmatic APIs. - IREE_STATUS_PAYLOAD_TYPE_OPAQUE = 0u, - // A string message annotation of type iree_status_payload_message_t. - IREE_STATUS_PAYLOAD_TYPE_MESSAGE = 1u, - // Starting type ID for user payloads. IREE reserves all payloads with types - // less than this. - IREE_STATUS_PAYLOAD_TYPE_MIN_USER = 0x70000000u, -} iree_status_payload_type_t; - -typedef struct iree_status_payload_s iree_status_payload_t; - -// Function that formats a payload into a human-readable string form for logs. -typedef void(IREE_API_PTR* iree_status_payload_formatter_t)( - const iree_status_payload_t* payload, iree_host_size_t buffer_capacity, - char* buffer, iree_host_size_t* out_buffer_length); - -// Header for optional status payloads. -// Each status may have zero or more payloads associated with it that can later -// be used to produce more detailed logging or programmatically query -// information about an error. -struct iree_status_payload_s { - // Next payload in the status payload linked list. - struct iree_status_payload_s* next; - // Payload type identifier used for programmatic access to payloads. May be - // IREE_STATUS_PAYLOAD_TYPE_OPAQUE if the payload cannot be accessed directly. - iree_status_payload_type_t type; - // Allocator used for the payload and associated resources. - iree_allocator_t allocator; - // String formatter callback used to write the payload into a string buffer. - // If not present then the payload will be mentioned but not dumped when the - // status is logged. - iree_status_payload_formatter_t formatter; -}; - -// A string message (IREE_STATUS_PAYLOAD_TYPE_MESSAGE). -typedef struct { - iree_status_payload_t header; - // String data reference. May point to an address immediately following this - // struct (if copied) or a constant string reference in rodata. - iree_string_view_t message; -} iree_status_payload_message_t; - -// Allocated storage for an iree_status_t. -// Only statuses that have either source information or payloads will have -// storage allocated for them. -typedef struct { - // Optional doubly-linked list of payloads associated with the status. - // Head = first added, tail = last added. - iree_status_payload_t* payload_head; - iree_status_payload_t* payload_tail; - -#if (IREE_STATUS_FEATURES & IREE_STATUS_FEATURE_SOURCE_LOCATION) != 0 - // __FILE__ of the originating status allocation. - const char* file; - // __LINE__ of the originating status allocation. - uint32_t line; -#endif // has IREE_STATUS_FEATURE_SOURCE_LOCATION - -#if (IREE_STATUS_FEATURES & IREE_STATUS_FEATURE_ANNOTATIONS) != 0 - // Optional message that is allocated either as a constant string in rodata or - // present as a suffix on the storage. - iree_string_view_t message; -#endif // has IREE_STATUS_FEATURE_ANNOTATIONS -} iree_status_storage_t; - -#define iree_status_storage(status) \ - ((iree_status_storage_t*)(((uintptr_t)(status) & ~IREE_STATUS_CODE_MASK))) - -// Appends a payload to the storage doubly-linked list. -static iree_status_t iree_status_append_payload( - iree_status_t status, iree_status_storage_t* storage, - iree_status_payload_t* payload) { - if (!storage->payload_tail) { - storage->payload_head = payload; - } else { - storage->payload_tail->next = payload; - } - storage->payload_tail = payload; - return status; -} - -// Formats an iree_status_payload_message_t to the given output |buffer|. -// |out_buffer_length| will be set to the number of characters written excluding -// NUL. If |buffer| is omitted then |out_buffer_length| will be set to the -// total number of characters in |buffer_capacity| required to contain the -// entire message. -static void IREE_API_CALL iree_status_payload_message_formatter( - const iree_status_payload_t* payload, iree_host_size_t buffer_capacity, - char* buffer, iree_host_size_t* out_buffer_length) { - iree_status_payload_message_t* message_payload = - (iree_status_payload_message_t*)payload; - if (!buffer) { - *out_buffer_length = message_payload->message.size; - return; - } - iree_host_size_t n = buffer_capacity < message_payload->message.size - ? buffer_capacity - : message_payload->message.size; - memcpy(buffer, message_payload->message.data, n); - buffer[n] = '\0'; - *out_buffer_length = n; -} - -// Captures the current stack and attaches it to the status storage. -// A count of |skip_frames| will be skipped from the top of the stack. -// Setting |skip_frames|=0 will include the caller in the stack while -// |skip_frames|=1 will exclude it. -static void iree_status_attach_stack_trace(iree_status_storage_t* storage, - int skip_frames) { -#if (IREE_STATUS_FEATURES & IREE_STATUS_FEATURE_STACK_TRACE) != 0 - // TODO(#55): backtrace or other magic. -#endif // has IREE_STATUS_FEATURE_STACK_TRACE -} - -IREE_API_EXPORT IREE_MUST_USE_RESULT iree_status_t IREE_API_CALL -iree_status_allocate(iree_status_code_t code, const char* file, uint32_t line, - iree_string_view_t message) { -#if IREE_STATUS_FEATURES == 0 - // More advanced status code features like source location and messages are - // disabled. All statuses are just the codes. - return iree_status_from_code(code); -#else - // No-op for OK statuses; we won't get these from the macros but may be called - // with this from marshaling code. - if (IREE_UNLIKELY(code == IREE_STATUS_OK)) return iree_ok_status(); - - // Allocate storage with the appropriate alignment such that we can pack the - // code in the lower bits of the pointer. Since failed statuses are rare and - // likely have much larger costs (like string formatting) the extra bytes for - // alignment are worth being able to avoid pointer dereferences and other - // things during the normal code paths that just check codes. - // - // Note that we are using the CRT allocation function here, as we can't trust - // our allocator system to work when we are throwing errors (as we may be - // allocating this error from a failed allocation!). - size_t storage_alignment = (IREE_STATUS_CODE_MASK + 1); - size_t storage_size = - iree_math_align(sizeof(iree_status_storage_t), storage_alignment); - iree_status_storage_t* storage = (iree_status_storage_t*)iree_aligned_alloc( - storage_alignment, storage_size); - if (IREE_UNLIKELY(!storage)) return iree_status_from_code(code); - memset(storage, 0, sizeof(*storage)); - -#if (IREE_STATUS_FEATURES & IREE_STATUS_FEATURE_SOURCE_LOCATION) != 0 - storage->file = file; - storage->line = line; -#endif // has IREE_STATUS_FEATURE_SOURCE_LOCATION - -#if (IREE_STATUS_FEATURES & IREE_STATUS_FEATURE_ANNOTATIONS) != 0 - // NOTE: messages are rodata strings here and not retained. - storage->message = message; -#endif // has IREE_STATUS_FEATURE_ANNOTATIONS - - iree_status_attach_stack_trace(storage, /*skip_frames=*/1); - return (iree_status_t)((uintptr_t)storage | (code & IREE_STATUS_CODE_MASK)); -#endif // has any IREE_STATUS_FEATURES -} - -IREE_API_EXPORT IREE_MUST_USE_RESULT iree_status_t IREE_API_CALL -iree_status_allocate_f(iree_status_code_t code, const char* file, uint32_t line, - const char* format, ...) { - va_list varargs_0, varargs_1; - va_start(varargs_0, format); - va_start(varargs_1, format); - iree_status_t ret = - iree_status_allocate_vf(code, file, line, format, varargs_0, varargs_1); - va_end(varargs_0); - va_end(varargs_1); - return ret; -} - -IREE_API_EXPORT IREE_MUST_USE_RESULT iree_status_t IREE_API_CALL -iree_status_allocate_vf(iree_status_code_t code, const char* file, - uint32_t line, const char* format, va_list varargs_0, - va_list varargs_1) { -#if (IREE_STATUS_FEATURES & IREE_STATUS_FEATURE_ANNOTATIONS) == 0 - // Annotations disabled; ignore the format string/args. - return iree_status_allocate(code, file, line, iree_string_view_empty()); -#else - // No-op for OK statuses; we won't get these from the macros but may be called - // with this from marshaling code. - if (IREE_UNLIKELY(code == IREE_STATUS_OK)) return iree_ok_status(); - - // Compute the total number of bytes (including NUL) required to store the - // message. - size_t message_size = - vsnprintf(/*buffer=*/NULL, /*buffer_count=*/0, format, varargs_0); - if (message_size < 0) return iree_status_from_code(code); - ++message_size; // NUL byte - - // Allocate storage with the additional room to store the formatted message. - // This avoids additional allocations for the common case of a message coming - // only from the original status error site. - size_t storage_alignment = (IREE_STATUS_CODE_MASK + 1); - size_t storage_size = iree_math_align( - sizeof(iree_status_storage_t) + message_size, storage_alignment); - iree_status_storage_t* storage = (iree_status_storage_t*)iree_aligned_alloc( - storage_alignment, storage_size); - if (IREE_UNLIKELY(!storage)) return iree_status_from_code(code); - memset(storage, 0, sizeof(*storage)); - -#if (IREE_STATUS_FEATURES & IREE_STATUS_FEATURE_SOURCE_LOCATION) != 0 - storage->file = file; - storage->line = line; -#endif // has IREE_STATUS_FEATURE_SOURCE_LOCATION - - // vsnprintf directly into message buffer. - storage->message.size = message_size - 1; - storage->message.data = (const char*)storage + sizeof(iree_status_storage_t); - int ret = - vsnprintf((char*)storage->message.data, message_size, format, varargs_1); - if (IREE_UNLIKELY(ret < 0)) { - iree_aligned_free(storage); - return (iree_status_t)code; - } - - iree_status_attach_stack_trace(storage, /*skip_frames=*/1); - return (iree_status_t)((uintptr_t)storage | (code & IREE_STATUS_CODE_MASK)); -#endif // has IREE_STATUS_FEATURE_ANNOTATIONS -} - -IREE_API_EXPORT IREE_MUST_USE_RESULT iree_status_t IREE_API_CALL -iree_status_clone(iree_status_t status) { -#if IREE_STATUS_FEATURES == 0 - // Statuses are just codes; nothing to do. - return status; -#else - iree_status_storage_t* storage = iree_status_storage(status); - if (!storage) return status; - -#if (IREE_STATUS_FEATURES & IREE_STATUS_FEATURE_SOURCE_LOCATION) != 0 - const char* file = storage->file; - uint32_t line = storage->line; -#else - const char* file = NULL; - uint32_t line = 0; -#endif // has IREE_STATUS_FEATURE_SOURCE_LOCATION - -#if (IREE_STATUS_FEATURES & IREE_STATUS_FEATURE_ANNOTATIONS) != 0 - iree_string_view_t message = storage->message; -#else - iree_string_view_t message = iree_string_view_empty(); -#endif // has IREE_STATUS_FEATURE_ANNOTATIONS - - // Always copy the message by performing the formatting as we don't know - // whether the original status has ownership or not. - return iree_status_allocate_f(iree_status_code(status), file, line, "%.*s", - (int)message.size, message.data); -#endif // has no IREE_STATUS_FEATURES -} - -IREE_API_EXPORT void IREE_API_CALL iree_status_free(iree_status_t status) { -#if IREE_STATUS_FEATURES != 0 - iree_status_storage_t* storage = iree_status_storage(status); - if (!storage) return; - iree_status_payload_t* payload = storage->payload_head; - while (payload) { - iree_status_payload_t* next = payload->next; - iree_allocator_free(payload->allocator, payload); - payload = next; - } - iree_aligned_free(storage); -#endif // has any IREE_STATUS_FEATURES -} - -IREE_API_EXPORT iree_status_t IREE_API_CALL -iree_status_ignore(iree_status_t status) { - // We can set an 'ignored' flag on the status so that we can otherwise assert - // in iree_status_free when statuses are freed without this being called. - // Hoping with the C++ Status wrapper we won't hit that often so that - // complexity is skipped for now. - iree_status_free(status); - return iree_ok_status(); -} - -IREE_API_EXPORT iree_status_code_t IREE_API_CALL -iree_status_consume_code(iree_status_t status) { - iree_status_code_t code = iree_status_code(status); - iree_status_free(status); - return code; -} - -IREE_API_EXPORT IREE_MUST_USE_RESULT iree_status_t IREE_API_CALL -iree_status_annotate(iree_status_t base_status, iree_string_view_t message) { -#if (IREE_STATUS_FEATURES & IREE_STATUS_FEATURE_ANNOTATIONS) == 0 - // Annotations are disabled so we ignore this entirely. - return base_status; -#else - if (iree_status_is_ok(base_status) || iree_string_view_is_empty(message)) { - return base_status; - } - - // If there's no storage yet we can just reuse normal allocation. Both that - // and this do not copy |message|. - iree_status_storage_t* storage = iree_status_storage(base_status); - if (!storage) { - return iree_status_allocate(iree_status_code(base_status), NULL, 0, - message); - } else if (iree_string_view_is_empty(storage->message)) { - storage->message = message; - return base_status; - } - - iree_allocator_t allocator = iree_allocator_system(); - iree_status_payload_message_t* payload = NULL; - iree_status_ignore( - iree_allocator_malloc(allocator, sizeof(*payload), (void**)&payload)); - if (IREE_UNLIKELY(!payload)) return base_status; - memset(payload, 0, sizeof(*payload)); - payload->header.type = IREE_STATUS_PAYLOAD_TYPE_MESSAGE; - payload->header.allocator = allocator; - payload->header.formatter = iree_status_payload_message_formatter; - payload->message = message; - return iree_status_append_payload(base_status, storage, - (iree_status_payload_t*)payload); -#endif // has IREE_STATUS_FEATURE_ANNOTATIONS -} - -IREE_API_EXPORT IREE_MUST_USE_RESULT iree_status_t IREE_API_CALL -IREE_PRINTF_ATTRIBUTE(2, 3) - iree_status_annotate_f(iree_status_t base_status, const char* format, ...) { - va_list varargs_0, varargs_1; - va_start(varargs_0, format); - va_start(varargs_1, format); - iree_status_t ret = - iree_status_annotate_vf(base_status, format, varargs_0, varargs_1); - va_end(varargs_0); - va_end(varargs_1); - return ret; -} - -IREE_API_EXPORT IREE_MUST_USE_RESULT iree_status_t IREE_API_CALL -iree_status_annotate_vf(iree_status_t base_status, const char* format, - va_list varargs_0, va_list varargs_1) { -#if (IREE_STATUS_FEATURES & IREE_STATUS_FEATURE_ANNOTATIONS) == 0 - return base_status; -#else - if (iree_status_is_ok(base_status)) return base_status; - - // If there's no storage yet we can just reuse normal allocation. Both that - // and this do not copy |message|. - iree_status_storage_t* storage = iree_status_storage(base_status); - if (!storage) { - return iree_status_allocate_vf(iree_status_code(base_status), NULL, 0, - format, varargs_0, varargs_1); - } - - // Compute the total number of bytes (including NUL) required to store the - // message. - size_t message_size = - vsnprintf(/*buffer=*/NULL, /*buffer_count=*/0, format, varargs_0); - va_end(varargs_0); - if (message_size < 0) return base_status; - ++message_size; // NUL byte - - // Allocate storage with the additional room to store the formatted message. - // This avoids additional allocations for the common case of a message coming - // only from the original status error site. - iree_allocator_t allocator = iree_allocator_system(); - iree_status_payload_message_t* payload = NULL; - iree_status_ignore(iree_allocator_malloc( - allocator, sizeof(*payload) + message_size, (void**)&payload)); - if (IREE_UNLIKELY(!payload)) return base_status; - memset(payload, 0, sizeof(*payload)); - payload->header.type = IREE_STATUS_PAYLOAD_TYPE_MESSAGE; - payload->header.allocator = allocator; - payload->header.formatter = iree_status_payload_message_formatter; - - // vsnprintf directly into message buffer. - payload->message.size = message_size - 1; - payload->message.data = - (const char*)payload + sizeof(iree_status_payload_message_t); - int ret = vsnprintf((char*)payload->message.data, payload->message.size + 1, - format, varargs_1); - if (IREE_UNLIKELY(ret < 0)) { - iree_aligned_free(payload); - return base_status; - } - return iree_status_append_payload(base_status, storage, - (iree_status_payload_t*)payload); -#endif // has IREE_STATUS_FEATURE_ANNOTATIONS -} - -IREE_API_EXPORT bool IREE_API_CALL -iree_status_format(iree_status_t status, iree_host_size_t buffer_capacity, - char* buffer, iree_host_size_t* out_buffer_length) { - *out_buffer_length = 0; - - // Grab storage which may have a message and zero or more payloads. - iree_status_storage_t* storage = iree_status_storage(status); - - // Prefix with source location and status code string (may be 'OK'). - iree_host_size_t buffer_length = 0; - iree_status_code_t status_code = iree_status_code(status); - int n = 0; -#if (IREE_STATUS_FEATURES & IREE_STATUS_FEATURE_SOURCE_LOCATION) != 0 - if (storage && storage->file) { - n = snprintf(buffer ? buffer + buffer_length : NULL, - buffer ? buffer_capacity - buffer_length : 0, "%s:%d: %s", - storage->file, storage->line, - iree_status_code_string(status_code)); - } else { - n = snprintf(buffer ? buffer + buffer_length : NULL, - buffer ? buffer_capacity - buffer_length : 0, "%s", - iree_status_code_string(status_code)); - } -#else - n = snprintf(buffer ? buffer + buffer_length : NULL, - buffer ? buffer_capacity - buffer_length : 0, "%s", - iree_status_code_string(status_code)); -#endif // has IREE_STATUS_FEATURE_SOURCE_LOCATION - if (IREE_UNLIKELY(n < 0)) { - return false; - } else if (buffer && n >= buffer_capacity - buffer_length) { - buffer = NULL; - } - buffer_length += n; - -#if (IREE_STATUS_FEATURES & IREE_STATUS_FEATURE_ANNOTATIONS) != 0 - // Append base storage message. - if (storage && !iree_string_view_is_empty(storage->message)) { - n = snprintf(buffer ? buffer + buffer_length : NULL, - buffer ? buffer_capacity - buffer_length : 0, "; %.*s", - (int)storage->message.size, storage->message.data); - if (IREE_UNLIKELY(n < 0)) { - return false; - } else if (buffer && n >= buffer_capacity - buffer_length) { - buffer = NULL; - } - buffer_length += n; - } -#endif // has IREE_STATUS_FEATURE_ANNOTATIONS - -#if IREE_STATUS_FEATURES != 0 - // Append each payload separated by a newline. - iree_status_payload_t* payload = storage ? storage->payload_head : NULL; - while (payload != NULL) { - // Skip payloads that have no textual representation. - if (!payload->formatter) { - payload = payload->next; - continue; - } - - // Append newline to join with message above and other payloads. - if (buffer) { - if (2 >= buffer_capacity - buffer_length) { - buffer = NULL; - } else { - buffer[buffer_length] = ';'; - buffer[buffer_length + 1] = ' '; - buffer[buffer_length + 2] = '\0'; - } - } - buffer_length += 2; // '; ' - - // Append payload via custom formatter callback. - iree_host_size_t payload_buffer_length = 0; - payload->formatter(payload, buffer ? buffer_capacity - buffer_length : 0, - buffer ? buffer + buffer_length : NULL, - &payload_buffer_length); - if (buffer && payload_buffer_length >= buffer_capacity - buffer_length) { - buffer = NULL; - } - buffer_length += payload_buffer_length; - - payload = payload->next; - } -#endif // has IREE_STATUS_FEATURES - - *out_buffer_length = buffer_length; - return true; -} - -IREE_API_EXPORT bool IREE_API_CALL -iree_status_to_string(iree_status_t status, char** out_buffer, - iree_host_size_t* out_buffer_length) { - *out_buffer_length = 0; - iree_host_size_t buffer_length = 0; - if (IREE_UNLIKELY(!iree_status_format(status, /*buffer_capacity=*/0, - /*buffer=*/NULL, &buffer_length))) { - return false; - } - // Buffer capacity needs to be +1 to account for the terminating null of - // snprintf. - buffer_length++; - char* buffer = (char*)malloc(buffer_length); - if (IREE_UNLIKELY(!buffer)) return false; - bool ret = - iree_status_format(status, buffer_length, buffer, out_buffer_length); - if (ret) { - *out_buffer = buffer; - return true; - } else { - free(buffer); - return false; - } -} - //===----------------------------------------------------------------------===// // IREE Core API //===----------------------------------------------------------------------===// @@ -992,150 +33,3 @@ "%d but IREE is compiled as %d", expected_version, actual_version); } - -//===----------------------------------------------------------------------===// -// iree_time_t and iree_duration_t -//===----------------------------------------------------------------------===// - -IREE_API_EXPORT iree_time_t iree_time_now() { -#if defined(IREE_PLATFORM_WINDOWS) - // GetSystemTimePreciseAsFileTime requires Windows 8, add a fallback - // (such as using std::chrono) if older support is needed. - FILETIME system_time; - GetSystemTimePreciseAsFileTime(&system_time); - - const int64_t kUnixEpochStartTicks = 116444736000000000i64; - const int64_t kFtToMicroSec = 10; - LARGE_INTEGER li; - li.LowPart = system_time.dwLowDateTime; - li.HighPart = system_time.dwHighDateTime; - li.QuadPart -= kUnixEpochStartTicks; - li.QuadPart /= kFtToMicroSec; - return li.QuadPart; -#elif defined(IREE_PLATFORM_ANDROID) || defined(IREE_PLATFORM_APPLE) || \ - defined(IREE_PLATFORM_LINUX) - struct timespec clock_time; - clock_gettime(CLOCK_REALTIME, &clock_time); - return clock_time.tv_nsec; -#else -#error "IREE system clock needs to be set up for your platform" -#endif // IREE_PLATFORM_* -} - -IREE_API_EXPORT iree_time_t -iree_relative_timeout_to_deadline_ns(iree_duration_t timeout_ns) { - if (timeout_ns == IREE_DURATION_ZERO) { - return IREE_TIME_INFINITE_PAST; - } else if (timeout_ns == IREE_DURATION_INFINITE) { - return IREE_TIME_INFINITE_FUTURE; - } - return iree_time_now() + timeout_ns; -} - -IREE_API_EXPORT iree_duration_t -iree_absolute_deadline_to_timeout_ns(iree_time_t deadline_ns) { - if (deadline_ns == IREE_TIME_INFINITE_PAST) { - return IREE_DURATION_ZERO; - } else if (deadline_ns == IREE_TIME_INFINITE_FUTURE) { - return IREE_DURATION_INFINITE; - } else { - // We have either already passed the deadline (and can turn this into a - // poll) or want to do nanos->millis. We round up so that a deadline of 1ns - // results in 1ms as it should still wait, vs. if it was actually 0ns - // indicating the user intended a poll. - iree_time_t now_ns = iree_time_now(); - return deadline_ns < now_ns ? IREE_DURATION_ZERO : deadline_ns - now_ns; - } -} - -//===----------------------------------------------------------------------===// -// iree_allocator_t -//===----------------------------------------------------------------------===// - -IREE_API_EXPORT iree_status_t IREE_API_CALL iree_allocator_malloc( - iree_allocator_t allocator, iree_host_size_t byte_length, void** out_ptr) { - if (!allocator.alloc) { - return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, - "allocator has no alloc routine"); - } - return allocator.alloc(allocator.self, IREE_ALLOCATION_MODE_ZERO_CONTENTS, - byte_length, out_ptr); -} - -IREE_API_EXPORT iree_status_t IREE_API_CALL iree_allocator_realloc( - iree_allocator_t allocator, iree_host_size_t byte_length, void** out_ptr) { - if (!allocator.alloc) { - return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, - "allocator has no alloc routine"); - } - return allocator.alloc(allocator.self, - IREE_ALLOCATION_MODE_TRY_REUSE_EXISTING, byte_length, - out_ptr); -} - -IREE_API_EXPORT iree_status_t IREE_API_CALL -iree_allocator_clone(iree_allocator_t allocator, - iree_const_byte_span_t source_bytes, void** out_ptr) { - IREE_RETURN_IF_ERROR( - iree_allocator_malloc(allocator, source_bytes.data_length, out_ptr)); - memcpy(*out_ptr, source_bytes.data, source_bytes.data_length); - return iree_ok_status(); -} - -IREE_API_EXPORT void IREE_API_CALL -iree_allocator_free(iree_allocator_t allocator, void* ptr) { - if (ptr && allocator.free) { - allocator.free(allocator.self, ptr); - } -} - -IREE_API_EXPORT iree_status_t IREE_API_CALL -iree_allocator_system_allocate(void* self, iree_allocation_mode_t mode, - iree_host_size_t byte_length, void** out_ptr) { - IREE_ASSERT_ARGUMENT(out_ptr); - if (byte_length == 0) { - return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, - "allocations must be >0 bytes"); - } - - IREE_TRACE_ZONE_BEGIN(z0); - - void* existing_ptr = *out_ptr; - void* ptr = NULL; - if (existing_ptr && (mode & IREE_ALLOCATION_MODE_TRY_REUSE_EXISTING)) { - ptr = realloc(existing_ptr, byte_length); - if (ptr && (mode & IREE_ALLOCATION_MODE_ZERO_CONTENTS)) { - memset(ptr, 0, byte_length); - } - } else { - existing_ptr = NULL; - if (mode & IREE_ALLOCATION_MODE_ZERO_CONTENTS) { - ptr = calloc(1, byte_length); - } else { - ptr = malloc(byte_length); - } - } - if (!ptr) { - return iree_make_status(IREE_STATUS_RESOURCE_EXHAUSTED, - "system allocator failed the request"); - } - - if (existing_ptr) { - IREE_TRACE_FREE(existing_ptr); - } - IREE_TRACE_ALLOC(ptr, byte_length); - - *out_ptr = ptr; - IREE_TRACE_ZONE_END(z0); - return iree_ok_status(); -} - -IREE_API_EXPORT void IREE_API_CALL iree_allocator_system_free(void* self, - void* ptr) { - IREE_TRACE_ZONE_BEGIN(z0); - IREE_TRACE_FREE(ptr); - if (ptr) { - free(ptr); - } - IREE_TRACE_ZONE_END(z0); -}
diff --git a/iree/base/api.h b/iree/base/api.h index 2a0cbd0..f5255aa 100644 --- a/iree/base/api.h +++ b/iree/base/api.h
@@ -339,6 +339,10 @@ IREE_API_EXPORT iree_string_view_t IREE_API_CALL iree_string_view_remove_prefix(iree_string_view_t value, iree_host_size_t n); +// Removes leading and trailing whitespace. +IREE_API_EXPORT iree_string_view_t IREE_API_CALL +iree_string_view_trim(iree_string_view_t value); + // Returns a substring of the string view at offset |pos| and length |n|. // Use |n| == INTPTR_MAX to take the remaineder of the string after |pos|. // Returns empty string on failure.
diff --git a/iree/base/internal/BUILD b/iree/base/internal/BUILD index 0906894..31833b4 100644 --- a/iree/base/internal/BUILD +++ b/iree/base/internal/BUILD
@@ -12,7 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -# Implementations for iree/base/ +# Implementations for iree/base/. +# These are not part of the IREE API. Though they may be used by external +# projects their API may change at any time. + +load("//iree:lit_test.bzl", "iree_lit_test_suite") package( default_visibility = ["//visibility:public"], @@ -151,14 +155,34 @@ cc_library( name = "flags", - srcs = ["flags.cc"], + srcs = ["flags.c"], hdrs = ["flags.h"], deps = [ + ":internal", "//iree/base:api", - "@com_google_absl//absl/flags:parse", + "//iree/base:tracing", ], ) +cc_binary( + name = "flags_demo", + srcs = ["flags_demo.c"], + deps = [ + ":flags", + "//iree/base:core_headers", + ], +) + +iree_lit_test_suite( + name = "flags_test", + srcs = ["flags_test.txt"], + data = [ + ":flags_demo", + "//iree/tools:IreeFileCheck", + ], + tags = ["hostonly"], +) + cc_library( name = "main", srcs = [
diff --git a/iree/base/internal/CMakeLists.txt b/iree/base/internal/CMakeLists.txt index 2ab0ae8..603a4bc 100644 --- a/iree/base/internal/CMakeLists.txt +++ b/iree/base/internal/CMakeLists.txt
@@ -147,13 +147,36 @@ HDRS "flags.h" SRCS - "flags.cc" + "flags.c" DEPS - absl::flags_parse + ::internal iree::base::api + iree::base::tracing PUBLIC ) +iree_cc_binary( + NAME + flags_demo + SRCS + "flags_demo.c" + DEPS + ::flags + iree::base::core_headers +) + +iree_lit_test_suite( + NAME + flags_test + SRCS + "flags_test.txt" + DATA + ::flags_demo + iree::tools::IreeFileCheck + LABELS + "hostonly" +) + iree_cc_library( NAME main
diff --git a/iree/base/internal/debugging.h b/iree/base/internal/debugging.h index 05f0594..e05773c 100644 --- a/iree/base/internal/debugging.h +++ b/iree/base/internal/debugging.h
@@ -87,6 +87,7 @@ #if defined(IREE_SANITIZER_ADDRESS) #include <sanitizer/asan_interface.h> +#include <sanitizer/lsan_interface.h> #endif // IREE_SANITIZER_ADDRESS // For whenever we want to provide specialized msan/tsan hooks:
diff --git a/iree/base/internal/flags.c b/iree/base/internal/flags.c new file mode 100644 index 0000000..0b67591 --- /dev/null +++ b/iree/base/internal/flags.c
@@ -0,0 +1,608 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "iree/base/internal/flags.h" + +#include <errno.h> +#include <inttypes.h> +#include <limits.h> +#include <stdio.h> +#include <stdlib.h> +#include <string.h> + +#include "iree/base/internal/debugging.h" +#include "iree/base/tracing.h" + +#if IREE_FLAGS_ENABLE_CLI == 1 + +//===----------------------------------------------------------------------===// +// Flag manipulation utilities +//===----------------------------------------------------------------------===// + +static iree_status_t iree_flags_leaky_alloc(void* self, + iree_allocation_mode_t mode, + iree_host_size_t byte_length, + void** out_ptr) { + IREE_LEAK_CHECK_DISABLE_PUSH(); + void* ptr = malloc(byte_length); + IREE_LEAK_CHECK_DISABLE_POP(); + memset(ptr, 0, byte_length); + *out_ptr = ptr; + return iree_ok_status(); +} + +static void iree_flags_leaky_free(void* self, void* ptr) { free(ptr); } + +// Allocates heap memory that is leaked without triggering leak checkers. +// We do this so that we have valid memory for the lifetime of the process. +// The memory may still be freed but if not will not hurt anything (besides the +// private working set size). +static iree_allocator_t iree_flags_leaky_allocator() { + iree_allocator_t allocator = { + .alloc = iree_flags_leaky_alloc, + .free = iree_flags_leaky_free, + .self = NULL, + }; + return allocator; +} + +//===----------------------------------------------------------------------===// +// Flag registry +//===----------------------------------------------------------------------===// + +// Storage for registered flags. +typedef struct { + // __FILE__ of flag definition. + const char* file; + // __LINE__ of flag definition. + int line; + // Defines what data is at |storage| and how to parse/print it. + iree_flag_type_t type; + // Registered callback to issue when the flag is parsed, if any. + iree_flag_parse_callback_fn_t parse_callback; + // Registered callback to issue when the flag is to be printed, if any. + iree_flag_print_callback_fn_t print_callback; + // Direct reference to the variable storing the flag value of |type|. + void* storage; + // Name of the flag on the command line ('foo' => '--foo=value'). + iree_string_view_t name; + // Short description string. + iree_string_view_t description; +} iree_flag_t; + +// State used for flag registration and reflection. +typedef struct { + const char* program_name; + const char* usage; + + // Total number of entries in the |flags| list. + int flag_count; + // All registered flags in the executable in an undefined order. + iree_flag_t flags[IREE_FLAGS_CAPACITY]; +} iree_flag_registry_t; + +// Global flags state. +// This will persist for the lifetime of the program so that flags can be +// reparsed/dumped. If you're concerned about the .data overhead then you +// probably just want to disable the CLI support for flags entirely. +static iree_flag_registry_t iree_flag_registry = { + .program_name = NULL, + .usage = NULL, + .flag_count = 0, +}; + +int iree_flag_register(const char* file, int line, iree_flag_type_t type, + void* storage, + iree_flag_parse_callback_fn_t parse_callback, + iree_flag_print_callback_fn_t print_callback, + iree_string_view_t name, + iree_string_view_t description) { + // TODO(benvanik): make the registry a linked list and externalize the + // flag storage - then no need for a fixed count. If you're hitting this then + // file an issue :) + iree_flag_registry_t* registry = &iree_flag_registry; + IREE_ASSERT_LE(registry->flag_count + 1, IREE_FLAGS_CAPACITY, + "flag registry overflow; too many flags registered"); + int flag_ordinal = registry->flag_count++; + iree_flag_t* flag = ®istry->flags[flag_ordinal]; + flag->file = file; + flag->line = line; + flag->type = type; + flag->parse_callback = parse_callback; + flag->print_callback = print_callback; + flag->storage = storage; + flag->name = name; + flag->description = description; + return flag_ordinal; +} + +// Returns the flag registration with the given |name| or NULL if not found. +static iree_flag_t* iree_flag_lookup(iree_string_view_t name) { + iree_flag_registry_t* registry = &iree_flag_registry; + for (int i = 0; i < registry->flag_count; ++i) { + iree_flag_t* flag = ®istry->flags[i]; + if (iree_string_view_equal(flag->name, name)) { + return flag; + } + } + return NULL; +} + +static int iree_flag_cmp(const void* lhs_ptr, const void* rhs_ptr) { + const iree_flag_t* lhs = (const iree_flag_t*)lhs_ptr; + const iree_flag_t* rhs = (const iree_flag_t*)rhs_ptr; + int ret = strcmp(lhs->file, rhs->file); + if (ret == 0) { + return lhs->line - rhs->line; + } + return ret; +} + +// Sorts the flags in the flag registry by file > line. +static void iree_flag_registry_sort(iree_flag_registry_t* registry) { + qsort(registry->flags, registry->flag_count, sizeof(iree_flag_t), + iree_flag_cmp); +} + +//===----------------------------------------------------------------------===// +// Flag parsing/printing +//===----------------------------------------------------------------------===// + +void iree_flags_set_usage(const char* program_name, const char* usage) { + iree_flag_registry_t* registry = &iree_flag_registry; + registry->program_name = program_name; + registry->usage = usage; +} + +// Parses a flag value from the given string and stores it. +static iree_status_t iree_flag_parse(iree_flag_t* flag, + iree_string_view_t value) { + IREE_TRACE_ZONE_BEGIN(z0); + IREE_TRACE_ZONE_APPEND_TEXT(z0, flag->name.data, flag->name.size); + IREE_TRACE_ZONE_APPEND_TEXT(z0, value.data, value.size); + + // Insert NUL on the flag value. This is safe as the value is either coming + // from C argv memory which is mutable or a flagfile that we loaded into + // memory ourselves. + char* str_value = (char*)value.data; + if (value.size > 0) { + str_value[value.size] = 0; + } + + iree_status_t status = iree_ok_status(); + switch (flag->type) { + case IREE_FLAG_TYPE_callback: + status = flag->parse_callback(flag->name, flag->storage, value); + break; + case IREE_FLAG_TYPE_bool: + if (value.size == 0 || strcmp(str_value, "true") == 0 || + strcmp(str_value, "1") == 0) { + *(bool*)flag->storage = true; + } else { + *(bool*)flag->storage = false; + } + break; + case IREE_FLAG_TYPE_int32_t: + *(int32_t*)flag->storage = value.size ? atoi(str_value) : 0; + break; + case IREE_FLAG_TYPE_int64_t: + *(int64_t*)flag->storage = value.size ? atoll(str_value) : 0; + break; + case IREE_FLAG_TYPE_float: + *(float*)flag->storage = value.size ? (float)atof(str_value) : 0.0f; + break; + case IREE_FLAG_TYPE_double: + *(double*)flag->storage = value.size ? atof(str_value) : 0.0; + break; + case IREE_FLAG_TYPE_string: { + iree_host_size_t str_length = value.size; + if (str_length > 2) { + // Strip double quotes: "foo" -> foo. + // This may not be worth the complexity. + if (str_value[0] == '"' && str_value[str_length - 1] == '"') { + str_value[str_length - 1] = 0; + ++str_value; + str_length = str_length - 2; + } + } + *(const char**)flag->storage = str_value; + break; + } + default: + status = iree_make_status(IREE_STATUS_FAILED_PRECONDITION, + "invalid flag type %u", flag->type); + break; + } + IREE_TRACE_ZONE_END(z0); + return status; +} + +// Prints a flag value to |file| (like 'true' or '5.43'). +static void iree_flag_print(FILE* file, iree_flag_t* flag) { + if (flag->type == IREE_FLAG_TYPE_callback) { + flag->print_callback(flag->name, flag->storage, file); + return; + } + fprintf(file, "--%.*s", (int)flag->name.size, flag->name.data); + if (flag->storage == NULL) return; + switch (flag->type) { + case IREE_FLAG_TYPE_bool: + fprintf(file, "=%s", (*(bool*)flag->storage) ? "true" : "false"); + break; + case IREE_FLAG_TYPE_int32_t: + fprintf(file, "=%" PRId32, *(int32_t*)flag->storage); + break; + case IREE_FLAG_TYPE_int64_t: + fprintf(file, "=%" PRId64, *(int64_t*)flag->storage); + break; + case IREE_FLAG_TYPE_float: + fprintf(file, "=%g", *(float*)flag->storage); + break; + case IREE_FLAG_TYPE_double: + fprintf(file, "=%g", *(double*)flag->storage); + break; + case IREE_FLAG_TYPE_string: + fprintf(file, "=\"%s\"", *(const char**)flag->storage); + break; + default: + fprintf(file, "=<INVALID>"); + break; + } + fprintf(file, "\n"); +} + +// Dumps a flag definition and value to |file|. +static void iree_flag_dump(iree_flag_dump_mode_t mode, FILE* file, + iree_flag_t* flag) { + if (iree_all_bits_set(mode, IREE_FLAG_DUMP_MODE_VERBOSE)) { + if (!iree_string_view_is_empty(flag->description)) { + iree_string_view_t description = flag->description; + while (!iree_string_view_is_empty(description)) { + iree_string_view_t line; + iree_string_view_split(description, '\n', &line, &description); + if (!iree_string_view_is_empty(line)) { + fprintf(file, "# %.*s\n", (int)line.size, line.data); + } + } + } + } + iree_flag_print(file, flag); +} + +static iree_status_t iree_flags_parse_help(iree_string_view_t flag_name, + void* storage, + iree_string_view_t value) { + iree_flag_registry_t* registry = &iree_flag_registry; + + fprintf(stdout, + "# " + "====================================================================" + "========\n"); + fprintf(stdout, "# 👻 IREE: %s\n", + registry->program_name ? registry->program_name : ""); + fprintf(stdout, + "# " + "====================================================================" + "========\n\n"); + if (registry->usage) { + fprintf(stdout, "%s\n", registry->usage); + } + iree_flags_dump(IREE_FLAG_DUMP_MODE_VERBOSE, stdout); + fprintf(stdout, "\n"); + + return iree_ok_status(); +} +static void iree_flags_print_help(iree_string_view_t flag_name, void* storage, + FILE* file) { + fprintf(file, "# --%.*s\n", (int)flag_name.size, flag_name.data); +} +IREE_FLAG_CALLBACK(iree_flags_parse_help, iree_flags_print_help, NULL, help, + "Displays command line usage information."); + +// Removes argument |arg| from the argument list. +static void iree_flags_remove_arg(int arg, int* argc_ptr, char*** argv_ptr) { + int argc = *argc_ptr; + char** argv = *argv_ptr; + memmove(&argv[arg], &argv[arg + 1], (argc - arg) * sizeof(char*)); + *argc_ptr = argc - 1; +} + +iree_status_t iree_flags_parse(iree_flags_parse_mode_t mode, int* argc_ptr, + char*** argv_ptr) { + if (argc_ptr == NULL || argv_ptr == NULL || *argc_ptr == 0) { + // No flags; that's fine - in some environments flags aren't supported. + return iree_ok_status(); + } + + // Always sort the registry; though we may parse flags multiple times this is + // not a hot path and this is easier than trying to keep track of whether we + // need to or not. + iree_flag_registry_sort(&iree_flag_registry); + + int argc = *argc_ptr; + char** argv = *argv_ptr; + + for (int arg_ordinal = 1; arg_ordinal < argc; ++arg_ordinal) { + iree_string_view_t arg = iree_make_cstring_view(argv[arg_ordinal]); + + // Strip whitespace. + arg = iree_string_view_trim(arg); + + // Position arguments are ignored; they may appear anywhere in the list. + if (!iree_string_view_starts_with(arg, iree_make_cstring_view("--"))) { + continue; + } + + // Strip `--`. + arg = iree_string_view_remove_prefix(arg, 2); + + // Split into `flag_name` = `flag_value`. + iree_string_view_t flag_name; + iree_string_view_t flag_value; + iree_string_view_split(arg, '=', &flag_name, &flag_value); + flag_name = iree_string_view_trim(flag_name); + flag_value = iree_string_view_trim(flag_value); + + // Lookup the flag by name. + iree_flag_t* flag = iree_flag_lookup(flag_name); + if (!flag) { + // If --undefok allows undefined flags then we just skip this one. Note + // that we leave it in the argument list so that subsequent flag parsers + // can try to handle it. + if (iree_all_bits_set(mode, IREE_FLAGS_PARSE_MODE_UNDEFINED_OK)) { + continue; + } + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "flag '%.*s' not recognized", (int)flag_name.size, + flag_name.data); + } + + // Parse and store the flag value. + IREE_RETURN_IF_ERROR(iree_flag_parse(flag, flag_value)); + + // --help gets special handling due to interop with external libraries that + // may also need to find it. If indicated we keep --help in the argument + // list and don't exit. + if (iree_string_view_equal(flag_name, iree_make_cstring_view("help"))) { + if (iree_all_bits_set(mode, IREE_FLAGS_PARSE_MODE_CONTINUE_AFTER_HELP)) { + continue; // don't remove the arg below + } + exit(0); // --help exits by default. + } + + // Splice out the flag from the argv list. + iree_flags_remove_arg(arg_ordinal, &argc, &argv); + --arg_ordinal; + } + + *argc_ptr = argc; + return iree_ok_status(); +} + +void iree_flags_parse_checked(iree_flags_parse_mode_t mode, int* argc, + char*** argv) { + IREE_TRACE_ZONE_BEGIN(z0); + for (int i = 0; i < *argc; ++i) { + IREE_TRACE_ZONE_APPEND_TEXT_CSTRING(z0, (*argv)[i]); + } + iree_status_t status = iree_flags_parse(mode, argc, argv); + IREE_TRACE_ZONE_END(z0); + if (iree_status_is_ok(status)) return; + + fprintf(stderr, "\x1b[31mFLAGS ERROR: (╯°â–¡°)╯︵👻\x1b[0m\n"); + char* buffer = NULL; + iree_host_size_t buffer_length = 0; + iree_status_to_string(status, &buffer, &buffer_length); + fprintf(stderr, "%.*s\n\n", (int)buffer_length, buffer); + fflush(stderr); + iree_allocator_free(iree_allocator_system(), buffer); + + exit(EXIT_FAILURE); +} + +void iree_flags_dump(iree_flag_dump_mode_t mode, FILE* file) { + IREE_TRACE_ZONE_BEGIN(z0); + + // Always sort the registry; though we may dump flags multiple times this is + // not a hot path and this is easier than trying to keep track of whether we + // need to or not. + iree_flag_registry_sort(&iree_flag_registry); + + const char* last_file = NULL; + for (size_t i = 0; i < iree_flag_registry.flag_count; ++i) { + iree_flag_t* flag = &iree_flag_registry.flags[i]; + if (iree_all_bits_set(mode, IREE_FLAG_DUMP_MODE_VERBOSE)) { + if (last_file) { + fprintf(file, "\n"); + } + if (!last_file || strcmp(flag->file, last_file) != 0) { + fprintf(file, + "# " + "===-----------------------------------------------------------" + "-----------===\n"); + fprintf(file, "# Flags in %s:%d\n", flag->file, flag->line); + fprintf(file, + "# " + "===-----------------------------------------------------------" + "-----------===\n\n"); + last_file = flag->file; + } + } + iree_flag_dump(mode, file, flag); + } + + IREE_TRACE_ZONE_END(z0); +} + +//===----------------------------------------------------------------------===// +// --flagfile= support +//===----------------------------------------------------------------------===// +// NOTE: this is conditionally enabled as some platforms may not have IO. + +#if IREE_FLAGS_ENABLE_FLAG_FILE == 1 + +// TODO(benvanik): use this to replace file_io.cc. +static iree_status_t iree_file_read_contents(const char* path, + iree_allocator_t allocator, + iree_byte_span_t* out_contents) { + IREE_TRACE_ZONE_BEGIN(z0); + *out_contents = iree_make_byte_span(NULL, 0); + FILE* file = fopen(path, "rb"); + if (file == NULL) { + IREE_TRACE_ZONE_END(z0); + return iree_make_status(iree_status_code_from_errno(errno), + "failed to open file '%s'", path); + } + iree_status_t status = iree_ok_status(); + if (fseek(file, 0, SEEK_END) == -1) { + status = iree_make_status(iree_status_code_from_errno(errno), "seek (end)"); + } + size_t file_size = 0; + if (iree_status_is_ok(status)) { + file_size = ftell(file); + if (file_size == -1L) { + status = + iree_make_status(iree_status_code_from_errno(errno), "size query"); + } + } + if (iree_status_is_ok(status)) { + if (fseek(file, 0, SEEK_SET) == -1) { + status = + iree_make_status(iree_status_code_from_errno(errno), "seek (beg)"); + } + } + // Allocate +1 to force a trailing \0 in case this is a string. + char* contents = NULL; + if (iree_status_is_ok(status)) { + status = iree_allocator_malloc(allocator, file_size + 1, (void**)&contents); + } + if (iree_status_is_ok(status)) { + if (fread(contents, file_size, 1, file) != 1) { + status = + iree_make_status(iree_status_code_from_errno(errno), + "unable to read entire file contents of '%s'", path); + } + } + if (iree_status_is_ok(status)) { + contents[file_size] = 0; // NUL + *out_contents = iree_make_byte_span(contents, file_size); + } else { + iree_allocator_free(allocator, contents); + } + fclose(file); + IREE_TRACE_ZONE_END(z0); + return status; +} + +// Parses a newline-separated list of flags from a file. +static iree_status_t iree_flags_parse_file(iree_string_view_t file_path) { + // Read file contents. + // NOTE: we intentionally leak the contents here so that the flags remain in + // memory in case they are referenced. + // NOTE: safe to use file_path.data here as it will always have a NUL + // terminator. + iree_allocator_t allocator = iree_flags_leaky_allocator(); + iree_byte_span_t file_contents; + IREE_RETURN_IF_ERROR( + iree_file_read_contents(file_path.data, allocator, &file_contents), + "while trying to parse flagfile"); + + // Run through the file line-by-line. + int line_number = 0; + iree_string_view_t contents = iree_make_string_view( + (const char*)file_contents.data, file_contents.data_length); + while (!iree_string_view_is_empty(contents)) { + // Split into a single line and the entire rest of the file contents. + iree_string_view_t line; + iree_string_view_split(contents, '\n', &line, &contents); + ++line_number; + + // Strip whitespace. + line = iree_string_view_trim(line); + if (iree_string_view_is_empty(line)) continue; + + // Ignore comments. + if (iree_string_view_starts_with(line, iree_make_cstring_view("#")) || + iree_string_view_starts_with(line, iree_make_cstring_view("//"))) { + continue; + } + + // Strip `--`. + if (!iree_string_view_starts_with(line, iree_make_cstring_view("--"))) { + // Positional arguments can't be specified in flag files. + return iree_make_status( + IREE_STATUS_INVALID_ARGUMENT, + "%.*s:%d: positional arguments not allowed in flag files", + (int)file_path.size, file_path.data, line_number); + } + line = iree_string_view_remove_prefix(line, 2); + + // Split into `flag_name` = `flag_value`. + iree_string_view_t flag_name; + iree_string_view_t flag_value; + iree_string_view_split(line, '=', &flag_name, &flag_value); + flag_name = iree_string_view_trim(flag_name); + flag_value = iree_string_view_trim(flag_value); + + // Lookup the flag by name. + iree_flag_t* flag = iree_flag_lookup(flag_name); + if (!flag) { + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "%.*s:%d: flag '%.*s' not recognized", + (int)file_path.size, file_path.data, line_number, + (int)flag_name.size, flag_name.data); + } + + // Parse the flag value. + IREE_RETURN_IF_ERROR(iree_flag_parse(flag, flag_value), + "%.*s:%d: while parsing flag '%.*s'", + (int)file_path.size, file_path.data, line_number, + (int)line.size, line.data); + } + + // NOTE: we intentionally leak the memory as flags may continue to reference + // segments of it for their string values. + return iree_ok_status(); +} + +static iree_status_t iree_flags_parse_flagfile(iree_string_view_t flag_name, + void* storage, + iree_string_view_t value) { + if (iree_string_view_is_empty(value)) { + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "--%.*s= requires a file path", (int)flag_name.size, + flag_name.data); + } + + IREE_TRACE_ZONE_BEGIN(z0); + IREE_TRACE_ZONE_APPEND_TEXT(z0, value.data, value.size); + iree_status_t status = iree_flags_parse_file(value); + IREE_TRACE_ZONE_END(z0); + + return status; +} +static void iree_flags_print_flagfile(iree_string_view_t flag_name, + void* storage, FILE* file) { + fprintf(file, "# --%.*s=[path]\n", (int)flag_name.size, flag_name.data); +} +IREE_FLAG_CALLBACK(iree_flags_parse_flagfile, iree_flags_print_flagfile, NULL, + flagfile, + "Parses a newline-separated list of flags from a file.\n" + "Flags are parsed at the point where the flagfile is " + "specified\nand following flags may override the parsed " + "values."); + +#endif // IREE_FLAGS_ENABLE_FLAG_FILE + +#endif // IREE_FLAGS_ENABLE_CLI
diff --git a/iree/base/internal/flags.cc b/iree/base/internal/flags.cc deleted file mode 100644 index 8d0145a..0000000 --- a/iree/base/internal/flags.cc +++ /dev/null
@@ -1,54 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/base/internal/flags.h" - -#include <stdlib.h> -#include <string.h> - -// TODO(#3814): replace abseil with pretty much anything else. -#include "absl/flags/parse.h" - -iree_status_t iree_flags_parse(int* argc, char*** argv) { - if (argc == nullptr || argv == nullptr || *argc == 0) { - // No flags; that's fine - in some environments flags aren't supported. - return iree_ok_status(); - } - - auto positional_args = absl::ParseCommandLine(*argc, *argv); - if (positional_args.size() < *argc) { - // Edit the passed argument refs to only include positional args. - *argc = static_cast<int>(positional_args.size()); - for (int i = 0; i < *argc; ++i) { - (*argv)[i] = positional_args[i]; - } - (*argv)[*argc + 1] = nullptr; - } - - return iree_ok_status(); -} - -void iree_flags_parse_checked(int* argc, char*** argv) { - iree_status_t status = iree_flags_parse(argc, argv); - if (iree_status_is_cancelled(status)) { - exit(EXIT_SUCCESS); - return; - } - if (!iree_status_is_ok(status)) { - // TODO(#2843): replace C++ logging. - iree_status_ignore(status); - exit(EXIT_FAILURE); - return; - } -}
diff --git a/iree/base/internal/flags.h b/iree/base/internal/flags.h index ef2c9da..6f2b68a 100644 --- a/iree/base/internal/flags.h +++ b/iree/base/internal/flags.h
@@ -15,50 +15,279 @@ #ifndef IREE_BASE_INTERNAL_FLAGS_H_ #define IREE_BASE_INTERNAL_FLAGS_H_ +#include <stdio.h> + #include "iree/base/api.h" +#include "iree/base/target_platform.h" #ifdef __cplusplus extern "C" { #endif // __cplusplus +// 1 to enable command line parsing from argc/argv; 0 otherwise. +// When parsing is disabled flags are just variables that can still be queried +// and manually overridden by code if desired. +#if !defined(IREE_FLAGS_ENABLE_CLI) +#define IREE_FLAGS_ENABLE_CLI 1 +#endif // !IREE_FLAGS_ENABLE_CLI + +// 1 to enable --flagfile= support. +#if !defined(IREE_FLAGS_ENABLE_FLAG_FILE) +#define IREE_FLAGS_ENABLE_FLAG_FILE 1 +#endif // !IREE_FLAGS_ENABLE_FLAG_FILE + +// Maximum number of flags that can be registered in a single binary. +#define IREE_FLAGS_CAPACITY 64 + +//===----------------------------------------------------------------------===// +// Static initialization utility +//===----------------------------------------------------------------------===// +// This declares a static initialization function with the given name. +// Usage: +// IREE_STATIC_INITIALIZER(initializer_name) { +// // Do something here! Note that initialization order is undefined and +// // what you do should be tolerant to that. +// +// // If you want a finalizer (you probably don't; they may not get run) +// // then you can use atexit: +// atexit(some_finalizer_fn); +// } + +#ifdef __cplusplus + +#define IREE_STATIC_INITIALIZER(f) \ + static void f(void); \ + struct f##_t_ { \ + f##_t_(void) { f(); } \ + }; \ + static f##_t_ f##_; \ + static void f(void) + +#elif defined(IREE_COMPILER_MSVC) + +// `__attribute__((constructor))`-like behavior in MSVC. See: +// https://docs.microsoft.com/en-us/cpp/c-runtime-library/crt-initialization?view=msvc-160 + +#pragma section(".CRT$XCU", read) +#define IREE_STATIC_INITIALIZER_IMPL(f, p) \ + static void f(void); \ + __declspec(allocate(".CRT$XCU")) void (*f##_)(void) = f; \ + __pragma(comment(linker, "/include:" p #f "_")) static void f(void) +#ifdef _WIN64 +#define IREE_STATIC_INITIALIZER(f) IREE_STATIC_INITIALIZER_IMPL(f, "") +#else +#define IREE_STATIC_INITIALIZER(f) IREE_STATIC_INITIALIZER_IMPL(f, "_") +#endif // _WIN64 + +#else + +#define IREE_STATIC_INITIALIZER(f) \ + static void f(void) __attribute__((constructor)); \ + static void f(void) + +#endif // __cplusplus / MSVC + +//===----------------------------------------------------------------------===// +// Flag definition +//===----------------------------------------------------------------------===// + +typedef enum { + IREE_FLAG_DUMP_MODE_DEFAULT = 0, + IREE_FLAG_DUMP_MODE_VERBOSE = 1u << 0, +} iree_flag_dump_mode_t; + +#if IREE_FLAGS_ENABLE_CLI == 1 + +// Types of flags supported by the parser. +typedef enum { + // Empty/unspecified sentinel. + IREE_FLAG_TYPE_none = 0, + // Custom parsing callback; see IREE_FLAG_CALLBACK. + IREE_FLAG_TYPE_callback = 1, + // Boolean flag: + // --foo (set true) + // --foo=true | --foo=false + IREE_FLAG_TYPE_bool, + // 32-bit integer flag: + // --foo=123 + IREE_FLAG_TYPE_int32_t, + // 64-bit integer flag: + // --foo=123 + IREE_FLAG_TYPE_int64_t, + // 32-bit floating-point flag: + // --foo=1.2 + IREE_FLAG_TYPE_float, + // 64-bit floating-point flag: + // --foo=1.2 + IREE_FLAG_TYPE_double, + // String flag: + // --foo=abc + // --foo="a b c" + // Holds a reference to constant string data; assigned values must remain + // live for as long as the flag value references them. + IREE_FLAG_TYPE_string, +} iree_flag_type_t; + +#define IREE_FLAG_CTYPE_bool bool +#define IREE_FLAG_CTYPE_int32_t int32_t +#define IREE_FLAG_CTYPE_int64_t int64_t +#define IREE_FLAG_CTYPE_float float +#define IREE_FLAG_CTYPE_double double +#define IREE_FLAG_CTYPE_string const char* + +// Custom callback issued for each time the flag is seen during parsing. +// The |value| provided will already be trimmed and may be empty. For +// compatibility with non-IREE APIs there will be a NUL terminator immediately +// following the flag value in memory such that `value.data` can be used as a +// C-string. +typedef iree_status_t(IREE_API_PTR* iree_flag_parse_callback_fn_t)( + iree_string_view_t flag_name, void* storage, iree_string_view_t value); + +// Custom callback issued for each time the flag is to be printed. +// The callback should print the flag and its value to |file|. +// Example: `--my_flag=value\n` +typedef void(IREE_API_PTR* iree_flag_print_callback_fn_t)( + iree_string_view_t flag_name, void* storage, FILE* file); + +int iree_flag_register(const char* file, int line, iree_flag_type_t type, + void* storage, + iree_flag_parse_callback_fn_t parse_callback, + iree_flag_print_callback_fn_t print_callback, + iree_string_view_t name, iree_string_view_t description); + +// Defines a flag with the given |type| and |name|. +// +// Conceptually the flag is just a variable and can be loaded/stored: +// IREE_FLAG(bool, foo, true, "hello"); +// => +// static bool FLAG_foo = true; +// ... +// if (FLAG_foo) do_something(); +// +// If flag parsing is enabled with IREE_FLAGS_ENABLE_CLI == 1 then the flag +// value can be specified on the command line with --name: +// --foo +// --foo=true +// +// See iree_flag_type_t for the types supported and how they are parsed. +#define IREE_FLAG(type, name, default_value, description) \ + static IREE_FLAG_CTYPE_##type FLAG_##name = (default_value); \ + IREE_STATIC_INITIALIZER(iree_flag_register_##name) { \ + iree_flag_register(__FILE__, __LINE__, IREE_FLAG_TYPE_##type, \ + (void**)&(FLAG_##name), /*parse_callback=*/NULL, \ + /*print_callback=*/NULL, iree_make_cstring_view(#name), \ + iree_make_cstring_view(description)); \ + } + +// Defines a flag issues |callback| for custom parsing. +// +// Usage: +// iree_status_t parse_callback(const char* flag_name, void* storage, +// iree_string_view_t value) { +// // Parse |value| and store in |storage|, however you want. +// // Returning IREE_STATUS_INVALID_ARGUMENT will trigger --help. +// int* storage_ptr = (int*)storage; +// printf("hello! %d", (*storage_ptr)++); +// return iree_ok_status(); +// } +// void print_callback(const char* flag_name, void* storage, FILE* file) { +// // Print the value in |storage|, however you want. For repeated fields +// // you can print multiple separated by newlines. +// int* storage_ptr = (int*)storage; +// fprintf(file, "--say_hello=%d\n", *storage_ptr); +// } +// int my_storage = 0; +// IREE_FLAG_CALLBACK(parse_callback, print_callback, &my_storage, +// say_hello, "Say hello!"); +#define IREE_FLAG_CALLBACK(parse_callback, print_callback, storage, name, \ + description) \ + IREE_STATIC_INITIALIZER(iree_flag_register_##name) { \ + iree_flag_register(__FILE__, __LINE__, IREE_FLAG_TYPE_callback, \ + (void*)storage, parse_callback, print_callback, \ + iree_make_cstring_view(#name), \ + iree_make_cstring_view(description)); \ + } + +#else + +#define IREE_FLAG(type, name, default_value, description) \ + static IREE_FLAG_CTYPE_##type name = (default_value); + +#define IREE_FLAG_CALLBACK(parse_callback, print_callback, storage, name, \ + description) + +#endif // IREE_FLAGS_ENABLE_CLI + //===----------------------------------------------------------------------===// // Flag parsing //===----------------------------------------------------------------------===// +// Controls how flag parsing is performed. +enum iree_flags_parse_mode_e { + IREE_FLAGS_PARSE_MODE_DEFAULT = 0, + // Do not error out on undefined flags; leave them in the list. + // Useful when needing to chain multiple flag parsers together. + IREE_FLAGS_PARSE_MODE_UNDEFINED_OK = 1u << 0, + // Continues parsing and returns success without exiting when `--help` is + // encountered. This allows for IREE flag parsing to happen before another + // external library parses its flags. `--help` will remain in the flag set + // such that the subsequent parsing can find it. + IREE_FLAGS_PARSE_MODE_CONTINUE_AFTER_HELP = 1u << 1, +}; +typedef uint32_t iree_flags_parse_mode_t; + +#if IREE_FLAGS_ENABLE_CLI == 1 + +// Sets the usage information printed when --help is passed on the command line. +// Both strings must remain live for the lifetime of the program. +void iree_flags_set_usage(const char* program_name, const char* usage); + // Parses flags from the given command line arguments. // All flag-style arguments ('--foo', '-f', etc) will be consumed and argc/argv // will be updated to contain only the program name (index 0) and any remaining // positional arguments. // -// Returns success if all flags were parsed and execution should continue. -// May return IREE_STATUS_CANCELLED if execution should be cancelled gracefully -// such as when --help is used. +// Returns 0 if all flags were parsed and execution should continue. +// Returns >0 if execution should be cancelled such as when --help is used. +// Returns <0 if parsing fails. // // Usage: // extern "C" int main(int argc, char** argv) { // iree_status_t status = iree_flags_parse(&argc, &argv); -// if (iree_status_is_cancelled(status)) return 0; -// if (!iree_status_is_ok(status)) { -// // TODO(#2843): replace C++ logging. -// LOG(ERROR) << status; -// iree_status_ignore(status); -// return 1; -// } +// if (!iree_status_is_ok(status)) { exit(1); } // consume_positional_args(argc, argv); // return 0; // } // // Example: -// argc = 4, argv = ['program', 'abc', '--flag=2', '-p'] +// argc = 4, argv = ['program', 'abc', '--flag=2'] // Results: // argc = 2, argv = ['program', 'abc'] -iree_status_t iree_flags_parse(int* argc, char*** argv); +iree_status_t iree_flags_parse(iree_flags_parse_mode_t mode, int* argc, + char*** argv); // Parses flags as with iree_flags_parse but will use exit() or abort(). // WARNING: this almost always what you want in a command line tool and *never* // what you want when embedded in a host process. You don't want to have a flag // typo and shut down your entire server/sandbox/Android app/etc. -void iree_flags_parse_checked(int* argc, char*** argv); +void iree_flags_parse_checked(iree_flags_parse_mode_t mode, int* argc, + char*** argv); + +// Dumps all flags and their current values to the given |file|. +void iree_flags_dump(iree_flag_dump_mode_t mode, FILE* file); + +#else + +inline void iree_flags_set_usage(const char* program_name, const char* usage) {} +inline int iree_flags_parse(iree_flags_parse_mode_t mode, int* argc, + char*** argv) { + return 0; +} +inline void iree_flags_parse_checked(iree_flags_parse_mode_t mode, int* argc, + char*** argv) {} +inline void iree_flags_dump(iree_flag_dump_mode_t mode, FILE* file) {} + +#endif // IREE_FLAGS_ENABLE_CLI #ifdef __cplusplus } // extern "C"
diff --git a/iree/base/internal/flags_demo.c b/iree/base/internal/flags_demo.c new file mode 100644 index 0000000..34cd671 --- /dev/null +++ b/iree/base/internal/flags_demo.c
@@ -0,0 +1,67 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <inttypes.h> +#include <stdint.h> +#include <stdio.h> + +#include "iree/base/internal/flags.h" + +IREE_FLAG(bool, test_bool, false, "A boolean value."); +IREE_FLAG(int32_t, test_int32, 123, "An int32_t value."); +IREE_FLAG(int64_t, test_int64, 555, "An int64_t value."); +IREE_FLAG(float, test_float, 1.0f, "A float value."); +IREE_FLAG(string, test_string, "some default", "A string\nvalue."); + +static iree_status_t parse_callback(iree_string_view_t flag_name, void* storage, + iree_string_view_t value) { + int* count_ptr = (int*)storage; + if (strcmp(value.data, "FORCE_FAILURE") == 0) { + return iree_make_status(IREE_STATUS_INTERNAL, + "callbacks can do verification"); + } + *count_ptr += atoi(value.data); + return iree_ok_status(); +} +static void print_callback(iree_string_view_t flag_name, void* storage, + FILE* file) { + int* count_ptr = (int*)storage; + fprintf(file, "--%.*s=%d\n", (int)flag_name.size, flag_name.data, *count_ptr); +} +static int callback_count = 0; +IREE_FLAG_CALLBACK(parse_callback, print_callback, &callback_count, + test_callback, "Callback!"); + +int main(int argc, char** argv) { + // Parse flags, updating argc/argv with position arguments. + iree_flags_parse_checked(IREE_FLAGS_PARSE_MODE_DEFAULT, &argc, &argv); + + // Report parsed flag values: + printf("FLAG[test_bool] = %s\n", FLAG_test_bool ? "true" : "false"); + printf("FLAG[test_int32] = %" PRId32 "\n", FLAG_test_int32); + printf("FLAG[test_int64] = %" PRId64 "\n", FLAG_test_int64); + printf("FLAG[test_float] = %g\n", FLAG_test_float); + printf("FLAG[test_string] = %s\n", FLAG_test_string); + printf("FLAG[test_callback] = %d\n", callback_count); + + // Report positional arguments: + for (int i = 0; i < argc; ++i) { + printf("ARG(%d) = %s\n", i, argv[i]); + } + + // Dump all flags back out for round-tripping: + iree_flags_dump(IREE_FLAG_DUMP_MODE_DEFAULT, stdout); + + return 0; +}
diff --git a/iree/base/internal/flags_test.txt b/iree/base/internal/flags_test.txt new file mode 100644 index 0000000..2fca1be --- /dev/null +++ b/iree/base/internal/flags_test.txt
@@ -0,0 +1,94 @@ +// RUN: ( flags_demo ) | IreeFileCheck --check-prefix=NO-FLAGS %s +// NO-FLAGS: FLAG[test_bool] = false +// NO-FLAGS: FLAG[test_int32] = 123 +// NO-FLAGS: FLAG[test_int64] = 555 +// NO-FLAGS: FLAG[test_float] = 1 +// NO-FLAGS: FLAG[test_string] = some default +// NO-FLAGS: FLAG[test_callback] = 0 +// NO-FLAGS: ARG(0) ={{.+}}flags_demo + +// RUN: ( flags_demo --help ) | IreeFileCheck --check-prefix=FLAGS-HELP %s +// FLAGS-HELP: # {{.+}} IREE +// FLAGS-HELP: # Flags in {{.+}}flags.c +// FLAGS-HELP: # Displays command line usage information. +// FLAGS-HELP: --help +// FLAGS-HELP: # Flags in {{.+}}flags_demo.c +// FLAGS-HELP: # A boolean value. +// FLAGS-HELP: --test_bool=false +// FLAGS-HELP: # An int32_t value. +// FLAGS-HELP: --test_int32=123 +// FLAGS-HELP: # An int64_t value. +// FLAGS-HELP: --test_int64=555 +// FLAGS-HELP: # A float value. +// FLAGS-HELP: --test_float=1 +// FLAGS-HELP: # A string +// FLAGS-HELP: # value. +// FLAGS-HELP: --test_string="some default" +// FLAGS-HELP: # Callback! +// FLAGS-HELP: --test_callback=0 + +// RUN: ( flags_demo --unknown-flag 2>&1 || [[ $? == 1 ]] ) | IreeFileCheck --check-prefix=UNKNOWN-FLAG %s +// UNKNOWN-FLAG: INVALID_ARGUMENT; flag 'unknown-flag' not recognized + +// RUN: ( flags_demo --test_bool=true ) | IreeFileCheck --check-prefix=FLAG-BOOL-TRUE %s +// FLAG-BOOL-TRUE: FLAG[test_bool] = true +// RUN: ( flags_demo --test_bool=1 ) | IreeFileCheck --check-prefix=FLAG-BOOL-1 %s +// FLAG-BOOL-1: FLAG[test_bool] = true +// RUN: ( flags_demo --test_bool=true --test_bool=false ) | IreeFileCheck --check-prefix=FLAG-BOOL-OVERRIDE %s +// FLAG-BOOL-OVERRIDE: FLAG[test_bool] = false + +// RUN: ( flags_demo --test_int32=456 ) | IreeFileCheck --check-prefix=FLAG-INT32 %s +// FLAG-INT32: FLAG[test_int32] = 456 +// RUN: ( flags_demo --test_int32=-2147483648 ) | IreeFileCheck --check-prefix=FLAG-INT32-MIN %s +// FLAG-INT32-MIN: FLAG[test_int32] = -2147483648 +// RUN: ( flags_demo --test_int32=2147483647 ) | IreeFileCheck --check-prefix=FLAG-INT32-MAX %s +// FLAG-INT32-MAX: FLAG[test_int32] = 2147483647 + +// RUN: ( flags_demo --test_int64=902834 ) | IreeFileCheck --check-prefix=FLAG-INT64 %s +// FLAG-INT64: FLAG[test_int64] = 902834 +// RUN: ( flags_demo --test_int64=-9223372036854775808 ) | IreeFileCheck --check-prefix=FLAG-INT64-MIN %s +// FLAG-INT64-MIN: FLAG[test_int64] = -9223372036854775808 +// RUN: ( flags_demo --test_int64=9223372036854775807 ) | IreeFileCheck --check-prefix=FLAG-INT64-MAX %s +// FLAG-INT64-MAX: FLAG[test_int64] = 9223372036854775807 + +// RUN: ( flags_demo --test_float=1.1234 ) | IreeFileCheck --check-prefix=FLAG-FLOAT %s +// FLAG-FLOAT: FLAG[test_float] = 1.1234 + +// RUN: ( flags_demo --test_string= ) | IreeFileCheck --check-prefix=FLAG-STRING-EMPTY %s +// FLAG-STRING-EMPTY: FLAG[test_string] = +// RUN: ( flags_demo --test_string=abc ) | IreeFileCheck --check-prefix=FLAG-STRING-ABC %s +// FLAG-STRING-ABC: FLAG[test_string] = abc +// RUN: ( flags_demo --test_string="with some space" ) | IreeFileCheck --check-prefix=FLAG-STRING-SPACES %s +// FLAG-STRING-SPACES: FLAG[test_string] = with some space + +// RUN: ( flags_demo --test_callback=1 ) | IreeFileCheck --check-prefix=FLAG-CALLBACK-1 %s +// FLAG-CALLBACK-1: FLAG[test_callback] = 1 +// RUN: ( flags_demo --test_callback=4 ) | IreeFileCheck --check-prefix=FLAG-CALLBACK-4 %s +// FLAG-CALLBACK-4: FLAG[test_callback] = 4 +// RUN: ( flags_demo --test_callback=FORCE_FAILURE 2>&1 || [[ $? == 1 ]] ) | IreeFileCheck --check-prefix=FLAG-CALLBACK-ERROR %s +// FLAG-CALLBACK-ERROR: INTERNAL; callbacks can do verification + +// RUN: ( flags_demo arg1 ) | IreeFileCheck --check-prefix=FLAG-POSITIONAL-1 %s +// FLAG-POSITIONAL-1: ARG(1) = arg1 +// RUN: ( flags_demo arg1 arg2 arg3 ) | IreeFileCheck --check-prefix=FLAG-POSITIONAL-3 %s +// FLAG-POSITIONAL-3: ARG(1) = arg1 +// FLAG-POSITIONAL-3: ARG(2) = arg2 +// FLAG-POSITIONAL-3: ARG(3) = arg3 + +// RUN: ( flags_demo --test_bool=true --flagfile=not_found.txt 2>&1 || [[ $? == 1 ]] ) | IreeFileCheck --check-prefix=MISSING-FLAGFILE %s +// MISSING-FLAGFILE: NOT_FOUND; failed to open file 'not_found.txt' + +// RUN: ( flags_demo --test_bool=true --flagfile=%s ) | IreeFileCheck --check-prefix=FLAGFILE %s +# Comments are ignored. +// FLAGFILE: FLAG[test_bool] = false +--test_bool=false +// FLAGFILE: FLAG[test_int64] = 123111 +// Note that whitespace is ignored in case you are copy/pasting flags around. + --test_int64=123111 +// FLAGFILE: FLAG[test_float] = 55.1 +--test_float=55.1 +// FLAGFILE: FLAG[test_string] = override spaces +--test_string="override spaces" + + +# NOTE: above two lines are to test that vertical whitespace is ok.
diff --git a/iree/base/logging.cc b/iree/base/logging.cc index a8ecd73..eaacb99 100644 --- a/iree/base/logging.cc +++ b/iree/base/logging.cc
@@ -20,13 +20,13 @@ #include <android/log.h> #endif -#include "absl/flags/flag.h" #include "absl/strings/str_format.h" +#include "iree/base/internal/flags.h" #include "iree/base/tracing.h" -ABSL_FLAG(int, iree_minloglevel, 0, +IREE_FLAG(int32_t, iree_minloglevel, 0, "Minimum logging level. 0 = INFO and above."); -ABSL_FLAG(int, iree_v, 0, +IREE_FLAG(int32_t, iree_v, 0, "Verbosity level maximum. 1 = IREE_VLOG(0-1), 2 = IREE_VLOG(0-2)."); namespace iree { @@ -60,7 +60,7 @@ if (LogLevelStrToInt(iree_env_var_val, &level)) { return level; } - return absl::GetFlag(FLAGS_iree_minloglevel); + return FLAG_iree_minloglevel; } int64_t MinVLogLevelFromEnv() { @@ -69,7 +69,7 @@ if (LogLevelStrToInt(iree_env_var_val, &level)) { return level; } - return absl::GetFlag(FLAGS_iree_v); + return FLAG_iree_v; } } // namespace
diff --git a/iree/base/status.c b/iree/base/status.c new file mode 100644 index 0000000..fb58922 --- /dev/null +++ b/iree/base/status.c
@@ -0,0 +1,786 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <assert.h> +#include <errno.h> +#include <limits.h> +#include <stdarg.h> +#include <stdio.h> +#include <stdlib.h> +#include <string.h> + +#include "iree/base/api.h" +#include "iree/base/target_platform.h" +#include "iree/base/tracing.h" + +#if defined(IREE_PLATFORM_WINDOWS) +// https://docs.microsoft.com/en-us/cpp/c-runtime-library/reference/aligned-malloc +#define iree_aligned_alloc(alignment, size) _aligned_malloc(size, alignment) +#define iree_aligned_free(p) _aligned_free(p) +#elif defined(_ISOC11_SOURCE) +// https://en.cppreference.com/w/c/memory/aligned_alloc +#define iree_aligned_alloc(alignment, size) aligned_alloc(alignment, size) +#define iree_aligned_free(p) free(p) +#elif _POSIX_C_SOURCE >= 200112L +// https://pubs.opengroup.org/onlinepubs/9699919799/functions/posix_memalign.html +static inline void* iree_aligned_alloc(size_t alignment, size_t size) { + void* ptr = NULL; + return posix_memalign(&ptr, alignment, size) == 0 ? ptr : NULL; +} +#define iree_aligned_free(p) free(p) +#else +// Emulates alignment with normal malloc. We overallocate by at least the +// alignment + the size of a pointer, store the base pointer at p[-1], and +// return the aligned pointer. This lets us easily get the base pointer in free +// to pass back to the system. +static inline void* iree_aligned_alloc(size_t alignment, size_t size) { + void* base_ptr = malloc(size + alignment + sizeof(uintptr_t)); + if (!base_ptr) return NULL; + uintptr_t* aligned_ptr = (uintptr_t*)iree_math_align( + (uintptr_t)base_ptr + sizeof(uintptr_t), alignment); + aligned_ptr[-1] = (uintptr_t)base_ptr; + return aligned_ptr; +} +static inline void iree_aligned_free(void* p) { + if (IREE_UNLIKELY(!p)) return; + uintptr_t* aligned_ptr = (uintptr_t*)p; + void* base_ptr = (void*)aligned_ptr[-1]; + free(base_ptr); +} +#endif // IREE_PLATFORM_WINDOWS + +//===----------------------------------------------------------------------===// +// iree_status_t canonical errors +//===----------------------------------------------------------------------===// + +IREE_API_EXPORT iree_status_code_t +iree_status_code_from_errno(int error_number) { + switch (error_number) { + case 0: + return IREE_STATUS_OK; + case EINVAL: // Invalid argument + case ENAMETOOLONG: // Filename too long + case E2BIG: // Argument list too long + case EDESTADDRREQ: // Destination address required + case EDOM: // Mathematics argument out of domain of function + case EFAULT: // Bad address + case EILSEQ: // Illegal byte sequence + case ENOPROTOOPT: // Protocol not available + case ENOSTR: // Not a STREAM + case ENOTSOCK: // Not a socket + case ENOTTY: // Inappropriate I/O control operation + case EPROTOTYPE: // Protocol wrong type for socket + case ESPIPE: // Invalid seek + return IREE_STATUS_INVALID_ARGUMENT; + case ETIMEDOUT: // Connection timed out + case ETIME: // Timer expired + return IREE_STATUS_DEADLINE_EXCEEDED; + case ENODEV: // No such device + case ENOENT: // No such file or directory +#ifdef ENOMEDIUM + case ENOMEDIUM: // No medium found +#endif + case ENXIO: // No such device or address + case ESRCH: // No such process + return IREE_STATUS_NOT_FOUND; + case EEXIST: // File exists + case EADDRNOTAVAIL: // Address not available + case EALREADY: // Connection already in progress +#ifdef ENOTUNIQ + case ENOTUNIQ: // Name not unique on network +#endif + return IREE_STATUS_ALREADY_EXISTS; + case EPERM: // Operation not permitted + case EACCES: // Permission denied +#ifdef ENOKEY + case ENOKEY: // Required key not available +#endif + case EROFS: // Read only file system + return IREE_STATUS_PERMISSION_DENIED; + case ENOTEMPTY: // Directory not empty + case EISDIR: // Is a directory + case ENOTDIR: // Not a directory + case EADDRINUSE: // Address already in use + case EBADF: // Invalid file descriptor +#ifdef EBADFD + case EBADFD: // File descriptor in bad state +#endif + case EBUSY: // Device or resource busy + case ECHILD: // No child processes + case EISCONN: // Socket is connected +#ifdef EISNAM + case EISNAM: // Is a named type file +#endif +#ifdef ENOTBLK + case ENOTBLK: // Block device required +#endif + case ENOTCONN: // The socket is not connected + case EPIPE: // Broken pipe +#ifdef ESHUTDOWN + case ESHUTDOWN: // Cannot send after transport endpoint shutdown +#endif + case ETXTBSY: // Text file busy +#ifdef EUNATCH + case EUNATCH: // Protocol driver not attached +#endif + return IREE_STATUS_FAILED_PRECONDITION; + case ENOSPC: // No space left on device +#ifdef EDQUOT + case EDQUOT: // Disk quota exceeded +#endif + case EMFILE: // Too many open files + case EMLINK: // Too many links + case ENFILE: // Too many open files in system + case ENOBUFS: // No buffer space available + case ENODATA: // No message is available on the STREAM read queue + case ENOMEM: // Not enough space + case ENOSR: // No STREAM resources +#ifdef EUSERS + case EUSERS: // Too many users +#endif + return IREE_STATUS_RESOURCE_EXHAUSTED; +#ifdef ECHRNG + case ECHRNG: // Channel number out of range +#endif + case EFBIG: // File too large + case EOVERFLOW: // Value too large to be stored in data type + case ERANGE: // Result too large + return IREE_STATUS_OUT_OF_RANGE; +#ifdef ENOPKG + case ENOPKG: // Package not installed +#endif + case ENOSYS: // Function not implemented + case ENOTSUP: // Operation not supported + case EAFNOSUPPORT: // Address family not supported +#ifdef EPFNOSUPPORT + case EPFNOSUPPORT: // Protocol family not supported +#endif + case EPROTONOSUPPORT: // Protocol not supported +#ifdef ESOCKTNOSUPPORT + case ESOCKTNOSUPPORT: // Socket type not supported +#endif + case EXDEV: // Improper link + return IREE_STATUS_UNIMPLEMENTED; + case EAGAIN: // Resource temporarily unavailable +#ifdef ECOMM + case ECOMM: // Communication error on send +#endif + case ECONNREFUSED: // Connection refused + case ECONNABORTED: // Connection aborted + case ECONNRESET: // Connection reset + case EINTR: // Interrupted function call +#ifdef EHOSTDOWN + case EHOSTDOWN: // Host is down +#endif + case EHOSTUNREACH: // Host is unreachable + case ENETDOWN: // Network is down + case ENETRESET: // Connection aborted by network + case ENETUNREACH: // Network unreachable + case ENOLCK: // No locks available + case ENOLINK: // Link has been severed +#ifdef ENONET + case ENONET: // Machine is not on the network +#endif + return IREE_STATUS_UNAVAILABLE; + case EDEADLK: // Resource deadlock avoided +#ifdef ESTALE + case ESTALE: // Stale file handle +#endif + return IREE_STATUS_ABORTED; + case ECANCELED: // Operation cancelled + return IREE_STATUS_CANCELLED; + default: + return IREE_STATUS_UNKNOWN; + } +} + +#if defined(IREE_PLATFORM_WINDOWS) +IREE_API_EXPORT iree_status_code_t +iree_status_code_from_win32_error(uint32_t error) { + switch (error) { + case ERROR_SUCCESS: + return IREE_STATUS_OK; + case ERROR_FILE_NOT_FOUND: + case ERROR_PATH_NOT_FOUND: + return IREE_STATUS_NOT_FOUND; + case ERROR_TOO_MANY_OPEN_FILES: + case ERROR_OUTOFMEMORY: + case ERROR_HANDLE_DISK_FULL: + case ERROR_HANDLE_EOF: + return IREE_STATUS_RESOURCE_EXHAUSTED; + case ERROR_ACCESS_DENIED: + return IREE_STATUS_PERMISSION_DENIED; + case ERROR_INVALID_HANDLE: + return IREE_STATUS_INVALID_ARGUMENT; + case ERROR_NOT_READY: + case ERROR_READ_FAULT: + return IREE_STATUS_UNAVAILABLE; + case ERROR_WRITE_FAULT: + return IREE_STATUS_DATA_LOSS; + case ERROR_NOT_SUPPORTED: + return IREE_STATUS_UNIMPLEMENTED; + default: + return IREE_STATUS_UNKNOWN; + } +} +#endif // IREE_PLATFORM_WINDOWS + +//===----------------------------------------------------------------------===// +// iree_status_t +//===----------------------------------------------------------------------===// + +IREE_API_EXPORT const char* IREE_API_CALL +iree_status_code_string(iree_status_code_t code) { + switch (code) { + case IREE_STATUS_OK: + return "OK"; + case IREE_STATUS_CANCELLED: + return "CANCELLED"; + case IREE_STATUS_UNKNOWN: + return "UNKNOWN"; + case IREE_STATUS_INVALID_ARGUMENT: + return "INVALID_ARGUMENT"; + case IREE_STATUS_DEADLINE_EXCEEDED: + return "DEADLINE_EXCEEDED"; + case IREE_STATUS_NOT_FOUND: + return "NOT_FOUND"; + case IREE_STATUS_ALREADY_EXISTS: + return "ALREADY_EXISTS"; + case IREE_STATUS_PERMISSION_DENIED: + return "PERMISSION_DENIED"; + case IREE_STATUS_UNAUTHENTICATED: + return "UNAUTHENTICATED"; + case IREE_STATUS_RESOURCE_EXHAUSTED: + return "RESOURCE_EXHAUSTED"; + case IREE_STATUS_FAILED_PRECONDITION: + return "FAILED_PRECONDITION"; + case IREE_STATUS_ABORTED: + return "ABORTED"; + case IREE_STATUS_OUT_OF_RANGE: + return "OUT_OF_RANGE"; + case IREE_STATUS_UNIMPLEMENTED: + return "UNIMPLEMENTED"; + case IREE_STATUS_INTERNAL: + return "INTERNAL"; + case IREE_STATUS_UNAVAILABLE: + return "UNAVAILABLE"; + case IREE_STATUS_DATA_LOSS: + return "DATA_LOSS"; + default: + return ""; + } +} + +// TODO(#55): move payload methods/types to header when API is stabilized. + +// Defines the type of an iree_status_payload_t. +typedef enum { + // Opaque; payload may still be formatted by a formatter but is not possible + // to retrieve by the programmatic APIs. + IREE_STATUS_PAYLOAD_TYPE_OPAQUE = 0u, + // A string message annotation of type iree_status_payload_message_t. + IREE_STATUS_PAYLOAD_TYPE_MESSAGE = 1u, + // Starting type ID for user payloads. IREE reserves all payloads with types + // less than this. + IREE_STATUS_PAYLOAD_TYPE_MIN_USER = 0x70000000u, +} iree_status_payload_type_t; + +typedef struct iree_status_payload_s iree_status_payload_t; + +// Function that formats a payload into a human-readable string form for logs. +typedef void(IREE_API_PTR* iree_status_payload_formatter_t)( + const iree_status_payload_t* payload, iree_host_size_t buffer_capacity, + char* buffer, iree_host_size_t* out_buffer_length); + +// Header for optional status payloads. +// Each status may have zero or more payloads associated with it that can later +// be used to produce more detailed logging or programmatically query +// information about an error. +struct iree_status_payload_s { + // Next payload in the status payload linked list. + struct iree_status_payload_s* next; + // Payload type identifier used for programmatic access to payloads. May be + // IREE_STATUS_PAYLOAD_TYPE_OPAQUE if the payload cannot be accessed directly. + iree_status_payload_type_t type; + // Allocator used for the payload and associated resources. + iree_allocator_t allocator; + // String formatter callback used to write the payload into a string buffer. + // If not present then the payload will be mentioned but not dumped when the + // status is logged. + iree_status_payload_formatter_t formatter; +}; + +// A string message (IREE_STATUS_PAYLOAD_TYPE_MESSAGE). +typedef struct { + iree_status_payload_t header; + // String data reference. May point to an address immediately following this + // struct (if copied) or a constant string reference in rodata. + iree_string_view_t message; +} iree_status_payload_message_t; + +// Allocated storage for an iree_status_t. +// Only statuses that have either source information or payloads will have +// storage allocated for them. +typedef struct { + // Optional doubly-linked list of payloads associated with the status. + // Head = first added, tail = last added. + iree_status_payload_t* payload_head; + iree_status_payload_t* payload_tail; + +#if (IREE_STATUS_FEATURES & IREE_STATUS_FEATURE_SOURCE_LOCATION) != 0 + // __FILE__ of the originating status allocation. + const char* file; + // __LINE__ of the originating status allocation. + uint32_t line; +#endif // has IREE_STATUS_FEATURE_SOURCE_LOCATION + +#if (IREE_STATUS_FEATURES & IREE_STATUS_FEATURE_ANNOTATIONS) != 0 + // Optional message that is allocated either as a constant string in rodata or + // present as a suffix on the storage. + iree_string_view_t message; +#endif // has IREE_STATUS_FEATURE_ANNOTATIONS +} iree_status_storage_t; + +#define iree_status_storage(status) \ + ((iree_status_storage_t*)(((uintptr_t)(status) & ~IREE_STATUS_CODE_MASK))) + +// Appends a payload to the storage doubly-linked list. +static iree_status_t iree_status_append_payload( + iree_status_t status, iree_status_storage_t* storage, + iree_status_payload_t* payload) { + if (!storage->payload_tail) { + storage->payload_head = payload; + } else { + storage->payload_tail->next = payload; + } + storage->payload_tail = payload; + return status; +} + +// Formats an iree_status_payload_message_t to the given output |buffer|. +// |out_buffer_length| will be set to the number of characters written excluding +// NUL. If |buffer| is omitted then |out_buffer_length| will be set to the +// total number of characters in |buffer_capacity| required to contain the +// entire message. +static void IREE_API_CALL iree_status_payload_message_formatter( + const iree_status_payload_t* payload, iree_host_size_t buffer_capacity, + char* buffer, iree_host_size_t* out_buffer_length) { + iree_status_payload_message_t* message_payload = + (iree_status_payload_message_t*)payload; + if (!buffer) { + *out_buffer_length = message_payload->message.size; + return; + } + iree_host_size_t n = buffer_capacity < message_payload->message.size + ? buffer_capacity + : message_payload->message.size; + memcpy(buffer, message_payload->message.data, n); + buffer[n] = '\0'; + *out_buffer_length = n; +} + +// Captures the current stack and attaches it to the status storage. +// A count of |skip_frames| will be skipped from the top of the stack. +// Setting |skip_frames|=0 will include the caller in the stack while +// |skip_frames|=1 will exclude it. +static void iree_status_attach_stack_trace(iree_status_storage_t* storage, + int skip_frames) { +#if (IREE_STATUS_FEATURES & IREE_STATUS_FEATURE_STACK_TRACE) != 0 + // TODO(#55): backtrace or other magic. +#endif // has IREE_STATUS_FEATURE_STACK_TRACE +} + +IREE_API_EXPORT IREE_MUST_USE_RESULT iree_status_t IREE_API_CALL +iree_status_allocate(iree_status_code_t code, const char* file, uint32_t line, + iree_string_view_t message) { +#if IREE_STATUS_FEATURES == 0 + // More advanced status code features like source location and messages are + // disabled. All statuses are just the codes. + return iree_status_from_code(code); +#else + // No-op for OK statuses; we won't get these from the macros but may be called + // with this from marshaling code. + if (IREE_UNLIKELY(code == IREE_STATUS_OK)) return iree_ok_status(); + + // Allocate storage with the appropriate alignment such that we can pack the + // code in the lower bits of the pointer. Since failed statuses are rare and + // likely have much larger costs (like string formatting) the extra bytes for + // alignment are worth being able to avoid pointer dereferences and other + // things during the normal code paths that just check codes. + // + // Note that we are using the CRT allocation function here, as we can't trust + // our allocator system to work when we are throwing errors (as we may be + // allocating this error from a failed allocation!). + size_t storage_alignment = (IREE_STATUS_CODE_MASK + 1); + size_t storage_size = + iree_math_align(sizeof(iree_status_storage_t), storage_alignment); + iree_status_storage_t* storage = (iree_status_storage_t*)iree_aligned_alloc( + storage_alignment, storage_size); + if (IREE_UNLIKELY(!storage)) return iree_status_from_code(code); + memset(storage, 0, sizeof(*storage)); + +#if (IREE_STATUS_FEATURES & IREE_STATUS_FEATURE_SOURCE_LOCATION) != 0 + storage->file = file; + storage->line = line; +#endif // has IREE_STATUS_FEATURE_SOURCE_LOCATION + +#if (IREE_STATUS_FEATURES & IREE_STATUS_FEATURE_ANNOTATIONS) != 0 + // NOTE: messages are rodata strings here and not retained. + storage->message = message; +#endif // has IREE_STATUS_FEATURE_ANNOTATIONS + + iree_status_attach_stack_trace(storage, /*skip_frames=*/1); + return (iree_status_t)((uintptr_t)storage | (code & IREE_STATUS_CODE_MASK)); +#endif // has any IREE_STATUS_FEATURES +} + +IREE_API_EXPORT IREE_MUST_USE_RESULT iree_status_t IREE_API_CALL +iree_status_allocate_f(iree_status_code_t code, const char* file, uint32_t line, + const char* format, ...) { + va_list varargs_0, varargs_1; + va_start(varargs_0, format); + va_start(varargs_1, format); + iree_status_t ret = + iree_status_allocate_vf(code, file, line, format, varargs_0, varargs_1); + va_end(varargs_0); + va_end(varargs_1); + return ret; +} + +IREE_API_EXPORT IREE_MUST_USE_RESULT iree_status_t IREE_API_CALL +iree_status_allocate_vf(iree_status_code_t code, const char* file, + uint32_t line, const char* format, va_list varargs_0, + va_list varargs_1) { +#if (IREE_STATUS_FEATURES & IREE_STATUS_FEATURE_ANNOTATIONS) == 0 + // Annotations disabled; ignore the format string/args. + return iree_status_allocate(code, file, line, iree_string_view_empty()); +#else + // No-op for OK statuses; we won't get these from the macros but may be called + // with this from marshaling code. + if (IREE_UNLIKELY(code == IREE_STATUS_OK)) return iree_ok_status(); + + // Compute the total number of bytes (including NUL) required to store the + // message. + size_t message_size = + vsnprintf(/*buffer=*/NULL, /*buffer_count=*/0, format, varargs_0); + if (message_size < 0) return iree_status_from_code(code); + ++message_size; // NUL byte + + // Allocate storage with the additional room to store the formatted message. + // This avoids additional allocations for the common case of a message coming + // only from the original status error site. + size_t storage_alignment = (IREE_STATUS_CODE_MASK + 1); + size_t storage_size = iree_math_align( + sizeof(iree_status_storage_t) + message_size, storage_alignment); + iree_status_storage_t* storage = (iree_status_storage_t*)iree_aligned_alloc( + storage_alignment, storage_size); + if (IREE_UNLIKELY(!storage)) return iree_status_from_code(code); + memset(storage, 0, sizeof(*storage)); + +#if (IREE_STATUS_FEATURES & IREE_STATUS_FEATURE_SOURCE_LOCATION) != 0 + storage->file = file; + storage->line = line; +#endif // has IREE_STATUS_FEATURE_SOURCE_LOCATION + + // vsnprintf directly into message buffer. + storage->message.size = message_size - 1; + storage->message.data = (const char*)storage + sizeof(iree_status_storage_t); + int ret = + vsnprintf((char*)storage->message.data, message_size, format, varargs_1); + if (IREE_UNLIKELY(ret < 0)) { + iree_aligned_free(storage); + return (iree_status_t)code; + } + + iree_status_attach_stack_trace(storage, /*skip_frames=*/1); + return (iree_status_t)((uintptr_t)storage | (code & IREE_STATUS_CODE_MASK)); +#endif // has IREE_STATUS_FEATURE_ANNOTATIONS +} + +IREE_API_EXPORT IREE_MUST_USE_RESULT iree_status_t IREE_API_CALL +iree_status_clone(iree_status_t status) { +#if IREE_STATUS_FEATURES == 0 + // Statuses are just codes; nothing to do. + return status; +#else + iree_status_storage_t* storage = iree_status_storage(status); + if (!storage) return status; + +#if (IREE_STATUS_FEATURES & IREE_STATUS_FEATURE_SOURCE_LOCATION) != 0 + const char* file = storage->file; + uint32_t line = storage->line; +#else + const char* file = NULL; + uint32_t line = 0; +#endif // has IREE_STATUS_FEATURE_SOURCE_LOCATION + +#if (IREE_STATUS_FEATURES & IREE_STATUS_FEATURE_ANNOTATIONS) != 0 + iree_string_view_t message = storage->message; +#else + iree_string_view_t message = iree_string_view_empty(); +#endif // has IREE_STATUS_FEATURE_ANNOTATIONS + + // Always copy the message by performing the formatting as we don't know + // whether the original status has ownership or not. + return iree_status_allocate_f(iree_status_code(status), file, line, "%.*s", + (int)message.size, message.data); +#endif // has no IREE_STATUS_FEATURES +} + +IREE_API_EXPORT void IREE_API_CALL iree_status_free(iree_status_t status) { +#if IREE_STATUS_FEATURES != 0 + iree_status_storage_t* storage = iree_status_storage(status); + if (!storage) return; + iree_status_payload_t* payload = storage->payload_head; + while (payload) { + iree_status_payload_t* next = payload->next; + iree_allocator_free(payload->allocator, payload); + payload = next; + } + iree_aligned_free(storage); +#endif // has any IREE_STATUS_FEATURES +} + +IREE_API_EXPORT iree_status_t IREE_API_CALL +iree_status_ignore(iree_status_t status) { + // We can set an 'ignored' flag on the status so that we can otherwise assert + // in iree_status_free when statuses are freed without this being called. + // Hoping with the C++ Status wrapper we won't hit that often so that + // complexity is skipped for now. + iree_status_free(status); + return iree_ok_status(); +} + +IREE_API_EXPORT iree_status_code_t IREE_API_CALL +iree_status_consume_code(iree_status_t status) { + iree_status_code_t code = iree_status_code(status); + iree_status_free(status); + return code; +} + +IREE_API_EXPORT IREE_MUST_USE_RESULT iree_status_t IREE_API_CALL +iree_status_annotate(iree_status_t base_status, iree_string_view_t message) { +#if (IREE_STATUS_FEATURES & IREE_STATUS_FEATURE_ANNOTATIONS) == 0 + // Annotations are disabled so we ignore this entirely. + return base_status; +#else + if (iree_status_is_ok(base_status) || iree_string_view_is_empty(message)) { + return base_status; + } + + // If there's no storage yet we can just reuse normal allocation. Both that + // and this do not copy |message|. + iree_status_storage_t* storage = iree_status_storage(base_status); + if (!storage) { + return iree_status_allocate(iree_status_code(base_status), NULL, 0, + message); + } else if (iree_string_view_is_empty(storage->message)) { + storage->message = message; + return base_status; + } + + iree_allocator_t allocator = iree_allocator_system(); + iree_status_payload_message_t* payload = NULL; + iree_status_ignore( + iree_allocator_malloc(allocator, sizeof(*payload), (void**)&payload)); + if (IREE_UNLIKELY(!payload)) return base_status; + memset(payload, 0, sizeof(*payload)); + payload->header.type = IREE_STATUS_PAYLOAD_TYPE_MESSAGE; + payload->header.allocator = allocator; + payload->header.formatter = iree_status_payload_message_formatter; + payload->message = message; + return iree_status_append_payload(base_status, storage, + (iree_status_payload_t*)payload); +#endif // has IREE_STATUS_FEATURE_ANNOTATIONS +} + +IREE_API_EXPORT IREE_MUST_USE_RESULT iree_status_t IREE_API_CALL +IREE_PRINTF_ATTRIBUTE(2, 3) + iree_status_annotate_f(iree_status_t base_status, const char* format, ...) { + va_list varargs_0, varargs_1; + va_start(varargs_0, format); + va_start(varargs_1, format); + iree_status_t ret = + iree_status_annotate_vf(base_status, format, varargs_0, varargs_1); + va_end(varargs_0); + va_end(varargs_1); + return ret; +} + +IREE_API_EXPORT IREE_MUST_USE_RESULT iree_status_t IREE_API_CALL +iree_status_annotate_vf(iree_status_t base_status, const char* format, + va_list varargs_0, va_list varargs_1) { +#if (IREE_STATUS_FEATURES & IREE_STATUS_FEATURE_ANNOTATIONS) == 0 + return base_status; +#else + if (iree_status_is_ok(base_status)) return base_status; + + // If there's no storage yet we can just reuse normal allocation. Both that + // and this do not copy |message|. + iree_status_storage_t* storage = iree_status_storage(base_status); + if (!storage) { + return iree_status_allocate_vf(iree_status_code(base_status), NULL, 0, + format, varargs_0, varargs_1); + } + + // Compute the total number of bytes (including NUL) required to store the + // message. + size_t message_size = + vsnprintf(/*buffer=*/NULL, /*buffer_count=*/0, format, varargs_0); + va_end(varargs_0); + if (message_size < 0) return base_status; + ++message_size; // NUL byte + + // Allocate storage with the additional room to store the formatted message. + // This avoids additional allocations for the common case of a message coming + // only from the original status error site. + iree_allocator_t allocator = iree_allocator_system(); + iree_status_payload_message_t* payload = NULL; + iree_status_ignore(iree_allocator_malloc( + allocator, sizeof(*payload) + message_size, (void**)&payload)); + if (IREE_UNLIKELY(!payload)) return base_status; + memset(payload, 0, sizeof(*payload)); + payload->header.type = IREE_STATUS_PAYLOAD_TYPE_MESSAGE; + payload->header.allocator = allocator; + payload->header.formatter = iree_status_payload_message_formatter; + + // vsnprintf directly into message buffer. + payload->message.size = message_size - 1; + payload->message.data = + (const char*)payload + sizeof(iree_status_payload_message_t); + int ret = vsnprintf((char*)payload->message.data, payload->message.size + 1, + format, varargs_1); + if (IREE_UNLIKELY(ret < 0)) { + iree_aligned_free(payload); + return base_status; + } + return iree_status_append_payload(base_status, storage, + (iree_status_payload_t*)payload); +#endif // has IREE_STATUS_FEATURE_ANNOTATIONS +} + +IREE_API_EXPORT bool IREE_API_CALL +iree_status_format(iree_status_t status, iree_host_size_t buffer_capacity, + char* buffer, iree_host_size_t* out_buffer_length) { + *out_buffer_length = 0; + + // Grab storage which may have a message and zero or more payloads. + iree_status_storage_t* storage = iree_status_storage(status); + + // Prefix with source location and status code string (may be 'OK'). + iree_host_size_t buffer_length = 0; + iree_status_code_t status_code = iree_status_code(status); + int n = 0; +#if (IREE_STATUS_FEATURES & IREE_STATUS_FEATURE_SOURCE_LOCATION) != 0 + if (storage && storage->file) { + n = snprintf(buffer ? buffer + buffer_length : NULL, + buffer ? buffer_capacity - buffer_length : 0, "%s:%d: %s", + storage->file, storage->line, + iree_status_code_string(status_code)); + } else { + n = snprintf(buffer ? buffer + buffer_length : NULL, + buffer ? buffer_capacity - buffer_length : 0, "%s", + iree_status_code_string(status_code)); + } +#else + n = snprintf(buffer ? buffer + buffer_length : NULL, + buffer ? buffer_capacity - buffer_length : 0, "%s", + iree_status_code_string(status_code)); +#endif // has IREE_STATUS_FEATURE_SOURCE_LOCATION + if (IREE_UNLIKELY(n < 0)) { + return false; + } else if (buffer && n >= buffer_capacity - buffer_length) { + buffer = NULL; + } + buffer_length += n; + +#if (IREE_STATUS_FEATURES & IREE_STATUS_FEATURE_ANNOTATIONS) != 0 + // Append base storage message. + if (storage && !iree_string_view_is_empty(storage->message)) { + n = snprintf(buffer ? buffer + buffer_length : NULL, + buffer ? buffer_capacity - buffer_length : 0, "; %.*s", + (int)storage->message.size, storage->message.data); + if (IREE_UNLIKELY(n < 0)) { + return false; + } else if (buffer && n >= buffer_capacity - buffer_length) { + buffer = NULL; + } + buffer_length += n; + } +#endif // has IREE_STATUS_FEATURE_ANNOTATIONS + +#if IREE_STATUS_FEATURES != 0 + // Append each payload separated by a newline. + iree_status_payload_t* payload = storage ? storage->payload_head : NULL; + while (payload != NULL) { + // Skip payloads that have no textual representation. + if (!payload->formatter) { + payload = payload->next; + continue; + } + + // Append newline to join with message above and other payloads. + if (buffer) { + if (2 >= buffer_capacity - buffer_length) { + buffer = NULL; + } else { + buffer[buffer_length] = ';'; + buffer[buffer_length + 1] = ' '; + buffer[buffer_length + 2] = '\0'; + } + } + buffer_length += 2; // '; ' + + // Append payload via custom formatter callback. + iree_host_size_t payload_buffer_length = 0; + payload->formatter(payload, buffer ? buffer_capacity - buffer_length : 0, + buffer ? buffer + buffer_length : NULL, + &payload_buffer_length); + if (buffer && payload_buffer_length >= buffer_capacity - buffer_length) { + buffer = NULL; + } + buffer_length += payload_buffer_length; + + payload = payload->next; + } +#endif // has IREE_STATUS_FEATURES + + *out_buffer_length = buffer_length; + return true; +} + +IREE_API_EXPORT bool IREE_API_CALL +iree_status_to_string(iree_status_t status, char** out_buffer, + iree_host_size_t* out_buffer_length) { + *out_buffer_length = 0; + iree_host_size_t buffer_length = 0; + if (IREE_UNLIKELY(!iree_status_format(status, /*buffer_capacity=*/0, + /*buffer=*/NULL, &buffer_length))) { + return false; + } + // Buffer capacity needs to be +1 to account for the terminating null of + // snprintf. + buffer_length++; + char* buffer = (char*)malloc(buffer_length); + if (IREE_UNLIKELY(!buffer)) return false; + bool ret = + iree_status_format(status, buffer_length, buffer, out_buffer_length); + if (ret) { + *out_buffer = buffer; + return true; + } else { + free(buffer); + return false; + } +}
diff --git a/iree/base/string_view.c b/iree/base/string_view.c new file mode 100644 index 0000000..07a27ca --- /dev/null +++ b/iree/base/string_view.c
@@ -0,0 +1,224 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <ctype.h> +#include <limits.h> +#include <stdlib.h> +#include <string.h> + +#include "iree/base/api.h" + +static inline size_t iree_min_host_size(size_t a, size_t b) { + return a < b ? a : b; +} + +IREE_API_EXPORT bool IREE_API_CALL +iree_string_view_equal(iree_string_view_t lhs, iree_string_view_t rhs) { + if (lhs.size != rhs.size) return false; + for (iree_host_size_t i = 0; i < lhs.size; ++i) { + if (lhs.data[i] != rhs.data[i]) return false; + } + return true; +} + +IREE_API_EXPORT int IREE_API_CALL +iree_string_view_compare(iree_string_view_t lhs, iree_string_view_t rhs) { + iree_host_size_t min_size = iree_min_host_size(lhs.size, rhs.size); + int cmp = strncmp(lhs.data, rhs.data, min_size); + if (cmp != 0) { + return cmp; + } else if (lhs.size == rhs.size) { + return 0; + } + return lhs.size < rhs.size ? -1 : 1; +} + +IREE_API_EXPORT bool IREE_API_CALL iree_string_view_starts_with( + iree_string_view_t value, iree_string_view_t prefix) { + if (!value.data || !prefix.data || prefix.size > value.size) { + return false; + } + return strncmp(value.data, prefix.data, prefix.size) == 0; +} + +IREE_API_EXPORT iree_host_size_t IREE_API_CALL iree_string_view_find_char( + iree_string_view_t value, char c, iree_host_size_t pos) { + if (iree_string_view_is_empty(value) || pos >= value.size) { + return IREE_STRING_VIEW_NPOS; + } + const char* result = + (const char*)(memchr(value.data + pos, c, value.size - pos)); + return result != NULL ? result - value.data : IREE_STRING_VIEW_NPOS; +} + +IREE_API_EXPORT iree_host_size_t IREE_API_CALL iree_string_view_find_first_of( + iree_string_view_t value, iree_string_view_t s, iree_host_size_t pos) { + if (iree_string_view_is_empty(value) || iree_string_view_is_empty(s)) { + return IREE_STRING_VIEW_NPOS; + } + if (s.size == 1) { + // Avoid the cost of the lookup table for a single-character search. + return iree_string_view_find_char(value, s.data[0], pos); + } + bool lookup_table[UCHAR_MAX + 1] = {0}; + for (iree_host_size_t i = 0; i < s.size; ++i) { + lookup_table[(uint8_t)s.data[i]] = true; + } + for (iree_host_size_t i = pos; i < value.size; ++i) { + if (lookup_table[(uint8_t)value.data[i]]) { + return i; + } + } + return IREE_STRING_VIEW_NPOS; +} + +IREE_API_EXPORT iree_host_size_t IREE_API_CALL iree_string_view_find_last_of( + iree_string_view_t value, iree_string_view_t s, iree_host_size_t pos) { + if (iree_string_view_is_empty(value) || iree_string_view_is_empty(s)) { + return IREE_STRING_VIEW_NPOS; + } + bool lookup_table[UCHAR_MAX + 1] = {0}; + for (iree_host_size_t i = 0; i < s.size; ++i) { + lookup_table[(uint8_t)s.data[i]] = true; + } + pos = iree_min(pos, value.size); + iree_host_size_t i = pos; + while (i != 0) { + --i; + if (lookup_table[(uint8_t)value.data[i]]) { + return i; + } + } + return IREE_STRING_VIEW_NPOS; +} + +IREE_API_EXPORT iree_string_view_t IREE_API_CALL +iree_string_view_remove_prefix(iree_string_view_t value, iree_host_size_t n) { + if (n >= value.size) { + return iree_string_view_empty(); + } + return iree_make_string_view(value.data + n, value.size - n); +} + +IREE_API_EXPORT iree_string_view_t IREE_API_CALL +iree_string_view_trim(iree_string_view_t value) { + if (iree_string_view_is_empty(value)) return value; + iree_host_size_t start = 0; + iree_host_size_t end = value.size - 1; + while (value.size > 0 && start <= end) { + if (isspace(value.data[start])) { + start++; + } else { + break; + } + } + while (end > start) { + if (isspace(value.data[end])) { + --end; + } else { + break; + } + } + return iree_make_string_view(value.data + start, end - start + 1); +} + +IREE_API_EXPORT iree_string_view_t IREE_API_CALL iree_string_view_substr( + iree_string_view_t value, iree_host_size_t pos, iree_host_size_t n) { + pos = iree_min_host_size(pos, value.size); + n = iree_min_host_size(n, value.size - pos); + return iree_make_string_view(value.data + pos, n); +} + +IREE_API_EXPORT intptr_t IREE_API_CALL iree_string_view_split( + iree_string_view_t value, char split_char, iree_string_view_t* out_lhs, + iree_string_view_t* out_rhs) { + *out_lhs = iree_string_view_empty(); + *out_rhs = iree_string_view_empty(); + if (!value.data || !value.size) { + return -1; + } + const void* first_ptr = memchr(value.data, split_char, value.size); + if (!first_ptr) { + *out_lhs = value; + return -1; + } + intptr_t offset = (intptr_t)((const char*)(first_ptr)-value.data); + if (out_lhs) { + out_lhs->data = value.data; + out_lhs->size = offset; + } + if (out_rhs) { + out_rhs->data = value.data + offset + 1; + out_rhs->size = value.size - offset - 1; + } + return offset; +} + +static bool iree_string_view_match_pattern_impl(iree_string_view_t value, + iree_string_view_t pattern) { + iree_host_size_t next_char_index = iree_string_view_find_first_of( + pattern, iree_make_cstring_view("*?"), /*pos=*/0); + if (next_char_index == IREE_STRING_VIEW_NPOS) { + return iree_string_view_equal(value, pattern); + } else if (next_char_index > 0) { + iree_string_view_t value_prefix = + iree_string_view_substr(value, 0, next_char_index); + iree_string_view_t pattern_prefix = + iree_string_view_substr(pattern, 0, next_char_index); + if (!iree_string_view_equal(value_prefix, pattern_prefix)) { + return false; + } + value = + iree_string_view_substr(value, next_char_index, IREE_STRING_VIEW_NPOS); + pattern = iree_string_view_substr(pattern, next_char_index, + IREE_STRING_VIEW_NPOS); + } + if (iree_string_view_is_empty(value) && iree_string_view_is_empty(pattern)) { + return true; + } + char pattern_char = pattern.data[0]; + if (pattern_char == '*' && pattern.size > 1 && + iree_string_view_is_empty(value)) { + return false; + } else if (pattern_char == '*' && pattern.size == 1) { + return true; + } else if (pattern_char == '?' || value.data[0] == pattern_char) { + return iree_string_view_match_pattern_impl( + iree_string_view_substr(value, 1, IREE_STRING_VIEW_NPOS), + iree_string_view_substr(pattern, 1, IREE_STRING_VIEW_NPOS)); + } else if (pattern_char == '*') { + return iree_string_view_match_pattern_impl( + value, + iree_string_view_substr(pattern, 1, IREE_STRING_VIEW_NPOS)) || + iree_string_view_match_pattern_impl( + iree_string_view_substr(value, 1, IREE_STRING_VIEW_NPOS), + pattern); + } + return false; +} + +IREE_API_EXPORT bool IREE_API_CALL iree_string_view_match_pattern( + iree_string_view_t value, iree_string_view_t pattern) { + return iree_string_view_match_pattern_impl(value, pattern); +} + +IREE_API_EXPORT iree_host_size_t IREE_API_CALL +iree_string_view_append_to_buffer(iree_string_view_t source_value, + iree_string_view_t* target_value, + char* buffer) { + memcpy(buffer, source_value.data, source_value.size); + target_value->data = buffer; + target_value->size = source_value.size; + return source_value.size; +}
diff --git a/iree/base/string_view_test.cc b/iree/base/string_view_test.cc new file mode 100644 index 0000000..91fa6b8 --- /dev/null +++ b/iree/base/string_view_test.cc
@@ -0,0 +1,41 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <string> + +#include "iree/base/api.h" +#include "iree/testing/gtest.h" + +namespace { + +std::string ToString(iree_string_view_t value) { + return std::string(value.data, value.size); +} + +TEST(StringViewTest, Trim) { + auto trim = [](const char* value) -> std::string { + return ToString(iree_string_view_trim(iree_make_cstring_view(value))); + }; + ASSERT_EQ(trim(""), ""); + ASSERT_EQ(trim("a"), "a"); + ASSERT_EQ(trim(" a"), "a"); + ASSERT_EQ(trim("a "), "a"); + ASSERT_EQ(trim("a b"), "a b"); + ASSERT_EQ(trim(" a b "), "a b"); + ASSERT_EQ(trim("\t\t\na b\n \t "), "a b"); + ASSERT_EQ(trim("\n"), ""); + ASSERT_EQ(trim("\r\n"), ""); +} + +} // namespace
diff --git a/iree/base/time.c b/iree/base/time.c new file mode 100644 index 0000000..28e11fe --- /dev/null +++ b/iree/base/time.c
@@ -0,0 +1,72 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <limits.h> +#include <stdio.h> +#include <stdlib.h> +#include <time.h> + +#include "iree/base/api.h" +#include "iree/base/target_platform.h" + +IREE_API_EXPORT iree_time_t iree_time_now() { +#if defined(IREE_PLATFORM_WINDOWS) + // GetSystemTimePreciseAsFileTime requires Windows 8, add a fallback + // (such as using std::chrono) if older support is needed. + FILETIME system_time; + GetSystemTimePreciseAsFileTime(&system_time); + + const int64_t kUnixEpochStartTicks = 116444736000000000i64; + const int64_t kFtToMicroSec = 10; + LARGE_INTEGER li; + li.LowPart = system_time.dwLowDateTime; + li.HighPart = system_time.dwHighDateTime; + li.QuadPart -= kUnixEpochStartTicks; + li.QuadPart /= kFtToMicroSec; + return li.QuadPart; +#elif defined(IREE_PLATFORM_ANDROID) || defined(IREE_PLATFORM_APPLE) || \ + defined(IREE_PLATFORM_LINUX) + struct timespec clock_time; + clock_gettime(CLOCK_REALTIME, &clock_time); + return clock_time.tv_nsec; +#else +#error "IREE system clock needs to be set up for your platform" +#endif // IREE_PLATFORM_* +} + +IREE_API_EXPORT iree_time_t +iree_relative_timeout_to_deadline_ns(iree_duration_t timeout_ns) { + if (timeout_ns == IREE_DURATION_ZERO) { + return IREE_TIME_INFINITE_PAST; + } else if (timeout_ns == IREE_DURATION_INFINITE) { + return IREE_TIME_INFINITE_FUTURE; + } + return iree_time_now() + timeout_ns; +} + +IREE_API_EXPORT iree_duration_t +iree_absolute_deadline_to_timeout_ns(iree_time_t deadline_ns) { + if (deadline_ns == IREE_TIME_INFINITE_PAST) { + return IREE_DURATION_ZERO; + } else if (deadline_ns == IREE_TIME_INFINITE_FUTURE) { + return IREE_DURATION_INFINITE; + } else { + // We have either already passed the deadline (and can turn this into a + // poll) or want to do nanos->millis. We round up so that a deadline of 1ns + // results in 1ms as it should still wait, vs. if it was actually 0ns + // indicating the user intended a poll. + iree_time_t now_ns = iree_time_now(); + return deadline_ns < now_ns ? IREE_DURATION_ZERO : deadline_ns - now_ns; + } +}
diff --git a/iree/base/tracing.h b/iree/base/tracing.h index f73ac4c..0293313 100644 --- a/iree/base/tracing.h +++ b/iree/base/tracing.h
@@ -410,6 +410,8 @@ #define IREE_TRACE_ZONE_SET_COLOR(zone_id, color_xrgb) #define IREE_TRACE_ZONE_APPEND_VALUE(zone_id, value) #define IREE_TRACE_ZONE_APPEND_TEXT(zone_id, ...) +#define IREE_TRACE_ZONE_APPEND_TEXT_CSTRING(zone_id, value) +#define IREE_TRACE_ZONE_APPEND_TEXT_STRING_VIEW(zone_id, value, value_length) #define IREE_TRACE_ZONE_END(zone_id) #define IREE_RETURN_AND_END_ZONE_IF_ERROR(zone_id, ...) \ IREE_RETURN_IF_ERROR(__VA_ARGS__)
diff --git a/iree/compiler/Conversion/LinalgToLLVM/LinalgTileAndVectorizePass.cpp b/iree/compiler/Conversion/LinalgToLLVM/LinalgTileAndVectorizePass.cpp index 25d8f4f..7a2e336 100644 --- a/iree/compiler/Conversion/LinalgToLLVM/LinalgTileAndVectorizePass.cpp +++ b/iree/compiler/Conversion/LinalgToLLVM/LinalgTileAndVectorizePass.cpp
@@ -257,6 +257,8 @@ .insert<ContractionOpToOuterProductOpLowering, ContractionOpToMatmulOpLowering, ContractionOpLowering>( vectorTransformsOptions, context); + vector::populateVectorTransferLoweringPatterns( + vectorContractLoweringPatterns); if (failed(applyPatternsAndFoldGreedily( funcOp, std::move(vectorContractLoweringPatterns)))) { return signalPassFailure();
diff --git a/iree/compiler/Conversion/LinalgToNVVM/VectorizationPass.cpp b/iree/compiler/Conversion/LinalgToNVVM/VectorizationPass.cpp index 69971dd..7030e23 100644 --- a/iree/compiler/Conversion/LinalgToNVVM/VectorizationPass.cpp +++ b/iree/compiler/Conversion/LinalgToNVVM/VectorizationPass.cpp
@@ -98,6 +98,18 @@ }); { + // Lower transfer op to canonical form. + OwningRewritePatternList lowerTransferOpPatterns(funcOp.getContext()); + vector::populateVectorToVectorCanonicalizationPatterns( + lowerTransferOpPatterns); + vector::populateVectorToVectorTransformationPatterns( + lowerTransferOpPatterns); + vector::populateVectorTransferLoweringPatterns(lowerTransferOpPatterns); + (void)applyPatternsAndFoldGreedily(funcOp, + std::move(lowerTransferOpPatterns)); + } + + { // Step 2. Unroll the vetors to native size and canonicalize. OwningRewritePatternList vectorUnrollPatterns(context); populateVectorUnrollPatterns(vectorUnrollPatterns);
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/TileAndVectorizeInOneWorkgroupPass.cpp b/iree/compiler/Conversion/LinalgToSPIRV/TileAndVectorizeInOneWorkgroupPass.cpp index 8fc4463..3be9b52 100644 --- a/iree/compiler/Conversion/LinalgToSPIRV/TileAndVectorizeInOneWorkgroupPass.cpp +++ b/iree/compiler/Conversion/LinalgToSPIRV/TileAndVectorizeInOneWorkgroupPass.cpp
@@ -323,29 +323,102 @@ vector::UnrollVectorOptions().setNativeShapeFn(getNativeVectorSize)); } +namespace { + +/// Workaround SPIR-V backend limitations. SPIR-V vetorization pass relies on +/// unrolling to reduce instructions to a vector size we can convert to SPIR-V. +/// When vectorization creates transpose those block unrolling and result in +/// large vector we currently cannot lower. For now we always merge the +/// transpose into the contract op so that it can be unrolled. +// TODO(thomasraoux): Make transpose work with the current unrolling mechanism +// or replace unrolling. +class CombineContractTranspose final + : public OpRewritePattern<vector::ContractionOp> { + public: + using OpRewritePattern<vector::ContractionOp>::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::ContractionOp op, + PatternRewriter &rewriter) const override { + // Perform lhs + rhs transpositions to conform to matmul row-major + // semantics. Bail out if the contraction cannot be put in this form. + MLIRContext *ctx = op.getContext(); + Location loc = op.getLoc(); + bool foundTranspose = false; + std::array<Value, 3> sources = {op.lhs(), op.rhs(), op.acc()}; + SmallVector<AffineMap> newMaps; + SmallVector<Value> newSources; + for (auto source : llvm::enumerate(sources)) { + auto map = + op.indexing_maps()[source.index()].cast<AffineMapAttr>().getValue(); + auto tranposeOp = source.value().getDefiningOp<vector::TransposeOp>(); + if (!tranposeOp) { + newSources.push_back(source.value()); + newMaps.push_back(map); + continue; + } + SmallVector<int64_t, 3> perm; + tranposeOp.getTransp(perm); + SmallVector<AffineExpr> exprs(perm.size()); + for (auto remap : llvm::enumerate(perm)) { + exprs[remap.value()] = map.getResult(remap.index()); + } + newMaps.push_back( + AffineMap::get(map.getNumDims(), map.getNumSymbols(), exprs, ctx)); + newSources.push_back(tranposeOp.vector()); + foundTranspose = true; + } + if (!foundTranspose) return failure(); + + Value res = rewriter.create<vector::ContractionOp>( + loc, newSources[0], newSources[1], newSources[2], + rewriter.getAffineMapArrayAttr(newMaps), op.iterator_types()); + rewriter.replaceOp(op, res); + return success(); + } +}; + +} // namespace + //====---------------------------------------------------------------------===// // Vector patterns //====---------------------------------------------------------------------===// static void applyVectorTransformation(FuncOp funcOp) { { - OwningRewritePatternList vectorUnrollPatterns(funcOp.getContext()); - populateVectorUnrollPatterns(funcOp.getContext(), vectorUnrollPatterns); - (void)applyPatternsAndFoldGreedily(funcOp, std::move(vectorUnrollPatterns)); + { + OwningRewritePatternList lowerTransferOpPatterns(funcOp.getContext()); + vector::populateVectorToVectorCanonicalizationPatterns( + lowerTransferOpPatterns); + vector::populateVectorToVectorTransformationPatterns( + lowerTransferOpPatterns); + vector::populateVectorTransferLoweringPatterns(lowerTransferOpPatterns); + lowerTransferOpPatterns.add<CombineContractTranspose>( + funcOp.getContext()); + (void)applyPatternsAndFoldGreedily(funcOp, + std::move(lowerTransferOpPatterns)); + } + { + OwningRewritePatternList vectorUnrollPatterns(funcOp.getContext()); + populateVectorUnrollPatterns(funcOp.getContext(), vectorUnrollPatterns); + (void)applyPatternsAndFoldGreedily(funcOp, + std::move(vectorUnrollPatterns)); + } + { + OwningRewritePatternList canonicalizationPatterns1(funcOp.getContext()); - OwningRewritePatternList canonicalizationPatterns1(funcOp.getContext()); - vector::populateVectorToVectorCanonicalizationPatterns( - canonicalizationPatterns1); - vector::populateVectorToVectorTransformationPatterns( - canonicalizationPatterns1); - vector::populateSplitVectorTransferPatterns(canonicalizationPatterns1); - (void)applyPatternsAndFoldGreedily(funcOp, - std::move(canonicalizationPatterns1)); + vector::populateVectorToVectorTransformationPatterns( + canonicalizationPatterns1); + vector::populateVectorToVectorCanonicalizationPatterns( + canonicalizationPatterns1); + vector::populateSplitVectorTransferPatterns(canonicalizationPatterns1); + (void)applyPatternsAndFoldGreedily(funcOp, + std::move(canonicalizationPatterns1)); - OwningRewritePatternList canonicalizationPatterns2(funcOp.getContext()); - vector::populateVectorSlicesLoweringPatterns(canonicalizationPatterns2); - (void)applyPatternsAndFoldGreedily(funcOp, - std::move(canonicalizationPatterns2)); + OwningRewritePatternList canonicalizationPatterns2(funcOp.getContext()); + vector::populateVectorSlicesLoweringPatterns(canonicalizationPatterns2); + (void)applyPatternsAndFoldGreedily(funcOp, + std::move(canonicalizationPatterns2)); + } LLVM_DEBUG({ llvm::dbgs() << "--- After Vector Unroll ---\n"; funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
diff --git a/iree/compiler/Dialect/VM/IR/VMOps.td b/iree/compiler/Dialect/VM/IR/VMOps.td index 408f4e6..f95fe08 100644 --- a/iree/compiler/Dialect/VM/IR/VMOps.td +++ b/iree/compiler/Dialect/VM/IR/VMOps.td
@@ -855,6 +855,7 @@ def VM_ListAllocOp : VM_PureOp<"list.alloc", [ DeclareOpInterfaceMethods<VM_SerializableOpInterface>, + MemoryEffects<[MemAlloc]>, ]> { let summary = [{allocates a new empty list}]; let description = [{ @@ -905,6 +906,7 @@ def VM_ListSizeOp : VM_Op<"list.size", [ DeclareOpInterfaceMethods<VM_SerializableOpInterface>, + MemoryEffects<[MemRead]>, ]> { let summary = [{the size of the list in elements}]; let description = [{ @@ -930,6 +932,7 @@ def VM_ListResizeOp : VM_Op<"list.resize", [ DeclareOpInterfaceMethods<VM_SerializableOpInterface>, + MemoryEffects<[MemWrite]>, ]> { let summary = [{resizes the list to a new count in elements}]; let description = [{ @@ -956,6 +959,7 @@ list<OpTrait> traits = []> : VM_PureOp<mnemonic, !listconcat(traits, [ DeclareOpInterfaceMethods<VM_SerializableOpInterface>, + MemoryEffects<[MemRead]>, ])> { let summary = [{primitive type element accessor}]; let description = [{ @@ -984,6 +988,7 @@ list<OpTrait> traits = []> : VM_Op<mnemonic, !listconcat(traits, [ DeclareOpInterfaceMethods<VM_SerializableOpInterface>, + MemoryEffects<[MemWrite]>, ])> { let summary = [{primitive type element mutator}]; let description = [{ @@ -1021,6 +1026,7 @@ def VM_ListGetRefOp : VM_PureOp<"list.get.ref", [ DeclareOpInterfaceMethods<VM_SerializableOpInterface>, + MemoryEffects<[MemRead]>, ]> { let summary = [{ref type element accessor}]; let description = [{ @@ -1052,6 +1058,7 @@ def VM_ListSetRefOp : VM_Op<"list.set.ref", [ DeclareOpInterfaceMethods<VM_SerializableOpInterface>, + MemoryEffects<[MemWrite]>, ]> { let summary = [{ref type element mutator}]; let description = [{
diff --git a/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeEncoder.cpp b/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeEncoder.cpp index 1da0509..349c56c 100644 --- a/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeEncoder.cpp +++ b/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeEncoder.cpp
@@ -199,8 +199,7 @@ } LogicalResult encodeResult(Value value) override { - uint16_t reg = - registerAllocation_->mapUseToRegister(value, currentOp_, 0).encode(); + uint16_t reg = registerAllocation_->mapToRegister(value).encode(); return writeUint16(reg); }
diff --git a/iree/hal/dylib/registration/BUILD b/iree/hal/dylib/registration/BUILD index 565906a..d5021a6 100644 --- a/iree/hal/dylib/registration/BUILD +++ b/iree/hal/dylib/registration/BUILD
@@ -31,17 +31,18 @@ cc_library( name = "registration", - srcs = ["driver_module.cc"], + srcs = ["driver_module.c"], hdrs = ["driver_module.h"], defines = [ "IREE_HAL_HAVE_DYLIB_DRIVER_MODULE=1", ], deps = [ + "//iree/base/internal:flags", "//iree/hal:api", "//iree/hal/local:task_driver", "//iree/hal/local/loaders:embedded_library_loader", "//iree/hal/local/loaders:legacy_library_loader", - "@com_google_absl//absl/flags:flag", + "//iree/task:api", ], )
diff --git a/iree/hal/dylib/registration/CMakeLists.txt b/iree/hal/dylib/registration/CMakeLists.txt index 2b7953c..702140e 100644 --- a/iree/hal/dylib/registration/CMakeLists.txt +++ b/iree/hal/dylib/registration/CMakeLists.txt
@@ -20,13 +20,14 @@ HDRS "driver_module.h" SRCS - "driver_module.cc" + "driver_module.c" DEPS - absl::flags + iree::base::internal::flags iree::hal::api iree::hal::local::loaders::embedded_library_loader iree::hal::local::loaders::legacy_library_loader iree::hal::local::task_driver + iree::task::api DEFINES "IREE_HAL_HAVE_DYLIB_DRIVER_MODULE=1" PUBLIC
diff --git a/iree/hal/dylib/registration/driver_module.cc b/iree/hal/dylib/registration/driver_module.c similarity index 74% rename from iree/hal/dylib/registration/driver_module.cc rename to iree/hal/dylib/registration/driver_module.c index f529c71..442870f 100644 --- a/iree/hal/dylib/registration/driver_module.cc +++ b/iree/hal/dylib/registration/driver_module.c
@@ -16,10 +16,11 @@ #include <inttypes.h> -#include "absl/flags/flag.h" +#include "iree/base/internal/flags.h" #include "iree/hal/local/loaders/embedded_library_loader.h" #include "iree/hal/local/loaders/legacy_library_loader.h" #include "iree/hal/local/task_driver.h" +#include "iree/task/api.h" // TODO(#4298): remove this driver registration and wrapper. // By having a single iree/hal/local/registration that then has the loaders @@ -28,11 +29,6 @@ // using an existing executor so that we can entirely externalize the task // system configuration from the HAL. -ABSL_FLAG(int, dylib_worker_count, 0, - "Specified number of workers to use or 0 for automatic."); -ABSL_FLAG(int, dylib_max_worker_count, 16, - "Maximum number of task system workers to use."); - #define IREE_HAL_DYLIB_DRIVER_ID 0x58444C4Cu // XDLL static iree_status_t iree_hal_dylib_driver_factory_enumerate( @@ -40,10 +36,10 @@ iree_host_size_t* out_driver_info_count) { static const iree_hal_driver_info_t driver_infos[1] = { { - /*.driver_id=*/IREE_HAL_DYLIB_DRIVER_ID, - /*.driver_name=*/iree_make_cstring_view("dylib"), - /*.full_name=*/ - iree_make_cstring_view("AOT compiled dynamic libraries"), + .driver_id = IREE_HAL_DYLIB_DRIVER_ID, + .driver_name = iree_string_view_literal("dylib"), + .full_name = + iree_string_view_literal("AOT compiled dynamic libraries"), }, }; *out_driver_info_count = IREE_ARRAYSIZE(driver_infos); @@ -64,17 +60,6 @@ iree_hal_task_device_params_t default_params; iree_hal_task_device_params_initialize(&default_params); - iree_task_topology_t topology; - iree_task_topology_initialize(&topology); - if (absl::GetFlag(FLAGS_dylib_worker_count) > 0) { - iree_task_topology_initialize_from_group_count( - absl::GetFlag(FLAGS_dylib_worker_count), &topology); - } else { - iree_task_topology_initialize_from_unique_l2_cache_groups( - /*max_group_count=*/absl::GetFlag(FLAGS_dylib_max_worker_count), - &topology); - } - iree_status_t status = iree_ok_status(); iree_hal_executable_loader_t* loaders[2] = {NULL, NULL}; @@ -90,8 +75,7 @@ iree_task_executor_t* executor = NULL; if (iree_status_is_ok(status)) { - status = iree_task_executor_create(IREE_TASK_SCHEDULING_MODE_RESERVED, - &topology, allocator, &executor); + status = iree_task_executor_create_from_flags(allocator, &executor); } if (iree_status_is_ok(status)) { @@ -101,7 +85,6 @@ } iree_task_executor_release(executor); - iree_task_topology_deinitialize(&topology); for (iree_host_size_t i = 0; i < loader_count; ++i) { iree_hal_executable_loader_release(loaders[i]); } @@ -111,9 +94,9 @@ IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_dylib_driver_module_register(iree_hal_driver_registry_t* registry) { static const iree_hal_driver_factory_t factory = { - /*self=*/NULL, - iree_hal_dylib_driver_factory_enumerate, - iree_hal_dylib_driver_factory_try_create, + .self = NULL, + .enumerate = iree_hal_dylib_driver_factory_enumerate, + .try_create = iree_hal_dylib_driver_factory_try_create, }; return iree_hal_driver_registry_register_factory(registry, &factory); }
diff --git a/iree/hal/executable_cache.h b/iree/hal/executable_cache.h index 5b71a7c..cd5d854 100644 --- a/iree/hal/executable_cache.h +++ b/iree/hal/executable_cache.h
@@ -71,6 +71,11 @@ // Device must support the IREE_HAL_DEVICE_FEATURE_SUPPORTS_PROFILING feature // and executables must support the ExecutableFeature::kProfiling feature. IREE_HAL_EXECUTABLE_CACHING_MODE_ENABLE_PROFILING = 1u << 5, + // Disables verification of executable layouts and modes. + // This is useful when debugging with partial information but should never + // be enabled for real usage as the verification is the best way to catch + // API misuse. + IREE_HAL_EXECUTABLE_CACHING_MODE_DISABLE_VERIFICATION = 1u << 6, }; typedef uint32_t iree_hal_executable_caching_mode_t;
diff --git a/iree/hal/local/BUILD b/iree/hal/local/BUILD index f6eb991..f8397aa 100644 --- a/iree/hal/local/BUILD +++ b/iree/hal/local/BUILD
@@ -61,6 +61,20 @@ ) cc_test( + name = "executable_library_benchmark", + srcs = ["executable_library_benchmark.c"], + deps = [ + ":executable_library", + ":local", + "//iree/base:api", + "//iree/base:tracing", + "//iree/base/internal:flags", + "//iree/hal/local/loaders:embedded_library_loader", + "//iree/testing:benchmark", + ], +) + +cc_test( name = "executable_library_test", srcs = [ "executable_library_demo.c",
diff --git a/iree/hal/local/CMakeLists.txt b/iree/hal/local/CMakeLists.txt index c25cdab..960636e 100644 --- a/iree/hal/local/CMakeLists.txt +++ b/iree/hal/local/CMakeLists.txt
@@ -52,6 +52,21 @@ iree_cc_test( NAME + executable_library_benchmark + SRCS + "executable_library_benchmark.c" + DEPS + ::executable_library + ::local + iree::base::api + iree::base::internal::flags + iree::base::tracing + iree::hal::local::loaders::embedded_library_loader + iree::testing::benchmark +) + +iree_cc_test( + NAME executable_library_test SRCS "executable_library_demo.c"
diff --git a/iree/hal/local/executable_library_benchmark.c b/iree/hal/local/executable_library_benchmark.c new file mode 100644 index 0000000..8c9a148 --- /dev/null +++ b/iree/hal/local/executable_library_benchmark.c
@@ -0,0 +1,400 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <errno.h> + +#include "iree/base/api.h" +#include "iree/base/internal/flags.h" +#include "iree/base/tracing.h" +#include "iree/hal/local/executable_library.h" +#include "iree/hal/local/executable_loader.h" +#include "iree/hal/local/local_descriptor_set_layout.h" +#include "iree/hal/local/local_executable.h" +#include "iree/testing/benchmark.h" + +IREE_FLAG(string, executable_format, "", + "Format of the executable file being loaded."); +IREE_FLAG(string, executable_file, "", + "Path to the executable library file to load."); + +IREE_FLAG(int32_t, entry_point, 0, "Entry point ordinal to run."); + +IREE_FLAG(int32_t, workgroup_count_x, 1, + "X dimension of the workgroup count defining the number of\n" + "workgroup invocations that will be run per benchmark iteration.\n" + "This is the fastest-changing dimension."); +IREE_FLAG(int32_t, workgroup_count_y, 1, + "Y dimension of the workgroup count defining the number of\n" + "workgroup invocations that will be run per benchmark iteration."); +IREE_FLAG(int32_t, workgroup_count_z, 1, + "Z dimension of the workgroup count defining the number of\n" + "workgroup invocations that will be run per benchmark iteration.\n" + "This is the slowest-changing dimension."); +IREE_FLAG(int32_t, workgroup_size_x, 1, + "X dimension of the workgroup size passed to the executable."); +IREE_FLAG(int32_t, workgroup_size_y, 1, + "Y dimension of the workgroup size passed to the executable."); +IREE_FLAG(int32_t, workgroup_size_z, 1, + "Z dimension of the workgroup size passed to the executable."); + +// Total number of bindings we (currently) allow any executable to have. +#define IREE_HAL_LOCAL_MAX_TOTAL_BINDING_COUNT \ + (IREE_HAL_LOCAL_MAX_DESCRIPTOR_SET_COUNT * \ + IREE_HAL_LOCAL_MAX_DESCRIPTOR_BINDING_COUNT) + +// Parsed parameters from flags. +// Used to construct the dispatch parameters for the benchmark invocation. +struct { + int32_t push_constant_count; + union { + uint32_t ui32; + } push_constants[IREE_HAL_LOCAL_MAX_PUSH_CONSTANT_COUNT]; + + int32_t binding_count; + iree_string_view_t bindings[IREE_HAL_LOCAL_MAX_TOTAL_BINDING_COUNT]; +} dispatch_params = { + .push_constant_count = 0, + .binding_count = 0, +}; + +static iree_status_t parse_push_constant(iree_string_view_t flag_name, + void* storage, + iree_string_view_t value) { + IREE_ASSERT_LE(dispatch_params.push_constant_count + 1, + IREE_ARRAYSIZE(dispatch_params.push_constants), + "too many push constants"); + dispatch_params.push_constants[dispatch_params.push_constant_count++].ui32 = + atoi(value.data); + return iree_ok_status(); +} +static void print_push_constant(iree_string_view_t flag_name, void* storage, + FILE* file) { + if (dispatch_params.push_constant_count == 0) { + fprintf(file, "# --%.*s=[integer value]\n", (int)flag_name.size, + flag_name.data); + return; + } + for (int32_t i = 0; i < dispatch_params.push_constant_count; ++i) { + fprintf(file, "--%.*s=%u", (int)flag_name.size, flag_name.data, + dispatch_params.push_constants[i].ui32); + if (i < dispatch_params.push_constant_count - 1) { + fprintf(file, "\n"); + } + } +} +IREE_FLAG_CALLBACK(parse_push_constant, print_push_constant, &dispatch_params, + push_constant_callback, + "Appends a uint32_t push constant value.\n"); + +static iree_status_t parse_binding(iree_string_view_t flag_name, void* storage, + iree_string_view_t value) { + IREE_ASSERT_LE(dispatch_params.binding_count + 1, + IREE_ARRAYSIZE(dispatch_params.bindings), "too many bindings"); + dispatch_params.bindings[dispatch_params.binding_count++] = value; + return iree_ok_status(); +} +static void print_binding(iree_string_view_t flag_name, void* storage, + FILE* file) { + if (dispatch_params.binding_count == 0) { + fprintf(file, "# --%.*s=\"shapextype[=values]\"\n", (int)flag_name.size, + flag_name.data); + return; + } + for (int32_t i = 0; i < dispatch_params.binding_count; ++i) { + const iree_string_view_t binding_str = dispatch_params.bindings[i]; + fprintf(file, "--%.*s=\"%.*s\"\n", (int)flag_name.size, flag_name.data, + (int)binding_str.size, binding_str.data); + } +} +IREE_FLAG_CALLBACK( + parse_binding, print_binding, &dispatch_params, binding, + "Appends a binding to the dispatch parameters.\n" + "Bindings are defined by their shape, element type, and their data.\n" + "Examples:\n" + " # 16 4-byte elements zero-initialized:\n" + " --binding=2x8xi32\n" + " # 10000 bytes all initialized to 123:\n" + " --binding=10000xi8=123\n" + " # 2 4-byte floating-point values with contents [[1.4], [2.1]]:\n" + " --binding=2x1xf32=1.4,2.1"); + +#if defined(IREE_HAL_HAVE_EMBEDDED_LIBRARY_LOADER) +#include "iree/hal/local/loaders/embedded_library_loader.h" +#endif // IREE_HAL_HAVE_EMBEDDED_LIBRARY_LOADER + +// Creates an executable loader based on the given format flag. +static iree_status_t iree_hal_executable_library_create_loader( + iree_allocator_t host_allocator, + iree_hal_executable_loader_t** out_executable_loader) { +#if defined(IREE_HAL_HAVE_EMBEDDED_LIBRARY_LOADER) + if (strcmp(FLAG_executable_format, "EX_ELF") == 0) { + return iree_hal_embedded_library_loader_create(host_allocator, + out_executable_loader); + } +#endif // IREE_HAL_HAVE_EMBEDDED_LIBRARY_LOADER + return iree_make_status( + IREE_STATUS_UNAVAILABLE, + "no loader available that can handle --executable_format=%s", + FLAG_executable_format); +} + +// TODO(benvanik): use this to replace file_io.cc. +static iree_status_t iree_file_read_contents(const char* path, + iree_allocator_t allocator, + iree_byte_span_t* out_contents) { + IREE_TRACE_ZONE_BEGIN(z0); + *out_contents = iree_make_byte_span(NULL, 0); + FILE* file = fopen(path, "rb"); + if (file == NULL) { + IREE_TRACE_ZONE_END(z0); + return iree_make_status(iree_status_code_from_errno(errno), + "failed to open file '%s'", path); + } + iree_status_t status = iree_ok_status(); + if (fseek(file, 0, SEEK_END) == -1) { + status = iree_make_status(iree_status_code_from_errno(errno), "seek (end)"); + } + size_t file_size = 0; + if (iree_status_is_ok(status)) { + file_size = ftell(file); + if (file_size == -1L) { + status = + iree_make_status(iree_status_code_from_errno(errno), "size query"); + } + } + if (iree_status_is_ok(status)) { + if (fseek(file, 0, SEEK_SET) == -1) { + status = + iree_make_status(iree_status_code_from_errno(errno), "seek (beg)"); + } + } + // Allocate +1 to force a trailing \0 in case this is a string. + char* contents = NULL; + if (iree_status_is_ok(status)) { + status = iree_allocator_malloc(allocator, file_size + 1, (void**)&contents); + } + if (iree_status_is_ok(status)) { + if (fread(contents, file_size, 1, file) != 1) { + status = + iree_make_status(iree_status_code_from_errno(errno), + "unable to read entire file contents of '%s'", path); + } + } + if (iree_status_is_ok(status)) { + contents[file_size] = 0; // NUL + *out_contents = iree_make_byte_span(contents, file_size); + } else { + iree_allocator_free(allocator, contents); + } + fclose(file); + IREE_TRACE_ZONE_END(z0); + return status; +} + +// NOTE: error handling is here just for better diagnostics: it is not tracking +// allocations correctly and will leak. Don't use this as an example for how to +// write robust code. +static iree_status_t iree_hal_executable_library_run( + iree_benchmark_state_t* benchmark_state) { + iree_allocator_t host_allocator = benchmark_state->host_allocator; + + // Register the loader used to load (or find) the executable. + iree_hal_executable_loader_t* executable_loader = NULL; + IREE_RETURN_IF_ERROR(iree_hal_executable_library_create_loader( + host_allocator, &executable_loader)); + + // Setup the specification used to perform the executable load. + // This information is normally used to select the appropriate loader but in + // this benchmark we only have a single one. + iree_hal_executable_spec_t executable_spec; + iree_hal_executable_spec_initialize(&executable_spec); + executable_spec.caching_mode = + IREE_HAL_EXECUTABLE_CACHING_MODE_ALLOW_OPTIMIZATION | + IREE_HAL_EXECUTABLE_CACHING_MODE_ALIAS_PROVIDED_DATA | + IREE_HAL_EXECUTABLE_CACHING_MODE_DISABLE_VERIFICATION; + executable_spec.executable_format = + iree_make_cstring_view(FLAG_executable_format); + + // Load the executable data. + IREE_RETURN_IF_ERROR(iree_file_read_contents( + FLAG_executable_file, host_allocator, + (iree_byte_span_t*)&executable_spec.executable_data)); + + // Setup the layouts defining how each entry point is interpreted. + // NOTE: we know for the embedded library loader that this is not required. + // Other loaders may need it in which case it'll have to be provided. + executable_spec.executable_layout_count = 0; + executable_spec.executable_layouts = NULL; + + // Perform the load, which will fail if the executable cannot be loaded or + // there was an issue with the layouts. + iree_hal_executable_t* executable = NULL; + IREE_RETURN_IF_ERROR(iree_hal_executable_loader_try_load( + executable_loader, &executable_spec, &executable)); + + // Allocate storage for buffers and populate them. + // They only need to remain valid for the duration of the invocation and all + // memory accessed by the invocation will come from here. + iree_hal_allocator_t* heap_allocator = NULL; + IREE_RETURN_IF_ERROR(iree_hal_allocator_create_heap( + iree_make_cstring_view("benchmark"), host_allocator, &heap_allocator)); + iree_hal_buffer_view_t* buffer_views[IREE_HAL_LOCAL_MAX_TOTAL_BINDING_COUNT]; + void* binding_ptrs[IREE_HAL_LOCAL_MAX_TOTAL_BINDING_COUNT]; + size_t binding_lengths[IREE_HAL_LOCAL_MAX_TOTAL_BINDING_COUNT]; + for (iree_host_size_t i = 0; i < dispatch_params.binding_count; ++i) { + IREE_RETURN_IF_ERROR( + iree_hal_buffer_view_parse(dispatch_params.bindings[i], heap_allocator, + host_allocator, &buffer_views[i])); + iree_hal_buffer_t* buffer = iree_hal_buffer_view_buffer(buffer_views[i]); + iree_device_size_t buffer_length = + iree_hal_buffer_view_byte_length(buffer_views[i]); + iree_hal_buffer_mapping_t buffer_mapping; + IREE_RETURN_IF_ERROR(iree_hal_buffer_map_range( + buffer, IREE_HAL_MEMORY_ACCESS_ALL, 0, buffer_length, &buffer_mapping)); + binding_ptrs[i] = buffer_mapping.contents.data; + binding_lengths[i] = (size_t)buffer_mapping.contents.data_length; + } + + // Setup dispatch state. + iree_hal_executable_dispatch_state_v0_t dispatch_state = { + .workgroup_count = {{ + .x = FLAG_workgroup_count_x, + .y = FLAG_workgroup_count_y, + .z = FLAG_workgroup_count_z, + }}, + .workgroup_size = {{ + .x = FLAG_workgroup_size_x, + .y = FLAG_workgroup_size_y, + .z = FLAG_workgroup_size_z, + }}, + .push_constant_count = dispatch_params.push_constant_count, + .push_constants = &dispatch_params.push_constants[0].ui32, + .binding_count = dispatch_params.binding_count, + .binding_ptrs = binding_ptrs, + .binding_lengths = binding_lengths, + .imports = NULL, // not yet implemented + }; + + // Execute benchmark the workgroup invocation. + // Note that each iteration runs through the whole grid as it's important that + // we are testing the memory access patterns: if we just ran the same single + // tile processing the same exact region of memory over and over we are not + // testing cache effects. + IREE_TRACE_ZONE_BEGIN(z1); + int64_t dispatch_count = 0; + while (iree_benchmark_keep_running(benchmark_state, /*batch_count=*/1)) { + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z1, iree_hal_local_executable_issue_dispatch_inline( + iree_hal_local_executable_cast(executable), FLAG_entry_point, + &dispatch_state)); + ++dispatch_count; + } + IREE_TRACE_ZONE_END(z1); + + // To get a total time per invocation we set the item count to the total + // invocations dispatched. That gives us both total dispatch and single + // invocation times in the reporter output. + int64_t total_invocations = + dispatch_count * dispatch_state.workgroup_count.x * + dispatch_state.workgroup_count.y * dispatch_state.workgroup_count.z; + iree_benchmark_set_items_processed(benchmark_state, total_invocations); + + // Deallocate buffers. + for (iree_host_size_t i = 0; i < dispatch_params.binding_count; ++i) { + iree_hal_buffer_view_release(buffer_views[i]); + } + iree_hal_allocator_release(heap_allocator); + + // Unload. + iree_allocator_free(host_allocator, + (void*)executable_spec.executable_data.data); + iree_hal_executable_release(executable); + iree_hal_executable_loader_release(executable_loader); + + return iree_ok_status(); +} + +int main(int argc, char** argv) { + iree_flags_set_usage( + "executable_library_benchmark", + "Benchmarks a single entry point within an executable library.\n" + "Executable libraries can be found in your temp path when compiling\n" + "with `-iree-llvm-keep-linker-artifacts`. The parameters used can be\n" + "inferred from the entry point `hal.interface` and dispatches to it.\n" + "\n" + "Note that this tool is intentionally low level: you must specify all\n" + "of the push constant/binding parameters precisely as they are expected\n" + "by the executable. `iree-benchmark-module` is the user-friendly\n" + "benchmarking tool while this one favors direct access to the\n" + "executables (bypassing all of the IREE VM, HAL APIs, task system,\n" + "etc).\n" + "\n" + "Example --flagfile:\n" + " --executable_format=EX_ELF\n" + " --executable_file=iree/hal/local/elf/testdata/" + "simple_mul_dispatch_x86_64.so\n" + " --entry_point=0\n" + " --workgroup_count_x=1\n" + " --workgroup_count_y=1\n" + " --workgroup_count_z=1\n" + " --workgroup_size_x=1\n" + " --workgroup_size_y=1\n" + " --workgroup_size_z=1\n" + " --binding=4xf32=1,2,3,4\n" + " --binding=4xf32=100,200,300,400\n" + " --binding=4xf32=0,0,0,0);\n" + "\n"); + + iree_flags_parse_checked(IREE_FLAGS_PARSE_MODE_UNDEFINED_OK, &argc, &argv); + iree_benchmark_initialize(&argc, argv); + +#if IREE_TRACING_FEATURES & IREE_TRACING_FEATURE_INSTRUMENTATION + // clang-format off + fprintf(stderr, +"\x1b[31m" +"===----------------------------------------------------------------------===\n" +"\n" +" ██ ██ █████ ██████ ███ ██ ██ ███ ██ ██████\n" +" ██ ██ ██ ██ ██ ██ ████ ██ ██ ████ ██ ██\n" +" ██ █ ██ ███████ ██████ ██ ██ ██ ██ ██ ██ ██ ██ ███\n" +" ██ ███ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██\n" +" ███ ███ ██ ██ ██ ██ ██ ████ ██ ██ ████ ██████\n" +"\n" +"===----------------------------------------------------------------------===\n" +"\n" +"Tracing is enabled and will skew your results!\n" +"The timings involved here can an order of magnitude off due to the tracing\n" +"time sampling, recording, and instrumentation overhead. Disable tracing with\n" +"IREE_ENABLE_RUNTIME_TRACING=OFF and rebuild.\n" +"\x1b[0m" +"\n" + ); + fflush(stderr); + // clang-format on +#endif // IREE_TRACING_FEATURES & IREE_TRACING_FEATURE_INSTRUMENTATION + + // TODO(benvanik): override these with our own flags. + iree_benchmark_def_t benchmark_def = { + .flags = IREE_BENCHMARK_FLAG_MEASURE_PROCESS_CPU_TIME | + IREE_BENCHMARK_FLAG_USE_REAL_TIME, + .time_unit = IREE_BENCHMARK_UNIT_NANOSECOND, + .minimum_duration_ns = 0, + .iteration_count = 0, + .run = iree_hal_executable_library_run, + }; + iree_benchmark_register(iree_make_cstring_view("dispatch"), &benchmark_def); + + iree_benchmark_run_specified(); + return 0; +}
diff --git a/iree/hal/local/loaders/BUILD b/iree/hal/local/loaders/BUILD index c3e4713..51df798 100644 --- a/iree/hal/local/loaders/BUILD +++ b/iree/hal/local/loaders/BUILD
@@ -28,7 +28,7 @@ srcs = ["embedded_library_loader.c"], hdrs = ["embedded_library_loader.h"], defines = [ - "IREE_HAL_HAVE_ELF_LIBRARY_LOADER=1", + "IREE_HAL_HAVE_EMBEDDED_LIBRARY_LOADER=1", ], deps = [ "//iree/base:api",
diff --git a/iree/hal/local/loaders/CMakeLists.txt b/iree/hal/local/loaders/CMakeLists.txt index 7ca0508..4e652c0 100644 --- a/iree/hal/local/loaders/CMakeLists.txt +++ b/iree/hal/local/loaders/CMakeLists.txt
@@ -25,7 +25,7 @@ iree::hal::local iree::hal::local::elf::elf_module DEFINES - "IREE_HAL_HAVE_ELF_LIBRARY_LOADER=1" + "IREE_HAL_HAVE_EMBEDDED_LIBRARY_LOADER=1" PUBLIC )
diff --git a/iree/hal/local/loaders/embedded_library_loader.c b/iree/hal/local/loaders/embedded_library_loader.c index dc19c84..b1e0f04 100644 --- a/iree/hal/local/loaders/embedded_library_loader.c +++ b/iree/hal/local/loaders/embedded_library_loader.c
@@ -85,6 +85,7 @@ } static iree_status_t iree_hal_elf_executable_create( + iree_hal_executable_caching_mode_t caching_mode, iree_const_byte_span_t elf_data, iree_host_size_t executable_layout_count, iree_hal_executable_layout_t* const* executable_layouts, iree_allocator_t host_allocator, iree_hal_executable_t** out_executable) { @@ -118,7 +119,10 @@ // Query metadata and get the entry point function pointers. status = iree_hal_elf_executable_query_library(executable); } - if (iree_status_is_ok(status)) { + if (iree_status_is_ok(status) && + !iree_all_bits_set( + caching_mode, + IREE_HAL_EXECUTABLE_CACHING_MODE_DISABLE_VERIFICATION)) { // Check to make sure that the entry point count matches the layouts // provided. if (executable->library.v0->entry_point_count != executable_layout_count) { @@ -267,7 +271,7 @@ // Perform the load of the ELF and wrap it in an executable handle. iree_status_t status = iree_hal_elf_executable_create( - executable_spec->executable_data, + executable_spec->caching_mode, executable_spec->executable_data, executable_spec->executable_layout_count, executable_spec->executable_layouts, executable_loader->host_allocator, out_executable);
diff --git a/iree/hal/vulkan/registration/BUILD b/iree/hal/vulkan/registration/BUILD index c8e40ac..a7a645a 100644 --- a/iree/hal/vulkan/registration/BUILD +++ b/iree/hal/vulkan/registration/BUILD
@@ -41,7 +41,6 @@ "//iree/base/internal:flags", "//iree/hal:api", "//iree/hal/vulkan", - "@com_google_absl//absl/flags:flag", ], )
diff --git a/iree/hal/vulkan/registration/CMakeLists.txt b/iree/hal/vulkan/registration/CMakeLists.txt index 51c41bc..87ddb12 100644 --- a/iree/hal/vulkan/registration/CMakeLists.txt +++ b/iree/hal/vulkan/registration/CMakeLists.txt
@@ -20,7 +20,6 @@ SRCS "driver_module.cc" DEPS - absl::flags iree::base::core_headers iree::base::internal::flags iree::base::status
diff --git a/iree/hal/vulkan/registration/driver_module.cc b/iree/hal/vulkan/registration/driver_module.cc index 761943b..2d3faf0 100644 --- a/iree/hal/vulkan/registration/driver_module.cc +++ b/iree/hal/vulkan/registration/driver_module.cc
@@ -16,7 +16,6 @@ #include <inttypes.h> -#include "absl/flags/flag.h" #include "iree/base/internal/flags.h" #include "iree/base/status.h" #include "iree/base/target_platform.h" @@ -25,17 +24,18 @@ #define IREE_HAL_VULKAN_1_X_DRIVER_ID 0x564C4B31u // VLK1 -ABSL_FLAG(bool, vulkan_validation_layers, true, +IREE_FLAG(bool, vulkan_validation_layers, true, "Enables standard Vulkan validation layers."); -ABSL_FLAG(bool, vulkan_debug_utils, true, +IREE_FLAG(bool, vulkan_debug_utils, true, "Enables VK_EXT_debug_utils, records markers, and logs errors."); -ABSL_FLAG(int, vulkan_default_index, 0, "Index of the default Vulkan device."); +IREE_FLAG(int32_t, vulkan_default_index, 0, + "Index of the default Vulkan device."); -ABSL_FLAG(bool, vulkan_force_timeline_semaphore_emulation, false, +IREE_FLAG(bool, vulkan_force_timeline_semaphore_emulation, false, "Uses timeline semaphore emulation even if native support exists."); -ABSL_FLAG(bool, vulkan_tracing, true, +IREE_FLAG(bool, vulkan_tracing, true, "Enables Vulkan tracing (if IREE tracing is enabled)."); static iree_status_t iree_hal_vulkan_create_driver_with_flags( @@ -58,22 +58,21 @@ driver_options.api_version = VK_API_VERSION_1_2; #endif // IREE_PLATFORM_ANDROID - if (absl::GetFlag(FLAGS_vulkan_validation_layers)) { + if (FLAG_vulkan_validation_layers) { driver_options.requested_features |= IREE_HAL_VULKAN_FEATURE_ENABLE_VALIDATION_LAYERS; } - if (absl::GetFlag(FLAGS_vulkan_debug_utils)) { + if (FLAG_vulkan_debug_utils) { driver_options.requested_features |= IREE_HAL_VULKAN_FEATURE_ENABLE_DEBUG_UTILS; } - if (absl::GetFlag(FLAGS_vulkan_tracing)) { + if (FLAG_vulkan_tracing) { driver_options.requested_features |= IREE_HAL_VULKAN_FEATURE_ENABLE_TRACING; } - driver_options.default_device_index = - absl::GetFlag(FLAGS_vulkan_default_index); + driver_options.default_device_index = FLAG_vulkan_default_index; - if (absl::GetFlag(FLAGS_vulkan_force_timeline_semaphore_emulation)) { + if (FLAG_vulkan_force_timeline_semaphore_emulation) { driver_options.device_options.flags |= IREE_HAL_VULKAN_DEVICE_FORCE_TIMELINE_SEMAPHORE_EMULATION; }
diff --git a/iree/task/BUILD b/iree/task/BUILD index d5ef358..b973b5a 100644 --- a/iree/task/BUILD +++ b/iree/task/BUILD
@@ -19,6 +19,17 @@ ) cc_library( + name = "api", + srcs = ["api.c"], + hdrs = ["api.h"], + deps = [ + ":task", + "//iree/base:tracing", + "//iree/base/internal:flags", + ], +) + +cc_library( name = "task", srcs = [ "executor.c",
diff --git a/iree/task/CMakeLists.txt b/iree/task/CMakeLists.txt index 46ea469..2fafc78 100644 --- a/iree/task/CMakeLists.txt +++ b/iree/task/CMakeLists.txt
@@ -12,6 +12,20 @@ iree_cc_library( NAME + api + HDRS + "api.h" + SRCS + "api.c" + DEPS + ::task + iree::base::internal::flags + iree::base::tracing + PUBLIC +) + +iree_cc_library( + NAME task HDRS "affinity_set.h"
diff --git a/iree/task/api.c b/iree/task/api.c new file mode 100644 index 0000000..558397b --- /dev/null +++ b/iree/task/api.c
@@ -0,0 +1,125 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "iree/task/api.h" + +#include "iree/base/internal/flags.h" +#include "iree/base/tracing.h" + +//===----------------------------------------------------------------------===// +// Executor configuration +//===----------------------------------------------------------------------===// + +IREE_FLAG( + bool, task_scheduling_defer_worker_startup, false, + "Creates all workers suspended and waits until work is first scheduled to\n" + "them to resume. This trades off initial blocking startup time waking the\n" + "threads for potential latency additions later on as threads take longer\n" + "to wake on their first use."); + +IREE_FLAG( + bool, task_scheduling_dedicated_wait_thread, false, + "Creates a dedicated thread performing waits on root wait handles. On\n" + "workloads with many short-duration waits this will reduce total\n" + "latency as the waits are aggressively processed and dependent tasks are\n" + "scheduled. It also keeps any wait-related syscalls off the worker\n" + "threads that would otherwise need to perform the syscalls during\n" + "coordination."); + +//===----------------------------------------------------------------------===// +// Topology configuration +//===----------------------------------------------------------------------===// + +IREE_FLAG( + string, task_topology_mode, "physical_cores", + "Available modes:\n" + " --task_topology_group_count=non-zero:\n" + " Uses whatever the specified group count is and ignores the set mode.\n" + " 'physical_cores':\n" + " Creates one group per physical core in the machine up to\n" + " the value specified by --task_topology_max_group_count.\n" + " 'unique_l2_cache_groups':\n" + " Creates one group for each unique L2 cache group across all available\n" + " cores up to the value specified by --task_topology_max_group_count.\n" + " This optimizes for temporal and spatial cache locality but may suffer\n" + " from oversubscription if there are other processes trying to use the\n" + " same cores.\n"); + +IREE_FLAG( + int32_t, task_topology_group_count, 0, + "Defines the total number of task system workers that will be created.\n" + "Workers will be distributed across cores. Specifying 0 will use a\n" + "heuristic defined by --task_topology_mode= to automatically select the\n" + "worker count and distribution."); + +IREE_FLAG( + int32_t, task_topology_max_group_count, 8, + "Sets a maximum value on the worker count that can be automatically\n" + "detected and used when --task_topology_group_count=0 and is ignored\n" + "otherwise.\n"); + +// TODO(benvanik): add --task_topology_dump to dump out the current machine +// configuration as seen by the topology utilities. + +//===----------------------------------------------------------------------===// +// Task system factory functions +//===----------------------------------------------------------------------===// + +IREE_API_EXPORT iree_status_t IREE_API_CALL +iree_task_executor_create_from_flags(iree_allocator_t host_allocator, + iree_task_executor_t** out_executor) { + IREE_ASSERT_ARGUMENT(out_executor); + *out_executor = NULL; + IREE_TRACE_ZONE_BEGIN(z0); + + iree_task_scheduling_mode_t scheduling_mode = 0; + if (FLAG_task_scheduling_defer_worker_startup) { + scheduling_mode |= IREE_TASK_SCHEDULING_MODE_DEFER_WORKER_STARTUP; + } + if (FLAG_task_scheduling_dedicated_wait_thread) { + scheduling_mode |= IREE_TASK_SCHEDULING_MODE_DEDICATED_WAIT_THREAD; + } + + iree_status_t status = iree_ok_status(); + + iree_task_topology_t topology; + iree_task_topology_initialize(&topology); + + if (FLAG_task_topology_group_count != 0) { + iree_task_topology_initialize_from_group_count( + FLAG_task_topology_group_count, &topology); + } else if (strcmp(FLAG_task_topology_mode, "physical_cores") == 0) { + iree_task_topology_initialize_from_physical_cores( + FLAG_task_topology_max_group_count, &topology); + } else if (strcmp(FLAG_task_topology_mode, "unique_l2_cache_groups") == 0) { + iree_task_topology_initialize_from_unique_l2_cache_groups( + FLAG_task_topology_max_group_count, &topology); + } else { + status = iree_make_status( + IREE_STATUS_INVALID_ARGUMENT, + "one of --task_topology_group_count or --task_topology_mode must be " + "specified and be a valid value; have --task_topology_mode=%s.", + FLAG_task_topology_mode); + } + + if (iree_status_is_ok(status)) { + status = iree_task_executor_create(scheduling_mode, &topology, + host_allocator, out_executor); + } + + iree_task_topology_deinitialize(&topology); + + IREE_TRACE_ZONE_END(z0); + return status; +}
diff --git a/iree/task/api.h b/iree/task/api.h new file mode 100644 index 0000000..4e36600 --- /dev/null +++ b/iree/task/api.h
@@ -0,0 +1,54 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_TASK_API_H_ +#define IREE_TASK_API_H_ + +#include "iree/task/executor.h" +#include "iree/task/topology.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +//===----------------------------------------------------------------------===// +// Task system factory functions +//===----------------------------------------------------------------------===// + +// Creates a task system executor from the current command line flags. +// This configures a topology and all of the executor parameters and returns +// a newly created instance in |out_executor| that must be released by the +// caller. +// +// This utility method is useful when only a single executor exists within a +// process as the flags are globals. When multiple executors may exist or +// programmatic configuration is needed use the iree_task_executor_create method +// directly. +IREE_API_EXPORT iree_status_t IREE_API_CALL +iree_task_executor_create_from_flags(iree_allocator_t host_allocator, + iree_task_executor_t** out_executor); + +//===----------------------------------------------------------------------===// +// Task system simple invocation utilities +//===----------------------------------------------------------------------===// + +// TODO(benvanik): simple IO completion event callback. +// TODO(benvanik): simple async function call dispatch. +// TODO(benvanik): simple parallel-for grid-style function call dispatch. + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // IREE_TASK_API_H_
diff --git a/iree/task/topology.c b/iree/task/topology.c index e901528..d96b627 100644 --- a/iree/task/topology.c +++ b/iree/task/topology.c
@@ -248,9 +248,10 @@ for (iree_host_size_t j = 0; j < topology->group_count; ++j) { if (i == j) continue; const iree_task_topology_group_t* other_group = &topology->groups[j]; - uint64_t group_processor_bits = 1ull << other_group->processor_index; + uint64_t group_processor_bits = + iree_math_rotl_u64(1ull, other_group->processor_index); if (constructive_sharing_mask & group_processor_bits) { - group_mask |= 1ull << other_group->group_index; + group_mask |= iree_math_rotl_u64(1ull, other_group->group_index); } }
diff --git a/iree/testing/benchmark.h b/iree/testing/benchmark.h index 76fa307..f91e1a4 100644 --- a/iree/testing/benchmark.h +++ b/iree/testing/benchmark.h
@@ -126,7 +126,7 @@ // Registers a benchmark with the given definition. void iree_benchmark_register(iree_string_view_t name, - iree_benchmark_def_t* benchmark_def); + const iree_benchmark_def_t* benchmark_def); //===----------------------------------------------------------------------===// // Benchmark infra management
diff --git a/iree/testing/benchmark_main.c b/iree/testing/benchmark_main.c index 2fcf14b..3990911 100644 --- a/iree/testing/benchmark_main.c +++ b/iree/testing/benchmark_main.c
@@ -16,8 +16,11 @@ #include "iree/testing/benchmark.h" int main(int argc, char** argv) { + // Pass through flags to benchmark (allowing --help to fall through). + iree_flags_parse_checked(IREE_FLAGS_PARSE_MODE_UNDEFINED_OK | + IREE_FLAGS_PARSE_MODE_CONTINUE_AFTER_HELP, + &argc, &argv); iree_benchmark_initialize(&argc, argv); - iree_flags_parse_checked(&argc, &argv); iree_benchmark_run_specified(); return 0; }
diff --git a/iree/testing/gtest_main.cc b/iree/testing/gtest_main.cc index c3d456c..e455555 100644 --- a/iree/testing/gtest_main.cc +++ b/iree/testing/gtest_main.cc
@@ -16,8 +16,11 @@ #include "iree/testing/gtest.h" extern "C" int main(int argc, char** argv) { + // Pass through flags to gtest (allowing --help to fall through). + iree_flags_parse_checked(IREE_FLAGS_PARSE_MODE_UNDEFINED_OK | + IREE_FLAGS_PARSE_MODE_CONTINUE_AFTER_HELP, + &argc, &argv); ::testing::InitGoogleTest(&argc, argv); - iree_flags_parse_checked(&argc, &argv); return RUN_ALL_TESTS(); }
diff --git a/iree/testing/vulkan/CMakeLists.txt b/iree/testing/vulkan/CMakeLists.txt index f263f0e..3a84ccc 100644 --- a/iree/testing/vulkan/CMakeLists.txt +++ b/iree/testing/vulkan/CMakeLists.txt
@@ -63,7 +63,6 @@ "iree-run-module-vulkan-gui-main.cc" DEPS ::vulkan_gui_util - absl::flags iree::base::internal::file_io iree::base::internal::flags iree::base::internal::main
diff --git a/iree/testing/vulkan/iree-run-module-vulkan-gui-main.cc b/iree/testing/vulkan/iree-run-module-vulkan-gui-main.cc index 29491ba..1e3f42b 100644 --- a/iree/testing/vulkan/iree-run-module-vulkan-gui-main.cc +++ b/iree/testing/vulkan/iree-run-module-vulkan-gui-main.cc
@@ -18,7 +18,6 @@ #include "iree/testing/vulkan/vulkan_gui_util.h" // Other dependencies (helpers, etc.) -#include "absl/flags/flag.h" #include "iree/base/internal/file_io.h" #include "iree/base/internal/flags.h" #include "iree/base/internal/main.h" @@ -29,29 +28,44 @@ #include "iree/vm/api.h" #include "iree/vm/bytecode_module.h" -ABSL_FLAG(std::string, module_file, "-", +IREE_FLAG(string, module_file, "-", "File containing the module to load that contains the entry " "function. Defaults to stdin."); -ABSL_FLAG(std::string, entry_function, "", +IREE_FLAG(string, entry_function, "", "Name of a function contained in the module specified by input_file " "to run."); -ABSL_FLAG(std::vector<std::string>, function_inputs, {}, - "A comma-separated list of of input buffers of the format:" - "[shape]xtype=[value]\n" - "2x2xi32=1 2 3 4\n" - "Optionally, brackets may be used to separate the element values. " - "They are ignored by the parser.\n" - "2x2xi32=[[1 2][3 4]]\n" - "Due to the absence of repeated flags in absl, commas should not be " - "used to separate elements. They are reserved for separating input " - "values:\n" - "2x2xi32=[[1 2][3 4]], 1x2xf32=[[1 2]]"); - -ABSL_FLAG(std::string, function_inputs_file, "", - "Provides a file for input shapes and optional values (see " - "ParseToVariantListFromFile in vm_util.h for details)"); +static iree_status_t parse_function_input(iree_string_view_t flag_name, + void* storage, + iree_string_view_t value) { + auto* list = (std::vector<std::string>*)storage; + list->push_back(std::string(value.data, value.size)); + return iree_ok_status(); +} +static void print_function_input(iree_string_view_t flag_name, void* storage, + FILE* file) { + auto* list = (std::vector<std::string>*)storage; + if (list->empty()) { + fprintf(file, "# --%.*s=\n", (int)flag_name.size, flag_name.data); + } else { + for (size_t i = 0; i < list->size(); ++i) { + fprintf(file, "--%.*s=\"%s\"\n", (int)flag_name.size, flag_name.data, + list->at(i).c_str()); + } + } +} +static std::vector<std::string> FLAG_function_inputs; +IREE_FLAG_CALLBACK( + parse_function_input, print_function_input, &FLAG_function_inputs, + function_input, + "An input value or buffer of the format:\n" + " [shape]xtype=[value]\n" + " 2x2xi32=1 2 3 4\n" + "Optionally, brackets may be used to separate the element values:\n" + " 2x2xi32=[[1 2][3 4]]\n" + "Each occurrence of the flag indicates an input in the order they were\n" + "specified on the command line."); static VkAllocationCallbacks* g_Allocator = NULL; static VkInstance g_Instance = VK_NULL_HANDLE; @@ -89,7 +103,7 @@ } Status GetModuleContentsFromFlags(std::string* out_contents) { - auto module_file = absl::GetFlag(FLAGS_module_file); + auto module_file = std::string(FLAG_module_file); if (module_file == "-") { *out_contents = std::string{std::istreambuf_iterator<char>(std::cin), std::istreambuf_iterator<char>()}; @@ -144,7 +158,7 @@ } // namespace extern "C" int iree_main(int argc, char** argv) { - iree_flags_parse_checked(&argc, &argv); + iree_flags_parse_checked(IREE_FLAGS_PARSE_MODE_DEFAULT, &argc, &argv); IREE_CHECK_OK(iree_hal_vulkan_driver_module_register( iree_hal_driver_registry_default())); @@ -322,7 +336,7 @@ IREE_LOG(INFO) << "Context with modules is ready for use"; // Lookup the entry point function. - std::string entry_function = absl::GetFlag(FLAGS_entry_function); + std::string entry_function = FLAG_entry_function; iree_vm_function_t main_function; IREE_CHECK_OK(bytecode_module->lookup_function( bytecode_module->self, IREE_VM_FUNCTION_LINKAGE_EXPORT, @@ -339,25 +353,15 @@ std::vector<RawSignatureParser::Description> main_function_input_descs; IREE_CHECK_OK(ParseInputSignature(main_function, &main_function_input_descs)); vm::ref<iree_vm_list_t> main_function_inputs; - if (!absl::GetFlag(FLAGS_function_inputs_file).empty()) { - if (!absl::GetFlag(FLAGS_function_inputs).empty()) { - IREE_LOG(FATAL) << "Expected only one of function_inputs and " - "function_inputs_file to be set"; - } - IREE_CHECK_OK(ParseToVariantListFromFile( - main_function_input_descs, iree_hal_device_allocator(iree_vk_device), - absl::GetFlag(FLAGS_function_inputs_file), &main_function_inputs)); - } else { - IREE_CHECK_OK(ParseToVariantList( - main_function_input_descs, iree_hal_device_allocator(iree_vk_device), - absl::GetFlag(FLAGS_function_inputs), &main_function_inputs)); - } + IREE_CHECK_OK(ParseToVariantList( + main_function_input_descs, iree_hal_device_allocator(iree_vk_device), + FLAG_function_inputs, &main_function_inputs)); std::vector<RawSignatureParser::Description> main_function_output_descs; IREE_CHECK_OK( ParseOutputSignature(main_function, &main_function_output_descs)); - const std::string& window_title = absl::GetFlag(FLAGS_module_file); + const std::string window_title = std::string(FLAG_module_file); // -------------------------------------------------------------------------- // --------------------------------------------------------------------------
diff --git a/iree/tools/BUILD b/iree/tools/BUILD index c01c6f6..c6592e3 100644 --- a/iree/tools/BUILD +++ b/iree/tools/BUILD
@@ -39,9 +39,6 @@ "//iree/tools/utils:vm_util", "//iree/vm", "//iree/vm:bytecode_module", - "@com_google_absl//absl/flags:flag", - "@com_google_absl//absl/flags:parse", - "@com_google_absl//absl/flags:usage", "@com_google_absl//absl/strings", "@com_google_benchmark//:benchmark", ], @@ -64,7 +61,6 @@ "//iree/testing:gtest", "//iree/tools/utils:vm_util", "//iree/vm:bytecode_module", - "@com_google_absl//absl/flags:flag", "@com_google_absl//absl/strings", ], ) @@ -254,7 +250,6 @@ "//iree/tools/utils:vm_util", "//iree/vm", "//iree/vm:bytecode_module", - "@com_google_absl//absl/flags:flag", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@llvm-project//llvm:Support", @@ -280,7 +275,6 @@ "//iree/tools/utils:vm_util", "//iree/vm", "//iree/vm:bytecode_module", - "@com_google_absl//absl/flags:flag", "@com_google_absl//absl/strings", ], )
diff --git a/iree/tools/CMakeLists.txt b/iree/tools/CMakeLists.txt index 90c2a90..084989f 100644 --- a/iree/tools/CMakeLists.txt +++ b/iree/tools/CMakeLists.txt
@@ -63,9 +63,6 @@ SRCS "iree-benchmark-module-main.cc" DEPS - absl::flags - absl::flags_parse - absl::flags_usage absl::strings benchmark iree::base::internal::file_io @@ -89,7 +86,6 @@ "iree-check-module-main.cc" DEPS iree::modules::check::native_module - absl::flags absl::strings iree::base::api iree::base::core_headers @@ -127,7 +123,6 @@ SRCS "iree-run-module-main.cc" DEPS - absl::flags absl::strings iree::base::internal::file_io iree::base::internal::flags @@ -365,7 +360,6 @@ MLIRPass MLIRSupport MLIRTargetLLVMIRExport - absl::flags absl::span absl::strings iree::base::api
diff --git a/iree/tools/iree-benchmark-module-main.cc b/iree/tools/iree-benchmark-module-main.cc index a225339..f5f9611 100644 --- a/iree/tools/iree-benchmark-module-main.cc +++ b/iree/tools/iree-benchmark-module-main.cc
@@ -12,10 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "absl/flags/flag.h" -#include "absl/flags/internal/parse.h" -#include "absl/flags/usage.h" -#include "absl/strings/string_view.h" #include "benchmark/benchmark.h" #include "iree/base/internal/file_io.h" #include "iree/base/internal/flags.h" @@ -27,39 +23,54 @@ #include "iree/vm/api.h" #include "iree/vm/bytecode_module.h" -ABSL_FLAG(std::string, module_file, "-", +IREE_FLAG(string, module_file, "-", "File containing the module to load that contains the entry " "function. Defaults to stdin."); // TODO(hanchung): Extract the batch size using // iree_vm_function_reflection_attr. -ABSL_FLAG( - int, batch_size, 1, +IREE_FLAG( + int32_t, batch_size, 1, "The number of batch size, which is expected to match " "iree-hal-benchmark-dispatch-repeat-count when translating the module"); -ABSL_FLAG(std::string, entry_function, "", +IREE_FLAG(string, entry_function, "", "Name of a function contained in the module specified by module_file " "to run. If this is not set, all the exported functions will be " "benchmarked and they are expected to not have input arguments."); -ABSL_FLAG(std::string, driver, "vmla", "Backend driver to use."); +IREE_FLAG(string, driver, "vmla", "Backend driver to use."); -ABSL_FLAG(std::vector<std::string>, function_inputs, {}, - "A comma-separated list of of input buffers of the format:" - "[shape]xtype=[value]\n" - "2x2xi32=1 2 3 4\n" - "Optionally, brackets may be used to separate the element values. " - "They are ignored by the parser.\n" - "2x2xi32=[[1 2][3 4]]\n" - "Due to the absence of repeated flags in absl, commas should not be " - "used to separate elements. They are reserved for separating input " - "values:\n" - "2x2xi32=[[1 2][3 4]], 1x2xf32=[[1 2]]"); - -ABSL_FLAG(std::string, function_inputs_file, "", - "Provides a file for input shapes and optional values (see " - "ParseToVariantListFromFile in vm_util.h for details)"); +static iree_status_t parse_function_input(iree_string_view_t flag_name, + void* storage, + iree_string_view_t value) { + auto* list = (std::vector<std::string>*)storage; + list->push_back(std::string(value.data, value.size)); + return iree_ok_status(); +} +static void print_function_input(iree_string_view_t flag_name, void* storage, + FILE* file) { + auto* list = (std::vector<std::string>*)storage; + if (list->empty()) { + fprintf(file, "# --%.*s=\n", (int)flag_name.size, flag_name.data); + } else { + for (size_t i = 0; i < list->size(); ++i) { + fprintf(file, "--%.*s=\"%s\"\n", (int)flag_name.size, flag_name.data, + list->at(i).c_str()); + } + } +} +static std::vector<std::string> FLAG_function_inputs; +IREE_FLAG_CALLBACK( + parse_function_input, print_function_input, &FLAG_function_inputs, + function_input, + "An input value or buffer of the format:\n" + " [shape]xtype=[value]\n" + " 2x2xi32=1 2 3 4\n" + "Optionally, brackets may be used to separate the element values:\n" + " 2x2xi32=[[1 2][3 4]]\n" + "Each occurrence of the flag indicates an input in the order they were\n" + "specified on the command line."); namespace iree { namespace { @@ -91,7 +102,7 @@ iree_vm_function_t function, iree_vm_list_t* inputs, const std::vector<RawSignatureParser::Description>& output_descs) { auto benchmark_name = "BM_" + function_name; - int batch_size = absl::GetFlag(FLAGS_batch_size); + int batch_size = FLAG_batch_size; benchmark::RegisterBenchmark( benchmark_name.c_str(), [benchmark_name, batch_size, context, function, inputs, @@ -116,7 +127,7 @@ Status GetModuleContentsFromFlags(std::string* out_contents) { IREE_TRACE_SCOPE0("GetModuleContentsFromFlags"); - auto module_file = absl::GetFlag(FLAGS_module_file); + auto module_file = std::string(FLAG_module_file); if (module_file == "-") { *out_contents = std::string{std::istreambuf_iterator<char>(std::cin), std::istreambuf_iterator<char>()}; @@ -156,7 +167,7 @@ IREE_RETURN_IF_ERROR(Init()); } - auto function_name = absl::GetFlag(FLAGS_entry_function); + auto function_name = std::string(FLAG_entry_function); if (!function_name.empty()) { IREE_RETURN_IF_ERROR(RegisterSpecificFunction(function_name)); } else { @@ -178,7 +189,7 @@ // Create IREE's device and module. IREE_RETURN_IF_ERROR( - iree::CreateDevice(absl::GetFlag(FLAGS_driver), &device_)); + iree::CreateDevice(std::string(FLAG_driver), &device_)); IREE_RETURN_IF_ERROR(CreateHalModule(device_, &hal_module_)); IREE_RETURN_IF_ERROR(LoadBytecodeModule(module_data_, &input_module_)); @@ -206,15 +217,9 @@ // Construct inputs. std::vector<RawSignatureParser::Description> input_descs; IREE_RETURN_IF_ERROR(ParseInputSignature(function, &input_descs)); - if (!absl::GetFlag(FLAGS_function_inputs_file).empty()) { - IREE_RETURN_IF_ERROR(ParseToVariantListFromFile( - input_descs, iree_hal_device_allocator(device_), - absl::GetFlag(FLAGS_function_inputs_file), &inputs_)); - } else { - IREE_RETURN_IF_ERROR( - ParseToVariantList(input_descs, iree_hal_device_allocator(device_), - absl::GetFlag(FLAGS_function_inputs), &inputs_)); - } + IREE_CHECK_OK(ParseToVariantList(input_descs, + iree_hal_device_allocator(device_), + FLAG_function_inputs, &inputs_)); // Creates output signature. std::vector<RawSignatureParser::Description> output_descs; @@ -266,44 +271,12 @@ int main(int argc, char** argv) { IREE_TRACE_SCOPE0("main"); - // We have to contend with two flag parsing libraries here: absl's and - // benchmark's. To make matters worse, both define the `--help` flag. To - // ensure that each is able to parse its own flags, we use an absl "internal" - // function (still with public visibility) to parse while ignoring undefined - // flags. If it sees `--help` it will exit here, so we include the benchmark - // library usage information in the manually-set help output. Then we let - // benchmark parse its flags. Finally we call the normal initialization - // function to do other IREE initialization including flag parsing with - // normal options. Any remaining flags will be unknown and result in an error. - absl::SetProgramUsageMessage( - "iree-benchmark-module \n" - " --module_file=module.vmfb\n" - " --entry_function=exported_function_to_benchmark\n" - " If this is not set, all the exported functions will be \n" - " benchmarked and they are expected to not have input arguments\n" - " [--function_inputs=2xi32=1 2,1x2xf32=2 1 | \n" - " --function_inputs_file=file_with_function_inputs]\n" - " [--driver=vmla]\n" - "\n\n" - " Optional flags from third_party/benchmark/src/benchmark.cc:\n" - " [--benchmark_list_tests={true|false}]\n" - " [--benchmark_filter=<regex>]\n" - " [--benchmark_min_time=<min_time>]\n" - " [--benchmark_repetitions=<num_repetitions>]\n" - " [--benchmark_report_aggregates_only={true|false}]\n" - " [--benchmark_display_aggregates_only={true|false}]\n" - " [--benchmark_format=<console|json|csv>]\n" - " [--benchmark_out=<filename>]\n" - " [--benchmark_out_format=<json|console|csv>]\n" - " [--benchmark_color={auto|true|false}]\n" - " [--benchmark_counters_tabular={true|false}]\n" - " [--v=<verbosity>]\n"); - absl::flags_internal::ParseCommandLineImpl( - argc, argv, absl::flags_internal::ArgvListAction::kRemoveParsedArgs, - absl::flags_internal::UsageFlagsAction::kHandleUsage, - absl::flags_internal::OnUndefinedFlag::kIgnoreUndefined); + // Pass through flags to benchmark (allowing --help to fall through). + iree_flags_parse_checked(IREE_FLAGS_PARSE_MODE_UNDEFINED_OK | + IREE_FLAGS_PARSE_MODE_CONTINUE_AFTER_HELP, + &argc, &argv); ::benchmark::Initialize(&argc, argv); - iree_flags_parse_checked(&argc, &argv); + IREE_CHECK_OK(iree_hal_register_all_available_drivers( iree_hal_driver_registry_default()));
diff --git a/iree/tools/iree-check-module-main.cc b/iree/tools/iree-check-module-main.cc index fe0e00e..aee13e1 100644 --- a/iree/tools/iree-check-module-main.cc +++ b/iree/tools/iree-check-module-main.cc
@@ -14,7 +14,6 @@ #include <iostream> -#include "absl/flags/flag.h" #include "absl/strings/match.h" #include "absl/strings/string_view.h" #include "iree/base/api.h" @@ -41,9 +40,9 @@ #define IREE_FORCE_BINARY_STDIN() #endif // IREE_PLATFORM_WINDOWS -ABSL_FLAG(std::string, driver, "vmla", "Backend driver to use."); +IREE_FLAG(string, driver, "vmla", "Backend driver to use."); -ABSL_FLAG( +IREE_FLAG( bool, expect_failure, false, "Whether running module is expected to fail. If set, failing " "statuses from function evaluation are logged and ignored and all " @@ -104,7 +103,7 @@ IREE_RETURN_IF_ERROR(LoadBytecodeModule(module_data, &input_module)); iree_hal_device_t* device = nullptr; - IREE_RETURN_IF_ERROR(CreateDevice(absl::GetFlag(FLAGS_driver), &device)); + IREE_RETURN_IF_ERROR(CreateDevice(std::string(FLAG_driver), &device)); iree_vm_module_t* hal_module = nullptr; IREE_RETURN_IF_ERROR(CreateHalModule(device, &hal_module)); iree_vm_module_t* check_module = nullptr; @@ -179,7 +178,10 @@ } // namespace extern "C" int main(int argc, char** argv) { - iree_flags_parse_checked(&argc, &argv); + // Pass through flags to gtest (allowing --help to fall through). + iree_flags_parse_checked(IREE_FLAGS_PARSE_MODE_UNDEFINED_OK | + IREE_FLAGS_PARSE_MODE_CONTINUE_AFTER_HELP, + &argc, &argv); IREE_CHECK_OK(iree_hal_register_all_available_drivers( iree_hal_driver_registry_default())); ::testing::InitGoogleTest(&argc, argv); @@ -195,7 +197,7 @@ int exit_code = 1; auto status = Run(std::move(module_file_path), &exit_code); int ret = status.ok() ? exit_code : 1; - if (absl::GetFlag(FLAGS_expect_failure)) { + if (FLAG_expect_failure) { if (ret == 0) { std::cout << "Test passed but expected failure\n"; std::cout << status;
diff --git a/iree/tools/iree-run-mlir-main.cc b/iree/tools/iree-run-mlir-main.cc index 099b91c..8746c9a 100644 --- a/iree/tools/iree-run-mlir-main.cc +++ b/iree/tools/iree-run-mlir-main.cc
@@ -39,7 +39,6 @@ #include <iostream> #include <utility> -#include "absl/flags/flag.h" #include "absl/strings/match.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" @@ -120,13 +119,6 @@ llvm::cl::ZeroOrMore, }; -static llvm::cl::opt<std::string> function_inputs_file_flag{ - "function-input-file", - llvm::cl::desc("Provides a file for input shapes and optional values (see " - "ParseToVariantListFromFile in vm_util.h for details)"), - llvm::cl::init(""), -}; - static llvm::cl::opt<bool> run_flag{ "run", llvm::cl::desc("Runs the module (vs. just compiling and verifing)"), @@ -274,21 +266,11 @@ std::vector<RawSignatureParser::Description> input_descs; IREE_RETURN_IF_ERROR(ParseInputSignature(function, &input_descs)); vm::ref<iree_vm_list_t> inputs; - if (!function_inputs_file_flag.empty()) { - if (!function_inputs_flag.empty()) { - return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, - "expected only one of function_inputs and " - "function_inputs_file to be set"); - } - IREE_RETURN_IF_ERROR(ParseToVariantListFromFile( - input_descs, allocator, function_inputs_file_flag, &inputs)); - } else { - auto function_inputs_list = absl::MakeConstSpan( - function_inputs_flag.empty() ? nullptr : &function_inputs_flag.front(), - function_inputs_flag.size()); - IREE_RETURN_IF_ERROR(ParseToVariantList(input_descs, allocator, - function_inputs_list, &inputs)); - } + auto function_inputs_list = absl::MakeConstSpan( + function_inputs_flag.empty() ? nullptr : &function_inputs_flag.front(), + function_inputs_flag.size()); + IREE_RETURN_IF_ERROR(ParseToVariantList(input_descs, allocator, + function_inputs_list, &inputs)); std::vector<RawSignatureParser::Description> output_descs; IREE_RETURN_IF_ERROR(ParseOutputSignature(function, &output_descs)); @@ -520,7 +502,8 @@ } argc_absl += run_args_flag.size(); char** argv_absl_ptr = argv_absl.data(); - iree_flags_parse_checked(&argc_absl, &argv_absl_ptr); + iree_flags_parse_checked(IREE_FLAGS_PARSE_MODE_DEFAULT, &argc_absl, + &argv_absl_ptr); IREE_CHECK_OK(iree_hal_register_all_available_drivers( iree_hal_driver_registry_default()));
diff --git a/iree/tools/iree-run-module-main.cc b/iree/tools/iree-run-module-main.cc index d4015a3..24d4934 100644 --- a/iree/tools/iree-run-module-main.cc +++ b/iree/tools/iree-run-module-main.cc
@@ -14,7 +14,6 @@ #include <iostream> -#include "absl/flags/flag.h" #include "absl/strings/string_view.h" #include "iree/base/internal/file_io.h" #include "iree/base/internal/flags.h" @@ -26,38 +25,53 @@ #include "iree/vm/api.h" #include "iree/vm/bytecode_module.h" -ABSL_FLAG(std::string, module_file, "-", +IREE_FLAG(string, module_file, "-", "File containing the module to load that contains the entry " "function. Defaults to stdin."); -ABSL_FLAG(std::string, entry_function, "", +IREE_FLAG(string, entry_function, "", "Name of a function contained in the module specified by module_file " "to run."); -ABSL_FLAG(std::string, driver, "vmla", "Backend driver to use."); +IREE_FLAG(string, driver, "vmla", "Backend driver to use."); -ABSL_FLAG(std::vector<std::string>, function_inputs, {}, - "A comma-separated list of of input buffers of the format:" - "[shape]xtype=[value]\n" - "2x2xi32=1 2 3 4\n" - "Optionally, brackets may be used to separate the element values. " - "They are ignored by the parser.\n" - "2x2xi32=[[1 2][3 4]]\n" - "Due to the absence of repeated flags in absl, commas should not be " - "used to separate elements. They are reserved for separating input " - "values:\n" - "2x2xi32=[[1 2][3 4]], 1x2xf32=[[1 2]]"); - -ABSL_FLAG(std::string, function_inputs_file, "", - "Provides a file for input shapes and optional values (see " - "ParseToVariantListFromFile in vm_util.h for details)"); +static iree_status_t parse_function_input(iree_string_view_t flag_name, + void* storage, + iree_string_view_t value) { + auto* list = (std::vector<std::string>*)storage; + list->push_back(std::string(value.data, value.size)); + return iree_ok_status(); +} +static void print_function_input(iree_string_view_t flag_name, void* storage, + FILE* file) { + auto* list = (std::vector<std::string>*)storage; + if (list->empty()) { + fprintf(file, "# --%.*s=\n", (int)flag_name.size, flag_name.data); + } else { + for (size_t i = 0; i < list->size(); ++i) { + fprintf(file, "--%.*s=\"%s\"\n", (int)flag_name.size, flag_name.data, + list->at(i).c_str()); + } + } +} +static std::vector<std::string> FLAG_function_inputs; +IREE_FLAG_CALLBACK( + parse_function_input, print_function_input, &FLAG_function_inputs, + function_input, + "An input value or buffer of the format:\n" + " [shape]xtype=[value]\n" + " 2x2xi32=1 2 3 4\n" + "Optionally, brackets may be used to separate the element values:\n" + " 2x2xi32=[[1 2][3 4]]\n" + "Each occurrence of the flag indicates an input in the order they were\n" + "specified on the command line."); namespace iree { namespace { Status GetModuleContentsFromFlags(std::string* out_contents) { IREE_TRACE_SCOPE0("GetModuleContentsFromFlags"); - auto module_file = absl::GetFlag(FLAGS_module_file); + auto module_file = std::string(FLAG_module_file); if (module_file == "-") { *out_contents = std::string{std::istreambuf_iterator<char>(std::cin), std::istreambuf_iterator<char>()}; @@ -84,7 +98,7 @@ IREE_RETURN_IF_ERROR(LoadBytecodeModule(module_data, &input_module)); iree_hal_device_t* device = nullptr; - IREE_RETURN_IF_ERROR(CreateDevice(absl::GetFlag(FLAGS_driver), &device)); + IREE_RETURN_IF_ERROR(CreateDevice(std::string(FLAG_driver), &device)); iree_vm_module_t* hal_module = nullptr; IREE_RETURN_IF_ERROR(CreateHalModule(device, &hal_module)); @@ -96,7 +110,7 @@ iree_allocator_system(), &context), "creating context"); - std::string function_name = absl::GetFlag(FLAGS_entry_function); + std::string function_name = std::string(FLAG_entry_function); iree_vm_function_t function; if (function_name.empty()) { return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, @@ -115,20 +129,9 @@ IREE_RETURN_IF_ERROR(ParseInputSignature(function, &input_descs)); vm::ref<iree_vm_list_t> inputs; - if (!absl::GetFlag(FLAGS_function_inputs_file).empty()) { - if (!absl::GetFlag(FLAGS_function_inputs).empty()) { - return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, - "expected only one of function_inputs and " - "function_inputs_file to be set"); - } - IREE_RETURN_IF_ERROR(ParseToVariantListFromFile( - input_descs, iree_hal_device_allocator(device), - absl::GetFlag(FLAGS_function_inputs_file), &inputs)); - } else { - IREE_RETURN_IF_ERROR(ParseToVariantList( - input_descs, iree_hal_device_allocator(device), - absl::MakeConstSpan(absl::GetFlag(FLAGS_function_inputs)), &inputs)); - } + IREE_CHECK_OK(ParseToVariantList(input_descs, + iree_hal_device_allocator(device), + FLAG_function_inputs, &inputs)); std::vector<RawSignatureParser::Description> output_descs; IREE_RETURN_IF_ERROR(ParseOutputSignature(function, &output_descs)); @@ -159,7 +162,7 @@ } // namespace extern "C" int main(int argc, char** argv) { - iree_flags_parse_checked(&argc, &argv); + iree_flags_parse_checked(IREE_FLAGS_PARSE_MODE_DEFAULT, &argc, &argv); IREE_CHECK_OK(iree_hal_register_all_available_drivers( iree_hal_driver_registry_default())); IREE_CHECK_OK(Run());
diff --git a/iree/tools/test/benchmark_flags.txt b/iree/tools/test/benchmark_flags.txt index c1ce1fc..7a114bc 100644 --- a/iree/tools/test/benchmark_flags.txt +++ b/iree/tools/test/benchmark_flags.txt
@@ -1,22 +1,16 @@ -// HELP: iree-benchmark-module +// RUN: ( iree-benchmark-module --help || [[ $? == 1 ]] ) | IreeFileCheck --check-prefix=HELP %s // HELP: --module_file // HELP: --benchmark_list_tests -// RUN: ( iree-benchmark-module --help || [[ $? == 1 ]] ) | IreeFileCheck --check-prefix=HELP %s -// RUN: ( iree-benchmark-module --helpshort || [[ $? == 1 ]] ) | IreeFileCheck --check-prefix=HELP %s -// UNKNOWN: unknown-flag -// RUN: ( iree-benchmark-module --unknown-flag 2>&1 || [[ $? == 1 ]] ) | IreeFileCheck --check-prefix=UNKNOWN %s -// RUN: ( iree-benchmark-module --driver=vmla --unknown-flag --benchmark_list_tests 2>&1 || [[ $? == 1 ]] ) | IreeFileCheck --check-prefix=UNKNOWN %s - -// LIST-BENCHMARKS: BM_foo1 -// LIST-BENCHMARKS: BM_foo2 // RUN: ( iree-translate --iree-hal-target-backends=vmla -iree-mlir-to-vm-bytecode-module %s | iree-benchmark-module --benchmark_list_tests --driver=vmla --benchmark_list_tests ) | IreeFileCheck --check-prefix=LIST-BENCHMARKS %s module { + // LIST-BENCHMARKS: BM_foo1 func @foo1() -> tensor<4xf32> attributes { iree.module.export } { %input = iree.unfoldable_constant dense<[0.0, 1.0, 2.0, 4.0]> : tensor<4xf32> %result = "mhlo.exponential"(%input) : (tensor<4xf32>) -> tensor<4xf32> return %result : tensor<4xf32> } + // LIST-BENCHMARKS: BM_foo2 func @foo2() -> tensor<4xf32> attributes { iree.module.export } { %input = iree.unfoldable_constant dense<[0.0, 1.0, 2.0, 4.0]> : tensor<4xf32> %result = "mhlo.abs"(%input) : (tensor<4xf32>) -> tensor<4xf32>
diff --git a/iree/tools/test/iree-benchmark-module.mlir b/iree/tools/test/iree-benchmark-module.mlir index a75b639..c4a4e32 100644 --- a/iree/tools/test/iree-benchmark-module.mlir +++ b/iree/tools/test/iree-benchmark-module.mlir
@@ -1,6 +1,6 @@ -// RUN: iree-translate --iree-hal-target-backends=vmla -iree-mlir-to-vm-bytecode-module %s | iree-benchmark-module --driver=vmla --entry_function=abs --function_inputs="i32=-2" | IreeFileCheck %s -// RUN: [[ $IREE_VULKAN_DISABLE == 1 ]] || (iree-translate --iree-hal-target-backends=vulkan-spirv -iree-mlir-to-vm-bytecode-module %s | iree-benchmark-module --driver=vulkan --entry_function=abs --function_inputs="i32=-2" | IreeFileCheck %s) -// RUN: [[ $IREE_LLVMAOT_DISABLE == 1 ]] || (iree-translate --iree-hal-target-backends=dylib-llvm-aot -iree-mlir-to-vm-bytecode-module %s | iree-benchmark-module --driver=dylib --entry_function=abs --function_inputs="i32=-2" | IreeFileCheck %s) +// RUN: iree-translate --iree-hal-target-backends=vmla -iree-mlir-to-vm-bytecode-module %s | iree-benchmark-module --driver=vmla --entry_function=abs --function_input=i32=-2 | IreeFileCheck %s +// RUN: [[ $IREE_VULKAN_DISABLE == 1 ]] || (iree-translate --iree-hal-target-backends=vulkan-spirv -iree-mlir-to-vm-bytecode-module %s | iree-benchmark-module --driver=vulkan --entry_function=abs --function_input=i32=-2 | IreeFileCheck %s) +// RUN: [[ $IREE_LLVMAOT_DISABLE == 1 ]] || (iree-translate --iree-hal-target-backends=dylib-llvm-aot -iree-mlir-to-vm-bytecode-module %s | iree-benchmark-module --driver=dylib --entry_function=abs --function_input=i32=-2 | IreeFileCheck %s) // CHECK-LABEL: BM_abs func @abs(%input : tensor<i32>) -> (tensor<i32>) attributes { iree.module.export } {
diff --git a/iree/tools/test/iree-run-module.mlir b/iree/tools/test/iree-run-module.mlir index a97daf9..763195a 100644 --- a/iree/tools/test/iree-run-module.mlir +++ b/iree/tools/test/iree-run-module.mlir
@@ -1,6 +1,6 @@ -// RUN: (iree-translate --iree-hal-target-backends=vmla -iree-mlir-to-vm-bytecode-module %s | iree-run-module --driver=vmla --entry_function=abs --function_inputs="i32=-2") | IreeFileCheck %s -// RUN: [[ $IREE_VULKAN_DISABLE == 1 ]] || ((iree-translate --iree-hal-target-backends=vulkan-spirv -iree-mlir-to-vm-bytecode-module %s | iree-run-module --driver=vulkan --entry_function=abs --function_inputs="i32=-2") | IreeFileCheck %s) -// RUN: [[ $IREE_LLVMAOT_DISABLE == 1 ]] || ((iree-translate --iree-hal-target-backends=dylib-llvm-aot -iree-mlir-to-vm-bytecode-module %s | iree-run-module --driver=dylib --entry_function=abs --function_inputs="i32=-2") | IreeFileCheck %s) +// RUN: (iree-translate --iree-hal-target-backends=vmla -iree-mlir-to-vm-bytecode-module %s | iree-run-module --driver=vmla --entry_function=abs --function_input=i32=-2) | IreeFileCheck %s +// RUN: [[ $IREE_VULKAN_DISABLE == 1 ]] || ((iree-translate --iree-hal-target-backends=vulkan-spirv -iree-mlir-to-vm-bytecode-module %s | iree-run-module --driver=vulkan --entry_function=abs --function_input=i32=-2) | IreeFileCheck %s) +// RUN: [[ $IREE_LLVMAOT_DISABLE == 1 ]] || ((iree-translate --iree-hal-target-backends=dylib-llvm-aot -iree-mlir-to-vm-bytecode-module %s | iree-run-module --driver=dylib --entry_function=abs --function_input=i32=-2) | IreeFileCheck %s) // CHECK-LABEL: EXEC @abs func @abs(%input : tensor<i32>) -> (tensor<i32>) attributes { iree.module.export } {
diff --git a/iree/tools/test/multiple_args.mlir b/iree/tools/test/multiple_args.mlir index f28afa7..cb725ce 100644 --- a/iree/tools/test/multiple_args.mlir +++ b/iree/tools/test/multiple_args.mlir
@@ -1,6 +1,6 @@ -// RUN: iree-translate --iree-hal-target-backends=vmla -iree-mlir-to-vm-bytecode-module %s | iree-run-module --entry_function=multi_input --function_inputs='2xi32=[1 2], 2xi32=[3 4]' | IreeFileCheck %s +// RUN: iree-translate --iree-hal-target-backends=vmla -iree-mlir-to-vm-bytecode-module %s | iree-run-module --entry_function=multi_input --function_input="2xi32=[1 2]" --function_input="2xi32=[3 4]" | IreeFileCheck %s // RUN: iree-run-mlir --iree-hal-target-backends=vmla --function-input='2xi32=[1 2]' --function-input='2xi32=[3 4]' %s | IreeFileCheck %s -// RUN: iree-translate --iree-hal-target-backends=vmla -iree-mlir-to-vm-bytecode-module %s | iree-benchmark-module --driver=vmla --entry_function=multi_input --function_inputs='2xi32=[1 2], 2xi32=[3 4]' | IreeFileCheck --check-prefix=BENCHMARK %s +// RUN: iree-translate --iree-hal-target-backends=vmla -iree-mlir-to-vm-bytecode-module %s | iree-benchmark-module --driver=vmla --entry_function=multi_input --function_input="2xi32=[1 2]" --function_input="2xi32=[3 4]" | IreeFileCheck --check-prefix=BENCHMARK %s // BENCHMARK-LABEL: BM_multi_input // CHECK-LABEL: EXEC @multi_input
diff --git a/iree/tools/test/scalars.mlir b/iree/tools/test/scalars.mlir index daf1872..d6a5546 100644 --- a/iree/tools/test/scalars.mlir +++ b/iree/tools/test/scalars.mlir
@@ -1,5 +1,5 @@ -// RUN: (iree-translate --iree-hal-target-backends=vmla -iree-mlir-to-vm-bytecode-module %s | iree-run-module --entry_function=scalar --function_inputs='i32=42') | IreeFileCheck %s -// RUN: iree-translate --iree-hal-target-backends=vmla -iree-mlir-to-vm-bytecode-module %s | iree-benchmark-module --driver=vmla --entry_function=scalar --function_inputs='i32=42' | IreeFileCheck --check-prefix=BENCHMARK %s +// RUN: (iree-translate --iree-hal-target-backends=vmla -iree-mlir-to-vm-bytecode-module %s | iree-run-module --entry_function=scalar --function_input=i32=42) | IreeFileCheck %s +// RUN: iree-translate --iree-hal-target-backends=vmla -iree-mlir-to-vm-bytecode-module %s | iree-benchmark-module --driver=vmla --entry_function=scalar --function_input=i32=42 | IreeFileCheck --check-prefix=BENCHMARK %s // RUN: (iree-run-mlir --iree-hal-target-backends=vmla --function-input=i32=42 %s) | IreeFileCheck %s // BENCHMARK-LABEL: BM_scalar
diff --git a/iree/tools/utils/vm_util.cc b/iree/tools/utils/vm_util.cc index 27f0cd2..3b3e06d 100644 --- a/iree/tools/utils/vm_util.cc +++ b/iree/tools/utils/vm_util.cc
@@ -177,17 +177,6 @@ return ParseToVariantList(descs, allocator, input_views, out_list); } -Status ParseToVariantListFromFile( - absl::Span<const RawSignatureParser::Description> descs, - iree_hal_allocator_t* allocator, const std::string& filename, - iree_vm_list_t** out_list) { - std::string contents; - IREE_RETURN_IF_ERROR(file_io::GetFileContents(filename.c_str(), &contents)); - std::vector<absl::string_view> input_views( - absl::StrSplit(contents, '\n', absl::SkipEmpty())); - return ParseToVariantList(descs, allocator, input_views, out_list); -} - Status PrintVariantList(absl::Span<const RawSignatureParser::Description> descs, iree_vm_list_t* variant_list, std::ostream* os) { for (int i = 0; i < iree_vm_list_size(variant_list); ++i) {
diff --git a/iree/tools/utils/vm_util.h b/iree/tools/utils/vm_util.h index 19cf92a..006961a 100644 --- a/iree/tools/utils/vm_util.h +++ b/iree/tools/utils/vm_util.h
@@ -62,14 +62,6 @@ iree_hal_allocator_t* allocator, absl::Span<const std::string> input_strings, iree_vm_list_t** out_list); -// Parses the content in |filename| into a variant list of VM scalars and -// buffers. See ParseToVariantList for the format of scalars and buffers. The -// inputs are expected to be newline-separated. -Status ParseToVariantListFromFile( - absl::Span<const RawSignatureParser::Description> descs, - iree_hal_allocator_t* allocator, const std::string& filename, - iree_vm_list_t** out_list); - // Prints a variant list of VM scalars and buffers to |os|. // Prints scalars in the format: // type=value
diff --git a/iree/vm/bytecode_dispatch.c b/iree/vm/bytecode_dispatch.c index c741706..caf9c4a 100644 --- a/iree/vm/bytecode_dispatch.c +++ b/iree/vm/bytecode_dispatch.c
@@ -889,9 +889,18 @@ return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, "list is null"); } uint32_t index = VM_DecOperandRegI32("index"); + const iree_vm_type_def_t* type_def = VM_DecTypeOf("result"); bool result_is_move; iree_vm_ref_t* result = VM_DecResultRegRef("result", &result_is_move); - return iree_vm_list_get_ref_retain(list, index, result); + // TODO(benvanik): use result_is_move with a _retain_or_move. + IREE_RETURN_IF_ERROR(iree_vm_list_get_ref_retain(list, index, result)); + if (result->type != IREE_VM_REF_TYPE_NULL && + (iree_vm_type_def_is_value(type_def) || + result->type != type_def->ref_type)) { + // Type mismatch; put null in the register instead. + // TODO(benvanik): return an error here and make a query type method? + iree_vm_ref_release(result); + } }); DISPATCH_OP(CORE, ListSetRef, { @@ -905,9 +914,9 @@ bool operand_is_move; iree_vm_ref_t* operand = VM_DecOperandRegRef("value", &operand_is_move); if (operand_is_move) { - return iree_vm_list_set_ref_move(list, index, operand); + IREE_RETURN_IF_ERROR(iree_vm_list_set_ref_move(list, index, operand)); } else { - return iree_vm_list_set_ref_retain(list, index, operand); + IREE_RETURN_IF_ERROR(iree_vm_list_set_ref_retain(list, index, operand)); } });
diff --git a/iree/vm/native_module.c b/iree/vm/native_module.c index 5ad605d..ca5a63f 100644 --- a/iree/vm/native_module.c +++ b/iree/vm/native_module.c
@@ -74,19 +74,14 @@ static void IREE_API_PTR iree_vm_native_module_destroy(void* self) { iree_vm_native_module_t* module = (iree_vm_native_module_t*)self; + iree_allocator_t allocator = module->allocator; // Destroy the optional user-provided self. - if (module->self == module) { - iree_allocator_t allocator = module->allocator; - if (module->user_interface.destroy) { - module->user_interface.destroy(module->self); - } - iree_allocator_free(allocator, module); - } else { - if (module->user_interface.destroy) { - module->user_interface.destroy(module->self); - } + if (module->user_interface.destroy) { + module->user_interface.destroy(module->self); } + + iree_allocator_free(allocator, module); } static iree_string_view_t IREE_API_PTR iree_vm_native_module_name(void* self) {
diff --git a/iree/vm/ref.c b/iree/vm/ref.c index d89d292..6fb80c5 100644 --- a/iree/vm/ref.c +++ b/iree/vm/ref.c
@@ -151,6 +151,8 @@ IREE_API_EXPORT void IREE_API_CALL iree_vm_ref_retain(iree_vm_ref_t* ref, iree_vm_ref_t* out_ref) { + // NOTE: ref and out_ref may alias. + iree_vm_ref_t temp_ref = *ref; if (ref != out_ref && ref->ptr != out_ref->ptr) { // Output ref contains a value that should be released first. // Note that we check for it being the same as the new value so we don't @@ -159,7 +161,7 @@ } // Assign ref to out_ref and increment the counter. - memmove(out_ref, ref, sizeof(*out_ref)); + *out_ref = temp_ref; if (out_ref->ptr) { volatile iree_atomic_ref_count_t* counter = iree_vm_get_ref_counter_ptr(out_ref); @@ -179,13 +181,15 @@ IREE_API_EXPORT void IREE_API_CALL iree_vm_ref_retain_or_move( int is_move, iree_vm_ref_t* ref, iree_vm_ref_t* out_ref) { + // NOTE: ref and out_ref may alias. + iree_vm_ref_t temp_ref = *ref; if (ref != out_ref) { // Output ref contains a value that should be released first. iree_vm_ref_release(out_ref); } // Assign ref to out_ref and increment the counter if not moving. - memmove(out_ref, ref, sizeof(*out_ref)); + *out_ref = temp_ref; if (out_ref->ptr && !is_move) { // Retain by incrementing counter and preserving the source ref. volatile iree_atomic_ref_count_t* counter = @@ -228,6 +232,8 @@ IREE_API_EXPORT void IREE_API_CALL iree_vm_ref_assign(iree_vm_ref_t* ref, iree_vm_ref_t* out_ref) { + // NOTE: ref and out_ref may alias. + iree_vm_ref_t temp_ref = *ref; if (ref == out_ref) { // Source == target; ignore. return; @@ -237,11 +243,13 @@ } // Assign ref to out_ref (without incrementing counter). - memcpy(out_ref, ref, sizeof(*out_ref)); + *out_ref = temp_ref; } IREE_API_EXPORT void IREE_API_CALL iree_vm_ref_move(iree_vm_ref_t* ref, iree_vm_ref_t* out_ref) { + // NOTE: ref and out_ref may alias. + iree_vm_ref_t temp_ref = *ref; if (ref == out_ref) { // Source == target; ignore. return; @@ -251,7 +259,7 @@ } // Assign ref to out_ref (without incrementing counter). - memcpy(out_ref, ref, sizeof(*out_ref)); + *out_ref = temp_ref; // Reset input ref so it points at nothing. memset(ref, 0, sizeof(*ref));
diff --git a/iree/vm/ref.h b/iree/vm/ref.h index 5e05f67..2ae0512 100644 --- a/iree/vm/ref.h +++ b/iree/vm/ref.h
@@ -95,9 +95,6 @@ typedef void(IREE_API_PTR* iree_vm_ref_destroy_t)(void* ptr); -#define IREE_VM_REF_DESTROY_FREE free -#define IREE_VM_REF_DESTROY_CC_DELETE +[](void* ptr) { delete ptr; } - // Describes a type for the VM. typedef struct { // Function called when references of this type reach 0 and should be
diff --git a/iree/vm/test/list_ops.mlir b/iree/vm/test/list_ops.mlir index 45bb147..514de69 100644 --- a/iree/vm/test/list_ops.mlir +++ b/iree/vm/test/list_ops.mlir
@@ -81,6 +81,32 @@ } //===--------------------------------------------------------------------===// + // Multiple lists within the same block + //===--------------------------------------------------------------------===// + + vm.export @test_multiple_lists + vm.func @test_multiple_lists() { + %c0 = vm.const.i32 0 : i32 + %c1 = vm.const.i32 1 : i32 + %c27 = vm.const.i32 27 : i32 + %c42 = vm.const.i32 42 : i32 + + // These allocs shouldn't be CSE'd. + %list0 = vm.list.alloc %c1 : (i32) -> !vm.list<i8> + %list1 = vm.list.alloc %c1 : (i32) -> !vm.list<i8> + vm.list.resize %list0, %c1 : (!vm.list<i8>, i32) + vm.list.resize %list1, %c1 : (!vm.list<i8>, i32) + vm.list.set.i32 %list0, %c0, %c27 : (!vm.list<i8>, i32, i32) + vm.list.set.i32 %list1, %c0, %c42 : (!vm.list<i8>, i32, i32) + %res0 = vm.list.get.i32 %list0, %c0 : (!vm.list<i8>, i32) -> i32 + %res1 = vm.list.get.i32 %list1, %c0 : (!vm.list<i8>, i32) -> i32 + vm.check.eq %res0, %c27, "list0.get(0)=27" : i32 + vm.check.eq %res1, %c42, "list1.get(0)=42" : i32 + + vm.return + } + + //===--------------------------------------------------------------------===// // Failure tests //===--------------------------------------------------------------------===//
diff --git a/iree/vm/test/list_variant_ops.mlir b/iree/vm/test/list_variant_ops.mlir index 5f50d03..e26c218 100644 --- a/iree/vm/test/list_variant_ops.mlir +++ b/iree/vm/test/list_variant_ops.mlir
@@ -40,7 +40,7 @@ %inner0_e2 = vm.list.get.i32 %inner0_ret, %c2 : (!vm.list<i32>, i32) -> i32 vm.check.eq %inner0_e2, %c102 : i32 - %inner1_ret = vm.list.get.ref %outer, %c0 : (!vm.list<!vm.list<i32>>, i32) -> !vm.list<i32> + %inner1_ret = vm.list.get.ref %outer, %c1 : (!vm.list<!vm.list<i32>>, i32) -> !vm.list<i32> vm.check.eq %inner1_ret, %inner1 : !vm.list<i32> %inner1_e2 = vm.list.get.i32 %inner1_ret, %c2 : (!vm.list<i32>, i32) -> i32 vm.check.eq %inner1_e2, %c100 : i32 @@ -89,8 +89,8 @@ vm.return } - vm.export @test_variant_slot_change - vm.func @test_variant_slot_change() { + vm.export @fail_variant_slot_change + vm.func @fail_variant_slot_change() { %capacity = vm.const.i32 42 : i32 %list = vm.list.alloc %capacity : (i32) -> !vm.list<?> vm.list.resize %list, %capacity : (!vm.list<?>, i32) @@ -109,9 +109,10 @@ %e10_buf = vm.list.get.ref %list, %c10 : (!vm.list<?>, i32) -> !vm.ref<!iree.byte_buffer> vm.check.eq %e10_buf, %v10_buf : !vm.ref<!iree.byte_buffer> - // Accessing it as an i32 now that it stores the ref should return a - // default (until we support type queries). + // Accessing it as an i32 now that it stores the ref should fail at runtime. + // TODO(benvanik): support type queries and/or make this silently return 0. %e10_any = vm.list.get.i32 %list, %c10 : (!vm.list<?>, i32) -> i32 + // -- FAILURE HERE -- %zero = vm.const.i32.zero : i32 vm.check.eq %e10_any, %zero : i32