Synchronize submodules with LLVM at llvm/llvm-project@6a7a2ee8161d Updates LLVM dependencies to match [6a7a2ee8161d](https://github.com/llvm/llvm-project/commit/6a7a2ee8161d). - TensorFlow to [4f8db85aa9ab](https://github.com/tensorflow/tensorflow/commit/4f8db85aa9ab) - MLIR-HLO to [c65dc7a455dc](https://github.com/tensorflow/mlir-hlo/commit/c65dc7a455dc) `./scripts/git/update_to_llvm_syncpoint.py ` PiperOrigin-RevId: 387433820
diff --git a/.gitignore b/.gitignore index 4f52c93..6fe9bf2 100644 --- a/.gitignore +++ b/.gitignore
@@ -58,3 +58,6 @@ # Generated documentation files mkdocs/site/ docs/website/site/ + +# Temporary files +iree/compiler/Dialect/HAL/Target/LLVM/librt/bin/librt.ll
diff --git a/CMakeLists.txt b/CMakeLists.txt index c466088..e4df9d5 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt
@@ -322,6 +322,23 @@ endif() #------------------------------------------------------------------------------- +# Check if git submodules have been initialized. +# This will only run if python3 is available. +#------------------------------------------------------------------------------- + +find_package(Python3 COMPONENTS Interpreter QUIET) +if(Python3_FOUND) + execute_process( + COMMAND ${Python3_EXECUTABLE} scripts/git/check_submodule_init.py + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} + RESULT_VARIABLE ret + ) + if(NOT ret EQUAL "0") + message(FATAL_ERROR "check_submodule_init.py failed, see the logs above") + endif() +endif() + +#------------------------------------------------------------------------------- # MLIR/LLVM Dependency # We treat the LLVM dependency specially because we support several different # ways to use it:
diff --git a/bindings/python/iree/compiler/xla.py b/bindings/python/iree/compiler/xla.py index eb04231..a0b0578 100644 --- a/bindings/python/iree/compiler/xla.py +++ b/bindings/python/iree/compiler/xla.py
@@ -74,6 +74,7 @@ import_format: Union[ImportFormat, str] = ImportFormat.BINARY_PROTO, import_extra_args: Sequence[str] = (), + save_temp_mhlo_input: Optional[str] = None, save_temp_iree_input: Optional[str] = None, **kwargs): """Initialize options from keywords. @@ -87,6 +88,7 @@ self.import_only = import_only self.import_format = ImportFormat.parse(import_format) self.import_extra_args = import_extra_args + self.save_temp_mhlo_input = save_temp_mhlo_input self.save_temp_iree_input = save_temp_iree_input @@ -121,6 +123,10 @@ cl.append("--mlir-print-op-generic") # Save temps flags. + save_mhlo_input = tfs.alloc_optional("tf-mhlo.mlir", + export_as=options.save_temp_mhlo_input) + if save_mhlo_input: + cl.append(f"--save-temp-mhlo-input={save_mhlo_input}") iree_input = tfs.alloc_optional("xla-iree-input.mlir", export_as=options.save_temp_iree_input) if iree_input:
diff --git a/build_tools/buildkite/cmake/build_configurations.yml b/build_tools/buildkite/cmake/build_configurations.yml index fb79ee9..426e5fb 100644 --- a/build_tools/buildkite/cmake/build_configurations.yml +++ b/build_tools/buildkite/cmake/build_configurations.yml
@@ -5,7 +5,7 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception steps: - - label: "Build with tracing enabled" + - label: ":zap: Build with tracing enabled" commands: - "./scripts/git/submodule_versions.py init" - "docker run --user=$(id -u):$(id -g) --volume=\\$PWD:\\$IREE_DOCKER_WORKDIR --workdir=\\$IREE_DOCKER_WORKDIR --rm gcr.io/iree-oss/cmake@sha256:9d9953acf5ca0cf1ff3e8de32f10f24dfab1c4e8ec5d1fc047f556024ee4bed6 ./build_tools/cmake/build_tracing.sh" @@ -14,7 +14,7 @@ agents: - "queue=build" - - label: "Build the runtime only" + - label: ":hammer_and_wrench: Build the runtime only" commands: - "./scripts/git/submodule_versions.py init" - "docker run --user=$(id -u):$(id -g) --volume=\\$PWD:\\$IREE_DOCKER_WORKDIR --workdir=\\$IREE_DOCKER_WORKDIR --rm gcr.io/iree-oss/cmake@sha256:9d9953acf5ca0cf1ff3e8de32f10f24dfab1c4e8ec5d1fc047f556024ee4bed6 ./build_tools/cmake/build_runtime.sh" @@ -22,3 +22,27 @@ IREE_DOCKER_WORKDIR: "/usr/src/github/iree" agents: - "queue=build" + + - label: ":linux: Build host install" + key: "build-host-install" + commands: + - "./scripts/git/submodule_versions.py init" + - "docker run --user=$(id -u):$(id -g) --volume=\\$PWD:\\$IREE_DOCKER_WORKDIR --workdir=\\$IREE_DOCKER_WORKDIR --rm gcr.io/iree-oss/cmake@sha256:9d9953acf5ca0cf1ff3e8de32f10f24dfab1c4e8ec5d1fc047f556024ee4bed6 ./build_tools/cmake/build_host_install.sh" + - "tar -czvf build-artifacts.tgz build-host/install" + artifact_paths: "build-artifacts.tgz" + env: + IREE_DOCKER_WORKDIR: "/usr/src/github/iree" + agents: + - "queue=build" + + - label: ":webassembly: Build WebAssembly runtime with Emscripten" + depends_on: "build-host-install" + commands: + - "buildkite-agent artifact download --step build-host-install build-artifacts.tgz ./" + - "tar xzf build-artifacts.tgz" + - "./scripts/git/submodule_versions.py init" + - "docker run --user=$(id -u):$(id -g) --volume=\\$PWD:\\$IREE_DOCKER_WORKDIR --workdir=\\$IREE_DOCKER_WORKDIR --rm gcr.io/iree-oss/cmake-emscripten@sha256:8acad361d23cb586187c2ea29df3a1ab301b5283c3648beb328681d69ecd0ab0 ./build_tools/cmake/build_runtime_emscripten.sh" + env: + IREE_DOCKER_WORKDIR: "/usr/src/github/iree" + agents: + - "queue=build"
diff --git a/build_tools/cmake/build_host_install.sh b/build_tools/cmake/build_host_install.sh new file mode 100755 index 0000000..b5f8130 --- /dev/null +++ b/build_tools/cmake/build_host_install.sh
@@ -0,0 +1,36 @@ +#!/bin/bash +# Copyright 2021 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +# Builds host binaries (compiler tools and other utilities) and installs them +# in build-host/install. Designed for CI, but can be run manually. This uses +# previously cached build results and does not clear build directories. + +set -x +set -e + +CMAKE_BIN=${CMAKE_BIN:-$(which cmake)} +"${CMAKE_BIN?}" --version +ninja --version + +ROOT_DIR=$(git rev-parse --show-toplevel) +cd ${ROOT_DIR?} + +if [ -d "build-host" ] +then + echo "build-host directory already exists. Will use cached results there." +else + echo "build-host directory does not already exist. Creating a new one." + mkdir build-host +fi +cd build-host + +# Configure, build, install. +"${CMAKE_BIN?}" -G Ninja .. \ + -DCMAKE_INSTALL_PREFIX=./install \ + -DIREE_BUILD_TESTS=OFF \ + -DIREE_BUILD_SAMPLES=OFF +"${CMAKE_BIN?}" --build . --target install
diff --git a/build_tools/cmake/build_runtime_emscripten.sh b/build_tools/cmake/build_runtime_emscripten.sh new file mode 100755 index 0000000..ca2abc3 --- /dev/null +++ b/build_tools/cmake/build_runtime_emscripten.sh
@@ -0,0 +1,51 @@ +#!/bin/bash +# Copyright 2021 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +# Cross-compile IREE's runtime through Emscripten to WebAssembly with CMake. +# Designed for CI, but can be run manually. This uses previously cached build +# results and does not clear build directories. +# +# Host binaries (e.g. compiler tools) should already be built at +# ./build-host/install. Emscripten binaries (e.g. .wasm and .js files) will be +# built in ./build-emscripten/. + +set -x +set -e + +if ! command -v emcmake &> /dev/null +then + echo "'emcmake' not found, setup environment according to https://emscripten.org/docs/getting_started/downloads.html" + exit +fi + +CMAKE_BIN=${CMAKE_BIN:-$(which cmake)} +"${CMAKE_BIN?}" --version +ninja --version + +ROOT_DIR=$(git rev-parse --show-toplevel) +cd ${ROOT_DIR?} + +if [ -d "build-emscripten" ] +then + echo "build-emscripten directory already exists. Will use cached results there." +else + echo "build-emscripten directory does not already exist. Creating a new one." + mkdir build-emscripten +fi +cd build-emscripten + +# Configure using Emscripten's CMake wrapper, then build. +emcmake "${CMAKE_BIN?}" -G Ninja .. \ + -DIREE_HOST_BINARY_ROOT=$PWD/../build-host/install \ + -DIREE_HAL_DRIVERS_TO_BUILD=VMVX\;DyLib \ + -DIREE_BUILD_COMPILER=OFF \ + -DIREE_ENABLE_MLIR=OFF \ + -DIREE_BUILD_TESTS=OFF \ + -DIREE_BUILD_SAMPLES=ON + +# TODO(scotttodd): expand this list of targets +"${CMAKE_BIN?}" --build . --target iree_samples_simple_embedding_simple_embedding_vmvx_sync
diff --git a/build_tools/docker/cmake-emscripten/Dockerfile b/build_tools/docker/cmake-emscripten/Dockerfile new file mode 100644 index 0000000..4f3f600 --- /dev/null +++ b/build_tools/docker/cmake-emscripten/Dockerfile
@@ -0,0 +1,44 @@ +# Copyright 2021 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +# An image for building IREE through Emscripten using CMake. + +FROM gcr.io/iree-oss/cmake@sha256:9d9953acf5ca0cf1ff3e8de32f10f24dfab1c4e8ec5d1fc047f556024ee4bed6 AS final + +# See also +# * https://github.com/emscripten-core/emsdk/blob/main/docker/Dockerfile +# * https://hub.docker.com/r/emscripten/emsdk + +ARG EMSDK_COMMIT=5c0e31a03a72136ccaa761c408baf5a640f942ec +ARG SDK_VERSION=2.0.25 + +# Follow https://emscripten.org/docs/getting_started/downloads.html. +RUN git clone https://github.com/emscripten-core/emsdk +RUN cd emsdk && git checkout "${EMSDK_COMMIT?}" && \ + ./emsdk install ${SDK_VERSION?} && \ + ./emsdk activate ${SDK_VERSION?} + +# Set some environment variables for Emscripten to use. +ENV EMSDK=/emsdk +ENV EM_DATA=${EMSDK}/.data +ENV EM_CONFIG=${EMSDK}/.emscripten +ENV EM_CACHE=${EM_DATA}/cache +ENV EM_PORTS=${EM_DATA}/ports +# Emscripten writes into its cache location (outside of the CMake build +# directory). +# We can either +# (A) Grant broad write permissions to the cache directory to be able to run +# our scripts under different users. +# (B) Mount a user home directory when using the image. +# Since (A) requires less configuration, we'll do that. If multiple tools would +# want a user directory (like Bazel), we should switch to (B). +# See https://github.com/emscripten-core/emsdk/issues/535 +RUN mkdir -p ${EM_CACHE} && chmod -R 777 ${EM_CACHE} + +# Normally we'd run `source emsdk_env.sh`, but that doesn't integrate with +# Docker's environment properties model. Instead, we directly extend the path +# to include the directories suggested by `emsdk activate`. +ENV PATH="$PWD/emsdk:$PWD/emsdk/node/14.15.5_64bit/bin:$PWD/emsdk/upstream/emscripten:$PATH"
diff --git a/build_tools/docker/manage_images.py b/build_tools/docker/manage_images.py index 9e9bfd8..ce61e2f 100755 --- a/build_tools/docker/manage_images.py +++ b/build_tools/docker/manage_images.py
@@ -46,6 +46,7 @@ 'bazel': ['base', 'util'], 'cmake': ['base', 'util'], 'cmake-android': ['cmake-python', 'util'], + 'cmake-emscripten': ['cmake'], 'cmake-python': ['cmake'], 'cmake-python-vulkan': ['cmake-python', 'vulkan'], 'cmake-python-swiftshader': ['cmake-python-vulkan', 'swiftshader'],
diff --git a/build_tools/docker/prod_digests.txt b/build_tools/docker/prod_digests.txt index f5e1f59..dd0a5b7 100644 --- a/build_tools/docker/prod_digests.txt +++ b/build_tools/docker/prod_digests.txt
@@ -17,3 +17,4 @@ gcr.io/iree-oss/cmake-riscv@sha256:95489593bc9b0cd325ce9c1a32b47389c01b174a5b8190a16d937d2e8828d384 gcr.io/iree-oss/cmake-bazel-frontends-android@sha256:1392e3a27cddbdc597817168fb61e125bbdcbfd9076eff9d70bd8012b0a0c5ba gcr.io/iree-oss/samples@sha256:be5465585706b620d6c722caa6237eafdfaa8dd11ce20db0981b979f2d3387b3 +gcr.io/iree-oss/cmake-emscripten@sha256:8acad361d23cb586187c2ea29df3a1ab301b5283c3648beb328681d69ecd0ab0
diff --git a/build_tools/kokoro/gcp_ubuntu/cmake/linux/x86-swiftshader-asan/build.sh b/build_tools/kokoro/gcp_ubuntu/cmake/linux/x86-swiftshader-asan/build.sh index bd48c3f..1f5b793 100755 --- a/build_tools/kokoro/gcp_ubuntu/cmake/linux/x86-swiftshader-asan/build.sh +++ b/build_tools/kokoro/gcp_ubuntu/cmake/linux/x86-swiftshader-asan/build.sh
@@ -88,6 +88,7 @@ fi if [[ "${IREE_CUDA_DISABLE?}" == 1 ]]; then label_exclude_args+=("^driver=cuda$") + label_exclude_args+=("^uses_cuda_runtime$") fi if [[ "${IREE_VULKAN_F16_DISABLE?}" == 1 ]]; then label_exclude_args+=("^vulkan_uses_vk_khr_shader_float16_int8$")
diff --git a/docs/developers/best_practices.md b/docs/developers/best_practices.md new file mode 100644 index 0000000..d3d9c91 --- /dev/null +++ b/docs/developers/best_practices.md
@@ -0,0 +1,63 @@ +# IREE Best Practices + +This page contains a list of best practices for getting the most out of IREE, +spanning model authoring, ahead-of-time compilation, and runtime use. Treat +these as a collection of ideas to consider or areas to start benchmarking when +working on your own applications. + +## Introduction + +Common themes include: + +* Give the compiler as much information as possible +* Give the compiler opportunities to batch work together or defer computation +* Keep compute devices saturated with work through pipelining +* Use dense math where possible, particularly for inner loop bodies +* Limit synchronization points between devices like CPUs and GPUs +* Profile early and often, using the right tools for each level of granularity + +## Practices for model authoring + +### Track state within your model when possible + +If your model is stateful prefer to store that state directly within your +program rather than externalizing it through arguments and return values. By +keeping state inside your program the compiler is better able to reason about +it and function calls will have lower overhead. + +If you do externalize state, try to pack that state into a limited number of +arguments. + +See the +[variables and state](https://github.com/google/iree/tree/main/iree/samples/variables_and_state) +sample for further guidance on tracking and using state. + +### Limit uses of dynamic shapes + +While IREE aims to support general dynamic shapes use, it is better able to +optimize parts of programs where shapes are static. Slow varying dimensions +like batch index or timestamp are safer uses of dynamic shapes than faster +varying dimensions like the x/y/channel dimensions of images. + +See the +[dynamic shapes](https://github.com/google/iree/tree/main/iree/samples/dynamic_shapes) +sample for further guidance on using dynamic shapes. + +## Practices for compilation settings + +TODO: mention parameters to tune + +TODO: which compiler targets to use (try both CUDA and Vulkan?) + +TODO: use the most specific LLVM target triple you can? + +## Practices for runtime use + +### Do the minimum amount of work: cache queries and reuse buffers + +Try to front-load queries, particularly queries using strings that look up into +maps like `iree_runtime_session_call_by_name`, so that hot sections of code are +doing the minimum amount of work: routing inputs through buffers, scheduling +runtime calls, and routing outputs through other buffers. + +TODO: sample code, profile numbers
diff --git a/docs/developers/get_started/getting_started_emscripten.md b/docs/developers/get_started/getting_started_emscripten.md new file mode 100644 index 0000000..c5aefb7 --- /dev/null +++ b/docs/developers/get_started/getting_started_emscripten.md
@@ -0,0 +1,67 @@ +# Getting Started With Emscripten + +[Emscripten](https://emscripten.org/index.html) is a complete compiler +toolchain to WebAssembly, using LLVM, with a special focus on speed, size, and +the Web platform. Emscripten can be used to compile parts of IREE to +[WebAssembly](https://webassembly.org/) for execution within web browsers or +other Wasm runtimes. + +## Status + +IREE's _runtime_ can be compiled through Emscripten in some limited +configurations. More of the runtime will be supported over time. + +IREE's _compiler_ can be compiled through Emscripten with local changes. More +work is needed for this to be generally supported. + +## Prerequisites + +Read https://emscripten.org/docs/getting_started/downloads.html and run + +``` +./emsdk install latest +./emsdk activate latest +source ./emsdk_env.sh +``` + +## Building IREE's Runtime with Emscripten + +### Host Configuration + +Build and install at least the compiler tools on your host machine, or install +them from a binary distribution: + +```shell +$ cmake -G Ninja -B ../iree-build-host/ \ + -DCMAKE_C_COMPILER=clang \ + -DCMAKE_CXX_COMPILER=clang++ \ + -DCMAKE_INSTALL_PREFIX=../iree-build-host/install \ + . +$ cmake --build ../iree-build-host/ --target install +``` + +### Target Configuration + +```shell +$ emcmake cmake -G Ninja -B ../iree-build-emscripten/ \ + -DCMake_BUILD_TYPE=Release \ + -DIREE_HOST_BINARY_ROOT=$(realpath ../iree-build-host/install) \ + -DIREE_BUILD_TESTS=OFF \ + -DIREE_BUILD_COMPILER=OFF \ + -DIREE_ENABLE_MLIR=OFF \ + . +``` + +Build: + +``` +cmake --build ../iree-build-emscripten/ \ + --target iree_samples_simple_embedding_simple_embedding_vmvx_sync +``` + +### Load into a WebAssembly Environment + +Copy the outputs from the build process (e.g. `simple_embedding_vmvx_sync.js` +and `simple_embedding_vmvx_sync.wasm`) into your application and follow +instructions at either https://webassembly.org/getting-started/developers-guide/ +or https://developer.mozilla.org/en-US/docs/WebAssembly/Loading_and_running.
diff --git a/docs/website/docs/bindings/python.md b/docs/website/docs/bindings/python.md index b6faade..f279ad6 100644 --- a/docs/website/docs/bindings/python.md +++ b/docs/website/docs/bindings/python.md
@@ -10,11 +10,16 @@ | `iree-tools-tf-snapshot` | Tools for importing from [TensorFlow](https://www.tensorflow.org/) | | `iree-tools-tflite-snapshot` | Tools for importing from [TensorFlow Lite](https://www.tensorflow.org/lite) | | `iree-tools-xla-snapshot` | Tools for importing from [XLA](https://www.tensorflow.org/xla) | +| `iree-jax-snapshot` | Tools for importing from [JAX](https://github.com/google/jax) | Collectively, these packages allow for importing from frontends, compiling towards various targets, and executing compiled code on IREE's backends. -<!-- TODO(??): Which package for JAX? --> +!!! warning + The TensorFlow, TensorFlow Lite, and XLA packages are currently only + available on Linux and macOS. They are not available on Windows yet (see + [this issue](https://github.com/google/iree/issues/6417)). + <!-- TODO(??): API references for packages/modules --> <!-- TODO(??): at least link to source code and sample Colab notebooks for now --> <!-- TODO(??): link to frontend docs -->
diff --git a/docs/website/docs/blog/2021-07-19-tflite-tosa-compilation-diagram.png b/docs/website/docs/blog/2021-07-19-tflite-tosa-compilation-diagram.png new file mode 100644 index 0000000..11b5986 --- /dev/null +++ b/docs/website/docs/blog/2021-07-19-tflite-tosa-compilation-diagram.png Binary files differ
diff --git a/docs/website/docs/blog/2021-07-19-tflite-tosa.md b/docs/website/docs/blog/2021-07-19-tflite-tosa.md new file mode 100644 index 0000000..90929b6 --- /dev/null +++ b/docs/website/docs/blog/2021-07-19-tflite-tosa.md
@@ -0,0 +1,44 @@ + Monday, July 19, 2021<br> + By Rob Suderman and Jenni Kilduff + +## TFLite Support via TOSA + +IREE can now execute [TensorFlow Lite](https://www.tensorflow.org/lite) +(TFLite) models through the use of +[TOSA](https://developer.mlplatform.org/w/tosa/), an open standard of common +tensor operations, and a part of [MLIR](https://mlir.llvm.org/) core. TOSA’s +high-level representation of tensor operations provides a common front-end for +ingesting models from different frameworks. In this case we ingest a TFLite +flatbuffer and compile it to TOSA IR, which IREE takes as an input format to +compile to its various backends. + +{ align=left } + +Using TFLite as a frontend for IREE provides an alternative ingestion method for +already existing models that could benefit from IREE’s design. This enables +models already designed for on-device inference to have an alternative path for +execution without requiring any additional porting, while benefiting from +IREE’s improvements in buffer management, work dispatch system, and compact +binary format. With continued improvements to IREE/MLIR’s compilation +performance, more optimized versions can be compiled and distributed to target +devices without an update to the clientside environment. + +Today, we have validated floating point support for a variety of models, +including +[mobilenet](https://tfhub.dev/s?deployment-format=lite&network-architecture=mobilenet,mobilenet-v2,mobilenet-v3,mobilenet-v1&q=mobilenet) +(v1, v2, and v3) and +[mobilebert](https://tfhub.dev/tensorflow/lite-model/mobilebert/1/default/1). +More work is in progress to support fully quantized models, and TFLite’s hybrid +quantization, along with dynamic shape support. + +## Examples + +TFLite with IREE is available in Python and Java. We have a +[colab notebook](https://colab.sandbox.google.com/github/google/iree/blob/main/colab/tflite_text_classification.ipynb) +that shows how to use IREE’s python bindings and TFLite compiler tools to +compile a pre-trained TFLite model from a flatbuffer and run using IREE. We +also have an +[Android Java app](https://github.com/not-jenni/iree-android-tflite-demo) that +was forked from an existing TFLite demo app, swapping out the TFLite library +for our own AAR. More information on IREE’s TFLite frontend is available +[here](../ml-frameworks/tensorflow-lite.md).
diff --git a/docs/website/docs/deployment-configurations/cpu-dylib.md b/docs/website/docs/deployment-configurations/cpu-dylib.md index 706d039..8bad3be 100644 --- a/docs/website/docs/deployment-configurations/cpu-dylib.md +++ b/docs/website/docs/deployment-configurations/cpu-dylib.md
@@ -131,11 +131,11 @@ <!-- TODO(??): troubleshooting --> -[android-cc]: /building-from-source/android/ -[get-started]: /building-from-source/getting-started/ +[android-cc]: ../building-from-source/android.md +[get-started]: ../building-from-source/getting-started.md [iree-releases]: https://github.com/google/iree/releases/ [llvm]: https://llvm.org/ [mlir]: https://mlir.llvm.org/ [tf-hub-mobilenetv2]: https://tfhub.dev/google/tf2-preview/mobilenet_v2/classification -[tf-import]: /frontends/tensorflow/ -[tflite-import]: /frontends/tensorflow-lite/ +[tf-import]: ../ml-frameworks/tensorflow.md +[tflite-import]: ../ml-frameworks/tensorflow-lite.md
diff --git a/docs/website/docs/deployment-configurations/gpu-vulkan.md b/docs/website/docs/deployment-configurations/gpu-vulkan.md index d86de5c..a973bc3 100644 --- a/docs/website/docs/deployment-configurations/gpu-vulkan.md +++ b/docs/website/docs/deployment-configurations/gpu-vulkan.md
@@ -177,13 +177,13 @@ <!-- TODO(??): troubleshooting --> -[android-cc]: /building-from-source/android/ -[get-started]: /building-from-source/getting-started/ +[android-cc]: ../building-from-source/android.md +[get-started]: ../building-from-source/getting-started.md [iree-releases]: https://github.com/google/iree/releases/ [mlir]: https://mlir.llvm.org/ [spirv]: https://www.khronos.org/registry/spir-v/ [tf-hub-mobilenetv2]: https://tfhub.dev/google/tf2-preview/mobilenet_v2/classification -[tf-import]: /frontends/tensorflow/ -[tflite-import]: /frontends/tensorflow-lite/ +[tf-import]: ../ml-frameworks/tensorflow.md +[tflite-import]: ../ml-frameworks/tensorflow-lite.md [vulkan]: https://www.khronos.org/vulkan/ [vulkan-sdk]: https://vulkan.lunarg.com/sdk/home/
diff --git a/docs/website/docs/ml-frameworks/tensorflow-lite.md b/docs/website/docs/ml-frameworks/tensorflow-lite.md index b91ac3f..e807555 100644 --- a/docs/website/docs/ml-frameworks/tensorflow-lite.md +++ b/docs/website/docs/ml-frameworks/tensorflow-lite.md
@@ -24,6 +24,11 @@ -f https://github.com/google/iree/releases ``` +!!! warning + The TensorFlow Lite package is currently only available on Linux and macOS. + It is not available on Windows yet (see + [this issue](https://github.com/google/iree/issues/6417)). + ## Importing models Fist, import the TFLite model to TOSA MLIR:
diff --git a/docs/website/docs/ml-frameworks/tensorflow.md b/docs/website/docs/ml-frameworks/tensorflow.md index 837f8ec..9e5df96 100644 --- a/docs/website/docs/ml-frameworks/tensorflow.md +++ b/docs/website/docs/ml-frameworks/tensorflow.md
@@ -27,6 +27,11 @@ -f https://github.com/google/iree/releases ``` +!!! warning + The TensorFlow package is currently only available on Linux and macOS. It + is not available on Windows yet (see + [this issue](https://github.com/google/iree/issues/6417)). + ## Importing models IREE compilers transform a model into its final deployable format in several
diff --git a/docs/website/mkdocs.yml b/docs/website/mkdocs.yml index a6dca81..c9253f3 100644 --- a/docs/website/mkdocs.yml +++ b/docs/website/mkdocs.yml
@@ -73,6 +73,7 @@ markdown_extensions: - abbr - admonition + - attr_list - footnotes - meta - pymdownx.details @@ -115,3 +116,5 @@ - TensorFlow Lite: 'bindings/tensorflow-lite.md' - 'Community': - Projects: 'community/projects.md' + - 'Blog': + - TFLite Support via TOSA: 'blog/2021-07-19-tflite-tosa.md'
diff --git a/experimental/rocm/descriptor_set_layout.c b/experimental/rocm/descriptor_set_layout.c index 69fb4af..c9e985e 100644 --- a/experimental/rocm/descriptor_set_layout.c +++ b/experimental/rocm/descriptor_set_layout.c
@@ -13,49 +13,58 @@ typedef struct iree_hal_rocm_descriptor_set_layout_t { iree_hal_resource_t resource; - iree_hal_rocm_context_wrapper_t *context; + iree_hal_rocm_context_wrapper_t* context; + iree_host_size_t binding_count; } iree_hal_rocm_descriptor_set_layout_t; extern const iree_hal_descriptor_set_layout_vtable_t iree_hal_rocm_descriptor_set_layout_vtable; -static iree_hal_rocm_descriptor_set_layout_t * +static iree_hal_rocm_descriptor_set_layout_t* iree_hal_rocm_descriptor_set_layout_cast( - iree_hal_descriptor_set_layout_t *base_value) { + iree_hal_descriptor_set_layout_t* base_value) { IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_rocm_descriptor_set_layout_vtable); - return (iree_hal_rocm_descriptor_set_layout_t *)base_value; + return (iree_hal_rocm_descriptor_set_layout_t*)base_value; } iree_status_t iree_hal_rocm_descriptor_set_layout_create( - iree_hal_rocm_context_wrapper_t *context, + iree_hal_rocm_context_wrapper_t* context, iree_hal_descriptor_set_layout_usage_type_t usage_type, iree_host_size_t binding_count, - const iree_hal_descriptor_set_layout_binding_t *bindings, - iree_hal_descriptor_set_layout_t **out_descriptor_set_layout) { + const iree_hal_descriptor_set_layout_binding_t* bindings, + iree_hal_descriptor_set_layout_t** out_descriptor_set_layout) { IREE_ASSERT_ARGUMENT(context); IREE_ASSERT_ARGUMENT(!binding_count || bindings); IREE_ASSERT_ARGUMENT(out_descriptor_set_layout); *out_descriptor_set_layout = NULL; IREE_TRACE_ZONE_BEGIN(z0); - iree_hal_rocm_descriptor_set_layout_t *descriptor_set_layout = NULL; + iree_hal_rocm_descriptor_set_layout_t* descriptor_set_layout = NULL; iree_status_t status = iree_allocator_malloc(context->host_allocator, sizeof(*descriptor_set_layout), - (void **)&descriptor_set_layout); + (void**)&descriptor_set_layout); if (iree_status_is_ok(status)) { iree_hal_resource_initialize(&iree_hal_rocm_descriptor_set_layout_vtable, &descriptor_set_layout->resource); descriptor_set_layout->context = context; + descriptor_set_layout->binding_count = binding_count; *out_descriptor_set_layout = - (iree_hal_descriptor_set_layout_t *)descriptor_set_layout; + (iree_hal_descriptor_set_layout_t*)descriptor_set_layout; } IREE_TRACE_ZONE_END(z0); return status; } +iree_host_size_t iree_hal_rocm_descriptor_set_layout_binding_count( + iree_hal_descriptor_set_layout_t* base_descriptor_set_layout) { + iree_hal_rocm_descriptor_set_layout_t* descriptor_set_layout = + iree_hal_rocm_descriptor_set_layout_cast(base_descriptor_set_layout); + return descriptor_set_layout->binding_count; +} + static void iree_hal_rocm_descriptor_set_layout_destroy( - iree_hal_descriptor_set_layout_t *base_descriptor_set_layout) { - iree_hal_rocm_descriptor_set_layout_t *descriptor_set_layout = + iree_hal_descriptor_set_layout_t* base_descriptor_set_layout) { + iree_hal_rocm_descriptor_set_layout_t* descriptor_set_layout = iree_hal_rocm_descriptor_set_layout_cast(base_descriptor_set_layout); iree_allocator_t host_allocator = descriptor_set_layout->context->host_allocator;
diff --git a/experimental/rocm/descriptor_set_layout.h b/experimental/rocm/descriptor_set_layout.h index fb07b76..9d58acd 100644 --- a/experimental/rocm/descriptor_set_layout.h +++ b/experimental/rocm/descriptor_set_layout.h
@@ -16,11 +16,15 @@ #endif // __cplusplus iree_status_t iree_hal_rocm_descriptor_set_layout_create( - iree_hal_rocm_context_wrapper_t *context, + iree_hal_rocm_context_wrapper_t* context, iree_hal_descriptor_set_layout_usage_type_t usage_type, iree_host_size_t binding_count, - const iree_hal_descriptor_set_layout_binding_t *bindings, - iree_hal_descriptor_set_layout_t **out_descriptor_set_layout); + const iree_hal_descriptor_set_layout_binding_t* bindings, + iree_hal_descriptor_set_layout_t** out_descriptor_set_layout); + +// Return the binding count for the given descriptor set layout. +iree_host_size_t iree_hal_rocm_descriptor_set_layout_binding_count( + iree_hal_descriptor_set_layout_t* descriptor_set_layout); #ifdef __cplusplus } // extern "C"
diff --git a/experimental/rocm/direct_command_buffer.c b/experimental/rocm/direct_command_buffer.c index 3d8706a..db28d91 100644 --- a/experimental/rocm/direct_command_buffer.c +++ b/experimental/rocm/direct_command_buffer.c
@@ -11,6 +11,7 @@ #include <stdint.h> #include "experimental/rocm/dynamic_symbols.h" +#include "experimental/rocm/executable_layout.h" #include "experimental/rocm/native_executable.h" #include "experimental/rocm/rocm_buffer.h" #include "experimental/rocm/status_util.h" @@ -44,7 +45,7 @@ return (iree_hal_rocm_direct_command_buffer_t*)base_value; } -iree_status_t iree_hal_rocm_direct_command_buffer_allocate( +iree_status_t iree_hal_rocm_direct_command_buffer_create( iree_hal_rocm_context_wrapper_t* context, iree_hal_command_buffer_mode_t mode, iree_hal_command_category_t command_categories, @@ -283,6 +284,8 @@ const iree_hal_descriptor_set_binding_t* bindings) { iree_hal_rocm_direct_command_buffer_t* command_buffer = iree_hal_rocm_direct_command_buffer_cast(base_command_buffer); + iree_host_size_t base_binding = + iree_hal_rocm_base_binding_index(executable_layout, set); // Convention with the compiler side. We map bindings to kernel argument. // We compact the bindings to get a dense set of arguments and keep them order // based on the binding index. @@ -303,7 +306,8 @@ iree_hal_rocm_buffer_device_pointer( iree_hal_buffer_allocated_buffer(binding.buffer)) + iree_hal_buffer_byte_offset(binding.buffer) + binding.offset; - *((hipDeviceptr_t*)command_buffer->current_descriptor[i]) = device_ptr; + *((hipDeviceptr_t*)command_buffer->current_descriptor[i + base_binding]) = + device_ptr; } return iree_ok_status(); }
diff --git a/experimental/rocm/direct_command_buffer.h b/experimental/rocm/direct_command_buffer.h index 0145b15..bd665bf 100644 --- a/experimental/rocm/direct_command_buffer.h +++ b/experimental/rocm/direct_command_buffer.h
@@ -26,16 +26,16 @@ unsigned int blockDimX; unsigned int blockDimY; unsigned int blockDimZ; - void **kernelParams; + void** kernelParams; } hip_launch_params; // Creates a rocm direct command buffer. -iree_status_t iree_hal_rocm_direct_command_buffer_allocate( - iree_hal_rocm_context_wrapper_t *context, +iree_status_t iree_hal_rocm_direct_command_buffer_create( + iree_hal_rocm_context_wrapper_t* context, iree_hal_command_buffer_mode_t mode, iree_hal_command_category_t command_categories, iree_hal_queue_affinity_t queue_affinity, - iree_hal_command_buffer_t **out_command_buffer); + iree_hal_command_buffer_t** out_command_buffer); #ifdef __cplusplus } // extern "C"
diff --git a/experimental/rocm/dynamic_symbol_tables.h b/experimental/rocm/dynamic_symbol_tables.h index fe6cfd5..aa78c5a 100644 --- a/experimental/rocm/dynamic_symbol_tables.h +++ b/experimental/rocm/dynamic_symbol_tables.h
@@ -30,6 +30,7 @@ RC_PFN_DECL(hipMemcpyAsync, void *, const void *, size_t, hipMemcpyKind, hipStream_t) RC_PFN_DECL(hipMalloc, void **, size_t) +RC_PFN_DECL(hipMallocManaged, hipDeviceptr_t *, size_t, unsigned int) RC_PFN_DECL(hipFree, void *) RC_PFN_DECL(hipHostFree, void *) RC_PFN_DECL(hipMemAllocHost, void **, size_t, unsigned int)
diff --git a/experimental/rocm/dynamic_symbols.c b/experimental/rocm/dynamic_symbols.c index d676838..ad7ca6d 100644 --- a/experimental/rocm/dynamic_symbols.c +++ b/experimental/rocm/dynamic_symbols.c
@@ -12,7 +12,7 @@ #include "iree/base/target_platform.h" #include "iree/base/tracing.h" -static const char *kROCMLoaderSearchNames[] = { +static const char* kROCMLoaderSearchNames[] = { #if defined(IREE_PLATFORM_WINDOWS) "amdhip64.dll", #else @@ -21,12 +21,12 @@ }; static iree_status_t iree_hal_rocm_dynamic_symbols_resolve_all( - iree_hal_rocm_dynamic_symbols_t *syms) { -#define RC_PFN_DECL(rocmSymbolName, ...) \ - { \ - static const char *kName = #rocmSymbolName; \ - IREE_RETURN_IF_ERROR(iree_dynamic_library_lookup_symbol( \ - syms->loader_library, kName, (void **)&syms->rocmSymbolName)); \ + iree_hal_rocm_dynamic_symbols_t* syms) { +#define RC_PFN_DECL(rocmSymbolName, ...) \ + { \ + static const char* kName = #rocmSymbolName; \ + IREE_RETURN_IF_ERROR(iree_dynamic_library_lookup_symbol( \ + syms->loader_library, kName, (void**)&syms->rocmSymbolName)); \ } #define RC_PFN_STR_DECL(rocmSymbolName, ...) RC_PFN_DECL(rocmSymbolName, ...) #include "experimental/rocm/dynamic_symbol_tables.h" // IWYU pragma: keep @@ -36,7 +36,7 @@ } iree_status_t iree_hal_rocm_dynamic_symbols_initialize( - iree_allocator_t allocator, iree_hal_rocm_dynamic_symbols_t *out_syms) { + iree_allocator_t allocator, iree_hal_rocm_dynamic_symbols_t* out_syms) { IREE_TRACE_ZONE_BEGIN(z0); memset(out_syms, 0, sizeof(*out_syms)); iree_status_t status = iree_dynamic_library_load_from_files( @@ -59,7 +59,7 @@ } void iree_hal_rocm_dynamic_symbols_deinitialize( - iree_hal_rocm_dynamic_symbols_t *syms) { + iree_hal_rocm_dynamic_symbols_t* syms) { IREE_TRACE_ZONE_BEGIN(z0); iree_dynamic_library_release(syms->loader_library); memset(syms, 0, sizeof(*syms));
diff --git a/experimental/rocm/dynamic_symbols.h b/experimental/rocm/dynamic_symbols.h index 00a58b0..9cbd774 100644 --- a/experimental/rocm/dynamic_symbols.h +++ b/experimental/rocm/dynamic_symbols.h
@@ -20,12 +20,12 @@ // any of the symbol is not available. The functions signatures are matching // the declarations in `hipruntime.h`. typedef struct iree_hal_rocm_dynamic_symbols_t { - iree_dynamic_library_t *loader_library; + iree_dynamic_library_t* loader_library; #define RC_PFN_DECL(rocmSymbolName, ...) \ hipError_t (*rocmSymbolName)(__VA_ARGS__); #define RC_PFN_STR_DECL(rocmSymbolName, ...) \ - const char *(*rocmSymbolName)(__VA_ARGS__); + const char* (*rocmSymbolName)(__VA_ARGS__); #include "experimental/rocm/dynamic_symbol_tables.h" // IWYU pragma: export #undef RC_PFN_DECL #undef RC_PFN_STR_DECL @@ -35,13 +35,13 @@ // iree_hal_rocm_dynamic_symbols_deinitialize must be used to release the // library resources. iree_status_t iree_hal_rocm_dynamic_symbols_initialize( - iree_allocator_t allocator, iree_hal_rocm_dynamic_symbols_t *out_syms); + iree_allocator_t allocator, iree_hal_rocm_dynamic_symbols_t* out_syms); // Deinitializes |syms| by unloading the backing library. All function pointers // will be invalidated. They _may_ still work if there are other reasons the // library remains loaded so be careful. void iree_hal_rocm_dynamic_symbols_deinitialize( - iree_hal_rocm_dynamic_symbols_t *syms); + iree_hal_rocm_dynamic_symbols_t* syms); #ifdef __cplusplus } // extern "C"
diff --git a/experimental/rocm/event_semaphore.c b/experimental/rocm/event_semaphore.c index 7fe6860..d3085fd 100644 --- a/experimental/rocm/event_semaphore.c +++ b/experimental/rocm/event_semaphore.c
@@ -13,34 +13,34 @@ typedef struct iree_hal_rocm_semaphore_t { iree_hal_resource_t resource; - iree_hal_rocm_context_wrapper_t *context; + iree_hal_rocm_context_wrapper_t* context; uint64_t initial_value; } iree_hal_rocm_semaphore_t; extern const iree_hal_semaphore_vtable_t iree_hal_rocm_semaphore_vtable; -static iree_hal_rocm_semaphore_t *iree_hal_rocm_semaphore_cast( - iree_hal_semaphore_t *base_value) { +static iree_hal_rocm_semaphore_t* iree_hal_rocm_semaphore_cast( + iree_hal_semaphore_t* base_value) { IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_rocm_semaphore_vtable); - return (iree_hal_rocm_semaphore_t *)base_value; + return (iree_hal_rocm_semaphore_t*)base_value; } iree_status_t iree_hal_rocm_semaphore_create( - iree_hal_rocm_context_wrapper_t *context, uint64_t initial_value, - iree_hal_semaphore_t **out_semaphore) { + iree_hal_rocm_context_wrapper_t* context, uint64_t initial_value, + iree_hal_semaphore_t** out_semaphore) { IREE_ASSERT_ARGUMENT(context); IREE_ASSERT_ARGUMENT(out_semaphore); IREE_TRACE_ZONE_BEGIN(z0); - iree_hal_rocm_semaphore_t *semaphore = NULL; + iree_hal_rocm_semaphore_t* semaphore = NULL; iree_status_t status = iree_allocator_malloc( - context->host_allocator, sizeof(*semaphore), (void **)&semaphore); + context->host_allocator, sizeof(*semaphore), (void**)&semaphore); if (iree_status_is_ok(status)) { iree_hal_resource_initialize(&iree_hal_rocm_semaphore_vtable, &semaphore->resource); semaphore->context = context; semaphore->initial_value = initial_value; - *out_semaphore = (iree_hal_semaphore_t *)semaphore; + *out_semaphore = (iree_hal_semaphore_t*)semaphore; } IREE_TRACE_ZONE_END(z0); @@ -48,8 +48,8 @@ } static void iree_hal_rocm_semaphore_destroy( - iree_hal_semaphore_t *base_semaphore) { - iree_hal_rocm_semaphore_t *semaphore = + iree_hal_semaphore_t* base_semaphore) { + iree_hal_rocm_semaphore_t* semaphore = iree_hal_rocm_semaphore_cast(base_semaphore); iree_allocator_t host_allocator = semaphore->context->host_allocator; IREE_TRACE_ZONE_BEGIN(z0); @@ -60,24 +60,24 @@ } static iree_status_t iree_hal_rocm_semaphore_query( - iree_hal_semaphore_t *base_semaphore, uint64_t *out_value) { + iree_hal_semaphore_t* base_semaphore, uint64_t* out_value) { // TODO: Support semaphores completely. *out_value = 0; return iree_make_status(IREE_STATUS_UNIMPLEMENTED, "Not impemented on rocm"); } static iree_status_t iree_hal_rocm_semaphore_signal( - iree_hal_semaphore_t *base_semaphore, uint64_t new_value) { + iree_hal_semaphore_t* base_semaphore, uint64_t new_value) { // TODO: Support semaphores completely. Return OK currently as everything is // synchronized for each submit to allow things to run. return iree_ok_status(); } -static void iree_hal_rocm_semaphore_fail(iree_hal_semaphore_t *base_semaphore, +static void iree_hal_rocm_semaphore_fail(iree_hal_semaphore_t* base_semaphore, iree_status_t status) {} static iree_status_t iree_hal_rocm_semaphore_wait( - iree_hal_semaphore_t *base_semaphore, uint64_t value, + iree_hal_semaphore_t* base_semaphore, uint64_t value, iree_timeout_t timeout) { // TODO: Support semaphores completely. Return OK currently as everything is // synchronized for each submit to allow things to run.
diff --git a/experimental/rocm/event_semaphore.h b/experimental/rocm/event_semaphore.h index 5f32492..9e79aa8 100644 --- a/experimental/rocm/event_semaphore.h +++ b/experimental/rocm/event_semaphore.h
@@ -20,8 +20,8 @@ // Create a rocm allocator. iree_status_t iree_hal_rocm_semaphore_create( - iree_hal_rocm_context_wrapper_t *context, uint64_t initial_value, - iree_hal_semaphore_t **out_semaphore); + iree_hal_rocm_context_wrapper_t* context, uint64_t initial_value, + iree_hal_semaphore_t** out_semaphore); #ifdef __cplusplus } // extern "C"
diff --git a/experimental/rocm/executable_layout.c b/experimental/rocm/executable_layout.c index 6960cd2..f293921 100644 --- a/experimental/rocm/executable_layout.c +++ b/experimental/rocm/executable_layout.c
@@ -8,30 +8,31 @@ #include <stddef.h> +#include "experimental/rocm/descriptor_set_layout.h" #include "iree/base/api.h" #include "iree/base/tracing.h" typedef struct iree_hal_rocm_executable_layout_t { iree_hal_resource_t resource; - iree_hal_rocm_context_wrapper_t *context; + iree_hal_rocm_context_wrapper_t* context; iree_host_size_t set_layout_count; - iree_hal_descriptor_set_layout_t *set_layouts[]; + iree_hal_descriptor_set_layout_t* set_layouts[]; } iree_hal_rocm_executable_layout_t; extern const iree_hal_executable_layout_vtable_t iree_hal_rocm_executable_layout_vtable; -static iree_hal_rocm_executable_layout_t *iree_hal_rocm_executable_layout_cast( - iree_hal_executable_layout_t *base_value) { +static iree_hal_rocm_executable_layout_t* iree_hal_rocm_executable_layout_cast( + iree_hal_executable_layout_t* base_value) { IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_rocm_executable_layout_vtable); - return (iree_hal_rocm_executable_layout_t *)base_value; + return (iree_hal_rocm_executable_layout_t*)base_value; } iree_status_t iree_hal_rocm_executable_layout_create( - iree_hal_rocm_context_wrapper_t *context, iree_host_size_t set_layout_count, - iree_hal_descriptor_set_layout_t **set_layouts, + iree_hal_rocm_context_wrapper_t* context, iree_host_size_t set_layout_count, + iree_hal_descriptor_set_layout_t** set_layouts, iree_host_size_t push_constant_count, - iree_hal_executable_layout_t **out_executable_layout) { + iree_hal_executable_layout_t** out_executable_layout) { IREE_ASSERT_ARGUMENT(context); IREE_ASSERT_ARGUMENT(!set_layout_count || set_layouts); IREE_ASSERT_ARGUMENT(out_executable_layout); @@ -40,12 +41,12 @@ // Currently the executable layout doesn't do anything. // TODO: Handle creating the argument layout at that time hadling both push // constant and buffers. - iree_hal_rocm_executable_layout_t *executable_layout = NULL; + iree_hal_rocm_executable_layout_t* executable_layout = NULL; iree_host_size_t total_size = sizeof(*executable_layout) + set_layout_count * sizeof(*executable_layout->set_layouts); iree_status_t status = iree_allocator_malloc( - context->host_allocator, total_size, (void **)&executable_layout); + context->host_allocator, total_size, (void**)&executable_layout); if (iree_status_is_ok(status)) { iree_hal_resource_initialize(&iree_hal_rocm_executable_layout_vtable, &executable_layout->resource); @@ -55,15 +56,15 @@ executable_layout->set_layouts[i] = set_layouts[i]; iree_hal_descriptor_set_layout_retain(set_layouts[i]); } - *out_executable_layout = (iree_hal_executable_layout_t *)executable_layout; + *out_executable_layout = (iree_hal_executable_layout_t*)executable_layout; } IREE_TRACE_ZONE_END(z0); return status; } static void iree_hal_rocm_executable_layout_destroy( - iree_hal_executable_layout_t *base_executable_layout) { - iree_hal_rocm_executable_layout_t *executable_layout = + iree_hal_executable_layout_t* base_executable_layout) { + iree_hal_rocm_executable_layout_t* executable_layout = iree_hal_rocm_executable_layout_cast(base_executable_layout); iree_allocator_t host_allocator = executable_layout->context->host_allocator; IREE_TRACE_ZONE_BEGIN(z0); @@ -76,6 +77,20 @@ IREE_TRACE_ZONE_END(z0); } +iree_host_size_t iree_hal_rocm_base_binding_index( + iree_hal_executable_layout_t* base_executable_layout, uint32_t set) { + iree_hal_rocm_executable_layout_t* executable_layout = + iree_hal_rocm_executable_layout_cast(base_executable_layout); + iree_host_size_t base_binding = 0; + for (iree_host_size_t i = 0; i < set; ++i) { + iree_host_size_t binding_count = + iree_hal_rocm_descriptor_set_layout_binding_count( + executable_layout->set_layouts[i]); + base_binding += binding_count; + } + return base_binding; +} + const iree_hal_executable_layout_vtable_t iree_hal_rocm_executable_layout_vtable = { .destroy = iree_hal_rocm_executable_layout_destroy,
diff --git a/experimental/rocm/executable_layout.h b/experimental/rocm/executable_layout.h index cd0fccf..8287cb8 100644 --- a/experimental/rocm/executable_layout.h +++ b/experimental/rocm/executable_layout.h
@@ -17,10 +17,14 @@ // Creates the kernel arguments. iree_status_t iree_hal_rocm_executable_layout_create( - iree_hal_rocm_context_wrapper_t *context, iree_host_size_t set_layout_count, - iree_hal_descriptor_set_layout_t **set_layouts, + iree_hal_rocm_context_wrapper_t* context, iree_host_size_t set_layout_count, + iree_hal_descriptor_set_layout_t** set_layouts, iree_host_size_t push_constant_count, - iree_hal_executable_layout_t **out_executable_layout); + iree_hal_executable_layout_t** out_executable_layout); + +// Return the base binding index for the given set. +iree_host_size_t iree_hal_rocm_base_binding_index( + iree_hal_executable_layout_t* executable_layout, uint32_t set); #ifdef __cplusplus } // extern "C"
diff --git a/experimental/rocm/native_executable.c b/experimental/rocm/native_executable.c index f72bd4f..680f07f 100644 --- a/experimental/rocm/native_executable.c +++ b/experimental/rocm/native_executable.c
@@ -27,7 +27,7 @@ typedef struct iree_hal_rocm_native_executable_t { iree_hal_resource_t resource; - iree_hal_rocm_context_wrapper_t *context; + iree_hal_rocm_context_wrapper_t* context; iree_host_size_t entry_count; hipModule_t module; iree_hal_rocm_native_executable_function_t entry_functions[]; @@ -36,23 +36,23 @@ extern const iree_hal_executable_vtable_t iree_hal_rocm_native_executable_vtable; -static iree_hal_rocm_native_executable_t *iree_hal_rocm_native_executable_cast( - iree_hal_executable_t *base_value) { +static iree_hal_rocm_native_executable_t* iree_hal_rocm_native_executable_cast( + iree_hal_executable_t* base_value) { IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_rocm_native_executable_vtable); - return (iree_hal_rocm_native_executable_t *)base_value; + return (iree_hal_rocm_native_executable_t*)base_value; } iree_status_t iree_hal_rocm_native_executable_create( - iree_hal_rocm_context_wrapper_t *context, - const iree_hal_executable_spec_t *executable_spec, - iree_hal_executable_t **out_executable) { + iree_hal_rocm_context_wrapper_t* context, + const iree_hal_executable_spec_t* executable_spec, + iree_hal_executable_t** out_executable) { IREE_ASSERT_ARGUMENT(context); IREE_ASSERT_ARGUMENT(executable_spec); IREE_ASSERT_ARGUMENT(out_executable); *out_executable = NULL; IREE_TRACE_ZONE_BEGIN(z0); - iree_hal_rocm_native_executable_t *executable = NULL; + iree_hal_rocm_native_executable_t* executable = NULL; // TODO: Verify the flat buffer. iree_ROCMExecutableDef_table_t executable_def = @@ -69,8 +69,8 @@ iree_host_size_t total_size = sizeof(*executable) + entry_count * sizeof(iree_hal_rocm_native_executable_function_t); - iree_status_t status = iree_allocator_malloc( - context->host_allocator, total_size, (void **)&executable); + iree_status_t status = iree_allocator_malloc(context->host_allocator, + total_size, (void**)&executable); hipModule_t module = NULL; ROCM_RETURN_IF_ERROR(context->syms, hipModuleLoadDataEx(&module, hsaco_image, 0, NULL, NULL), @@ -78,7 +78,7 @@ for (iree_host_size_t i = 0; i < entry_count; i++) { hipFunction_t function = NULL; - const char *entry_name = flatbuffers_string_vec_at(entry_points_vec, i); + const char* entry_name = flatbuffers_string_vec_at(entry_points_vec, i); ROCM_RETURN_IF_ERROR(context->syms, hipModuleGetFunction(&function, module, entry_name), "hipModuleGetFunction"); @@ -92,22 +92,22 @@ &executable->resource); executable->module = module; executable->context = context; - *out_executable = (iree_hal_executable_t *)executable; + *out_executable = (iree_hal_executable_t*)executable; IREE_TRACE_ZONE_END(z0); return iree_ok_status(); } hipFunction_t iree_hal_rocm_native_executable_for_entry_point( - iree_hal_executable_t *base_executable, int32_t entry_point) { - iree_hal_rocm_native_executable_t *executable = + iree_hal_executable_t* base_executable, int32_t entry_point) { + iree_hal_rocm_native_executable_t* executable = iree_hal_rocm_native_executable_cast(base_executable); return executable->entry_functions[entry_point].rocm_function; } iree_status_t iree_hal_rocm_native_executable_block_size( - iree_hal_executable_t *base_executable, int32_t entry_point, uint32_t *x, - uint32_t *y, uint32_t *z) { - iree_hal_rocm_native_executable_t *executable = + iree_hal_executable_t* base_executable, int32_t entry_point, uint32_t* x, + uint32_t* y, uint32_t* z) { + iree_hal_rocm_native_executable_t* executable = iree_hal_rocm_native_executable_cast(base_executable); *x = executable->entry_functions[entry_point].block_size_x; *y = executable->entry_functions[entry_point].block_size_y; @@ -116,8 +116,8 @@ } static void iree_hal_rocm_native_executable_destroy( - iree_hal_executable_t *base_executable) { - iree_hal_rocm_native_executable_t *executable = + iree_hal_executable_t* base_executable) { + iree_hal_rocm_native_executable_t* executable = iree_hal_rocm_native_executable_cast(base_executable); iree_allocator_t host_allocator = executable->context->host_allocator; IREE_TRACE_ZONE_BEGIN(z0);
diff --git a/experimental/rocm/native_executable.h b/experimental/rocm/native_executable.h index 7a9229a..bc671bf 100644 --- a/experimental/rocm/native_executable.h +++ b/experimental/rocm/native_executable.h
@@ -21,17 +21,17 @@ // Creates an executable from a HSACO module. The module may contain several // kernels that can be extracted along with the associated block size. iree_status_t iree_hal_rocm_native_executable_create( - iree_hal_rocm_context_wrapper_t *context, - const iree_hal_executable_spec_t *executable_spec, - iree_hal_executable_t **out_executable); + iree_hal_rocm_context_wrapper_t* context, + const iree_hal_executable_spec_t* executable_spec, + iree_hal_executable_t** out_executable); hipFunction_t iree_hal_rocm_native_executable_for_entry_point( - iree_hal_executable_t *executable, int32_t entry_point); + iree_hal_executable_t* executable, int32_t entry_point); // Return the block size of the given |entry_point| within the executable. iree_status_t iree_hal_rocm_native_executable_block_size( - iree_hal_executable_t *executable, int32_t entry_point, uint32_t *x, - uint32_t *y, uint32_t *z); + iree_hal_executable_t* executable, int32_t entry_point, uint32_t* x, + uint32_t* y, uint32_t* z); #ifdef __cplusplus } // extern "C"
diff --git a/experimental/rocm/nop_executable_cache.c b/experimental/rocm/nop_executable_cache.c index 6592727..54bde49 100644 --- a/experimental/rocm/nop_executable_cache.c +++ b/experimental/rocm/nop_executable_cache.c
@@ -15,44 +15,44 @@ typedef struct iree_hal_rocm_nop_executable_cache_t { iree_hal_resource_t resource; - iree_hal_rocm_context_wrapper_t *context; + iree_hal_rocm_context_wrapper_t* context; } iree_hal_rocm_nop_executable_cache_t; extern const iree_hal_executable_cache_vtable_t iree_hal_rocm_nop_executable_cache_vtable; -static iree_hal_rocm_nop_executable_cache_t * +static iree_hal_rocm_nop_executable_cache_t* iree_hal_rocm_nop_executable_cache_cast( - iree_hal_executable_cache_t *base_value) { + iree_hal_executable_cache_t* base_value) { IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_rocm_nop_executable_cache_vtable); - return (iree_hal_rocm_nop_executable_cache_t *)base_value; + return (iree_hal_rocm_nop_executable_cache_t*)base_value; } iree_status_t iree_hal_rocm_nop_executable_cache_create( - iree_hal_rocm_context_wrapper_t *context, iree_string_view_t identifier, - iree_hal_executable_cache_t **out_executable_cache) { + iree_hal_rocm_context_wrapper_t* context, iree_string_view_t identifier, + iree_hal_executable_cache_t** out_executable_cache) { IREE_ASSERT_ARGUMENT(out_executable_cache); *out_executable_cache = NULL; IREE_TRACE_ZONE_BEGIN(z0); - iree_hal_rocm_nop_executable_cache_t *executable_cache = NULL; + iree_hal_rocm_nop_executable_cache_t* executable_cache = NULL; iree_status_t status = iree_allocator_malloc(context->host_allocator, sizeof(*executable_cache), - (void **)&executable_cache); + (void**)&executable_cache); if (iree_status_is_ok(status)) { iree_hal_resource_initialize(&iree_hal_rocm_nop_executable_cache_vtable, &executable_cache->resource); executable_cache->context = context; - *out_executable_cache = (iree_hal_executable_cache_t *)executable_cache; + *out_executable_cache = (iree_hal_executable_cache_t*)executable_cache; } IREE_TRACE_ZONE_END(z0); return status; } static void iree_hal_rocm_nop_executable_cache_destroy( - iree_hal_executable_cache_t *base_executable_cache) { - iree_hal_rocm_nop_executable_cache_t *executable_cache = + iree_hal_executable_cache_t* base_executable_cache) { + iree_hal_rocm_nop_executable_cache_t* executable_cache = iree_hal_rocm_nop_executable_cache_cast(base_executable_cache); iree_allocator_t host_allocator = executable_cache->context->host_allocator; IREE_TRACE_ZONE_BEGIN(z0); @@ -63,7 +63,7 @@ } static bool iree_hal_rocm_nop_executable_cache_can_prepare_format( - iree_hal_executable_cache_t *base_executable_cache, + iree_hal_executable_cache_t* base_executable_cache, iree_hal_executable_caching_mode_t caching_mode, iree_string_view_t executable_format) { return iree_string_view_equal(executable_format, @@ -71,10 +71,10 @@ } static iree_status_t iree_hal_rocm_nop_executable_cache_prepare_executable( - iree_hal_executable_cache_t *base_executable_cache, - const iree_hal_executable_spec_t *executable_spec, - iree_hal_executable_t **out_executable) { - iree_hal_rocm_nop_executable_cache_t *executable_cache = + iree_hal_executable_cache_t* base_executable_cache, + const iree_hal_executable_spec_t* executable_spec, + iree_hal_executable_t** out_executable) { + iree_hal_rocm_nop_executable_cache_t* executable_cache = iree_hal_rocm_nop_executable_cache_cast(base_executable_cache); return iree_hal_rocm_native_executable_create( executable_cache->context, executable_spec, out_executable);
diff --git a/experimental/rocm/nop_executable_cache.h b/experimental/rocm/nop_executable_cache.h index d1b2fc1..a057466 100644 --- a/experimental/rocm/nop_executable_cache.h +++ b/experimental/rocm/nop_executable_cache.h
@@ -19,8 +19,8 @@ // This is useful to isolate pipeline caching behavior and verify compilation // behavior. iree_status_t iree_hal_rocm_nop_executable_cache_create( - iree_hal_rocm_context_wrapper_t *context, iree_string_view_t identifier, - iree_hal_executable_cache_t **out_executable_cache); + iree_hal_rocm_context_wrapper_t* context, iree_string_view_t identifier, + iree_hal_executable_cache_t** out_executable_cache); #ifdef __cplusplus } // extern "C"
diff --git a/experimental/rocm/rocm_allocator.c b/experimental/rocm/rocm_allocator.c index cccdc9f..4f9742f 100644 --- a/experimental/rocm/rocm_allocator.c +++ b/experimental/rocm/rocm_allocator.c
@@ -16,30 +16,30 @@ typedef struct iree_hal_rocm_allocator_t { iree_hal_resource_t resource; - iree_hal_rocm_context_wrapper_t *context; + iree_hal_rocm_context_wrapper_t* context; } iree_hal_rocm_allocator_t; extern const iree_hal_allocator_vtable_t iree_hal_rocm_allocator_vtable; -static iree_hal_rocm_allocator_t *iree_hal_rocm_allocator_cast( - iree_hal_allocator_t *base_value) { +static iree_hal_rocm_allocator_t* iree_hal_rocm_allocator_cast( + iree_hal_allocator_t* base_value) { IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_rocm_allocator_vtable); - return (iree_hal_rocm_allocator_t *)base_value; + return (iree_hal_rocm_allocator_t*)base_value; } iree_status_t iree_hal_rocm_allocator_create( - iree_hal_rocm_context_wrapper_t *context, - iree_hal_allocator_t **out_allocator) { + iree_hal_rocm_context_wrapper_t* context, + iree_hal_allocator_t** out_allocator) { IREE_ASSERT_ARGUMENT(context); IREE_TRACE_ZONE_BEGIN(z0); - iree_hal_rocm_allocator_t *allocator = NULL; + iree_hal_rocm_allocator_t* allocator = NULL; iree_status_t status = iree_allocator_malloc( - context->host_allocator, sizeof(*allocator), (void **)&allocator); + context->host_allocator, sizeof(*allocator), (void**)&allocator); if (iree_status_is_ok(status)) { iree_hal_resource_initialize(&iree_hal_rocm_allocator_vtable, &allocator->resource); allocator->context = context; - *out_allocator = (iree_hal_allocator_t *)allocator; + *out_allocator = (iree_hal_allocator_t*)allocator; } IREE_TRACE_ZONE_END(z0); @@ -47,8 +47,8 @@ } static void iree_hal_rocm_allocator_destroy( - iree_hal_allocator_t *base_allocator) { - iree_hal_rocm_allocator_t *allocator = + iree_hal_allocator_t* base_allocator) { + iree_hal_rocm_allocator_t* allocator = iree_hal_rocm_allocator_cast(base_allocator); iree_allocator_t host_allocator = allocator->context->host_allocator; IREE_TRACE_ZONE_BEGIN(z0); @@ -59,15 +59,15 @@ } static iree_allocator_t iree_hal_rocm_allocator_host_allocator( - const iree_hal_allocator_t *base_allocator) { - iree_hal_rocm_allocator_t *allocator = - (iree_hal_rocm_allocator_t *)base_allocator; + const iree_hal_allocator_t* base_allocator) { + iree_hal_rocm_allocator_t* allocator = + (iree_hal_rocm_allocator_t*)base_allocator; return allocator->context->host_allocator; } static iree_hal_buffer_compatibility_t iree_hal_rocm_allocator_query_buffer_compatibility( - iree_hal_allocator_t *base_allocator, iree_hal_memory_type_t memory_type, + iree_hal_allocator_t* base_allocator, iree_hal_memory_type_t memory_type, iree_hal_buffer_usage_t allowed_usage, iree_hal_buffer_usage_t intended_usage, iree_device_size_t allocation_size) { @@ -96,19 +96,31 @@ } static iree_status_t iree_hal_rocm_allocator_allocate_buffer( - iree_hal_allocator_t *base_allocator, iree_hal_memory_type_t memory_type, + iree_hal_allocator_t* base_allocator, iree_hal_memory_type_t memory_type, iree_hal_buffer_usage_t allowed_usage, iree_host_size_t allocation_size, - iree_hal_buffer_t **out_buffer) { - iree_hal_rocm_allocator_t *allocator = + iree_hal_buffer_t** out_buffer) { + iree_hal_rocm_allocator_t* allocator = iree_hal_rocm_allocator_cast(base_allocator); // Guard against the corner case where the requested buffer size is 0. The // application is unlikely to do anything when requesting a 0-byte buffer; but // it can happen in real world use cases. So we should at least not crash. if (allocation_size == 0) allocation_size = 4; iree_status_t status; - void *host_ptr = NULL; + void* host_ptr = NULL; hipDeviceptr_t device_ptr = 0; - if (iree_all_bits_set(memory_type, IREE_HAL_MEMORY_TYPE_HOST_VISIBLE)) { + if (iree_all_bits_set(memory_type, IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL)) { + // Device local case. + if (iree_all_bits_set(memory_type, IREE_HAL_MEMORY_TYPE_HOST_VISIBLE)) { + status = ROCM_RESULT_TO_STATUS( + allocator->context->syms, + hipMallocManaged(&device_ptr, allocation_size, hipMemAttachGlobal)); + host_ptr = (void*)device_ptr; + } else { + // Device only. + status = ROCM_RESULT_TO_STATUS(allocator->context->syms, + hipMalloc(&device_ptr, allocation_size)); + } + } else { unsigned int flags = hipHostMallocMapped; if (!iree_all_bits_set(memory_type, IREE_HAL_MEMORY_TYPE_HOST_CACHED)) { flags |= hipHostMallocWriteCombined; @@ -121,14 +133,11 @@ allocator->context->syms, hipHostGetDevicePointer(&device_ptr, host_ptr, /*flags=*/0)); } - } else { - status = ROCM_RESULT_TO_STATUS(allocator->context->syms, - hipMalloc(&device_ptr, allocation_size)); } if (iree_status_is_ok(status)) { status = iree_hal_rocm_buffer_wrap( - (iree_hal_allocator_t *)allocator, memory_type, + (iree_hal_allocator_t*)allocator, memory_type, IREE_HAL_MEMORY_ACCESS_ALL, allowed_usage, allocation_size, /*byte_offset=*/0, /*byte_length=*/allocation_size, device_ptr, host_ptr, out_buffer); @@ -140,23 +149,24 @@ return status; } -void iree_hal_rocm_allocator_free(iree_hal_allocator_t *base_allocator, - hipDeviceptr_t device_ptr, void *host_ptr, +void iree_hal_rocm_allocator_free(iree_hal_allocator_t* base_allocator, + hipDeviceptr_t device_ptr, void* host_ptr, iree_hal_memory_type_t memory_type) { - iree_hal_rocm_allocator_t *allocator = + iree_hal_rocm_allocator_t* allocator = iree_hal_rocm_allocator_cast(base_allocator); - if (iree_all_bits_set(memory_type, IREE_HAL_MEMORY_TYPE_HOST_VISIBLE)) { - ROCM_IGNORE_ERROR(allocator->context->syms, hipHostFree(host_ptr)); - } else { + if (iree_all_bits_set(memory_type, IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL)) { ROCM_IGNORE_ERROR(allocator->context->syms, hipFree(device_ptr)); + } else { + // Host local. + ROCM_IGNORE_ERROR(allocator->context->syms, hipHostFree(host_ptr)); } } static iree_status_t iree_hal_rocm_allocator_wrap_buffer( - iree_hal_allocator_t *base_allocator, iree_hal_memory_type_t memory_type, + iree_hal_allocator_t* base_allocator, iree_hal_memory_type_t memory_type, iree_hal_memory_access_t allowed_access, iree_hal_buffer_usage_t allowed_usage, iree_byte_span_t data, - iree_allocator_t data_allocator, iree_hal_buffer_t **out_buffer) { + iree_allocator_t data_allocator, iree_hal_buffer_t** out_buffer) { return iree_make_status(IREE_STATUS_UNAVAILABLE, "wrapping of external buffers not supported"); }
diff --git a/experimental/rocm/rocm_allocator.h b/experimental/rocm/rocm_allocator.h index a47480c..37d6710 100644 --- a/experimental/rocm/rocm_allocator.h +++ b/experimental/rocm/rocm_allocator.h
@@ -18,12 +18,12 @@ // Create a ROCM allocator. iree_status_t iree_hal_rocm_allocator_create( - iree_hal_rocm_context_wrapper_t *context, - iree_hal_allocator_t **out_allocator); + iree_hal_rocm_context_wrapper_t* context, + iree_hal_allocator_t** out_allocator); // Free an allocation represent by the given device or host pointer. -void iree_hal_rocm_allocator_free(iree_hal_allocator_t *allocator, - hipDeviceptr_t device_ptr, void *host_ptr, +void iree_hal_rocm_allocator_free(iree_hal_allocator_t* allocator, + hipDeviceptr_t device_ptr, void* host_ptr, iree_hal_memory_type_t memory_type); #ifdef __cplusplus
diff --git a/experimental/rocm/rocm_buffer.c b/experimental/rocm/rocm_buffer.c index 195e897..8822cb5 100644 --- a/experimental/rocm/rocm_buffer.c +++ b/experimental/rocm/rocm_buffer.c
@@ -16,32 +16,32 @@ typedef struct iree_hal_rocm_buffer_t { iree_hal_buffer_t base; - void *host_ptr; + void* host_ptr; hipDeviceptr_t device_ptr; } iree_hal_rocm_buffer_t; extern const iree_hal_buffer_vtable_t iree_hal_rocm_buffer_vtable; -static iree_hal_rocm_buffer_t *iree_hal_rocm_buffer_cast( - iree_hal_buffer_t *base_value) { +static iree_hal_rocm_buffer_t* iree_hal_rocm_buffer_cast( + iree_hal_buffer_t* base_value) { IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_rocm_buffer_vtable); - return (iree_hal_rocm_buffer_t *)base_value; + return (iree_hal_rocm_buffer_t*)base_value; } iree_status_t iree_hal_rocm_buffer_wrap( - iree_hal_allocator_t *allocator, iree_hal_memory_type_t memory_type, + iree_hal_allocator_t* allocator, iree_hal_memory_type_t memory_type, iree_hal_memory_access_t allowed_access, iree_hal_buffer_usage_t allowed_usage, iree_device_size_t allocation_size, iree_device_size_t byte_offset, iree_device_size_t byte_length, - hipDeviceptr_t device_ptr, void *host_ptr, iree_hal_buffer_t **out_buffer) { + hipDeviceptr_t device_ptr, void* host_ptr, iree_hal_buffer_t** out_buffer) { IREE_ASSERT_ARGUMENT(allocator); IREE_ASSERT_ARGUMENT(out_buffer); IREE_TRACE_ZONE_BEGIN(z0); - iree_hal_rocm_buffer_t *buffer = NULL; + iree_hal_rocm_buffer_t* buffer = NULL; iree_status_t status = iree_allocator_malloc(iree_hal_allocator_host_allocator(allocator), - sizeof(*buffer), (void **)&buffer); + sizeof(*buffer), (void**)&buffer); if (iree_status_is_ok(status)) { iree_hal_resource_initialize(&iree_hal_rocm_buffer_vtable, &buffer->base.resource); @@ -59,11 +59,11 @@ } IREE_TRACE_ZONE_END(z0); - return iree_ok_status(); + return status; } -static void iree_hal_rocm_buffer_destroy(iree_hal_buffer_t *base_buffer) { - iree_hal_rocm_buffer_t *buffer = iree_hal_rocm_buffer_cast(base_buffer); +static void iree_hal_rocm_buffer_destroy(iree_hal_buffer_t* base_buffer) { + iree_hal_rocm_buffer_t* buffer = iree_hal_rocm_buffer_cast(base_buffer); iree_allocator_t host_allocator = iree_hal_allocator_host_allocator(iree_hal_buffer_allocator(base_buffer)); IREE_TRACE_ZONE_BEGIN(z0); @@ -76,11 +76,11 @@ } static iree_status_t iree_hal_rocm_buffer_map_range( - iree_hal_buffer_t *base_buffer, iree_hal_mapping_mode_t mapping_mode, + iree_hal_buffer_t* base_buffer, iree_hal_mapping_mode_t mapping_mode, iree_hal_memory_access_t memory_access, iree_device_size_t local_byte_offset, iree_device_size_t local_byte_length, - void **out_data_ptr) { - iree_hal_rocm_buffer_t *buffer = iree_hal_rocm_buffer_cast(base_buffer); + void** out_data_ptr) { + iree_hal_rocm_buffer_t* buffer = iree_hal_rocm_buffer_cast(base_buffer); if (!iree_all_bits_set(buffer->base.memory_type, IREE_HAL_MEMORY_TYPE_HOST_VISIBLE)) { @@ -88,14 +88,14 @@ "trying to map memory not host visible"); } - uint8_t *data_ptr = (uint8_t *)(buffer->host_ptr) + local_byte_offset; + uint8_t* data_ptr = (uint8_t*)(buffer->host_ptr) + local_byte_offset; // If we mapped for discard scribble over the bytes. This is not a mandated // behavior but it will make debugging issues easier. Alternatively for // heap buffers we could reallocate them such that ASAN yells, but that // would only work if the entire buffer was discarded. #ifndef NDEBUG if (iree_any_bit_set(memory_access, IREE_HAL_MEMORY_ACCESS_DISCARD)) { - memset(data_ptr + local_byte_offset, 0xCD, local_byte_length); + memset(data_ptr, 0xCD, local_byte_length); } #endif // !NDEBUG *out_data_ptr = data_ptr; @@ -103,28 +103,28 @@ } static void iree_hal_rocm_buffer_unmap_range( - iree_hal_buffer_t *base_buffer, iree_device_size_t local_byte_offset, - iree_device_size_t local_byte_length, void *data_ptr) { + iree_hal_buffer_t* base_buffer, iree_device_size_t local_byte_offset, + iree_device_size_t local_byte_length, void* data_ptr) { // nothing to do. } static iree_status_t iree_hal_rocm_buffer_invalidate_range( - iree_hal_buffer_t *base_buffer, iree_device_size_t local_byte_offset, + iree_hal_buffer_t* base_buffer, iree_device_size_t local_byte_offset, iree_device_size_t local_byte_length) { // Nothing to do. return iree_ok_status(); } static iree_status_t iree_hal_rocm_buffer_flush_range( - iree_hal_buffer_t *base_buffer, iree_device_size_t local_byte_offset, + iree_hal_buffer_t* base_buffer, iree_device_size_t local_byte_offset, iree_device_size_t local_byte_length) { // Nothing to do. return iree_ok_status(); } hipDeviceptr_t iree_hal_rocm_buffer_device_pointer( - iree_hal_buffer_t *base_buffer) { - iree_hal_rocm_buffer_t *buffer = iree_hal_rocm_buffer_cast(base_buffer); + iree_hal_buffer_t* base_buffer) { + iree_hal_rocm_buffer_t* buffer = iree_hal_rocm_buffer_cast(base_buffer); return buffer->device_ptr; }
diff --git a/experimental/rocm/rocm_buffer.h b/experimental/rocm/rocm_buffer.h index e898630..85cf294 100644 --- a/experimental/rocm/rocm_buffer.h +++ b/experimental/rocm/rocm_buffer.h
@@ -17,16 +17,16 @@ // Wraps a rocm allocation in an iree_hal_buffer_t. iree_status_t iree_hal_rocm_buffer_wrap( - iree_hal_allocator_t *allocator, iree_hal_memory_type_t memory_type, + iree_hal_allocator_t* allocator, iree_hal_memory_type_t memory_type, iree_hal_memory_access_t allowed_access, iree_hal_buffer_usage_t allowed_usage, iree_device_size_t allocation_size, iree_device_size_t byte_offset, iree_device_size_t byte_length, - hipDeviceptr_t device_ptr, void *host_ptr, iree_hal_buffer_t **out_buffer); + hipDeviceptr_t device_ptr, void* host_ptr, iree_hal_buffer_t** out_buffer); // Returns the rocm base pointer for the given |buffer|. // This is the entire allocated_buffer and must be offset by the buffer // byte_offset and byte_length when used. -hipDeviceptr_t iree_hal_rocm_buffer_device_pointer(iree_hal_buffer_t *buffer); +hipDeviceptr_t iree_hal_rocm_buffer_device_pointer(iree_hal_buffer_t* buffer); #ifdef __cplusplus } // extern "C"
diff --git a/experimental/rocm/rocm_device.c b/experimental/rocm/rocm_device.c index cff4d89..fa5a75e 100644 --- a/experimental/rocm/rocm_device.c +++ b/experimental/rocm/rocm_device.c
@@ -32,27 +32,27 @@ // Optional driver that owns the ROCM symbols. We retain it for our lifetime // to ensure the symbols remains valid. - iree_hal_driver_t *driver; + iree_hal_driver_t* driver; hipDevice_t device; // TODO: support multiple streams. hipStream_t stream; iree_hal_rocm_context_wrapper_t context_wrapper; - iree_hal_allocator_t *device_allocator; + iree_hal_allocator_t* device_allocator; } iree_hal_rocm_device_t; extern const iree_hal_device_vtable_t iree_hal_rocm_device_vtable; -static iree_hal_rocm_device_t *iree_hal_rocm_device_cast( - iree_hal_device_t *base_value) { +static iree_hal_rocm_device_t* iree_hal_rocm_device_cast( + iree_hal_device_t* base_value) { IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_rocm_device_vtable); - return (iree_hal_rocm_device_t *)base_value; + return (iree_hal_rocm_device_t*)base_value; } -static void iree_hal_rocm_device_destroy(iree_hal_device_t *base_device) { - iree_hal_rocm_device_t *device = iree_hal_rocm_device_cast(base_device); +static void iree_hal_rocm_device_destroy(iree_hal_device_t* base_device) { + iree_hal_rocm_device_t* device = iree_hal_rocm_device_cast(base_device); iree_allocator_t host_allocator = iree_hal_device_host_allocator(base_device); IREE_TRACE_ZONE_BEGIN(z0); @@ -70,21 +70,21 @@ } static iree_status_t iree_hal_rocm_device_create_internal( - iree_hal_driver_t *driver, iree_string_view_t identifier, + iree_hal_driver_t* driver, iree_string_view_t identifier, hipDevice_t rocm_device, hipStream_t stream, hipCtx_t context, - iree_hal_rocm_dynamic_symbols_t *syms, iree_allocator_t host_allocator, - iree_hal_device_t **out_device) { - iree_hal_rocm_device_t *device = NULL; + iree_hal_rocm_dynamic_symbols_t* syms, iree_allocator_t host_allocator, + iree_hal_device_t** out_device) { + iree_hal_rocm_device_t* device = NULL; iree_host_size_t total_size = sizeof(*device) + identifier.size; IREE_RETURN_IF_ERROR( - iree_allocator_malloc(host_allocator, total_size, (void **)&device)); + iree_allocator_malloc(host_allocator, total_size, (void**)&device)); memset(device, 0, total_size); iree_hal_resource_initialize(&iree_hal_rocm_device_vtable, &device->resource); device->driver = driver; iree_hal_driver_retain(device->driver); - uint8_t *buffer_ptr = (uint8_t *)device + sizeof(*device); + uint8_t* buffer_ptr = (uint8_t*)device + sizeof(*device); buffer_ptr += iree_string_view_append_to_buffer( - identifier, &device->identifier, (char *)buffer_ptr); + identifier, &device->identifier, (char*)buffer_ptr); device->device = rocm_device; device->stream = stream; device->context_wrapper.rocm_context = context; @@ -93,19 +93,19 @@ iree_status_t status = iree_hal_rocm_allocator_create( &device->context_wrapper, &device->device_allocator); if (iree_status_is_ok(status)) { - *out_device = (iree_hal_device_t *)device; + *out_device = (iree_hal_device_t*)device; } else { - iree_hal_device_release((iree_hal_device_t *)device); + iree_hal_device_release((iree_hal_device_t*)device); } return status; } -iree_status_t iree_hal_rocm_device_create(iree_hal_driver_t *driver, +iree_status_t iree_hal_rocm_device_create(iree_hal_driver_t* driver, iree_string_view_t identifier, - iree_hal_rocm_dynamic_symbols_t *syms, + iree_hal_rocm_dynamic_symbols_t* syms, hipDevice_t device, iree_allocator_t host_allocator, - iree_hal_device_t **out_device) { + iree_hal_device_t** out_device) { IREE_TRACE_ZONE_BEGIN(z0); hipCtx_t context; IREE_RETURN_AND_END_ZONE_IF_ERROR( @@ -130,26 +130,26 @@ } static iree_string_view_t iree_hal_rocm_device_id( - iree_hal_device_t *base_device) { - iree_hal_rocm_device_t *device = iree_hal_rocm_device_cast(base_device); + iree_hal_device_t* base_device) { + iree_hal_rocm_device_t* device = iree_hal_rocm_device_cast(base_device); return device->identifier; } static iree_allocator_t iree_hal_rocm_device_host_allocator( - iree_hal_device_t *base_device) { - iree_hal_rocm_device_t *device = iree_hal_rocm_device_cast(base_device); + iree_hal_device_t* base_device) { + iree_hal_rocm_device_t* device = iree_hal_rocm_device_cast(base_device); return device->context_wrapper.host_allocator; } -static iree_hal_allocator_t *iree_hal_rocm_device_allocator( - iree_hal_device_t *base_device) { - iree_hal_rocm_device_t *device = iree_hal_rocm_device_cast(base_device); +static iree_hal_allocator_t* iree_hal_rocm_device_allocator( + iree_hal_device_t* base_device) { + iree_hal_rocm_device_t* device = iree_hal_rocm_device_cast(base_device); return device->device_allocator; } static iree_status_t iree_hal_rocm_device_query_i32( - iree_hal_device_t *base_device, iree_string_view_t category, - iree_string_view_t key, int32_t *out_value) { + iree_hal_device_t* base_device, iree_string_view_t category, + iree_string_view_t key, int32_t* out_value) { // iree_hal_rocm_device_t* device = iree_hal_rocm_device_cast(base_device); *out_value = 0; @@ -169,77 +169,77 @@ } static iree_status_t iree_hal_rocm_device_create_command_buffer( - iree_hal_device_t *base_device, iree_hal_command_buffer_mode_t mode, + iree_hal_device_t* base_device, iree_hal_command_buffer_mode_t mode, iree_hal_command_category_t command_categories, iree_hal_queue_affinity_t queue_affinity, - iree_hal_command_buffer_t **out_command_buffer) { - iree_hal_rocm_device_t *device = iree_hal_rocm_device_cast(base_device); - return iree_hal_rocm_direct_command_buffer_allocate( + iree_hal_command_buffer_t** out_command_buffer) { + iree_hal_rocm_device_t* device = iree_hal_rocm_device_cast(base_device); + return iree_hal_rocm_direct_command_buffer_create( &device->context_wrapper, mode, command_categories, queue_affinity, out_command_buffer); } static iree_status_t iree_hal_rocm_device_create_descriptor_set( - iree_hal_device_t *base_device, - iree_hal_descriptor_set_layout_t *set_layout, + iree_hal_device_t* base_device, + iree_hal_descriptor_set_layout_t* set_layout, iree_host_size_t binding_count, - const iree_hal_descriptor_set_binding_t *bindings, - iree_hal_descriptor_set_t **out_descriptor_set) { + const iree_hal_descriptor_set_binding_t* bindings, + iree_hal_descriptor_set_t** out_descriptor_set) { return iree_make_status(IREE_STATUS_UNIMPLEMENTED, "non-push descriptor sets still need work"); } static iree_status_t iree_hal_rocm_device_create_descriptor_set_layout( - iree_hal_device_t *base_device, + iree_hal_device_t* base_device, iree_hal_descriptor_set_layout_usage_type_t usage_type, iree_host_size_t binding_count, - const iree_hal_descriptor_set_layout_binding_t *bindings, - iree_hal_descriptor_set_layout_t **out_descriptor_set_layout) { - iree_hal_rocm_device_t *device = iree_hal_rocm_device_cast(base_device); + const iree_hal_descriptor_set_layout_binding_t* bindings, + iree_hal_descriptor_set_layout_t** out_descriptor_set_layout) { + iree_hal_rocm_device_t* device = iree_hal_rocm_device_cast(base_device); return iree_hal_rocm_descriptor_set_layout_create( &device->context_wrapper, usage_type, binding_count, bindings, out_descriptor_set_layout); } static iree_status_t iree_hal_rocm_device_create_event( - iree_hal_device_t *base_device, iree_hal_event_t **out_event) { - iree_hal_rocm_device_t *device = iree_hal_rocm_device_cast(base_device); + iree_hal_device_t* base_device, iree_hal_event_t** out_event) { + iree_hal_rocm_device_t* device = iree_hal_rocm_device_cast(base_device); return iree_hal_rocm_event_create(&device->context_wrapper, out_event); } static iree_status_t iree_hal_rocm_device_create_executable_cache( - iree_hal_device_t *base_device, iree_string_view_t identifier, - iree_hal_executable_cache_t **out_executable_cache) { - iree_hal_rocm_device_t *device = iree_hal_rocm_device_cast(base_device); + iree_hal_device_t* base_device, iree_string_view_t identifier, + iree_hal_executable_cache_t** out_executable_cache) { + iree_hal_rocm_device_t* device = iree_hal_rocm_device_cast(base_device); return iree_hal_rocm_nop_executable_cache_create( &device->context_wrapper, identifier, out_executable_cache); } static iree_status_t iree_hal_rocm_device_create_executable_layout( - iree_hal_device_t *base_device, iree_host_size_t push_constants, + iree_hal_device_t* base_device, iree_host_size_t push_constants, iree_host_size_t set_layout_count, - iree_hal_descriptor_set_layout_t **set_layouts, - iree_hal_executable_layout_t **out_executable_layout) { - iree_hal_rocm_device_t *device = iree_hal_rocm_device_cast(base_device); + iree_hal_descriptor_set_layout_t** set_layouts, + iree_hal_executable_layout_t** out_executable_layout) { + iree_hal_rocm_device_t* device = iree_hal_rocm_device_cast(base_device); return iree_hal_rocm_executable_layout_create( &device->context_wrapper, set_layout_count, set_layouts, push_constants, out_executable_layout); } static iree_status_t iree_hal_rocm_device_create_semaphore( - iree_hal_device_t *base_device, uint64_t initial_value, - iree_hal_semaphore_t **out_semaphore) { - iree_hal_rocm_device_t *device = iree_hal_rocm_device_cast(base_device); + iree_hal_device_t* base_device, uint64_t initial_value, + iree_hal_semaphore_t** out_semaphore) { + iree_hal_rocm_device_t* device = iree_hal_rocm_device_cast(base_device); return iree_hal_rocm_semaphore_create(&device->context_wrapper, initial_value, out_semaphore); } static iree_status_t iree_hal_rocm_device_queue_submit( - iree_hal_device_t *base_device, + iree_hal_device_t* base_device, iree_hal_command_category_t command_categories, iree_hal_queue_affinity_t queue_affinity, iree_host_size_t batch_count, - const iree_hal_submission_batch_t *batches) { - iree_hal_rocm_device_t *device = iree_hal_rocm_device_cast(base_device); + const iree_hal_submission_batch_t* batches) { + iree_hal_rocm_device_t* device = iree_hal_rocm_device_cast(base_device); // TODO(raikonenfnu): Once semaphore is implemented wait for semaphores // TODO(thomasraoux): Conservatively syncronize after every submit until we // support semaphores. @@ -251,11 +251,11 @@ } static iree_status_t iree_hal_rocm_device_submit_and_wait( - iree_hal_device_t *base_device, + iree_hal_device_t* base_device, iree_hal_command_category_t command_categories, iree_hal_queue_affinity_t queue_affinity, iree_host_size_t batch_count, - const iree_hal_submission_batch_t *batches, - iree_hal_semaphore_t *wait_semaphore, uint64_t wait_value, + const iree_hal_submission_batch_t* batches, + iree_hal_semaphore_t* wait_semaphore, uint64_t wait_value, iree_timeout_t timeout) { // Submit... IREE_RETURN_IF_ERROR(iree_hal_rocm_device_queue_submit( @@ -266,15 +266,15 @@ } static iree_status_t iree_hal_rocm_device_wait_semaphores( - iree_hal_device_t *base_device, iree_hal_wait_mode_t wait_mode, - const iree_hal_semaphore_list_t *semaphore_list, iree_timeout_t timeout) { + iree_hal_device_t* base_device, iree_hal_wait_mode_t wait_mode, + const iree_hal_semaphore_list_t* semaphore_list, iree_timeout_t timeout) { return iree_make_status(IREE_STATUS_UNIMPLEMENTED, "semaphore not implemented"); } static iree_status_t iree_hal_rocm_device_wait_idle( - iree_hal_device_t *base_device, iree_timeout_t timeout) { - iree_hal_rocm_device_t *device = iree_hal_rocm_device_cast(base_device); + iree_hal_device_t* base_device, iree_timeout_t timeout) { + iree_hal_rocm_device_t* device = iree_hal_rocm_device_cast(base_device); // Wait until the stream is done. // TODO(thomasraoux): HIP doesn't support a deadline for wait, figure out how // to handle it better.
diff --git a/experimental/rocm/rocm_device.h b/experimental/rocm/rocm_device.h index 67e2f9b..083f4c7 100644 --- a/experimental/rocm/rocm_device.h +++ b/experimental/rocm/rocm_device.h
@@ -17,12 +17,12 @@ #endif // __cplusplus // Creates a device that owns and manages its own hipContext. -iree_status_t iree_hal_rocm_device_create(iree_hal_driver_t *driver, +iree_status_t iree_hal_rocm_device_create(iree_hal_driver_t* driver, iree_string_view_t identifier, - iree_hal_rocm_dynamic_symbols_t *syms, + iree_hal_rocm_dynamic_symbols_t* syms, hipDevice_t device, iree_allocator_t host_allocator, - iree_hal_device_t **out_device); + iree_hal_device_t** out_device); #ifdef __cplusplus } // extern "C"
diff --git a/experimental/rocm/rocm_driver.c b/experimental/rocm/rocm_driver.c index 8bd8ae6..808642f 100644 --- a/experimental/rocm/rocm_driver.c +++ b/experimental/rocm/rocm_driver.c
@@ -32,44 +32,44 @@ extern const iree_hal_driver_vtable_t iree_hal_rocm_driver_vtable; -static iree_hal_rocm_driver_t *iree_hal_rocm_driver_cast( - iree_hal_driver_t *base_value) { +static iree_hal_rocm_driver_t* iree_hal_rocm_driver_cast( + iree_hal_driver_t* base_value) { IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_rocm_driver_vtable); - return (iree_hal_rocm_driver_t *)base_value; + return (iree_hal_rocm_driver_t*)base_value; } IREE_API_EXPORT void iree_hal_rocm_driver_options_initialize( - iree_hal_rocm_driver_options_t *out_options) { + iree_hal_rocm_driver_options_t* out_options) { memset(out_options, 0, sizeof(*out_options)); out_options->default_device_index = 0; } static iree_status_t iree_hal_rocm_driver_create_internal( iree_string_view_t identifier, - const iree_hal_rocm_driver_options_t *options, - iree_allocator_t host_allocator, iree_hal_driver_t **out_driver) { - iree_hal_rocm_driver_t *driver = NULL; + const iree_hal_rocm_driver_options_t* options, + iree_allocator_t host_allocator, iree_hal_driver_t** out_driver) { + iree_hal_rocm_driver_t* driver = NULL; iree_host_size_t total_size = sizeof(*driver) + identifier.size; IREE_RETURN_IF_ERROR( - iree_allocator_malloc(host_allocator, total_size, (void **)&driver)); + iree_allocator_malloc(host_allocator, total_size, (void**)&driver)); iree_hal_resource_initialize(&iree_hal_rocm_driver_vtable, &driver->resource); driver->host_allocator = host_allocator; iree_string_view_append_to_buffer( identifier, &driver->identifier, - (char *)driver + total_size - identifier.size); + (char*)driver + total_size - identifier.size); driver->default_device_index = options->default_device_index; iree_status_t status = iree_hal_rocm_dynamic_symbols_initialize(host_allocator, &driver->syms); if (iree_status_is_ok(status)) { - *out_driver = (iree_hal_driver_t *)driver; + *out_driver = (iree_hal_driver_t*)driver; } else { - iree_hal_driver_release((iree_hal_driver_t *)driver); + iree_hal_driver_release((iree_hal_driver_t*)driver); } return status; } -static void iree_hal_rocm_driver_destroy(iree_hal_driver_t *base_driver) { - iree_hal_rocm_driver_t *driver = iree_hal_rocm_driver_cast(base_driver); +static void iree_hal_rocm_driver_destroy(iree_hal_driver_t* base_driver) { + iree_hal_rocm_driver_t* driver = iree_hal_rocm_driver_cast(base_driver); iree_allocator_t host_allocator = driver->host_allocator; IREE_TRACE_ZONE_BEGIN(z0); @@ -81,8 +81,8 @@ IREE_API_EXPORT iree_status_t iree_hal_rocm_driver_create( iree_string_view_t identifier, - const iree_hal_rocm_driver_options_t *options, - iree_allocator_t host_allocator, iree_hal_driver_t **out_driver) { + const iree_hal_rocm_driver_options_t* options, + iree_allocator_t host_allocator, iree_hal_driver_t** out_driver) { IREE_ASSERT_ARGUMENT(options); IREE_ASSERT_ARGUMENT(out_driver); IREE_TRACE_ZONE_BEGIN(z0); @@ -97,9 +97,9 @@ // Populates device information from the given ROCM physical device handle. // |out_device_info| must point to valid memory and additional data will be // appended to |buffer_ptr| and the new pointer is returned. -static uint8_t *iree_hal_rocm_populate_device_info( - hipDevice_t device, iree_hal_rocm_dynamic_symbols_t *syms, - uint8_t *buffer_ptr, iree_hal_device_info_t *out_device_info) { +static uint8_t* iree_hal_rocm_populate_device_info( + hipDevice_t device, iree_hal_rocm_dynamic_symbols_t* syms, + uint8_t* buffer_ptr, iree_hal_device_info_t* out_device_info) { char device_name[IREE_MAX_ROCM_DEVICE_NAME_LENGTH]; ROCM_IGNORE_ERROR(syms, hipDeviceGetName(device_name, sizeof(device_name), device)); @@ -109,31 +109,31 @@ iree_string_view_t device_name_string = iree_make_string_view(device_name, strlen(device_name)); buffer_ptr += iree_string_view_append_to_buffer( - device_name_string, &out_device_info->name, (char *)buffer_ptr); + device_name_string, &out_device_info->name, (char*)buffer_ptr); return buffer_ptr; } static iree_status_t iree_hal_rocm_driver_query_available_devices( - iree_hal_driver_t *base_driver, iree_allocator_t host_allocator, - iree_hal_device_info_t **out_device_infos, - iree_host_size_t *out_device_info_count) { - iree_hal_rocm_driver_t *driver = iree_hal_rocm_driver_cast(base_driver); + iree_hal_driver_t* base_driver, iree_allocator_t host_allocator, + iree_hal_device_info_t** out_device_infos, + iree_host_size_t* out_device_info_count) { + iree_hal_rocm_driver_t* driver = iree_hal_rocm_driver_cast(base_driver); // Query the number of available ROCM devices. int device_count = 0; ROCM_RETURN_IF_ERROR(&driver->syms, hipGetDeviceCount(&device_count), "hipGetDeviceCount"); // Allocate the return infos and populate with the devices. - iree_hal_device_info_t *device_infos = NULL; + iree_hal_device_info_t* device_infos = NULL; iree_host_size_t total_size = device_count * sizeof(iree_hal_device_info_t); for (iree_host_size_t i = 0; i < device_count; ++i) { total_size += IREE_MAX_ROCM_DEVICE_NAME_LENGTH * sizeof(char); } iree_status_t status = - iree_allocator_malloc(host_allocator, total_size, (void **)&device_infos); + iree_allocator_malloc(host_allocator, total_size, (void**)&device_infos); if (iree_status_is_ok(status)) { - uint8_t *buffer_ptr = - (uint8_t *)device_infos + device_count * sizeof(iree_hal_device_info_t); + uint8_t* buffer_ptr = + (uint8_t*)device_infos + device_count * sizeof(iree_hal_device_info_t); for (iree_host_size_t i = 0; i < device_count; ++i) { hipDevice_t device; iree_status_t status = ROCM_RESULT_TO_STATUS( @@ -153,8 +153,8 @@ } static iree_status_t iree_hal_rocm_driver_select_default_device( - iree_hal_rocm_dynamic_symbols_t *syms, int default_device_index, - iree_allocator_t host_allocator, hipDevice_t *out_device) { + iree_hal_rocm_dynamic_symbols_t* syms, int default_device_index, + iree_allocator_t host_allocator, hipDevice_t* out_device) { int device_count = 0; ROCM_RETURN_IF_ERROR(syms, hipGetDeviceCount(&device_count), "hipGetDeviceCount"); @@ -173,9 +173,9 @@ } static iree_status_t iree_hal_rocm_driver_create_device( - iree_hal_driver_t *base_driver, iree_hal_device_id_t device_id, - iree_allocator_t host_allocator, iree_hal_device_t **out_device) { - iree_hal_rocm_driver_t *driver = iree_hal_rocm_driver_cast(base_driver); + iree_hal_driver_t* base_driver, iree_hal_device_id_t device_id, + iree_allocator_t host_allocator, iree_hal_device_t** out_device) { + iree_hal_rocm_driver_t* driver = iree_hal_rocm_driver_cast(base_driver); IREE_TRACE_ZONE_BEGIN(z0); IREE_RETURN_AND_END_ZONE_IF_ERROR(
diff --git a/experimental/rocm/rocm_event.c b/experimental/rocm/rocm_event.c index 48b20c3..945a5f7 100644 --- a/experimental/rocm/rocm_event.c +++ b/experimental/rocm/rocm_event.c
@@ -14,40 +14,40 @@ // Dummy events for now, don't do anything. typedef struct iree_hal_rocm_event_t { iree_hal_resource_t resource; - iree_hal_rocm_context_wrapper_t *context_wrapper; + iree_hal_rocm_context_wrapper_t* context_wrapper; } iree_hal_rocm_event_t; extern const iree_hal_event_vtable_t iree_hal_rocm_event_vtable; -static iree_hal_rocm_event_t *iree_hal_rocm_event_cast( - iree_hal_event_t *base_value) { +static iree_hal_rocm_event_t* iree_hal_rocm_event_cast( + iree_hal_event_t* base_value) { IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_rocm_event_vtable); - return (iree_hal_rocm_event_t *)base_value; + return (iree_hal_rocm_event_t*)base_value; } iree_status_t iree_hal_rocm_event_create( - iree_hal_rocm_context_wrapper_t *context_wrapper, - iree_hal_event_t **out_event) { + iree_hal_rocm_context_wrapper_t* context_wrapper, + iree_hal_event_t** out_event) { IREE_ASSERT_ARGUMENT(context_wrapper); IREE_ASSERT_ARGUMENT(out_event); *out_event = NULL; IREE_TRACE_ZONE_BEGIN(z0); - iree_hal_rocm_event_t *event = NULL; + iree_hal_rocm_event_t* event = NULL; iree_status_t status = iree_allocator_malloc(context_wrapper->host_allocator, - sizeof(*event), (void **)&event); + sizeof(*event), (void**)&event); if (iree_status_is_ok(status)) { iree_hal_resource_initialize(&iree_hal_rocm_event_vtable, &event->resource); event->context_wrapper = context_wrapper; - *out_event = (iree_hal_event_t *)event; + *out_event = (iree_hal_event_t*)event; } IREE_TRACE_ZONE_END(z0); return status; } -static void iree_hal_rocm_event_destroy(iree_hal_event_t *base_event) { - iree_hal_rocm_event_t *event = iree_hal_rocm_event_cast(base_event); +static void iree_hal_rocm_event_destroy(iree_hal_event_t* base_event) { + iree_hal_rocm_event_t* event = iree_hal_rocm_event_cast(base_event); iree_allocator_t host_allocator = event->context_wrapper->host_allocator; IREE_TRACE_ZONE_BEGIN(z0);
diff --git a/experimental/rocm/rocm_event.h b/experimental/rocm/rocm_event.h index 73041a4..0bac1a2 100644 --- a/experimental/rocm/rocm_event.h +++ b/experimental/rocm/rocm_event.h
@@ -21,8 +21,8 @@ // command buffer we will add the appropriate edges to enforce the right // synchronization. iree_status_t iree_hal_rocm_event_create( - iree_hal_rocm_context_wrapper_t *context_wrapper, - iree_hal_event_t **out_event); + iree_hal_rocm_context_wrapper_t* context_wrapper, + iree_hal_event_t** out_event); #ifdef __cplusplus } // extern "C"
diff --git a/experimental/rocm/status_util.h b/experimental/rocm/status_util.h index b459591..0f6fcc5 100644 --- a/experimental/rocm/status_util.h +++ b/experimental/rocm/status_util.h
@@ -44,7 +44,7 @@ // Converts a hipError_t to a Status object. iree_status_t iree_hal_rocm_result_to_status( - iree_hal_rocm_dynamic_symbols_t *syms, hipError_t result, const char *file, + iree_hal_rocm_dynamic_symbols_t* syms, hipError_t result, const char* file, uint32_t line); #ifdef __cplusplus
diff --git a/integrations/tensorflow/iree_tf_compiler/iree-import-tf-main.cpp b/integrations/tensorflow/iree_tf_compiler/iree-import-tf-main.cpp index feec3ff..ee95f1e 100644 --- a/integrations/tensorflow/iree_tf_compiler/iree-import-tf-main.cpp +++ b/integrations/tensorflow/iree_tf_compiler/iree-import-tf-main.cpp
@@ -224,8 +224,8 @@ iree_integrations::TF::buildTFImportPassPipeline(pm); if (failed(pm.run(*module))) { - llvm::errs() - << "Running iree-import-tf pass pipeline failed (see diagnostics)\n"; + llvm::errs() << "Running iree-import-tf TF import pass pipeline failed " + "(see diagnostics)\n"; return 2; } if (!saveTempMidLevelImport.empty()) { @@ -237,8 +237,8 @@ applyPassManagerCLOptions(pm); iree_integrations::MHLO::buildMHLOImportPassPipeline(pm); if (failed(pm.run(*module))) { - llvm::errs() - << "Running iree-import-tf pass pipeline failed (see diagnostics)\n"; + llvm::errs() << "Running iree-import-tf MHLO Import pass pipeline failed " + "(see diagnostics)\n"; return 2; } }
diff --git a/integrations/tensorflow/iree_tf_compiler/iree-import-xla-main.cpp b/integrations/tensorflow/iree_tf_compiler/iree-import-xla-main.cpp index 2e11d7f..08f270a 100644 --- a/integrations/tensorflow/iree_tf_compiler/iree-import-xla-main.cpp +++ b/integrations/tensorflow/iree_tf_compiler/iree-import-xla-main.cpp
@@ -100,6 +100,10 @@ static cl::opt<std::string> outputFilename("o", cl::desc("Output filename"), cl::value_desc("filename"), cl::init("-")); + static llvm::cl::opt<std::string> saveTempMhloInput( + "save-temp-mhlo-input", + llvm::cl::desc("Save the MHLO pipeline input IR to this file"), + llvm::cl::init("")); static llvm::cl::opt<std::string> saveTempIreeImport( "save-temp-iree-input", llvm::cl::desc("Save the resultant IR to this file (useful for saving an " @@ -251,6 +255,11 @@ return success(); }; + // Save temp output. + if (!saveTempMhloInput.empty()) { + if (failed(saveToFile(saveTempMhloInput))) return 10; + } + // Run passes. PassManager pm(&context, PassManager::Nesting::Implicit); applyPassManagerCLOptions(pm); @@ -264,8 +273,8 @@ iree_integrations::MHLO::createEmitDefaultIREEABIPass()); if (failed(pm.run(*module))) { - llvm::errs() - << "Running iree-xla-import pass pipeline failed (see diagnostics)\n"; + llvm::errs() << "Running iree-xla-import MHLO import pass pipeline failed " + "(see diagnostics)\n"; return 2; }
diff --git a/iree/base/alignment.h b/iree/base/alignment.h index 0f5953a..5229d97 100644 --- a/iree/base/alignment.h +++ b/iree/base/alignment.h
@@ -11,6 +11,8 @@ #define IREE_BASE_ALIGNMENT_H_ #include <stddef.h> +#include <stdint.h> +#include <string.h> #include "iree/base/config.h" #include "iree/base/target_platform.h" @@ -55,6 +57,178 @@ return (value + (alignment - 1)) & ~(alignment - 1); } +// Returns the size of a struct padded out to iree_max_align_t. +// This must be used when performing manual trailing allocation packing to +// ensure the alignment requirements of the trailing data are satisified. +// +// NOTE: do not use this if using VLAs (`struct { int trailing[]; }`) - those +// must precisely follow the normal sizeof(t) as the compiler does the padding +// for you. +// +// Example: +// some_buffer_ptr_t* p = NULL; +// iree_host_size_t total_size = iree_sizeof_struct(*buffer) + extra_data_size; +// IREE_CHECK_OK(iree_allocator_malloc(allocator, total_size, (void**)&p)); +#define iree_sizeof_struct(t) iree_host_align(sizeof(t), iree_max_align_t) + +//===----------------------------------------------------------------------===// +// Alignment-safe memory accesses +//===----------------------------------------------------------------------===// + +// Map little-endian byte indices in memory to the host memory order indices. +#if defined(IREE_ENDIANNESS_LITTLE) +#define IREE_LE_IDX_1(i) (i) +#define IREE_LE_IDX_2(i) (i) +#define IREE_LE_IDX_4(i) (i) +#define IREE_LE_IDX_8(i) (i) +#else +#define IREE_LE_IDX_1(i) (i) +#define IREE_LE_IDX_2(i) (1 - (i)) +#define IREE_LE_IDX_4(i) (3 - (i)) +#define IREE_LE_IDX_8(i) (7 - (i)) +#endif // IREE_ENDIANNESS_* + +#if IREE_MEMORY_ACCESS_ALIGNMENT_REQUIRED + +static inline uint8_t iree_unaligned_load_le_u8(const uint8_t* ptr) { + return *ptr; +} +static inline uint16_t iree_unaligned_load_le_u16(const uint16_t* ptr) { + const uint8_t* p = (const uint8_t*)ptr; + return ((uint16_t)p[IREE_LE_IDX_2(0)]) | ((uint16_t)p[IREE_LE_IDX_2(1)] << 8); +} +static inline uint32_t iree_unaligned_load_le_u32(const uint32_t* ptr) { + const uint8_t* p = (const uint8_t*)ptr; + return ((uint32_t)p[IREE_LE_IDX_4(0)]) | + ((uint32_t)p[IREE_LE_IDX_4(1)] << 8) | + ((uint32_t)p[IREE_LE_IDX_4(2)] << 16) | + ((uint32_t)p[IREE_LE_IDX_4(3)] << 24); +} +static inline uint64_t iree_unaligned_load_le_u64(const uint64_t* ptr) { + const uint8_t* p = (const uint8_t*)ptr; + return ((uint64_t)p[IREE_LE_IDX_8(0)]) | + ((uint64_t)p[IREE_LE_IDX_8(1)] << 8) | + ((uint64_t)p[IREE_LE_IDX_8(2)] << 16) | + ((uint64_t)p[IREE_LE_IDX_8(3)] << 24) | + ((uint64_t)p[IREE_LE_IDX_8(4)] << 32) | + ((uint64_t)p[IREE_LE_IDX_8(5)] << 40) | + ((uint64_t)p[IREE_LE_IDX_8(6)] << 48) | + ((uint64_t)p[IREE_LE_IDX_8(7)] << 56); +} +static inline float iree_unaligned_load_le_f32(const float* ptr) { + uint32_t uint_value = iree_unaligned_load_le_u32((const uint32_t*)ptr); + float value; + memcpy(&value, &uint_value, sizeof(value)); + return value; +} +static inline double iree_unaligned_load_le_f64(const double* ptr) { + uint64_t uint_value = iree_unaligned_load_le_u64((const uint64_t*)ptr); + double value; + memcpy(&value, &uint_value, sizeof(value)); + return value; +} + +static inline void iree_unaligned_store_le_u8(uint8_t* ptr, uint8_t value) { + *ptr = value; +} +static inline void iree_unaligned_store_le_u16(uint16_t* ptr, uint16_t value) { + uint8_t* p = (uint8_t*)ptr; + p[IREE_LE_IDX_2(0)] = value; + p[IREE_LE_IDX_2(1)] = value >> 8; +} +static inline void iree_unaligned_store_le_u32(uint32_t* ptr, uint32_t value) { + uint8_t* p = (uint8_t*)ptr; + p[IREE_LE_IDX_4(0)] = value; + p[IREE_LE_IDX_4(1)] = value >> 8; + p[IREE_LE_IDX_4(2)] = value >> 16; + p[IREE_LE_IDX_4(3)] = value >> 24; +} +static inline void iree_unaligned_store_le_u64(uint64_t* ptr, uint64_t value) { + uint8_t* p = (uint8_t*)ptr; + p[IREE_LE_IDX_8(0)] = value; + p[IREE_LE_IDX_8(1)] = value >> 8; + p[IREE_LE_IDX_8(2)] = value >> 16; + p[IREE_LE_IDX_8(3)] = value >> 24; + p[IREE_LE_IDX_8(4)] = value >> 32; + p[IREE_LE_IDX_8(5)] = value >> 40; + p[IREE_LE_IDX_8(6)] = value >> 48; + p[IREE_LE_IDX_8(7)] = value >> 56; +} +static inline void iree_unaligned_store_le_f32(float* ptr, float value) { + uint32_t uint_value; + memcpy(&uint_value, &value, sizeof(value)); + iree_unaligned_store_le_u32((uint32_t*)ptr, uint_value); +} +static inline void iree_unaligned_store_le_f64(double* ptr, double value) { + uint64_t uint_value; + memcpy(&uint_value, &value, sizeof(value)); + iree_unaligned_store_le_u64((uint64_t*)ptr, uint_value); +} + +#else + +#if defined(IREE_ENDIANNESS_LITTLE) + +#define iree_unaligned_load_le_u8(ptr) *(ptr) +#define iree_unaligned_load_le_u16(ptr) *(ptr) +#define iree_unaligned_load_le_u32(ptr) *(ptr) +#define iree_unaligned_load_le_u64(ptr) *(ptr) +#define iree_unaligned_load_le_f32(ptr) *(ptr) +#define iree_unaligned_load_le_f64(ptr) *(ptr) + +#define iree_unaligned_store_le_u8(ptr, value) *(ptr) = (value) +#define iree_unaligned_store_le_u16(ptr, value) *(ptr) = (value) +#define iree_unaligned_store_le_u32(ptr, value) *(ptr) = (value) +#define iree_unaligned_store_le_u64(ptr, value) *(ptr) = (value) +#define iree_unaligned_store_le_f32(ptr, value) *(ptr) = (value) +#define iree_unaligned_store_le_f64(ptr, value) *(ptr) = (value) + +#else + +#error "TODO(benvanik): little-endian load/store for big-endian archs" + +#endif // IREE_ENDIANNESS_* + +#endif // IREE_MEMORY_ACCESS_ALIGNMENT_REQUIRED + +// clang-format off + +// Dereferences |ptr| and returns the value. +// Automatically handles unaligned accesses on architectures that may not +// support them natively (or efficiently). Memory is treated as little-endian. +#define iree_unaligned_load_le(ptr) \ + _Generic((ptr), \ + int8_t*: iree_unaligned_load_le_u8((const uint8_t*)(ptr)), \ + uint8_t*: iree_unaligned_load_le_u8((const uint8_t*)(ptr)), \ + int16_t*: iree_unaligned_load_le_u16((const uint16_t*)(ptr)), \ + uint16_t*: iree_unaligned_load_le_u16((const uint16_t*)(ptr)), \ + int32_t*: iree_unaligned_load_le_u32((const uint32_t*)(ptr)), \ + uint32_t*: iree_unaligned_load_le_u32((const uint32_t*)(ptr)), \ + int64_t*: iree_unaligned_load_le_u64((const uint64_t*)(ptr)), \ + uint64_t*: iree_unaligned_load_le_u64((const uint64_t*)(ptr)), \ + float*: iree_unaligned_load_le_f32((const float*)(ptr)), \ + double*: iree_unaligned_load_le_f64((const double*)(ptr)) \ + ) + +// Dereferences |ptr| and writes the given |value|. +// Automatically handles unaligned accesses on architectures that may not +// support them natively (or efficiently). Memory is treated as little-endian. +#define iree_unaligned_store(ptr, value) \ + _Generic((ptr), \ + int8_t*: iree_unaligned_store_le_u8((uint8_t*)(ptr), value), \ + uint8_t*: iree_unaligned_store_le_u8((uint8_t*)(ptr), value), \ + int16_t*: iree_unaligned_store_le_u16((uint16_t*)(ptr), value), \ + uint16_t*: iree_unaligned_store_le_u16((uint16_t*)(ptr), value), \ + int32_t*: iree_unaligned_store_le_u32((uint32_t*)(ptr), value), \ + uint32_t*: iree_unaligned_store_le_u32((uint32_t*)(ptr), value), \ + int64_t*: iree_unaligned_store_le_u64((uint64_t*)(ptr), value), \ + uint64_t*: iree_unaligned_store_le_u64((uint64_t*)(ptr), value), \ + float*: iree_unaligned_store_le_f32((float*)(ptr), value), \ + double*: iree_unaligned_store_le_f64((double*)(ptr), value) \ + ) + +// clang-format on + #ifdef __cplusplus } // extern "C" #endif
diff --git a/iree/base/internal/threading_darwin.c b/iree/base/internal/threading_darwin.c index 0d9f4fd..7f3bd00 100644 --- a/iree/base/internal/threading_darwin.c +++ b/iree/base/internal/threading_darwin.c
@@ -82,7 +82,7 @@ // (including the user-specified entry_arg). iree_thread_t* thread = NULL; iree_status_t status = - iree_allocator_malloc(allocator, sizeof(iree_thread_t), (void**)&thread); + iree_allocator_malloc(allocator, sizeof(*thread), (void**)&thread); if (!iree_status_is_ok(status)) { IREE_TRACE_ZONE_END(z0); return status;
diff --git a/iree/base/internal/threading_pthreads.c b/iree/base/internal/threading_pthreads.c index ed7c59f..f3d66a8 100644 --- a/iree/base/internal/threading_pthreads.c +++ b/iree/base/internal/threading_pthreads.c
@@ -124,7 +124,7 @@ // (including the user-specified entry_arg). iree_thread_t* thread = NULL; iree_status_t status = - iree_allocator_malloc(allocator, sizeof(iree_thread_t), (void**)&thread); + iree_allocator_malloc(allocator, sizeof(*thread), (void**)&thread); if (!iree_status_is_ok(status)) { IREE_TRACE_ZONE_END(z0); return status;
diff --git a/iree/base/internal/threading_win32.c b/iree/base/internal/threading_win32.c index f6e5d24..6e550e3 100644 --- a/iree/base/internal/threading_win32.c +++ b/iree/base/internal/threading_win32.c
@@ -139,7 +139,7 @@ // (including the user-specified entry_arg). iree_thread_t* thread = NULL; iree_status_t status = - iree_allocator_malloc(allocator, sizeof(iree_thread_t), (void**)&thread); + iree_allocator_malloc(allocator, sizeof(*thread), (void**)&thread); if (!iree_status_is_ok(status)) { IREE_TRACE_ZONE_END(z0); return status; @@ -200,7 +200,11 @@ iree_thread_resume(thread); - WaitForSingleObject(thread->handle, INFINITE); + if (thread->id != GetCurrentThreadId()) { + // Join with the thread. Since threads can delete themselves we must ensure + // they don't try to join with themselves and deadlock. + WaitForSingleObject(thread->handle, INFINITE); + } CloseHandle(thread->handle); iree_thread_override_list_deinitialize(&thread->qos_override_list); iree_allocator_free(thread->allocator, thread);
diff --git a/iree/base/internal/wait_handle_poll.c b/iree/base/internal/wait_handle_poll.c index 805a19d..5dd51fc 100644 --- a/iree/base/internal/wait_handle_poll.c +++ b/iree/base/internal/wait_handle_poll.c
@@ -148,10 +148,10 @@ } iree_host_size_t user_handle_list_size = - capacity * sizeof(iree_wait_handle_t); + capacity * iree_sizeof_struct(iree_wait_handle_t); iree_host_size_t poll_fd_list_size = capacity * sizeof(struct pollfd); - iree_host_size_t total_size = - sizeof(iree_wait_set_t) + user_handle_list_size + poll_fd_list_size; + iree_host_size_t total_size = iree_sizeof_struct(iree_wait_set_t) + + user_handle_list_size + poll_fd_list_size; iree_wait_set_t* set = NULL; IREE_RETURN_IF_ERROR( @@ -161,7 +161,8 @@ iree_wait_set_clear(set); set->user_handles = - (iree_wait_handle_t*)((uint8_t*)set + sizeof(iree_wait_set_t)); + (iree_wait_handle_t*)((uint8_t*)set + + iree_sizeof_struct(iree_wait_set_t)); set->poll_fds = (struct pollfd*)((uint8_t*)set->user_handles + user_handle_list_size);
diff --git a/iree/base/internal/wait_handle_win32.c b/iree/base/internal/wait_handle_win32.c index 0a5bb56..af17e4a 100644 --- a/iree/base/internal/wait_handle_win32.c +++ b/iree/base/internal/wait_handle_win32.c
@@ -147,8 +147,8 @@ iree_host_size_t user_handle_list_size = capacity * sizeof(iree_wait_handle_t); iree_host_size_t native_handle_list_size = capacity * sizeof(HANDLE); - iree_host_size_t total_size = - sizeof(iree_wait_set_t) + user_handle_list_size + native_handle_list_size; + iree_host_size_t total_size = iree_sizeof_struct(iree_wait_set_t) + + user_handle_list_size + native_handle_list_size; iree_wait_set_t* set = NULL; IREE_RETURN_IF_ERROR( @@ -158,7 +158,8 @@ iree_wait_set_clear(set); set->user_handles = - (iree_wait_handle_t*)((uint8_t*)set + sizeof(iree_wait_set_t)); + (iree_wait_handle_t*)((uint8_t*)set + + iree_sizeof_struct(iree_wait_set_t)); set->native_handles = (HANDLE*)((uint8_t*)set->user_handles + user_handle_list_size);
diff --git a/iree/base/target_platform.h b/iree/base/target_platform.h index 4277d71..8052967 100644 --- a/iree/base/target_platform.h +++ b/iree/base/target_platform.h
@@ -30,6 +30,8 @@ // IREE_ENDIANNESS_LITTLE // IREE_ENDIANNESS_BIG // +// IREE_MEMORY_ACCESS_ALIGNMENT_REQUIRED (0/1) +// // IREE_COMPILER_CLANG // IREE_COMPILER_GCC // IREE_COMPILER_GCC_COMPAT @@ -138,6 +140,39 @@ #endif // __BYTE_ORDER__ //============================================================================== +// IREE_MEMORY_ACCESS_* +//============================================================================== +// Certain architectures have specific memory access requirements that require +// user-mode code changes to work at all or work at reasonable performance. + +#if !defined(IREE_MEMORY_ACCESS_ALIGNMENT_REQUIRED) + +#if defined(IREE_ARCH_ARM_32) || defined(IREE_ARCH_ARM_64) + +// Armv6‑M and Armv8-M (w/o the main extension) do not support unaligned access. +// The -munaligned-access and -mno-unaligned-access flags control this. +// https://www.keil.com/support/man/docs/armclang_ref/armclang_ref_sam1444138667173.htm +#if !defined(__ARM_FEATURE_UNALIGNED) +#define IREE_MEMORY_ACCESS_ALIGNMENT_REQUIRED 1 +#else +#define IREE_MEMORY_ACCESS_ALIGNMENT_REQUIRED 0 +#endif // !__ARM_FEATURE_UNALIGNED + +#elif defined(IREE_ARCH_RISCV_32) || defined(IREE_ARCH_RISCV_64) + +// Though unaligned access is part of the base spec it is allowed to be +// implemented with trap handlers. Bare-metal systems likely won't have these +// handlers and even on systems that do (linux) we don't want to be trapping for +// every load/store. +#define IREE_MEMORY_ACCESS_ALIGNMENT_REQUIRED 1 + +#endif // IREE_ARCH_* + +#else +#define IREE_MEMORY_ACCESS_ALIGNMENT_REQUIRED 0 +#endif // !IREE_MEMORY_ACCESS_ALIGNMENT_REQUIRED + +//============================================================================== // IREE_COMPILER_* //==============================================================================
diff --git a/iree/compiler/Codegen/SPIRV/ConvertToSPIRVPass.cpp b/iree/compiler/Codegen/SPIRV/ConvertToSPIRVPass.cpp index b79eac9..7c3cbb7 100644 --- a/iree/compiler/Codegen/SPIRV/ConvertToSPIRVPass.cpp +++ b/iree/compiler/Codegen/SPIRV/ConvertToSPIRVPass.cpp
@@ -13,6 +13,8 @@ // //===----------------------------------------------------------------------===// +#include <tuple> + #include "iree/compiler/Codegen/PassDetail.h" #include "iree/compiler/Codegen/Passes.h" #include "iree/compiler/Codegen/Utils/MarkerUtils.h" @@ -51,15 +53,19 @@ // Resource utilities //===----------------------------------------------------------------------===// -/// Inserts a resource evariable of the given `type` at the beginning of +/// Map from hal.interface.binding.subspan ops to their corresponding +/// spv.GlobalVariable ops. +using InterfaceResourceMap = + llvm::DenseMap<Operation *, spirv::GlobalVariableOp>; + +/// Creates a resource evariable of the given `type` at the beginning of /// `moduleOp`'s block via `symbolTable` and bind it to `set` and `binding`. -spirv::GlobalVariableOp insertResourceVariable(Location loc, Type type, +spirv::GlobalVariableOp createResourceVariable(Location loc, Type type, unsigned set, unsigned binding, bool alias, ModuleOp moduleOp, - SymbolTable *symbolTable, - OpBuilder::Listener *listener) { - OpBuilder builder(moduleOp.getContext(), listener); + SymbolTable *symbolTable) { std::string name = llvm::formatv("__resource_var_{0}_{1}_", set, binding); + OpBuilder builder(moduleOp.getContext()); auto variable = builder.create<spirv::GlobalVariableOp>(loc, type, name, set, binding); if (alias) variable->setAttr("aliased", builder.getUnitAttr()); @@ -74,35 +80,66 @@ return {bindingOp.set().getSExtValue(), bindingOp.binding().getSExtValue()}; } -/// Returns the set of resources that should be marked as aliased in SPIR-V. -llvm::DenseSet<Operation *> getAliasedResources(ModuleOp module) { - llvm::DenseSet<Operation *> aliasedResources; +/// Scans all hal.interface.binding.subspan ops in `module`, creates their +/// corresponding spv.GlobalVariables when needed, and returns the map. +/// The created variables need to have their types fixed later. +InterfaceResourceMap createResourceVariables(mlir::ModuleOp module) { + SymbolTable symbolTable(module); + InterfaceResourceMap interfaceToResourceVars; - for (FuncOp func : module.getOps<FuncOp>()) { + auto fns = llvm::to_vector<1>(module.getOps<FuncOp>()); + for (FuncOp func : llvm::reverse(fns)) { // Collect all interface ops and their (set, binding) pairs in this - // function. - SmallVector<Operation *, 4> interfaceOps; - SmallVector<std::pair<uint32_t, uint32_t>, 4> setBindings; - llvm::DenseMap<std::pair<uint32_t, uint32_t>, unsigned> setBindingCount; + // function. Use SmallVector here for a deterministic order. + SmallVector<IREE::HAL::InterfaceBindingSubspanOp, 8> interfaceOps; + SmallVector<std::pair<uint32_t, uint32_t>, 8> setBindings; + + // Use a map to see if we have different types for one (set, binding) pair, + // which will require creating multiple SPIR-V global variables. + llvm::DenseMap<std::pair<uint32_t, uint32_t>, llvm::DenseSet<Type>> + setBindingTypes; + func.walk([&](Operation *op) { - if (isa<IREE::HAL::InterfaceBindingSubspanOp>(op)) { - interfaceOps.emplace_back(op); - setBindings.emplace_back(getInterfaceSetAndBinding(op)); - ++setBindingCount[setBindings.back()]; - } + auto interfaceOp = dyn_cast<IREE::HAL::InterfaceBindingSubspanOp>(op); + if (!interfaceOp || interfaceOp.use_empty()) return; + interfaceOps.emplace_back(interfaceOp); + setBindings.emplace_back(getInterfaceSetAndBinding(interfaceOp)); + setBindingTypes[setBindings.back()].insert(interfaceOp.getType()); }); - // Perform analysis to determine whether we need to mark the resource as - // alias. This should happen when we have multiple resources binding to the - // same (set, binding) pair and they are used in the same function. - for (unsigned i = 0; i < interfaceOps.size(); ++i) { - if (setBindingCount[setBindings[i]] > 1) { - aliasedResources.insert(interfaceOps[i]); + // Keep track of created SPIR-V global variables. This allows us to + // deduplicate when possible to reduce generated SPIR-V blob size. + llvm::DenseMap<std::tuple<uint32_t, uint32_t, Type>, + spirv::GlobalVariableOp> + resourceVars; + + for (int i = interfaceOps.size() - 1; i >= 0; --i) { + auto interfaceOp = interfaceOps[i]; + const auto &setBinding = setBindings[i]; + + auto key = std::make_tuple(setBinding.first, setBinding.second, + interfaceOp.getType()); + auto var = resourceVars.lookup(key); + if (!var) { + // If we have multiple SPIR-V global variables bound to the same (set, + // binding) pair and they are used in the same function, those variables + // need to have alias decoration. + bool alias = setBindingTypes[setBindings[i]].size() > 1; + + // We are using the interface op's type for creating the global + // variable. It's fine. The correctness boundary is the pass. + // We will fix it up during conversion so it won't leak. + var = createResourceVariable( + interfaceOp.getLoc(), interfaceOp.getType(), setBinding.first, + setBinding.second, alias, module, &symbolTable); + resourceVars[key] = var; } + + interfaceToResourceVars[interfaceOp] = var; } } - return aliasedResources; + return interfaceToResourceVars; } } // namespace @@ -169,17 +206,19 @@ : public OpConversionPattern<IREE::HAL::InterfaceBindingSubspanOp> { HALInterfaceBindingSubspanConverter( TypeConverter &typeConverter, MLIRContext *context, - const llvm::DenseSet<Operation *> &aliasedResources, - SymbolTable *symbolTable, PatternBenefit benefit = 1) + const InterfaceResourceMap &interfaceToResourceVars, + PatternBenefit benefit = 1) : OpConversionPattern(typeConverter, context, benefit), - aliasedResources(aliasedResources), - symbolTable(symbolTable) {} + interfaceToResourceVars(interfaceToResourceVars) {} LogicalResult matchAndRewrite( IREE::HAL::InterfaceBindingSubspanOp interfaceOp, ArrayRef<Value> operands, ConversionPatternRewriter &rewriter) const override { - auto moduleOp = interfaceOp->template getParentOfType<ModuleOp>(); + if (interfaceOp.use_empty()) { + rewriter.eraseOp(interfaceOp); + return success(); + } Type resultType = interfaceOp.getOperation()->getResult(0).getType(); Type convertedType = this->getTypeConverter()->convertType(resultType); @@ -187,22 +226,18 @@ return interfaceOp.emitError() << "failed to convert SPIR-V type: " << resultType; } - auto setAndBinding = getInterfaceSetAndBinding(interfaceOp.getOperation()); - // We always create a new resource variable for the interface. - spirv::GlobalVariableOp varOp = insertResourceVariable( - interfaceOp.getLoc(), convertedType, setAndBinding.first, - setAndBinding.second, - aliasedResources.contains(interfaceOp.getOperation()), moduleOp, - symbolTable, rewriter.getListener()); + auto varOp = interfaceToResourceVars.lookup(interfaceOp); + // Fix up the variable's type. + varOp.typeAttr(TypeAttr::get(convertedType)); rewriter.replaceOpWithNewOp<spirv::AddressOfOp>(interfaceOp, varOp); + return success(); } private: - const llvm::DenseSet<Operation *> &aliasedResources; - SymbolTable *symbolTable; + const InterfaceResourceMap &interfaceToResourceVars; }; /// Pattern to lower operations that become a no-ops at this level. @@ -307,18 +342,12 @@ IREE::HAL::InterfaceWorkgroupCountOp, spirv::BuiltIn::NumWorkgroups>>( typeConverter, context); - // Create a symbol table for the current module op. This must be done before - // starting conversion. The conversion framework has a deferred nature. - // Actions like replacing/deleting ops are just recorded instead of directly - // applied. So for function conversion, we will see both the original function - // and the new replacement function during conversion process. That causes an - // issue for creating symbol tables, which scans all symbols and errors out - // when duplicates are found. We should be fine here tough becuase we only use - // this symbol table to insert new SPIR-V global variables. - SymbolTable symbolTable(moduleOp); - auto aliasedResources = getAliasedResources(moduleOp); - patterns.insert<HALInterfaceBindingSubspanConverter>( - typeConverter, context, aliasedResources, &symbolTable); + // Performs a prelimiary step to analyze all hal.interface.binding.subspan ops + // and create spv.GlobalVariables. + auto interfaceToResourceVars = createResourceVariables(moduleOp); + // For using use them in conversion. + patterns.insert<HALInterfaceBindingSubspanConverter>(typeConverter, context, + interfaceToResourceVars); /// Fold certain operations as no-ops: /// - linalg.reshape becomes a no-op since all memrefs are linearized in
diff --git a/iree/compiler/Codegen/SPIRV/test/convert_to_spirv.mlir b/iree/compiler/Codegen/SPIRV/test/convert_to_spirv.mlir index abc817f..a7d78af 100644 --- a/iree/compiler/Codegen/SPIRV/test/convert_to_spirv.mlir +++ b/iree/compiler/Codegen/SPIRV/test/convert_to_spirv.mlir
@@ -24,20 +24,43 @@ module attributes {spv.target_env = #spv.target_env<#spv.vce<v1.3, [Shader], []>, {}>} { // CHECK-LABEL: spv.module - // CHECK: spv.GlobalVariable @__resource_var_3_4_ bind(3, 4) : !spv.ptr<!spv.struct<(!spv.array<16 x f32, stride=4> [0])>, StorageBuffer> - // CHECK: spv.GlobalVariable @__resource_var_1_2__0 bind(1, 2) {aliased} : !spv.ptr<!spv.struct<(!spv.array<16 x f32, stride=4> [0])>, StorageBuffer> - // CHECK: spv.GlobalVariable @__resource_var_1_2_ bind(1, 2) {aliased} : !spv.ptr<!spv.struct<(!spv.array<16 x f32, stride=4> [0])>, StorageBuffer> + // CHECK: spv.GlobalVariable @[[ARG0:.+]] bind(1, 2) : !spv.ptr<!spv.struct<(!spv.array<16 x f32, stride=4> [0])>, StorageBuffer> + // CHECK: spv.GlobalVariable @[[ARG1_0:.+]] bind(1, 3) {aliased} : !spv.ptr<!spv.struct<(!spv.array<16 x f32, stride=4> [0])>, StorageBuffer> + // CHECK: spv.GlobalVariable @[[ARG1_1:.+]] bind(1, 3) {aliased} : !spv.ptr<!spv.struct<(!spv.array<4 x vector<4xf32>, stride=16> [0])>, StorageBuffer> + // CHECK: spv.GlobalVariable @[[RET0:.+]] bind(3, 4) : !spv.ptr<!spv.struct<(!spv.array<16 x f32, stride=4> [0])>, StorageBuffer> // CHECK: spv.func @resource_bindings_in_same_entry_func() func @resource_bindings_in_same_entry_func() { %c0 = constant 0 : index + + // Same type + // CHECK: spv.mlir.addressof @[[ARG0]] + // CHECK: spv.mlir.addressof @[[ARG0]] %0 = hal.interface.binding.subspan @io::@arg0[%c0] : memref<4x4xf32> %1 = hal.interface.binding.subspan @io::@arg0[%c0] : memref<4x4xf32> - %2 = hal.interface.binding.subspan @io::@ret0[%c0] : memref<4x4xf32> + + // Different type + // CHECK: spv.mlir.addressof @[[ARG1_0]] + // CHECK: spv.mlir.addressof @[[ARG1_1]] + %2 = hal.interface.binding.subspan @io::@arg1[%c0] : memref<4x4xf32> + %3 = hal.interface.binding.subspan @io::@arg1[%c0] : memref<4xvector<4xf32>> + + // CHECK: spv.mlir.addressof @[[RET0]] + %4 = hal.interface.binding.subspan @io::@ret0[%c0] : memref<4x4xf32> + + %5 = memref.load %0[%c0, %c0] : memref<4x4xf32> + %6 = memref.load %1[%c0, %c0] : memref<4x4xf32> + + %7 = memref.load %2[%c0, %c0] : memref<4x4xf32> + %8 = memref.load %3[%c0] : memref<4xvector<4xf32>> + + %9 = memref.load %4[%c0, %c0] : memref<4x4xf32> + return } hal.interface @io attributes {push_constants = 5 : index, sym_visibility = "private"} { hal.interface.binding @arg0, set=1, binding=2, type="StorageBuffer", access="Read" + hal.interface.binding @arg1, set=1, binding=3, type="StorageBuffer", access="Read" hal.interface.binding @ret0, set=3, binding=4, type="StorageBuffer", access="Write" } } @@ -46,18 +69,22 @@ module attributes {spv.target_env = #spv.target_env<#spv.vce<v1.3, [Shader], []>, {}>} { // CHECK-LABEL: spv.module - // CHECK: spv.GlobalVariable @[[FUNC2_RET:.+]] bind(3, 4) : !spv.ptr<!spv.struct<(!spv.array<16 x f32, stride=4> [0])>, StorageBuffer> - // CHECK: spv.GlobalVariable @[[FUNC2_ARG:.+]] bind(1, 2) : !spv.ptr<!spv.struct<(!spv.array<16 x f32, stride=4> [0])>, StorageBuffer> - // CHECK: spv.GlobalVariable @[[FUNC1_RET:.+]] bind(3, 4) : !spv.ptr<!spv.struct<(!spv.array<4 x vector<4xf32>, stride=16> [0])>, StorageBuffer> // CHECK: spv.GlobalVariable @[[FUNC1_ARG:.+]] bind(1, 2) : !spv.ptr<!spv.struct<(!spv.array<16 x f32, stride=4> [0])>, StorageBuffer> + // CHECK: spv.GlobalVariable @[[FUNC1_RET:.+]] bind(3, 4) : !spv.ptr<!spv.struct<(!spv.array<4 x vector<4xf32>, stride=16> [0])>, StorageBuffer> + // CHECK: spv.GlobalVariable @[[FUNC2_ARG:.+]] bind(1, 2) : !spv.ptr<!spv.struct<(!spv.array<16 x f32, stride=4> [0])>, StorageBuffer> + // CHECK: spv.GlobalVariable @[[FUNC2_RET:.+]] bind(3, 4) : !spv.ptr<!spv.struct<(!spv.array<16 x f32, stride=4> [0])>, StorageBuffer> // CHECK: spv.func @resource_bindings_in_entry_func1() func @resource_bindings_in_entry_func1() { - // CHECK: spv.mlir.addressof @[[FUNC1_ARG:.+]] - // CHECK: spv.mlir.addressof @[[FUNC1_RET:.+]] + // CHECK: spv.mlir.addressof @[[FUNC1_ARG]] + // CHECK: spv.mlir.addressof @[[FUNC1_RET]] %c0 = constant 0 : index %0 = hal.interface.binding.subspan @io::@arg0[%c0] : memref<4x4xf32> %1 = hal.interface.binding.subspan @io::@ret0[%c0] : memref<4xvector<4xf32>> + + %2 = memref.load %0[%c0, %c0] : memref<4x4xf32> + %3 = memref.load %1[%c0] : memref<4xvector<4xf32>> + return } @@ -66,8 +93,12 @@ // CHECK: spv.mlir.addressof @[[FUNC2_ARG]] // CHECK: spv.mlir.addressof @[[FUNC2_RET]] %c0 = constant 0 : index - %0 = hal.interface.binding.subspan @io::@arg0[%c0] : memref<4x4xf32> - %1 = hal.interface.binding.subspan @io::@ret0[%c0] : memref<4x4xf32> + %0 = hal.interface.binding.subspan @io::@arg0[%c0] : memref<4x4xf32> // Same type as previous function + %1 = hal.interface.binding.subspan @io::@ret0[%c0] : memref<4x4xf32> // Different type as previous function + + %2 = memref.load %0[%c0, %c0] : memref<4x4xf32> + %3 = memref.load %1[%c0, %c0] : memref<4x4xf32> + return } @@ -85,6 +116,11 @@ %0 = hal.interface.binding.subspan @io::@arg0[%c0] : memref<8x5xf32> %1 = hal.interface.binding.subspan @io::@arg1[%c0] : memref<5xf32> %2 = hal.interface.binding.subspan @io::@ret0[%c0] : memref<8x5xf32> + + %3 = memref.load %0[%c0, %c0] : memref<8x5xf32> + %4 = memref.load %1[%c0] : memref<5xf32> + %5 = memref.load %2[%c0, %c0] : memref<8x5xf32> + return } hal.interface @io attributes {sym_visibility = "private"} { @@ -93,14 +129,17 @@ hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard" } } + +// Explicitly check the variable symbols + // CHECK-LABEL: spv.module -// CHECK-DAG: spv.GlobalVariable @[[RET0:.+]] bind(0, 2) -// CHECK-DAG: spv.GlobalVariable @[[ARG1:.+]] bind(0, 1) -// CHECK-DAG: spv.GlobalVariable @[[ARG0:.+]] bind(0, 0) +// CHECK: spv.GlobalVariable @__resource_var_0_0_ bind(0, 0) +// CHECK: spv.GlobalVariable @__resource_var_0_1_ bind(0, 1) +// CHECK: spv.GlobalVariable @__resource_var_0_2_ bind(0, 2) // CHECK: spv.func -// CHECK-DAG: %{{.+}} = spv.mlir.addressof @[[RET0]] -// CHECK-DAG: %{{.+}} = spv.mlir.addressof @[[ARG0]] -// CHECK-DAG: %{{.+}} = spv.mlir.addressof @[[ARG1]] +// CHECK: %{{.+}} = spv.mlir.addressof @__resource_var_0_0_ +// CHECK: %{{.+}} = spv.mlir.addressof @__resource_var_0_1_ +// CHECK: %{{.+}} = spv.mlir.addressof @__resource_var_0_2_ // -----
diff --git a/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp b/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp index 8bf33ff..ce2f8cb 100644 --- a/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp +++ b/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp
@@ -695,6 +695,43 @@ } }; +// Replace `flow.tensor.splat`-`flow.tensor.load` op-pairs by the input +// primitive value for the splat op. +struct FoldSplatLoadIntoPrimitive : public OpRewritePattern<TensorLoadOp> { + using OpRewritePattern<TensorLoadOp>::OpRewritePattern; + + LogicalResult matchAndRewrite(TensorLoadOp loadOp, + PatternRewriter &rewriter) const override { + auto sourceOp = + dyn_cast_or_null<TensorSplatOp>(loadOp.source().getDefiningOp()); + + if (!sourceOp) return failure(); + + rewriter.replaceOp(loadOp, sourceOp.value()); + return success(); + } +}; + +struct FoldSplatReshapeIntoSplat : public OpRewritePattern<TensorSplatOp> { + using OpRewritePattern<TensorSplatOp>::OpRewritePattern; + + LogicalResult matchAndRewrite(TensorSplatOp splatOp, + PatternRewriter &rewriter) const override { + if (!splatOp.result().hasOneUse()) return failure(); + + auto reshapeOp = dyn_cast_or_null<TensorReshapeOp>( + splatOp.result().use_begin()->getOwner()); + if (!reshapeOp) return failure(); + + rewriter.replaceOpWithNewOp<TensorSplatOp>( + reshapeOp, reshapeOp.result().getType(), splatOp.value(), + reshapeOp.result_dims()); + rewriter.eraseOp(splatOp); + + return success(); + } +}; + } // namespace void TensorReshapeOp::getCanonicalizationPatterns( @@ -702,6 +739,11 @@ results.insert<FlattenTensorReshapeChain>(context); } +void TensorLoadOp::getCanonicalizationPatterns( + OwningRewritePatternList &results, MLIRContext *context) { + results.insert<FoldSplatLoadIntoPrimitive>(context); +} + OpFoldResult TensorLoadOp::fold(ArrayRef<Attribute> operands) { if (auto source = operands[0].dyn_cast_or_null<ElementsAttr>()) { // Load directly from the constant source tensor. @@ -739,6 +781,12 @@ return {}; } +void TensorSplatOp::getCanonicalizationPatterns( + OwningRewritePatternList &results, MLIRContext *context) { + // TODO(benvanik): canonicalize splat+slice to smaller splat. + results.insert<FoldSplatReshapeIntoSplat>(context); +} + OpFoldResult TensorSplatOp::fold(ArrayRef<Attribute> operands) { if (operands.size() == 1 && operands.front()) { // Splat value is constant and we can fold the operation.
diff --git a/iree/compiler/Dialect/Flow/IR/FlowOps.td b/iree/compiler/Dialect/Flow/IR/FlowOps.td index 3b14102..cd1b2f4 100644 --- a/iree/compiler/Dialect/Flow/IR/FlowOps.td +++ b/iree/compiler/Dialect/Flow/IR/FlowOps.td
@@ -85,7 +85,7 @@ }]; let arguments = (ins - FLOW_VariableRefAttr:$variable + Arg<FLOW_VariableRefAttr, "", [MemRead]>:$variable ); let results = (outs AnyType:$result @@ -135,7 +135,7 @@ let arguments = (ins AnyType:$value, - FLOW_VariableRefAttr:$variable + Arg<FLOW_VariableRefAttr, "", [MemWrite]>:$variable ); let assemblyFormat = "$value `,` $variable attr-dict `:` type($value)"; @@ -822,6 +822,7 @@ // TODO(benvanik): canonicalize to slice+load if dims are known. let hasFolder = 1; + let hasCanonicalizer = 1; } def FLOW_TensorStoreOp : FLOW_PureOp<"tensor.store", [ @@ -899,7 +900,7 @@ bool isTransfer() { return true; } }]; - // TODO(benvanik): canonicalize splat+slice to smaller splat. + let hasCanonicalizer = 1; let hasFolder = 1; }
diff --git a/iree/compiler/Dialect/Flow/IR/test/tensor_folding.mlir b/iree/compiler/Dialect/Flow/IR/test/tensor_folding.mlir index 62cab58..724b8e5 100644 --- a/iree/compiler/Dialect/Flow/IR/test/tensor_folding.mlir +++ b/iree/compiler/Dialect/Flow/IR/test/tensor_folding.mlir
@@ -411,3 +411,34 @@ // CHECK: return %[[RESULT]] return %1 : tensor<?x?xf32> } + +// ----- + +// CHECK-LABEL: @foldSplatLoadIntoPrimitive +// CHECK-SAME: (%[[arg0:.+]]: f32, %[[arg1:.+]]: index, %[[arg2:.+]]: index) +func @foldSplatLoadIntoPrimitive(%arg0 : f32, %arg1 : index, %arg2 : index) -> f32 { + // CHECK-NEXT: return %[[arg0]] : f32 + %0 = flow.tensor.splat %arg0 : tensor<4x4xf32> + %1 = flow.tensor.load %0[%arg1, %arg2] : tensor<4x4xf32> + return %1 : f32 +} + +// ----- + +// CHECK-LABEL: @foldSplatReshapeIntoSplat +func @foldSplatReshapeIntoSplat(%arg0 : f32) -> tensor<16xf32> { + // CHECK-NEXT: %0 = flow.tensor.splat %arg0 : tensor<16xf32> + // CHECK-NEXT: return %0 : tensor<16xf32> + %0 = flow.tensor.splat %arg0 : tensor<4x4xf32> + %1 = flow.tensor.reshape %0 : tensor<4x4xf32> -> tensor<16xf32> + return %1 : tensor<16xf32> +} + +// CHECK-LABEL: @foldSplatReshapeIntoSplatDynamic +func @foldSplatReshapeIntoSplatDynamic(%arg0 : f32, %arg1 : index, %arg2 : index, %arg3 : index) -> tensor<?x?xf32> { + // CHECK-NEXT: %0 = flow.tensor.splat %arg0 : tensor<?x?xf32>{%arg2, %arg3} + // CHECK-NEXT: return %0 : tensor<?x?xf32> + %0 = flow.tensor.splat %arg0 : tensor<?x4xf32>{%arg1} + %1 = flow.tensor.reshape %0 : tensor<?x4xf32>{%arg1} -> tensor<?x?xf32>{%arg2, %arg3} + return %1 : tensor<?x?xf32> +}
diff --git a/iree/compiler/Dialect/Flow/Transforms/BUILD b/iree/compiler/Dialect/Flow/Transforms/BUILD index 167a506..8b26a4b 100644 --- a/iree/compiler/Dialect/Flow/Transforms/BUILD +++ b/iree/compiler/Dialect/Flow/Transforms/BUILD
@@ -53,6 +53,7 @@ "Passes.cpp", "PromoteI1ToI8Pass.cpp", "PromoteTensorLoads.cpp", + "SimplifyVariableAccesses.cpp", "StripAndSplatConstantVariables.cpp", "TypeConverter.cpp", ],
diff --git a/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt b/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt index 57484a1..e8c00f3 100644 --- a/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt +++ b/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt
@@ -50,6 +50,7 @@ "Passes.cpp" "PromoteI1ToI8Pass.cpp" "PromoteTensorLoads.cpp" + "SimplifyVariableAccesses.cpp" "StripAndSplatConstantVariables.cpp" "TypeConverter.cpp" DEPS
diff --git a/iree/compiler/Dialect/Flow/Transforms/DestructiveUpdateUtils.cpp b/iree/compiler/Dialect/Flow/Transforms/DestructiveUpdateUtils.cpp index 69249fc..81483a4 100644 --- a/iree/compiler/Dialect/Flow/Transforms/DestructiveUpdateUtils.cpp +++ b/iree/compiler/Dialect/Flow/Transforms/DestructiveUpdateUtils.cpp
@@ -14,6 +14,7 @@ #include "iree/compiler/Dialect/Flow/IR/FlowOps.h" #include "iree/compiler/Dialect/HAL/IR/HALOps.h" #include "iree/compiler/Dialect/IREE/IR/IREEDialect.h" +#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Debug.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" @@ -55,17 +56,28 @@ }; // TODO(nicolasvasilache): Use some interface instead of op names directly. -static bool hasDestructiveUpdateSubTensorUses( - BlockArgument arg, SpecialTerminatorOpCapture &capture) { +static bool hasDestructiveUpdateUses(BlockArgument arg, + SpecialTerminatorOpCapture &capture) { SmallVector<Operation *> reads; - SmallVector<tensor::InsertSliceOp> writes; + SmallVector<Operation *> writes; for (OpOperand &u : arg.getUses()) { - if (auto subTensorInsertOp = - dyn_cast<tensor::InsertSliceOp>(u.getOwner())) { - writes.push_back(subTensorInsertOp); - } else { - reads.push_back(u.getOwner()); - } + TypeSwitch<Operation *, void>(u.getOwner()) + .Case<linalg::LinalgOp, linalg_ext::LinalgExtOp>( + [&](auto linalgLikeOp) { + if (linalgLikeOp.isOutputTensor(&u)) { + writes.push_back(linalgLikeOp); + } else { + reads.push_back(linalgLikeOp); + } + }) + .Case<tensor::InsertSliceOp>([&](tensor::InsertSliceOp sliceOp) { + if (sliceOp.dest() == u.get()) { + writes.push_back(sliceOp); + } else { + reads.push_back(sliceOp); + } + }) + .Default([&](Operation *op) { reads.push_back(op); }); } // For now, only allow exactly a single SubTensorInsertOp that must be // dominated by all SubTensorOp. @@ -77,7 +89,7 @@ if (!domInfo.properlyDominates(read, writes.front())) { LLVM_DEBUG(llvm::dbgs() << "non-destructive use-def: " << *read << " does not properly dominate " - << *(writes.front().getOperation()) << "\n"); + << *(writes.front()) << "\n"); return false; } } @@ -133,8 +145,7 @@ // Case 2: multiple uses from an scf::ForOp then this must be used only by // SubTensorOp / SubTensorInsertOp with proper dominance. if (!regionArg.hasOneUse()) { - if (!hasDestructiveUpdateSubTensorUses(regionArg, capture)) - return nullptr; + if (!hasDestructiveUpdateUses(regionArg, capture)) return nullptr; return returnValue; } @@ -145,8 +156,7 @@ // Case 3a: Single use which is not an scf::ForOp, it may still be a // single SubTensor / SubTensorInsertOp. if (!innerForOp) { - if (!hasDestructiveUpdateSubTensorUses(regionArg, capture)) - return nullptr; + if (!hasDestructiveUpdateUses(regionArg, capture)) return nullptr; return returnValue; } @@ -216,31 +226,67 @@ return success(); } -static LogicalResult rewriteSubTensorInsertInPlace(OpBuilder &b, - tensor::InsertSliceOp op, - Value target) { - LLVM_DEBUG(llvm::dbgs() << "RewriteSubTensorInsertInPlace: " - << *(op.getOperation()) << "\n"); - if (!op.getResult().hasOneUse()) { - return op.emitError("not a single use operation"); +template <typename OpTy> +static LogicalResult rewriteDestructiveUpdateInPlace(OpBuilder &b, + OpTy linalgLikeOp, + Value target) { + LLVM_DEBUG(llvm::dbgs() << "RewriteDestructiveUpdateInPlace: " + << *linalgLikeOp.getOperation() << "\n"); + if (!linalgLikeOp->hasOneUse()) { + return linalgLikeOp.emitError("not a single use operation"); } - Operation *user = *(op.getResult().getUsers().begin()); - if (isa<scf::YieldOp>(user)) { - auto dest = op.dest(); - if (!dest.isa<BlockArgument>()) { - return op.emitError("dest is not a argument to the loop"); + OpOperand &use = *(linalgLikeOp->use_begin()); + if (isa<scf::YieldOp>(use.getOwner())) { + OpResult usedResult = use.get().cast<OpResult>(); + Value dest = + linalgLikeOp.getOutputOperand(usedResult.getResultNumber())->get(); + if (!dest || !dest.isa<BlockArgument>()) { + return linalgLikeOp.emitError("dest is not a argument to the loop"); } OpBuilder::InsertionGuard g(b); - b.setInsertionPoint(op); + b.setInsertionPointAfter(linalgLikeOp); // Kills the SSA use-def chain. - op.replaceAllUsesWith(dest); + usedResult.replaceAllUsesWith(dest); + + b.create<IREE::Flow::DispatchTensorStoreOp>(linalgLikeOp.getLoc(), + usedResult, target); + + return success(); + } + return failure(); +} + +/// Rewrites destructive in-place updates with the update operation being +/// tensor.insert_slice. +template <> +LogicalResult rewriteDestructiveUpdateInPlace<tensor::InsertSliceOp>( + OpBuilder &b, tensor::InsertSliceOp insertSliceOp, Value target) { + LLVM_DEBUG(llvm::dbgs() << "RewriteDestructiveUpdateInPlace: " + << *insertSliceOp.getOperation() << "\n"); + if (!insertSliceOp->hasOneUse()) { + return insertSliceOp.emitError("not a single use operation"); + } + + OpOperand &use = *(insertSliceOp->use_begin()); + if (isa<scf::YieldOp>(use.getOwner())) { + OpResult usedResult = use.get().cast<OpResult>(); + Value dest = insertSliceOp.dest(); + if (!dest || !dest.isa<BlockArgument>()) { + return insertSliceOp.emitError("dest is not a argument to the loop"); + } + OpBuilder::InsertionGuard g(b); + b.setInsertionPoint(insertSliceOp); + + // Kills the SSA use-def chain. + usedResult.replaceAllUsesWith(dest); b.create<IREE::Flow::DispatchTensorStoreOp>( - op.getLoc(), op.source(), target, op.offsets(), op.sizes(), - op.strides(), op.static_offsets(), op.static_sizes(), - op.static_strides()); + insertSliceOp->getLoc(), insertSliceOp.source(), target, + insertSliceOp.offsets(), insertSliceOp.sizes(), insertSliceOp.strides(), + insertSliceOp.static_offsets(), insertSliceOp.static_sizes(), + insertSliceOp.static_strides()); return success(); } @@ -366,11 +412,17 @@ << "\n"); // Try to rewrite inplace. - if (failed(rewriteSubTensorInsertInPlace( - b, cast<tensor::InsertSliceOp>(capture.rootDestructiveUpdate), - target))) { - return failure(); - } + auto status = + TypeSwitch<Operation *, LogicalResult>(capture.rootDestructiveUpdate) + .Case<linalg::LinalgOp, linalg_ext::LinalgExtOp, + tensor::InsertSliceOp>([&](auto op) { + if (failed(rewriteDestructiveUpdateInPlace(b, op, target))) { + return failure(); + } + return success(); + }) + .Default([&](Operation *) { return failure(); }); + if (failed(status)) return failure(); if (scf::ForOp loopOp = dyn_cast<scf::ForOp>(outermostProducingOp)) loopOp.walk( @@ -415,9 +467,9 @@ capture.initValue = op.value(); Value sourceValue = isADestructiveUpdatePattern(capture.initValue, capture); - if (!sourceValue || !isa_and_nonnull<tensor::InsertSliceOp>( - capture.rootDestructiveUpdate)) + if (!sourceValue) { return WalkResult::advance(); + } if (failed(rewriteDestructiveUpdateInPlace(b, capture, op.target()))) { return WalkResult::interrupt();
diff --git a/iree/compiler/Dialect/Flow/Transforms/Passes.cpp b/iree/compiler/Dialect/Flow/Transforms/Passes.cpp index 6185eac..5fca973 100644 --- a/iree/compiler/Dialect/Flow/Transforms/Passes.cpp +++ b/iree/compiler/Dialect/Flow/Transforms/Passes.cpp
@@ -65,9 +65,14 @@ namespace Flow { void buildFlowTransformPassPipeline(OpPassManager &passManager) { - // Perform initial cleanup. - // NOTE: There is no principled reason to be doing this here. But also ensures - // some consistency at the tool boundary. + // Simplify flow.variable accesses early on; this can help with dispatch + // region formation as redundant store-loads are removed. + passManager.addNestedPass<FuncOp>( + IREE::Flow::createSimplifyVariableAccessesPass()); + + // Perform cleanup after variable simplification as more canonicalizers may be + // able to kick in. + passManager.addNestedPass<FuncOp>(mlir::createCanonicalizerPass()); passManager.addNestedPass<FuncOp>(mlir::createCSEPass()); // Replaces variables with !shapex.ranked_shape types with individual @@ -169,9 +174,11 @@ // Reorder blocks to increase the grouping of streamable ops. passManager.addNestedPass<FuncOp>( IREE::Flow::createHoistUnstreamableOpsPass()); + // The hoisting pass does some reordering. Canonicalize to avoid unnecessary // arbitrary ordering. passManager.addNestedPass<FuncOp>(mlir::createCanonicalizerPass()); + passManager.addNestedPass<FuncOp>(mlir::createCSEPass()); // Clone constants that escape basic blocks until we have better analysis. passManager.addNestedPass<FuncOp>(
diff --git a/iree/compiler/Dialect/Flow/Transforms/Passes.h b/iree/compiler/Dialect/Flow/Transforms/Passes.h index 5f4403a..445190e 100644 --- a/iree/compiler/Dialect/Flow/Transforms/Passes.h +++ b/iree/compiler/Dialect/Flow/Transforms/Passes.h
@@ -131,6 +131,10 @@ // Reorders blocks to hoist ops that cannot be put into streams. std::unique_ptr<OperationPass<FuncOp>> createHoistUnstreamableOpsPass(); +// Hoists loads and sinks stores to variables to decrease data dependency +// regions. +std::unique_ptr<OperationPass<FuncOp>> createSimplifyVariableAccessesPass(); + // TODO(benvanik): cross-function stream flows. // Inserts clones of constant values where they may be required.
diff --git a/iree/compiler/Dialect/Flow/Transforms/Passes.td b/iree/compiler/Dialect/Flow/Transforms/Passes.td index c8fa51c..4dfc814 100644 --- a/iree/compiler/Dialect/Flow/Transforms/Passes.td +++ b/iree/compiler/Dialect/Flow/Transforms/Passes.td
@@ -57,7 +57,7 @@ def FormStreams : Pass<"iree-flow-form-streams", "FuncOp"> { - let summary = "Identifies dispatches that can be grouped into streams within functions"; + let summary = "Identifies dispatches that can be grouped into streams within functions."; let constructor = "mlir::iree_compiler::IREE::Flow::createFormStreamsPass()"; } @@ -75,7 +75,7 @@ def InjectDispatchTracing : Pass<"iree-flow-inject-dispatch-tracing", "FuncOp"> { - let summary = "Injects dispatch region tracing"; + let summary = "Injects dispatch region tracing."; let constructor = "mlir::iree_compiler::IREE::Flow::createInjectDispatchTracingPass()"; } @@ -128,6 +128,12 @@ let constructor = "mlir::iree_compiler::IREE::Flow::createPromoteTensorLoadsPass()"; } +def SimplifyVariableAccesses : + Pass<"iree-flow-simplify-variable-accesses", "FuncOp"> { + let summary = "Hoist loads and sinks stores to variables to decrease data dependency regions."; + let constructor = "mlir::iree_compiler::IREE::Flow::createSimplifyVariableAccessesPass()"; +} + def StripAndSplatConstantVariables : Pass<"iree-flow-strip-and-splat-constant-variables", "ModuleOp"> { let summary = "Strips constant flow.variables and replaces them with splats.";
diff --git a/iree/compiler/Dialect/Flow/Transforms/SimplifyVariableAccesses.cpp b/iree/compiler/Dialect/Flow/Transforms/SimplifyVariableAccesses.cpp new file mode 100644 index 0000000..f190956 --- /dev/null +++ b/iree/compiler/Dialect/Flow/Transforms/SimplifyVariableAccesses.cpp
@@ -0,0 +1,270 @@ +// Copyright 2021 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include <algorithm> +#include <iterator> + +#include "iree/compiler/Dialect/Flow/IR/FlowDialect.h" +#include "iree/compiler/Dialect/Flow/IR/FlowOps.h" +#include "iree/compiler/Dialect/Flow/Transforms/PassDetail.h" +#include "iree/compiler/Dialect/Flow/Transforms/Passes.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Debug.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassRegistry.h" + +#define DEBUG_TYPE "iree-flow-simplify-variable-accesses" + +namespace mlir { +namespace iree_compiler { +namespace IREE { +namespace Flow { + +// Builds symbol ref set for all immutable variables in |moduleOp|. +static DenseSet<StringRef> gatherImmutableVariables(ModuleOp moduleOp) { + DenseSet<StringRef> set; + moduleOp.walk([&](IREE::Flow::VariableOp variableOp) { + if (!variableOp.is_mutable()) { + set.insert(variableOp.sym_name()); + } + }); + return set; +} + +// Hoists all loads of immutable variables in |funcOp| to the entry block. +// |immutableVariables| is used for lookups of which variables are immutable. +static void hoistImmutableLoads(FuncOp funcOp, + DenseSet<StringRef> &immutableVariables) { + // Since CSE of loads isn't a thing yet we perform a basic deduping here by + // folding all subsequent loads into the first one found. This works only for + // immutable variables as otherwise we'd have to ensure stores and + // side-effects were properly observed. + DenseMap<Attribute, Operation *> loadOps; + auto *entryBlock = &funcOp.getBlocks().front(); + Operation *lastEntryOp = nullptr; + for (auto &block : funcOp) { + for (auto op : llvm::make_early_inc_range( + block.getOps<IREE::Flow::VariableLoadOp>())) { + if (immutableVariables.contains(op.variable())) { + auto variableRef = op.variableAttr().cast<Attribute>(); + auto it = loadOps.find(variableRef); + if (it == loadOps.end()) { + // Move to entry block; even if it's already there (so loads are + // hoisted at the same time). + LLVM_DEBUG(llvm::dbgs() + << "moving immutable variable " << op.variable() + << " load to the entry block\n"); + if (lastEntryOp) { + op->moveAfter(lastEntryOp); + } else { + op->moveBefore(entryBlock, entryBlock->begin()); + } + loadOps[variableRef] = op; + lastEntryOp = op; + } else { + LLVM_DEBUG(llvm::dbgs() + << "CSE'ing immutable variable " << op.variable() << "\n"); + op->replaceAllUsesWith(it->getSecond()); + op->erase(); + } + } + } + } +} + +static bool doesOpBlockMotion(Operation *op) { + return isa<mlir::CallOpInterface>(op) || + op->hasTrait<OpTrait::IREE::YieldPoint>(); +} + +static void moveOpUpInBlock(Block &block, Operation *op) { + while (op->getPrevNode()) { + if (doesOpBlockMotion(op->getPrevNode())) break; + op->moveBefore(op->getPrevNode()); + } +} + +static void moveOpDownInBlock(Block &block, Operation *op) { + while (op->getNextNode() != block.getTerminator()) { + if (doesOpBlockMotion(op->getNextNode())) break; + op->moveAfter(op->getNextNode()); + } +} + +// Optimizes the load/store ops for each given bucket. +// Returns true if any op was removed. +static bool optimizeBuckets( + Block &block, std::map<StringRef, SmallVector<Operation *>> &buckets) { + bool didRemoveAny = false; + for (auto &bucket : buckets) { + // First perform basic load-store forwarding and such. + auto &ops = bucket.second; + for (int i = ops.size() - 1; i >= 1; --i) { + auto previous = ops[i - 1]; + auto current = ops[i]; + if (isa<IREE::Flow::VariableStoreOp>(previous) && + isa<IREE::Flow::VariableLoadOp>(current)) { + // RAW - forward the stored variable to the following use. + auto storedValue = previous->getOperand(0); + LLVM_DEBUG({ + llvm::dbgs() << "RAW: replacing load with previous store value:\n"; + current->dump(); + llvm::dbgs() << "->\n"; + storedValue.dump(); + }); + current->replaceAllUsesWith(ValueRange{storedValue}); + ops.erase(ops.begin() + i); + current->erase(); + didRemoveAny = true; + } else if (isa<IREE::Flow::VariableLoadOp>(previous) && + isa<IREE::Flow::VariableLoadOp>(current)) { + // RAR - forward the loaded variable to the following use. + LLVM_DEBUG({ + llvm::dbgs() << "RAR: replacing subsequent load with op:\n"; + current->dump(); + llvm::dbgs() << "->\n"; + previous->dump(); + }); + current->replaceAllUsesWith(previous); + ops.erase(ops.begin() + i); + current->erase(); + didRemoveAny = true; + } else if (isa<IREE::Flow::VariableStoreOp>(previous) && + isa<IREE::Flow::VariableStoreOp>(current)) { + // WAW - remove the first store. + LLVM_DEBUG({ + llvm::dbgs() << "WAW: erasing source op:\n"; + previous->dump(); + llvm::dbgs() << "\nand keeping subsequent op:\n"; + current->dump(); + }); + ops.erase(ops.begin() + i - 1); + previous->erase(); + didRemoveAny = true; + } + } + assert(!ops.empty()); + + if (auto loadOp = dyn_cast<IREE::Flow::VariableLoadOp>(ops.front())) { + // If the head op is a load we can move that to the top of the block. + LLVM_DEBUG(llvm::dbgs() << "moving mutable variable " << loadOp.variable() + << " load upward\n"); + moveOpUpInBlock(block, ops.front()); + } + if (auto storeOp = dyn_cast<IREE::Flow::VariableStoreOp>(ops.back())) { + // If the tail op is a store we can move that to the bottom of the block. + LLVM_DEBUG(llvm::dbgs() << "moving mutable variable " + << storeOp.variable() << " store downward\n"); + moveOpDownInBlock(block, ops.back()); + } + } + return didRemoveAny; +} + +// Hoists loads and sinks stores to the boundary of |block| when safe. +// |immutableVariables| is used for lookups of which variables are immutable. +// +// Basic algorithm (repeat until no op removals): +// for each op: +// if immutable: skip +// add to load/store buckets (sorted vector) +// for each bucket (symbol): +// walk ops in reverse: +// if (prev == store && this == load) // RAW +// replace load with store source +// if (prev == load && this == load) // RAR +// replace with first load +// if (prev == store && this == store) // WAW +// remove first store +// if (head == load) move load to front +// if (tail == store) move store to back +// +// Returns true if there were any removals and the block should be reprocessed. +static bool rearrangeBlockVariableAccesses( + Block &block, DenseSet<StringRef> &immutableVariables) { + // Produce [symbol_name, [op, op, op, ...]] buckets. + // NOTE: we use a map here so that we are deterministically ordered. This may + // not be needed but the variable count is low and it's nice to not care about + // op order issues. + // + // Because there may be ops that we can't optimize across (calls/etc) we + // handle flushing buckets on demand. + std::map<StringRef, SmallVector<Operation *>> buckets; + bool didRemoveAny = false; + for (auto &op : block) { + if (auto loadOp = dyn_cast<IREE::Flow::VariableLoadOp>(op)) { + if (immutableVariables.contains(loadOp.variable())) continue; + buckets[loadOp.variable()].push_back(&op); + } else if (auto storeOp = dyn_cast<IREE::Flow::VariableStoreOp>(op)) { + buckets[storeOp.variable()].push_back(&op); + } else if (doesOpBlockMotion(&op)) { + // Split point - all accesses after this point must not assume anything + // about accesses before it. + didRemoveAny = optimizeBuckets(block, buckets) || didRemoveAny; + buckets.clear(); + } + } + didRemoveAny = optimizeBuckets(block, buckets) || didRemoveAny; + return didRemoveAny; +} + +namespace { + +class SimplifyVariableAccessesPass + : public SimplifyVariableAccessesBase<SimplifyVariableAccessesPass> { + public: + void runOnOperation() override { + auto funcOp = getOperation(); + if (funcOp.empty()) return; + + auto moduleOp = funcOp->getParentOfType<mlir::ModuleOp>(); + assert(moduleOp && "func not in a module"); + + // Build a set of all immutable variables for fast lookup. + auto immutableVariables = gatherImmutableVariables(moduleOp); + + // Hoist immutable variables first. These have no hazards and don't care + // about control flow - like `constant` - so getting them handled first + // avoids the need for us to do the full analysis. + hoistImmutableLoads(funcOp, immutableVariables); + + // We can't optimize the function if there are indirect loads/stores. + // Note that constant loads are still ok above. + for (auto &block : funcOp) { + for (auto &op : block) { + if (isa<IREE::Flow::VariableLoadIndirectOp>(op) || + isa<IREE::Flow::VariableStoreIndirectOp>(op)) { + LLVM_DEBUG(llvm::dbgs() + << "bailing on variable access simplification: indirect " + "accesses present in function\n"); + return; + } + } + } + + // For each block in the function hoist loads and sink stores. + // This does no cross-block movement, though it really should. Maybe when a + // real compiler engineer sees this they'll be inspired to do this properly. + for (auto &block : funcOp) { + LLVM_DEBUG(llvm::dbgs() << "==== REARRANGING BLOCK ACCESSES ====\n"); + while (rearrangeBlockVariableAccesses(block, immutableVariables)) { + // NOTE: block is processed until no more ops are removed. Will always + // end in a fixed amount of time as ops are only removed from the block. + } + } + } +}; + +} // namespace + +std::unique_ptr<OperationPass<FuncOp>> createSimplifyVariableAccessesPass() { + return std::make_unique<SimplifyVariableAccessesPass>(); +} + +} // namespace Flow +} // namespace IREE +} // namespace iree_compiler +} // namespace mlir
diff --git a/iree/compiler/Dialect/Flow/Transforms/test/BUILD b/iree/compiler/Dialect/Flow/Transforms/test/BUILD index b8e2e4f..b16ee25 100644 --- a/iree/compiler/Dialect/Flow/Transforms/test/BUILD +++ b/iree/compiler/Dialect/Flow/Transforms/test/BUILD
@@ -38,6 +38,7 @@ "pad_tensor_to_tensor.mlir", "promote_i1_to_i8.mlir", "promote_tensor_loads.mlir", + "simplify_variable_accesses.mlir", "strip_and_splat_constant_variables.mlir", "transformation.mlir", ],
diff --git a/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt b/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt index 1d3a4d1..01a4294 100644 --- a/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt +++ b/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt
@@ -35,6 +35,7 @@ "pad_tensor_to_tensor.mlir" "promote_i1_to_i8.mlir" "promote_tensor_loads.mlir" + "simplify_variable_accesses.mlir" "strip_and_splat_constant_variables.mlir" "transformation.mlir" DATA
diff --git a/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir b/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir index 5e30bed..3a7999d 100644 --- a/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir +++ b/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir
@@ -910,7 +910,7 @@ // CHECK-SAME: {__internal_linalg_transform__ = "workgroup"} // CHECK-SAME: ins(%[[UPDATE_TILE]], %[[INDICES_TILE]] : tensor<?x?xf32>, tensor<?x1xi32>) // CHECK-SAME: outs(%[[ORIGINAL]] : tensor<?x?xf32>) -// CHECK: flow.dispatch.tensor.store %[[RESULT_TILE]], %[[ARG5]], offsets = [0, 0] +// CHECK: flow.dispatch.tensor.store %[[RESULT_TILE]], %[[ARG5]], offsets = [], sizes = [], strides = [] // CHECK: return %[[RESULT]] : tensor<?x?xf32> // ----- @@ -1007,3 +1007,33 @@ // CHECK: flow.return // CHECK: } // CHECK: return %[[RESULT]]#0, %[[RESULT]]#1 + +// ----- + +func @scatter_static(%arg0 : tensor<4xi32>, %arg1 : tensor<4x1xi32>, %arg2 : tensor<8xi32>) + -> tensor<8xi32>{ + %cst = constant dense<[0, 9, 0, 10, 11, 0, 0, 12]> : tensor<8xi32> + %cst_0 = constant dense<[9, 10, 11, 12]> : tensor<4xi32> + %cst_1 = constant dense<[[1], [3], [4], [7]]> : tensor<4x1xi32> + %cst_2 = constant dense<0> : tensor<8xi32> + %0 = linalg_ext.scatter + ins(%arg0, %arg1 : tensor<4xi32>, tensor<4x1xi32>) + outs(%arg2 : tensor<8xi32>) { + ^bb0(%arg3: i32, %arg4: i32): // no predecessors + linalg_ext.yield %arg3 : i32 + } -> tensor<8xi32> + return %0 : tensor<8xi32> +} +// CHECK: func @scatter_static +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<4xi32> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<4x1xi32> +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor<8xi32> +// CHECK: %[[RESULT:.+]] = flow.dispatch.workgroups +// CHECK-NEXT: %[[ARG3:[a-zA-Z0-9_]+]]: !flow.dispatch.tensor<readonly:4xi32> +// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: !flow.dispatch.tensor<readonly:4x1xi32> +// CHECK-SAME: %[[ARG5:[a-zA-Z0-9_]+]]: !flow.dispatch.tensor<readwrite:8xi32> +// CHECK: scf.for %[[IV:.+]] = %{{.+}} to %{{.+}} step %{{.+}} { +// CHECK: %[[SCATTER_TILE:.+]] = linalg_ext.scatter +// CHECK: flow.dispatch.tensor.store %[[SCATTER_TILE]], %[[ARG5]], offsets = [], sizes = [], strides = [] +// CHECK-NEXT: } +// CHECK: return %[[RESULT]]
diff --git a/iree/compiler/Dialect/Flow/Transforms/test/simplify_variable_accesses.mlir b/iree/compiler/Dialect/Flow/Transforms/test/simplify_variable_accesses.mlir new file mode 100644 index 0000000..927d193 --- /dev/null +++ b/iree/compiler/Dialect/Flow/Transforms/test/simplify_variable_accesses.mlir
@@ -0,0 +1,136 @@ +// RUN: iree-opt -split-input-file -iree-flow-simplify-variable-accesses %s | IreeFileCheck %s + +flow.variable @varA dense<1> : tensor<2xi32> attributes {sym_visibility = "private"} +flow.variable @varB dense<3> : tensor<2x4xi32> attributes {sym_visibility = "private"} + +// CHECK-LABEL: @constants() +func @constants() { + // CHECK-NEXT: %[[VAR_A:.+]] = flow.variable.load @varA : tensor<2xi32> + // CHECK-NEXT: %[[VAR_B:.+]] = flow.variable.load @varB : tensor<2x4xi32> + // CHECK-NEXT: constant 10 + %w = constant 10 : index + %varA = flow.variable.load @varA : tensor<2xi32> + // CHECK-NEXT: %[[T:.+]] = flow.dispatch @ex::@dispatch0{{.+}}(%[[VAR_A]]) + %d0 = flow.dispatch @ex::@dispatch0[%w](%varA) : (tensor<2xi32>) -> tensor<2xi32> + %varB = flow.variable.load @varB : tensor<2x4xi32> + // CHECK-NEXT: flow.dispatch @ex::@dispatch1{{.+}}(%[[T]], %[[VAR_B]]) + %d1 = flow.dispatch @ex::@dispatch1[%w](%d0, %varB) : (tensor<2xi32>, tensor<2x4xi32>) -> tensor<2xi32> + return +} + +// ----- + +flow.variable @varA 1 : i32 attributes {sym_visibility = "private"} +flow.variable @varB 2 : i32 attributes {sym_visibility = "private"} + +// CHECK-LABEL: @constants_in_cfg +func @constants_in_cfg(%start: i32, %bound: i32) -> i32 { + // CHECK-NEXT: %[[VAR_A:.+]] = flow.variable.load @varA : i32 + // CHECK-NEXT: %[[VAR_B:.+]] = flow.variable.load @varB : i32 + // CHECK-NEXT: br ^bb1 + br ^bb1(%start : i32) +// CHECK: ^bb1(%[[BB1_ARG:.+]]: i32): +^bb1(%2: i32): + %cmp = cmpi slt, %2, %bound : i32 + cond_br %cmp, ^bb2(%2 : i32), ^bb3(%2 : i32) +// CHECK: ^bb2(%[[BB2_ARG:.+]]: i32): +^bb2(%5: i32): + %6 = flow.variable.load @varA : i32 + // CHECK-NEXT: = addi %[[BB2_ARG]], %[[VAR_A]] : i32 + %7 = addi %5, %6 : i32 + br ^bb1(%7 : i32) +// CHECK: ^bb3(%[[BB3_ARG:.+]]: i32): +^bb3(%8: i32): + %9 = flow.variable.load @varA : i32 + // CHECK-NEXT: %[[T0:.+]] = muli %[[BB3_ARG]], %[[VAR_A]] : i32 + %10 = muli %8, %9 : i32 + %11 = flow.variable.load @varB : i32 + // CHECK-NEXT: %[[T1:.+]] = subi %[[T0]], %[[VAR_B]] + %12 = subi %10, %11 : i32 + // CHECK-NEXT: return %[[T1]] + return %12 : i32 +} + +// ----- + +flow.variable @varA mutable dense<1> : tensor<2xi32> attributes {sym_visibility = "private"} +flow.variable @varB dense<3> : tensor<2x4xi32> attributes {sym_visibility = "private"} + +// CHECK-LABEL: @mixed_mutability +func @mixed_mutability() { + // CHECK-DAG: %[[VAR_A:.+]] = flow.variable.load @varA : tensor<2xi32> + // CHECK-DAG: %[[VAR_B:.+]] = flow.variable.load @varB : tensor<2x4xi32> + // CHECK-NEXT: constant 10 + %w = constant 10 : index + %varA = flow.variable.load @varA : tensor<2xi32> + // CHECK-NEXT: %[[T0:.+]] = flow.dispatch @ex::@dispatch0{{.+}}(%[[VAR_A]]) + %d0 = flow.dispatch @ex::@dispatch0[%w](%varA) : (tensor<2xi32>) -> tensor<2xi32> + %varB = flow.variable.load @varB : tensor<2x4xi32> + // CHECK-NEXT: %[[T1:.+]] = flow.dispatch @ex::@dispatch1{{.+}}(%[[T0]], %[[VAR_B]]) + %d1 = flow.dispatch @ex::@dispatch1[%w](%d0, %varB) : (tensor<2xi32>, tensor<2x4xi32>) -> tensor<2xi32> + // CHECK-NEXT: flow.variable.store %[[T1]], @varA : tensor<2xi32> + flow.variable.store %d1, @varA : tensor<2xi32> + return +} + +// ----- + +flow.variable @varA mutable dense<1> : tensor<2xi32> attributes {sym_visibility = "private"} + +// CHECK-LABEL: @raw +func @raw() { + // CHECK: %[[T:.+]] = flow.variable.load @varA {id = 0 + %varA_0 = flow.variable.load @varA {id = 0} : tensor<2xi32> + flow.variable.store %varA_0, @varA {id = 0} : tensor<2xi32> + %varA_1 = flow.variable.load @varA {id = 1} : tensor<2xi32> + // CHECK-NEXT: flow.variable.store %[[T]], @varA {id = 1 + flow.variable.store %varA_1, @varA {id = 1} : tensor<2xi32> + return +} + +// ----- + +flow.variable @varA mutable dense<1> : tensor<2xi32> attributes {sym_visibility = "private"} + +// CHECK-LABEL: @rar +func @rar() -> (tensor<2xi32>, tensor<2xi32>) { + // CHECK: %[[T:.+]] = flow.variable.load @varA {id = 0 + %varA_0 = flow.variable.load @varA {id = 0} : tensor<2xi32> + %varA_1 = flow.variable.load @varA {id = 1} : tensor<2xi32> + // CHECK-NEXT: return %[[T]], %[[T]] + return %varA_0, %varA_1 : tensor<2xi32>, tensor<2xi32> +} + +// ----- + +flow.variable @varA mutable dense<1> : tensor<2xi32> attributes {sym_visibility = "private"} + +// CHECK-LABEL: @waw +// CHECK-SAME: (%[[ARG0:.+]]: tensor<2xi32>, %[[ARG1:.+]]: tensor<2xi32>) +func @waw(%varA_0: tensor<2xi32>, %varA_1: tensor<2xi32>) { + flow.variable.store %varA_0, @varA : tensor<2xi32> + // CHECK-NEXT: flow.variable.store %[[ARG1]], @varA + flow.variable.store %varA_1, @varA : tensor<2xi32> + return +} + +// ----- + +flow.variable @varA mutable dense<1> : tensor<2xi32> attributes {sym_visibility = "private"} + +// CHECK-LABEL: @side_effects( +func @side_effects() { + // CHECK-NEXT: %[[T0:.+]] = flow.variable.load @varA + %varA_0 = flow.variable.load @varA : tensor<2xi32> + // CHECK-NEXT: flow.variable.store %[[T0]], @varA + flow.variable.store %varA_0, @varA : tensor<2xi32> + // CHECK-NEXT: call @other_fn() + call @other_fn() : () -> () + // CHECK-NEXT: %[[T1:.+]] = flow.variable.load @varA + %varA_1 = flow.variable.load @varA : tensor<2xi32> + // CHECK-NEXT: flow.variable.store %[[T1]], @varA + flow.variable.store %varA_1, @varA : tensor<2xi32> + return +} + +func private @other_fn()
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/BUILD b/iree/compiler/Dialect/HAL/Target/LLVM/BUILD index 01c1c13..3359b98 100644 --- a/iree/compiler/Dialect/HAL/Target/LLVM/BUILD +++ b/iree/compiler/Dialect/HAL/Target/LLVM/BUILD
@@ -41,12 +41,14 @@ "//iree/compiler/Codegen/LLVMCPU", "//iree/compiler/Codegen/Utils", "//iree/compiler/Dialect/HAL/Target", + "//iree/compiler/Dialect/HAL/Target/LLVM/librt", "//iree/compiler/Utils", "//iree/schemas:dylib_executable_def_c_fbs", "@llvm-project//llvm:AArch64AsmParser", "@llvm-project//llvm:AArch64CodeGen", "@llvm-project//llvm:ARMAsmParser", "@llvm-project//llvm:ARMCodeGen", + "@llvm-project//llvm:BitReader", "@llvm-project//llvm:BitWriter", "@llvm-project//llvm:Core", "@llvm-project//llvm:RISCVAsmParser",
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/CMakeLists.txt b/iree/compiler/Dialect/HAL/Target/LLVM/CMakeLists.txt index f12301b..29c9514 100644 --- a/iree/compiler/Dialect/HAL/Target/LLVM/CMakeLists.txt +++ b/iree/compiler/Dialect/HAL/Target/LLVM/CMakeLists.txt
@@ -32,6 +32,7 @@ LLVMAArch64CodeGen LLVMARMAsmParser LLVMARMCodeGen + LLVMBitReader LLVMBitWriter LLVMCore LLVMRISCVAsmParser @@ -50,6 +51,7 @@ iree::compiler::Codegen::PassHeaders iree::compiler::Codegen::Utils iree::compiler::Dialect::HAL::Target + iree::compiler::Dialect::HAL::Target::LLVM::librt iree::compiler::Utils iree::schemas::dylib_executable_def_c_fbs PUBLIC
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/LLVMAOTTarget.cpp b/iree/compiler/Dialect/HAL/Target/LLVM/LLVMAOTTarget.cpp index f80e47c..fdefab2 100644 --- a/iree/compiler/Dialect/HAL/Target/LLVM/LLVMAOTTarget.cpp +++ b/iree/compiler/Dialect/HAL/Target/LLVM/LLVMAOTTarget.cpp
@@ -13,9 +13,11 @@ #include "iree/compiler/Dialect/HAL/Target/LLVM/LibraryBuilder.h" #include "iree/compiler/Dialect/HAL/Target/LLVM/LinkerTool.h" #include "iree/compiler/Dialect/HAL/Target/LLVM/StaticLibraryGenerator.h" +#include "iree/compiler/Dialect/HAL/Target/LLVM/librt/librt.h" #include "iree/compiler/Dialect/HAL/Target/TargetRegistry.h" #include "iree/compiler/Utils/FlatbufferUtils.h" #include "iree/schemas/dylib_executable_def_builder.h" +#include "llvm/Bitcode/BitcodeReader.h" #include "llvm/Bitcode/BitcodeWriter.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" @@ -30,11 +32,10 @@ namespace IREE { namespace HAL { -namespace { +static constexpr char kQueryFunctionName[] = + "iree_hal_executable_library_query"; -constexpr char kQueryFunctionName[] = "iree_hal_executable_library_query"; - -llvm::Optional<FileLineColLoc> findFirstFileLoc(Location baseLoc) { +static llvm::Optional<FileLineColLoc> findFirstFileLoc(Location baseLoc) { if (auto loc = baseLoc.dyn_cast<FusedLoc>()) { for (auto &childLoc : loc.getLocations()) { auto childResult = findFirstFileLoc(childLoc); @@ -46,7 +47,7 @@ return llvm::None; } -std::string guessModuleName(mlir::ModuleOp moduleOp) { +static std::string guessModuleName(mlir::ModuleOp moduleOp) { std::string moduleName = moduleOp.getName().hasValue() ? moduleOp.getName().getValue().str() : ""; if (!moduleName.empty()) return moduleName; @@ -58,8 +59,6 @@ } } -} // namespace - class LLVMAOTTargetBackend final : public TargetBackend { public: explicit LLVMAOTTargetBackend(LLVMTargetOptions options) @@ -284,8 +283,10 @@ << options_.targetTriple << "'"; } - // Emit object files. - SmallVector<Artifact, 4> objectFiles; + SmallVector<Artifact> objectFiles; + + // Emit the base object file containing the bulk of our code. + // This must come first such that we have the proper library linking order. { // NOTE: today we just use a single object file, however if we wanted to // scale code generation and linking we'd want to generate one per @@ -306,6 +307,16 @@ objectFiles.push_back(std::move(objectFile)); } + // Optionally append additional object files that provide functionality that + // may otherwise have been runtime-dynamic (like libc/libm calls). + // For now we only do this for embedded uses. + if (options_.linkEmbedded) { + if (failed(buildLibraryObjects(variantOp.getLoc(), targetMachine.get(), + objectFiles, context))) { + return variantOp.emitError() << "failed generating library objects"; + } + } + // If we are keeping artifacts then let's also add the bitcode for easier // debugging (vs just the binary object file). if (options_.keepLinkerArtifacts) { @@ -329,7 +340,6 @@ // Copy the static object file to the specified output along with // generated header file. const std::string &libraryPath = options_.staticLibraryOutput; - const auto library_name = objectFiles[0].path; if (!outputStaticLibrary(libraryName, queryFunctionName, libraryPath, objectFiles[0].path)) { return variantOp.emitError() << "static library generation failed"; @@ -506,6 +516,76 @@ return IREE::HAL::ExecutableTargetAttr::get(context, "llvm", format); } + static void overridePlatformGlobal(llvm::Module &module, StringRef globalName, + uint32_t newValue) { + // NOTE: the global will not be defined if it is not used in the module. + auto *globalValue = module.getNamedGlobal(globalName); + if (!globalValue) return; + globalValue->setLinkage(llvm::GlobalValue::PrivateLinkage); + globalValue->setDSOLocal(true); + globalValue->setConstant(true); + globalValue->setInitializer(llvm::ConstantInt::get( + globalValue->getValueType(), APInt(32, newValue))); + } + + // Builds an object file for the librt embedded runtime library. + // This is done per link operation so that we can match the precise target + // configuration. Since we (mostly) link once per user-level compilation + // this is fine today. If in the future we invoke the compiler for thousands + // of modules we'd want to (carefully) cache this. + LogicalResult buildLibraryObjects(Location loc, + llvm::TargetMachine *targetMachine, + SmallVector<Artifact> &objectFiles, + llvm::LLVMContext &context) { + assert(!objectFiles.empty() && "libraries must come after the base object"); + + // Load the generic bitcode file contents. + llvm::MemoryBufferRef bitcodeBufferRef( + llvm::StringRef(iree_compiler_librt_create()->data, + iree_compiler_librt_create()->size), + "librt.bc"); + auto bitcodeModuleValue = llvm::parseBitcodeFile(bitcodeBufferRef, context); + if (!bitcodeModuleValue) { + return mlir::emitError(loc) + << "failed to parse librt bitcode: " + << llvm::toString(bitcodeModuleValue.takeError()); + } + auto bitcodeModule = std::move(bitcodeModuleValue.get()); + bitcodeModule->setDataLayout(targetMachine->createDataLayout()); + bitcodeModule->setTargetTriple(targetMachine->getTargetTriple().str()); + + // Inject target-specific flags. + // TODO(benvanik): move this entire function to another file that can do + // more complex logic cleanly. This is just an example. + overridePlatformGlobal(*bitcodeModule, "librt_platform_example_flag", 0u); + + // Run the LLVM passes to optimize it for the current target. + if (failed(runLLVMIRPasses(options_, targetMachine, bitcodeModule.get()))) { + return mlir::emitError(loc) + << "failed to run librt LLVM-IR opt passes targeting '" + << options_.targetTriple << "'"; + } + + // Emit an object file we can pass to the linker. + std::string objectData; + if (failed(runEmitObjFilePasses(targetMachine, bitcodeModule.get(), + &objectData))) { + return mlir::emitError(loc) + << "failed to compile librt LLVM-IR module to an object file"; + } + + // Write the object file to disk with a similar name to the base file. + auto objectFile = + Artifact::createVariant(objectFiles.front().path, ".librt.o"); + auto &os = objectFile.outputFile->os(); + os << objectData; + os.flush(); + os.close(); + objectFiles.push_back(std::move(objectFile)); + + return success(); + } + LLVMTargetOptions options_; };
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/internal/EmbeddedLinkerTool.cpp b/iree/compiler/Dialect/HAL/Target/LLVM/internal/EmbeddedLinkerTool.cpp index f2847c2..cda219e 100644 --- a/iree/compiler/Dialect/HAL/Target/LLVM/internal/EmbeddedLinkerTool.cpp +++ b/iree/compiler/Dialect/HAL/Target/LLVM/internal/EmbeddedLinkerTool.cpp
@@ -80,7 +80,7 @@ // Create the shared object name; if we only have a single input object we // can just reuse that. - if (objectFiles.size() == 1) { + if (!objectFiles.empty()) { artifacts.libraryFile = Artifact::createVariant(objectFiles.front().path, "so"); } else {
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/librt/BUILD b/iree/compiler/Dialect/HAL/Target/LLVM/librt/BUILD new file mode 100644 index 0000000..f9f6feb --- /dev/null +++ b/iree/compiler/Dialect/HAL/Target/LLVM/librt/BUILD
@@ -0,0 +1,31 @@ +# Copyright 2021 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +load("//build_tools/embed_data:build_defs.bzl", "c_embed_data") +load("//iree:build_defs.oss.bzl", "iree_cmake_extra_content") + +package( + default_visibility = ["//visibility:public"], + features = ["layering_check"], + licenses = ["notice"], # Apache 2.0 +) + +iree_cmake_extra_content( + content = """ +if(NOT "${IREE_TARGET_BACKEND_DYLIB-LLVM-AOT}" AND NOT "${IREE_TARGET_BACKEND_WASM-LLVM-AOT}") + return() +endif() +""", +) + +c_embed_data( + name = "librt", + srcs = ["bin/librt.bc"], + c_file_output = "librt.c", + flatten = True, + h_file_output = "librt.h", + identifier = "iree_compiler_librt", +)
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/librt/CMakeLists.txt b/iree/compiler/Dialect/HAL/Target/LLVM/librt/CMakeLists.txt new file mode 100644 index 0000000..c6dabe0 --- /dev/null +++ b/iree/compiler/Dialect/HAL/Target/LLVM/librt/CMakeLists.txt
@@ -0,0 +1,32 @@ +################################################################################ +# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from # +# iree/compiler/Dialect/HAL/Target/LLVM/librt/BUILD # +# # +# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary # +# CMake-only content. # +# # +# To disable autogeneration for this file entirely, delete this header. # +################################################################################ + +if(NOT "${IREE_TARGET_BACKEND_DYLIB-LLVM-AOT}" AND NOT "${IREE_TARGET_BACKEND_WASM-LLVM-AOT}") + return() +endif() + +iree_add_all_subdirs() + +iree_c_embed_data( + NAME + librt + SRCS + "bin/librt.bc" + C_FILE_OUTPUT + "librt.c" + H_FILE_OUTPUT + "librt.h" + IDENTIFIER + "iree_compiler_librt" + FLATTEN + PUBLIC +) + +### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/librt/bin/librt.bc b/iree/compiler/Dialect/HAL/Target/LLVM/librt/bin/librt.bc new file mode 100644 index 0000000..a484f31 --- /dev/null +++ b/iree/compiler/Dialect/HAL/Target/LLVM/librt/bin/librt.bc Binary files differ
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/librt/build.sh b/iree/compiler/Dialect/HAL/Target/LLVM/librt/build.sh new file mode 100644 index 0000000..c671e4d --- /dev/null +++ b/iree/compiler/Dialect/HAL/Target/LLVM/librt/build.sh
@@ -0,0 +1,45 @@ +# Copyright 2021 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +set -e + +SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )" +OUT="${SCRIPT_DIR}/bin" +SRC="${SCRIPT_DIR}/src" +LL_FILE="${OUT}/librt.ll" +BC_FILE="${OUT}/librt.bc" + +# Generate an LLVM IR assembly listing so we can easily read the file. +# This is not checked in or used by the compiler. +clang \ + -target wasm32 \ + -std=c17 \ + -O2 \ + -Xclang -disable-llvm-passes \ + -fno-ident \ + -fvisibility=hidden \ + -nostdinc \ + -g0 \ + -S \ + -emit-llvm \ + -fno-verbose-asm \ + -fdiscard-value-names \ + -o "${LL_FILE}" \ + -c \ + "${SRC}/libm.c" + +# Clang adds a bunch of bad attributes and host-specific information that we +# don't want (so we get at least somewhat deterministic builds). +sed -i 's/^;.*$//' ${LL_FILE} +sed -i 's/^source_filename.*$//' ${LL_FILE} +sed -i 's/^target datalayout.*$//' ${LL_FILE} +sed -i 's/^target triple.*$//' ${LL_FILE} +sed -i 's/^\(attributes #[0-9]* = {\).*$/\1 inlinehint }/' ${LL_FILE} + +# Generate a binary bitcode file embedded into the compiler binary. +# NOTE: we do this from stdin so that the filename on the user's system is not +# embedded in the bitcode file (making it non-deterministic). +cat ${LL_FILE} | llvm-as -o=${BC_FILE}
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/librt/src/libm.c b/iree/compiler/Dialect/HAL/Target/LLVM/librt/src/libm.c new file mode 100644 index 0000000..bdbe0ae --- /dev/null +++ b/iree/compiler/Dialect/HAL/Target/LLVM/librt/src/libm.c
@@ -0,0 +1,13 @@ +// Copyright 2021 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "libm.h" + +// https://en.cppreference.com/w/c/numeric/math/fma +LIBRT_EXPORT float fmaf(float x, float y, float z) { + // TODO(*): a real implementation :) + return (x * y) + z; +}
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/librt/src/libm.h b/iree/compiler/Dialect/HAL/Target/LLVM/librt/src/libm.h new file mode 100644 index 0000000..10dfe76 --- /dev/null +++ b/iree/compiler/Dialect/HAL/Target/LLVM/librt/src/libm.h
@@ -0,0 +1,15 @@ +// Copyright 2021 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_COMPILER_DIALECT_HAL_TARGET_LLVM_LIBRT_SRC_LIBM_H_ +#define IREE_COMPILER_DIALECT_HAL_TARGET_LLVM_LIBRT_SRC_LIBM_H_ + +#include "librt.h" + +// https://en.cppreference.com/w/c/numeric/math/fma +LIBRT_EXPORT float fmaf(float x, float y, float z); + +#endif // IREE_COMPILER_DIALECT_HAL_TARGET_LLVM_LIBRT_SRC_LIBM_H_
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/librt/src/librt.h b/iree/compiler/Dialect/HAL/Target/LLVM/librt/src/librt.h new file mode 100644 index 0000000..0af8b21 --- /dev/null +++ b/iree/compiler/Dialect/HAL/Target/LLVM/librt/src/librt.h
@@ -0,0 +1,75 @@ +// Copyright 2021 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +//===----------------------------------------------------------------------===// +// A simplified libc/libm-alike that is designed to compile to portable LLVM IR. +//===----------------------------------------------------------------------===// +// This library is focused on supporting the subset of LLVM's RuntimeLibcalls +// that we need in our embedded executable binaries. This means that things like +// printf, malloc, etc are excluded. +// +// See the full list of possible functions here: +// third_party/llvm-project/llvm/include/llvm/IR/RuntimeLibcalls.def +// +// Code here must not use any system headers - as almost all pull in bits/ and +// various other target-dependent definitions that make the resulting IR +// non-portable. This means there is no size_t, etc. Any definitions that may +// come from an std* file must be redefined here with care. +// +// Code must also not use any mutable global or thread-local state ala +// errno/rounding modes/etc. Each of the functions in the library will be called +// concurrently from multiple threads and from multiple source modules. There +// must be no mutable static values anywhere. +// +// Avoid #ifdef entirely: they indicate a leakage of host build configuration +// into what is supposed to be a portable module. Anything that requires +// target-specific conditional logic must be implemented via an extern that +// can be substituted by the IREE compiler when producing the final +// target-specific module. + +//===----------------------------------------------------------------------===// +// Attributes and metadata +//===----------------------------------------------------------------------===// + +// Tagged on functions that are part of the public API. +#define LIBRT_EXPORT __attribute__((visibility("hidden"))) + +//===----------------------------------------------------------------------===// +// stdint.h +//===----------------------------------------------------------------------===// +// https://pubs.opengroup.org/onlinepubs/009604599/basedefs/stdint.h.html +// NOTE: no size_t/ptrdiff_t/etc (as they are target dependent). + +typedef signed char int8_t; +typedef short int16_t; +typedef int int32_t; +typedef long long int64_t; +typedef unsigned char uint8_t; +typedef unsigned short uint16_t; +typedef unsigned int uint32_t; +typedef unsigned long long uint64_t; + +#define INT8_MIN (-127i8 - 1) +#define INT16_MIN (-32767i16 - 1) +#define INT32_MIN (-2147483647i32 - 1) +#define INT64_MIN (-9223372036854775807i64 - 1) +#define INT8_MAX 127i8 +#define INT16_MAX 32767i16 +#define INT32_MAX 2147483647i32 +#define INT64_MAX 9223372036854775807i64 +#define UINT8_MAX 0xffui8 +#define UINT16_MAX 0xffffui16 +#define UINT32_MAX 0xffffffffui32 +#define UINT64_MAX 0xffffffffffffffffui64 + +//===----------------------------------------------------------------------===// +// Target-specific queries +//===----------------------------------------------------------------------===// +// These are substituted with values from the compiler and must not be specified +// here in C before we generate the IR. + +// Do not use: here as an example. Remove once we have any other flag. +extern int librt_platform_example_flag;
diff --git a/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp b/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp index c7d61c8..3e47a35 100644 --- a/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp +++ b/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp
@@ -203,11 +203,11 @@ return {Range{zero, ub, one}}; } -Operation *ScatterOp::getTiledImplementation( - OpBuilder &builder, ValueRange outputs, ArrayRef<OpFoldResult> offsets, - ArrayRef<OpFoldResult> sizes, - SmallVectorImpl<SmallVector<OpFoldResult, 4>> &resultOffsets, - SmallVectorImpl<SmallVector<OpFoldResult, 4>> &resultSizes) { +Operation *ScatterOp::getTiledImplementation(OpBuilder &builder, + ValueRange outputs, + ArrayRef<OpFoldResult> offsets, + ArrayRef<OpFoldResult> sizes, + SmallVectorImpl<Value> &results) { assert(outputs.size() == 1 && offsets.size() == 1 && sizes.size() == 1); Location loc = getLoc(); auto zeroAttr = builder.getI64IntegerAttr(0); @@ -241,20 +241,18 @@ indicesSizes, indicesStrides); assert(tiledIndices && "failed to get slice of indices"); - resultOffsets.resize(1); - resultOffsets[0].resize(updateRank, zeroAttr); - resultSizes.resize(1); - resultSizes[0].resize(updateRank); - for (auto dim : llvm::seq<int64_t>(0, updateRank)) { - resultSizes[0][dim] = getDim(builder, loc, original(), dim); - } SmallVector<Type> resultTypes; if (getNumResults()) { resultTypes.push_back(getResultTypes()[0]); } - return cast<LinalgExtOp>(getOperation()) - .clone(builder, loc, resultTypes, - ValueRange{tiledUpdate, tiledIndices, outputs[0]}); + Operation *tiledScatterOp = + cast<LinalgExtOp>(getOperation()) + .clone(builder, loc, resultTypes, + ValueRange{tiledUpdate, tiledIndices, outputs[0]}); + if (getNumResults()) { + results.push_back(tiledScatterOp->getResult(0)); + } + return tiledScatterOp; } void ScatterOp::generateScalarUpdateLoopBody(OpBuilder &b, Location loc, @@ -466,11 +464,11 @@ return partitionableLoops; } -Operation *SortOp::getTiledImplementation( - OpBuilder &builder, ValueRange outputs, ArrayRef<OpFoldResult> offsets, - ArrayRef<OpFoldResult> sizes, - SmallVectorImpl<SmallVector<OpFoldResult, 4>> &resultOffsets, - SmallVectorImpl<SmallVector<OpFoldResult, 4>> &resultSizes) { +Operation *SortOp::getTiledImplementation(OpBuilder &builder, + ValueRange outputs, + ArrayRef<OpFoldResult> offsets, + ArrayRef<OpFoldResult> sizes, + SmallVectorImpl<Value> &results) { assert(outputs.size() == this->outputs().size()); int64_t rank = getOperandRank(); assert(offsets.size() == static_cast<size_t>(rank) && @@ -479,22 +477,24 @@ SmallVector<OpFoldResult> strides(rank, oneAttr); Location loc = getLoc(); SmallVector<Value> tiledOperands(outputs.size()); - resultOffsets.resize(outputs.size()); - resultSizes.resize(outputs.size()); for (auto en : llvm::enumerate(outputs)) { tiledOperands[en.index()] = getSlice(builder, getLoc(), en.value(), offsets, sizes, strides); assert(tiledOperands[en.index()] && "failed to get slice of operand"); - resultOffsets[en.index()].assign(offsets.begin(), offsets.end()); - resultSizes[en.index()].assign(sizes.begin(), sizes.end()); } SmallVector<Type, 4> resultTypes; if (getNumResults()) { resultTypes = llvm::to_vector<4>( llvm::map_range(tiledOperands, [&](Value v) { return v.getType(); })); } - return cast<LinalgExtOp>(getOperation()) - .clone(builder, loc, resultTypes, tiledOperands); + Operation *tiledSortOp = cast<LinalgExtOp>(getOperation()) + .clone(builder, loc, resultTypes, tiledOperands); + for (auto result : llvm::enumerate(tiledSortOp->getResults())) { + auto insertSliceOp = builder.create<tensor::InsertSliceOp>( + loc, result.value(), outputs[result.index()], offsets, sizes, strides); + results.push_back(insertSliceOp.getResult()); + } + return tiledSortOp; } LogicalResult SortOp::generateScalarImplementation(OpBuilder &b, Location loc,
diff --git a/iree/compiler/Dialect/LinalgExt/IR/TiledOpInterface.td b/iree/compiler/Dialect/LinalgExt/IR/TiledOpInterface.td index 232c6aa..337b273 100644 --- a/iree/compiler/Dialect/LinalgExt/IR/TiledOpInterface.td +++ b/iree/compiler/Dialect/LinalgExt/IR/TiledOpInterface.td
@@ -85,6 +85,10 @@ /*desc=*/[{ Generates a tiled version of the operation given the tile size for the loops. + + Returns the tiled operation generated. If the operation has + tensor semantics then the result of the tiled values are to + be inserted into the `outputs` and return in `results`. }], /*retType=*/"Operation *", /*methodName=*/"getTiledImplementation", @@ -93,8 +97,7 @@ "ValueRange ":$outputs, "ArrayRef<OpFoldResult> ":$offsets, "ArrayRef<OpFoldResult> ":$sizes, - "SmallVectorImpl<SmallVector<OpFoldResult, 4>> &":$resultOffsets, - "SmallVectorImpl<SmallVector<OpFoldResult, 4>> &":$resultSizes), + "SmallVectorImpl<Value> &":$results), /*methodBody=*/"", /*defaultImplementation=*/[{ return nullptr;
diff --git a/iree/compiler/Dialect/LinalgExt/Transforms/Tiling.cpp b/iree/compiler/Dialect/LinalgExt/Transforms/Tiling.cpp index 829c306..bf6f1be 100644 --- a/iree/compiler/Dialect/LinalgExt/Transforms/Tiling.cpp +++ b/iree/compiler/Dialect/LinalgExt/Transforms/Tiling.cpp
@@ -132,36 +132,13 @@ // If this is the innermost loop, then generated the tiled implementation of // the op by invoking the TiledOpInterface methods. if (loopDepth == tileSizes.size()) { - SmallVector<SmallVector<OpFoldResult, 4>> resultOffsets; - SmallVector<SmallVector<OpFoldResult, 4>> resultSizes; - Operation *tiledOp = tilableOp.getTiledImplementation( - builder, outputs, offsets, tileSizes, resultOffsets, resultSizes); - if (!tiledOp) { + TiledOp ret; + ret.op = tilableOp.getTiledImplementation(builder, outputs, offsets, + tileSizes, ret.results); + if (!ret.op) { return static_cast<LogicalResult>( tilableOp.emitOpError("failed to get tiled implementation")); } - assert(tiledOp->getNumResults() == 0 || - (resultOffsets.size() == tiledOp->getNumResults())); - TiledOp ret; - ret.op = tiledOp; - - // If the operation has results, then the result of the tiled operation is - // to be inserted into the `initValues` and returned. - if (tiledOp->getNumResults()) { - SmallVector<Value> results; - auto oneAttr = builder.getI64IntegerAttr(1); - results.reserve(tiledOp->getNumResults()); - for (auto en : llvm::enumerate(tiledOp->getResults())) { - Value result = en.value(); - ArrayRef<OpFoldResult> offsets(resultOffsets[en.index()]); - ArrayRef<OpFoldResult> sizes(resultSizes[en.index()]); - SmallVector<OpFoldResult> strides(offsets.size(), oneAttr); - Value insert = builder.create<tensor::InsertSliceOp>( - loc, result, outputs[en.index()], offsets, sizes, strides); - results.push_back(insert); - } - std::swap(ret.results, results); - } return ret; } @@ -318,15 +295,13 @@ return loopBounds; } - Operation *getTiledImplementation( - Operation *op, OpBuilder &b, ValueRange outputs, - ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes, - SmallVectorImpl<SmallVector<OpFoldResult, 4>> &resultOffsets, - SmallVectorImpl<SmallVector<OpFoldResult, 4>> &resultSizes) const { - // Compute a subtensor of the source based on the offsets. + Operation *getTiledImplementation(Operation *op, OpBuilder &b, + ValueRange outputs, + ArrayRef<OpFoldResult> offsets, + ArrayRef<OpFoldResult> sizes, + SmallVectorImpl<Value> &results) const { auto insertOp = cast<tensor::InsertSliceOp>(op); - auto opOffsets = insertOp.getMixedOffsets(); - auto opSizes = insertOp.getMixedSizes(); + // Compute a subtensor of the source based on the offsets. auto opStrides = insertOp.getMixedStrides(); if (!llvm::all_of(opStrides, [&](OpFoldResult valueOrAttr) { return isValue(valueOrAttr, 1); @@ -334,23 +309,22 @@ op->emitOpError("unable to tile operation with non-unit stride"); return nullptr; } - // The operation returned is just a tensor.extract_slice of the source with - // the given offsets, sizes and strides. Setting the correct result offset - // will make the sure the tiling algorithm will insert this slice into the - // correct place in the destination. - // The result offset is just the offset passed in plus the offset specified - // in the op (since all strides are checked to be 1). + Location loc = insertOp.getLoc(); + auto oneAttr = b.getI64IntegerAttr(1); + SmallVector<OpFoldResult> strides(offsets.size(), oneAttr); + auto extractSliceOp = b.create<tensor::ExtractSliceOp>( + loc, insertOp.source(), offsets, sizes, strides); + + // The offsets for the insert is based on the op offsets plus the offsets of + // the loops passed in. + auto opOffsets = insertOp.getMixedOffsets(); + auto opSizes = insertOp.getMixedSizes(); unsigned offsetIndex = 0; ArrayRef<int64_t> sourceShape = insertOp.getSourceType().getShape(); int64_t destRank = insertOp.getType().getRank(); - resultOffsets.resize(1); - resultOffsets[0].resize(destRank); - resultSizes.resize(1); - resultSizes[0].resize(destRank); - Location loc = insertOp.getLoc(); + SmallVector<OpFoldResult> resultOffsets(destRank); + SmallVector<OpFoldResult> resultSizes(destRank); auto zeroAttr = b.getI64IntegerAttr(0); - auto oneAttr = b.getI64IntegerAttr(1); - SmallVector<OpFoldResult> strides(offsets.size(), oneAttr); for (auto opOffset : llvm::enumerate(opOffsets)) { // Check for rank-reducing by checking that // 1) The corresponding opSize value is 1 @@ -358,29 +332,33 @@ // Then the opOffset is for the rank-reduced dimension. Skip. unsigned opOffsetIndex = opOffset.index(); if (isValue(opSizes[opOffsetIndex], 1) && sourceShape[offsetIndex] != 1) { - resultOffsets[0][opOffsetIndex] = zeroAttr; - resultSizes[0][opOffsetIndex] = oneAttr; + resultOffsets[opOffsetIndex] = zeroAttr; + resultSizes[opOffsetIndex] = oneAttr; continue; } OpFoldResult opOffsetVal = opOffset.value(); OpFoldResult offset = offsets[offsetIndex]; if (opOffsetVal.is<Attribute>() && offset.is<Attribute>()) { - resultOffsets[0][opOffsetIndex] = b.getI64IntegerAttr( + resultOffsets[opOffsetIndex] = b.getI64IntegerAttr( *getConstantValue(opOffsetVal) + *getConstantValue(offset)); } else { AffineMap map = AffineMap::get( 1, 1, {b.getAffineDimExpr(0) + b.getAffineSymbolExpr(0)}); - resultOffsets[0][opOffsetIndex] = + resultOffsets[opOffsetIndex] = b.create<AffineApplyOp>(loc, map, ValueRange{getValue(b, loc, offset), getValue(b, loc, opOffsetVal)}) .getResult(); } - resultSizes[0][opOffsetIndex] = sizes[offsetIndex]; + resultSizes[opOffsetIndex] = sizes[offsetIndex]; offsetIndex++; } - return b.create<tensor::ExtractSliceOp>(loc, insertOp.source(), offsets, - sizes, strides); + SmallVector<OpFoldResult> resultStrides(destRank, oneAttr); + auto tiledInsertOp = b.create<tensor::InsertSliceOp>( + loc, extractSliceOp.result(), outputs[0], resultOffsets, resultSizes, + resultStrides); + results.push_back(tiledInsertOp.result()); + return extractSliceOp; } }; } // namespace
diff --git a/iree/compiler/Dialect/LinalgExt/Transforms/test/tiling.mlir b/iree/compiler/Dialect/LinalgExt/Transforms/test/tiling.mlir index 5818c01..41a1cd0 100644 --- a/iree/compiler/Dialect/LinalgExt/Transforms/test/tiling.mlir +++ b/iree/compiler/Dialect/LinalgExt/Transforms/test/tiling.mlir
@@ -30,15 +30,11 @@ // CHECK-SAME: [%[[USED_TILESIZE]], %[[D1]]] // CHECK: %[[INDEX_SLICE:.+]] = tensor.extract_slice %[[INDICES]][%[[IV]], 0] // CHECK-SAME: [%[[USED_TILESIZE]], 1] -// CHECK-DAG: %[[SLICE_D0:.+]] = tensor.dim %[[ORIGINAL]], %[[C0]] -// CHECK-DAG: %[[SLICE_D1:.+]] = tensor.dim %[[ORIGINAL]], %[[C1]] // CHECK: %[[SCATTER_TILE:.+]] = linalg_ext.scatter // CHECK-SAME: __internal_linalg_transform__ = "tiling_output" // CHECK-SAME: ins(%[[UPDATE_SLICE]], %[[INDEX_SLICE]] // CHECK-SAME: outs(%[[INIT]] -// CHECK: %[[YIELD:.+]] = tensor.insert_slice %[[SCATTER_TILE]] into %[[INIT]][0, 0] -// CHECK-SAME: [%[[SLICE_D0]], %[[SLICE_D1]]] -// CHECK: scf.yield %[[YIELD]] +// CHECK: scf.yield %[[SCATTER_TILE]] // CHECK: return %[[RESULT]] // ----- @@ -114,15 +110,11 @@ // CHECK-SAME: [%[[USED_TILESIZE]], %[[D1]]] // CHECK: %[[INDEX_SLICE:.+]] = tensor.extract_slice %[[INDICES]][%[[IV]], 0] // CHECK-SAME: [%[[USED_TILESIZE]], 1] -// CHECK-DAG: %[[SLICE_D0:.+]] = tensor.dim %[[ORIGINAL]], %[[C0]] -// CHECK-DAG: %[[SLICE_D1:.+]] = tensor.dim %[[ORIGINAL]], %[[C1]] // CHECK: %[[SCATTER_TILE:.+]] = linalg_ext.scatter // CHECK-SAME: __internal_linalg_transform__ = "distribute_output" // CHECK-SAME: ins(%[[UPDATE_SLICE]], %[[INDEX_SLICE]] // CHECK-SAME: outs(%[[INIT]] -// CHECK: %[[YIELD:.+]] = tensor.insert_slice %[[SCATTER_TILE]] into %[[INIT]][0, 0] -// CHECK-SAME: [%[[SLICE_D0]], %[[SLICE_D1]]] -// CHECK: scf.yield %[[YIELD]] +// CHECK: scf.yield %[[SCATTER_TILE]] // CHECK: return %[[RESULT]] // ----- @@ -435,9 +427,9 @@ // CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: index // CHECK: %[[RESULT:.+]] = scf.for %[[IV0:[a-zA-Z0-9]+]] = // CHECK: %[[YIELD1:.+]] = scf.for %[[IV1:[a-zA-Z0-9]+]] = +// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[ARG0]][%[[IV0]], %[[IV1]]] // CHECK: %[[OFFSET0:.+]] = affine.apply #[[MAP2]](%[[IV0]])[%[[ARG2]]] // CHECK: %[[OFFSET1:.+]] = affine.apply #[[MAP2]](%[[IV1]])[%[[ARG3]]] -// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[ARG0]][%[[IV0]], %[[IV1]]] // CHECK: %[[UPDATE:.+]] = tensor.insert_slice %[[SLICE]] // CHECK-SAME: into %{{.+}}[%[[OFFSET0]], %[[OFFSET1]]] // CHECK: scf.yield %[[UPDATE]] @@ -466,9 +458,9 @@ // CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: index // CHECK: %[[RESULT:.+]] = scf.for %[[IV0:[a-zA-Z0-9]+]] = // CHECK: %[[YIELD1:.+]] = scf.for %[[IV1:[a-zA-Z0-9]+]] = +// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[ARG0]][%[[IV0]], %[[IV1]]] // CHECK: %[[OFFSET0:.+]] = affine.apply #[[MAP2]](%[[IV0]])[%[[ARG2]]] // CHECK: %[[OFFSET1:.+]] = affine.apply #[[MAP2]](%[[IV1]])[%[[ARG3]]] -// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[ARG0]][%[[IV0]], %[[IV1]]] // CHECK: %[[UPDATE:.+]] = tensor.insert_slice %[[SLICE]] // CHECK-SAME: into %{{.+}}[%[[OFFSET0]], 0, %[[OFFSET1]]] // CHECK: scf.yield %[[UPDATE]]
diff --git a/iree/compiler/InputConversion/MHLO/BUILD b/iree/compiler/InputConversion/MHLO/BUILD index c1dc464..aaf1475 100644 --- a/iree/compiler/InputConversion/MHLO/BUILD +++ b/iree/compiler/InputConversion/MHLO/BUILD
@@ -46,10 +46,10 @@ name = "MHLO", srcs = [ "BroadcastingToLinalgPatterns.cpp", - "ConvertAndDistributeMHLOToLinalgExt.cpp", "ConvertComplexToReal.cpp", "ConvertMHLOToFlow.cpp", "ConvertMHLOToFlow.h", + "ConvertMHLOToLinalgExt.cpp", "LegalizeInputTypes.cpp", "MHLOToLinalgOnTensors.cpp", "MHLOToMHLOPreprocessing.cpp",
diff --git a/iree/compiler/InputConversion/MHLO/CMakeLists.txt b/iree/compiler/InputConversion/MHLO/CMakeLists.txt index 36d2a4c..66276fc 100644 --- a/iree/compiler/InputConversion/MHLO/CMakeLists.txt +++ b/iree/compiler/InputConversion/MHLO/CMakeLists.txt
@@ -41,10 +41,10 @@ "Passes.h" SRCS "BroadcastingToLinalgPatterns.cpp" - "ConvertAndDistributeMHLOToLinalgExt.cpp" "ConvertComplexToReal.cpp" "ConvertMHLOToFlow.cpp" "ConvertMHLOToFlow.h" + "ConvertMHLOToLinalgExt.cpp" "LegalizeInputTypes.cpp" "MHLOToLinalgOnTensors.cpp" "MHLOToMHLOPreprocessing.cpp"
diff --git a/iree/compiler/InputConversion/MHLO/ConvertAndDistributeMHLOToLinalgExt.cpp b/iree/compiler/InputConversion/MHLO/ConvertMHLOToLinalgExt.cpp similarity index 68% rename from iree/compiler/InputConversion/MHLO/ConvertAndDistributeMHLOToLinalgExt.cpp rename to iree/compiler/InputConversion/MHLO/ConvertMHLOToLinalgExt.cpp index b24c76d..571b854 100644 --- a/iree/compiler/InputConversion/MHLO/ConvertAndDistributeMHLOToLinalgExt.cpp +++ b/iree/compiler/InputConversion/MHLO/ConvertMHLOToLinalgExt.cpp
@@ -47,51 +47,6 @@ namespace { //===----------------------------------------------------------------------===// -// Base classes. -//===----------------------------------------------------------------------===// - -template <typename Derived, typename OpTy> -struct ConvertToLinalgExtPattern : public OpConversionPattern<OpTy> { - using OpConversionPattern<OpTy>::OpConversionPattern; - - LogicalResult matchAndRewrite( - OpTy op, ArrayRef<Value> args, - ConversionPatternRewriter &rewriter) const final { - Value one = rewriter.create<ConstantIndexOp>(op.getLoc(), 1); - SmallVector<Value> workload(3, one); - - // Gather the dynamic dimensions for all operands. - SmallVector<Value> operandDynamicDims; - for (Value arg : args) { - if (auto rt = arg.getType().dyn_cast<RankedTensorType>()) { - for (unsigned i = 0; i < rt.getRank(); ++i) { - if (!rt.isDynamicDim(i)) continue; - auto dim = rewriter.createOrFold<tensor::DimOp>(op.getLoc(), arg, i); - operandDynamicDims.push_back(dim); - } - } - } - - auto dispatchOp = rewriter.create<IREE::Flow::DispatchWorkgroupsOp>( - op.getLoc(), workload, op->getResultTypes(), - /*result_dims=*/ValueRange{}, - /*operands=*/args, - /*operand_dims=*/operandDynamicDims, - /*tied_operands=*/Derived::getTiedResultOperandIndices(args)); - { - OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPointToStart(&dispatchOp.getRegion().front()); - if (failed(Derived::lowerMHLOOp(dispatchOp, op, args, rewriter))) { - return failure(); - } - rewriter.create<IREE::Flow::ReturnOp>(op.getLoc()); - } - rewriter.replaceOp(op, dispatchOp.getResults()); - return success(); - } -}; - -//===----------------------------------------------------------------------===// // Region operations lowering. //===----------------------------------------------------------------------===// @@ -131,32 +86,15 @@ // SortOp //===----------------------------------------------------------------------===// -struct SortOpConversion - : public ConvertToLinalgExtPattern<SortOpConversion, mhlo::SortOp> { - using ConvertToLinalgExtPattern<SortOpConversion, - mhlo::SortOp>::ConvertToLinalgExtPattern; +struct SortOpConversion : public OpConversionPattern<mhlo::SortOp> { + using OpConversionPattern<mhlo::SortOp>::OpConversionPattern; - static SmallVector<int64_t> getTiedResultOperandIndices( - ArrayRef<Value> args) { - return llvm::to_vector<4>(llvm::seq<int64_t>(0, args.size())); - } - - static LogicalResult lowerMHLOOp(IREE::Flow::DispatchWorkgroupsOp dispatchOp, - mhlo::SortOp op, ArrayRef<Value> args, - ConversionPatternRewriter &rewriter) { - auto blockArgs = dispatchOp.getClosureBodyRegion().getArguments(); - SmallVector<Value> initValues; - ImplicitLocOpBuilder b(op.getLoc(), rewriter); - for (auto it : llvm::zip(args, blockArgs)) { - auto argTy = std::get<0>(it).getType().cast<RankedTensorType>(); - auto blockArg = std::get<1>(it); - initValues.push_back( - b.create<IREE::Flow::DispatchTensorLoadOp>(argTy, blockArg)); - } - - auto sortOp = b.create<linalg_ext::SortOp>(op.getResultTypes(), - /*inputs=*/ValueRange{}, - initValues, op.dimensionAttr()); + LogicalResult matchAndRewrite( + mhlo::SortOp op, ArrayRef<Value> args, + ConversionPatternRewriter &rewriter) const final { + auto sortOp = rewriter.create<linalg_ext::SortOp>( + op.getLoc(), op.getResultTypes(), + /*inputs=*/ValueRange{}, args, op.dimensionAttr()); rewriter.inlineRegionBefore(op.comparator(), sortOp.region(), sortOp.region().begin()); Region ®ion = sortOp.region(); @@ -169,12 +107,7 @@ } rewriter.applySignatureConversion(®ion, signature_converter); - for (auto it : llvm::zip(sortOp.getResults(), blockArgs)) { - auto value = std::get<0>(it); - auto target = std::get<1>(it); - b.create<IREE::Flow::DispatchTensorStoreOp>(value, target); - } - + rewriter.replaceOp(op, sortOp->getResults()); return success(); } }; @@ -183,10 +116,8 @@ // ScatterOp //===----------------------------------------------------------------------===// -struct ScatterOpConversion - : public ConvertToLinalgExtPattern<ScatterOpConversion, mhlo::ScatterOp> { - using ConvertToLinalgExtPattern<ScatterOpConversion, - mhlo::ScatterOp>::ConvertToLinalgExtPattern; +struct ScatterOpConversion : public OpConversionPattern<mhlo::ScatterOp> { + using OpConversionPattern<mhlo::ScatterOp>::OpConversionPattern; /// Returns true if the `dimensionNumbers` from the mhlo.scatter op follows a /// canonical form: @@ -275,28 +206,23 @@ return success(); } - static LogicalResult lowerMHLOOp(IREE::Flow::DispatchWorkgroupsOp dispatchOp, - mhlo::ScatterOp op, ArrayRef<Value> args, - ConversionPatternRewriter &rewriter) { + LogicalResult matchAndRewrite( + mhlo::ScatterOp op, ArrayRef<Value> args, + ConversionPatternRewriter &rewriter) const final { if (!hasCanonicalDimensionNumbers(op)) return failure(); ImplicitLocOpBuilder b(op.getLoc(), rewriter); mhlo::ScatterOpAdaptor adaptor(args); - auto blockArgs = dispatchOp.getClosureBodyRegion().getArguments(); - Value original = b.create<IREE::Flow::DispatchTensorLoadOp>( - adaptor.operand().getType().cast<RankedTensorType>(), blockArgs[0]); - Value indices = b.create<IREE::Flow::DispatchTensorLoadOp>( - adaptor.scatter_indices().getType().cast<RankedTensorType>(), - blockArgs[1]); - Value updates = b.create<IREE::Flow::DispatchTensorLoadOp>( - adaptor.updates().getType().cast<RankedTensorType>(), blockArgs[2]); + Value original = adaptor.operand(); + Value indices = adaptor.scatter_indices(); + Value updates = adaptor.updates(); if (failed(collapseBatchDimsIfNeeded(indices, updates, b))) { return failure(); } - auto scatterOp = b.create<linalg_ext::ScatterOp>( - op->getResultTypes(), ValueRange{updates, indices}, + auto scatterOp = rewriter.create<linalg_ext::ScatterOp>( + op.getLoc(), op->getResultTypes(), ValueRange{updates, indices}, ValueRange{original}); rewriter.inlineRegionBefore(op.update_computation(), scatterOp.region(), @@ -311,8 +237,7 @@ signatureConverter.addInputs(0, argType); rewriter.applySignatureConversion(®ion, signatureConverter); - b.create<IREE::Flow::DispatchTensorStoreOp>(scatterOp.getResult(0), - blockArgs[0]); + rewriter.replaceOp(op, scatterOp->getResults()); return success(); } }; @@ -321,9 +246,8 @@ // Pass //===----------------------------------------------------------------------===// -struct ConvertAndDistributeMHLOToLinalgExtPass - : public ConvertAndDistributeMHLOToLinalgExtBase< - ConvertAndDistributeMHLOToLinalgExtPass> { +struct ConvertMHLOToLinalgExtPass + : public ConvertMHLOToLinalgExtBase<ConvertMHLOToLinalgExtPass> { void getDependentDialects(DialectRegistry ®istry) const override { registry.insert<linalg_ext::LinalgExtDialect, linalg::LinalgDialect, IREE::Flow::FlowDialect, StandardOpsDialect, @@ -354,9 +278,8 @@ }; } // namespace -std::unique_ptr<OperationPass<FuncOp>> -createConvertAndDistributeMHLOToLinalgExtPass() { - return std::make_unique<ConvertAndDistributeMHLOToLinalgExtPass>(); +std::unique_ptr<OperationPass<FuncOp>> createConvertMHLOToLinalgExtPass() { + return std::make_unique<ConvertMHLOToLinalgExtPass>(); } } // namespace iree_compiler
diff --git a/iree/compiler/InputConversion/MHLO/Passes.cpp b/iree/compiler/InputConversion/MHLO/Passes.cpp index ebf3147..2b589cb 100644 --- a/iree/compiler/InputConversion/MHLO/Passes.cpp +++ b/iree/compiler/InputConversion/MHLO/Passes.cpp
@@ -66,7 +66,7 @@ // Convert to Linalg. After this point, MHLO will be eliminated. passManager.addNestedPass<FuncOp>( - mlir::iree_compiler::createConvertAndDistributeMHLOToLinalgExtPass()); + mlir::iree_compiler::createConvertMHLOToLinalgExtPass()); passManager.addNestedPass<FuncOp>( mlir::iree_compiler::createMHLOToLinalgOnTensorsPass());
diff --git a/iree/compiler/InputConversion/MHLO/Passes.h b/iree/compiler/InputConversion/MHLO/Passes.h index 5761952..8bc4222 100644 --- a/iree/compiler/InputConversion/MHLO/Passes.h +++ b/iree/compiler/InputConversion/MHLO/Passes.h
@@ -34,9 +34,8 @@ /// Creates XLA-HLO to Linalg on tensors transformation pass. std::unique_ptr<OperationPass<FuncOp>> createMHLOToLinalgOnTensorsPass(); -/// Creates XLA-HLO to LinalgExt and Flow transformation pass. -std::unique_ptr<OperationPass<FuncOp>> -createConvertAndDistributeMHLOToLinalgExtPass(); +/// Creates XLA-HLO to LinalgExt pass. +std::unique_ptr<OperationPass<FuncOp>> createConvertMHLOToLinalgExtPass(); /// Creates XLA-HLO preprocessing transformation pass. In this pass we should /// have all mhlo -> mhlo transformations that are shared between all
diff --git a/iree/compiler/InputConversion/MHLO/Passes.td b/iree/compiler/InputConversion/MHLO/Passes.td index a08fa08..b0082b2 100644 --- a/iree/compiler/InputConversion/MHLO/Passes.td +++ b/iree/compiler/InputConversion/MHLO/Passes.td
@@ -15,12 +15,12 @@ let constructor = "mlir::iree_compiler::createMHLOToLinalgOnTensorsPass()"; } -def ConvertAndDistributeMHLOToLinalgExt +def ConvertMHLOToLinalgExt : Pass<"iree-mhlo-to-linalg-ext", "FuncOp"> { let summary = "Convert from XLA-HLO ops to LinalgExt ops and distribute to Flow ops"; let constructor = - "mlir::iree_compiler::createConvertAndDistributeMHLOToLinalgExtPass()"; + "mlir::iree_compiler::createConvertMHLOToLinalgExtPass()"; } def LegalizeInputTypes :
diff --git a/iree/compiler/InputConversion/MHLO/test/BUILD b/iree/compiler/InputConversion/MHLO/test/BUILD index c3f4485..cef5ec5 100644 --- a/iree/compiler/InputConversion/MHLO/test/BUILD +++ b/iree/compiler/InputConversion/MHLO/test/BUILD
@@ -20,7 +20,7 @@ srcs = enforce_glob( [ "broadcasting.mlir", - "convert_and_distribute_mhlo_to_linalg_ext.mlir", + "convert_mhlo_to_linalg_ext.mlir", "convert_complex_to_real.mlir", "dynamic_shape.mlir", "fft.mlir",
diff --git a/iree/compiler/InputConversion/MHLO/test/CMakeLists.txt b/iree/compiler/InputConversion/MHLO/test/CMakeLists.txt index 9f3a453..d7405b8 100644 --- a/iree/compiler/InputConversion/MHLO/test/CMakeLists.txt +++ b/iree/compiler/InputConversion/MHLO/test/CMakeLists.txt
@@ -15,8 +15,8 @@ lit SRCS "broadcasting.mlir" - "convert_and_distribute_mhlo_to_linalg_ext.mlir" "convert_complex_to_real.mlir" + "convert_mhlo_to_linalg_ext.mlir" "dynamic_shape.mlir" "fft.mlir" "legalize_input_types.mlir"
diff --git a/iree/compiler/InputConversion/MHLO/test/convert_and_distribute_mhlo_to_linalg_ext.mlir b/iree/compiler/InputConversion/MHLO/test/convert_and_distribute_mhlo_to_linalg_ext.mlir deleted file mode 100644 index 8b80c34..0000000 --- a/iree/compiler/InputConversion/MHLO/test/convert_and_distribute_mhlo_to_linalg_ext.mlir +++ /dev/null
@@ -1,342 +0,0 @@ -// RUN: iree-opt -split-input-file -iree-mhlo-to-linalg-ext %s | IreeFileCheck %s - -func @sort_1d(%arg0: tensor<128xi32>) -> (tensor<128xi32>) { - %0 = "mhlo.sort"(%arg0) ( { - ^bb0(%arg2: tensor<i32>, %arg3: tensor<i32>): // no predecessors - %1 = "mhlo.compare"(%arg2, %arg3) {comparison_direction = "GT"} : (tensor<i32>, tensor<i32>) -> tensor<i1> - "mhlo.return"(%1) : (tensor<i1>) -> () - }) {dimension = 0 : i64, is_stable = false} : (tensor<128xi32>) -> (tensor<128xi32>) - return %0 : tensor<128xi32> -} -// CHECK-LABEL: func @sort_1d -// CHECK: %[[ARG0:[a-zA-Z0-9]+]] -// CHECK-DAG: %[[C1:.+]] = constant 1 : index -// CHECK: %[[RES:.+]] = flow.dispatch.workgroups -// CHECK-SAME: [%[[C1]], %[[C1]], %[[C1]]](%[[ARG0]]) : (tensor<128xi32>) -> %[[ARG0]] -// CHECK: %[[ARG1:.+]]: !flow.dispatch.tensor<readwrite:128xi32> -// CHECK: %[[IN:.+]] = flow.dispatch.tensor.load %[[ARG1]] -// CHECK: %[[SORT:.+]] = linalg_ext.sort -// CHECK-SAME: dimension(0) -// CHECK-SAME: outs(%[[IN]] : tensor<128xi32>) -// CHECK: ^bb0(%[[ARG2:.+]]: i32, %[[ARG3:.+]]: i32) -// CHECK: %[[CMP:.+]] = cmpi sgt, %[[ARG2]], %[[ARG3]] -// CHECK: linalg_ext.yield %[[CMP]] -// CHECK: flow.dispatch.tensor.store %[[SORT]], %[[ARG1]] -// CHECK: return %[[RES]] - -// ----- - -func @sort_2d(%arg0: tensor<16x32xi32>) -> (tensor<16x32xi32>) { - %0 = "mhlo.sort"(%arg0) ( { - ^bb0(%arg2: tensor<i32>, %arg3: tensor<i32>): // no predecessors - %1 = "mhlo.compare"(%arg2, %arg3) {comparison_direction = "GT"} : (tensor<i32>, tensor<i32>) -> tensor<i1> - "mhlo.return"(%1) : (tensor<i1>) -> () - }) {dimension = 0 : i64, is_stable = false} : (tensor<16x32xi32>) -> (tensor<16x32xi32>) - return %0 : tensor<16x32xi32> -} -// CHECK-LABEL: func @sort_2d -// CHECK: %[[ARG0:[a-zA-Z0-9]+]] -// CHECK-DAG: %[[C1:.+]] = constant 1 : index -// CHECK: %[[RES:.+]] = flow.dispatch.workgroups -// CHECK-SAME: [%[[C1]], %[[C1]], %[[C1]]](%[[ARG0]]) : (tensor<16x32xi32>) -> %[[ARG0]] -// CHECK: %[[ARG1:.+]]: !flow.dispatch.tensor<readwrite:16x32xi32> -// CHECK: %[[IN:.+]] = flow.dispatch.tensor.load %[[ARG1]] -// CHECK: %[[SORT:.+]] = linalg_ext.sort -// CHECK-SAME: dimension(0) -// CHECK-SAME: outs(%[[IN]] : tensor<16x32xi32>) -// CHECK: ^bb0(%[[ARG2:.+]]: i32, %[[ARG3:.+]]: i32) -// CHECK: %[[CMP:.+]] = cmpi sgt, %[[ARG2]], %[[ARG3]] -// CHECK: linalg_ext.yield %[[CMP]] -// CHECK: flow.dispatch.tensor.store %[[SORT]], %[[ARG1]] -// CHECK: return %[[RES]] - -// ----- - -func @topk(%arg0: tensor<128xi32>, %arg1: tensor<128xi32>) -> (tensor<128xi32>) { - %0:2 = "mhlo.sort"(%arg0, %arg1) ( { - ^bb0(%arg2: tensor<i32>, %arg3: tensor<i32>, %arg4: tensor<i32>, %arg5: tensor<i32>): // no predecessors - %1 = "mhlo.compare"(%arg2, %arg3) {comparison_direction = "GT"} : (tensor<i32>, tensor<i32>) -> tensor<i1> - "mhlo.return"(%1) : (tensor<i1>) -> () - }) {dimension = 0 : i64, is_stable = false} : (tensor<128xi32>, tensor<128xi32>) -> (tensor<128xi32>, tensor<128xi32>) - return %0#0 : tensor<128xi32> -} -// CHECK-LABEL: func @topk -// CHECK: %[[ARG0:[a-zA-Z0-9]+]] -// CHECK: %[[ARG1:[a-zA-Z0-9]+]] -// CHECK-DAG: %[[C1:.+]] = constant 1 : index -// CHECK: %[[RES:.+]]:2 = flow.dispatch.workgroups -// CHECK-SAME: [%[[C1]], %[[C1]], %[[C1]]](%[[ARG0]], %[[ARG1]]) -// CHECK-SAME: : (tensor<128xi32>, tensor<128xi32>) -> (%[[ARG0]], %[[ARG1]]) -// CHECK: %[[ARG2:[a-zA-Z0-9]+]]: !flow.dispatch.tensor<readwrite:128xi32> -// CHECK: %[[ARG3:[a-zA-Z0-9]+]]: !flow.dispatch.tensor<readwrite:128xi32> -// CHECK: %[[IN1:.+]] = flow.dispatch.tensor.load %[[ARG2]] -// CHECK: %[[IN2:.+]] = flow.dispatch.tensor.load %[[ARG3]] -// CHECK: %[[SORT:.+]]:2 = linalg_ext.sort -// CHECK-SAME: dimension(0) -// CHECK-SAME: outs(%[[IN1]], %[[IN2]] : tensor<128xi32>, tensor<128xi32>) -// CHECK: ^bb0(%[[ARG4:.+]]: i32, %[[ARG5:.+]]: i32, %{{.*}}: i32, %{{.*}}: i32) -// CHECK: %[[CMP:.+]] = cmpi sgt, %[[ARG4]], %[[ARG5]] -// CHECK: linalg_ext.yield %[[CMP]] -// CHECK: flow.dispatch.tensor.store %[[SORT]]#0, %[[ARG2]] -// CHECK: flow.dispatch.tensor.store %[[SORT]]#1, %[[ARG3]] -// CHECK: return %[[RES]]#0 - -// ----- - -func @scatter_update_scalar_1D(%arg0: tensor<8xi32>, %arg1: tensor<4x1xi32>, - %arg2: tensor<4xi32>) -> tensor<8xi32> { - %0 = "mhlo.scatter"(%arg0, %arg1, %arg2) ( { - ^bb0(%arg3: tensor<i32>, %arg4: tensor<i32>): // no predecessors - "mhlo.return"(%arg4) : (tensor<i32>) -> () - }) { - indices_are_sorted = false, - scatter_dimension_numbers = { - index_vector_dim = 1 : i64, - inserted_window_dims = dense<0> : tensor<1xi64>, - scatter_dims_to_operand_dims = dense<0> : tensor<1xi64>, - update_window_dims = dense<> : tensor<0xi64> - }, - unique_indices = false - } : (tensor<8xi32>, tensor<4x1xi32>, tensor<4xi32>) -> tensor<8xi32> - return %0 : tensor<8xi32> -} -// CHECK-LABEL: func @scatter_update_scalar_1D -// CHECK: %[[ARG0:[a-zA-Z0-9]+]] -// CHECK: %[[ARG1:[a-zA-Z0-9]+]] -// CHECK: %[[ARG2:[a-zA-Z0-9]+]] -// CHECK-DAG: %[[C1:.+]] = constant 1 : index -// CHECK: %[[RES:.+]] = flow.dispatch.workgroups -// CHECK-SAME: [%[[C1]], %[[C1]], %[[C1]]](%[[ARG0]], %[[ARG1]], %[[ARG2]]) -// CHECK-SAME: : (tensor<8xi32>, tensor<4x1xi32>, tensor<4xi32>) -> %[[ARG0]] -// CHECK: %[[ARG3:[a-zA-Z0-9]+]]: !flow.dispatch.tensor<readwrite:8xi32> -// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: !flow.dispatch.tensor<readonly:4x1xi32> -// CHECK-SAME: %[[ARG5:[a-zA-Z0-9]+]]: !flow.dispatch.tensor<readonly:4xi32> -// CHECK: %[[ORIGINAL:.+]] = flow.dispatch.tensor.load %[[ARG3]] -// CHECK: %[[INDICES:.+]] = flow.dispatch.tensor.load %[[ARG4]] -// CHECK: %[[UPDATES:.+]] = flow.dispatch.tensor.load %[[ARG5]] -// CHECK: %[[SCATTER:.+]] = linalg_ext.scatter -// CHECK-SAME: ins(%[[UPDATES]], %[[INDICES]] : tensor<4xi32>, tensor<4x1xi32>) -// CHECK-SAME: outs(%[[ORIGINAL]] : tensor<8xi32>) -// CHECK: ^bb0(%[[V1:.+]]: i32, %[[V2:.+]]: i32): // no predecessors -// CEECK: linalg.yield %[[V1]] -// CHECK: flow.dispatch.tensor.store %[[SCATTER]], %[[ARG3]] -// CHECK: return %[[RES]] - -// ----- - -func @scatter_update_scalar_2D(%arg0: tensor<4x3xi32>, %arg1: tensor<3x2xi32>, - %arg2: tensor<3xi32>) -> tensor<4x3xi32> { - %0 = "mhlo.scatter"(%arg0, %arg1, %arg2) ( { - ^bb0(%arg3: tensor<i32>, %arg4: tensor<i32>): // no predecessors - "mhlo.return"(%arg4) : (tensor<i32>) -> () - }) {indices_are_sorted = false, - scatter_dimension_numbers = { - index_vector_dim = 1 : i64, - inserted_window_dims = dense<[0, 1]> : tensor<2xi64>, - scatter_dims_to_operand_dims = dense<[0, 1]> : tensor<2xi64>, - update_window_dims = dense<> : tensor<0xi64> - }, - unique_indices = false - } : (tensor<4x3xi32>, tensor<3x2xi32>, tensor<3xi32>) -> tensor<4x3xi32> - return %0 : tensor<4x3xi32> -} -// CHECK-LABEL: func @scatter_update_scalar_2D -// CHECK: %[[ARG0:[a-zA-Z0-9]+]] -// CHECK: %[[ARG1:[a-zA-Z0-9]+]] -// CHECK: %[[ARG2:[a-zA-Z0-9]+]] -// CHECK-DAG: %[[C1:.+]] = constant 1 : index -// CHECK: %[[RES:.+]] = flow.dispatch.workgroups -// CHECK-SAME: [%[[C1]], %[[C1]], %[[C1]]](%[[ARG0]], %[[ARG1]], %[[ARG2]]) -// CHECK-SAME: : (tensor<4x3xi32>, tensor<3x2xi32>, tensor<3xi32>) -> %[[ARG0]] -// CHECK: %[[ARG3:[a-zA-Z0-9]+]]: !flow.dispatch.tensor<readwrite:4x3xi32> -// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: !flow.dispatch.tensor<readonly:3x2xi32> -// CHECK-SAME: %[[ARG5:[a-zA-Z0-9]+]]: !flow.dispatch.tensor<readonly:3xi32> -// CHECK: %[[ORIGINAL:.+]] = flow.dispatch.tensor.load %[[ARG3]] -// CHECK: %[[INDICES:.+]] = flow.dispatch.tensor.load %[[ARG4]] -// CHECK: %[[UPDATES:.+]] = flow.dispatch.tensor.load %[[ARG5]] -// CHECK: %[[SCATTER:.+]] = linalg_ext.scatter -// CHECK-SAME: ins(%[[UPDATES]], %[[INDICES]] : tensor<3xi32>, tensor<3x2xi32>) -// CHECK-SAME: outs(%[[ORIGINAL]] : tensor<4x3xi32>) -// CHECK: ^bb0(%[[V1:.+]]: i32, %[[V2:.+]]: i32): // no predecessors -// CEECK: linalg.yield %[[V1]] -// CHECK: flow.dispatch.tensor.store %[[SCATTER]], %[[ARG3]] -// CHECK: return %[[RES]] - -// ----- - -func @scatter_update_slice_2D(%arg0: tensor<6x3xi32>, %arg1: tensor<2x1xi32>, - %arg2: tensor<2x3xi32>) -> tensor<6x3xi32> { - %0 = "mhlo.scatter"(%arg0, %arg1, %arg2) ( { - ^bb0(%arg3: tensor<i32>, %arg4: tensor<i32>): // no predecessors - "mhlo.return"(%arg4) : (tensor<i32>) -> () - }) { - indices_are_sorted = false, - scatter_dimension_numbers = { - index_vector_dim = 1 : i64, - inserted_window_dims = dense<0> : tensor<1xi64>, - scatter_dims_to_operand_dims = dense<0> : tensor<1xi64>, - update_window_dims = dense<1> : tensor<1xi64> - }, - unique_indices = false - } : (tensor<6x3xi32>, tensor<2x1xi32>, tensor<2x3xi32>) -> tensor<6x3xi32> - return %0 : tensor<6x3xi32> -} -// CHECK-LABEL: func @scatter_update_slice_2D -// CHECK: %[[ARG0:[a-zA-Z0-9]+]] -// CHECK: %[[ARG1:[a-zA-Z0-9]+]] -// CHECK: %[[ARG2:[a-zA-Z0-9]+]] -// CHECK-DAG: %[[C1:.+]] = constant 1 : index -// CHECK: %[[RES:.+]] = flow.dispatch.workgroups -// CHECK-SAME: [%[[C1]], %[[C1]], %[[C1]]](%[[ARG0]], %[[ARG1]], %[[ARG2]]) -// CHECK-SAME: : (tensor<6x3xi32>, tensor<2x1xi32>, tensor<2x3xi32>) -> %[[ARG0]] -// CHECK: %[[ARG3:[a-zA-Z0-9]+]]: !flow.dispatch.tensor<readwrite:6x3xi32> -// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: !flow.dispatch.tensor<readonly:2x1xi32> -// CHECK-SAME: %[[ARG5:[a-zA-Z0-9]+]]: !flow.dispatch.tensor<readonly:2x3xi32> -// CHECK: %[[ORIGINAL:.+]] = flow.dispatch.tensor.load %[[ARG3]] -// CHECK: %[[INDICES:.+]] = flow.dispatch.tensor.load %[[ARG4]] -// CHECK: %[[UPDATES:.+]] = flow.dispatch.tensor.load %[[ARG5]] -// CHECK: %[[SCATTER:.+]] = linalg_ext.scatter -// CHECK-SAME: ins(%[[UPDATES]], %[[INDICES]] : tensor<2x3xi32>, tensor<2x1xi32>) -// CHECK-SAME: outs(%[[ORIGINAL]] : tensor<6x3xi32>) -// CHECK: ^bb0(%[[V1:.+]]: i32, %[[V2:.+]]: i32): // no predecessors -// CEECK: linalg.yield %[[V1]] -// CHECK: flow.dispatch.tensor.store %[[SCATTER]], %[[ARG3]] -// CHECK: return %[[RES]] - -// ----- - -func @scatter_add_slice_2D(%arg0: tensor<6x3xi32>, %arg1: tensor<2x1xi32>, - %arg2: tensor<2x3xi32>) -> tensor<6x3xi32> { - %0 = "mhlo.scatter"(%arg0, %arg1, %arg2) ( { - ^bb0(%arg3: tensor<i32>, %arg4: tensor<i32>): // no predecessors - %1 = mhlo.add %arg3, %arg4 : tensor<i32> - "mhlo.return"(%1) : (tensor<i32>) -> () - }) { - indices_are_sorted = false, - scatter_dimension_numbers = { - index_vector_dim = 1 : i64, - inserted_window_dims = dense<0> : tensor<1xi64>, - scatter_dims_to_operand_dims = dense<0> : tensor<1xi64>, - update_window_dims = dense<1> : tensor<1xi64> - }, - unique_indices = false - } : (tensor<6x3xi32>, tensor<2x1xi32>, tensor<2x3xi32>) -> tensor<6x3xi32> - return %0 : tensor<6x3xi32> -} -// CHECK-LABEL: func @scatter_add_slice_2D -// CHECK: %[[ARG0:[a-zA-Z0-9]+]] -// CHECK: %[[ARG1:[a-zA-Z0-9]+]] -// CHECK: %[[ARG2:[a-zA-Z0-9]+]] -// CHECK-DAG: %[[C1:.+]] = constant 1 : index -// CHECK: %[[RES:.+]] = flow.dispatch.workgroups -// CHECK-SAME: [%[[C1]], %[[C1]], %[[C1]]](%[[ARG0]], %[[ARG1]], %[[ARG2]]) -// CHECK-SAME: : (tensor<6x3xi32>, tensor<2x1xi32>, tensor<2x3xi32>) -> %[[ARG0]] -// CHECK: %[[ARG3:[a-zA-Z0-9]+]]: !flow.dispatch.tensor<readwrite:6x3xi32> -// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: !flow.dispatch.tensor<readonly:2x1xi32> -// CHECK-SAME: %[[ARG5:[a-zA-Z0-9]+]]: !flow.dispatch.tensor<readonly:2x3xi32> -// CHECK: %[[ORIGINAL:.+]] = flow.dispatch.tensor.load %[[ARG3]] -// CHECK: %[[INDICES:.+]] = flow.dispatch.tensor.load %[[ARG4]] -// CHECK: %[[UPDATES:.+]] = flow.dispatch.tensor.load %[[ARG5]] -// CHECK: %[[SCATTER:.+]] = linalg_ext.scatter -// CHECK-SAME: ins(%[[UPDATES]], %[[INDICES]] : tensor<2x3xi32>, tensor<2x1xi32>) -// CHECK-SAME: outs(%[[ORIGINAL]] : tensor<6x3xi32>) -// CHECK: ^bb0(%[[V1:.+]]: i32, %[[V2:.+]]: i32): // no predecessors -// -// The order is reverse. -// CHECK: %[[V3:.+]] = addi %[[V2]], %[[V1]] -// CEECK: linalg.yield %[[V3]] -// CHECK: flow.dispatch.tensor.store %[[SCATTER]], %[[ARG3]] -// CHECK: return %[[RES]] - -// ----- - -func @scatter_update_batch_scalar_1D(%arg0: tensor<8xi32>, - %arg1: tensor<3x4x1xi32>, %arg2: tensor<3x4xi32>) -> tensor<8xi32> { - %0 = "mhlo.scatter"(%arg0, %arg1, %arg2) ( { - ^bb0(%arg3: tensor<i32>, %arg4: tensor<i32>): // no predecessors - "mhlo.return"(%arg4) : (tensor<i32>) -> () - }) { - indices_are_sorted = false, - scatter_dimension_numbers = { - index_vector_dim = 2 : i64, - inserted_window_dims = dense<0> : tensor<i64>, - scatter_dims_to_operand_dims = dense<0> : tensor<1xi64>, - update_window_dims = dense<> : tensor<0xi64> - }, - unique_indices = false - } : (tensor<8xi32>, tensor<3x4x1xi32>, tensor<3x4xi32>) -> tensor<8xi32> - return %0 : tensor<8xi32> -} -// CHECK-LABEL: func @scatter_update_batch_scalar_1D -// CHECK: %[[ARG0:[a-zA-Z0-9]+]] -// CHECK: %[[ARG1:[a-zA-Z0-9]+]] -// CHECK: %[[ARG2:[a-zA-Z0-9]+]] -// CHECK-DAG: %[[C1:.+]] = constant 1 : index -// CHECK: %[[RES:.+]] = flow.dispatch.workgroups -// CHECK-SAME: [%[[C1]], %[[C1]], %[[C1]]](%[[ARG0]], %[[ARG1]], %[[ARG2]]) -// CHECK-SAME: : (tensor<8xi32>, tensor<3x4x1xi32>, tensor<3x4xi32>) -> %[[ARG0]] -// CHECK: %[[ARG3:[a-zA-Z0-9]+]]: !flow.dispatch.tensor<readwrite:8xi32> -// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: !flow.dispatch.tensor<readonly:3x4x1xi32> -// CHECK-SAME: %[[ARG5:[a-zA-Z0-9]+]]: !flow.dispatch.tensor<readonly:3x4xi32> -// CHECK: %[[ORIGINAL:.+]] = flow.dispatch.tensor.load %[[ARG3]] -// CHECK: %[[INDICES:.+]] = flow.dispatch.tensor.load %[[ARG4]] -// CHECK: %[[UPDATES:.+]] = flow.dispatch.tensor.load %[[ARG5]] -// CHECK: %[[COLLAPSED_INDICES:.+]] = linalg.tensor_collapse_shape -// CHECK-SAME: %[[INDICES]] {{\[}}[0, 1], [2]] : tensor<3x4x1xi32> into tensor<12x1xi32> -// CHECK: %[[COLLAPSED_UPDATES:.+]] = linalg.tensor_collapse_shape -// CHECK-SAME: %[[UPDATES]] {{\[}}[0, 1]] : tensor<3x4xi32> into tensor<12xi32> -// CHECK: %[[SCATTER:.+]] = linalg_ext.scatter -// CHECK-SAME: ins(%[[COLLAPSED_UPDATES]], %[[COLLAPSED_INDICES]] : tensor<12xi32>, tensor<12x1xi32>) -// CHECK-SAME: outs(%[[ORIGINAL]] : tensor<8xi32>) -// CHECK: ^bb0(%[[V1:.+]]: i32, %[[V2:.+]]: i32): // no predecessors -// CEECK: linalg.yield %[[V1]] -// CHECK: flow.dispatch.tensor.store %[[SCATTER]], %[[ARG3]] -// CHECK: return %[[RES]] - -// ----- - -func @scatter_update_batch_slice_3D_dynamic(%arg0: tensor<1x24x512xi32>, - %arg1: tensor<?x3x2xi32>, %arg2: tensor<?x3x512xi32>) -> tensor<1x24x512xi32> { - %0 = "mhlo.scatter"(%arg0, %arg1, %arg2) ( { - ^bb0(%arg3: tensor<i32>, %arg4: tensor<i32>): // no predecessors - "mhlo.return"(%arg4) : (tensor<i32>) -> () - }) {indices_are_sorted = false, - scatter_dimension_numbers = { - index_vector_dim = 2 : i64, - inserted_window_dims = dense<[0, 1]> : tensor<2xi64>, - scatter_dims_to_operand_dims = dense<[0, 1]> : tensor<2xi64>, - update_window_dims = dense<2> : tensor<1xi64> - }, - unique_indices = false - } : (tensor<1x24x512xi32>, tensor<?x3x2xi32>, tensor<?x3x512xi32>) -> tensor<1x24x512xi32> - return %0 : tensor<1x24x512xi32> -} -// CHECK-LABEL: func @scatter_update_batch_slice_3D_dynamic -// CHECK: %[[ARG0:[a-zA-Z0-9]+]] -// CHECK: %[[ARG1:[a-zA-Z0-9]+]] -// CHECK: %[[ARG2:[a-zA-Z0-9]+]] -// CHECK-DAG: %[[C1:.+]] = constant 1 : index -// CHECK-DAG: %[[C0:.+]] = constant 0 : index -// CHECK-DAG: %[[DIM1:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x3x2xi32> -// CHECK-DAG: %[[C0:.+]] = constant 0 : index -// CHECK-DAG: %[[DIM2:.+]] = tensor.dim %[[ARG2]], %[[C0]] : tensor<?x3x512xi32> -// CHECK: %[[RES:.+]] = flow.dispatch.workgroups -// CHECK-SAME: [%[[C1]], %[[C1]], %[[C1]]](%[[ARG0]], %[[ARG1]], %[[ARG2]]) -// CHECK-SAME: : (tensor<1x24x512xi32>, tensor<?x3x2xi32>{%[[DIM1]]}, tensor<?x3x512xi32>{%[[DIM2]]}) -> %[[ARG0]] -// CHECK: %[[ARG3:[a-zA-Z0-9]+]]: !flow.dispatch.tensor<readwrite:1x24x512xi32> -// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: !flow.dispatch.tensor<readonly:?x3x2xi32> -// CHECK-SAME: %[[ARG5:[a-zA-Z0-9]+]]: !flow.dispatch.tensor<readonly:?x3x512xi32> -// CHECK: %[[ORIGINAL:.+]] = flow.dispatch.tensor.load %[[ARG3]] -// CHECK: %[[INDICES:.+]] = flow.dispatch.tensor.load %[[ARG4]] -// CHECK: %[[UPDATES:.+]] = flow.dispatch.tensor.load %[[ARG5]] -// CHECK: %[[COLLAPSED_INDICES:.+]] = linalg.tensor_collapse_shape -// CHECK-SAME: %[[INDICES]] {{\[}}[0, 1], [2]] : tensor<?x3x2xi32> into tensor<?x2xi32> -// CHECK: %[[COLLAPSED_UPDATES:.+]] = linalg.tensor_collapse_shape -// CHECK-SAME: %[[UPDATES]] {{\[}}[0, 1], [2]] : tensor<?x3x512xi32> into tensor<?x512xi32> -// CHECK: %[[SCATTER:.+]] = linalg_ext.scatter -// CHECK-SAME: ins(%[[COLLAPSED_UPDATES]], %[[COLLAPSED_INDICES]] : tensor<?x512xi32>, tensor<?x2xi32>) -// CHECK-SAME: outs(%[[ORIGINAL]] : tensor<1x24x512xi32>) -// CHECK: ^bb0(%[[V1:.+]]: i32, %[[V2:.+]]: i32): // no predecessors -// CEECK: linalg.yield %[[V1]] -// CHECK: flow.dispatch.tensor.store %[[SCATTER]], %[[ARG3]] -// CHECK: return %[[RES]]
diff --git a/iree/compiler/InputConversion/MHLO/test/convert_mhlo_to_linalg_ext.mlir b/iree/compiler/InputConversion/MHLO/test/convert_mhlo_to_linalg_ext.mlir new file mode 100644 index 0000000..e25865b --- /dev/null +++ b/iree/compiler/InputConversion/MHLO/test/convert_mhlo_to_linalg_ext.mlir
@@ -0,0 +1,250 @@ +// RUN: iree-opt -split-input-file -iree-mhlo-to-linalg-ext %s | IreeFileCheck %s + +func @sort_1d(%arg0: tensor<128xi32>) -> (tensor<128xi32>) { + %0 = "mhlo.sort"(%arg0) ( { + ^bb0(%arg2: tensor<i32>, %arg3: tensor<i32>): // no predecessors + %1 = "mhlo.compare"(%arg2, %arg3) {comparison_direction = "GT"} : (tensor<i32>, tensor<i32>) -> tensor<i1> + "mhlo.return"(%1) : (tensor<i1>) -> () + }) {dimension = 0 : i64, is_stable = false} : (tensor<128xi32>) -> (tensor<128xi32>) + return %0 : tensor<128xi32> +} +// CHECK-LABEL: func @sort_1d +// CHECK: %[[ARG0:[a-zA-Z0-9]+]] +// CHECK: %[[SORT:.+]] = linalg_ext.sort +// CHECK-SAME: dimension(0) +// CHECK-SAME: outs(%[[ARG0]] : tensor<128xi32>) +// CHECK: ^bb0(%[[ARG1:.+]]: i32, %[[ARG2:.+]]: i32) +// CHECK: %[[CMP:.+]] = cmpi sgt, %[[ARG1]], %[[ARG2]] +// CHECK: linalg_ext.yield %[[CMP]] +// CHECK: return %[[SORT]] + +// ----- + +func @sort_2d(%arg0: tensor<16x32xi32>) -> (tensor<16x32xi32>) { + %0 = "mhlo.sort"(%arg0) ( { + ^bb0(%arg2: tensor<i32>, %arg3: tensor<i32>): // no predecessors + %1 = "mhlo.compare"(%arg2, %arg3) {comparison_direction = "GT"} : (tensor<i32>, tensor<i32>) -> tensor<i1> + "mhlo.return"(%1) : (tensor<i1>) -> () + }) {dimension = 0 : i64, is_stable = false} : (tensor<16x32xi32>) -> (tensor<16x32xi32>) + return %0 : tensor<16x32xi32> +} +// CHECK-LABEL: func @sort_2d +// CHECK: %[[ARG0:[a-zA-Z0-9]+]] +// CHECK: %[[SORT:.+]] = linalg_ext.sort +// CHECK-SAME: dimension(0) +// CHECK-SAME: outs(%[[ARG0]] : tensor<16x32xi32>) +// CHECK: ^bb0(%[[ARG1:.+]]: i32, %[[ARG2:.+]]: i32) +// CHECK: %[[CMP:.+]] = cmpi sgt, %[[ARG1]], %[[ARG2]] +// CHECK: linalg_ext.yield %[[CMP]] +// CHECK: return %[[SORT]] + +// ----- + +func @topk(%arg0: tensor<128xi32>, %arg1: tensor<128xi32>) -> (tensor<128xi32>) { + %0:2 = "mhlo.sort"(%arg0, %arg1) ( { + ^bb0(%arg2: tensor<i32>, %arg3: tensor<i32>, %arg4: tensor<i32>, %arg5: tensor<i32>): // no predecessors + %1 = "mhlo.compare"(%arg2, %arg3) {comparison_direction = "GT"} : (tensor<i32>, tensor<i32>) -> tensor<i1> + "mhlo.return"(%1) : (tensor<i1>) -> () + }) {dimension = 0 : i64, is_stable = false} : (tensor<128xi32>, tensor<128xi32>) -> (tensor<128xi32>, tensor<128xi32>) + return %0#0 : tensor<128xi32> +} +// CHECK-LABEL: func @topk +// CHECK: %[[ARG0:[a-zA-Z0-9]+]] +// CHECK: %[[ARG1:[a-zA-Z0-9]+]] +// CHECK: %[[SORT:.+]]:2 = linalg_ext.sort +// CHECK-SAME: dimension(0) +// CHECK-SAME: outs(%[[ARG0]], %[[ARG1]] : tensor<128xi32>, tensor<128xi32>) +// CHECK: ^bb0(%[[ARG2:.+]]: i32, %[[ARG3:.+]]: i32, %{{.*}}: i32, %{{.*}}: i32) +// CHECK: %[[CMP:.+]] = cmpi sgt, %[[ARG2]], %[[ARG3]] +// CHECK: linalg_ext.yield %[[CMP]] +// CHECK: return %[[SORT]]#0 + +// ----- + +func @scatter_update_scalar_1D(%arg0: tensor<8xi32>, %arg1: tensor<4x1xi32>, + %arg2: tensor<4xi32>) -> tensor<8xi32> { + %0 = "mhlo.scatter"(%arg0, %arg1, %arg2) ( { + ^bb0(%arg3: tensor<i32>, %arg4: tensor<i32>): // no predecessors + "mhlo.return"(%arg4) : (tensor<i32>) -> () + }) { + indices_are_sorted = false, + scatter_dimension_numbers = { + index_vector_dim = 1 : i64, + inserted_window_dims = dense<0> : tensor<1xi64>, + scatter_dims_to_operand_dims = dense<0> : tensor<1xi64>, + update_window_dims = dense<> : tensor<0xi64> + }, + unique_indices = false + } : (tensor<8xi32>, tensor<4x1xi32>, tensor<4xi32>) -> tensor<8xi32> + return %0 : tensor<8xi32> +} +// CHECK-LABEL: func @scatter_update_scalar_1D +// CHECK: %[[ARG0:[a-zA-Z0-9]+]] +// CHECK: %[[ARG1:[a-zA-Z0-9]+]] +// CHECK: %[[ARG2:[a-zA-Z0-9]+]] +// CHECK: %[[SCATTER:.+]] = linalg_ext.scatter +// CHECK-SAME: ins(%[[ARG2]], %[[ARG1]] : tensor<4xi32>, tensor<4x1xi32>) +// CHECK-SAME: outs(%[[ARG0]] : tensor<8xi32>) +// CHECK: ^bb0(%[[V1:.+]]: i32, %[[V2:.+]]: i32): // no predecessors +// CEECK: linalg.yield %[[V1]] +// CHECK: return %[[SCATTER]] + +// ----- + +func @scatter_update_scalar_2D(%arg0: tensor<4x3xi32>, %arg1: tensor<3x2xi32>, + %arg2: tensor<3xi32>) -> tensor<4x3xi32> { + %0 = "mhlo.scatter"(%arg0, %arg1, %arg2) ( { + ^bb0(%arg3: tensor<i32>, %arg4: tensor<i32>): // no predecessors + "mhlo.return"(%arg4) : (tensor<i32>) -> () + }) {indices_are_sorted = false, + scatter_dimension_numbers = { + index_vector_dim = 1 : i64, + inserted_window_dims = dense<[0, 1]> : tensor<2xi64>, + scatter_dims_to_operand_dims = dense<[0, 1]> : tensor<2xi64>, + update_window_dims = dense<> : tensor<0xi64> + }, + unique_indices = false + } : (tensor<4x3xi32>, tensor<3x2xi32>, tensor<3xi32>) -> tensor<4x3xi32> + return %0 : tensor<4x3xi32> +} +// CHECK-LABEL: func @scatter_update_scalar_2D +// CHECK: %[[ARG0:[a-zA-Z0-9]+]] +// CHECK: %[[ARG1:[a-zA-Z0-9]+]] +// CHECK: %[[ARG2:[a-zA-Z0-9]+]] +// CHECK: %[[SCATTER:.+]] = linalg_ext.scatter +// CHECK-SAME: ins(%[[ARG2]], %[[ARG1]] : tensor<3xi32>, tensor<3x2xi32>) +// CHECK-SAME: outs(%[[ARG0]] : tensor<4x3xi32>) +// CHECK: ^bb0(%[[V1:.+]]: i32, %[[V2:.+]]: i32): // no predecessors +// CEECK: linalg.yield %[[V1]] +// CHECK: return %[[SCATTER]] + +// ----- + +func @scatter_update_slice_2D(%arg0: tensor<6x3xi32>, %arg1: tensor<2x1xi32>, + %arg2: tensor<2x3xi32>) -> tensor<6x3xi32> { + %0 = "mhlo.scatter"(%arg0, %arg1, %arg2) ( { + ^bb0(%arg3: tensor<i32>, %arg4: tensor<i32>): // no predecessors + "mhlo.return"(%arg4) : (tensor<i32>) -> () + }) { + indices_are_sorted = false, + scatter_dimension_numbers = { + index_vector_dim = 1 : i64, + inserted_window_dims = dense<0> : tensor<1xi64>, + scatter_dims_to_operand_dims = dense<0> : tensor<1xi64>, + update_window_dims = dense<1> : tensor<1xi64> + }, + unique_indices = false + } : (tensor<6x3xi32>, tensor<2x1xi32>, tensor<2x3xi32>) -> tensor<6x3xi32> + return %0 : tensor<6x3xi32> +} +// CHECK-LABEL: func @scatter_update_slice_2D +// CHECK: %[[ARG0:[a-zA-Z0-9]+]] +// CHECK: %[[ARG1:[a-zA-Z0-9]+]] +// CHECK: %[[ARG2:[a-zA-Z0-9]+]] +// CHECK: %[[SCATTER:.+]] = linalg_ext.scatter +// CHECK-SAME: ins(%[[ARG2]], %[[ARG1]] : tensor<2x3xi32>, tensor<2x1xi32>) +// CHECK-SAME: outs(%[[ARG0]] : tensor<6x3xi32>) +// CHECK: ^bb0(%[[V1:.+]]: i32, %[[V2:.+]]: i32): // no predecessors +// CEECK: linalg.yield %[[V1]] +// CHECK: return %[[SCATTER]] + +// ----- + +func @scatter_add_slice_2D(%arg0: tensor<6x3xi32>, %arg1: tensor<2x1xi32>, + %arg2: tensor<2x3xi32>) -> tensor<6x3xi32> { + %0 = "mhlo.scatter"(%arg0, %arg1, %arg2) ( { + ^bb0(%arg3: tensor<i32>, %arg4: tensor<i32>): // no predecessors + %1 = mhlo.add %arg3, %arg4 : tensor<i32> + "mhlo.return"(%1) : (tensor<i32>) -> () + }) { + indices_are_sorted = false, + scatter_dimension_numbers = { + index_vector_dim = 1 : i64, + inserted_window_dims = dense<0> : tensor<1xi64>, + scatter_dims_to_operand_dims = dense<0> : tensor<1xi64>, + update_window_dims = dense<1> : tensor<1xi64> + }, + unique_indices = false + } : (tensor<6x3xi32>, tensor<2x1xi32>, tensor<2x3xi32>) -> tensor<6x3xi32> + return %0 : tensor<6x3xi32> +} +// CHECK-LABEL: func @scatter_add_slice_2D +// CHECK: %[[ARG0:[a-zA-Z0-9]+]] +// CHECK: %[[ARG1:[a-zA-Z0-9]+]] +// CHECK: %[[ARG2:[a-zA-Z0-9]+]] +// CHECK: %[[SCATTER:.+]] = linalg_ext.scatter +// CHECK-SAME: ins(%[[ARG2]], %[[ARG1]] : tensor<2x3xi32>, tensor<2x1xi32>) +// CHECK-SAME: outs(%[[ARG0]] : tensor<6x3xi32>) +// CHECK: ^bb0(%[[V1:.+]]: i32, %[[V2:.+]]: i32): // no predecessors +// +// The order is reverse. +// CHECK: %[[V3:.+]] = addi %[[V2]], %[[V1]] +// CEECK: linalg.yield %[[V3]] +// CHECK: return %[[SCATTER]] + +// ----- + +func @scatter_update_batch_scalar_1D(%arg0: tensor<8xi32>, + %arg1: tensor<3x4x1xi32>, %arg2: tensor<3x4xi32>) -> tensor<8xi32> { + %0 = "mhlo.scatter"(%arg0, %arg1, %arg2) ( { + ^bb0(%arg3: tensor<i32>, %arg4: tensor<i32>): // no predecessors + "mhlo.return"(%arg4) : (tensor<i32>) -> () + }) { + indices_are_sorted = false, + scatter_dimension_numbers = { + index_vector_dim = 2 : i64, + inserted_window_dims = dense<0> : tensor<i64>, + scatter_dims_to_operand_dims = dense<0> : tensor<1xi64>, + update_window_dims = dense<> : tensor<0xi64> + }, + unique_indices = false + } : (tensor<8xi32>, tensor<3x4x1xi32>, tensor<3x4xi32>) -> tensor<8xi32> + return %0 : tensor<8xi32> +} +// CHECK-LABEL: func @scatter_update_batch_scalar_1D +// CHECK: %[[ARG0:[a-zA-Z0-9]+]] +// CHECK: %[[ARG1:[a-zA-Z0-9]+]] +// CHECK: %[[ARG2:[a-zA-Z0-9]+]] +// CHECK: %[[COLLAPSED_INDICES:.+]] = linalg.tensor_collapse_shape +// CHECK-SAME: %[[ARG1]] {{\[}}[0, 1], [2]] : tensor<3x4x1xi32> into tensor<12x1xi32> +// CHECK: %[[COLLAPSED_UPDATES:.+]] = linalg.tensor_collapse_shape +// CHECK-SAME: %[[ARG2]] {{\[}}[0, 1]] : tensor<3x4xi32> into tensor<12xi32> +// CHECK: %[[SCATTER:.+]] = linalg_ext.scatter +// CHECK-SAME: ins(%[[COLLAPSED_UPDATES]], %[[COLLAPSED_INDICES]] : tensor<12xi32>, tensor<12x1xi32>) +// CHECK-SAME: outs(%[[ARG0]] : tensor<8xi32>) +// CHECK: ^bb0(%[[V1:.+]]: i32, %[[V2:.+]]: i32): // no predecessors +// CEECK: linalg.yield %[[V1]] +// CHECK: return %[[SCATTER]] + +// ----- + +func @scatter_update_batch_slice_3D_dynamic(%arg0: tensor<1x24x512xi32>, + %arg1: tensor<?x3x2xi32>, %arg2: tensor<?x3x512xi32>) -> tensor<1x24x512xi32> { + %0 = "mhlo.scatter"(%arg0, %arg1, %arg2) ( { + ^bb0(%arg3: tensor<i32>, %arg4: tensor<i32>): // no predecessors + "mhlo.return"(%arg4) : (tensor<i32>) -> () + }) {indices_are_sorted = false, + scatter_dimension_numbers = { + index_vector_dim = 2 : i64, + inserted_window_dims = dense<[0, 1]> : tensor<2xi64>, + scatter_dims_to_operand_dims = dense<[0, 1]> : tensor<2xi64>, + update_window_dims = dense<2> : tensor<1xi64> + }, + unique_indices = false + } : (tensor<1x24x512xi32>, tensor<?x3x2xi32>, tensor<?x3x512xi32>) -> tensor<1x24x512xi32> + return %0 : tensor<1x24x512xi32> +} +// CHECK-LABEL: func @scatter_update_batch_slice_3D_dynamic +// CHECK: %[[ARG0:[a-zA-Z0-9]+]] +// CHECK: %[[ARG1:[a-zA-Z0-9]+]] +// CHECK: %[[ARG2:[a-zA-Z0-9]+]] +// CHECK: %[[COLLAPSED_INDICES:.+]] = linalg.tensor_collapse_shape +// CHECK-SAME: %[[ARG1]] {{\[}}[0, 1], [2]] : tensor<?x3x2xi32> into tensor<?x2xi32> +// CHECK: %[[COLLAPSED_UPDATES:.+]] = linalg.tensor_collapse_shape +// CHECK-SAME: %[[ARG2]] {{\[}}[0, 1], [2]] : tensor<?x3x512xi32> into tensor<?x512xi32> +// CHECK: %[[SCATTER:.+]] = linalg_ext.scatter +// CHECK-SAME: ins(%[[COLLAPSED_UPDATES]], %[[COLLAPSED_INDICES]] : tensor<?x512xi32>, tensor<?x2xi32>) +// CHECK-SAME: outs(%[[ARG0]] : tensor<1x24x512xi32>) +// CHECK: ^bb0(%[[V1:.+]]: i32, %[[V2:.+]]: i32): // no predecessors +// CEECK: linalg.yield %[[V1]] +// CHECK: return %[[SCATTER]]
diff --git a/iree/hal/allocator_heap.c b/iree/hal/allocator_heap.c index 7951398..21b7f3e 100644 --- a/iree/hal/allocator_heap.c +++ b/iree/hal/allocator_heap.c
@@ -28,7 +28,8 @@ IREE_TRACE_ZONE_BEGIN(z0); iree_hal_heap_allocator_t* allocator = NULL; - iree_host_size_t total_size = sizeof(*allocator) + identifier.size; + iree_host_size_t total_size = + iree_sizeof_struct(*allocator) + identifier.size; iree_status_t status = iree_allocator_malloc(host_allocator, total_size, (void**)&allocator); if (iree_status_is_ok(status)) { @@ -37,7 +38,7 @@ allocator->host_allocator = host_allocator; iree_string_view_append_to_buffer( identifier, &allocator->identifier, - (char*)allocator + total_size - identifier.size); + (char*)allocator + iree_sizeof_struct(*allocator)); *out_allocator = (iree_hal_allocator_t*)allocator; }
diff --git a/iree/hal/buffer_heap.c b/iree/hal/buffer_heap.c index 7621667..9d74cc1 100644 --- a/iree/hal/buffer_heap.c +++ b/iree/hal/buffer_heap.c
@@ -32,8 +32,10 @@ IREE_ASSERT_ARGUMENT(out_buffer); IREE_TRACE_ZONE_BEGIN(z0); + // NOTE: we want the buffer data to always be 16-byte aligned. iree_hal_heap_buffer_t* buffer = NULL; - iree_host_size_t header_size = iree_host_align(sizeof(*buffer), 16); + iree_host_size_t header_size = + iree_host_align(iree_sizeof_struct(*buffer), 16); iree_host_size_t total_size = header_size + allocation_size; iree_status_t status = iree_allocator_malloc(host_allocator, total_size, (void**)&buffer);
diff --git a/iree/hal/cuda/BUILD b/iree/hal/cuda/BUILD index 78685b7..6ebe4c9 100644 --- a/iree/hal/cuda/BUILD +++ b/iree/hal/cuda/BUILD
@@ -48,6 +48,8 @@ "nop_executable_cache.h", "status_util.c", "status_util.h", + "stream_command_buffer.c", + "stream_command_buffer.h", ], hdrs = [ "api.h", @@ -59,9 +61,11 @@ "//iree/base:core_headers", "//iree/base:tracing", "//iree/base/internal", + "//iree/base/internal:arena", "//iree/base/internal:flatcc", "//iree/base/internal:synchronization", "//iree/hal", + "//iree/hal/utils:deferred_command_buffer", "//iree/schemas:cuda_executable_def_c_fbs", ], )
diff --git a/iree/hal/cuda/CMakeLists.txt b/iree/hal/cuda/CMakeLists.txt index 997a146..829c381 100644 --- a/iree/hal/cuda/CMakeLists.txt +++ b/iree/hal/cuda/CMakeLists.txt
@@ -45,15 +45,19 @@ "nop_executable_cache.h" "status_util.c" "status_util.h" + "stream_command_buffer.c" + "stream_command_buffer.h" DEPS ::dynamic_symbols iree::base iree::base::core_headers iree::base::internal + iree::base::internal::arena iree::base::internal::flatcc iree::base::internal::synchronization iree::base::tracing iree::hal + iree::hal::utils::deferred_command_buffer iree::schemas::cuda_executable_def_c_fbs PUBLIC )
diff --git a/iree/hal/cuda/api.h b/iree/hal/cuda/api.h index 0a3cd1d..0a7ec62 100644 --- a/iree/hal/cuda/api.h +++ b/iree/hal/cuda/api.h
@@ -16,6 +16,27 @@ extern "C" { #endif // __cplusplus +// Parameters configuring an iree_hal_cuda_device_t. +// Must be initialized with iree_hal_cuda_device_params_initialize prior to use. +typedef struct iree_hal_cuda_device_params_t { + // Number of queues exposed on the device. + // Each queue acts as a separate synchronization scope where all work executes + // concurrently unless prohibited by semaphores. + iree_host_size_t queue_count; + + // Total size of each block in the device shared block pool. + // Larger sizes will lower overhead and ensure the heap isn't hit for + // transient allocations while also increasing memory consumption. + iree_host_size_t arena_block_size; + + // Switch for using deferred command buffer or default graph command buffer + bool use_deferred_submission; +} iree_hal_cuda_device_params_t; + +// Initializes |out_params| to default values. +void iree_hal_cuda_device_params_initialize( + iree_hal_cuda_device_params_t* out_params); + //===----------------------------------------------------------------------===// // iree_hal_cuda_driver_t //===----------------------------------------------------------------------===// @@ -35,6 +56,7 @@ // |out_driver| must be released by the caller (see |iree_hal_driver_release|). IREE_API_EXPORT iree_status_t iree_hal_cuda_driver_create( iree_string_view_t identifier, + const iree_hal_cuda_device_params_t* default_params, const iree_hal_cuda_driver_options_t* options, iree_allocator_t host_allocator, iree_hal_driver_t** out_driver);
diff --git a/iree/hal/cuda/cuda_buffer.c b/iree/hal/cuda/cuda_buffer.c index a87eb24..63334a5 100644 --- a/iree/hal/cuda/cuda_buffer.c +++ b/iree/hal/cuda/cuda_buffer.c
@@ -59,7 +59,7 @@ } IREE_TRACE_ZONE_END(z0); - return iree_ok_status(); + return status; } static void iree_hal_cuda_buffer_destroy(iree_hal_buffer_t* base_buffer) { @@ -95,7 +95,7 @@ // would only work if the entire buffer was discarded. #ifndef NDEBUG if (iree_any_bit_set(memory_access, IREE_HAL_MEMORY_ACCESS_DISCARD)) { - memset(data_ptr + local_byte_offset, 0xCD, local_byte_length); + memset(data_ptr, 0xCD, local_byte_length); } #endif // !NDEBUG *out_data_ptr = data_ptr;
diff --git a/iree/hal/cuda/cuda_device.c b/iree/hal/cuda/cuda_device.c index d78d21f..0239b96 100644 --- a/iree/hal/cuda/cuda_device.c +++ b/iree/hal/cuda/cuda_device.c
@@ -10,6 +10,7 @@ #include <stdint.h> #include <string.h> +#include "iree/base/internal/arena.h" #include "iree/base/tracing.h" #include "iree/hal/cuda/context_wrapper.h" #include "iree/hal/cuda/cuda_allocator.h" @@ -21,6 +22,8 @@ #include "iree/hal/cuda/graph_command_buffer.h" #include "iree/hal/cuda/nop_executable_cache.h" #include "iree/hal/cuda/status_util.h" +#include "iree/hal/cuda/stream_command_buffer.h" +#include "iree/hal/utils/deferred_command_buffer.h" //===----------------------------------------------------------------------===// // iree_hal_cuda_device_t @@ -30,6 +33,10 @@ iree_hal_resource_t resource; iree_string_view_t identifier; + // Block pool used for command buffers with a larger block size (as command + // buffers can contain inlined data uploads). + iree_arena_block_pool_t block_pool; + // Optional driver that owns the CUDA symbols. We retain it for our lifetime // to ensure the symbols remains valid. iree_hal_driver_t* driver; @@ -41,6 +48,8 @@ iree_hal_cuda_context_wrapper_t context_wrapper; iree_hal_allocator_t* device_allocator; + // Switch for using deferred command buffer or default graph command buffer + bool use_deferred_submission; } iree_hal_cuda_device_t; extern const iree_hal_device_vtable_t iree_hal_cuda_device_vtable; @@ -51,6 +60,26 @@ return (iree_hal_cuda_device_t*)base_value; } +void iree_hal_cuda_device_params_initialize( + iree_hal_cuda_device_params_t* out_params) { + out_params->arena_block_size = 32 * 1024; + out_params->queue_count = 8; + out_params->use_deferred_submission = false; +} + +static iree_status_t iree_hal_cuda_device_check_params( + const iree_hal_cuda_device_params_t* params) { + if (params->arena_block_size < 4096) { + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "arena block size too small (< 4096 bytes)"); + } + if (params->queue_count == 0) { + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "at least one queue is required"); + } + return iree_ok_status(); +} + static void iree_hal_cuda_device_destroy(iree_hal_device_t* base_device) { iree_hal_cuda_device_t* device = iree_hal_cuda_device_cast(base_device); iree_allocator_t host_allocator = iree_hal_device_host_allocator(base_device); @@ -61,6 +90,7 @@ CUDA_IGNORE_ERROR(device->context_wrapper.syms, cuStreamDestroy(device->stream)); + iree_arena_block_pool_deinitialize(&device->block_pool); // Finally, destroy the device. iree_hal_driver_release(device->driver); @@ -71,25 +101,28 @@ static iree_status_t iree_hal_cuda_device_create_internal( iree_hal_driver_t* driver, iree_string_view_t identifier, - CUdevice cu_device, CUstream stream, CUcontext context, - iree_hal_cuda_dynamic_symbols_t* syms, iree_allocator_t host_allocator, - iree_hal_device_t** out_device) { + const iree_hal_cuda_device_params_t* params, CUdevice cu_device, + CUstream stream, CUcontext context, iree_hal_cuda_dynamic_symbols_t* syms, + iree_allocator_t host_allocator, iree_hal_device_t** out_device) { iree_hal_cuda_device_t* device = NULL; - iree_host_size_t total_size = sizeof(*device) + identifier.size; + iree_host_size_t total_size = iree_sizeof_struct(*device) + identifier.size; IREE_RETURN_IF_ERROR( iree_allocator_malloc(host_allocator, total_size, (void**)&device)); memset(device, 0, total_size); iree_hal_resource_initialize(&iree_hal_cuda_device_vtable, &device->resource); device->driver = driver; iree_hal_driver_retain(device->driver); - uint8_t* buffer_ptr = (uint8_t*)device + sizeof(*device); - buffer_ptr += iree_string_view_append_to_buffer( - identifier, &device->identifier, (char*)buffer_ptr); + iree_string_view_append_to_buffer( + identifier, &device->identifier, + (char*)device + iree_sizeof_struct(*device)); device->device = cu_device; device->stream = stream; device->context_wrapper.cu_context = context; device->context_wrapper.host_allocator = host_allocator; + iree_arena_block_pool_initialize(params->arena_block_size, host_allocator, + &device->block_pool); device->context_wrapper.syms = syms; + device->use_deferred_submission = params->use_deferred_submission; iree_status_t status = iree_hal_cuda_allocator_create( &device->context_wrapper, &device->device_allocator); if (iree_status_is_ok(status)) { @@ -100,13 +133,15 @@ return status; } -iree_status_t iree_hal_cuda_device_create(iree_hal_driver_t* driver, - iree_string_view_t identifier, - iree_hal_cuda_dynamic_symbols_t* syms, - CUdevice device, - iree_allocator_t host_allocator, - iree_hal_device_t** out_device) { +iree_status_t iree_hal_cuda_device_create( + iree_hal_driver_t* driver, iree_string_view_t identifier, + const iree_hal_cuda_device_params_t* params, + iree_hal_cuda_dynamic_symbols_t* syms, CUdevice device, + iree_allocator_t host_allocator, iree_hal_device_t** out_device) { + IREE_ASSERT_ARGUMENT(params); IREE_TRACE_ZONE_BEGIN(z0); + IREE_RETURN_AND_END_ZONE_IF_ERROR(z0, + iree_hal_cuda_device_check_params(params)); CUcontext context; IREE_RETURN_AND_END_ZONE_IF_ERROR( z0, CU_RESULT_TO_STATUS(syms, cuCtxCreate(&context, 0, device))); @@ -115,8 +150,8 @@ syms, cuStreamCreate(&stream, CU_STREAM_NON_BLOCKING)); if (iree_status_is_ok(status)) { - status = iree_hal_cuda_device_create_internal(driver, identifier, device, - stream, context, syms, + status = iree_hal_cuda_device_create_internal(driver, identifier, params, + device, stream, context, syms, host_allocator, out_device); } if (!iree_status_is_ok(status)) { @@ -174,7 +209,12 @@ iree_hal_queue_affinity_t queue_affinity, iree_hal_command_buffer_t** out_command_buffer) { iree_hal_cuda_device_t* device = iree_hal_cuda_device_cast(base_device); - return iree_hal_cuda_graph_command_buffer_allocate( + if (device->use_deferred_submission) { + return iree_hal_deferred_command_buffer_create( + mode, command_categories, &device->block_pool, + iree_hal_device_host_allocator(base_device), out_command_buffer); + } + return iree_hal_cuda_graph_command_buffer_create( &device->context_wrapper, mode, command_categories, queue_affinity, out_command_buffer); } @@ -240,13 +280,30 @@ iree_hal_queue_affinity_t queue_affinity, iree_host_size_t batch_count, const iree_hal_submission_batch_t* batches) { iree_hal_cuda_device_t* device = iree_hal_cuda_device_cast(base_device); - for (int i = 0; i < batch_count; i++) { - for (int j = 0; j < batches[i].command_buffer_count; j++) { - CUgraphExec exec = iree_hal_cuda_graph_command_buffer_exec( - batches[i].command_buffers[j]); - CUDA_RETURN_IF_ERROR(device->context_wrapper.syms, - cuGraphLaunch(exec, device->stream), - "cuGraphLaunch"); + if (device->use_deferred_submission) { + iree_hal_command_buffer_t* stream_command_buffer; + iree_status_t status = iree_hal_cuda_stream_command_buffer_create( + &device->context_wrapper, + IREE_HAL_COMMAND_BUFFER_MODE_ALLOW_INLINE_EXECUTION, command_categories, + device->stream, &stream_command_buffer); + if (iree_status_is_ok(status)) { + for (int i = 0; i < batch_count; i++) { + for (int j = 0; j < batches[i].command_buffer_count; j++) { + iree_hal_deferred_command_buffer_apply(batches[i].command_buffers[j], + stream_command_buffer); + } + } + } + iree_hal_command_buffer_release(stream_command_buffer); + } else { + for (int i = 0; i < batch_count; i++) { + for (int j = 0; j < batches[i].command_buffer_count; j++) { + CUgraphExec exec = iree_hal_cuda_graph_command_buffer_exec( + batches[i].command_buffers[j]); + CUDA_RETURN_IF_ERROR(device->context_wrapper.syms, + cuGraphLaunch(exec, device->stream), + "cuGraphLaunch"); + } } } // TODO(thomasraoux): Conservatively syncronize after every submit until we
diff --git a/iree/hal/cuda/cuda_device.h b/iree/hal/cuda/cuda_device.h index b6b47af..d7b5790 100644 --- a/iree/hal/cuda/cuda_device.h +++ b/iree/hal/cuda/cuda_device.h
@@ -17,12 +17,11 @@ #endif // __cplusplus // Creates a device that owns and manages its own CUcontext. -iree_status_t iree_hal_cuda_device_create(iree_hal_driver_t* driver, - iree_string_view_t identifier, - iree_hal_cuda_dynamic_symbols_t* syms, - CUdevice device, - iree_allocator_t host_allocator, - iree_hal_device_t** out_device); +iree_status_t iree_hal_cuda_device_create( + iree_hal_driver_t* driver, iree_string_view_t identifier, + const iree_hal_cuda_device_params_t* params, + iree_hal_cuda_dynamic_symbols_t* syms, CUdevice device, + iree_allocator_t host_allocator, iree_hal_device_t** out_device); #ifdef __cplusplus } // extern "C"
diff --git a/iree/hal/cuda/cuda_driver.c b/iree/hal/cuda/cuda_driver.c index 6e5617d..8903775 100644 --- a/iree/hal/cuda/cuda_driver.c +++ b/iree/hal/cuda/cuda_driver.c
@@ -22,6 +22,7 @@ // We allow overriding so that multiple CUDA versions can be exposed in the // same process. iree_string_view_t identifier; + iree_hal_cuda_device_params_t default_params; int default_device_index; // CUDA symbols. iree_hal_cuda_dynamic_symbols_t syms; @@ -46,18 +47,23 @@ static iree_status_t iree_hal_cuda_driver_create_internal( iree_string_view_t identifier, + const iree_hal_cuda_device_params_t* default_params, const iree_hal_cuda_driver_options_t* options, iree_allocator_t host_allocator, iree_hal_driver_t** out_driver) { iree_hal_cuda_driver_t* driver = NULL; - iree_host_size_t total_size = sizeof(*driver) + identifier.size; + iree_host_size_t total_size = iree_sizeof_struct(*driver) + identifier.size; IREE_RETURN_IF_ERROR( iree_allocator_malloc(host_allocator, total_size, (void**)&driver)); + iree_hal_resource_initialize(&iree_hal_cuda_driver_vtable, &driver->resource); driver->host_allocator = host_allocator; iree_string_view_append_to_buffer( identifier, &driver->identifier, - (char*)driver + total_size - identifier.size); + (char*)driver + iree_sizeof_struct(*driver)); + memcpy(&driver->default_params, default_params, + sizeof(driver->default_params)); driver->default_device_index = options->default_device_index; + iree_status_t status = iree_hal_cuda_dynamic_symbols_initialize(host_allocator, &driver->syms); if (iree_status_is_ok(status)) { @@ -81,14 +87,16 @@ IREE_API_EXPORT iree_status_t iree_hal_cuda_driver_create( iree_string_view_t identifier, + const iree_hal_cuda_device_params_t* default_params, const iree_hal_cuda_driver_options_t* options, iree_allocator_t host_allocator, iree_hal_driver_t** out_driver) { + IREE_ASSERT_ARGUMENT(default_params); IREE_ASSERT_ARGUMENT(options); IREE_ASSERT_ARGUMENT(out_driver); IREE_TRACE_ZONE_BEGIN(z0); iree_status_t status = iree_hal_cuda_driver_create_internal( - identifier, options, host_allocator, out_driver); + identifier, default_params, options, host_allocator, out_driver); IREE_TRACE_ZONE_END(z0); return status; @@ -193,9 +201,9 @@ iree_string_view_t device_name = iree_make_cstring_view("cuda"); // Attempt to create the device. - iree_status_t status = - iree_hal_cuda_device_create(base_driver, device_name, &driver->syms, - device, host_allocator, out_device); + iree_status_t status = iree_hal_cuda_device_create( + base_driver, device_name, &driver->default_params, &driver->syms, device, + host_allocator, out_device); IREE_TRACE_ZONE_END(z0); return status;
diff --git a/iree/hal/cuda/dynamic_symbol_tables.h b/iree/hal/cuda/dynamic_symbol_tables.h index f131e14..6222219 100644 --- a/iree/hal/cuda/dynamic_symbol_tables.h +++ b/iree/hal/cuda/dynamic_symbol_tables.h
@@ -39,3 +39,8 @@ CU_PFN_DECL(cuStreamDestroy, CUstream) CU_PFN_DECL(cuStreamSynchronize, CUstream) CU_PFN_DECL(cuStreamWaitEvent, CUstream, CUevent, unsigned int) +CU_PFN_DECL(cuMemsetD32Async, unsigned long long, int, size_t, CUstream) +CU_PFN_DECL(cuMemcpyAsync, CUdeviceptr, CUdeviceptr, size_t, CUstream) +CU_PFN_DECL(cuLaunchKernel, CUfunction, unsigned int, unsigned int, + unsigned int, unsigned int, unsigned int, unsigned int, + unsigned int, CUstream, void **, void **)
diff --git a/iree/hal/cuda/graph_command_buffer.c b/iree/hal/cuda/graph_command_buffer.c index 90126e8..f5667fa 100644 --- a/iree/hal/cuda/graph_command_buffer.c +++ b/iree/hal/cuda/graph_command_buffer.c
@@ -47,7 +47,7 @@ return (iree_hal_cuda_graph_command_buffer_t*)base_value; } -iree_status_t iree_hal_cuda_graph_command_buffer_allocate( +iree_status_t iree_hal_cuda_graph_command_buffer_create( iree_hal_cuda_context_wrapper_t* context, iree_hal_command_buffer_mode_t mode, iree_hal_command_category_t command_categories, @@ -136,7 +136,7 @@ static iree_status_t iree_hal_cuda_graph_command_buffer_begin( iree_hal_command_buffer_t* base_command_buffer) { - // nothing to do. + // Nothing to do. return iree_ok_status(); }
diff --git a/iree/hal/cuda/graph_command_buffer.h b/iree/hal/cuda/graph_command_buffer.h index b38b251..c50ccf9 100644 --- a/iree/hal/cuda/graph_command_buffer.h +++ b/iree/hal/cuda/graph_command_buffer.h
@@ -18,7 +18,7 @@ #endif // __cplusplus // Creates a cuda graph. -iree_status_t iree_hal_cuda_graph_command_buffer_allocate( +iree_status_t iree_hal_cuda_graph_command_buffer_create( iree_hal_cuda_context_wrapper_t* context, iree_hal_command_buffer_mode_t mode, iree_hal_command_category_t command_categories,
diff --git a/iree/hal/cuda/registration/driver_module.c b/iree/hal/cuda/registration/driver_module.c index f8fb736..0a51ed9 100644 --- a/iree/hal/cuda/registration/driver_module.c +++ b/iree/hal/cuda/registration/driver_module.c
@@ -41,6 +41,9 @@ driver_id); } IREE_TRACE_ZONE_BEGIN(z0); + iree_hal_cuda_device_params_t default_params; + iree_hal_cuda_device_params_initialize(&default_params); + // TODO(jinchen62): set up default_params.use_deferred_submission by flag // When we expose more than one driver (different cuda versions, etc) we // can name them here: iree_string_view_t identifier = iree_make_cstring_view("cuda"); @@ -48,7 +51,7 @@ iree_hal_cuda_driver_options_t driver_options; iree_hal_cuda_driver_options_initialize(&driver_options); iree_status_t status = iree_hal_cuda_driver_create( - identifier, &driver_options, allocator, out_driver); + identifier, &default_params, &driver_options, allocator, out_driver); IREE_TRACE_ZONE_END(z0); return status; }
diff --git a/iree/hal/cuda/stream_command_buffer.c b/iree/hal/cuda/stream_command_buffer.c new file mode 100644 index 0000000..9ab153d --- /dev/null +++ b/iree/hal/cuda/stream_command_buffer.c
@@ -0,0 +1,353 @@ +// Copyright 2021 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/hal/cuda/stream_command_buffer.h" + +#include "iree/base/tracing.h" +#include "iree/hal/cuda/cuda_buffer.h" +#include "iree/hal/cuda/cuda_event.h" +#include "iree/hal/cuda/executable_layout.h" +#include "iree/hal/cuda/native_executable.h" +#include "iree/hal/cuda/status_util.h" + +#define IREE_HAL_CUDA_MAX_BINDING_COUNT 64 + +// This records the commands on the calling thread without additional threading +// indirection. + +typedef struct { + iree_hal_resource_t resource; + iree_hal_cuda_context_wrapper_t* context; + iree_hal_command_buffer_mode_t mode; + iree_hal_command_category_t allowed_categories; + CUstream stream; + // Keep track of the current set of kernel arguments. + void* current_descriptor[IREE_HAL_CUDA_MAX_BINDING_COUNT]; + CUdeviceptr* device_ptrs[IREE_HAL_CUDA_MAX_BINDING_COUNT]; +} iree_hal_cuda_stream_command_buffer_t; + +extern const iree_hal_command_buffer_vtable_t + iree_hal_cuda_stream_command_buffer_vtable; + +static iree_hal_cuda_stream_command_buffer_t* +iree_hal_cuda_stream_command_buffer_cast( + iree_hal_command_buffer_t* base_value) { + IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_cuda_stream_command_buffer_vtable); + return (iree_hal_cuda_stream_command_buffer_t*)base_value; +} + +iree_status_t iree_hal_cuda_stream_command_buffer_create( + iree_hal_cuda_context_wrapper_t* context, + iree_hal_command_buffer_mode_t mode, + iree_hal_command_category_t command_categories, CUstream stream, + iree_hal_command_buffer_t** out_command_buffer) { + IREE_ASSERT_ARGUMENT(context); + IREE_ASSERT_ARGUMENT(out_command_buffer); + *out_command_buffer = NULL; + IREE_TRACE_ZONE_BEGIN(z0); + + iree_hal_cuda_stream_command_buffer_t* command_buffer = NULL; + size_t total_size = sizeof(*command_buffer) + + IREE_HAL_CUDA_MAX_BINDING_COUNT * sizeof(void*) + + IREE_HAL_CUDA_MAX_BINDING_COUNT * sizeof(CUdeviceptr); + iree_status_t status = iree_allocator_malloc( + context->host_allocator, total_size, (void**)&command_buffer); + if (iree_status_is_ok(status)) { + iree_hal_resource_initialize(&iree_hal_cuda_stream_command_buffer_vtable, + &command_buffer->resource); + command_buffer->context = context; + command_buffer->mode = mode; + command_buffer->allowed_categories = command_categories; + command_buffer->stream = stream; + for (size_t i = 0; i < IREE_HAL_CUDA_MAX_BINDING_COUNT; i++) { + command_buffer->current_descriptor[i] = &command_buffer->device_ptrs[i]; + } + } + + *out_command_buffer = (iree_hal_command_buffer_t*)command_buffer; + IREE_TRACE_ZONE_END(z0); + return status; +} + +static void iree_hal_cuda_stream_command_buffer_destroy( + iree_hal_command_buffer_t* base_command_buffer) { + iree_hal_cuda_stream_command_buffer_t* command_buffer = + iree_hal_cuda_stream_command_buffer_cast(base_command_buffer); + IREE_TRACE_ZONE_BEGIN(z0); + + iree_allocator_free(command_buffer->context->host_allocator, command_buffer); + + IREE_TRACE_ZONE_END(z0); +} + +static iree_hal_command_buffer_mode_t iree_hal_cuda_stream_command_buffer_mode( + const iree_hal_command_buffer_t* base_command_buffer) { + const iree_hal_cuda_stream_command_buffer_t* command_buffer = + (const iree_hal_cuda_stream_command_buffer_t*)(base_command_buffer); + return command_buffer->mode; +} + +static iree_hal_command_category_t +iree_hal_cuda_stream_command_buffer_allowed_categories( + const iree_hal_command_buffer_t* base_command_buffer) { + const iree_hal_cuda_stream_command_buffer_t* command_buffer = + (const iree_hal_cuda_stream_command_buffer_t*)(base_command_buffer); + return command_buffer->allowed_categories; +} + +static iree_status_t iree_hal_cuda_stream_command_buffer_begin( + iree_hal_command_buffer_t* base_command_buffer) { + return iree_ok_status(); +} + +static iree_status_t iree_hal_cuda_stream_command_buffer_end( + iree_hal_command_buffer_t* base_command_buffer) { + return iree_ok_status(); +} + +static iree_status_t iree_hal_cuda_stream_command_buffer_execution_barrier( + iree_hal_command_buffer_t* base_command_buffer, + iree_hal_execution_stage_t source_stage_mask, + iree_hal_execution_stage_t target_stage_mask, + iree_hal_execution_barrier_flags_t flags, + iree_host_size_t memory_barrier_count, + const iree_hal_memory_barrier_t* memory_barriers, + iree_host_size_t buffer_barrier_count, + const iree_hal_buffer_barrier_t* buffer_barriers) { + // TODO(jinchen62): implement CUDA barrier + return iree_ok_status(); +} + +static iree_status_t iree_hal_cuda_stream_command_buffer_signal_event( + iree_hal_command_buffer_t* base_command_buffer, iree_hal_event_t* event, + iree_hal_execution_stage_t source_stage_mask) { + // TODO(jinchen62): implement CUDA barrier + return iree_ok_status(); +} + +static iree_status_t iree_hal_cuda_stream_command_buffer_reset_event( + iree_hal_command_buffer_t* base_command_buffer, iree_hal_event_t* event, + iree_hal_execution_stage_t source_stage_mask) { + // TODO(jinchen62): implement CUDA barrier + return iree_ok_status(); +} + +static iree_status_t iree_hal_cuda_stream_command_buffer_wait_events( + iree_hal_command_buffer_t* base_command_buffer, + iree_host_size_t event_count, const iree_hal_event_t** events, + iree_hal_execution_stage_t source_stage_mask, + iree_hal_execution_stage_t target_stage_mask, + iree_host_size_t memory_barrier_count, + const iree_hal_memory_barrier_t* memory_barriers, + iree_host_size_t buffer_barrier_count, + const iree_hal_buffer_barrier_t* buffer_barriers) { + // TODO(jinchen62): implement CUDA barrier + return iree_ok_status(); +} + +static iree_status_t iree_hal_cuda_stream_command_buffer_discard_buffer( + iree_hal_command_buffer_t* base_command_buffer, iree_hal_buffer_t* buffer) { + // nothing to do. + return iree_ok_status(); +} + +// Splats a pattern value of 1, 2, or 4 bytes out to a 4 byte value. +static uint32_t iree_hal_cuda_splat_pattern(const void* pattern, + size_t pattern_length) { + switch (pattern_length) { + case 1: { + uint32_t pattern_value = *(const uint8_t*)(pattern); + return (pattern_value << 24) | (pattern_value << 16) | + (pattern_value << 8) | pattern_value; + } + case 2: { + uint32_t pattern_value = *(const uint16_t*)(pattern); + return (pattern_value << 16) | pattern_value; + } + case 4: { + uint32_t pattern_value = *(const uint32_t*)(pattern); + return pattern_value; + } + default: + return 0; // Already verified that this should not be possible. + } +} + +static iree_status_t iree_hal_cuda_stream_command_buffer_fill_buffer( + iree_hal_command_buffer_t* base_command_buffer, + iree_hal_buffer_t* target_buffer, iree_device_size_t target_offset, + iree_device_size_t length, const void* pattern, + iree_host_size_t pattern_length) { + iree_hal_cuda_stream_command_buffer_t* command_buffer = + iree_hal_cuda_stream_command_buffer_cast(base_command_buffer); + + CUdeviceptr target_device_buffer = iree_hal_cuda_buffer_device_pointer( + iree_hal_buffer_allocated_buffer(target_buffer)); + target_offset += iree_hal_buffer_byte_offset(target_buffer); + uint32_t dword_pattern = iree_hal_cuda_splat_pattern(pattern, pattern_length); + CUdeviceptr dst = target_device_buffer + target_offset; + int value = dword_pattern; + size_t sizeBytes = length; + CUDA_RETURN_IF_ERROR( + command_buffer->context->syms, + cuMemsetD32Async(dst, value, sizeBytes, command_buffer->stream), + "cuMemsetD32Async"); + return iree_ok_status(); +} + +static iree_status_t iree_hal_cuda_stream_command_buffer_update_buffer( + iree_hal_command_buffer_t* base_command_buffer, const void* source_buffer, + iree_host_size_t source_offset, iree_hal_buffer_t* target_buffer, + iree_device_size_t target_offset, iree_device_size_t length) { + return iree_make_status(IREE_STATUS_UNIMPLEMENTED, + "need cuda implementation of update buffer"); +} + +static iree_status_t iree_hal_cuda_stream_command_buffer_copy_buffer( + iree_hal_command_buffer_t* base_command_buffer, + iree_hal_buffer_t* source_buffer, iree_device_size_t source_offset, + iree_hal_buffer_t* target_buffer, iree_device_size_t target_offset, + iree_device_size_t length) { + iree_hal_cuda_stream_command_buffer_t* command_buffer = + iree_hal_cuda_stream_command_buffer_cast(base_command_buffer); + + CUdeviceptr target_device_buffer = iree_hal_cuda_buffer_device_pointer( + iree_hal_buffer_allocated_buffer(target_buffer)); + target_offset += iree_hal_buffer_byte_offset(target_buffer); + CUdeviceptr source_device_buffer = iree_hal_cuda_buffer_device_pointer( + iree_hal_buffer_allocated_buffer(source_buffer)); + source_offset += iree_hal_buffer_byte_offset(source_buffer); + CUDA_RETURN_IF_ERROR(command_buffer->context->syms, + cuMemcpyAsync(target_device_buffer, source_device_buffer, + length, command_buffer->stream), + "cuMemcpyAsync"); + return iree_ok_status(); +} + +static iree_status_t iree_hal_cuda_stream_command_buffer_push_constants( + iree_hal_command_buffer_t* base_command_buffer, + iree_hal_executable_layout_t* executable_layout, iree_host_size_t offset, + const void* values, iree_host_size_t values_length) { + return iree_make_status(IREE_STATUS_UNIMPLEMENTED, + "need cuda implementation of push constants"); +} + +// Tie together the binding index and its index in |bindings| array. +typedef struct { + uint32_t index; + uint32_t binding; +} iree_hal_cuda_binding_mapping_t; + +// Helper to sort the binding based on their binding index. +static int compare_binding_index(const void* a, const void* b) { + const iree_hal_cuda_binding_mapping_t buffer_a = + *(const iree_hal_cuda_binding_mapping_t*)a; + const iree_hal_cuda_binding_mapping_t buffer_b = + *(const iree_hal_cuda_binding_mapping_t*)b; + return buffer_a.binding < buffer_b.binding ? -1 : 1; +} + +static iree_status_t iree_hal_cuda_stream_command_buffer_push_descriptor_set( + iree_hal_command_buffer_t* base_command_buffer, + iree_hal_executable_layout_t* executable_layout, uint32_t set, + iree_host_size_t binding_count, + const iree_hal_descriptor_set_binding_t* bindings) { + iree_hal_cuda_stream_command_buffer_t* command_buffer = + iree_hal_cuda_stream_command_buffer_cast(base_command_buffer); + iree_host_size_t base_binding = + iree_hal_cuda_base_binding_index(executable_layout, set); + // Convention with the compiler side. We map bindings to kernel argument. + // We compact the bindings to get a dense set of arguments and keep them order + // based on the binding index. + // Sort the binding based on the binding index and map the array index to the + // argument index. + iree_hal_cuda_binding_mapping_t binding_used[IREE_HAL_CUDA_MAX_BINDING_COUNT]; + for (iree_host_size_t i = 0; i < binding_count; i++) { + iree_hal_cuda_binding_mapping_t buffer = {i, bindings[i].binding}; + binding_used[i] = buffer; + } + qsort(binding_used, binding_count, sizeof(iree_hal_cuda_binding_mapping_t), + compare_binding_index); + assert(binding_count < IREE_HAL_CUDA_MAX_BINDING_COUNT && + "binding count larger than the max expected."); + for (iree_host_size_t i = 0; i < binding_count; i++) { + iree_hal_descriptor_set_binding_t binding = bindings[binding_used[i].index]; + CUdeviceptr device_ptr = + iree_hal_cuda_buffer_device_pointer( + iree_hal_buffer_allocated_buffer(binding.buffer)) + + iree_hal_buffer_byte_offset(binding.buffer) + binding.offset; + *((CUdeviceptr*)command_buffer->current_descriptor[i + base_binding]) = + device_ptr; + } + return iree_ok_status(); +} + +static iree_status_t iree_hal_cuda_stream_command_buffer_bind_descriptor_set( + iree_hal_command_buffer_t* base_command_buffer, + iree_hal_executable_layout_t* executable_layout, uint32_t set, + iree_hal_descriptor_set_t* descriptor_set, + iree_host_size_t dynamic_offset_count, + const iree_device_size_t* dynamic_offsets) { + return iree_make_status(IREE_STATUS_UNIMPLEMENTED, + "need cuda implementation of bind descriptor set"); +} + +static iree_status_t iree_hal_cuda_stream_command_buffer_dispatch( + iree_hal_command_buffer_t* base_command_buffer, + iree_hal_executable_t* executable, int32_t entry_point, + uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z) { + iree_hal_cuda_stream_command_buffer_t* command_buffer = + iree_hal_cuda_stream_command_buffer_cast(base_command_buffer); + + int32_t block_size_x, block_size_y, block_size_z; + IREE_RETURN_IF_ERROR(iree_hal_cuda_native_executable_block_size( + executable, entry_point, &block_size_x, &block_size_y, &block_size_z)); + CUfunction func = + iree_hal_cuda_native_executable_for_entry_point(executable, entry_point); + CUDA_RETURN_IF_ERROR( + command_buffer->context->syms, + cuLaunchKernel(func, workgroup_x, workgroup_y, workgroup_z, block_size_x, + block_size_y, block_size_z, 0, command_buffer->stream, + command_buffer->current_descriptor, NULL), + "cuLaunchKernel"); + return iree_ok_status(); +} + +static iree_status_t iree_hal_cuda_stream_command_buffer_dispatch_indirect( + iree_hal_command_buffer_t* base_command_buffer, + iree_hal_executable_t* executable, int32_t entry_point, + iree_hal_buffer_t* workgroups_buffer, + iree_device_size_t workgroups_offset) { + return iree_make_status(IREE_STATUS_UNIMPLEMENTED, + "need cuda implementation of dispatch indirect"); +} + +const iree_hal_command_buffer_vtable_t + iree_hal_cuda_stream_command_buffer_vtable = { + .destroy = iree_hal_cuda_stream_command_buffer_destroy, + .mode = iree_hal_cuda_stream_command_buffer_mode, + .allowed_categories = + iree_hal_cuda_stream_command_buffer_allowed_categories, + .begin = iree_hal_cuda_stream_command_buffer_begin, + .end = iree_hal_cuda_stream_command_buffer_end, + .execution_barrier = + iree_hal_cuda_stream_command_buffer_execution_barrier, + .signal_event = iree_hal_cuda_stream_command_buffer_signal_event, + .reset_event = iree_hal_cuda_stream_command_buffer_reset_event, + .wait_events = iree_hal_cuda_stream_command_buffer_wait_events, + .discard_buffer = iree_hal_cuda_stream_command_buffer_discard_buffer, + .fill_buffer = iree_hal_cuda_stream_command_buffer_fill_buffer, + .update_buffer = iree_hal_cuda_stream_command_buffer_update_buffer, + .copy_buffer = iree_hal_cuda_stream_command_buffer_copy_buffer, + .push_constants = iree_hal_cuda_stream_command_buffer_push_constants, + .push_descriptor_set = + iree_hal_cuda_stream_command_buffer_push_descriptor_set, + .bind_descriptor_set = + iree_hal_cuda_stream_command_buffer_bind_descriptor_set, + .dispatch = iree_hal_cuda_stream_command_buffer_dispatch, + .dispatch_indirect = + iree_hal_cuda_stream_command_buffer_dispatch_indirect, +};
diff --git a/iree/hal/cuda/stream_command_buffer.h b/iree/hal/cuda/stream_command_buffer.h new file mode 100644 index 0000000..b4b901a --- /dev/null +++ b/iree/hal/cuda/stream_command_buffer.h
@@ -0,0 +1,35 @@ +// Copyright 2021 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_HAL_CUDA_STREAM_COMMAND_BUFFER_H_ +#define IREE_HAL_CUDA_STREAM_COMMAND_BUFFER_H_ + +#include "iree/base/internal/arena.h" +#include "iree/hal/api.h" +#include "iree/hal/cuda/context_wrapper.h" +#include "iree/hal/cuda/cuda_headers.h" +#include "iree/hal/cuda/dynamic_symbols.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +// Creates a cuda stream command buffer that immediately +// issues commands against the given |stream|. +// Access to |stream| must be synchronized by the user. +// Used for replaying commands in special situations and +// never returned to a user from the device_create_command_buffer +iree_status_t iree_hal_cuda_stream_command_buffer_create( + iree_hal_cuda_context_wrapper_t *context, + iree_hal_command_buffer_mode_t mode, + iree_hal_command_category_t command_categories, CUstream stream, + iree_hal_command_buffer_t **out_command_buffer); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // IREE_HAL_CUDA_STREAM_COMMAND_BUFFER_H_
diff --git a/iree/hal/driver_registry.c b/iree/hal/driver_registry.c index 972bff9..8594023 100644 --- a/iree/hal/driver_registry.c +++ b/iree/hal/driver_registry.c
@@ -188,8 +188,8 @@ // Allocate the required memory for both the driver infos and the string // storage in a single block. - iree_host_size_t total_driver_infos_size = iree_host_align( - total_driver_info_count * sizeof(iree_hal_driver_info_t), 8); + iree_host_size_t total_driver_infos_size = + total_driver_info_count * sizeof(iree_hal_driver_info_t); if (iree_status_is_ok(status)) { status = iree_allocator_malloc(allocator, total_driver_infos_size + total_storage_size,
diff --git a/iree/hal/local/elf/arch/x86_32.c b/iree/hal/local/elf/arch/x86_32.c index 6a7c103..05d08d7 100644 --- a/iree/hal/local/elf/arch/x86_32.c +++ b/iree/hal/local/elf/arch/x86_32.c
@@ -120,10 +120,10 @@ // Cross-ABI function calls //============================================================================== -// System V 386 ABI (used in IREE): +// System V i386 ABI (used in IREE): // https://uclibc.org/docs/psABI-i386.pdf // Arguments: -// +// (reverse order on the stack; last arg furthest from stack pointer) // // Results: // EAX @@ -136,7 +136,7 @@ #if defined(IREE_PLATFORM_WINDOWS) -#error "TODO" +#error "TODO(#6554): need cdecl -> sysv ABI shims in x86_32_msvc.asm" #else
diff --git a/iree/hal/local/loaders/embedded_library_loader.c b/iree/hal/local/loaders/embedded_library_loader.c index ddaafe2..74d48fe 100644 --- a/iree/hal/local/loaders/embedded_library_loader.c +++ b/iree/hal/local/loaders/embedded_library_loader.c
@@ -35,6 +35,8 @@ const iree_hal_executable_library_header_t** header; const iree_hal_executable_library_v0_t* v0; } library; + + iree_hal_local_executable_layout_t* layouts[]; } iree_hal_elf_executable_t; extern const iree_hal_local_executable_vtable_t iree_hal_elf_executable_vtable; @@ -140,16 +142,13 @@ iree_hal_elf_executable_t* executable = NULL; iree_host_size_t total_size = sizeof(*executable) + - executable_layout_count * sizeof(iree_hal_local_executable_layout_t); + executable_layout_count * sizeof(*executable->layouts); iree_status_t status = iree_allocator_malloc(host_allocator, total_size, (void**)&executable); if (iree_status_is_ok(status)) { - iree_hal_local_executable_layout_t** executable_layouts_ptr = - (iree_hal_local_executable_layout_t**)(((uint8_t*)executable) + - sizeof(*executable)); iree_hal_local_executable_initialize( &iree_hal_elf_executable_vtable, executable_layout_count, - executable_layouts, executable_layouts_ptr, host_allocator, + executable_layouts, &executable->layouts[0], host_allocator, &executable->base); } if (iree_status_is_ok(status)) {
diff --git a/iree/hal/local/loaders/legacy_library_loader.c b/iree/hal/local/loaders/legacy_library_loader.c index f370b94..fa9fa1f 100644 --- a/iree/hal/local/loaders/legacy_library_loader.c +++ b/iree/hal/local/loaders/legacy_library_loader.c
@@ -81,6 +81,8 @@ const iree_hal_executable_library_header_t** header; const iree_hal_executable_library_v0_t* v0; } library; + + iree_hal_local_executable_layout_t* layouts[]; } iree_hal_legacy_executable_t; extern const iree_hal_local_executable_vtable_t @@ -224,16 +226,13 @@ iree_hal_legacy_executable_t* executable = NULL; iree_host_size_t total_size = sizeof(*executable) + - executable_layout_count * sizeof(iree_hal_local_executable_layout_t); + executable_layout_count * sizeof(*executable->layouts); iree_status_t status = iree_allocator_malloc(host_allocator, total_size, (void**)&executable); if (iree_status_is_ok(status)) { - iree_hal_local_executable_layout_t** executable_layouts_ptr = - (iree_hal_local_executable_layout_t**)(((uint8_t*)executable) + - sizeof(*executable)); iree_hal_local_executable_initialize( &iree_hal_legacy_executable_vtable, executable_layout_count, - executable_layouts, executable_layouts_ptr, host_allocator, + executable_layouts, &executable->layouts[0], host_allocator, &executable->base); executable->def = executable_def; }
diff --git a/iree/hal/local/loaders/static_library_loader.c b/iree/hal/local/loaders/static_library_loader.c index a7ff63b..1e7819c 100644 --- a/iree/hal/local/loaders/static_library_loader.c +++ b/iree/hal/local/loaders/static_library_loader.c
@@ -30,6 +30,8 @@ const iree_hal_executable_library_header_t** header; const iree_hal_executable_library_v0_t* v0; } library; + + iree_hal_local_executable_layout_t* layouts[]; } iree_hal_static_executable_t; static const iree_hal_local_executable_vtable_t @@ -50,16 +52,13 @@ iree_hal_static_executable_t* executable = NULL; iree_host_size_t total_size = sizeof(*executable) + - executable_layout_count * sizeof(iree_hal_local_executable_layout_t); + executable_layout_count * sizeof(*executable->layouts); iree_status_t status = iree_allocator_malloc(host_allocator, total_size, (void**)&executable); if (iree_status_is_ok(status)) { - iree_hal_local_executable_layout_t** executable_layouts_ptr = - (iree_hal_local_executable_layout_t**)(((uint8_t*)executable) + - sizeof(*executable)); iree_hal_local_executable_initialize( &iree_hal_static_executable_vtable, executable_layout_count, - executable_layouts, executable_layouts_ptr, host_allocator, + executable_layouts, &executable->layouts[0], host_allocator, &executable->base); executable->library.header = library_header; executable->identifier = iree_make_cstring_view((*library_header)->name);
diff --git a/iree/hal/local/sync_device.c b/iree/hal/local/sync_device.c index b6083ab..924da83 100644 --- a/iree/hal/local/sync_device.c +++ b/iree/hal/local/sync_device.c
@@ -23,13 +23,13 @@ iree_hal_resource_t resource; iree_string_view_t identifier; - iree_host_size_t loader_count; - iree_hal_executable_loader_t** loaders; - iree_allocator_t host_allocator; iree_hal_allocator_t* device_allocator; iree_hal_sync_semaphore_state_t semaphore_state; + + iree_host_size_t loader_count; + iree_hal_executable_loader_t* loaders[]; } iree_hal_sync_device_t; static const iree_hal_device_vtable_t iree_hal_sync_device_vtable; @@ -64,8 +64,9 @@ iree_hal_sync_device_check_params(params)); iree_hal_sync_device_t* device = NULL; - iree_host_size_t total_size = sizeof(*device) + identifier.size + - loader_count * sizeof(*device->loaders); + iree_host_size_t struct_size = + sizeof(*device) + loader_count * sizeof(*device->loaders); + iree_host_size_t total_size = struct_size + identifier.size; iree_status_t status = iree_allocator_malloc(host_allocator, total_size, (void**)&device); if (iree_status_is_ok(status)) { @@ -73,13 +74,10 @@ iree_hal_resource_initialize(&iree_hal_sync_device_vtable, &device->resource); iree_string_view_append_to_buffer(identifier, &device->identifier, - (char*)device + sizeof(*device)); + (char*)device + struct_size); device->host_allocator = host_allocator; device->loader_count = loader_count; - device->loaders = - (iree_hal_executable_loader_t**)((uint8_t*)device->identifier.data + - identifier.size); for (iree_host_size_t i = 0; i < device->loader_count; ++i) { device->loaders[i] = loaders[i]; iree_hal_executable_loader_retain(device->loaders[i]);
diff --git a/iree/hal/local/task_device.c b/iree/hal/local/task_device.c index 158b9e4..236c83e 100644 --- a/iree/hal/local/task_device.c +++ b/iree/hal/local/task_device.c
@@ -92,19 +92,18 @@ iree_hal_task_device_check_params(params)); iree_hal_task_device_t* device = NULL; - iree_host_size_t total_size = - sizeof(*device) + params->queue_count * sizeof(*device->queues) + - identifier.size + loader_count * sizeof(*device->loaders); + iree_host_size_t struct_size = sizeof(*device) + + params->queue_count * sizeof(*device->queues) + + loader_count * sizeof(*device->loaders); + iree_host_size_t total_size = struct_size + identifier.size; iree_status_t status = iree_allocator_malloc(host_allocator, total_size, (void**)&device); if (iree_status_is_ok(status)) { memset(device, 0, total_size); iree_hal_resource_initialize(&iree_hal_task_device_vtable, &device->resource); - iree_string_view_append_to_buffer( - identifier, &device->identifier, - (char*)device + sizeof(*device) + - params->queue_count * sizeof(*device->queues)); + iree_string_view_append_to_buffer(identifier, &device->identifier, + (char*)device + struct_size); device->host_allocator = host_allocator; iree_arena_block_pool_initialize(4096, host_allocator, &device->small_block_pool); @@ -117,8 +116,9 @@ device->loader_count = loader_count; device->loaders = - (iree_hal_executable_loader_t**)((uint8_t*)device->identifier.data + - identifier.size); + (iree_hal_executable_loader_t**)((uint8_t*)device + sizeof(*device) + + params->queue_count * + sizeof(*device->queues)); for (iree_host_size_t i = 0; i < device->loader_count; ++i) { device->loaders[i] = loaders[i]; iree_hal_executable_loader_retain(device->loaders[i]);
diff --git a/iree/hal/local/task_driver.c b/iree/hal/local/task_driver.c index e3f10ed..7711520 100644 --- a/iree/hal/local/task_driver.c +++ b/iree/hal/local/task_driver.c
@@ -47,9 +47,9 @@ IREE_TRACE_ZONE_BEGIN(z0); iree_hal_task_driver_t* driver = NULL; - iree_host_size_t total_size = sizeof(*driver) + - loader_count * sizeof(*driver->loaders) + - identifier.size; + iree_host_size_t struct_size = + sizeof(*driver) + loader_count * sizeof(*driver->loaders); + iree_host_size_t total_size = struct_size + identifier.size; iree_status_t status = iree_allocator_malloc(host_allocator, total_size, (void**)&driver); if (iree_status_is_ok(status)) { @@ -57,9 +57,8 @@ &driver->resource); driver->host_allocator = host_allocator; - iree_string_view_append_to_buffer( - identifier, &driver->identifier, - (char*)driver + total_size - identifier.size); + iree_string_view_append_to_buffer(identifier, &driver->identifier, + (char*)driver + struct_size); memcpy(&driver->default_params, default_params, sizeof(driver->default_params));
diff --git a/iree/hal/vulkan/native_descriptor_set_layout.cc b/iree/hal/vulkan/native_descriptor_set_layout.cc index 1d3c4a3..dff1661 100644 --- a/iree/hal/vulkan/native_descriptor_set_layout.cc +++ b/iree/hal/vulkan/native_descriptor_set_layout.cc
@@ -56,6 +56,7 @@ VkDescriptorSetLayoutBinding* native_bindings = NULL; if (binding_count > 0) { + // TODO(benvanik): avoid this allocation if possible (inline_array). IREE_RETURN_IF_ERROR(iree_allocator_malloc( logical_device->host_allocator(), binding_count * sizeof(VkDescriptorSetLayoutBinding),
diff --git a/iree/hal/vulkan/vma_buffer.cc b/iree/hal/vulkan/vma_buffer.cc index 9030775..ffdb8ff 100644 --- a/iree/hal/vulkan/vma_buffer.cc +++ b/iree/hal/vulkan/vma_buffer.cc
@@ -123,7 +123,7 @@ // would only work if the entire buffer was discarded. #ifndef NDEBUG if (iree_any_bit_set(memory_access, IREE_HAL_MEMORY_ACCESS_DISCARD)) { - memset(data_ptr + local_byte_offset, 0xCD, local_byte_length); + memset(*out_data_ptr, 0xCD, local_byte_length); } #endif // !NDEBUG
diff --git a/iree/samples/custom_modules/module.cc b/iree/samples/custom_modules/module.cc index 0477ef5..1759d82 100644 --- a/iree/samples/custom_modules/module.cc +++ b/iree/samples/custom_modules/module.cc
@@ -56,7 +56,7 @@ // Note that we allocate the message and the string value together. iree_custom_message_t* message = NULL; IREE_RETURN_IF_ERROR(iree_allocator_malloc( - allocator, sizeof(iree_custom_message_t) + value.size, (void**)&message)); + allocator, sizeof(*message) + value.size, (void**)&message)); message->ref_object.counter = IREE_ATOMIC_VAR_INIT(1); message->allocator = allocator; message->value.data = ((const char*)message) + sizeof(iree_custom_message_t); @@ -71,8 +71,8 @@ iree_custom_message_t** out_message) { IREE_ASSERT_ARGUMENT(out_message); iree_custom_message_t* message = NULL; - IREE_RETURN_IF_ERROR(iree_allocator_malloc( - allocator, sizeof(iree_custom_message_t), (void**)&message)); + IREE_RETURN_IF_ERROR( + iree_allocator_malloc(allocator, sizeof(*message), (void**)&message)); message->ref_object.counter = IREE_ATOMIC_VAR_INIT(1); message->allocator = allocator; message->value = value; // Unowned.
diff --git a/iree/samples/simple_embedding/BUILD b/iree/samples/simple_embedding/BUILD index 94c29c2..d2fd399 100644 --- a/iree/samples/simple_embedding/BUILD +++ b/iree/samples/simple_embedding/BUILD
@@ -265,6 +265,53 @@ iree_cmake_extra_content( content = """ endif() + +if(${IREE_HAL_DRIVER_CUDA} AND (${IREE_TARGET_BACKEND_CUDA} OR DEFINED IREE_HOST_BINARY_ROOT)) +""", + inline = True, +) + +cc_binary( + name = "simple_embedding_cuda", + srcs = [ + "device_cuda.c", + "simple_embedding.c", + ], + deps = [ + ":simple_embedding_test_bytecode_module_cuda_c", + "//iree/base", + "//iree/hal", + "//iree/hal/cuda/registration", + "//iree/modules/hal", + "//iree/vm", + "//iree/vm:bytecode_module", + ], +) + +iree_bytecode_module( + name = "simple_embedding_test_bytecode_module_cuda", + src = "simple_embedding_test.mlir", + c_identifier = "iree_samples_simple_embedding_test_module_cuda", + flags = [ + "-iree-input-type=mhlo", + "-iree-mlir-to-vm-bytecode-module", + "-iree-hal-target-backends=cuda", + "-iree-llvm-debug-symbols=false", + ], +) + +# Simple embedding is failing in the CI. +# run_binary_test( +# name = "simple_embedding_cuda_test", +# tags = [ +# "driver=cuda", +# ], +# test_binary = ":simple_embedding_cuda", +# ) + +iree_cmake_extra_content( + content = """ +endif() """, inline = True, )
diff --git a/iree/samples/simple_embedding/CMakeLists.txt b/iree/samples/simple_embedding/CMakeLists.txt index 4623869..aaee677 100644 --- a/iree/samples/simple_embedding/CMakeLists.txt +++ b/iree/samples/simple_embedding/CMakeLists.txt
@@ -245,4 +245,48 @@ endif() +if(${IREE_HAL_DRIVER_CUDA} AND (${IREE_TARGET_BACKEND_CUDA} OR DEFINED IREE_HOST_BINARY_ROOT)) + +iree_cc_binary( + NAME + simple_embedding_cuda + SRCS + "device_cuda.c" + "simple_embedding.c" + DEPS + ::simple_embedding_test_bytecode_module_cuda_c + iree::base + iree::hal + iree::hal::cuda::registration + iree::modules::hal + iree::vm + iree::vm::bytecode_module +) + +iree_bytecode_module( + NAME + simple_embedding_test_bytecode_module_cuda + SRC + "simple_embedding_test.mlir" + C_IDENTIFIER + "iree_samples_simple_embedding_test_module_cuda" + FLAGS + "-iree-input-type=mhlo" + "-iree-mlir-to-vm-bytecode-module" + "-iree-hal-target-backends=cuda" + "-iree-llvm-debug-symbols=false" + PUBLIC +) + +iree_run_binary_test( + NAME + "simple_embedding_cuda_test" + LABELS + "driver=cuda" + TEST_BINARY + ::simple_embedding_cuda +) + +endif() + ### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/iree/samples/simple_embedding/device_cuda.c b/iree/samples/simple_embedding/device_cuda.c new file mode 100644 index 0000000..e8cf2f4 --- /dev/null +++ b/iree/samples/simple_embedding/device_cuda.c
@@ -0,0 +1,39 @@ +// Copyright 2021 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +// A example of setting up the the cuda driver. + +#include <stddef.h> + +#include "iree/base/api.h" +#include "iree/hal/api.h" +#include "iree/hal/cuda/registration/driver_module.h" + +// Compiled module embedded here to avoid file IO: +#include "iree/samples/simple_embedding/simple_embedding_test_bytecode_module_cuda_c.h" + +iree_status_t create_sample_device(iree_hal_device_t** device) { + // Only register the cuda HAL driver. + IREE_RETURN_IF_ERROR( + iree_hal_cuda_driver_module_register(iree_hal_driver_registry_default())); + // Create the hal driver from the name. + iree_hal_driver_t* driver = NULL; + iree_string_view_t identifier = iree_make_cstring_view("cuda"); + IREE_RETURN_IF_ERROR(iree_hal_driver_registry_try_create_by_name( + iree_hal_driver_registry_default(), identifier, iree_allocator_system(), + &driver)); + IREE_RETURN_IF_ERROR(iree_hal_driver_create_default_device( + driver, iree_allocator_system(), device)); + iree_hal_driver_release(driver); + return iree_ok_status(); +} + +const iree_const_byte_span_t load_bytecode_module_data() { + const struct iree_file_toc_t* module_file_toc = + iree_samples_simple_embedding_test_module_cuda_create(); + return iree_make_const_byte_span(module_file_toc->data, + module_file_toc->size); +}
diff --git a/iree/tools/run_lit.ps1 b/iree/tools/run_lit.ps1 index 7f19601..e88dcb4 100644 --- a/iree/tools/run_lit.ps1 +++ b/iree/tools/run_lit.ps1
@@ -63,6 +63,7 @@ } $test_line = $test_line.Substring("// RUN: ".Length) $test_line = $test_line -replace "%s", $test_file + $test_line = $test_line -replace "`"", "\`"" Write-Host -ForegroundColor Blue "Running test command:" Write-Host -ForegroundColor Yellow "$test_line" & $bashExe -c $test_line | Out-Default
diff --git a/iree/vm/buffer.c b/iree/vm/buffer.c index b6b2117..d433a9f 100644 --- a/iree/vm/buffer.c +++ b/iree/vm/buffer.c
@@ -95,8 +95,7 @@ // The actual buffer payload is prefixed with the buffer type so we need only // a single allocation. - iree_host_size_t prefix_size = - iree_host_align(sizeof(iree_vm_buffer_t), iree_max_align_t); + iree_host_size_t prefix_size = iree_sizeof_struct(**out_buffer); iree_host_size_t total_size = prefix_size + length; // Allocate combined [prefix | buffer] memory.
diff --git a/iree/vm/bytecode_dispatch_util.h b/iree/vm/bytecode_dispatch_util.h index 2e692be..596968e 100644 --- a/iree/vm/bytecode_dispatch_util.h +++ b/iree/vm/bytecode_dispatch_util.h
@@ -153,34 +153,12 @@ // Bytecode data access macros for reading values of a given type from a byte // offset within the current function. -#if defined(IREE_ENDIANNESS_LITTLE) -#define OP_I8(i) bytecode_data[pc + (i)] -#define OP_I16(i) *((uint16_t*)&bytecode_data[pc + (i)]) -#define OP_I32(i) *((uint32_t*)&bytecode_data[pc + (i)]) -#define OP_I64(i) *((uint64_t*)&bytecode_data[pc + (i)]) -#define OP_F32(i) *((float*)&bytecode_data[pc + (i)]) -#define OP_F64(i) *((double*)&bytecode_data[pc + (i)]) -#else -#define OP_I8(i) bytecode_data[pc + (i)] -#define OP_I16(i) \ - ((uint16_t)bytecode_data[pc + 0 + (i)]) | \ - ((uint16_t)bytecode_data[pc + 1 + (i)] << 8) -#define OP_I32(i) \ - ((uint32_t)bytecode_data[pc + 0 + (i)]) | \ - ((uint32_t)bytecode_data[pc + 1 + (i)] << 8) | \ - ((uint32_t)bytecode_data[pc + 2 + (i)] << 16) | \ - ((uint32_t)bytecode_data[pc + 3 + (i)] << 24) -#define OP_I64(i) \ - ((uint64_t)bytecode_data[pc + 0 + (i)]) | \ - ((uint64_t)bytecode_data[pc + 1 + (i)] << 8) | \ - ((uint64_t)bytecode_data[pc + 2 + (i)] << 16) | \ - ((uint64_t)bytecode_data[pc + 3 + (i)] << 24) | \ - ((uint64_t)bytecode_data[pc + 4 + (i)] << 32) | \ - ((uint64_t)bytecode_data[pc + 5 + (i)] << 40) | \ - ((uint64_t)bytecode_data[pc + 6 + (i)] << 48) | \ - ((uint64_t)bytecode_data[pc + 7 + (i)] << 56) -#error "TODO: OP_F32 and OP_F64 for big endian systems" -#endif // IREE_ENDIANNESS_LITTLE +#define OP_I8(i) iree_unaligned_load_le((uint8_t*)&bytecode_data[pc + (i)]) +#define OP_I16(i) iree_unaligned_load_le((uint16_t*)&bytecode_data[pc + (i)]) +#define OP_I32(i) iree_unaligned_load_le((uint32_t*)&bytecode_data[pc + (i)]) +#define OP_I64(i) iree_unaligned_load_le((uint64_t*)&bytecode_data[pc + (i)]) +#define OP_F32(i) iree_unaligned_load_le((float*)&bytecode_data[pc + (i)]) +#define OP_F64(i) iree_unaligned_load_le((double*)&bytecode_data[pc + (i)]) //===----------------------------------------------------------------------===// // Utilities matching the tablegen op encoding scheme
diff --git a/iree/vm/bytecode_module.c b/iree/vm/bytecode_module.c index 0feedd8..e88972d 100644 --- a/iree/vm/bytecode_module.c +++ b/iree/vm/bytecode_module.c
@@ -761,9 +761,8 @@ iree_vm_bytecode_module_t* module = NULL; IREE_RETURN_AND_END_ZONE_IF_ERROR( - z0, iree_allocator_malloc( - allocator, sizeof(iree_vm_bytecode_module_t) + type_table_size, - (void**)&module)); + z0, iree_allocator_malloc(allocator, sizeof(*module) + type_table_size, + (void**)&module)); module->allocator = allocator; iree_vm_FunctionDescriptor_vec_t function_descriptors = @@ -782,8 +781,6 @@ module->def = module_def; module->type_count = iree_vm_TypeDef_vec_len(type_defs); - module->type_table = (iree_vm_type_def_t*)((uint8_t*)module + - sizeof(iree_vm_bytecode_module_t)); iree_status_t resolve_status = iree_vm_bytecode_module_resolve_types(type_defs, module->type_table); if (!iree_status_is_ok(resolve_status)) {
diff --git a/iree/vm/bytecode_module_impl.h b/iree/vm/bytecode_module_impl.h index d602194..a63cbf8 100644 --- a/iree/vm/bytecode_module_impl.h +++ b/iree/vm/bytecode_module_impl.h
@@ -67,7 +67,7 @@ // Type table mapping module type IDs to registered VM types. iree_host_size_t type_count; - iree_vm_type_def_t* type_table; + iree_vm_type_def_t type_table[]; } iree_vm_bytecode_module_t; // A resolved and split import in the module state table.
diff --git a/iree/vm/context.c b/iree/vm/context.c index fa56039..2652fbf 100644 --- a/iree/vm/context.c +++ b/iree/vm/context.c
@@ -321,12 +321,12 @@ // TODO(benvanik): tune list growth for module count >> 4. new_capacity = context->list.capacity * 2; } - iree_vm_module_t** new_module_list; + iree_vm_module_t** new_module_list = NULL; IREE_RETURN_AND_END_ZONE_IF_ERROR( z0, iree_allocator_malloc(context->allocator, sizeof(iree_vm_module_t*) * new_capacity, (void**)&new_module_list)); - iree_vm_module_state_t** new_module_state_list; + iree_vm_module_state_t** new_module_state_list = NULL; IREE_RETURN_AND_END_ZONE_IF_ERROR( z0, iree_allocator_malloc(context->allocator,
diff --git a/iree/vm/instance.c b/iree/vm/instance.c index 34fee0c..9d1f4c6 100644 --- a/iree/vm/instance.c +++ b/iree/vm/instance.c
@@ -27,8 +27,8 @@ iree_vm_instance_t* instance = NULL; IREE_RETURN_AND_END_ZONE_IF_ERROR( - z0, iree_allocator_malloc(allocator, sizeof(iree_vm_instance_t), - (void**)&instance)); + z0, + iree_allocator_malloc(allocator, sizeof(*instance), (void**)&instance)); instance->allocator = allocator; iree_atomic_ref_count_init(&instance->ref_count);
diff --git a/iree/vm/list.c b/iree/vm/list.c index e6e784d..d668969 100644 --- a/iree/vm/list.c +++ b/iree/vm/list.c
@@ -166,7 +166,7 @@ iree_allocator_t allocator, iree_vm_list_t** out_list) { iree_vm_list_t* list = NULL; IREE_RETURN_IF_ERROR( - iree_allocator_malloc(allocator, sizeof(iree_vm_list_t), (void**)&list)); + iree_allocator_malloc(allocator, sizeof(*list), (void**)&list)); memset(list, 0, sizeof(*list)); iree_atomic_ref_count_init(&list->ref_object.counter); list->allocator = allocator;
diff --git a/iree/vm/native_module.c b/iree/vm/native_module.c index 1926ab2..7e90813 100644 --- a/iree/vm/native_module.c +++ b/iree/vm/native_module.c
@@ -360,8 +360,8 @@ // to expose this via a query_size function so that we could adjust the size // of our storage independent of the definition of the user module. iree_vm_native_module_t* module = NULL; - IREE_RETURN_IF_ERROR(iree_allocator_malloc( - allocator, sizeof(iree_vm_native_module_t), (void**)&module)); + IREE_RETURN_IF_ERROR( + iree_allocator_malloc(allocator, sizeof(*module), (void**)&module)); iree_status_t status = iree_vm_native_module_initialize( interface, module_descriptor, allocator, (iree_vm_module_t*)module);
diff --git a/scripts/git/check_submodule_init.py b/scripts/git/check_submodule_init.py new file mode 100644 index 0000000..0523c9d --- /dev/null +++ b/scripts/git/check_submodule_init.py
@@ -0,0 +1,33 @@ +# Copyright 2021 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import os +import subprocess +import sys + + +def run(): + # No-op if we're not in a git repository. + try: + subprocess.check_call(['git', 'rev-parse', '--is-inside-work-tree'], + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL) + except: + return + + output = os.popen("git submodule status") + submodules = output.readlines() + + for submodule in submodules: + if (submodule.strip()[0] == "-"): + print( + "The git submodule '%s' is not initialized. Please run `git submodule update --init`" + % (submodule.split()[1])) + sys.exit(1) + + +if __name__ == "__main__": + run()