Merge main -> google (#4048)
* abf54564 Merge google -> main (#4045)
* 1e4349ec Merge branch 'main' into google-to-main
* 1f86f0ef Add new tf.keras.applications models (#3958)
* 41004a91 [spirv] Add a better tiled and vectorized convolution path (#3990)
* c4129446 Bump Tracy to get Ben's dynamic zone colors and to get past (#4042)
* 2a25f5f1 Remove input shape hack from kws tests (#3959)
* 359ffa01 Allow fusing batch_matmul and fill (#4038)
* 35e837ed Merge pull request #4008 from MaheshRavishankar/matmul_fusion
* ffdd9d19 Enable e2e execution of fusion of matmul with producer.
* de116a06 Use tile+fuse on CPU side.
* caf91acb Use fusion of sequence of LinalgOps in IREE SPIR-V codegen path.
* 551f254a Add convert f32 to f16 pass (#4010)
* f1b1e581 Use vector transfer forwarding and dead store elimination in IREE (#4032)
* fac8fd1f Add file to track docker :prod digests (#3955)
* 37f0ae80 Tensorlist refactor to remove shape input and support empty reserve behavior (..
* 7585e532 Add relative_artifacts_dir (#3960)
* da538897 Use subprocess.run and types in manage_images.py (#3947)
* fbc63a4f Upgrade PR action to v3 (#3946)
* 83ea4a3c [llvm] Add a PlanConvLoopOrder pass (#3920)
* 1f1c70a0 Merge pull request #4012
* bab1468c Submodule fix
* bab0308d Merge branch 'main' into google-to-main
* 8df7e5cd Sort add_subdirectory and remove todo given it's fixed now (#4029)
* 1bb131cb Add missing mkdir to fix building and publishing documentation (#4030)
* 46949bad Fix VectorizeMemref out-of-bound access (#4031)
* b32de092 Update Linux docs to include an LLVM AOT example (#3165)
* 40da4547 Update docs to recommend enabling/disabling assertions globally. (#4014)
* fc8da035 Update Tracy to get clean shutdown on Android (#4015)
* 12e20efc support running as root on android. (#4013)
* 1a4f469e Merge branch 'main' into google-to-main
* aff3e9cb Add vectorization support for batch_matmul op (#3997)
* 33d4a6be Changing iree/base/math.h to C and adding some PRNG routines. (#4005)
* 73bdbe7b Fixup EmitC include directories (#3983)
* af7b4c39 Refactor finding python3 to use standard cmake find_package (#3778)
* f92602b7 Always emitting framepointers in generated ELFs (#3987)
* 89ef6328 bump Tracy to get Android fixes (#3988)
* 7f989eb2 Disable MLIR crash reproducer on CI in python tests. (#3943)
* 4db3d08c Adding a demonstration of using the flow dialect post-partitioning. (#3701)
* 589dfa7b Remove no-longer-functional flag (#3961)
* 64543172 Fix MacOS builds after hack-and-slash PR (#3962)
* 8fb887e4 Update links to coverage tables (#3956)
* c93facb4 Adding iree_atomic_slist_t type. (#3917)
* bd082ca6 Merge pull request #3874 from google/benvanik-hack-and-slash
* f4f4ea26 Use UnitTestSpec in tf.keras.layers tests (#3935)
* 7b8e9f75 Reverting flatcc to use our own cmake file for cross-compilation.
* 323e1fde Simplify dylib driver switch.
* 55f3de0d Only register VMLA driver in bindings/java/.
* f429916f Fix warning flag on Windows and HAL CTS driver registration. (#3911)
* 7951e228 Drop IREE_DRIVER_MODULES and iree/base:initializer from ModelBuilder.
* 4e111e2f Disable layering_check in iree/hal/drivers/BUILD.
* 4773736d Add package to iree/base/testing/BUILD.
* 7ca321c1 Skipping dylib driver in simple_embedding_test as a hack.
* 692deb59 Overriding the default MLIR -> LLVM module name.
* 513f40e8 Speculative removing nowindows tags (#3615). If there's something that still d..
* cc47813a Removing the broken forward declarations entirely from some codegen code. http..
* fbcad44d Removing _GNU_SOURCE copt for wait_handle.
* c9b10a01 Fixing bad type in hal/api.h (been there for ages!).
* e0c532ec Changing iree::InitializeEnvironment to iree_flags_parse. Preparation for #3814.
* 8886ac07 Removing iree_api_init from the API.
* 132d747c Removing ALWAYSLINK support from cmake.
* 0ed81f6b Removing iree/base/initializer.h.
* 0135343c Changing to an explicit driver registration mechanism. This is required for ha..
* bf091d3c Removing ALWAYSLINK support from external_cc_library.
* d4bb871d Changing iree-tblgen to not require alwayslink.
* 6bc6f90c Removing IREE_COMMON_INCLUDE_DIRS and uses by LLVM/MLIR.
* 3c082aab Removing IREE_COMMON_INCLUDE_DIRS mhlo pollution.
* 2395cb99 Removing emitc usage of IREE_COMMON_INCLUDE_DIRS for now.
* 036bd966 TODOs on future library layout changes.
* c3a13e62 Rearranging iree/vm/ to reduce a public + cc target.
* 493b0e2b Rearranging iree/base build rules. By moving the dynamic_library_test out of t..
* 8fd38bf9 Replacing uses of some absl utilities now in C++14.
* 67863190 Removing unused absl/types/variant.h include.
* 99bd1af5 Replace absl::exchange with iree::exchange to reduce absl/utility dep.
* c1d0ee10 Removing unused PLATFORM_VULKAN_DEPS. It may be needed internally but it shoul..
* 15437f4b Simplifying iree/hal/dylib/ build config.
* e6984a5a Simplifying iree/hal/ build config.
* 10062814 Simplifying iree/hal/vulkan/ build config.
* 827e51b0 Simplifying iree/hal/llvmjit/ build config.
* 9a72f5d1 Simplifying iree/hal/metal/ build config.
* c7a7d726 Simplifying iree/hal/vmla/ build config.
* 90faf21f Adding IREE_TARGET_GUI_LINKOPTS to remove custom linkopts use.
* e5774c30 Remove unused args from flatbuffers_c_library macro.
* 22d16b4d Adding iree/base/flatcc.h to make flatcc easier to include.
* e44dee56 Switching from -DVK_NO_PROTOTYPES to iree/hal/vulkan/vulkan_headers.h.
* 48ca2fe6 Removing build-config setting of _GNU_SOURCE.
* eeb7dde0 Goodbye flatbuffers (well, the C++ ones anyway).
* 9c676a86 Removing all build config/utils related to flatbuffers.
* 49c61213 byte->ubyte in flatbuffer defs.
* c99000a8 Replacing compiler use of VM bytecode module def flatbuffers->flatcc.
* 1bf1e8d7 Replacing runtime use of metal executable flatbuffers->flatcc. Maybe it works?..
* 48aafb89 Replacing runtime use of spirv executable flatbuffers->flatcc.
* 011e9a2d Replacing runtime use of llvmjit executable flatbuffers->flatcc.
* a021062f Replacing runtime use of dylib executable flatbuffers->flatcc.
* 53a05d73 Replacing runtime use of VMLA executable flatbuffers->flatc.
* 99d30a99 Replacing compiler use of HAL executable flatbuffers->flatc.
* 6ebd1b0c Removing unused tag field in metal/spirv.
* bc685ed7 Adding flatcc json support and making iree-dump-module use it.
* 94b11c35 Adding include for flatcc to flat_c_libraries.
* 1172cf1f Removing unused iree::schemas::reflection_data.
* c86281af Removing unneeded flatbuffers copts.
* 7f3a7e3a Fixing various type warnings. We don't today have these warnings enabled in ou..
* c17659fc Refining MSVC warning set to the minimum required and documenting.
* b7c92bf4 Cleaning up MSVC warnings and syncing with bazel warnings.
* 94356d3b Removing legacy repo_utils.bzl.
* 0f0d9c82 Prevent bazel-to-cmake from running on iree/base/CMakeLists.txt for now.
* 36225a4d Centralizing -ldl/-lpthread linkopts (as they were in bazel already).
* 31c4dbb9 Documenting iree_copts with a nice big warning.
* e4740a57 Pass android libraries as actual linkopts.
* 85cdd868 Fixing cmake style issues - prefer `if(` not `if (` please.
* bf4069e3 Sorting copts/linkopts so we can override things.
* 28040cd8 Simplifying VMA build integration.
* 479ef30f Replacing use of PROJECT_SOURCE_DIR/PROJECT_BINARY_DIR. Those use the previous..
PiperOrigin-RevId: 345252359
diff --git a/.github/workflows/update_llvm_dependent_submodules.yml b/.github/workflows/update_llvm_dependent_submodules.yml
index bd0319a..2d66b0a 100644
--- a/.github/workflows/update_llvm_dependent_submodules.yml
+++ b/.github/workflows/update_llvm_dependent_submodules.yml
@@ -42,7 +42,7 @@
echo "TF_SHA=$(git submodule status third_party/tensorflow | awk '{print $1}' | cut -c -12)" >> $GITHUB_ENV
echo "LLVM_BAZEL_SHA=$(git submodule status third_party/llvm-bazel | awk '{print $1}' | cut -c -12)" >> $GITHUB_ENV
- name: Creating Pull Request
- uses: peter-evans/create-pull-request@v2
+ uses: peter-evans/create-pull-request@v3
with:
# Personal token is required to trigger additional automation (e.g. presubmits).
token: ${{ secrets.GITHUB_WRITE_ACCESS_TOKEN }}
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 3903d75..b37f0e0 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -374,18 +374,17 @@
# Non-LLVM Dependencies
#-------------------------------------------------------------------------------
-# Use the (deprecated) FindPythonInterp/FindPythonLibs functions before
-# any of our dependencies do. See
+# Use the FindPython functions before any of our dependencies do. See
# https://pybind11.readthedocs.io/en/stable/faq.html#inconsistent-detection-of-python-version-in-cmake-and-pybind11
# If one dependency finds Python 2 (the default),
# any others that try to find Python 3 will fail.
# (Also come on, it's $CURRENT_YEAR - please just use Python 3 already.)
if(${IREE_BUILD_COMPILER} OR ${IREE_BUILD_PYTHON_BINDINGS})
- find_package(PythonInterp 3 REQUIRED)
+ find_package(Python3 COMPONENTS Interpreter REQUIRED)
endif()
if(${IREE_BUILD_PYTHON_BINDINGS})
# Note: Optional because python libs can be manually specified.
- find_package(PythonLibs 3)
+ find_package(Python3 COMPONENTS Interpreter Development REQUIRED)
endif()
include(external_cc_library)
diff --git a/SUBMODULE_VERSIONS b/SUBMODULE_VERSIONS
index 331c7b0..517435a 100644
--- a/SUBMODULE_VERSIONS
+++ b/SUBMODULE_VERSIONS
@@ -3,7 +3,7 @@
63b254577ed77a8004a9be6ac707f3dccc4e1fd9 third_party/cpuinfo
4c13807b7d43ff0946b7ffea0ae3aee9e611d778 third_party/dear_imgui
4fb0ff7069bd88ee85902f4d0bb62794e5f6d021 third_party/flatcc
-f2fb48c3b3d79a75a88a99fba6576b25d42ec528 third_party/googletest
+b1fbd33c06cdb0024c67733c6fdec2009d17b384 third_party/googletest
ff0c2096ee35d3a14eaf425e6bbd3a36ab3a5a4a third_party/llvm-bazel
c266c56d545dfecf767b312771f716b394c5d5eb third_party/llvm-project
55801f03f9cc69abfcf8b508a873f702c11b3b5f third_party/mlir-emitc
@@ -14,6 +14,6 @@
685f86471e9d26b3eb7676695a2e2cefb4551ae9 third_party/spirv_cross
f8bf11a0253a32375c32cad92c841237b96696c0 third_party/spirv_headers
431fc9d043c52eb0990d7efdb19546fa31c1bf1d third_party/tensorflow
-d7059eca6351546d1f51e248fc75e49dfeee709e third_party/tracy
+9c3dac3ed2bd647b8d63f197fed058fee97a7e1e third_party/tracy
9bd3f561bcee3f01d22912de10bb07ce4e23d378 third_party/vulkan_headers
3528e2aed3e8808f33e1e7d63eeb1560456a605a third_party/vulkan_memory_allocator
diff --git a/build_tools/cmake/build_docs.sh b/build_tools/cmake/build_docs.sh
index 57b4aaa..24d3f22 100755
--- a/build_tools/cmake/build_docs.sh
+++ b/build_tools/cmake/build_docs.sh
@@ -49,6 +49,7 @@
ninja iree-doc iree_tools_iree-opt
cd ${ROOT_DIR}
+mkdir -p ${BUILD_DIR}/doc/
# Copy docs in source tree over
cp README.md ${BUILD_DIR}/doc/index.md
cp -rf docs/* ${BUILD_DIR}/doc/
diff --git a/build_tools/cmake/iree_multipy.cmake b/build_tools/cmake/iree_multipy.cmake
index f773925..1ac1028 100644
--- a/build_tools/cmake/iree_multipy.cmake
+++ b/build_tools/cmake/iree_multipy.cmake
@@ -22,7 +22,7 @@
# Configure the defaults.
# Note that this is using the pybind11 configuration vars, which creates
# a fragile dependency. It would be better to derive these locally.
- if(PYTHONLIBS_FOUND)
+ if(Python3_FOUND)
set(IREE_MULTIPY_DEFAULT_EXECUTABLE "${PYTHON_EXECUTABLE}" CACHE INTERNAL "Python executable" )
set(IREE_MULTIPY_DEFAULT_INCLUDE_DIRS "${PYTHON_INCLUDE_DIRS}" CACHE INTERNAL "Python include dirs" )
set(IREE_MULTIPY_DEFAULT_LIBRARIES "${PYTHON_LIBRARIES}" CACHE INTERNAL "Python libraries")
diff --git a/build_tools/docker/README.md b/build_tools/docker/README.md
index 3172c86..3fd42d6 100644
--- a/build_tools/docker/README.md
+++ b/build_tools/docker/README.md
@@ -82,7 +82,7 @@
2. Build the image, push the image to GCR and update all references to the image
with the new GCR digest:
- ```shell
+ ```shell
python3 build_tools/docker/manage_images.py \
--image "${IMAGE?}" --build \
--tag latest \
@@ -99,7 +99,7 @@
digest references to test the new images.
5. Merge your PR after is approved and all CI tests pass. **Please remember to
- complete the rest of the steps below**.
+ complete the step below**.
### Part 3. Updating the `:prod` tag
@@ -108,38 +108,13 @@
in GCR. This also makes development significantly easier for others who need to
modify the `docker` images.
-6. On the `main` branch, build (but don't push) the images and locally tag them
- with the `:prod` tag:
-
- ```shell
- python3 build_tools/docker/manage_images.py \
- --image "${IMAGE?}" --build \
- --tag prod \
- --update_references
- ```
-
- This build should be entirely cache hits.
-7. We include `--update_references` in the command above so that we can check
- that none of the images or references to them have been changed. Check that
- the following command produces no output before continuing:
-
- ```shell
- git status --porcelain
- ```
-
- If the output is not empty then you'll need to find the source of the
- discrepancy (e.g. a locally modified `Dockerfile`) and remove it, and repeat
- steps 5 and 6 before continuing. (This relies on you keeping your local copy
- of the Docker images. If you didn't, you'll have to manually pull the missing
- images by their digest).
-8. Now that we've confirmed that none of the images were changed, we can push
- them to GCR with the `:prod` tag.
+6. We use `build_tools/docker/prod_digests.txt` as a source of truth for which
+ versions of the images on GCR should have the `:prod` tag. The following
+ command will ensure that you are at upstream HEAD on the `main` branch before
+ it updates the tags.
```shell
- python3 build_tools/docker/manage_images.py \
- --image "${IMAGE?}" \
- --tag prod \
- --push
+ python3 build_tools/docker/manage_prod.py
```
## Debugging
@@ -150,13 +125,13 @@
GCR).
```shell
-# Pull all :prod images
-python3 build_tools/docker/manage_images.py --images all --pull --tag prod
+# Pull all images that should have :prod tags. (They won't if someone ignores
+# step 6 above, but the images that this command pulls are correct regardless).
+python3 build_tools/docker/manage_prod.py --pull_only
+
# Update the :latest images to match the :prod images.
-# If you have a clean workspace this _shouldn't_ require building anything as
-# everything should be cache hits from the :prod images downloaded above, but if
-# the :prod images are behind then that will not be the case and this may take
-# several hours (depending on your machine).
+# If you have a clean workspace this shouldn't require building anything as
+# everything should be cache hits from the :prod images downloaded above.
python3 build_tools/docker/manage_images.py \
--images all --build \
--tag latest \
diff --git a/build_tools/docker/__init__.py b/build_tools/docker/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/build_tools/docker/__init__.py
diff --git a/build_tools/docker/manage_images.py b/build_tools/docker/manage_images.py
index ad5c3a5..4093c68 100755
--- a/build_tools/docker/manage_images.py
+++ b/build_tools/docker/manage_images.py
@@ -29,9 +29,6 @@
transitively on depend on it, but don't take side-effecting actions:
python3 build_tools/docker/manage_images.py --build --image cmake --dry-run
-Push all `prod` images to GCR:
- python3 build_tools/docker/manage_images.py --push --tag prod --images all
-
Rebuild and push all images and update references to them in the repository:
python3 build_tools/docker/manage_images.py --push --images all
--update-references
@@ -44,9 +41,13 @@
import re
import subprocess
import sys
+from typing import List, Sequence, Union
+
+import utils
IREE_GCR_URL = 'gcr.io/iree-oss/'
-DOCKER_DIR = 'build_tools/docker/'
+DIGEST_REGEX = r'sha256:[a-zA-Z0-9]+'
+DOCKER_DIR = 'build_tools/docker/'.replace('/', os.sep)
# Map from image names to images that they depend on.
IMAGES_TO_DEPENDENCIES = {
@@ -88,13 +89,6 @@
required=True,
action='append',
help=f'Name of the image to build: {IMAGES_HELP}.')
- parser.add_argument(
- '--tag',
- type=str,
- default='latest',
- help='Tag for the images to build. Defaults to `latest` (which is good '
- 'for testing changes in a PR). Use `prod` to update the images that the '
- 'CI caches.')
parser.add_argument('--pull',
action='store_true',
help='Pull the specified image before building.')
@@ -130,144 +124,148 @@
return args
-def get_ordered_images_to_process(images):
- unmarked_images = list(images)
- # Python doesn't have a builtin OrderedSet
- marked_images = set()
- order = []
+def get_ordered_images_to_process(images: Sequence[str]) -> List[str]:
+ # Python doesn't have a builtin OrderedSet, so we mimic one to the extent
+ # that we need by using 'in' before adding any elements.
+ processing_order = []
- def visit(image):
- if image in marked_images:
- return
- for dependent_images in IMAGES_TO_DEPENDENT_IMAGES[image]:
- visit(dependent_images)
- marked_images.add(image)
- order.append(image)
+ def add_dependent_images(image: str):
+ if image not in processing_order:
+ for dependent_image in IMAGES_TO_DEPENDENT_IMAGES[image]:
+ add_dependent_images(dependent_image)
+ processing_order.append(image)
- while unmarked_images:
- visit(unmarked_images.pop())
+ for image in images:
+ add_dependent_images(image)
- order.reverse()
- return order
+ processing_order.reverse()
+ return processing_order
-def stream_command(command, dry_run=False):
+def run_command(command: Sequence[str],
+ dry_run: bool = False,
+ check: bool = True,
+ capture_output: bool = False,
+ universal_newlines: bool = True,
+ **run_kwargs) -> subprocess.CompletedProcess:
+ """Thin wrapper around subprocess.run"""
print(f'Running: `{" ".join(command)}`')
if dry_run:
- return 0
- process = subprocess.Popen(command,
- bufsize=1,
- stderr=subprocess.STDOUT,
- stdout=subprocess.PIPE,
- universal_newlines=True)
- for line in process.stdout:
- print(line, end='')
+ # Dummy CompletedProess with successful returncode.
+ return subprocess.CompletedProcess(command, returncode=0)
- if process.poll() is None:
- raise RuntimeError('Unexpected end of output while process is not finished')
- return process.poll()
+ if capture_output:
+ # Hardcode support for python <= 3.6.
+ run_kwargs['stdout'] = subprocess.PIPE
+ run_kwargs['stderr'] = subprocess.PIPE
+ return subprocess.run(command,
+ universal_newlines=universal_newlines,
+ check=check,
+ **run_kwargs)
-def check_stream_command(command, dry_run=False):
- exit_code = stream_command(command, dry_run=dry_run)
- if exit_code != 0:
- print(f'Command failed with exit code {exit_code}: `{" ".join(command)}`')
- sys.exit(exit_code)
-
-
-def get_repo_digest(image):
+def get_repo_digest(tagged_image_url: str) -> str:
inspect_command = [
'docker',
'image',
'inspect',
- f'{image}',
+ tagged_image_url,
'-f',
'{{index .RepoDigests 0}}',
]
- inspect_process = subprocess.run(inspect_command,
- universal_newlines=True,
- stdout=subprocess.PIPE,
- stderr=subprocess.PIPE,
- timeout=10)
- if inspect_process.returncode != 0:
- print(f'Computing the repository digest for {image} failed.'
- ' Has it been pushed to GCR?')
- print(f'Output from `{" ".join(inspect_command)}`:')
- print(inspect_process.stdout, end='')
- print(inspect_process.stderr, end='')
- sys.exit(inspect_process.returncode)
- _, repo_digest = inspect_process.stdout.strip().split('@')
+ try:
+ completed_process = utils.run_command(
+ inspect_command,
+ dry_run=False, # Run even if --dry_run is True.
+ capture_output=True,
+ timeout=10)
+ except subprocess.CalledProcessError as error:
+ raise RuntimeError(f'Computing the repository digest for {tagged_image_url}'
+ ' failed. Has it been pushed to GCR?') from error
+ _, repo_digest = completed_process.stdout.strip().split('@')
return repo_digest
-def update_rbe_reference(digest, dry_run=False):
+def update_rbe_reference(digest: str, dry_run: bool = False):
print('Updating WORKSPACE file for rbe-toolchain')
+ digest_updates = 0
for line in fileinput.input(files=['WORKSPACE'], inplace=(not dry_run)):
if line.strip().startswith('digest ='):
- print(re.sub('sha256:[a-zA-Z0-9]+', digest, line), end='')
+ digest_updates += 1
+ print(re.sub(DIGEST_REGEX, digest, line), end='')
else:
print(line, end='')
+ if digest_updates > 1:
+ raise RuntimeError(
+ "There is more than one instance of 'digest =' in the WORKSPACE file. "
+ "This means that more than just the 'rbe_toolchain' digest was "
+ "overwritten, and the file should be restored.")
-def update_references(image_name, digest, dry_run=False):
- print(f'Updating references to {image_name}')
- grep_command = ['git', 'grep', '-l', f'{image_name}@sha256']
- grep_process = subprocess.run(grep_command,
- stdout=subprocess.PIPE,
- stderr=subprocess.PIPE,
- timeout=5,
- universal_newlines=True)
- if grep_process.returncode > 1:
- print(f'{" ".join(grep_command)} '
- f'failed with exit code {grep_process.returncode}')
- sys.exit(grep_process.returncode)
- if grep_process.returncode == 1:
- print(f'Found no references to {image_name}')
- return
+def update_references(image_url: str, digest: str, dry_run: bool = False):
+ """Updates all references to 'image_url' with a sha256 digest."""
+ print(f'Updating references to {image_url}')
- files = grep_process.stdout.split()
+ grep_command = ['git', 'grep', '-l', f'{image_url}@sha256']
+ try:
+ completed_process = run_command(grep_command,
+ capture_output=True,
+ timeout=5)
+ except subprocess.CalledProcessError as error:
+ if error.returncode == 1:
+ print(f'Found no references to {image_url}')
+ return
+ raise error
+
+ # Update references in all grepped files.
+ files = completed_process.stdout.split()
print(f'Updating references in {len(files)} files: {files}')
for line in fileinput.input(files=files, inplace=(not dry_run)):
- print(re.sub(f'{image_name}@sha256:[a-zA-Z0-9]+', f'{image_name}@{digest}',
- line),
+ print(re.sub(f'{image_url}@{DIGEST_REGEX}', f'{image_url}@{digest}', line),
end='')
if __name__ == '__main__':
args = parse_arguments()
- # Ensure the user has the correct authorization if they try to push to GCR.
if args.push:
- if stream_command(['which', 'gcloud']) != 0:
- print('gcloud not found.'
- ' See https://cloud.google.com/sdk/install for installation.')
- sys.exit(1)
- check_stream_command(['gcloud', 'auth', 'configure-docker'],
- dry_run=args.dry_run)
+ # Ensure the user has the correct authorization if they try to push to GCR.
+ utils.check_gcloud_auth(dry_run=args.dry_run)
images_to_process = get_ordered_images_to_process(args.images)
print(f'Also processing dependent images. Will process: {images_to_process}')
for image in images_to_process:
print(f'Processing image {image}')
- image_name = posixpath.join(IREE_GCR_URL, image)
- image_tag = f'{image_name}:{args.tag}'
+ image_url = posixpath.join(IREE_GCR_URL, image)
+ tagged_image_url = f'{image_url}:latest'
image_path = os.path.join(DOCKER_DIR, image)
if args.pull:
- check_stream_command(['docker', 'pull', image_tag], dry_run=args.dry_run)
+ utils.run_command(['docker', 'pull', tagged_image_url], args.dry_run)
if args.build:
- check_stream_command(['docker', 'build', '--tag', image_tag, image_path],
- dry_run=args.dry_run)
+ utils.run_command(
+ ['docker', 'build', '--tag', tagged_image_url, image_path],
+ args.dry_run)
if args.push:
- check_stream_command(['docker', 'push', image_tag], dry_run=args.dry_run)
+ utils.run_command(['docker', 'push', tagged_image_url], args.dry_run)
if args.update_references:
- digest = get_repo_digest(image_tag)
+ digest = get_repo_digest(tagged_image_url)
+
+ # Check that the image is in 'prod_digests.txt' and append it to the list
+ # in the file if it isn't. We know that the GCR digest exists at this
+ # point because 'get_repo_digest' confirms that the image has been pushed.
+ with open(utils.PROD_DIGESTS_PATH, 'r') as f:
+ in_prod_digests = f'{image_url}@' in f.read()
+ if not in_prod_digests:
+ with open(utils.PROD_DIGESTS_PATH, 'a') as f:
+ f.write(f'{image_url}@{digest}\n')
+
# Just hardcode this oddity
if image == 'rbe-toolchain':
update_rbe_reference(digest, dry_run=args.dry_run)
- update_references(image_name, digest, dry_run=args.dry_run)
+ update_references(image_url, digest, dry_run=args.dry_run)
diff --git a/build_tools/docker/manage_prod.py b/build_tools/docker/manage_prod.py
new file mode 100644
index 0000000..3f26b00
--- /dev/null
+++ b/build_tools/docker/manage_prod.py
@@ -0,0 +1,65 @@
+#!/usr/bin/env python3
+
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Uses prod_digests.txt to update GCR's :prod tags.
+
+Usage:
+ Pull all images that should have :prod tags, tag the with :prod and push
+ them to GCR. This will make sure that you are at upstream head on the main
+ branch before pushing:
+ python3 build_tools/docker/manage_prod.py
+
+ Pull all images that should have :prod tags:
+ python3 build_tools/docker/manage_prod.py --pull_only
+"""
+
+import argparse
+import os
+import utils
+
+
+def parse_arguments():
+ """Parses command-line options."""
+ parser = argparse.ArgumentParser(
+ description="Pull and push the images in prod_digests.txt to GCR.")
+ parser.add_argument("--pull_only",
+ "--pull-only",
+ action="store_true",
+ help="Pull but do not tag or push the images.")
+ return parser.parse_args()
+
+
+if __name__ == "__main__":
+ args = parse_arguments()
+
+ if not args.pull_only:
+ # Ensure the user has the correct authorization if they try to push to GCR.
+ utils.check_gcloud_auth()
+
+ # Only allow the :prod tag to be pushed from the version of
+ # `prod_digests.txt` at upstream HEAD on the main branch.
+ utils.run_command([os.path.normpath("scripts/git/git_update.sh"), "main"])
+
+ with open(utils.PROD_DIGESTS_PATH, "r") as f:
+ images_with_digests = [line.strip() for line in f.readlines()]
+
+ for image_with_digest in images_with_digests:
+ image_url, _ = image_with_digest.split("@")
+ tagged_image_url = f"{image_url}:prod"
+
+ utils.run_command(["docker", "pull", image_with_digest])
+ if not args.pull_only:
+ utils.run_command(["docker", "tag", image_with_digest, tagged_image_url])
+ utils.run_command(["docker", "push", tagged_image_url])
diff --git a/build_tools/docker/prod_digests.txt b/build_tools/docker/prod_digests.txt
new file mode 100644
index 0000000..4ee200a
--- /dev/null
+++ b/build_tools/docker/prod_digests.txt
@@ -0,0 +1,17 @@
+gcr.io/iree-oss/base@sha256:392b2f865f000c6fb558d01a372446f3ab81120db34249f03efa999669647230
+gcr.io/iree-oss/util@sha256:ec9198493cea4f5d9ac7097e8a64b94b7a43628cb995b91e6e89a95cff4a1982
+gcr.io/iree-oss/cmake@sha256:ceaff365ca0cd3d770daf5fad370e29783e30b654f56780761a6d0a040da45e5
+gcr.io/iree-oss/swiftshader@sha256:3ed32e7c74da71b6db1904b583827e760ea845d5d2876b38c036cf72ca6e5623
+gcr.io/iree-oss/cmake-python@sha256:5fa42743c458a7df680175542269067fd89d2072b776b43e48169a7d0b43ebc3
+gcr.io/iree-oss/cmake-android@sha256:7accda0b84e2ae337740f2ee71801ee30f2155900abf1cf7b73ea47c15dc694f
+gcr.io/iree-oss/bazel@sha256:59da17e5cc8176890a6e1bda369b1f3d398e27af3d47e02e1ffd5b76729c215b
+gcr.io/iree-oss/bazel-python@sha256:473b7e294136bc38abc1941042f0c0404199de5827f141520f0b6757305b7a95
+gcr.io/iree-oss/bazel-tensorflow@sha256:6ec501edcbaaf817941c5be5060cafc47616ca4e2a875bbb62944ffbc396ceb0
+gcr.io/iree-oss/vulkan@sha256:c2e21657a231f3e39c50c01c3cbae3355f5b03ff52033b41ad322a0c792099dd
+gcr.io/iree-oss/rbe-toolchain@sha256:d6d895294076b5289e81489f664656211c41656cffe7c448ecb5c6f54f045974
+gcr.io/iree-oss/cmake-python-vulkan@sha256:63db8f65485e73af8a16603729bf39b4e616b6cb90216a1589ba2919489a6483
+gcr.io/iree-oss/cmake-python-swiftshader@sha256:3e3d3427f3a58b32fa3ed578b610e411e0b81fd0e1984ac9b0fceae8bf8343dc
+gcr.io/iree-oss/cmake-python-nvidia@sha256:310e3b399717905bb2b485f3ebed32222915c7dc4dc075aa4e1b8551101fe607
+gcr.io/iree-oss/bazel-tensorflow-vulkan@sha256:61522fcfcd11cd9c067e991b72419d6decf70dae35b8ee3efa71e55ca31b8866
+gcr.io/iree-oss/bazel-tensorflow-swiftshader@sha256:39c0e43c503bddfacd69758a50f02450ad2322d35324e2f56997aebb33a1b20a
+gcr.io/iree-oss/bazel-tensorflow-nvidia@sha256:e5e96ec1709e83355ee2264c97c26fa5c3d40f749a62734f4787b17a83f2c3b8
diff --git a/build_tools/docker/utils.py b/build_tools/docker/utils.py
new file mode 100644
index 0000000..f9a041b
--- /dev/null
+++ b/build_tools/docker/utils.py
@@ -0,0 +1,55 @@
+#!/usr/bin/env python3
+
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import subprocess
+from typing import Sequence
+
+PROD_DIGESTS_PATH = "build_tools/docker/prod_digests.txt".replace("/", os.sep)
+
+
+def run_command(command: Sequence[str],
+ dry_run: bool = False,
+ check: bool = True,
+ capture_output: bool = False,
+ universal_newlines: bool = True,
+ **run_kwargs) -> subprocess.CompletedProcess:
+ """Thin wrapper around subprocess.run"""
+ print(f"Running: `{' '.join(command)}`")
+ if not dry_run:
+ if capture_output:
+ # Hardcode support for python <= 3.6.
+ run_kwargs["stdout"] = subprocess.PIPE
+ run_kwargs["stderr"] = subprocess.PIPE
+
+ completed_process = subprocess.run(command,
+ universal_newlines=universal_newlines,
+ check=check,
+ **run_kwargs)
+ return completed_process
+ # Dummy CompletedProess with successful returncode.
+ return subprocess.CompletedProcess(command, returncode=0)
+
+
+def check_gcloud_auth(dry_run: bool = False):
+ # Ensure the user has the correct authorization if they try to push to GCR.
+ try:
+ run_command(['which', 'gcloud'])
+ except subprocess.CalledProcessError as error:
+ raise RuntimeError(
+ 'gcloud not found. See https://cloud.google.com/sdk/install for '
+ 'installation.') from error
+ run_command(["gcloud", "auth", "configure-docker"], dry_run)
diff --git a/docs/get_started/getting_started_linux_bazel.md b/docs/get_started/getting_started_linux_bazel.md
index a05ebb6..77f7c6c 100644
--- a/docs/get_started/getting_started_linux_bazel.md
+++ b/docs/get_started/getting_started_linux_bazel.md
@@ -90,8 +90,9 @@
# and with assertions enabled.
build:debug --config=asserts --compilation_mode=opt '--per_file_copt=iree|llvm@-O0' --strip=never
-# Use --config=asserts to enable assertions in IREE and LLVM.
-build:asserts --compilation_mode=opt '--per_file_copt=iree|llvm@-UNDEBUG'
+# Use --config=asserts to enable assertions. This has to be done globally:
+# Code compiled with and without assertions can't be linked together (ODR violation).
+build:asserts --compilation_mode=opt '--copt=-UNDEBUG'
```
## What's next?
diff --git a/docs/get_started/getting_started_linux_cmake.md b/docs/get_started/getting_started_linux_cmake.md
index 04f499f..83d0fb2 100644
--- a/docs/get_started/getting_started_linux_cmake.md
+++ b/docs/get_started/getting_started_linux_cmake.md
@@ -107,6 +107,39 @@
-function-input="i32=-2" -iree-hal-target-backends=vmla -print-mlir
```
+### LLVM Ahead-of-Time (AOT) backend
+
+To compile IREE LLVM AOT (vs JIT) module we need to set the AOT linker path environment variable:
+
+```shell
+$ export IREE_LLVMAOT_LINKER_PATH=ld.lld-10
+```
+
+Translate a source MLIR into an IREE module:
+
+```shell
+# Assuming in IREE source root
+$ ./build/iree/tools/iree-translate \
+ -iree-mlir-to-vm-bytecode-module \
+ -iree-llvm-target-triple=x86_64-linux-gnu \
+ -iree-hal-target-backends=dylib-llvm-aot \
+ iree/tools/test/simple.mlir \
+ -o /tmp/simple-llvm_aot.vmfb
+```
+
+Then run the compiled module using the `dylib` HAL driver:
+
+```shell
+$ ./build/iree/tools/iree-run-module -driver=dylib \
+ -input_file=/tmp/simple-llvm_aot.vmfb \
+ -entry_function=abs \
+ -inputs="i32=-5"
+
+EXEC @abs
+i32=5
+```
+
+
### Further Reading
* For an introduction to IREE's project structure and developer tools, see
diff --git a/docs/get_started/getting_started_macos_bazel.md b/docs/get_started/getting_started_macos_bazel.md
index 71b6818..0087166 100644
--- a/docs/get_started/getting_started_macos_bazel.md
+++ b/docs/get_started/getting_started_macos_bazel.md
@@ -93,8 +93,9 @@
# and with assertions enabled.
build:debug --config=asserts --compilation_mode=opt '--per_file_copt=iree|llvm@-O0' --strip=never
-# Use --config=asserts to enable assertions in IREE and LLVM.
-build:asserts --compilation_mode=opt '--per_file_copt=iree|llvm@-UNDEBUG'
+# Use --config=asserts to enable assertions. This has to be done globally:
+# Code compiled with and without assertions can't be linked together (ODR violation).
+build:asserts --compilation_mode=opt '--copt=-UNDEBUG'
```
## What's next?
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py
index 313be88..1961a9a 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py
@@ -55,7 +55,7 @@
DEFAULT_INPUT_GENERATOR = tf_utils.uniform
-def _setup_artifacts_dir(module_name: str) -> str:
+def _setup_artifacts_dir(relative_artifacts_dir: str) -> str:
parent_dirs = [
FLAGS.artifacts_dir,
os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR'),
@@ -65,7 +65,7 @@
# Use the most preferred path in parent_dirs that isn't None.
parent_dir = next(parent for parent in parent_dirs if parent is not None)
- artifacts_dir = os.path.join(parent_dir, module_name)
+ artifacts_dir = os.path.join(parent_dir, relative_artifacts_dir)
logging.info("Saving compilation artifacts and traces to '%s'", artifacts_dir)
os.makedirs(artifacts_dir, exist_ok=True)
return artifacts_dir
@@ -125,9 +125,9 @@
_global_modules = None
-def compile_tf_module(
- module_class: Type[tf.Module],
- exported_names: Sequence[str] = ()) -> Modules:
+def compile_tf_module(module_class: Type[tf.Module],
+ exported_names: Sequence[str] = (),
+ relative_artifacts_dir: str = None) -> Modules:
"""Compiles module_class to each backend that we test.
Args:
@@ -135,6 +135,9 @@
exported_names: optional iterable of strings representing which of
module_class's functions to compile. If exported_names is empty all
functions will be compiled.
+ relative_artifacts_dir: optional string specifying where to save compilation
+ artifacts within the artifacts_dir. If it is not specified then
+ module_class.__name__ will be used.
Returns:
A 'Modules' namedtuple containing the reference module, target modules and
@@ -145,7 +148,9 @@
return _global_modules
# Setup the directory for saving compilation artifacts and traces.
- artifacts_dir = _setup_artifacts_dir(module_class.__name__)
+ if relative_artifacts_dir is None:
+ relative_artifacts_dir = module_class.__name__
+ artifacts_dir = _setup_artifacts_dir(relative_artifacts_dir)
# Get the backend information for this test.
ref_backend_info = module_utils.BackendInfo(FLAGS.reference_backend,
diff --git a/integrations/tensorflow/compiler/dialect/tf_tensorlist/conversion/BUILD b/integrations/tensorflow/compiler/dialect/tf_tensorlist/conversion/BUILD
index 8b74864..1a957b7 100644
--- a/integrations/tensorflow/compiler/dialect/tf_tensorlist/conversion/BUILD
+++ b/integrations/tensorflow/compiler/dialect/tf_tensorlist/conversion/BUILD
@@ -36,6 +36,7 @@
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Transforms",
"@org_tensorflow//tensorflow/compiler/mlir/tensorflow",
+ "@org_tensorflow//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
],
alwayslink = 1,
)
@@ -70,6 +71,8 @@
deps = [
"//integrations/tensorflow/compiler/dialect/tf_tensorlist/ir",
"//iree/compiler/Dialect/HAL/Conversion",
+ "//iree/compiler/Dialect/HAL/IR",
+ "//iree/compiler/Dialect/HAL/Utils",
"//iree/compiler/Dialect/Modules/TensorList/IR",
"//iree/compiler/Dialect/Modules/TensorList/IR:TensorListDialect",
"@llvm-project//mlir:IR",
diff --git a/integrations/tensorflow/compiler/dialect/tf_tensorlist/conversion/convert_flow_to_hal.cc b/integrations/tensorflow/compiler/dialect/tf_tensorlist/conversion/convert_flow_to_hal.cc
index 501fbe2..fa38d04 100644
--- a/integrations/tensorflow/compiler/dialect/tf_tensorlist/conversion/convert_flow_to_hal.cc
+++ b/integrations/tensorflow/compiler/dialect/tf_tensorlist/conversion/convert_flow_to_hal.cc
@@ -16,11 +16,123 @@
#include "integrations/tensorflow/compiler/dialect/tf_tensorlist/ir/tf_tensorlist_ops.h"
#include "iree/compiler/Dialect/HAL/Conversion/ConversionTarget.h"
+#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
+#include "iree/compiler/Dialect/HAL/IR/HALTypes.h"
+#include "iree/compiler/Dialect/HAL/Utils/TypeUtils.h"
#include "iree/compiler/Dialect/Modules/TensorList/IR/TensorListOps.h"
+#include "iree/compiler/Dialect/Modules/TensorList/IR/TensorListTypes.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/StandardTypes.h"
namespace mlir {
namespace iree_compiler {
+namespace {
+
+Value getBufferView(Operation *srcOp, Value srcOperand, Value dstOperand,
+ ConversionPatternRewriter &rewriter) {
+ auto operand = IREE::HAL::TensorRewriteAdaptor::getChecked(
+ srcOp->getLoc(), srcOperand, dstOperand, rewriter);
+ if (!operand.hasValue()) {
+ srcOp->emitOpError() << "unable to create adaptor for operand";
+ return nullptr;
+ }
+ auto bufferView = operand->getBufferView();
+ if (!bufferView) {
+ srcOp->emitOpError() << "unable to get buffer view for operand";
+ return nullptr;
+ }
+
+ return bufferView;
+}
+
+class ReserveOpConversion : public OpConversionPattern<tf_tensorlist::Reserve> {
+ public:
+ ReserveOpConversion(MLIRContext *ctx, TypeConverter &converter)
+ : OpConversionPattern(ctx) {}
+
+ LogicalResult matchAndRewrite(
+ tf_tensorlist::Reserve reserveOp, llvm::ArrayRef<Value> newOperands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto elementTy = reserveOp.element_type();
+ auto element_value = IREE::HAL::getElementTypeValue(elementTy).getValue();
+
+ auto operand0 = getBufferView(reserveOp, reserveOp.getOperand(0),
+ newOperands[0], rewriter);
+ auto operand1 = getBufferView(reserveOp, reserveOp.getOperand(1),
+ newOperands[1], rewriter);
+
+ if (!operand0 || !operand1) {
+ return failure();
+ }
+
+ rewriter.replaceOpWithNewOp<IREE::TensorList::Reserve>(
+ reserveOp,
+ IREE::TensorList::TensorListType::get(reserveOp.getContext()), operand0,
+ operand1, rewriter.getI32IntegerAttr(element_value));
+ return success();
+ }
+};
+
+class ConcatOpConversion : public OpConversionPattern<tf_tensorlist::Concat> {
+ public:
+ ConcatOpConversion(MLIRContext *ctx, TypeConverter &converter)
+ : OpConversionPattern(ctx) {}
+
+ LogicalResult matchAndRewrite(
+ tf_tensorlist::Concat concatOp, llvm::ArrayRef<Value> newOperands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto device =
+ rewriter.createOrFold<IREE::HAL::ExSharedDeviceOp>(concatOp.getLoc());
+ auto allocator =
+ rewriter.create<IREE::HAL::DeviceAllocatorOp>(concatOp.getLoc(), device)
+ .getResult();
+
+ auto newConcatOp = rewriter.createOrFold<IREE::TensorList::Concat>(
+ concatOp.getLoc(),
+ IREE::HAL::BufferViewType::get(rewriter.getContext()), allocator,
+ newOperands[0]);
+
+ auto bufferOp = rewriter.createOrFold<IREE::HAL::BufferViewBufferOp>(
+ newConcatOp.getLoc(), newConcatOp);
+
+ rewriter.replaceOp(concatOp, bufferOp);
+ return success();
+ }
+};
+
+class StackOpConversion : public OpConversionPattern<tf_tensorlist::Stack> {
+ public:
+ StackOpConversion(MLIRContext *ctx, TypeConverter &converter)
+ : OpConversionPattern(ctx) {}
+
+ LogicalResult matchAndRewrite(
+ tf_tensorlist::Stack stackOp, llvm::ArrayRef<Value> newOperands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto device =
+ rewriter.createOrFold<IREE::HAL::ExSharedDeviceOp>(stackOp.getLoc());
+ auto allocator =
+ rewriter.create<IREE::HAL::DeviceAllocatorOp>(stackOp.getLoc(), device)
+ .getResult();
+
+ auto operand1 =
+ getBufferView(stackOp, stackOp.getOperand(1), newOperands[1], rewriter);
+ if (!operand1) return failure();
+
+ auto newStackOp = rewriter.createOrFold<IREE::TensorList::Stack>(
+ stackOp.getLoc(), IREE::HAL::BufferViewType::get(rewriter.getContext()),
+ allocator, newOperands[0], operand1);
+
+ auto bufferOp = rewriter.createOrFold<IREE::HAL::BufferViewBufferOp>(
+ stackOp.getLoc(), newStackOp);
+
+ rewriter.replaceOp(stackOp, bufferOp);
+ return success();
+ }
+};
+
+} // namespace
+
void populateTensorListToHALPatterns(MLIRContext *context,
OwningRewritePatternList &patterns,
TypeConverter &typeConverter) {
@@ -29,9 +141,6 @@
// verification or have a specific use case (such as a place where only the
// buffer is required and the shape is not) we could add our own.
patterns.insert<
- HALOpConversion<tf_tensorlist::Reserve, IREE::TensorList::Reserve>>(
- context, typeConverter);
- patterns.insert<
HALOpConversion<tf_tensorlist::GetItem, IREE::TensorList::GetItem>>(
context, typeConverter);
patterns.insert<
@@ -40,12 +149,10 @@
patterns.insert<
HALOpConversion<tf_tensorlist::FromTensor, IREE::TensorList::FromTensor>>(
context, typeConverter);
- patterns
- .insert<HALOpConversion<tf_tensorlist::Concat, IREE::TensorList::Concat>>(
- context, typeConverter);
- patterns
- .insert<HALOpConversion<tf_tensorlist::Stack, IREE::TensorList::Stack>>(
- context, typeConverter);
+
+ patterns.insert<ConcatOpConversion>(context, typeConverter);
+ patterns.insert<ReserveOpConversion>(context, typeConverter);
+ patterns.insert<StackOpConversion>(context, typeConverter);
}
} // namespace iree_compiler
diff --git a/integrations/tensorflow/compiler/dialect/tf_tensorlist/conversion/convert_tf_to_tf_tensorlist.cc b/integrations/tensorflow/compiler/dialect/tf_tensorlist/conversion/convert_tf_to_tf_tensorlist.cc
index 71e42f6..2570e2c 100644
--- a/integrations/tensorflow/compiler/dialect/tf_tensorlist/conversion/convert_tf_to_tf_tensorlist.cc
+++ b/integrations/tensorflow/compiler/dialect/tf_tensorlist/conversion/convert_tf_to_tf_tensorlist.cc
@@ -19,10 +19,26 @@
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
+#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
namespace mlir {
namespace tf_tensorlist {
+namespace {
+TypeAttr GetVariantElementTypeAttr(Type type) {
+ if (auto variantTy = type.dyn_cast<TF::VariantType>()) {
+ return GetVariantElementTypeAttr(variantTy.getSubtypes().front());
+ }
+
+ if (auto shapedTy = type.dyn_cast<ShapedType>()) {
+ return GetVariantElementTypeAttr(shapedTy.getElementType());
+ }
+
+ return TypeAttr::get(type);
+}
+
+} // namespace
+
#include "integrations/tensorflow/compiler/dialect/tf_tensorlist/conversion/convert_tf_to_tf_tensorlist.inc"
class ConvertTfToTfTensorList
diff --git a/integrations/tensorflow/compiler/dialect/tf_tensorlist/conversion/convert_tf_to_tf_tensorlist.td b/integrations/tensorflow/compiler/dialect/tf_tensorlist/conversion/convert_tf_to_tf_tensorlist.td
index 0887605..3968c00 100644
--- a/integrations/tensorflow/compiler/dialect/tf_tensorlist/conversion/convert_tf_to_tf_tensorlist.td
+++ b/integrations/tensorflow/compiler/dialect/tf_tensorlist/conversion/convert_tf_to_tf_tensorlist.td
@@ -17,26 +17,27 @@
include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td"
include "integrations/tensorflow/compiler/dialect/tf_tensorlist/ir/tf_tensorlist_ops.td"
-def : Pat<(TF_TensorListReserveOp $element_shape, $num_elements),
- (TfTensorList_Reserve $element_shape, $num_elements)>;
+// GetElementTypeAttr
+def GetVariantElementTypeAttr : NativeCodeCall<
+ "GetVariantElementTypeAttr($0.getType())">;
+
+def : Pat<(TF_TensorListReserveOp:$result $element_shape, $num_elements),
+ (TfTensorList_Reserve $element_shape, $num_elements,
+ (GetVariantElementTypeAttr $result))>;
def : Pat<(TF_TensorListGetItemOp $input_handle, $index, $element_shape),
- (TfTensorList_GetItem
- $input_handle,
- $index,
- $element_shape)>;
+ (TfTensorList_GetItem $input_handle, $index)>;
def : Pat<(TF_TensorListSetItemOp $input_handle, $index, $item),
(TfTensorList_SetItem $input_handle, $index, $item)>;
def : Pat<(TF_TensorListFromTensorOp $tensor, $element_shape),
- (TfTensorList_FromTensor $tensor, $element_shape)>;
+ (TfTensorList_FromTensor $tensor)>;
def WrapScalarI64InTensor : NativeCodeCall<
"DenseElementsAttr::get(RankedTensorType::get({}, $_builder.getIntegerType(64)), {$_self.getValue()})">;
def : Pat<(TF_TensorListStackOp $input_handle, $element_shape, $num_elements),
(TfTensorList_Stack
$input_handle,
- $element_shape,
(ConstantOp WrapScalarI64InTensor:$num_elements))>;
diff --git a/integrations/tensorflow/compiler/dialect/tf_tensorlist/conversion/test/convert_flow_to_hal.mlir b/integrations/tensorflow/compiler/dialect/tf_tensorlist/conversion/test/convert_flow_to_hal.mlir
index a1fc6ad..4cbf904 100644
--- a/integrations/tensorflow/compiler/dialect/tf_tensorlist/conversion/test/convert_flow_to_hal.mlir
+++ b/integrations/tensorflow/compiler/dialect/tf_tensorlist/conversion/test/convert_flow_to_hal.mlir
@@ -5,7 +5,7 @@
// CHECK: [[VIEW0:%.+]] = hal.buffer_view.create %arg0{{.*}}
// CHECK: [[VIEW1:%.+]] = hal.buffer_view.create %arg1{{.*}}
// CHECK: "tensorlist.Reserve"([[VIEW0]], [[VIEW1]])
- %0 = "tf_tensorlist.Reserve"(%arg0, %arg1) : (tensor<0xi32>, tensor<i32>) -> !tf_tensorlist.list
+ %0 = "tf_tensorlist.Reserve"(%arg0, %arg1) {element_type = f32} : (tensor<0xi32>, tensor<i32>) -> !tf_tensorlist.list
return %0 : !tf_tensorlist.list
}
@@ -18,20 +18,26 @@
return %0 : !tf_tensorlist.list
}
-// CHECK-LABEL: func @GetItem(%arg0: !tensorlist.list, %arg1: !hal.buffer, %arg2: !hal.buffer) -> !hal.buffer {
-func @GetItem(%arg0: !tf_tensorlist.list, %arg1: tensor<i32>, %arg2: tensor<0xi32>) -> tensor<f32> {
+// CHECK-LABEL: func @GetItem(%arg0: !tensorlist.list, %arg1: !hal.buffer) -> !hal.buffer {
+func @GetItem(%arg0: !tf_tensorlist.list, %arg1: tensor<i32>) -> tensor<f32> {
// CHECK: [[VIEW1:%.+]] = hal.buffer_view.create %arg1{{.*}}
-// CHECK: [[VIEW2:%.+]] = hal.buffer_view.create %arg2{{.*}}
-// CHECK: "tensorlist.GetItem"(%arg0, [[VIEW1]], [[VIEW2]])
- %0 = "tf_tensorlist.GetItem"(%arg0, %arg1, %arg2) : (!tf_tensorlist.list, tensor<i32>, tensor<0xi32>) -> tensor<f32>
+// CHECK: "tensorlist.GetItem"(%arg0, [[VIEW1]])
+ %0 = "tf_tensorlist.GetItem"(%arg0, %arg1) : (!tf_tensorlist.list, tensor<i32>) -> tensor<f32>
return %0 : tensor<f32>
}
-// CHECK-LABEL: func @Stack(%arg0: !tensorlist.list, %arg1: !hal.buffer, %arg2: !hal.buffer) -> !hal.buffer {
-func @Stack(%arg0: !tf_tensorlist.list, %arg1: tensor<1xi32>, %arg2: tensor<i32>) -> tensor<1xf32> {
+// CHECK-LABEL: func @Stack(%arg0: !tensorlist.list, %arg1: !hal.buffer) -> !hal.buffer {
+func @Stack(%arg0: !tf_tensorlist.list, %arg1: tensor<i32>) -> tensor<1xf32> {
// CHECK: [[VIEW1:%.+]] = hal.buffer_view.create %arg1{{.*}}
-// CHECK: [[VIEW2:%.+]] = hal.buffer_view.create %arg2{{.*}}
-// CHECK: "tensorlist.Stack"(%arg0, [[VIEW1]], [[VIEW2]])
- %0 = "tf_tensorlist.Stack"(%arg0, %arg1, %arg2) : (!tf_tensorlist.list, tensor<1xi32>, tensor<i32>) -> tensor<1xf32>
+// CHECK: "tensorlist.Stack"(%allocator, %arg0, [[VIEW1]])
+ %0 = "tf_tensorlist.Stack"(%arg0, %arg1) : (!tf_tensorlist.list, tensor<i32>) -> tensor<1xf32>
return %0 : tensor<1xf32>
}
+
+// CHECK-LABEL: func @Concat(%arg0: !tensorlist.list) -> !hal.buffer {
+func @Concat(%arg0: !tf_tensorlist.list) -> (tensor<1xf32>) {
+// CHECK: "tensorlist.Concat"(%allocator, %arg0)
+ %0 = "tf_tensorlist.Concat"(%arg0) : (!tf_tensorlist.list) -> tensor<1xf32>
+ return %0 : tensor<1xf32>
+}
+
diff --git a/integrations/tensorflow/compiler/dialect/tf_tensorlist/conversion/test/convert_tf_to_tf_tensorlist.mlir b/integrations/tensorflow/compiler/dialect/tf_tensorlist/conversion/test/convert_tf_to_tf_tensorlist.mlir
index 1f08ff4..2d96974 100644
--- a/integrations/tensorflow/compiler/dialect/tf_tensorlist/conversion/test/convert_tf_to_tf_tensorlist.mlir
+++ b/integrations/tensorflow/compiler/dialect/tf_tensorlist/conversion/test/convert_tf_to_tf_tensorlist.mlir
@@ -4,9 +4,9 @@
// CHECK-LABEL: func @basic
func @basic(%arg0: tensor<f32>, %num_elements: tensor<i32>, %element_shape: tensor<0xi32>, %index: tensor<i32>, %item: tensor<f32>) -> tensor<f32> {
- // CHECK-NEXT: [[LIST0:%.+]] = "tf_tensorlist.Reserve"(%arg2, %arg1) : (tensor<0xi32>, tensor<i32>) -> !tf_tensorlist.list
+ // CHECK-NEXT: [[LIST0:%.+]] = "tf_tensorlist.Reserve"(%arg2, %arg1) {element_type = f32} : (tensor<0xi32>, tensor<i32>) -> !tf_tensorlist.list
// CHECK-NEXT: [[LIST1:%.+]] = "tf_tensorlist.SetItem"([[LIST0]], %arg3, %arg4) : (!tf_tensorlist.list, tensor<i32>, tensor<f32>) -> !tf_tensorlist.list
- // CHECK-NEXT: [[T:%.+]] = "tf_tensorlist.GetItem"([[LIST1]], %arg3, %arg2) : (!tf_tensorlist.list, tensor<i32>, tensor<0xi32>) -> tensor<f32>
+ // CHECK-NEXT: [[T:%.+]] = "tf_tensorlist.GetItem"([[LIST1]], %arg3) : (!tf_tensorlist.list, tensor<i32>) -> tensor<f32>
// CHECK-NEXT: return [[T]] : tensor<f32>
%list0 = "tf.TensorListReserve"(%element_shape, %num_elements) : (tensor<0xi32>, tensor<i32>) -> tensor<!tf.variant<tensor<f32>>>
%list1 = "tf.TensorListSetItem"(%list0, %index, %item) : (tensor<!tf.variant<tensor<f32>>>, tensor<i32>, tensor<f32>) -> tensor<!tf.variant<tensor<f32>>>
@@ -16,9 +16,9 @@
// CHECK-LABEL: func @stack
func @stack(%arg0: tensor<?xf32>, %element_shape: tensor<0xi32>) -> tensor<?xf32> {
- // CHECK-NEXT: [[LIST0:%.+]] = "tf_tensorlist.FromTensor"(%arg0, %arg1) : (tensor<?xf32>, tensor<0xi32>) -> !tf_tensorlist.list
+ // CHECK-NEXT: [[LIST0:%.+]] = "tf_tensorlist.FromTensor"(%arg0) : (tensor<?xf32>) -> !tf_tensorlist.list
// CHECK-NEXT: [[CONST:%.+]] = constant dense<-1> : tensor<i64>
- // CHECK-NEXT: [[T:%.+]] = "tf_tensorlist.Stack"([[LIST0]], %arg1, [[CONST]]) : (!tf_tensorlist.list, tensor<0xi32>, tensor<i64>) -> tensor<?xf32>
+ // CHECK-NEXT: [[T:%.+]] = "tf_tensorlist.Stack"([[LIST0]], [[CONST]]) : (!tf_tensorlist.list, tensor<i64>) -> tensor<?xf32>
// CHECK-NEXT: return [[T]]
%list0 = "tf.TensorListFromTensor"(%arg0, %element_shape) : (tensor<?xf32>, tensor<0xi32>) -> tensor<!tf.variant<tensor<f32>>>
@@ -28,7 +28,7 @@
// CHECK-LABEL: func @concat
func @concat(%arg0: tensor<?xf32>, %element_shape: tensor<0xi32>, %lead: tensor<i64>) -> (tensor<?xf32>, tensor<0xi64>) {
- // CHECK-DAG: [[LIST0:%.+]] = "tf_tensorlist.FromTensor"(%arg0, %arg1)
+ // CHECK-DAG: [[LIST0:%.+]] = "tf_tensorlist.FromTensor"(%arg0)
// CHECK-DAG: [[T:%.+]] = "tf_tensorlist.Concat"([[LIST0]])
// CHECK-DAG: [[L:%.+]] = "tf_tensorlist.GetDim0"([[LIST0]])
// CHECK: return [[T]], [[L]]
@@ -38,6 +38,7 @@
return %t#0, %t#1 : tensor<?xf32>, tensor<0xi64>
}
+
// CHECK-LABEL: func @control_flow_simple
func @control_flow_simple(%arg0: tensor<f32>, %num_elements: tensor<i32>, %element_shape: tensor<0xi32>, %index: tensor<i32>, %item: tensor<f32>) {
// CHECK-NEXT: tf_tensorlist.Reserve
diff --git a/integrations/tensorflow/compiler/dialect/tf_tensorlist/ir/test/ops.mlir b/integrations/tensorflow/compiler/dialect/tf_tensorlist/ir/test/ops.mlir
index 3bf2630..9e2c13d 100644
--- a/integrations/tensorflow/compiler/dialect/tf_tensorlist/ir/test/ops.mlir
+++ b/integrations/tensorflow/compiler/dialect/tf_tensorlist/ir/test/ops.mlir
@@ -9,13 +9,13 @@
%item: tensor<?xf32>
) {
// CHECK: tf_tensorlist.Reserve
- %2 = "tf_tensorlist.Reserve"(%element_shape, %num_elements) : (tensor<1xi32>, tensor<i32>) -> !tf_tensorlist.list
+ %2 = "tf_tensorlist.Reserve"(%element_shape, %num_elements) { element_type = f32 } : (tensor<1xi32>, tensor<i32>) -> !tf_tensorlist.list
// CHECK: tf_tensorlist.GetItem
- %3 = "tf_tensorlist.GetItem"(%list, %index, %element_shape) : (!tf_tensorlist.list, tensor<i32>, tensor<1xi32>) -> tensor<?xf32>
+ %3 = "tf_tensorlist.GetItem"(%list, %index) : (!tf_tensorlist.list, tensor<i32>) -> tensor<?xf32>
// CHECK: tf_tensorlist.SetItem
%4 = "tf_tensorlist.SetItem"(%list, %index, %item) : (!tf_tensorlist.list, tensor<i32>, tensor<?xf32>) -> !tf_tensorlist.list
// CHECK: tf_tensorlist.Stack
- %5 = "tf_tensorlist.Stack"(%list, %element_shape, %index) : (!tf_tensorlist.list, tensor<1xi32>, tensor<i32>) -> tensor<1x2xf32>
+ %5 = "tf_tensorlist.Stack"(%list, %index) : (!tf_tensorlist.list, tensor<i32>) -> tensor<1x2xf32>
// CHECK: tf_tensorlist.Concat
%6 = "tf_tensorlist.Concat"(%list) : (!tf_tensorlist.list) -> tensor<1x2xf32>
// CHECK: tf_tensorlist.GetDim0
diff --git a/integrations/tensorflow/compiler/dialect/tf_tensorlist/ir/tf_tensorlist_ops.td b/integrations/tensorflow/compiler/dialect/tf_tensorlist/ir/tf_tensorlist_ops.td
index f125e7a..3f35bff 100644
--- a/integrations/tensorflow/compiler/dialect/tf_tensorlist/ir/tf_tensorlist_ops.td
+++ b/integrations/tensorflow/compiler/dialect/tf_tensorlist/ir/tf_tensorlist_ops.td
@@ -34,7 +34,8 @@
let arguments = (ins
TF_I32OrI64Tensor:$element_shape,
- I32Tensor:$num_elements
+ I32Tensor:$num_elements,
+ TypeAttr:$element_type
);
let results = (outs
@@ -62,8 +63,7 @@
let arguments = (ins
TfTensorList_TensorList:$list,
- I32Tensor:$index,
- I32Tensor:$element_shape
+ I32Tensor:$index
);
let results = (outs
@@ -103,8 +103,7 @@
}];
let arguments = (ins
- TF_Tensor:$tensor,
- I32Tensor:$element_shape
+ TF_Tensor:$tensor
);
let results = (outs
@@ -125,7 +124,6 @@
let arguments = (ins
TfTensorList_TensorList:$list,
- I32Tensor:$element_shape,
// TODO(silvasean): Properly handle IREE's blind truncation to 32-bit.
// This is logically `index` type, but coming from TensorFlow it
// comes in as i64. IREE then proceeds to blindly truncate it to I32
diff --git a/integrations/tensorflow/e2e/keras/BUILD b/integrations/tensorflow/e2e/keras/BUILD
index 6ecd6b5..89f92ad 100644
--- a/integrations/tensorflow/e2e/keras/BUILD
+++ b/integrations/tensorflow/e2e/keras/BUILD
@@ -28,10 +28,6 @@
"//integrations/tensorflow/e2e:iree_e2e_cartesian_product_test_suite.bzl",
"iree_e2e_cartesian_product_test_suite",
)
-load(
- "//integrations/tensorflow/e2e:iree_e2e_test_suite.bzl",
- "iree_e2e_test_suite",
-)
package(
default_visibility = ["//visibility:public"],
@@ -39,33 +35,6 @@
licenses = ["notice"], # Apache 2.0
)
-# @unused
-DOC = """
-vision_model_test_manual is for manual testing of all keras vision models.
-Test will run only manually with all parameters specified manually, for example:
-bazel run -c opt integrations/tensorflow/e2e/keras:vision_model_test_manual -- \
---target_backends=tf,iree_vmla \
---data=imagenet \
---url=https://storage.googleapis.com/iree_models/ \
---model=ResNet50
-
-Command arguments description:
---target_backends: can be combination of these: tf,iree_vmla
---data: can be 'imagenet' or 'cifar10'.
- imagenet - input image size (1, 224, 224, 3)
- cifar10 - input image size (1, 32, 32, 3) - it is used for quick tests
- and needs pretrained weights, we pretrained models: ResNet50, MobileNet, MobileNetV2
---include_top: Whether or not to include the final (top) layers of the model.
---url: we need it only for cifar10 models to load weights from https://storage.googleapis.com/iree_models/
- imagenet pretrained weights url is specified by keras
---model: supports ResNet50, MobileNet, MobileNetV2, ResNet101, ResNet152,
- ResNet50V2, ResNet101V2, ResNet152V2, VGG16, VGG19, Xception,
- InceptionV3, InceptionResNetV2, DenseNet121, DenseNet169,
- DenseNet201, NASNetMobile, NASNetLarge
- All above models works with 'imagenet' data sets.
- ResNet50, MobileNet, MobileNetV2 work with both 'imagenet' and 'cifar10' data sets.
-"""
-
[
iree_py_binary(
name = src.replace(".py", "_manual"),
@@ -82,308 +51,6 @@
)
]
-SPECIAL_CASES = [
- "keyword_spotting_streaming_test.py",
- "vision_model_test.py",
-]
-
-TFLITE_FAILING = []
-
-VMLA_FAILING = []
-
-LLVM_FAILING = []
-
-VULKAN_FAILING = []
-
-TF_PASSING = glob(
- ["*_test.py"],
- exclude = SPECIAL_CASES,
-)
-
-TFLITE_PASSING = glob(
- ["*_test.py"],
- exclude = TFLITE_FAILING + SPECIAL_CASES,
-)
-
-VMLA_PASSING = glob(
- ["*_test.py"],
- exclude = VMLA_FAILING + SPECIAL_CASES,
-)
-
-LLVM_PASSING = glob(
- ["*_test.py"],
- exclude = LLVM_FAILING + SPECIAL_CASES,
-)
-
-VULKAN_PASSING = glob(
- ["*_test.py"],
- exclude = VULKAN_FAILING + SPECIAL_CASES,
-)
-
-iree_e2e_test_suite(
- name = "keras_tests",
- backends_to_srcs = {
- "tf": TF_PASSING,
- "tflite": TFLITE_PASSING,
- "iree_vmla": VMLA_PASSING,
- "iree_llvmjit": LLVM_PASSING,
- "iree_vulkan": VULKAN_PASSING,
- },
- reference_backend = "tf",
- deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
- "//integrations/tensorflow/bindings/python/pyiree/tf/support",
- ],
-)
-
-iree_e2e_test_suite(
- name = "keras_tests_failing",
- backends_to_srcs = {
- "tflite": TFLITE_FAILING,
- "iree_vmla": VMLA_FAILING,
- "iree_llvmjit": LLVM_FAILING,
- "iree_vulkan": VULKAN_FAILING,
- },
- reference_backend = "tf",
- tags = [
- "failing",
- "manual",
- "nokokoro",
- "notap",
- ],
- deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
- "//integrations/tensorflow/bindings/python/pyiree/tf/support",
- ],
-)
-
-iree_e2e_cartesian_product_test_suite(
- name = "large_cifar10_tests",
- size = "large",
- srcs = ["vision_model_test.py"],
- flags_to_values = {
- "reference_backend": "tf",
- "data": "cifar10",
- "model": [
- # All models with runtime shorter than ResNet50.
- "MobileNet", # Max: Vulkan 61.0s
- "MobileNetV2", # Max: LLVM 96.3s
- "ResNet50", # Max: LLVM 145.6s
- "VGG16", # Max: LLVM 89.5s
- "VGG19", # Max: LLVM 94.7s
- ],
- "target_backends": [
- "tf",
- "tflite",
- "iree_vmla",
- "iree_llvmjit",
- "iree_vulkan",
- ],
- },
- main = "vision_model_test.py",
- tags = ["manual"],
- deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
- "//integrations/tensorflow/bindings/python/pyiree/tf/support",
- ],
-)
-
-iree_e2e_cartesian_product_test_suite(
- name = "enormous_cifar10_tests",
- size = "enormous",
- srcs = ["vision_model_test.py"],
- failing_configurations = [
- {
- # Failing on vmla with negative inputs.
- "model": [
- "NASNetLarge",
- "NASNetMobile",
- ],
- "target_backends": "iree_vmla",
- },
- {
- # Failing on llvm and vulkan:
- "model": [
- "NASNetLarge",
- "NASNetMobile",
- "ResNet50V2",
- "ResNet101V2",
- "ResNet152V2",
- ],
- "target_backends": [
- "iree_llvmjit",
- "iree_vulkan",
- ],
- },
- ],
- flags_to_values = {
- "reference_backend": "tf",
- "data": "cifar10",
- "model": [
- "DenseNet121",
- "DenseNet169",
- "DenseNet201",
- "NASNetLarge",
- "NASNetMobile",
- "ResNet50V2",
- "ResNet101",
- "ResNet101V2",
- "ResNet152",
- "ResNet152V2",
- ],
- "target_backends": [
- "tf",
- "tflite",
- "iree_vmla",
- "iree_llvmjit",
- "iree_vulkan",
- ],
- },
- main = "vision_model_test.py",
- tags = [
- "guitar",
- "manual",
- "nokokoro",
- "notap",
- ],
- deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
- "//integrations/tensorflow/bindings/python/pyiree/tf/support",
- ],
-)
-
-# 'non_hermetic' tests use real model weights to test numerical correctness.
-iree_e2e_cartesian_product_test_suite(
- name = "cifar10_non_hermetic_tests",
- size = "large",
- srcs = ["vision_model_test.py"],
- flags_to_values = {
- "reference_backend": "tf",
- "data": "cifar10",
- "url": "https://storage.googleapis.com/iree_models/",
- "use_external_weights": True,
- "model": [
- "MobileNet",
- "MobileNetV2",
- "ResNet50",
- ],
- "target_backends": [
- "tf",
- "tflite",
- "iree_vmla",
- "iree_llvmjit",
- "iree_vulkan",
- ],
- },
- main = "vision_model_test.py",
- tags = [
- "external",
- "guitar",
- "manual",
- "no-remote",
- "nokokoro",
- "notap",
- ],
- deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
- "//integrations/tensorflow/bindings/python/pyiree/tf/support",
- ],
-)
-
-# 'non_hermetic' tests use real model weights to test numerical correctness.
-iree_e2e_cartesian_product_test_suite(
- name = "imagenet_non_hermetic_tests",
- size = "enormous",
- srcs = ["vision_model_test.py"],
- failing_configurations = [
- {
- # Failing on vmla with negative inputs.
- "model": [
- "NASNetLarge",
- "NASNetMobile",
- ],
- "target_backends": "iree_vmla",
- },
- {
- # Failing vulkan:
- "model": [
- "InceptionResNetV2",
- "InceptionV3",
- ],
- "target_backends": [
- "iree_vulkan",
- ],
- },
- {
- # Failing llvm and vulkan:
- "model": [
- "NASNetLarge",
- "NASNetMobile",
- "ResNet50V2",
- "ResNet101V2",
- "ResNet152V2",
- "Xception",
- ],
- "target_backends": [
- "iree_llvmjit",
- "iree_vulkan",
- ],
- },
- ],
- flags_to_values = {
- "reference_backend": "tf",
- "data": "imagenet",
- "use_external_weights": True,
- "model": [
- "DenseNet121",
- "DenseNet169",
- "DenseNet201",
- "InceptionResNetV2",
- "InceptionV3",
- "MobileNet",
- "MobileNetV2",
- "NASNetLarge",
- "NASNetMobile",
- "ResNet50",
- "ResNet50V2",
- "ResNet101",
- "ResNet101V2",
- "ResNet152",
- "ResNet152V2",
- "VGG16",
- "VGG19",
- "Xception",
- ],
- "target_backends": [
- "tf",
- "tflite",
- "iree_vmla",
- "iree_llvmjit",
- "iree_vulkan",
- ],
- },
- main = "vision_model_test.py",
- tags = [
- "external",
- "guitar",
- "manual",
- "nokokoro",
- "notap",
- ],
- deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
- "//integrations/tensorflow/bindings/python/pyiree/tf/support",
- ],
-)
-
-# It is used to produce weights for keras vision models with input image size
-# 32x32. These models are not optimized for accuracy or latency (they are for
-# debugging only). They have the same neural net topology with keras vision
-# models trained on imagenet data sets
-iree_py_binary(
- name = "train_vision_models_on_cifar",
- srcs = ["train_vision_models_on_cifar.py"],
- python_version = "PY3",
- srcs_version = "PY2AND3",
- deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
- "//integrations/tensorflow/bindings/python/pyiree/tf/support",
- ],
-)
-
# Keyword Spotting Tests:
KEYWORD_SPOTTING_MODELS = [
"svdf",
diff --git a/integrations/tensorflow/e2e/keras/applications/BUILD b/integrations/tensorflow/e2e/keras/applications/BUILD
new file mode 100644
index 0000000..a9df330
--- /dev/null
+++ b/integrations/tensorflow/e2e/keras/applications/BUILD
@@ -0,0 +1,318 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Test coverage across backends for e2e tests is defined directly in the BUILD
+# files. Coverage tables generated from this file can be viewed here:
+# https://google.github.io/iree/tensorflow-coverage/vision-coverage
+# Updates made to test suite names should also be reflected here:
+# https://github.com/google/iree/blob/main/scripts/update_e2e_coverage.py
+
+load(
+ "//bindings/python:build_defs.oss.bzl",
+ "INTREE_TENSORFLOW_PY_DEPS",
+ "NUMPY_DEPS",
+ "iree_py_binary",
+)
+load(
+ "//integrations/tensorflow/e2e:iree_e2e_cartesian_product_test_suite.bzl",
+ "iree_e2e_cartesian_product_test_suite",
+)
+
+package(
+ default_visibility = ["//visibility:public"],
+ features = ["layering_check"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+# @unused
+DOC = """
+applications_test_manual is for manual testing of all keras vision models.
+Test will run only manually with all parameters specified manually, for example:
+bazel run -c opt integrations/tensorflow/e2e/keras/applications:applications_test_manual -- \
+--target_backends=tf,iree_vmla \
+--data=imagenet \
+--url=https://storage.googleapis.com/iree_models/ \
+--model=ResNet50
+
+Command arguments description:
+--target_backends: can be combination of these: tf,iree_vmla
+--data: can be 'imagenet' or 'cifar10'.
+ imagenet - input image size (1, 224, 224, 3)
+ cifar10 - input image size (1, 32, 32, 3) - it is used for quick tests
+ and needs pretrained weights, we pretrained models: ResNet50, MobileNet, MobileNetV2
+--include_top: Whether or not to include the final (top) layers of the model.
+--url: we need it only for cifar10 models to load weights from https://storage.googleapis.com/iree_models/
+ imagenet pretrained weights url is specified by keras
+--model: supports ResNet50, MobileNet, MobileNetV2, ResNet101, ResNet152,
+ ResNet50V2, ResNet101V2, ResNet152V2, VGG16, VGG19, Xception,
+ InceptionV3, InceptionResNetV2, DenseNet121, DenseNet169,
+ DenseNet201, NASNetMobile, NASNetLarge
+ All above models works with 'imagenet' data sets.
+ ResNet50, MobileNet, MobileNetV2 work with both 'imagenet' and 'cifar10' data sets.
+"""
+
+[
+ iree_py_binary(
+ name = src.replace(".py", "_manual"),
+ srcs = [src],
+ main = src,
+ python_version = "PY3",
+ deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
+ "//integrations/tensorflow/bindings/python/pyiree/tf/support",
+ ],
+ )
+ for src in glob(
+ ["*_test.py"],
+ )
+]
+
+KERAS_APPLICATIONS_MODELS = [
+ "DenseNet121",
+ "DenseNet169",
+ "DenseNet201",
+ "EfficientNetB0",
+ "EfficientNetB1",
+ "EfficientNetB2",
+ "EfficientNetB3",
+ "EfficientNetB4",
+ "EfficientNetB5",
+ "EfficientNetB6",
+ "EfficientNetB7",
+ "InceptionResNetV2",
+ "InceptionV3",
+ "MobileNet",
+ "MobileNetV2",
+ "MobileNetV3Large",
+ "MobileNetV3Small",
+ "NASNetLarge",
+ "NASNetMobile",
+ "ResNet101",
+ "ResNet101V2",
+ "ResNet152",
+ "ResNet152V2",
+ "ResNet50",
+ "ResNet50V2",
+ "VGG16",
+ "VGG19",
+]
+
+iree_e2e_cartesian_product_test_suite(
+ name = "large_cifar10_tests",
+ size = "large",
+ srcs = ["applications_test.py"],
+ flags_to_values = {
+ "reference_backend": "tf",
+ "data": "cifar10",
+ "model": [
+ # All models with runtime shorter than ResNet50.
+ "MobileNet", # Max: Vulkan 61.0s
+ "MobileNetV2", # Max: LLVM 96.3s
+ "ResNet50", # Max: LLVM 145.6s
+ "VGG16", # Max: LLVM 89.5s
+ "VGG19", # Max: LLVM 94.7s
+ ],
+ "target_backends": [
+ "tf",
+ "tflite",
+ "iree_vmla",
+ "iree_llvmjit",
+ "iree_vulkan",
+ ],
+ },
+ main = "applications_test.py",
+ tags = ["manual"],
+ deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
+ "//integrations/tensorflow/bindings/python/pyiree/tf/support",
+ ],
+)
+
+iree_e2e_cartesian_product_test_suite(
+ name = "enormous_cifar10_tests",
+ size = "enormous",
+ srcs = ["applications_test.py"],
+ failing_configurations = [
+ {
+ # Failing on vmla with negative inputs.
+ "model": [
+ "NASNetLarge",
+ "NASNetMobile",
+ ],
+ "target_backends": "iree_vmla",
+ },
+ {
+ # Failing on llvm and vulkan:
+ "model": [
+ "NASNetLarge",
+ "NASNetMobile",
+ "ResNet50V2",
+ "ResNet101V2",
+ "ResNet152V2",
+ ],
+ "target_backends": [
+ "iree_llvmjit",
+ "iree_vulkan",
+ ],
+ },
+ ],
+ flags_to_values = {
+ "reference_backend": "tf",
+ "data": "cifar10",
+ "model": [
+ "DenseNet121",
+ "DenseNet169",
+ "DenseNet201",
+ "NASNetLarge",
+ "NASNetMobile",
+ "ResNet50V2",
+ "ResNet101",
+ "ResNet101V2",
+ "ResNet152",
+ "ResNet152V2",
+ ],
+ "target_backends": [
+ "tf",
+ "tflite",
+ "iree_vmla",
+ "iree_llvmjit",
+ "iree_vulkan",
+ ],
+ },
+ main = "applications_test.py",
+ tags = [
+ "guitar",
+ "manual",
+ "nokokoro",
+ "notap",
+ ],
+ deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
+ "//integrations/tensorflow/bindings/python/pyiree/tf/support",
+ ],
+)
+
+# 'non_hermetic' tests use real model weights to test numerical correctness.
+iree_e2e_cartesian_product_test_suite(
+ name = "cifar10_non_hermetic_tests",
+ size = "large",
+ srcs = ["applications_test.py"],
+ flags_to_values = {
+ "reference_backend": "tf",
+ "data": "cifar10",
+ "url": "https://storage.googleapis.com/iree_models/",
+ "use_external_weights": True,
+ "model": [
+ "MobileNet",
+ "MobileNetV2",
+ "ResNet50",
+ ],
+ "target_backends": [
+ "tf",
+ "tflite",
+ "iree_vmla",
+ "iree_llvmjit",
+ "iree_vulkan",
+ ],
+ },
+ main = "applications_test.py",
+ tags = [
+ "external",
+ "guitar",
+ "manual",
+ "no-remote",
+ "nokokoro",
+ "notap",
+ ],
+ deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
+ "//integrations/tensorflow/bindings/python/pyiree/tf/support",
+ ],
+)
+
+# 'non_hermetic' tests use real model weights to test numerical correctness.
+iree_e2e_cartesian_product_test_suite(
+ name = "imagenet_non_hermetic_tests",
+ size = "enormous",
+ srcs = ["applications_test.py"],
+ failing_configurations = [
+ {
+ # Failing on vmla with negative inputs.
+ "model": [
+ "NASNetLarge",
+ "NASNetMobile",
+ ],
+ "target_backends": "iree_vmla",
+ },
+ {
+ # Failing vulkan:
+ "model": [
+ "InceptionResNetV2",
+ "InceptionV3",
+ ],
+ "target_backends": [
+ "iree_vulkan",
+ ],
+ },
+ {
+ # Failing llvm and vulkan:
+ "model": [
+ "NASNetLarge",
+ "NASNetMobile",
+ "ResNet50V2",
+ "ResNet101V2",
+ "ResNet152V2",
+ "Xception",
+ ],
+ "target_backends": [
+ "iree_llvmjit",
+ "iree_vulkan",
+ ],
+ },
+ ],
+ flags_to_values = {
+ "reference_backend": "tf",
+ "data": "imagenet",
+ "use_external_weights": True,
+ "model": KERAS_APPLICATIONS_MODELS,
+ "target_backends": [
+ "tf",
+ "tflite",
+ "iree_vmla",
+ "iree_llvmjit",
+ "iree_vulkan",
+ ],
+ },
+ main = "applications_test.py",
+ tags = [
+ "external",
+ "guitar",
+ "manual",
+ "nokokoro",
+ "notap",
+ ],
+ deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
+ "//integrations/tensorflow/bindings/python/pyiree/tf/support",
+ ],
+)
+
+# It is used to produce weights for keras vision models with input image size
+# 32x32. These models are not optimized for accuracy or latency (they are for
+# debugging only). They have the same neural net topology with keras vision
+# models trained on imagenet data sets
+iree_py_binary(
+ name = "train_vision_models_on_cifar",
+ srcs = ["train_vision_models_on_cifar.py"],
+ python_version = "PY3",
+ srcs_version = "PY2AND3",
+ deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
+ "//integrations/tensorflow/bindings/python/pyiree/tf/support",
+ ],
+)
diff --git a/integrations/tensorflow/e2e/keras/applications/applications_test.py b/integrations/tensorflow/e2e/keras/applications/applications_test.py
new file mode 100644
index 0000000..8c0e77c
--- /dev/null
+++ b/integrations/tensorflow/e2e/keras/applications/applications_test.py
@@ -0,0 +1,121 @@
+# Lint as: python3
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Test all models in tf.keras.applications."""
+
+import os
+
+from absl import app
+from absl import flags
+import numpy as np
+from pyiree.tf.support import tf_test_utils
+from pyiree.tf.support import tf_utils
+import tensorflow.compat.v2 as tf
+
+FLAGS = flags.FLAGS
+
+# Testing all applications models automatically can take time
+# so we test it one by one, with argument --model=MobileNet
+flags.DEFINE_string("model", "ResNet50", "model name")
+flags.DEFINE_string(
+ "url", "", "url with model weights "
+ "for example https://storage.googleapis.com/iree_models/")
+flags.DEFINE_bool("use_external_weights", False,
+ "Whether or not to load external weights from the web")
+flags.DEFINE_enum("data", "cifar10", ["cifar10", "imagenet"],
+ "data sets on which model was trained: imagenet, cifar10")
+flags.DEFINE_bool(
+ "include_top", True,
+ "Whether or not to include the final (top) layers of the model.")
+
+BATCH_SIZE = 1
+IMAGE_DIM = 224
+
+
+def load_cifar10_weights(model):
+ file_name = "cifar10" + FLAGS.model
+ # get_file will download the model weights from a publicly available folder,
+ # save them to cache_dir=~/.keras/models/ and return a path to them.
+ url = os.path.join(
+ FLAGS.url, f"cifar10_include_top_{FLAGS.include_top:d}_{FLAGS.model}.h5")
+ weights_path = tf.keras.utils.get_file(file_name, url)
+ model.load_weights(weights_path)
+ return model
+
+
+def initialize_model():
+ # If weights == "imagenet", the model will load the appropriate weights from
+ # an external tf.keras URL.
+ weights = None
+ if FLAGS.use_external_weights and FLAGS.data == "imagenet":
+ weights = "imagenet"
+
+ model_class = getattr(tf.keras.applications, FLAGS.model)
+ model = model_class(weights=weights, include_top=FLAGS.include_top)
+
+ if FLAGS.use_external_weights and FLAGS.data == "cifar10":
+ if not FLAGS.url:
+ raise ValueError(
+ "cifar10 weights cannot be loaded without the `--url` flag.")
+ model = load_cifar10_weights(model)
+ return model
+
+
+class ApplicationsModule(tf_test_utils.TestModule):
+
+ def __init__(self):
+ super().__init__()
+ self.m = initialize_model()
+
+ input_shape = list([BATCH_SIZE] + self.m.inputs[0].shape[1:])
+
+ # Some models accept dynamic image dimensions by default, so we use
+ # IMAGE_DIM as a stand-in.
+ for i, dim in enumerate(input_shape):
+ if dim is None:
+ input_shape[i] = IMAGE_DIM
+
+ # Specify input shape with a static batch size.
+ # TODO(b/142948097): Add support for dynamic shapes in SPIR-V lowering.
+ self.call = tf_test_utils.tf_function_unit_test(
+ input_signature=[tf.TensorSpec(input_shape)],
+ name="call",
+ rtol=1e-5,
+ atol=1e-5)(lambda x: self.m(x, training=False))
+
+
+class ApplicationsTest(tf_test_utils.TracedModuleTestCase):
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self._modules = tf_test_utils.compile_tf_module(
+ ApplicationsModule,
+ exported_names=ApplicationsModule.get_tf_function_unit_tests(),
+ relative_artifacts_dir=os.path.join(FLAGS.model, FLAGS.data))
+
+
+def main(argv):
+ del argv # Unused.
+ if hasattr(tf, "enable_v2_behavior"):
+ tf.enable_v2_behavior()
+
+ if not hasattr(tf.keras.applications, FLAGS.model):
+ raise ValueError(f"Unsupported model: {FLAGS.model}")
+
+ ApplicationsTest.generate_unit_tests(ApplicationsModule)
+ tf.test.main()
+
+
+if __name__ == "__main__":
+ app.run(main)
diff --git a/integrations/tensorflow/e2e/keras/train_vision_models_on_cifar.py b/integrations/tensorflow/e2e/keras/applications/train_vision_models_on_cifar.py
similarity index 63%
rename from integrations/tensorflow/e2e/keras/train_vision_models_on_cifar.py
rename to integrations/tensorflow/e2e/keras/applications/train_vision_models_on_cifar.py
index 6cfa854..28cd296 100644
--- a/integrations/tensorflow/e2e/keras/train_vision_models_on_cifar.py
+++ b/integrations/tensorflow/e2e/keras/applications/train_vision_models_on_cifar.py
@@ -28,50 +28,12 @@
'include_top', True,
'Whether or not to include the final (top) layers of the model.')
-APP_MODELS = {
- 'ResNet50':
- tf.keras.applications.resnet.ResNet50,
- 'ResNet101':
- tf.keras.applications.resnet.ResNet101,
- 'ResNet152':
- tf.keras.applications.resnet.ResNet152,
- 'ResNet50V2':
- tf.keras.applications.resnet_v2.ResNet50V2,
- 'ResNet101V2':
- tf.keras.applications.resnet_v2.ResNet101V2,
- 'ResNet152V2':
- tf.keras.applications.resnet_v2.ResNet152V2,
- 'VGG16':
- tf.keras.applications.vgg16.VGG16,
- 'VGG19':
- tf.keras.applications.vgg19.VGG19,
- 'Xception':
- tf.keras.applications.xception.Xception,
- 'InceptionV3':
- tf.keras.applications.inception_v3.InceptionV3,
- 'InceptionResNetV2':
- tf.keras.applications.inception_resnet_v2.InceptionResNetV2,
- 'MobileNet':
- tf.keras.applications.mobilenet.MobileNet,
- 'MobileNetV2':
- tf.keras.applications.mobilenet_v2.MobileNetV2,
- 'DenseNet121':
- tf.keras.applications.densenet.DenseNet121,
- 'DenseNet169':
- tf.keras.applications.densenet.DenseNet169,
- 'DenseNet201':
- tf.keras.applications.densenet.DenseNet201,
- 'NASNetMobile':
- tf.keras.applications.nasnet.NASNetMobile,
- 'NASNetLarge':
- tf.keras.applications.nasnet.NASNetLarge,
-}
-
# minimum size for keras vision models
INPUT_SHAPE = [1, 32, 32, 3]
-def main(_):
+def main(argv):
+ del argv # Unused.
# prepare training and testing data
(train_images,
@@ -89,10 +51,10 @@
train_labels = train_labels[:4000]
# It is a toy model for debugging (not optimized for accuracy or speed).
-
- model = APP_MODELS[FLAGS.model](weights=None,
- include_top=FLAGS.include_top,
- input_shape=INPUT_SHAPE[1:])
+ model_class = getattr(tf.keras.applications, FLAGS.model)
+ model = model_class(weights=None,
+ include_top=FLAGS.include_top,
+ input_shape=INPUT_SHAPE[1:])
model.summary()
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
diff --git a/integrations/tensorflow/e2e/keras/keyword_spotting_streaming_test.py b/integrations/tensorflow/e2e/keras/keyword_spotting_streaming_test.py
index a099f5b..e2da0c2 100644
--- a/integrations/tensorflow/e2e/keras/keyword_spotting_streaming_test.py
+++ b/integrations/tensorflow/e2e/keras/keyword_spotting_streaming_test.py
@@ -49,42 +49,28 @@
}
-class KeywordSpottingModule(tf.Module):
+class KeywordSpottingModule(tf_test_utils.TestModule):
def __init__(self):
super().__init__()
self.m = utils.get_model_with_default_params(FLAGS.model,
MODE_ENUM_TO_MODE[FLAGS.mode])
- self.write_input_shapes_to_cls(self.m)
- self.m.predict = lambda x: self.m.call(x, training=False)
- input_signature = [tf.TensorSpec(shape) for shape in self.input_shapes]
- self.predict = tf.function(input_signature=input_signature)(self.m.predict)
- @classmethod
- def write_input_shapes_to_cls(cls, model):
- # We store the input shapes on the cls because we need access to them to
- # generate the random inputs. Lists are not valid exported names, so we
- # cannot access them from the module instance itself, and instead we store
- # the input shapes on the test case below.
- cls.input_shapes = [tensor.shape for tensor in model.inputs]
+ call = lambda *args: self.m(*args, training=False)
+ input_signature = [tf.TensorSpec(tensor.shape) for tensor in self.m.inputs]
+ self.call = tf_test_utils.tf_function_unit_test(
+ input_signature=input_signature, name="call", atol=1e-5)(call)
class KeywordSpottingTest(tf_test_utils.TracedModuleTestCase):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
- self._modules = tf_test_utils.compile_tf_module(KeywordSpottingModule,
- exported_names=['predict'])
- self._input_shapes = KeywordSpottingModule.input_shapes
-
- def test_predict(self):
-
- def predict(module):
- inputs = [tf_utils.uniform(shape) for shape in self._input_shapes]
- inputs = inputs[0] if len(inputs) == 1 else inputs
- module.predict(inputs, atol=1e-5)
-
- self.compare_backends(predict, self._modules)
+ self._modules = tf_test_utils.compile_tf_module(
+ KeywordSpottingModule,
+ exported_names=['call'],
+ relative_artifacts_dir=os.path.join('kws_streaming', FLAGS.model,
+ FLAGS.mode))
def main(argv):
@@ -95,9 +81,8 @@
if FLAGS.model not in ALL_MODELS:
raise ValueError(f'Unsupported model: {FLAGS.model}.\n'
f'Expected one of {MODELS_HELP}.')
- KeywordSpottingModule.__name__ = os.path.join('keyword_spotting', FLAGS.model,
- FLAGS.mode)
+ KeywordSpottingTest.generate_unit_tests(KeywordSpottingModule)
tf.test.main()
diff --git a/integrations/tensorflow/e2e/keras/layers/layers_test.py b/integrations/tensorflow/e2e/keras/layers/layers_test.py
index 45a7d95..1734d68 100644
--- a/integrations/tensorflow/e2e/keras/layers/layers_test.py
+++ b/integrations/tensorflow/e2e/keras/layers/layers_test.py
@@ -480,11 +480,6 @@
return inputs[0] if len(inputs) == 1 else inputs
-def keras_arg_wrapper(*args):
- """Wrapper to convert multiple positional args into a list of values."""
- return list(args) if isinstance(args, tuple) else args
-
-
def create_wrapped_keras_layer(
layer: str, unit_test_spec: tf_test_utils.UnitTestSpec) -> tf.keras.Model:
"""Wraps a keras layer in a model for compilation."""
@@ -522,7 +517,7 @@
static_signature = [static_signature]
dynamic_signature = [dynamic_signature]
- call = lambda *args: model(keras_arg_wrapper(*args), training=FLAGS.training)
+ call = lambda *args: model(*args, training=FLAGS.training)
return tf_test_utils.tf_function_unit_test(
input_signature=dynamic_signature,
static_signature=static_signature,
@@ -557,13 +552,22 @@
setattr(self, unit_test_spec.unit_test_name, layer_unit_test)
+def get_relative_artifacts_dir() -> str:
+ dynamic_str = "dynamic" if FLAGS.dynamic_dims else "static"
+ training_str = "training" if FLAGS.training else "non_training"
+ full_api_str = "default_api" if FLAGS.test_default_kwargs_only else "full_api"
+ settings_str = f"{full_api_str}_{dynamic_str}_{training_str}"
+ return os.path.join("tf", "keras", "layers", FLAGS.layer, settings_str)
+
+
class KerasLayersTest(tf_test_utils.TracedModuleTestCase):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._modules = tf_test_utils.compile_tf_module(
KerasLayersModule,
- exported_names=KerasLayersModule.get_tf_function_unit_tests())
+ exported_names=KerasLayersModule.get_tf_function_unit_tests(),
+ relative_artifacts_dir=get_relative_artifacts_dir())
def main(argv):
@@ -580,17 +584,6 @@
if FLAGS.layer not in LAYERS_TO_UNIT_TEST_SPECS:
raise ValueError(f"Unrecognized layer: '{FLAGS.layer}'")
- # Set up name for saving artifacts.
- dynamic_str = "dynamic" if FLAGS.dynamic_dims else "static"
- training_str = "training" if FLAGS.training else "non_training"
- full_api_str = "default_api" if FLAGS.test_default_kwargs_only else "full_api"
- settings_str = f"{full_api_str}_{dynamic_str}_{training_str}"
- relative_artifacts_dir = os.path.join("tf", "keras", "layers", FLAGS.layer,
- settings_str)
- # The relative artifacts directory path is calculated from the module name
- # TODO(meadowlark): provide a better way of overridding this default.
- KerasLayersModule.__name__ = relative_artifacts_dir
-
unit_tests = KerasLayersModule.get_tf_function_unit_tests()
logging.info("Testing the following %s functions: %s", len(unit_tests),
unit_tests)
diff --git a/integrations/tensorflow/e2e/keras/train/classification_training_test.py b/integrations/tensorflow/e2e/keras/train/classification_training_test.py
index 0b3027a..a73205b 100644
--- a/integrations/tensorflow/e2e/keras/train/classification_training_test.py
+++ b/integrations/tensorflow/e2e/keras/train/classification_training_test.py
@@ -94,7 +94,10 @@
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._modules = tf_test_utils.compile_tf_module(
- ClassificationTrainingModule, exported_names=["train_on_batch"])
+ ClassificationTrainingModule,
+ exported_names=["train_on_batch"],
+ relative_artifacts_dir=os.path.join(
+ ClassificationTrainingModule.__name__, FLAGS.optimizer))
def test_train_on_batch(self):
@@ -109,8 +112,6 @@
del argv # Unused
if hasattr(tf, "enable_v2_behavior"):
tf.enable_v2_behavior()
- ClassificationTrainingModule.__name__ = os.path.join(
- ClassificationTrainingModule.__name__, FLAGS.optimizer)
tf.test.main()
diff --git a/integrations/tensorflow/e2e/keras/train/regression_training_test.py b/integrations/tensorflow/e2e/keras/train/regression_training_test.py
index cd064d3..0f9aa72 100644
--- a/integrations/tensorflow/e2e/keras/train/regression_training_test.py
+++ b/integrations/tensorflow/e2e/keras/train/regression_training_test.py
@@ -73,7 +73,10 @@
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._modules = tf_test_utils.compile_tf_module(
- RegressionTrainingModule, exported_names=["train_on_batch"])
+ RegressionTrainingModule,
+ exported_names=["train_on_batch"],
+ relative_artifacts_dir=os.path.join(RegressionTrainingModule.__name__,
+ FLAGS.optimizer))
def test_train_on_batch(self):
@@ -88,8 +91,6 @@
del argv # Unused
if hasattr(tf, "enable_v2_behavior"):
tf.enable_v2_behavior()
- RegressionTrainingModule.__name__ = os.path.join(
- RegressionTrainingModule.__name__, FLAGS.optimizer)
tf.test.main()
diff --git a/integrations/tensorflow/e2e/keras/vision_model_test.py b/integrations/tensorflow/e2e/keras/vision_model_test.py
deleted file mode 100644
index 114d407..0000000
--- a/integrations/tensorflow/e2e/keras/vision_model_test.py
+++ /dev/null
@@ -1,174 +0,0 @@
-# Lint as: python3
-# Copyright 2020 Google LLC
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# https://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""Test all applications models in Keras."""
-
-import os
-
-from absl import app
-from absl import flags
-import numpy as np
-from pyiree.tf.support import tf_test_utils
-from pyiree.tf.support import tf_utils
-import tensorflow.compat.v2 as tf
-
-FLAGS = flags.FLAGS
-
-# Testing all applications models automatically can take time
-# so we test it one by one, with argument --model=MobileNet
-flags.DEFINE_string('model', 'ResNet50', 'model name')
-flags.DEFINE_string(
- 'url', '', 'url with model weights '
- 'for example https://storage.googleapis.com/iree_models/')
-flags.DEFINE_bool('use_external_weights', False,
- 'Whether or not to load external weights from the web')
-flags.DEFINE_enum('data', 'cifar10', ['cifar10', 'imagenet'],
- 'data sets on which model was trained: imagenet, cifar10')
-flags.DEFINE_bool(
- 'include_top', True,
- 'Whether or not to include the final (top) layers of the model.')
-
-APP_MODELS = {
- 'ResNet50':
- tf.keras.applications.resnet.ResNet50,
- 'ResNet101':
- tf.keras.applications.resnet.ResNet101,
- 'ResNet152':
- tf.keras.applications.resnet.ResNet152,
- 'ResNet50V2':
- tf.keras.applications.resnet_v2.ResNet50V2,
- 'ResNet101V2':
- tf.keras.applications.resnet_v2.ResNet101V2,
- 'ResNet152V2':
- tf.keras.applications.resnet_v2.ResNet152V2,
- 'VGG16':
- tf.keras.applications.vgg16.VGG16,
- 'VGG19':
- tf.keras.applications.vgg19.VGG19,
- 'Xception':
- tf.keras.applications.xception.Xception,
- 'InceptionV3':
- tf.keras.applications.inception_v3.InceptionV3,
- 'InceptionResNetV2':
- tf.keras.applications.inception_resnet_v2.InceptionResNetV2,
- 'MobileNet':
- tf.keras.applications.mobilenet.MobileNet,
- 'MobileNetV2':
- tf.keras.applications.mobilenet_v2.MobileNetV2,
- 'DenseNet121':
- tf.keras.applications.densenet.DenseNet121,
- 'DenseNet169':
- tf.keras.applications.densenet.DenseNet169,
- 'DenseNet201':
- tf.keras.applications.densenet.DenseNet201,
- 'NASNetMobile':
- tf.keras.applications.nasnet.NASNetMobile,
- 'NASNetLarge':
- tf.keras.applications.nasnet.NASNetLarge,
-}
-
-
-def get_input_shape():
- if FLAGS.data == 'imagenet':
- if FLAGS.model in ['InceptionV3', 'Xception', 'InceptionResNetV2']:
- return (1, 299, 299, 3)
- elif FLAGS.model == 'NASNetLarge':
- return (1, 331, 331, 3)
- else:
- return (1, 224, 224, 3)
- elif FLAGS.data == 'cifar10':
- return (1, 32, 32, 3)
- else:
- raise ValueError(f'Data not supported: {FLAGS.data}')
-
-
-def load_cifar10_weights(model):
- file_name = 'cifar10' + FLAGS.model
- # get_file will download the model weights from a publicly available folder,
- # save them to cache_dir=~/.keras/models/ and return a path to them.
- url = os.path.join(
- FLAGS.url, f'cifar10_include_top_{FLAGS.include_top:d}_{FLAGS.model}.h5')
- weights_path = tf.keras.utils.get_file(file_name, url)
- model.load_weights(weights_path)
- return model
-
-
-def initialize_model():
- tf_utils.set_random_seed()
-
- # Keras applications models receive input shapes without a batch dimension, as
- # the batch size is dynamic by default. This selects just the image size.
- input_shape = get_input_shape()[1:]
-
- # If weights == 'imagenet', the model will load the appropriate weights from
- # an external tf.keras URL.
- weights = None
- if FLAGS.use_external_weights and FLAGS.data == 'imagenet':
- weights = 'imagenet'
-
- model = APP_MODELS[FLAGS.model](weights=weights,
- include_top=FLAGS.include_top,
- input_shape=input_shape)
-
- if FLAGS.use_external_weights and FLAGS.data == 'cifar10':
- if not FLAGS.url:
- raise ValueError(
- 'cifar10 weights cannot be loaded without the `--url` flag.')
- model = load_cifar10_weights(model)
- return model
-
-
-class VisionModule(tf.Module):
-
- def __init__(self):
- super().__init__()
- self.m = initialize_model()
- self.m.predict = lambda x: self.m.call(x, training=False)
- # Specify input shape with a static batch size.
- # TODO(b/142948097): Add support for dynamic shapes in SPIR-V lowering.
- # Replace input_shape with m.input_shape to make the batch size dynamic.
- self.predict = tf.function(
- input_signature=[tf.TensorSpec(get_input_shape())])(self.m.predict)
-
-
-class AppTest(tf_test_utils.TracedModuleTestCase):
-
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
- self._modules = tf_test_utils.compile_tf_module(VisionModule,
- exported_names=['predict'])
-
- def test_predict(self):
-
- def predict(module):
- module.predict(tf_utils.uniform(get_input_shape()), atol=1e-5, rtol=1e-5)
-
- self.compare_backends(predict, self._modules)
-
-
-def main(argv):
- del argv # Unused
- if hasattr(tf, 'enable_v2_behavior'):
- tf.enable_v2_behavior()
-
- if FLAGS.model not in APP_MODELS:
- raise ValueError(f'Unsupported model: {FLAGS.model}')
- # Override VisionModule's __name__ to be more specific.
- VisionModule.__name__ = os.path.join(FLAGS.model, FLAGS.data)
-
- tf.test.main()
-
-
-if __name__ == '__main__':
- app.run(main)
diff --git a/integrations/tensorflow/e2e/math/math_test.py b/integrations/tensorflow/e2e/math/math_test.py
index 00020d0..6ec57b4 100644
--- a/integrations/tensorflow/e2e/math/math_test.py
+++ b/integrations/tensorflow/e2e/math/math_test.py
@@ -696,12 +696,34 @@
setattr(self, unit_test_spec.unit_test_name, function_unit_test)
+def get_relative_artifacts_dir() -> str:
+ if len(FLAGS.functions) > 1:
+ # We only allow testing multiple functions with a single target backend
+ # so that we can store the artifacts under:
+ # 'artifacts_dir/multiple_functions__backend/...'
+ # We specialize the 'multiple_functions' dir by backend to avoid overwriting
+ # tf_input.mlir and iree_input.mlir. These are typically identical across
+ # backends, but are not when the functions to compile change per-backend.
+ if len(FLAGS.target_backends) != 1:
+ raise flags.IllegalFlagValueError(
+ "Expected len(target_backends) == 1 when len(functions) > 1, but got "
+ f"the following values for target_backends: {FLAGS.target_backends}.")
+ function_str = f"multiple_functions__{FLAGS.target_backends[0]}"
+ else:
+ function_str = FLAGS.functions[0]
+ dim_str = "dynamic_dims" if FLAGS.dynamic_dims else "static_dims"
+ complex_str = "complex" if FLAGS.test_complex else "non_complex"
+ return os.path.join("tf", "math", function_str, f"{dim_str}_{complex_str}")
+
+
class TfMathTest(tf_test_utils.TracedModuleTestCase):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._modules = tf_test_utils.compile_tf_module(
- TfMathModule, exported_names=TfMathModule.get_tf_function_unit_tests())
+ TfMathModule,
+ exported_names=TfMathModule.get_tf_function_unit_tests(),
+ relative_artifacts_dir=get_relative_artifacts_dir())
def main(argv):
@@ -721,26 +743,6 @@
"'--functions' must be specified if "
"'--list_functions_with_complex_tests' isn't")
- if len(FLAGS.functions) > 1:
- # We only allow testing multiple functions with a single target backend
- # so that we can store the artifacts under:
- # 'artifacts_dir/multiple_functions__backend/...'
- # We specialize the 'multiple_functions' dir by backend to avoid overwriting
- # tf_input.mlir and iree_input.mlir. These are typically identical across
- # backends, but are not when the functions to compile change per-backend.
- if len(FLAGS.target_backends) != 1:
- raise flags.IllegalFlagValueError(
- "Expected len(target_backends) == 1 when len(functions) > 1, but got "
- f"the following values for target_backends: {FLAGS.target_backends}.")
- function_str = f"multiple_functions__{FLAGS.target_backends[0]}"
- else:
- function_str = FLAGS.functions[0]
- dim_str = "dynamic_dims" if FLAGS.dynamic_dims else "static_dims"
- settings_str = os.path.join(function_str, dim_str)
- # The relative artifacts directory path is calculated from the module name
- # TODO(meadowlark): provide a better way of overridding this default.
- TfMathModule.__name__ = os.path.join("tf", "math", settings_str)
-
TfMathTest.generate_unit_tests(TfMathModule)
tf.test.main()
diff --git a/integrations/tensorflow/e2e/slim_vision_models/slim_vision_model_test.py b/integrations/tensorflow/e2e/slim_vision_models/slim_vision_model_test.py
index 0580eaf..1d2fb5f 100644
--- a/integrations/tensorflow/e2e/slim_vision_models/slim_vision_model_test.py
+++ b/integrations/tensorflow/e2e/slim_vision_models/slim_vision_model_test.py
@@ -77,8 +77,10 @@
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
- self._modules = tf_test_utils.compile_tf_module(SlimVisionModule,
- exported_names=['predict'])
+ self._modules = tf_test_utils.compile_tf_module(
+ SlimVisionModule,
+ exported_names=['predict'],
+ relative_artifacts_dir=FLAGS.model)
def test_predict(self):
@@ -94,8 +96,6 @@
del argv # Unused.
if hasattr(tf, 'enable_v2_behavior'):
tf.enable_v2_behavior()
-
- SlimVisionModule.__name__ = FLAGS.model
tf.test.main()
diff --git a/integrations/tensorflow/e2e/tensorlist_test.py b/integrations/tensorflow/e2e/tensorlist_test.py
index e76790e..104ede9 100644
--- a/integrations/tensorflow/e2e/tensorlist_test.py
+++ b/integrations/tensorflow/e2e/tensorlist_test.py
@@ -67,6 +67,12 @@
ta = ta.write(1, b)
return ta.stack()
+ @tf.function(input_signature=[tf.TensorSpec([], tf.float32)])
+ def partially_empty_stack(self, x):
+ ta = tf.TensorArray(dtype=tf.float32, size=2, element_shape=[])
+ ta = ta.write(0, x)
+ return ta.stack()
+
class TensorListTest(tf_test_utils.TracedModuleTestCase):
@@ -104,6 +110,11 @@
module.concat_with_tensorlist_stack(np.array(42., dtype=np.float32),
np.array(43., dtype=np.float32))
self.compare_backends(concat_with_tensorlist_stack, self._modules)
+
+ def test_partially_empty_stack(self):
+ def partially_empty_stack(module):
+ module.partially_empty_stack(np.array(42., dtype=np.float32))
+ self.compare_backends(partially_empty_stack, self._modules)
# yapf: enable
diff --git a/iree/base/BUILD b/iree/base/BUILD
index e6e634d..e90ea64 100644
--- a/iree/base/BUILD
+++ b/iree/base/BUILD
@@ -67,6 +67,16 @@
],
)
+cc_test(
+ name = "math_test",
+ srcs = ["math_test.cc"],
+ deps = [
+ ":core_headers",
+ "//iree/testing:gtest",
+ "//iree/testing:gtest_main",
+ ],
+)
+
#===------------------------------------------------------------------------===#
# Internal IREE C++ wrappers and utilities
#===------------------------------------------------------------------------===#
diff --git a/iree/base/CMakeLists.txt b/iree/base/CMakeLists.txt
index c45233e..e68b8e0 100644
--- a/iree/base/CMakeLists.txt
+++ b/iree/base/CMakeLists.txt
@@ -309,6 +309,17 @@
PUBLIC
)
+iree_cc_test(
+ NAME
+ math_test
+ SRCS
+ "math_test.cc"
+ DEPS
+ ::core_headers
+ iree::testing::gtest
+ iree::testing::gtest_main
+)
+
iree_cc_library(
NAME
ref_ptr
diff --git a/iree/base/internal/file_io_posix.cc b/iree/base/internal/file_io_posix.cc
index c67403e..fd1a718 100644
--- a/iree/base/internal/file_io_posix.cc
+++ b/iree/base/internal/file_io_posix.cc
@@ -126,7 +126,13 @@
return tmpdir;
}
+#ifdef __ANDROID__
+ // Support running Android command-line programs both as regular shell user
+ // and as root. For the latter, TMPDIR is not defined by default.
+ return "/data/local/tmp";
+#else
return "/tmp";
+#endif
}
StatusOr<std::string> GetTempFile(absl::string_view base_name) {
diff --git a/iree/base/math.h b/iree/base/math.h
index 3e27781..7e64db1 100644
--- a/iree/base/math.h
+++ b/iree/base/math.h
@@ -15,9 +15,11 @@
#ifndef IREE_BASE_MATH_H_
#define IREE_BASE_MATH_H_
-#include <cstdint>
+#include <stdbool.h>
+#include <stdint.h>
+#include <stdlib.h>
-#include "absl/base/attributes.h"
+#include "iree/base/target_platform.h"
// Haswell or later, gcc compile time option: -mlzcnt
#if defined(__LZCNT__)
@@ -26,88 +28,98 @@
// Clang on Windows has __builtin_clzll; otherwise we need to use the
// windows intrinsic functions.
-#if defined(_MSC_VER)
+#if defined(IREE_COMPILER_MSVC)
#include <intrin.h>
-#if defined(_M_X64)
+#if defined(IREE_ARCH_ARM_64) || defined(IREE_ARCH_X86_64)
#pragma intrinsic(_BitScanReverse64)
#pragma intrinsic(_BitScanForward64)
#endif
#pragma intrinsic(_BitScanReverse)
#pragma intrinsic(_BitScanForward)
-#endif
+#endif // IREE_COMPILER_MSVC
-#if defined(_MSC_VER)
-// We can achieve something similar to attribute((always_inline)) with MSVC by
-// using the __forceinline keyword, however this is not perfect. MSVC is
-// much less aggressive about inlining, and even with the __forceinline keyword.
-#define IREE_FORCEINLINE __forceinline
-#else
-// Use default attribute inline.
-#define IREE_FORCEINLINE inline ABSL_ATTRIBUTE_ALWAYS_INLINE
-#endif
+//==============================================================================
+// Bitwise rotation (aka circular shifts)
+//==============================================================================
-namespace iree {
+// Unsigned rotate-left a 64-bit integer.
+// https://en.cppreference.com/w/cpp/numeric/rotl
+//
+//
+// NOTE: this exact form is confirmed to be recognized by the compilers we care
+// about; do not modify: https://godbolt.org/z/xzof9d
+static inline uint64_t iree_math_rotl_u64(const uint64_t n, uint32_t c) {
+ if (!c) return n;
+ const uint32_t mask = 8 * sizeof(n) - 1;
+ c &= mask;
+ return (n << c) | (n >> (64 - c));
+}
-IREE_FORCEINLINE int CountLeadingZeros32(uint32_t n) {
-#if defined(_MSC_VER)
+// Unsigned rotate-right a 64-bit integer.
+// https://en.cppreference.com/w/cpp/numeric/rotr
+//
+// NOTE: this exact form is confirmed to be recognized by the compilers we care
+// about **except MSVC**; do not modify: https://godbolt.org/z/xzof9d
+static inline uint64_t iree_math_rotr_u64(const uint64_t n, uint32_t c) {
+ if (!c) return n;
+ const uint32_t mask = 8 * sizeof(n) - 1;
+ c &= mask;
+ return (n >> c) | (n << ((-c) & mask));
+}
+
+//==============================================================================
+// Bit scanning/counting
+//==============================================================================
+
+static inline int iree_math_count_leading_zeros_u32(const uint32_t n) {
+#if defined(IREE_COMPILER_MSVC)
unsigned long result = 0; // NOLINT(runtime/int)
if (_BitScanReverse(&result, n)) {
- return 31 - result;
+ return (int)(31 - result);
}
return 32;
-#elif defined(__GNUC__)
- // Use __builtin_clz, which uses the following instructions:
- // x86: bsr
- // ARM64: clz
- // PPC: cntlzd
- static_assert(sizeof(int) == sizeof(n),
- "__builtin_clz does not take 32-bit arg");
-
+#elif defined(IREE_COMPILER_GCC_COMPAT)
#if defined(__LCZNT__)
// NOTE: LZCNT is a risky instruction; it is not supported on architectures
// before Haswell, yet it is encoded as 'rep bsr', which typically ignores
// invalid rep prefixes, and interprets it as the 'bsr' instruction, which
// returns the index of the value rather than the count, resulting in
// incorrect code.
- return __lzcnt32(n);
+ return (int)__lzcnt32(n);
#endif // defined(__LCZNT__)
// Handle 0 as a special case because __builtin_clz(0) is undefined.
- if (n == 0) {
- return 32;
- }
- return __builtin_clz(n);
-#else
-#error No clz for this arch.
-#endif
-}
-
-IREE_FORCEINLINE int CountLeadingZeros64(uint64_t n) {
-#if defined(_MSC_VER) && defined(_M_X64)
- // MSVC does not have __buitin_clzll. Use _BitScanReverse64.
- unsigned long result = 0; // NOLINT(runtime/int)
- if (_BitScanReverse64(&result, n)) {
- return 63 - result;
- }
- return 64;
-#elif defined(_MSC_VER)
- // MSVC does not have __buitin_clzll. Compose two calls to _BitScanReverse
- unsigned long result = 0; // NOLINT(runtime/int)
- if ((n >> 32) && _BitScanReverse(&result, n >> 32)) {
- return 31 - result;
- }
- if (_BitScanReverse(&result, n)) {
- return 63 - result;
- }
- return 64;
-#elif defined(__GNUC__)
- // Use __builtin_clzll, which uses the following instructions:
+ if (n == 0) return 32;
+ // Use __builtin_clz, which uses the following instructions:
// x86: bsr
// ARM64: clz
// PPC: cntlzd
- static_assert(sizeof(unsigned long long) == sizeof(n), // NOLINT(runtime/int)
- "__builtin_clzll does not take 64-bit arg");
+ return (int)__builtin_clz(n);
+#else
+#error No clz for this arch.
+#endif // IREE_COMPILER_MSVC / IREE_COMPILER_GCC_COMPAT
+}
+static inline int iree_math_count_leading_zeros_u64(uint64_t n) {
+#if defined(IREE_COMPILER_MSVC) && \
+ (defined(IREE_ARCH_ARM_64) || defined(IREE_ARCH_X86_64))
+ // MSVC does not have __buitin_clzll. Use _BitScanReverse64.
+ unsigned long result = 0; // NOLINT(runtime/int)
+ if (_BitScanReverse64(&result, n)) {
+ return (int)(63 - result);
+ }
+ return 64;
+#elif defined(IREE_COMPILER_MSVC)
+ // MSVC does not have __buitin_clzll. Compose two calls to _BitScanReverse
+ unsigned long result = 0; // NOLINT(runtime/int)
+ if ((n >> 32) && _BitScanReverse(&result, n >> 32)) {
+ return (int)(31 - result);
+ }
+ if (_BitScanReverse(&result, n)) {
+ return (int)(63 - result);
+ }
+ return 64;
+#elif defined(IREE_COMPILER_GCC_COMPAT)
#if defined(__LCZNT__)
// NOTE: LZCNT is a risky instruction; it is not supported on architectures
// before Haswell, yet it is encoded as 'rep bsr', which typically ignores
@@ -117,99 +129,267 @@
return __lzcnt64(n);
#elif defined(__aarch64__) || defined(__powerpc64__)
// Empirically verified that __builtin_clzll(0) works as expected.
- return __builtin_clzll(n);
+ return (int)__builtin_clzll(n);
#endif
-
// Handle 0 as a special case because __builtin_clzll(0) is undefined.
- if (n == 0) {
- return 64;
- }
- return __builtin_clzll(n);
+ if (!n) return 64;
+ // Use __builtin_clzll, which uses the following instructions:
+ // x86: bsr
+ // PPC: cntlzd
+ // WASM: i32.clz
+ // RISC-V: __clzsi2 in GCC, splat out in clang
+ return (int)__builtin_clzll(n);
#else
#error No clz for this arch.
-#endif
+#endif // IREE_COMPILER_MSVC / IREE_COMPILER_GCC_COMPAT
}
-IREE_FORCEINLINE int CountTrailingZerosNonZero32(uint32_t n) {
-#if defined(_MSC_VER)
+static inline int iree_math_count_trailing_zeros_u32(uint32_t n) {
+#if defined(IREE_COMPILER_MSVC)
unsigned long result = 0; // NOLINT(runtime/int)
_BitScanForward(&result, n);
- return result;
-#elif defined(__GNUC__)
- static_assert(sizeof(int) == sizeof(n),
- "__builtin_ctz does not take 32-bit arg");
- return __builtin_ctz(n);
+ return (int)result;
+#elif defined(IREE_COMPILER_GCC_COMPAT)
+ return (int)__builtin_ctz(n);
#else
int c = 31;
n &= ~n + 1;
- if (n & 0x0000FFFF) c -= 16;
- if (n & 0x00FF00FF) c -= 8;
- if (n & 0x0F0F0F0F) c -= 4;
- if (n & 0x33333333) c -= 2;
- if (n & 0x55555555) c -= 1;
+ if (n & 0x0000FFFFu) c -= 16;
+ if (n & 0x00FF00FFu) c -= 8;
+ if (n & 0x0F0F0F0Fu) c -= 4;
+ if (n & 0x33333333u) c -= 2;
+ if (n & 0x55555555u) c -= 1;
return c;
-#endif
+#endif // IREE_COMPILER_MSVC / IREE_COMPILER_GCC_COMPAT
}
-IREE_FORCEINLINE int CountTrailingZerosNonZero64(uint64_t n) {
-#if defined(_MSC_VER) && defined(_M_X64)
+static inline int iree_math_count_trailing_zeros_u64(uint64_t n) {
+#if defined(IREE_COMPILER_MSVC) && defined(IREE_PTR_SIZE_64)
unsigned long result = 0; // NOLINT(runtime/int)
_BitScanForward64(&result, n);
- return result;
-#elif defined(_MSC_VER)
+ return (int)result;
+#elif defined(IREE_COMPILER_MSVC) && defined(IREE_PTR_SIZE_32)
unsigned long result = 0; // NOLINT(runtime/int)
- if (static_cast<uint32_t>(n) == 0) {
+ if ((uint32_t)(n) == 0) {
_BitScanForward(&result, n >> 32);
return result + 32;
}
_BitScanForward(&result, n);
- return result;
-#elif defined(__GNUC__)
- static_assert(sizeof(unsigned long long) == sizeof(n), // NOLINT(runtime/int)
- "__builtin_ctzll does not take 64-bit arg");
+ return (int)result;
+#elif defined(IREE_COMPILER_GCC_COMPAT)
+ // Use __builtin_clzll, which uses the following instructions:
+ // x86: bsr
+ // PPC: cntlzd
+ // WASM: i64.clz
+ // RISC-V: __clzdi2 in GCC, splat out in clang
return __builtin_ctzll(n);
#else
int c = 63;
n &= ~n + 1;
- if (n & 0x00000000FFFFFFFF) c -= 32;
- if (n & 0x0000FFFF0000FFFF) c -= 16;
- if (n & 0x00FF00FF00FF00FF) c -= 8;
- if (n & 0x0F0F0F0F0F0F0F0F) c -= 4;
- if (n & 0x3333333333333333) c -= 2;
- if (n & 0x5555555555555555) c -= 1;
+ if (n & 0x00000000FFFFFFFFull) c -= 32;
+ if (n & 0x0000FFFF0000FFFFull) c -= 16;
+ if (n & 0x00FF00FF00FF00FFull) c -= 8;
+ if (n & 0x0F0F0F0F0F0F0F0Full) c -= 4;
+ if (n & 0x3333333333333333ull) c -= 2;
+ if (n & 0x5555555555555555ull) c -= 1;
return c;
-#endif
+#endif // IREE_COMPILER_MSVC / IREE_COMPILER_GCC_COMPAT
}
-template <typename T>
-IREE_FORCEINLINE int TrailingZeros(T x) {
- return sizeof(T) == 8 ? CountTrailingZerosNonZero64(static_cast<uint64_t>(x))
- : CountTrailingZerosNonZero32(static_cast<uint32_t>(x));
-}
-
-template <typename T>
-IREE_FORCEINLINE int LeadingZeros(T x) {
- return sizeof(T) == 8 ? CountLeadingZeros64(static_cast<uint64_t>(x))
- : CountLeadingZeros32(static_cast<uint32_t>(x));
-}
+//==============================================================================
+// Population count
+//==============================================================================
// Returns the number of 1 bits in a 32 bit value.
-IREE_FORCEINLINE int CountOnes32(uint32_t n) {
- n -= ((n >> 1) & 0x55555555);
- n = ((n >> 2) & 0x33333333) + (n & 0x33333333);
- return static_cast<int>((((n + (n >> 4)) & 0xF0F0F0F) * 0x1010101) >> 24);
+static inline int iree_math_count_ones_u32(uint32_t n) {
+ n -= ((n >> 1) & 0x55555555u);
+ n = ((n >> 2) & 0x33333333u) + (n & 0x33333333u);
+ return (int)((((n + (n >> 4)) & 0x0F0F0F0Fu) * 0x01010101u) >> 24);
}
// Returns the number of 1 bits in a 64 bit value.
-IREE_FORCEINLINE int CountOnes64(uint64_t n) {
- return CountOnes32(n >> 32) + CountOnes32(n & 0xffffffff);
+static inline int iree_math_count_ones_u64(uint64_t n) {
+ return iree_math_count_ones_u32(n >> 32) +
+ iree_math_count_ones_u32(n & 0xFFFFFFFFu);
+}
+
+//==============================================================================
+// Rounding and alignment
+//==============================================================================
+// There are certain platforms - mostly those with poorer quality compilers or
+// more restricted instruction sets - where we want to avoid the clz path as
+// it is emulated and instead we use some bit-twiddling hacks. On other
+// platforms it's the opposite - they may emulate clz but doing so saves
+// dozens of bytes that otherwise would have been the shift/or tree.
+//
+// Which to choose is entirely determined by fiddling on godbolt for the
+// target platform: https://godbolt.org/z/h4vPzo
+
+// Rounds up the value to the nearest power of 2 (if not already a power of 2).
+// For 32-bit numbers this only supports values <= 2^31; higher will wrap.
+static inline uint32_t iree_math_round_up_to_pow2_u32(uint32_t n) {
+#if 0 // golf required; can be bloated
+ const uint32_t i = (n != 1);
+ return (1 + i) << ((iree_math_count_leading_zeros_u32(n - i) ^ 31));
+#elif 0 // golf required; can be bloated
+ return n == 1 ? 1u : 2u << ((iree_math_count_leading_zeros_u32(n - 1) ^ 31));
+#else
+ // https://graphics.stanford.edu/~seander/bithacks.html#RoundUpPowerOf2
+ n--;
+ n |= n >> 1;
+ n |= n >> 2;
+ n |= n >> 4;
+ n |= n >> 8;
+ n |= n >> 16;
+ return n + 1;
+#endif // 1
}
// Rounds up the value to the nearest power of 2 (if not already a power of 2).
-IREE_FORCEINLINE int RoundUpToNearestPow2(int n) {
- return n ? ~0u >> LeadingZeros(n) : 1;
+// For 64-bit numbers this only supports values <= 2^63; higher will wrap.
+static inline uint64_t iree_math_round_up_to_pow2_u64(uint64_t n) {
+#if 0 // golf required; can be bloated
+ const uint64_t i = (n != 1);
+ return (1 + i) << ((iree_math_count_leading_zeros_u64(n - i) ^ 63));
+#elif 0 // golf required; can be bloated
+ return n == 1 ? 1ull
+ : 2ull << ((iree_math_count_leading_zeros_u64(n - 1) ^ 63));
+#else
+ // https://graphics.stanford.edu/~seander/bithacks.html#RoundUpPowerOf2
+ n--;
+ n |= n >> 1;
+ n |= n >> 2;
+ n |= n >> 4;
+ n |= n >> 8;
+ n |= n >> 16;
+ n |= n >> 32;
+ return n + 1;
+#endif // 1
}
-} // namespace iree
+//==============================================================================
+// Pseudo-random number generators (PRNGs): **NOT CRYPTOGRAPHICALLY SECURE*
+//==============================================================================
+
+// A fixed-increment version of Java 8's SplittableRandom generator
+// See http://dx.doi.org/10.1145/2714064.2660195 and
+// http://docs.oracle.com/javase/8/docs/api/java/util/SplittableRandom.html
+//
+// SplitMix64 as recommended for use with xoroshiro by the authors:
+// http://prng.di.unimi.it/splitmix64.c
+// http://rosettacode.org/wiki/Pseudo-random_numbers/Splitmix64
+typedef uint64_t iree_prng_splitmix64_state_t;
+
+// Initializes a SplitMix64 PRNG state vector; |out_state| is overwritten.
+// |seed| may be any 64-bit value.
+static inline void iree_prng_splitmix64_initialize(
+ uint64_t seed, iree_prng_splitmix64_state_t* out_state) {
+ *out_state = seed;
+}
+
+// Steps a SplitMix64 PRNG state vector and yields a value for use.
+static inline uint64_t iree_prng_splitmix64_next(
+ iree_prng_splitmix64_state_t* state) {
+ uint64_t z = (*state += 0x9E3779B97F4A7C15ull);
+ z = (z ^ (z >> 30)) * 0xBF58476D1CE4E5B9ull;
+ z = (z ^ (z >> 27)) * 0x94D049BB133111EBull;
+ return z ^ (z >> 31);
+}
+
+// A small **pseudorandom** number generator (named after the operations used).
+// http://prng.di.unimi.it/
+typedef struct {
+ uint64_t value[2];
+} iree_prng_xoroshiro128_state_t;
+
+// Initializes a xoroshiro128+ PRNG state vector; |out_state| is overwritten.
+// |seed| may be any 64-bit value.
+static inline void iree_prng_xoroshiro128_initialize(
+ uint64_t seed, iree_prng_xoroshiro128_state_t* out_state) {
+ // The authors recommend using SplitMix64 to go from a single int seed
+ // into the two state values we need. It's critical that we don't use a
+ // xoroshiro128 for this as seeding a PRNG with the results of itself is...
+ // unsound.
+ iree_prng_splitmix64_state_t init_state;
+ iree_prng_splitmix64_initialize(seed, &init_state);
+ out_state->value[0] = iree_prng_splitmix64_next(&seed);
+ out_state->value[1] = iree_prng_splitmix64_next(&seed);
+
+ // A state of 0 will never produce anything but zeros so ensure that doesn't
+ // happen; of course, after running splitmix that should be closer to the
+ // side of never than not.
+ if (!out_state->value[0] && !out_state->value[1]) {
+ out_state->value[0] = 1;
+ }
+}
+
+// Steps a xoroshiro128 state vector and yields a value for use.
+// xoroshiro128+ variant: produces a single value with 32-bit bits of entropy.
+// This is the fastest variant but the lower 4 bits of the returned value may
+// not be sufficiently well-distributed. This is fine if the usage requires
+// fewer than 60 bits such as when sampling bools or array indices.
+// Note also that this works great for floating-point numbers where only 23 or
+// 53 bits are required to populate a mantissa and an additional step can be
+// used to generate the sign/exponent when required.
+//
+// footprint: 128-bits
+// period: 2^128 - 1
+// ns/64-bits: 0.72
+// cycles/byte: 0.29
+//
+// http://prng.di.unimi.it/xoroshiro128plus.c
+static inline uint64_t iree_prng_xoroshiro128plus_next_uint60(
+ iree_prng_xoroshiro128_state_t* state) {
+ uint64_t s0 = state->value[0];
+ uint64_t s1 = state->value[1];
+ const uint64_t result = s0 + s1;
+ s1 ^= s0;
+ state->value[0] = iree_math_rotl_u64(s0, 24) ^ s1 ^ (s1 << 16); // a, b
+ state->value[1] = iree_math_rotl_u64(s1, 37); // c
+ return result;
+}
+
+// Steps a xoroshiro128 state vector and yields a single boolean value for use.
+// See iree_prng_xoroshiro128plus_next_uint60 for details.
+static inline bool iree_prng_xoroshiro128plus_next_bool(
+ iree_prng_xoroshiro128_state_t* state) {
+ return (bool)(iree_prng_xoroshiro128plus_next_uint60(state) >> (64 - 1));
+}
+
+// Steps a xoroshiro128 state vector and yields a single uint8_t value for use.
+// See iree_prng_xoroshiro128plus_next_uint60 for details.
+static inline uint8_t iree_prng_xoroshiro128plus_next_uint8(
+ iree_prng_xoroshiro128_state_t* state) {
+ return (uint8_t)(iree_prng_xoroshiro128plus_next_uint60(state) >> (64 - 8));
+}
+
+// Steps a xoroshiro128 state vector and yields a single uint32_t value for use.
+// See iree_prng_xoroshiro128plus_next_uint60 for details.
+static inline uint32_t iree_prng_xoroshiro128plus_next_uint32(
+ iree_prng_xoroshiro128_state_t* state) {
+ return (uint32_t)(iree_prng_xoroshiro128plus_next_uint60(state) >> (64 - 32));
+}
+
+// Steps a xoroshiro128 state vector and yields a value for use.
+// xoroshiro128** variant: produces a single value with 32-bit bits of entropy.
+// Prefer this to xoroshiro128+ when good distribution over the integer range
+// is required; see xoroshiro128+ for details of its issues.
+//
+// footprint: 128-bits
+// period: 2^128 - 1
+// ns/64-bits: 0.93
+// cycles/byte: 0.42
+//
+// http://prng.di.unimi.it/xoroshiro128starstar.c
+static inline uint64_t iree_prng_xoroshiro128starstar_next_uint64(
+ iree_prng_xoroshiro128_state_t* state) {
+ uint64_t s0 = state->value[0];
+ uint64_t s1 = state->value[1];
+ const uint64_t result = iree_math_rotl_u64(s0 * 5, 7) * 9;
+ s1 ^= s0;
+ state->value[0] = iree_math_rotl_u64(s0, 24) ^ s1 ^ (s1 << 16); // a, b
+ state->value[1] = iree_math_rotl_u64(s1, 37); // c
+ return result;
+}
#endif // IREE_BASE_MATH_H_
diff --git a/iree/base/math_test.cc b/iree/base/math_test.cc
new file mode 100644
index 0000000..7dad343
--- /dev/null
+++ b/iree/base/math_test.cc
@@ -0,0 +1,226 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "iree/base/math.h"
+
+#include "iree/testing/gtest.h"
+#include "iree/testing/status_matchers.h"
+
+namespace {
+
+//==============================================================================
+// Bitwise rotation (aka circular shifts)
+//==============================================================================
+
+TEST(BitwiseRotationTest, ROTL64) {
+ EXPECT_EQ(0ull, iree_math_rotl_u64(0ull, 0u));
+ EXPECT_EQ(0ull, iree_math_rotl_u64(0ull, 0u));
+ EXPECT_EQ(1ull, iree_math_rotl_u64(1ull, 0u));
+ EXPECT_EQ(1ull, iree_math_rotl_u64(1ull, 0u));
+
+ EXPECT_EQ(2ull, iree_math_rotl_u64(1ull, 1u));
+ EXPECT_EQ(2ull, iree_math_rotl_u64(1ull, 1u));
+ EXPECT_EQ(UINT64_MAX, iree_math_rotl_u64(UINT64_MAX, 63u));
+ EXPECT_EQ(UINT64_MAX, iree_math_rotl_u64(UINT64_MAX, 64u));
+}
+
+TEST(BitwiseRotationTest, ROTR64) {
+ EXPECT_EQ(0ull, iree_math_rotr_u64(0ull, 0u));
+ EXPECT_EQ(0ull, iree_math_rotr_u64(0ull, 0u));
+ EXPECT_EQ(1ull, iree_math_rotr_u64(1ull, 0u));
+ EXPECT_EQ(1ull, iree_math_rotr_u64(1ull, 0u));
+
+ EXPECT_EQ(1ull, iree_math_rotr_u64(2ull, 1u));
+ EXPECT_EQ(0x8000000000000000ull, iree_math_rotr_u64(2ull, 2u));
+ EXPECT_EQ(0x8000000000000000ull, iree_math_rotr_u64(1ull, 1u));
+ EXPECT_EQ(0x4000000000000000ull, iree_math_rotr_u64(1ull, 2u));
+}
+
+//==============================================================================
+// Bit scanning/counting
+//==============================================================================
+
+TEST(BitwiseScansTest, CLZ32) {
+ EXPECT_EQ(32, iree_math_count_leading_zeros_u32(uint32_t{}));
+ EXPECT_EQ(0, iree_math_count_leading_zeros_u32(~uint32_t{}));
+ for (int index = 0; index < 32; index++) {
+ uint32_t x = 1u << index;
+ const int cnt = 31 - index;
+ ASSERT_EQ(cnt, iree_math_count_leading_zeros_u32(x)) << index;
+ ASSERT_EQ(cnt, iree_math_count_leading_zeros_u32(x + x - 1)) << index;
+ }
+}
+
+TEST(BitwiseScansTest, CLZ64) {
+ EXPECT_EQ(64, iree_math_count_leading_zeros_u64(uint64_t{}));
+ EXPECT_EQ(0, iree_math_count_leading_zeros_u64(~uint64_t{}));
+ for (int index = 0; index < 64; index++) {
+ uint64_t x = 1ull << index;
+ const int cnt = 63 - index;
+ ASSERT_EQ(cnt, iree_math_count_leading_zeros_u64(x)) << index;
+ ASSERT_EQ(cnt, iree_math_count_leading_zeros_u64(x + x - 1)) << index;
+ }
+}
+
+TEST(BitwiseScansTest, CTZ32) {
+ EXPECT_EQ(0, iree_math_count_trailing_zeros_u32(~uint32_t{}));
+ for (int index = 0; index < 32; index++) {
+ uint32_t x = static_cast<uint32_t>(1) << index;
+ const int cnt = index;
+ ASSERT_EQ(cnt, iree_math_count_trailing_zeros_u32(x)) << index;
+ ASSERT_EQ(cnt, iree_math_count_trailing_zeros_u32(~(x - 1))) << index;
+ }
+}
+
+TEST(BitwiseScansTest, CTZ64) {
+ // iree_math_count_trailing_zeros_u32
+ EXPECT_EQ(0, iree_math_count_trailing_zeros_u64(~uint64_t{}));
+ for (int index = 0; index < 64; index++) {
+ uint64_t x = static_cast<uint64_t>(1) << index;
+ const int cnt = index;
+ ASSERT_EQ(cnt, iree_math_count_trailing_zeros_u64(x)) << index;
+ ASSERT_EQ(cnt, iree_math_count_trailing_zeros_u64(~(x - 1))) << index;
+ }
+}
+
+//==============================================================================
+// Population count
+//==============================================================================
+
+TEST(PopulationCountTest, Ones32) {
+ EXPECT_EQ(0, iree_math_count_ones_u32(0u));
+ EXPECT_EQ(1, iree_math_count_ones_u32(1u));
+ EXPECT_EQ(29, iree_math_count_ones_u32(-15u));
+ EXPECT_EQ(5, iree_math_count_ones_u32(341u));
+ EXPECT_EQ(32, iree_math_count_ones_u32(UINT32_MAX));
+ EXPECT_EQ(31, iree_math_count_ones_u32(UINT32_MAX - 1));
+}
+
+TEST(PopulationCountTest, Ones64) {
+ EXPECT_EQ(0, iree_math_count_ones_u64(0ull));
+ EXPECT_EQ(1, iree_math_count_ones_u64(1ull));
+ EXPECT_EQ(61, iree_math_count_ones_u64(-15ull));
+ EXPECT_EQ(5, iree_math_count_ones_u64(341ull));
+ EXPECT_EQ(64, iree_math_count_ones_u64(UINT64_MAX));
+ EXPECT_EQ(63, iree_math_count_ones_u64(UINT64_MAX - 1ull));
+}
+
+//==============================================================================
+// Rounding and alignment
+//==============================================================================
+
+TEST(RoundingTest, UpToNextPow232) {
+ constexpr uint32_t kUint16Max = UINT16_MAX;
+ constexpr uint32_t kUint32Max = UINT32_MAX;
+ EXPECT_EQ(0u, iree_math_round_up_to_pow2_u32(0u));
+ EXPECT_EQ(1u, iree_math_round_up_to_pow2_u32(1u));
+ EXPECT_EQ(2u, iree_math_round_up_to_pow2_u32(2u));
+ EXPECT_EQ(4u, iree_math_round_up_to_pow2_u32(3u));
+ EXPECT_EQ(8u, iree_math_round_up_to_pow2_u32(8u));
+ EXPECT_EQ(16u, iree_math_round_up_to_pow2_u32(9u));
+ EXPECT_EQ(kUint16Max + 1u, iree_math_round_up_to_pow2_u32(kUint16Max - 1u));
+ EXPECT_EQ(kUint16Max + 1u, iree_math_round_up_to_pow2_u32(kUint16Max));
+ EXPECT_EQ(kUint16Max + 1u, iree_math_round_up_to_pow2_u32(kUint16Max + 1u));
+ EXPECT_EQ(131072u, iree_math_round_up_to_pow2_u32(kUint16Max + 2u));
+ EXPECT_EQ(262144u, iree_math_round_up_to_pow2_u32(262144u - 1u));
+ EXPECT_EQ(0x80000000u, iree_math_round_up_to_pow2_u32(0x7FFFFFFFu));
+ EXPECT_EQ(0x80000000u, iree_math_round_up_to_pow2_u32(0x80000000u));
+
+ // NOTE: wrap to 0.
+ EXPECT_EQ(0u, iree_math_round_up_to_pow2_u32(0x80000001u));
+ EXPECT_EQ(0u, iree_math_round_up_to_pow2_u32(kUint32Max - 1u));
+ EXPECT_EQ(0u, iree_math_round_up_to_pow2_u32(kUint32Max));
+}
+
+TEST(RoundingTest, UpToNextPow264) {
+ constexpr uint64_t kUint16Max = UINT16_MAX;
+ constexpr uint64_t kUint64Max = UINT64_MAX;
+ EXPECT_EQ(0ull, iree_math_round_up_to_pow2_u64(0ull));
+ EXPECT_EQ(1ull, iree_math_round_up_to_pow2_u64(1ull));
+ EXPECT_EQ(2ull, iree_math_round_up_to_pow2_u64(2ull));
+ EXPECT_EQ(4ull, iree_math_round_up_to_pow2_u64(3ull));
+ EXPECT_EQ(8ull, iree_math_round_up_to_pow2_u64(8ull));
+ EXPECT_EQ(16ull, iree_math_round_up_to_pow2_u64(9ull));
+ EXPECT_EQ(kUint16Max + 1ull,
+ iree_math_round_up_to_pow2_u64(kUint16Max - 1ull));
+ EXPECT_EQ(kUint16Max + 1ull, iree_math_round_up_to_pow2_u64(kUint16Max));
+ EXPECT_EQ(kUint16Max + 1ull,
+ iree_math_round_up_to_pow2_u64(kUint16Max + 1ull));
+ EXPECT_EQ(131072ull, iree_math_round_up_to_pow2_u64(kUint16Max + 2ull));
+ EXPECT_EQ(0x100000000ull, iree_math_round_up_to_pow2_u64(0xFFFFFFFEull));
+ EXPECT_EQ(0x100000000ull, iree_math_round_up_to_pow2_u64(0xFFFFFFFFull));
+ EXPECT_EQ(0x80000000ull, iree_math_round_up_to_pow2_u64(0x7FFFFFFFull));
+ EXPECT_EQ(0x80000000ull, iree_math_round_up_to_pow2_u64(0x80000000ull));
+ EXPECT_EQ(0x100000000ull, iree_math_round_up_to_pow2_u64(0x80000001ull));
+
+ // NOTE: wrap to 0.
+ EXPECT_EQ(0ull, iree_math_round_up_to_pow2_u64(0x8000000000000001ull));
+ EXPECT_EQ(0ull, iree_math_round_up_to_pow2_u64(kUint64Max - 1ull));
+ EXPECT_EQ(0ull, iree_math_round_up_to_pow2_u64(kUint64Max));
+}
+
+//==============================================================================
+// Pseudo-random number generators (PRNGs): **NOT CRYPTOGRAPHICALLY SECURE*
+//==============================================================================
+// NOTE: we leave the real testing to the authors; this just ensures we aren't
+// `return 4;`ing it or ignoring the seed.
+
+TEST(PRNG, SplitMix64) {
+ iree_prng_splitmix64_state_t state;
+
+ iree_prng_splitmix64_initialize(/*seed=*/0ull, &state);
+ EXPECT_EQ(16294208416658607535ull, iree_prng_splitmix64_next(&state));
+ EXPECT_EQ(7960286522194355700ull, iree_prng_splitmix64_next(&state));
+
+ iree_prng_splitmix64_initialize(/*seed=*/1ull, &state);
+ EXPECT_EQ(10451216379200822465ull, iree_prng_splitmix64_next(&state));
+ EXPECT_EQ(13757245211066428519ull, iree_prng_splitmix64_next(&state));
+
+ iree_prng_splitmix64_initialize(/*seed=*/UINT64_MAX, &state);
+ EXPECT_EQ(16490336266968443936ull, iree_prng_splitmix64_next(&state));
+ EXPECT_EQ(16834447057089888969ull, iree_prng_splitmix64_next(&state));
+}
+
+TEST(PRNG, Xoroshiro128) {
+ iree_prng_xoroshiro128_state_t state;
+
+ iree_prng_xoroshiro128_initialize(/*seed=*/0ull, &state);
+ EXPECT_EQ(5807750865143411619ull,
+ iree_prng_xoroshiro128plus_next_uint60(&state));
+ EXPECT_TRUE(iree_prng_xoroshiro128plus_next_bool(&state));
+ EXPECT_EQ(218u, iree_prng_xoroshiro128plus_next_uint8(&state));
+ EXPECT_EQ(1647201753u, iree_prng_xoroshiro128plus_next_uint32(&state));
+ EXPECT_EQ(7260361800523965311ull,
+ iree_prng_xoroshiro128starstar_next_uint64(&state));
+
+ iree_prng_xoroshiro128_initialize(/*seed=*/1ull, &state);
+ EXPECT_EQ(5761717516557699368ull,
+ iree_prng_xoroshiro128plus_next_uint60(&state));
+ EXPECT_TRUE(iree_prng_xoroshiro128plus_next_bool(&state));
+ EXPECT_EQ(103u, iree_prng_xoroshiro128plus_next_uint8(&state));
+ EXPECT_EQ(2242241045u, iree_prng_xoroshiro128plus_next_uint32(&state));
+ EXPECT_EQ(661144386810419178ull,
+ iree_prng_xoroshiro128starstar_next_uint64(&state));
+
+ iree_prng_xoroshiro128_initialize(/*seed=*/UINT64_MAX, &state);
+ EXPECT_EQ(14878039250348781289ull,
+ iree_prng_xoroshiro128plus_next_uint60(&state));
+ EXPECT_FALSE(iree_prng_xoroshiro128plus_next_bool(&state));
+ EXPECT_EQ(137u, iree_prng_xoroshiro128plus_next_uint8(&state));
+ EXPECT_EQ(2111322015u, iree_prng_xoroshiro128plus_next_uint32(&state));
+ EXPECT_EQ(138107609852220106ull,
+ iree_prng_xoroshiro128starstar_next_uint64(&state));
+}
+
+} // namespace
diff --git a/iree/compiler/Conversion/CodegenUtils/FunctionUtils.cpp b/iree/compiler/Conversion/CodegenUtils/FunctionUtils.cpp
index 7025221..5930e1f 100644
--- a/iree/compiler/Conversion/CodegenUtils/FunctionUtils.cpp
+++ b/iree/compiler/Conversion/CodegenUtils/FunctionUtils.cpp
@@ -14,6 +14,7 @@
#include "iree/compiler/Conversion/CodegenUtils/FunctionUtils.h"
+#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/IR/SymbolTable.h"
namespace mlir {
@@ -24,5 +25,14 @@
SymbolTable::Visibility::Public;
}
+unsigned getNumOuterParallelLoops(linalg::LinalgOp op) {
+ return op.iterator_types()
+ .getValue()
+ .take_while([](Attribute attr) -> bool {
+ return linalg::isParallelIteratorType(attr);
+ })
+ .size();
+}
+
} // namespace iree_compiler
} // namespace mlir
diff --git a/iree/compiler/Conversion/CodegenUtils/FunctionUtils.h b/iree/compiler/Conversion/CodegenUtils/FunctionUtils.h
index 66c85bd..b305cc7 100644
--- a/iree/compiler/Conversion/CodegenUtils/FunctionUtils.h
+++ b/iree/compiler/Conversion/CodegenUtils/FunctionUtils.h
@@ -15,6 +15,7 @@
#ifndef IREE_COMPILER_CONVERSION_CODEGENUTILS_FUNCTIONUTILS_H_
#define IREE_COMPILER_CONVERSION_CODEGENUTILS_FUNCTIONUTILS_H_
+#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/IR/Function.h"
namespace mlir {
@@ -33,6 +34,9 @@
return "operand_result_index";
}
+/// Returns the number of outer parallel loops of a linalgOp.
+unsigned getNumOuterParallelLoops(linalg::LinalgOp op);
+
} // namespace iree_compiler
} // namespace mlir
diff --git a/iree/compiler/Conversion/CodegenUtils/GetNumWorkgroups.cpp b/iree/compiler/Conversion/CodegenUtils/GetNumWorkgroups.cpp
index 452b662..be9760b 100644
--- a/iree/compiler/Conversion/CodegenUtils/GetNumWorkgroups.cpp
+++ b/iree/compiler/Conversion/CodegenUtils/GetNumWorkgroups.cpp
@@ -14,18 +14,24 @@
#include "iree/compiler/Conversion/CodegenUtils/GetNumWorkgroups.h"
+#include "iree/compiler/Conversion/CodegenUtils/FunctionUtils.h"
+#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
+#include "iree/compiler/Dialect/HAL/Utils/TypeUtils.h"
#include "iree/compiler/Dialect/Shape/IR/ShapeOps.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/Support/Debug.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/IR/Function.h"
+#include "mlir/IR/Matchers.h"
#include "mlir/IR/Module.h"
#define DEBUG_TYPE "workgroup-calculation"
namespace mlir {
namespace iree_compiler {
+
FuncOp getNumWorkgroupsFn(FuncOp entryPointFn,
llvm::StringRef numWorkgroupsFnAttr) {
SymbolRefAttr attr =
@@ -43,9 +49,23 @@
return numWorkgroupsFn;
}
-/// Computes the bounds of the parallel loops partitioned across workgroups.
-static Optional<SmallVector<Value, 2>> getParallelLoopRange(
- PatternRewriter &rewriter, FuncOp numWorkgroupsFn, Location loc,
+// TODO: This method is templated on the builder type since the `OpBuilder`
+// doesnt have an erase method. Just erasing the op leads to segfaults when the
+// builder is `PatternRewriter` since the rewriter doesn't know the op was
+// deleted. This can be simplified a lot when this issue is fixed.
+template <typename BuilderTy>
+static void eraseOp(BuilderTy &builder, Operation *op) {
+ builder.eraseOp(op);
+}
+template <>
+void eraseOp(OpBuilder &builder, Operation *op) {
+ op->erase();
+}
+
+/// Computes the bounds of the loops of the `linalgOp`.
+template <typename BuilderTy>
+static Optional<SmallVector<Value, 4>> getLoopUpperBounds(
+ BuilderTy &builder, Location loc, FuncOp numWorkgroupsFn,
linalg::LinalgOp linalgOp) {
if (!numWorkgroupsFn.empty()) {
numWorkgroupsFn.emitError("num workgroups fn expected to be empty");
@@ -56,79 +76,105 @@
<< numWorkgroupsFn.getName();
});
- rewriter.setInsertionPointToEnd(numWorkgroupsFn.addEntryBlock());
+ builder.createBlock(&numWorkgroupsFn.getBody(), /*insertPt=*/{},
+ numWorkgroupsFn.getType().getInputs());
llvm::SetVector<Operation *> slice;
getBackwardSlice(linalgOp, &slice);
BlockAndValueMapping mapper;
for (Operation *op : slice) {
- rewriter.clone(*op, mapper);
+ builder.clone(*op, mapper);
}
// Clone the linalg operation just to compute the loop bounds.
linalg::LinalgOp clonedLinalgOp =
- rewriter.clone(*linalgOp.getOperation(), mapper);
- SmallVector<Range, 4> ranges = clonedLinalgOp.createLoopRanges(rewriter, loc);
- SmallVector<Value, 4> bounds;
- bounds.reserve(ranges.size());
- for (Range r : ranges) bounds.push_back(r.size);
- unsigned numParallelLoops = linalgOp.iterator_types()
- .getValue()
- .take_while([](Attribute attr) -> bool {
- return attr.cast<StringAttr>().getValue() ==
- getParallelIteratorTypeName();
- })
- .size();
- SmallVector<Value, 2> returnVals(bounds.begin(),
- std::next(bounds.begin(), numParallelLoops));
- rewriter.eraseOp(clonedLinalgOp);
- return returnVals;
+ builder.clone(*linalgOp.getOperation(), mapper);
+ auto loopRange = clonedLinalgOp.createLoopRanges(builder, loc);
+ if (llvm::any_of(loopRange, [](Range range) {
+ return !matchPattern(range.stride, m_One()) ||
+ !matchPattern(range.offset, m_Zero());
+ })) {
+ linalgOp.emitError("unhandled non-unit stride loop range");
+ return llvm::None;
+ }
+ SmallVector<Value, 4> bounds = llvm::to_vector<4>(
+ llvm::map_range(loopRange, [](Range range) { return range.size; }));
+ eraseOp<BuilderTy>(builder, clonedLinalgOp);
+ return bounds;
}
/// Utility method to build IR that computes ceil(`numerator` / `denominator`)
-static Value buildCeilDiv(PatternRewriter &rewriter, Location loc,
- Value numerator, Value denominator) {
- Value one = rewriter.create<ConstantIndexOp>(loc, 1);
- Value t = rewriter.create<AddIOp>(
- loc, numerator, rewriter.create<SubIOp>(loc, denominator, one));
- return rewriter.create<SignedDivIOp>(loc, t, denominator);
+static Value buildCeilDiv(OpBuilder &builder, Location loc, Value numerator,
+ Value denominator) {
+ Value one = builder.create<ConstantIndexOp>(loc, 1);
+ Value t = builder.create<AddIOp>(
+ loc, numerator, builder.create<SubIOp>(loc, denominator, one));
+ return builder.create<SignedDivIOp>(loc, t, denominator);
}
/// Utility method to build IR that computes ceil(`numerator` / `denominator`)
/// when denominator is a constant.
-static Value buildCeilDivConstDenominator(PatternRewriter &rewriter,
- Location loc, Value numerator,
- int64_t denominator) {
- return buildCeilDiv(rewriter, loc, numerator,
- rewriter.create<ConstantIndexOp>(loc, denominator));
+static Value buildCeilDiv(OpBuilder &builder, Location loc, Value numerator,
+ int64_t denominator) {
+ return buildCeilDiv(
+ builder, loc, numerator,
+ builder.create<ConstantIndexOp>(loc, denominator).getResult());
}
-LogicalResult createNumWorkgroupsFromResultShape(
- PatternRewriter &rewriter, linalg::LinalgOp linalgOp, FuncOp entryPointFn,
- llvm::StringRef numWorkgroupsFnAttr, ArrayRef<int64_t> tileSizes) {
+template <class BuilderTy>
+static LogicalResult createNumWorkgroupsFromResultShapeImpl(
+ BuilderTy &builder, linalg::LinalgOp linalgOp, FuncOp entryPointFn,
+ llvm::StringRef numWorkgroupsFnAttr, ArrayRef<int64_t> tileSizes,
+ ArrayRef<unsigned> distributedLoops) {
FuncOp numWorkgroupsFn = getNumWorkgroupsFn(
linalgOp.getParentOfType<FuncOp>(), numWorkgroupsFnAttr);
if (!numWorkgroupsFn) return failure();
Location loc = linalgOp.getLoc();
- OpBuilder::InsertionGuard guard(rewriter);
- Optional<SmallVector<Value, 2>> parallelLoopRange =
- getParallelLoopRange(rewriter, numWorkgroupsFn, loc, linalgOp);
- if (!parallelLoopRange) return failure();
- Value one = rewriter.create<ConstantIndexOp>(loc, 1);
- SmallVector<Value, 3> returnValues(3, one);
- for (size_t i = 0, e = std::min<size_t>(parallelLoopRange->size(), 3); i != e;
- ++i) {
- if (tileSizes[e - i - 1] != 0) {
- returnValues[i] = buildCeilDivConstDenominator(
- rewriter, loc, (*parallelLoopRange)[e - i - 1], tileSizes[e - i - 1]);
+ OpBuilder::InsertionGuard guard(builder);
+ auto loopRange = getLoopUpperBounds(builder, loc, numWorkgroupsFn, linalgOp);
+ if (!loopRange) return failure();
+
+ SmallVector<Value, 4> numWorkgroups;
+ DenseSet<unsigned> distributedLoopsSet(distributedLoops.begin(),
+ distributedLoops.end());
+ for (auto size : enumerate(tileSizes)) {
+ if (size.value() && distributedLoopsSet.count(size.index())) {
+ Value num =
+ buildCeilDiv(builder, loc, (*loopRange)[size.index()], size.value());
+ numWorkgroups.push_back(num);
}
}
- rewriter.create<mlir::ReturnOp>(loc, returnValues);
+ SmallVector<Value, 4> resultValues =
+ llvm::to_vector<4>(llvm::reverse(numWorkgroups));
+ Value one = builder.template create<ConstantIndexOp>(loc, 1);
+ resultValues.resize(3, one);
+ builder.template create<mlir::ReturnOp>(loc, resultValues);
return success();
}
-LogicalResult createNumWorkgroupsFromLinearizedResultShape(
+LogicalResult createNumWorkgroupsFromResultShape(
+ OpBuilder &builder, linalg::LinalgOp linalgOp, FuncOp entryPointFn,
+ llvm::StringRef numWorkgroupsFnAttr, ArrayRef<int64_t> tileSizes,
+ ArrayRef<unsigned> distributedLoops) {
+ return createNumWorkgroupsFromResultShapeImpl<OpBuilder>(
+ builder, linalgOp, entryPointFn, numWorkgroupsFnAttr, tileSizes,
+ distributedLoops);
+}
+
+LogicalResult createNumWorkgroupsFromResultShape(
PatternRewriter &rewriter, linalg::LinalgOp linalgOp, FuncOp entryPointFn,
- llvm::StringRef numWorkgroupsFnAttr, int64_t workgroupSizeX) {
+ llvm::StringRef numWorkgroupsFnAttr, ArrayRef<int64_t> tileSizes) {
+ SmallVector<unsigned, 4> distributedLoops =
+ llvm::to_vector<4>(llvm::seq<unsigned>(
+ 0, std::min<unsigned>(3, getNumOuterParallelLoops(linalgOp))));
+ return createNumWorkgroupsFromResultShapeImpl<PatternRewriter>(
+ rewriter, linalgOp, entryPointFn, numWorkgroupsFnAttr, tileSizes,
+ distributedLoops);
+}
+
+LogicalResult createNumWorkgroupsFromLinearizedResultShape(
+ ConversionPatternRewriter &rewriter, linalg::LinalgOp linalgOp,
+ FuncOp entryPointFn, llvm::StringRef numWorkgroupsFnAttr,
+ int64_t workgroupSizeX) {
FuncOp numWorkgroupsFn = getNumWorkgroupsFn(
linalgOp.getParentOfType<FuncOp>(), numWorkgroupsFnAttr);
if (!numWorkgroupsFn) return failure();
@@ -144,16 +190,17 @@
Location loc = linalgOp.getLoc();
OpBuilder::InsertionGuard guard(rewriter);
- Optional<SmallVector<Value, 2>> parallelLoopRange =
- getParallelLoopRange(rewriter, numWorkgroupsFn, loc, linalgOp);
- if (!parallelLoopRange) return failure();
+ Optional<SmallVector<Value, 4>> loopRange =
+ getLoopUpperBounds(rewriter, loc, numWorkgroupsFn, linalgOp);
+ if (!loopRange) return failure();
+ unsigned numParallelLoops = getNumOuterParallelLoops(linalgOp);
Value one = rewriter.create<ConstantIndexOp>(loc, 1);
SmallVector<Value, 3> returnValues(3, one);
- for (auto range : *parallelLoopRange) {
+ for (auto range : ArrayRef<Value>(*loopRange).take_front(numParallelLoops)) {
returnValues[0] = rewriter.create<MulIOp>(loc, range, returnValues[0]);
}
- returnValues[0] = buildCeilDivConstDenominator(rewriter, loc, returnValues[0],
- workgroupSizeX);
+ returnValues[0] =
+ buildCeilDiv(rewriter, loc, returnValues[0], workgroupSizeX);
rewriter.create<mlir::ReturnOp>(loc, returnValues);
return success();
}
diff --git a/iree/compiler/Conversion/CodegenUtils/GetNumWorkgroups.h b/iree/compiler/Conversion/CodegenUtils/GetNumWorkgroups.h
index 4e28779..ec40ffe 100644
--- a/iree/compiler/Conversion/CodegenUtils/GetNumWorkgroups.h
+++ b/iree/compiler/Conversion/CodegenUtils/GetNumWorkgroups.h
@@ -32,33 +32,43 @@
namespace iree_compiler {
/// Generates a function that computes the number of workgroups as
-/// [ceil(`parallelLoopRange`[2] / `tileSizes`[2]),
-/// ceil(`parallelLoopRange`[1] / `tileSizes`[1]),
-/// ceil(`parallelLoopRange`[0] / `tileSizes`[0])]
-/// where `parallelLoopRange` is the ranges of the parallel loops of `linalgOp`
-/// distributed across workgroups.
+/// [ceil(`loopUpperBounds`[2] / `tileSizes`[2]),
+/// ceil(`loopUpperBounds`[1] / `tileSizes`[1]),
+/// ceil(`loopUpperBounds`[0] / `tileSizes`[0])]
+/// where `loopUpperBounds` is the ranges of the parallel loops of `linalgOp`
+/// distributed across workgroups. `distributedLoops` are the loop dimensions
+/// that are distributed.
+LogicalResult createNumWorkgroupsFromResultShape(
+ OpBuilder &builder, linalg::LinalgOp linalgOp, FuncOp entryPointFn,
+ llvm::StringRef numWorkgroupsFnAttr, llvm::ArrayRef<int64_t> tileSizes,
+ llvm::ArrayRef<unsigned> distributedLoops);
+
+/// Generates a function that computes the number of workgroups as
+/// [ceil(`loopUpperBounds`[2] / `tileSizes`[2]),
+/// ceil(`loopUpperBounds`[1] / `tileSizes`[1]),
+/// ceil(`loopUpperBounds`[0] / `tileSizes`[0])]
+/// where `loopUpperBounds` is the ranges of the parallel loops of `linalgOp`
+/// distributed across workgroups. Assumes that upto 3 outer parallel loops of
+/// the `linalgOp` are distributed.
LogicalResult createNumWorkgroupsFromResultShape(
PatternRewriter &rewriter, linalg::LinalgOp linalgOp, FuncOp entryPointFn,
llvm::StringRef numWorkgroupsFnAttr, llvm::ArrayRef<int64_t> tileSizes);
/// Generates a function that computes the number of workgroups as
-/// ceil(`parallelLoopRange`[0] * `parallelLoopRange`[1] * ... *
-/// `parallelLoopRange`[n-1] / `workgroupSizeX`)
-/// where `parallelLoopRange` is the ranges of the parallel loops of `linalgOp`
+/// ceil(`loopUpperBounds`[0] * `loopUpperBounds`[1] * ... *
+/// `loopUpperBounds`[n-1] / `workgroupSizeX`)
+/// where `loopUpperBounds` is the ranges of the parallel loops of `linalgOp`
/// distributed across workgroups.
LogicalResult createNumWorkgroupsFromLinearizedResultShape(
- PatternRewriter &rewriter, linalg::LinalgOp linalgOp, FuncOp entryPointFn,
- llvm::StringRef numWorkgroupsFnAttr, int64_t workgroupSizeX);
+ ConversionPatternRewriter &rewriter, linalg::LinalgOp linalgOp,
+ FuncOp entryPointFn, llvm::StringRef numWorkgroupsFnAttr,
+ int64_t workgroupSizeX);
/// For a given `entryPointFn` return the function that computes the number of
/// workgroups to use at launch time.
FuncOp getNumWorkgroupsFn(FuncOp entryPointFn,
llvm::StringRef numWorkgroupsFnAttr);
-LogicalResult createNumWorkgroupsFromLinearizedResultShape(
- PatternRewriter &rewriter, linalg::LinalgOp linalgOp, FuncOp entryPointFn,
- llvm::StringRef numWorkgroupsFnAttr, int64_t workgroupSizeX);
-
/// The codegeneration emits a function `numWorkgroupsFn` for each entry point
/// function. This function has arguments the !shapex.ranked_shape for all the
/// input and output shaped types. Using this the function returns the number of
diff --git a/iree/compiler/Conversion/CodegenUtils/MarkerUtils.cpp b/iree/compiler/Conversion/CodegenUtils/MarkerUtils.cpp
index ab2a9c1..a7af32d 100644
--- a/iree/compiler/Conversion/CodegenUtils/MarkerUtils.cpp
+++ b/iree/compiler/Conversion/CodegenUtils/MarkerUtils.cpp
@@ -47,6 +47,12 @@
return "copy_to_workgroup_memory";
}
+// This marker is needed because we tile a convolution op multiple times: 1)
+// workgroups, 2) invocations, and 3) tiling along filter's height/width and
+// input channel to generate loops for a single GPU invocation. This marker
+// is for the 3) step.
+StringRef getConvFilterTileMarker() { return "tile_conv_filter"; }
+
StringRef getVectorizeMarker() { return "vectorize"; }
StringRef getDeleteMarker() { return "delete"; }
diff --git a/iree/compiler/Conversion/CodegenUtils/MarkerUtils.h b/iree/compiler/Conversion/CodegenUtils/MarkerUtils.h
index 5072293..5371aba 100644
--- a/iree/compiler/Conversion/CodegenUtils/MarkerUtils.h
+++ b/iree/compiler/Conversion/CodegenUtils/MarkerUtils.h
@@ -45,6 +45,9 @@
/// Workgroup memory.
StringRef getCopyToWorkgroupMemoryMarker();
+/// Marker for tiling along convolution filter dimensions.
+StringRef getConvFilterTileMarker();
+
/// Marker for operations that are going to be vectorized.
StringRef getVectorizeMarker();
diff --git a/iree/compiler/Conversion/Common/BUILD b/iree/compiler/Conversion/Common/BUILD
index b8c5b5a..5827238 100644
--- a/iree/compiler/Conversion/Common/BUILD
+++ b/iree/compiler/Conversion/Common/BUILD
@@ -22,22 +22,33 @@
name = "Common",
srcs = [
"DeclareNumWorkgroupsFnPass.cpp",
+ "LaunchConfig.cpp",
"LegalizeNumWorkgroupsFnPass.cpp",
+ "Transforms.cpp",
+ "VectorTransferOptimization.cpp",
],
hdrs = [
"Attributes.h",
+ "LaunchConfig.h",
"Passes.h",
+ "Transforms.h",
],
deps = [
"//iree/compiler/Conversion/CodegenUtils",
"//iree/compiler/Dialect/HAL/IR",
"//iree/compiler/Dialect/IREE/IR",
"//iree/compiler/Dialect/Shape/IR",
+ "@llvm-project//llvm:Support",
"@llvm-project//mlir:CFGTransforms",
+ "@llvm-project//mlir:GPUDialect",
"@llvm-project//mlir:IR",
+ "@llvm-project//mlir:LinalgOps",
+ "@llvm-project//mlir:LinalgTransforms",
"@llvm-project//mlir:Pass",
+ "@llvm-project//mlir:SPIRVLowering",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Transforms",
+ "@llvm-project//mlir:VectorOps",
"@org_tensorflow//tensorflow/compiler/mlir/hlo",
],
)
diff --git a/iree/compiler/Conversion/Common/CMakeLists.txt b/iree/compiler/Conversion/Common/CMakeLists.txt
index b233805..955f010 100644
--- a/iree/compiler/Conversion/Common/CMakeLists.txt
+++ b/iree/compiler/Conversion/Common/CMakeLists.txt
@@ -19,16 +19,28 @@
Common
HDRS
"Attributes.h"
+ "LaunchConfig.h"
"Passes.h"
+ "Transforms.h"
SRCS
"DeclareNumWorkgroupsFnPass.cpp"
+ "LaunchConfig.cpp"
"LegalizeNumWorkgroupsFnPass.cpp"
+ "Transforms.cpp"
+ "VectorTransferOptimization.cpp"
DEPS
+ LLVMSupport
+ MLIRGPU
MLIRIR
+ MLIRLinalg
+ MLIRLinalgTransforms
MLIRPass
MLIRSCFToStandard
+ MLIRSPIRV
+ MLIRSPIRVTransforms
MLIRStandard
MLIRTransforms
+ MLIRVector
iree::compiler::Conversion::CodegenUtils
iree::compiler::Dialect::HAL::IR
iree::compiler::Dialect::IREE::IR
diff --git a/iree/compiler/Conversion/Common/LaunchConfig.cpp b/iree/compiler/Conversion/Common/LaunchConfig.cpp
new file mode 100644
index 0000000..20c5863
--- /dev/null
+++ b/iree/compiler/Conversion/Common/LaunchConfig.cpp
@@ -0,0 +1,169 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//===- LaunchConfig.cpp - Specifies configuration used to drive the nfo ---===//
+//
+// This file defines the data structure that is used by the codegeneration to
+// lower to target specific IR. The values of the parameters are archtecture
+// specific. Once set the same transformations can be used to generate the
+// desired code. This allows sharing codegen infra between different backends.
+//
+//===----------------------------------------------------------------------===//
+
+#include "iree/compiler/Conversion/Common/LaunchConfig.h"
+
+#include "iree/compiler/Conversion/CodegenUtils/FunctionUtils.h"
+#include "iree/compiler/Conversion/Common/Attributes.h"
+#include "llvm/Support/FormatVariadic.h"
+#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
+#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
+#include "mlir/Dialect/SPIRV/TargetAndABI.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/StandardTypes.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+/// Name of the StrAttr that can be used to get the key to access the tile size
+/// information.
+static const char kLaunchInfoKey[] = "launch_info_key";
+
+static Optional<StringRef> getKey(Operation *op) {
+ StringAttr attr = op->getAttrOfType<StringAttr>(kLaunchInfoKey);
+ if (!attr) return {};
+ return attr.getValue();
+}
+
+static void setKey(Operation *op, StringRef key) {
+ MLIRContext *context = op->getContext();
+ op->setAttr(Identifier::get(kLaunchInfoKey, op->getContext()),
+ StringAttr::get(key, context));
+}
+
+static std::string getOrSetNewKey(Operation *op, int64_t suffix) {
+ Optional<StringRef> key = getKey(op);
+ if (key) return key->str();
+ std::string newKey = llvm::formatv("__op_num_{0}__", suffix).str();
+ setKey(op, StringRef(newKey));
+ return newKey;
+}
+
+void LaunchConfig::finalize(FuncOp funcOp) {
+ funcOp.walk([&](linalg::LinalgOp linalgOp) {
+ linalgOp.removeAttr(Identifier::get(kLaunchInfoKey, funcOp.getContext()));
+ });
+}
+
+TileSizesListTypeRef LaunchConfig::getTileSizes(Operation *op) const {
+ auto key = getKey(op);
+ if (!key) return {};
+ auto it = tileSizes.find(*key);
+ return it->second;
+}
+
+ArrayRef<int64_t> LaunchConfig::getTileSizes(Operation *op,
+ size_t level) const {
+ auto t = getTileSizes(op);
+ if (level >= t.size()) return {};
+ return t[level];
+}
+
+void LaunchConfig::setTileSizes(Operation *op, TileSizesListType vTileSizes) {
+ tileSizes[getOrSetNewKey(op, tileSizes.size())] = vTileSizes;
+}
+
+void LaunchConfig::setTileSizes(Operation *op, ArrayRef<int64_t> vTileSizes,
+ size_t level) {
+ tileSizes[getOrSetNewKey(op, tileSizes.size())].emplace_back(
+ vTileSizes.begin(), vTileSizes.end());
+}
+
+static void setArrayVals(std::array<int64_t, 3> &array,
+ ArrayRef<int64_t> vals) {
+ if (vals.size() > 3) vals = vals.take_front(3);
+ for (auto size : enumerate(vals)) array[size.index()] = size.value();
+ for (unsigned i : llvm::seq<unsigned>(vals.size(), 3)) array[i] = 1;
+}
+
+void LaunchConfig::setWorkgroupSize(ArrayRef<int64_t> vWorkgroupSize) {
+ setArrayVals(workgroupSize, vWorkgroupSize);
+}
+
+void LaunchConfig::setNumSubgroups(ArrayRef<int64_t> vNumSubgroups) {
+ setArrayVals(numSubgroups, vNumSubgroups);
+}
+
+void LaunchConfig::setSameConfig(Operation *source, Operation *target) {
+ assert(getKey(source) && "missing configuration of source operation");
+ setKey(target, *getKey(source));
+}
+
+void LaunchConfig::setVectorize(bool enableVectorize) {
+ vectorize = enableVectorize;
+}
+
+LogicalResult propogateRootOperationLaunchConfig(
+ LaunchConfig &config, linalg::LinalgOp rootOperation,
+ const linalg::LinalgDependenceGraph &dependenceGraph) {
+ // Check the dependencies going into and out of the root operation. For now
+ // only the following dependencies are supported
+ // - WAW dependencies going into the root operation.
+ // - RAW dependencies going out of the root operation.
+ // - WAW dependencies going out of the root operation.
+ // i.e. there are no RAW dependences going into the root operation.
+ auto inRAWDependencies = dependenceGraph.getDependencesInto(
+ rootOperation, linalg::LinalgDependenceGraph::RAW);
+ if (!inRAWDependencies.empty()) {
+ return rootOperation.getOperation()->emitError(
+ "unhandled fusion of root operation with producer");
+ }
+ auto dependences = dependenceGraph.getDependentOperations(rootOperation);
+ unsigned numOuterParallel = getNumOuterParallelLoops(rootOperation);
+
+ // Check that for all dependences into and out of the root operation,
+ // - The result expressions of the indexing maps of the fused view in the
+ // producer and consumer must match for the parallel loops.
+ for (auto dependence :
+ dependenceGraph.getDependentOperations(rootOperation)) {
+ unsigned viewIndex = dependence.indexingOpView.operandIndex;
+ AffineMap indexingMap = rootOperation.getIndexingMap(viewIndex);
+ linalg::LinalgOp fusedOp =
+ cast<linalg::LinalgOp>(dependence.dependentOpView.op);
+ unsigned fusedViewIndex = dependence.dependentOpView.operandIndex;
+ AffineMap fusedIndexingMap = fusedOp.getIndexingMap(fusedViewIndex);
+ if (indexingMap.getNumResults() < numOuterParallel ||
+ fusedIndexingMap.getNumResults() < numOuterParallel ||
+ !llvm::all_of(
+ llvm::seq<unsigned>(0, numOuterParallel),
+ [&indexingMap, fusedIndexingMap](unsigned i) {
+ return indexingMap.getResult(i).isa<AffineDimExpr>() &&
+ fusedIndexingMap.getResult(i).isa<AffineDimExpr>() &&
+ indexingMap.getResult(i) == fusedIndexingMap.getResult(i);
+ })) {
+ return rootOperation.getOperation()->emitError(
+ "unhandled fusion of root operation with all operations in the "
+ "dispatch region");
+ }
+ }
+ // The dependent operations get the same tile size information as the root
+ // operation. To propogate that information, just use the same key as the root
+ // operation.
+ for (auto dependence : dependences) {
+ config.setSameConfig(rootOperation, dependence.dependentOpView.op);
+ }
+ return success();
+}
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/iree/compiler/Conversion/Common/LaunchConfig.h b/iree/compiler/Conversion/Common/LaunchConfig.h
new file mode 100644
index 0000000..dcb3b14
--- /dev/null
+++ b/iree/compiler/Conversion/Common/LaunchConfig.h
@@ -0,0 +1,140 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//===- LaunchConfig.h - Configuration used to drive arch specific codegen -===//
+//
+// This file declares the data structure that is used by the codegeneration to
+// lower to target specific IR. The values of the parameters are archtecture
+// specific. Once set the same transformations can be used to generate the
+// desired code. This allows sharing codegen infra between different backends.
+//
+//===----------------------------------------------------------------------===//
+#ifndef IREE_COMPILER_CONVERSION_COMMON_LAUNCHCONFIG_H_
+#define IREE_COMPILER_CONVERSION_COMMON_LAUNCHCONFIG_H_
+#include <array>
+
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/StringMap.h"
+#include "llvm/Support/FormatVariadic.h"
+#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
+#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
+#include "mlir/IR/Function.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/Types.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Support/LogicalResult.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+/// Stores the tile sizes to use at different levels of tiling as a vector of
+/// vectors.
+/// - First level tiling maps to workgroups.
+/// - Second level tiling maps to subgroups.
+/// - Third level tiling maps to invocations.
+using TileSizesListType = SmallVector<SmallVector<int64_t, 4>, 1>;
+using TileSizesListTypeRef = ArrayRef<SmallVector<int64_t, 4>>;
+
+/// Configurations for mapping Linalg ops to CPU/GPU parallel hiearchies.
+///
+/// Based on the linalg operations in a dispatch region, the number of levels of
+/// tiling, the tile sizes needed, the workgroup size, etc. need to be
+/// decided. These parameters are called `LaunchConfig`. This class implements
+/// one heuristic to compute these for the different linalg operations on
+/// buffers. This can be adapted later to support multiple configurations that
+/// can be picked based on device information/problem size information. It
+/// exposes the information needed by the codegenerators, and hides the
+/// implementation from the rest of the pipeline.
+class LaunchConfig {
+ public:
+ LaunchConfig() : workgroupSize({1, 1, 1}), numSubgroups({1, 1, 1}) {}
+
+ /// Removes attributes added to operations for retrieving tile size
+ /// information.
+ void finalize(FuncOp funcOp);
+
+ /// Gets the tile size computed for an operation at all levels.
+ TileSizesListTypeRef getTileSizes(Operation *op) const;
+
+ /// Gets the tile size computed for an operation for an level.
+ ArrayRef<int64_t> getTileSizes(Operation *op, size_t level) const;
+
+ /// Returns the workgroup size to use based on the tile sizes.
+ ArrayRef<int64_t> getWorkgroupSize() const { return workgroupSize; }
+
+ /// Returns the number of subgroups to use.
+ ArrayRef<int64_t> getNumSubgroups() const { return numSubgroups; }
+
+ /// Returns true if tile sizes have been computed for the operation. If tile
+ /// sizes arent set, it implies operation is not to be tiled.
+ bool hasTileSizes(Operation *op, size_t level = 0) const {
+ return !getTileSizes(op, level).empty();
+ }
+
+ /// Use vectorize transformations.
+ bool useVectorize() const { return vectorize; }
+
+ /// Sets the tile sizes to use for all levels of tiling of `op`.
+ void setTileSizes(Operation *op, TileSizesListType vTileSizes);
+
+ /// Sets the tile sizes to use for a given `level` of tiling of `op`.
+ void setTileSizes(Operation *op, ArrayRef<int64_t> vTileSizes, size_t level);
+
+ /// Sets the workgroup size to use for the function.
+ void setWorkgroupSize(ArrayRef<int64_t> vWorkgroupSize);
+
+ /// Sets number of subgroups to use.
+ void setNumSubgroups(ArrayRef<int64_t> vNumSubgroups);
+
+ /// Sets the configuration of the `targetOp` to be same as the configuration
+ /// of the `sourceOp`.
+ void setSameConfig(Operation *sourceOp, Operation *targetOp);
+
+ /// Sets flag to enable vectorization.
+ void setVectorize(bool enableVectorize);
+
+ protected:
+ /// Current tile size configuration per operation. They key used here to
+ /// retrieve the tile size information per operation is the value of a StrAttr
+ /// added to operations during `init`. When tiled this attribute is copied
+ /// over to the tiled operation, thereby the same key can be used to retrieve
+ /// the tile sizes for the next level of tiling. The `finalize` method removes
+ /// these attributes.
+ llvm::StringMap<TileSizesListType> tileSizes;
+
+ /// Workgroup size to use.
+ std::array<int64_t, 3> workgroupSize = {1, 1, 1};
+
+ /// Number of subgroups that are logically distributed along x, y & z.
+ std::array<int64_t, 3> numSubgroups = {1, 1, 1};
+
+ /// Use vectorization.
+ bool vectorize = false;
+};
+
+/// Propogates tile sizes from `rootOperation` to other linalg operations in the
+/// dispatch region. This assumes that each dispatch region has a single root
+/// operation (like matmul, conv, etc.) that determines the tile sizes to use
+/// for tile+fuse+distribute. These are then propogated to the other operations.
+/// Note: This is a temporary solution and might be defunct when the codegen
+/// becomes more sophisticated.
+LogicalResult propogateRootOperationLaunchConfig(
+ LaunchConfig &launchConfig, linalg::LinalgOp rootOperation,
+ const linalg::LinalgDependenceGraph &dependenceGraph);
+
+} // namespace iree_compiler
+} // namespace mlir
+#endif // IREE_COMPILER_CONVERSION_COMMON_LAUNCHCONFIG_H_
diff --git a/iree/compiler/Conversion/Common/Passes.h b/iree/compiler/Conversion/Common/Passes.h
index 00a61e0..1c9ca32 100644
--- a/iree/compiler/Conversion/Common/Passes.h
+++ b/iree/compiler/Conversion/Common/Passes.h
@@ -23,5 +23,8 @@
/// each entry point function. The function is defined, but is populated later.
std::unique_ptr<OperationPass<ModuleOp>> createDeclareNumWorkgroupsFnPass();
+/// Pass to optimize vector transfer_read and transfer_write.
+std::unique_ptr<FunctionPass> createVectorTransferOptimizationPass();
+
} // namespace iree_compiler
} // namespace mlir
diff --git a/iree/compiler/Conversion/Common/Transforms.cpp b/iree/compiler/Conversion/Common/Transforms.cpp
new file mode 100644
index 0000000..cf7bc2a
--- /dev/null
+++ b/iree/compiler/Conversion/Common/Transforms.cpp
@@ -0,0 +1,273 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//===- Transforms.cpp - Transformations common to all backends ------------===//
+//
+// Implements transformations that are common to all backends.
+//
+//===----------------------------------------------------------------------===//
+
+#include "iree/compiler/Conversion/Common/Transforms.h"
+
+#include "iree/compiler/Conversion/CodegenUtils/FunctionUtils.h"
+#include "iree/compiler/Conversion/CodegenUtils/GetNumWorkgroups.h"
+#include "iree/compiler/Conversion/CodegenUtils/MarkerUtils.h"
+#include "iree/compiler/Conversion/CodegenUtils/MatmulCodegenStrategy.h"
+#include "iree/compiler/Conversion/Common/Attributes.h"
+#include "mlir/Dialect/GPU/GPUDialect.h"
+#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
+#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
+#include "mlir/Dialect/Linalg/Utils/Utils.h"
+#include "mlir/IR/Function.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#define DEBUG_TYPE "iree-linalg-tile-and-fuse"
+
+namespace mlir {
+namespace iree_compiler {
+
+/// Apply canonicalizations related to tiling to make promotion/vectorization
+/// easier.
+void applyCanonicalizationPatternsForTiling(MLIRContext *context,
+ Operation *op) {
+ OwningRewritePatternList canonicalizationPatterns;
+ canonicalizationPatterns.insert<AffineMinCanonicalizationPattern>(context);
+ AffineApplyOp::getCanonicalizationPatterns(canonicalizationPatterns, context);
+ AffineMinOp::getCanonicalizationPatterns(canonicalizationPatterns, context);
+ SubViewOp::getCanonicalizationPatterns(canonicalizationPatterns, context);
+ applyPatternsAndFoldGreedily(op, std::move(canonicalizationPatterns));
+}
+
+//===----------------------------------------------------------------------===//
+// Helper functions for tile and fuse.
+//===----------------------------------------------------------------------===//
+
+/// Promotes views used to chain Linalg ops in `fusedOps`
+/// into buffers using the allocation callback in `options`.
+///
+/// Once fused the fused views that are due to a RAW dependence can be promoted
+/// to workgroup memory. This will make the intermediate storage dead.
+static LogicalResult promoteFusedViews(OpBuilder &builder,
+ ArrayRef<linalg::LinalgOp> fusedOps,
+ const TileAndFuseOptions &options) {
+ linalg::Aliases aliases;
+ linalg::LinalgDependenceGraph dependenceGraph(aliases, fusedOps);
+ auto fusableDependences =
+ linalg::findAllFusableDependences(fusedOps, dependenceGraph);
+
+ DenseSet<Value> promotedViews;
+ OpBuilder::InsertionGuard guard(builder);
+ builder.setInsertionPoint(*fusedOps.begin());
+
+ // Scan the list of ops in reverse order. The fusion is performed by creating
+ // a tiled version of the ops within the tiled loops of the last operation in
+ // the sequence, and then proceeding up the sequence.
+ for (linalg::LinalgOp op : llvm::reverse(fusedOps)) {
+ auto dependences = fusableDependences.lookup(op);
+ if (dependences.empty()) continue;
+ if (!llvm::hasSingleElement(dependences)) {
+ return op.emitError(
+ "unable to promote ops with multiple fusable dependences");
+ }
+ auto dependence = dependences.front();
+ unsigned producerIdx = dependence.dependentOpView.operandIndex;
+ linalg::LinalgOp consumer =
+ cast<linalg::LinalgOp>(dependence.indexingOpView.op);
+ unsigned consumerIdx = dependence.indexingOpView.operandIndex;
+ Value consumerView = consumer.getShapedOperand(consumerIdx);
+ Value promotedView = nullptr;
+
+ // If the view is already promoted, reuse that. The assumption is that the
+ // view matches already.
+ if (promotedViews.count(consumerView)) {
+ promotedView = consumerView;
+ } else if (dependence.dependenceType ==
+ linalg::LinalgDependenceGraph::RAW) {
+ SubViewOp promotedViewProducer =
+ op.getShapedOperand(producerIdx).getDefiningOp<SubViewOp>();
+ assert(promotedViewProducer &&
+ "expected producer to be a subview op as well");
+ Optional<linalg::PromotionInfo> promotionInfo =
+ linalg::promoteSubviewAsNewBuffer(
+ builder, op.getLoc(), promotedViewProducer, options.allocationFn);
+ if (!promotionInfo) {
+ return op.emitError("unable to promote RAW dependence");
+ }
+ promotedView = promotionInfo->partialLocalView;
+ consumer.getOperation()->setOperand(consumerIdx, promotedView);
+ promotedViews.insert(promotedView);
+ }
+ if (!promotedView) continue;
+ op.getOperation()->setOperand(producerIdx, promotedView);
+ }
+ return success();
+}
+
+/// Tile+Fuse only tiles the loops that can be fused. Tile any of the unfused
+/// loops in the operation based on the configuration.
+static linalg::LinalgOp tileUnfusedLoops(
+ OpBuilder &builder, linalg::LinalgOp linalgOp,
+ const std::set<unsigned> &fusedLoopDims, ArrayRef<int64_t> tileSizesRef) {
+ SmallVector<int64_t, 4> tileSizes = llvm::to_vector<4>(tileSizesRef);
+ tileSizes.resize(linalgOp.getNumLoops(), 0);
+ // Linalg uses tile size = 0 for a loop to indicate not tiling that loop. Set
+ // the fused loops to be untiled (since they are already tiled during fusion).
+ for (unsigned loopNum : fusedLoopDims) {
+ tileSizes[loopNum] = 0;
+ }
+ if (llvm::all_of(tileSizes, [](int64_t v) { return !v; })) return linalgOp;
+ Optional<linalg::TiledLinalgOp> tiledOp = tileLinalgOp(
+ builder, linalgOp,
+ linalg::LinalgTilingOptions().setTileSizes(tileSizes).setLoopType(
+ linalg::LinalgTilingLoopType::ParallelLoops));
+ if (!tiledOp) return nullptr;
+ linalgOp.erase();
+ return tiledOp->op;
+}
+
+/// Tiles the last operation in `fusableOps` and fuses all other operations with
+/// it by creating tiled versions of each within the generated inter-tile loops.
+static Optional<linalg::TiledAndFusedLinalgOps> tileAndFuseLinalgOps(
+ OpBuilder &builder, FuncOp funcOp, ArrayRef<linalg::LinalgOp> fusableOps,
+ const linalg::LinalgDependenceGraph &dependenceGraph,
+ ArrayRef<int64_t> tileSizes, const TileAndFuseOptions &options) {
+ // Get the tile sizes to use from the last fusable op and the tile+fuse all
+ // ops.
+ linalg::LinalgTilingOptions tilingOptions;
+ tilingOptions.setDistributionOptions(options.distributionOptions)
+ .setTileSizes(tileSizes)
+ .setLoopType(linalg::LinalgTilingLoopType::ParallelLoops);
+
+ Optional<linalg::TiledAndFusedLinalgOps> tiledAndFusedOps = llvm::None;
+ if (fusableOps.size() == 1) {
+ linalg::LinalgOp linalgOp = fusableOps.front();
+ Optional<linalg::TiledLinalgOp> tiledOp =
+ tileLinalgOp(builder, linalgOp, tilingOptions);
+ if (!tiledOp) {
+ linalgOp.emitError("unable to tile operation");
+ return llvm::None;
+ }
+ tiledAndFusedOps = linalg::TiledAndFusedLinalgOps{tiledOp->op, {}, {}, {}};
+ auto seq = llvm::seq<unsigned>(0, tileSizes.size());
+ tiledAndFusedOps->fusedLoopDims.insert(seq.begin(), seq.end());
+ tiledAndFusedOps->fusedLoops.assign(tiledOp->loops.begin(),
+ tiledOp->loops.end());
+ } else {
+ tiledAndFusedOps = tileAndFuseLinalgOps(builder, fusableOps,
+ dependenceGraph, tilingOptions);
+ }
+ if (!tiledAndFusedOps) {
+ funcOp.emitError("tile and fuse of linalg operations failed");
+ return llvm::None;
+ }
+
+ // Update the launch configuration.
+ SmallVector<unsigned, 2> distributedLoops =
+ llvm::to_vector<2>(tiledAndFusedOps->fusedLoopDims);
+ if (funcOp.getAttr(getNumWorkgroupsFnAttrName()) &&
+ failed(createNumWorkgroupsFromResultShape(
+ builder, fusableOps.back(), funcOp, getNumWorkgroupsFnAttrName(),
+ tileSizes, distributedLoops))) {
+ funcOp.emitError("failed to update launch configuration");
+ return llvm::None;
+ }
+
+ // Delete all the original operations.
+ for (auto linalgOp : fusableOps) linalgOp.erase();
+
+ // Add workgroup markers to all the tiled and fused operations.
+ for (auto fusedProducer : tiledAndFusedOps->fusedProducers) {
+ setMarker(fusedProducer, getWorkgroupMarker());
+ }
+ setMarker(tiledAndFusedOps->op, getWorkgroupMarker());
+
+ return tiledAndFusedOps;
+}
+
+LogicalResult tileAndFuseLinalgBufferOps(
+ FuncOp funcOp, ArrayRef<linalg::LinalgOp> linalgOps,
+ const linalg::LinalgDependenceGraph &dependenceGraph,
+ const LaunchConfig &launchConfig, const TileAndFuseOptions &options) {
+ // Collect all operations that are to be tiled-and-fused.
+ MLIRContext *context = funcOp.getContext();
+ SmallVector<linalg::LinalgOp, 4> fusableOps;
+ for (Operation *operation : linalgOps) {
+ if (!launchConfig.hasTileSizes(operation)) continue;
+ fusableOps.push_back(cast<linalg::LinalgOp>(operation));
+ }
+ if (fusableOps.empty()) return success();
+
+ OpBuilder builder(context);
+ ArrayRef<int64_t> tileSizes = launchConfig.getTileSizes(fusableOps.back(), 0);
+ Optional<linalg::TiledAndFusedLinalgOps> tiledAndFusedOps =
+ tileAndFuseLinalgOps(builder, funcOp, fusableOps, dependenceGraph,
+ tileSizes, options);
+ if (!tiledAndFusedOps) {
+ return funcOp.emitError("failed to tile and fuse operations");
+ }
+
+ LLVM_DEBUG({
+ llvm::dbgs() << "--- After Fusion on buffers ---\n";
+ funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
+ llvm::dbgs() << "\n\n";
+ });
+
+ applyCanonicalizationPatternsForTiling(context, funcOp);
+
+ LLVM_DEBUG({
+ llvm::dbgs() << "--- After Canonicalization ---\n";
+ funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
+ llvm::dbgs() << "\n\n";
+ });
+
+ if (options.allocationFn) {
+ SmallVector<linalg::LinalgOp, 4> promoteFusedViewOps =
+ llvm::to_vector<4>(tiledAndFusedOps->fusedProducers);
+ promoteFusedViewOps.push_back(tiledAndFusedOps->op);
+
+ if (failed(promoteFusedViews(builder, promoteFusedViewOps, options))) {
+ return failure();
+ }
+ }
+
+ LLVM_DEBUG({
+ llvm::dbgs() << "--- After Promotion ---\n";
+ funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
+ llvm::dbgs() << "\n\n";
+ });
+
+ // Tile the unfused loops. Set the tile sizes for the fused loops to be zero
+ // to avoid tiling them again.
+ for (linalg::LinalgOp &fusedOp : tiledAndFusedOps->fusedProducers) {
+ ArrayRef<int64_t> fusedOpTileSizes = launchConfig.getTileSizes(fusedOp, 0);
+ linalg::LinalgOp tiledOp = tileUnfusedLoops(
+ builder, fusedOp, tiledAndFusedOps->fusedLoopDims, fusedOpTileSizes);
+ if (!tiledOp) {
+ return fusedOp.emitError("unable to tile unfused loops");
+ }
+ }
+ linalg::LinalgOp tiledOp =
+ tileUnfusedLoops(builder, tiledAndFusedOps->op,
+ tiledAndFusedOps->fusedLoopDims, tileSizes);
+ if (!tiledOp) {
+ return tiledAndFusedOps->op.emitError("unable to tile unfused loops");
+ }
+
+ applyCanonicalizationPatternsForTiling(context, funcOp);
+ return success();
+}
+
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/iree/compiler/Conversion/Common/Transforms.h b/iree/compiler/Conversion/Common/Transforms.h
new file mode 100644
index 0000000..d10eec7
--- /dev/null
+++ b/iree/compiler/Conversion/Common/Transforms.h
@@ -0,0 +1,58 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//===- Transforms.h - Transformations common to all backends --------------===//
+//
+// Defines transformations that are common to backends
+//
+//===----------------------------------------------------------------------===//
+#ifndef IREE_COMPILER_CONVERSION_COMMON_TRANSFORMS_H_
+#define IREE_COMPILER_CONVERSION_COMMON_TRANSFORMS_H_
+
+#include "iree/compiler/Conversion/Common/LaunchConfig.h"
+#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
+#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/IR/Operation.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+/// Apply canonicalizations related to tiling to make promotion/vectorization
+/// easier.
+void applyCanonicalizationPatternsForTiling(MLIRContext *context,
+ Operation *op);
+
+struct TileAndFuseOptions {
+ linalg::LinalgLoopDistributionOptions distributionOptions;
+ linalg::AllocBufferCallbackFn allocationFn = nullptr;
+};
+/// Method to tile and fuse sequence of Linalg operations in `linalgOps`. Uses
+/// the tile sizes for the first level of tiling specified in
+/// `launchConfig`. Proceeds by
+/// 1) Find the common loops around `linalgOps` that can be fused.
+/// 2) Tile the fusable loops in the last operation in the sequence.
+/// 3) Creates tiled version of the other ops within the inter-tile loops
+/// generated in step 2.
+/// 4) For all the tiled+fused ops, tile the unfused loops as specified by
+/// launchconfig.
+LogicalResult tileAndFuseLinalgBufferOps(
+ FuncOp funcOp, ArrayRef<linalg::LinalgOp> linalgOps,
+ const linalg::LinalgDependenceGraph &dependenceGraph,
+ const LaunchConfig &launchConfig, const TileAndFuseOptions &options);
+
+} // namespace iree_compiler
+} // namespace mlir
+
+#endif // IREE_COMPILER_CONVERSION_COMMON_TRANSFORMS_H_
diff --git a/iree/compiler/Conversion/Common/VectorTransferOptimization.cpp b/iree/compiler/Conversion/Common/VectorTransferOptimization.cpp
new file mode 100644
index 0000000..91ef5a0
--- /dev/null
+++ b/iree/compiler/Conversion/Common/VectorTransferOptimization.cpp
@@ -0,0 +1,40 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "mlir/Dialect/Vector/VectorTransforms.h"
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+namespace {
+
+struct VectorTransferOptimizationPass
+ : public PassWrapper<VectorTransferOptimizationPass, FunctionPass> {
+ void runOnFunction() override { vector::transferOpflowOpt(getFunction()); }
+};
+
+} // namespace
+
+std::unique_ptr<FunctionPass> createVectorTransferOptimizationPass() {
+ return std::make_unique<VectorTransferOptimizationPass>();
+}
+
+static PassRegistration<VectorTransferOptimizationPass> pass(
+ "iree-codegen-optimize-vector-transfer",
+ "Run optimization transformations on vector transfer operations",
+ [] { return std::make_unique<VectorTransferOptimizationPass>(); });
+
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/iree/compiler/Conversion/HLOToHLO/BUILD b/iree/compiler/Conversion/HLOToHLO/BUILD
index fccd391..2dbe2b7 100644
--- a/iree/compiler/Conversion/HLOToHLO/BUILD
+++ b/iree/compiler/Conversion/HLOToHLO/BUILD
@@ -22,13 +22,18 @@
name = "HLOToHLO",
srcs = [
"DecomposeHLOClamp.cpp",
+ "DemoteF32ToF16.cpp",
],
hdrs = [
"Passes.h",
],
deps = [
+ "//iree/compiler/Dialect/Flow/IR",
+ "//iree/compiler/Dialect/IREE/IR",
+ "@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
+ "@llvm-project//mlir:Support",
"@llvm-project//mlir:TransformUtils",
"@org_tensorflow//tensorflow/compiler/mlir/hlo",
],
diff --git a/iree/compiler/Conversion/HLOToHLO/CMakeLists.txt b/iree/compiler/Conversion/HLOToHLO/CMakeLists.txt
index c067d1b..2b6f0ae 100644
--- a/iree/compiler/Conversion/HLOToHLO/CMakeLists.txt
+++ b/iree/compiler/Conversion/HLOToHLO/CMakeLists.txt
@@ -21,10 +21,15 @@
"Passes.h"
SRCS
"DecomposeHLOClamp.cpp"
+ "DemoteF32ToF16.cpp"
DEPS
+ LLVMSupport
MLIRIR
MLIRPass
+ MLIRSupport
MLIRTransformUtils
+ iree::compiler::Dialect::Flow::IR
+ iree::compiler::Dialect::IREE::IR
tensorflow::mlir_hlo
PUBLIC
)
diff --git a/iree/compiler/Conversion/HLOToHLO/DemoteF32ToF16.cpp b/iree/compiler/Conversion/HLOToHLO/DemoteF32ToF16.cpp
new file mode 100644
index 0000000..7e0f286
--- /dev/null
+++ b/iree/compiler/Conversion/HLOToHLO/DemoteF32ToF16.cpp
@@ -0,0 +1,196 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <memory>
+#include <utility>
+
+#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
+#include "iree/compiler/Dialect/IREE/IR/IREETypes.h"
+#include "llvm/ADT/APFloat.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
+#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Function.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir {
+namespace iree_compiler {
+namespace {
+struct ConvertF32ToF16Pass
+ : public PassWrapper<ConvertF32ToF16Pass, OperationPass<ModuleOp>> {
+ void runOnOperation() override;
+
+ private:
+};
+
+/// Any fp32 derived type is illegal.
+static bool isIllegalType(Type type) {
+ if (type.isF32()) return true;
+ if (auto ptrType = type.dyn_cast<IREE::PtrType>()) {
+ return isIllegalType(ptrType.getTargetType());
+ }
+ if (auto shapedType = type.dyn_cast<ShapedType>()) {
+ return isIllegalType(shapedType.getElementType());
+ }
+ return false;
+}
+
+class F32ToF16ConversionTarget : public ConversionTarget {
+ public:
+ using ConversionTarget::ConversionTarget;
+
+ protected:
+ // Operations are legal if they don't contain any illegal type.
+ bool isDynamicallyLegal(Operation *op) const override {
+ if (auto varOp = dyn_cast<IREE::Flow::VariableOp>(op)) {
+ return !isIllegalType(varOp.type());
+ }
+ if (auto funcOp = dyn_cast<FuncOp>(op)) {
+ for (Type type : funcOp.getType().getInputs()) {
+ if (isIllegalType(type)) return false;
+ }
+ for (Type type : funcOp.getType().getResults()) {
+ if (isIllegalType(type)) return false;
+ }
+ }
+ for (Type type : op->getResultTypes()) {
+ if (isIllegalType(type)) return false;
+ }
+ for (Type type : op->getOperandTypes()) {
+ if (isIllegalType(type)) return false;
+ }
+ return true;
+ }
+};
+
+class FloatTypeConverter : public TypeConverter {
+ public:
+ static Type convertTensor(RankedTensorType type) {
+ if (!type.getElementType().isF32()) return type;
+ auto newType = RankedTensorType::get(type.getShape(),
+ Float16Type::get(type.getContext()));
+ return newType;
+ }
+ explicit FloatTypeConverter() {
+ addConversion([](Type type) { return type; });
+ addConversion([&](FloatType type) {
+ if (type.isF32()) return FloatType::getF16(type.getContext());
+ return type;
+ });
+ addConversion(convertTensor);
+ addConversion([&](IREE::PtrType ptrType) {
+ if (auto tensorType =
+ ptrType.getTargetType().dyn_cast<RankedTensorType>()) {
+ return IREE::PtrType::get(convertTensor(tensorType));
+ }
+ return ptrType;
+ });
+ }
+};
+
+// Generic pattern to convert FP32 values and attributes to FP16.
+class GenericTypeConvert : public ConversionPattern {
+ public:
+ GenericTypeConvert(MLIRContext *context, TypeConverter &converter)
+ : ConversionPattern(0, converter, MatchAnyOpTypeTag()) {}
+ LogicalResult matchAndRewrite(
+ Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ llvm::SmallVector<NamedAttribute, 4> newAttr;
+ convertAttributes(op->getAttrs(), rewriter, newAttr);
+ llvm::SmallVector<Type, 4> newResults;
+ getTypeConverter()->convertTypes(op->getResultTypes(), newResults);
+ OperationState state(op->getLoc(), op->getName().getStringRef(), operands,
+ newResults, newAttr, op->getSuccessors());
+ for (Region &r : op->getRegions()) {
+ Region *newRegion = state.addRegion();
+ rewriter.inlineRegionBefore(r, *newRegion, newRegion->begin());
+ TypeConverter::SignatureConversion result(newRegion->getNumArguments());
+ getTypeConverter()->convertSignatureArgs(newRegion->getArgumentTypes(),
+ result);
+ rewriter.applySignatureConversion(newRegion, result);
+ }
+ Operation *newOp = rewriter.createOperation(state);
+ rewriter.replaceOp(op, newOp->getResults());
+ return success();
+ }
+
+ protected:
+ static void convertAttributes(ArrayRef<NamedAttribute> attrs,
+ ConversionPatternRewriter &rewriter,
+ SmallVectorImpl<NamedAttribute> &newAttrs) {
+ for (auto attr : attrs) {
+ if (auto fpAttr = attr.second.dyn_cast<DenseFPElementsAttr>()) {
+ std::vector<llvm::APFloat> args;
+ if (!fpAttr.getType().getElementType().isF32()) continue;
+ for (llvm::APFloat f : fpAttr.getFloatValues()) {
+ bool losesInfo;
+ f.convert(APFloat::IEEEhalf(), APFloat::rmTowardZero, &losesInfo);
+ args.push_back(f);
+ }
+ auto tensorType = RankedTensorType::get(fpAttr.getType().getShape(),
+ rewriter.getF16Type());
+ newAttrs.push_back(std::make_pair(
+ attr.first, DenseElementsAttr::get(tensorType, args)));
+ } else if (auto typeAttr = attr.second.dyn_cast<TypeAttr>()) {
+ if (isIllegalType(typeAttr.getValue())) {
+ if (auto tensorType =
+ typeAttr.getValue().dyn_cast<RankedTensorType>()) {
+ Type newType = RankedTensorType::get(tensorType.getShape(),
+ rewriter.getF16Type());
+ newAttrs.push_back(
+ std::make_pair(attr.first, TypeAttr::get(newType)));
+ }
+ }
+ } else {
+ newAttrs.push_back(attr);
+ }
+ }
+ }
+};
+
+void ConvertF32ToF16Pass::runOnOperation() {
+ MLIRContext *context = &getContext();
+ ModuleOp moduleOp = getOperation();
+
+ FloatTypeConverter converter;
+ OwningRewritePatternList patterns;
+ patterns.insert<GenericTypeConvert>(context, converter);
+ populateFuncOpTypeConversionPattern(patterns, context, converter);
+ F32ToF16ConversionTarget target(*context);
+ target.markUnknownOpDynamicallyLegal();
+ if (failed(applyFullConversion(moduleOp, target, std::move(patterns))))
+ return signalPassFailure();
+}
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// Pass entry point and registration
+//===----------------------------------------------------------------------===//
+std::unique_ptr<OperationPass<ModuleOp>> createDemoteF32ToF16Pass() {
+ return std::make_unique<ConvertF32ToF16Pass>();
+}
+
+static PassRegistration<ConvertF32ToF16Pass> pass(
+ "iree-convert-f32-to-f16",
+ "Convert f32 operations and values into equivalent f16 ones");
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/iree/compiler/Conversion/HLOToHLO/Passes.h b/iree/compiler/Conversion/HLOToHLO/Passes.h
index cf50702..df93c30 100644
--- a/iree/compiler/Conversion/HLOToHLO/Passes.h
+++ b/iree/compiler/Conversion/HLOToHLO/Passes.h
@@ -33,6 +33,10 @@
/// Creates a pass to decompose XLA-HLO clamp ops into primitive ops.
std::unique_ptr<OperationPass<FuncOp>> createDecomposeHLOClampPass();
+/// Create a pass to convert a model using f32 type to the equivalent one
+/// using 16.
+std::unique_ptr<OperationPass<ModuleOp>> createDemoteF32ToF16Pass();
+
} // namespace iree_compiler
} // namespace mlir
diff --git a/iree/compiler/Conversion/HLOToHLO/test/BUILD b/iree/compiler/Conversion/HLOToHLO/test/BUILD
new file mode 100644
index 0000000..1e3b7bb
--- /dev/null
+++ b/iree/compiler/Conversion/HLOToHLO/test/BUILD
@@ -0,0 +1,32 @@
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Tests for common transforms.
+
+load("//iree:lit_test.bzl", "iree_lit_test_suite")
+
+package(
+ default_visibility = ["//visibility:public"],
+ features = ["layering_check"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+iree_lit_test_suite(
+ name = "lit",
+ srcs = glob(["*.mlir"]),
+ data = [
+ "//iree/tools:IreeFileCheck",
+ "//iree/tools:iree-opt",
+ ],
+)
diff --git a/iree/compiler/Conversion/HLOToHLO/test/CMakeLists.txt b/iree/compiler/Conversion/HLOToHLO/test/CMakeLists.txt
new file mode 100644
index 0000000..fcc538b
--- /dev/null
+++ b/iree/compiler/Conversion/HLOToHLO/test/CMakeLists.txt
@@ -0,0 +1,26 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+iree_add_all_subdirs()
+
+file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir)
+iree_lit_test_suite(
+ NAME
+ lit
+ SRCS
+ "${_GLOB_X_MLIR}"
+ DATA
+ iree::tools::IreeFileCheck
+ iree::tools::iree-opt
+)
diff --git a/iree/compiler/Conversion/HLOToHLO/test/f32Tof16.mlir b/iree/compiler/Conversion/HLOToHLO/test/f32Tof16.mlir
new file mode 100644
index 0000000..6cd8325
--- /dev/null
+++ b/iree/compiler/Conversion/HLOToHLO/test/f32Tof16.mlir
@@ -0,0 +1,38 @@
+// RUN: iree-opt -split-input-file -iree-convert-f32-to-f16 %s | IreeFileCheck %s
+
+// CHECK: flow.variable {{.*}} : tensor<4xf16>
+// CHECK-LABEL: func @simple_f32() -> tensor<4xf16>
+// CHECK-NEXT: %{{.*}} = flow.variable.address @__global : !iree.ptr<tensor<4xf16>>
+// CHECK-NEXT: %{{.*}} = flow.variable.load.indirect %{{.*}} : !iree.ptr<tensor<4xf16>> -> tensor<4xf16>
+// CHECK-NEXT: return %{{.*}} : tensor<4xf16>
+module {
+ flow.variable @"__global" dense<"0x000020410000A040000020410000A040"> : tensor<4xf32> attributes {sym_visibility = "private"}
+ func @simple_f32() -> (tensor<4xf32>) {
+ %0 = flow.variable.address @"__global" : !iree.ptr<tensor<4xf32>>
+ %1 = flow.variable.load.indirect %0 : !iree.ptr<tensor<4xf32>> -> tensor<4xf32>
+ return %1 : tensor<4xf32>
+ }
+}
+
+// -----
+
+// CHECK: flow.variable
+// CHECK-NOT: f32
+// CHECK-LABEL: func @nested_region_f32()
+// CHECK-NOT: f32
+// CHECK: return %{{.*}} : tensor<4xf16>
+module {
+ flow.variable @"__global" dense<"0x000020410000A040000020410000A040"> : tensor<4xf32> attributes {sym_visibility = "private"}
+ func @nested_region_f32() -> (tensor<4xf32>) {
+ %0 = flow.variable.address @"__iree_flow_bert/embeddings/FakeLayerNorm/beta" : !iree.ptr<tensor<4xf32>>
+ %1 = flow.variable.load.indirect %0 : !iree.ptr<tensor<4xf32>> -> tensor<4xf32>
+ %2 = "mhlo.broadcast_in_dim"(%1) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<4x4xf32>
+ %4 = mhlo.constant dense<0xFF800000> : tensor<f32>
+ %3 = "mhlo.reduce"(%2, %4) ( {
+ ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>): // no predecessors
+ %5393 = mhlo.maximum %arg3, %arg4 : tensor<f32>
+ "mhlo.return"(%5393) : (tensor<f32>) -> ()
+ }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<4x4xf32>, tensor<f32>) -> tensor<4xf32>
+ return %3 : tensor<4xf32>
+ }
+}
diff --git a/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnBuffers.cpp b/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnBuffers.cpp
index d178038..90fe4dc 100644
--- a/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnBuffers.cpp
+++ b/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnBuffers.cpp
@@ -116,7 +116,7 @@
/// Returns the MemRefType to use for a given `tensorType`.
static MemRefType getMemrefTypeForTensor(
- RankedTensorType tensorType, ArrayRef<AffineMap> affineMapComposition = {},
+ ShapedType tensorType, ArrayRef<AffineMap> affineMapComposition = {},
unsigned memorySpace = 0) {
return MemRefType::get(tensorType.getShape(), tensorType.getElementType(),
affineMapComposition, memorySpace);
@@ -197,6 +197,14 @@
for (auto result : llvm::enumerate(op->getResults())) {
Value resultBuffer = resultTensorToBufferMap.lookup(result.value());
if (!resultBuffer) {
+ if (auto shapedType = result.value().getType().dyn_cast<ShapedType>()) {
+ if (shapedType.hasStaticShape()) {
+ resultBuffer = rewriter.create<AllocOp>(
+ op->getLoc(), getMemrefTypeForTensor(shapedType));
+ }
+ }
+ }
+ if (!resultBuffer) {
return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
diag << "failed to create buffer for result #" << result.index();
});
diff --git a/iree/compiler/Conversion/HLOToLinalg/test/linalg_tensor_to_buffer.mlir b/iree/compiler/Conversion/HLOToLinalg/test/linalg_tensor_to_buffer.mlir
index 057e909..1abc9c2 100644
--- a/iree/compiler/Conversion/HLOToLinalg/test/linalg_tensor_to_buffer.mlir
+++ b/iree/compiler/Conversion/HLOToLinalg/test/linalg_tensor_to_buffer.mlir
@@ -472,3 +472,54 @@
// CHECK: %[[RET0_RESHAPE2:.+]] = linalg.reshape %[[RET0]]
// CHECK-SAME: memref<1x1x1x1000xf32> into memref<1x1000xf32>
// CHECK: linalg.copy(%[[RET0_RESHAPE2]], %[[RET1]])
+
+// -----
+
+#map0 = affine_map<(d0, d1) -> (d0, d1)>
+module {
+ func @matmul_add() {
+ %c0 = constant 0 : index
+ %0 = hal.interface.load.tensor @legacy_io::@arg0, offset = %c0
+ : tensor<32x48xf32>
+ %1 = hal.interface.load.tensor @legacy_io::@arg1, offset = %c0
+ : tensor<48x64xf32>
+ %2 = hal.interface.load.tensor @legacy_io::@arg2, offset = %c0
+ : tensor<32x64xf32>
+ %3 = "mhlo.dot"(%0, %1)
+ : (tensor<32x48xf32>, tensor<48x64xf32>) -> tensor<32x64xf32>
+ %4 = linalg.generic {
+ indexing_maps = [#map0, #map0, #map0],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%2, %3 : tensor<32x64xf32>, tensor<32x64xf32>) {
+ ^bb0(%arg0: f32, %arg1: f32):
+ %5 = addf %arg0, %arg1 : f32
+ linalg.yield %5 : f32
+ } -> tensor<32x64xf32>
+ hal.interface.store.tensor %4, @legacy_io::@ret0, offset = %c0
+ : tensor<32x64xf32>
+ return
+ }
+ hal.interface @legacy_io attributes {sym_visiblity = "private"} {
+ hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+ hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
+ hal.interface.binding @arg2, set=0, binding=2, type="StorageBuffer", access="Read"
+ hal.interface.binding @ret0, set=0, binding=3, type="StorageBuffer", access="Write|Discard"
+ }
+}
+
+// CHECK-LABEL: func @matmul_add
+// CHECK-DAG: %[[RET0:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0}
+// CHECK-DAG: %[[ARG0:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0}
+// CHECK-DAG: %[[ARG1:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg1}
+// CHECK-DAG: %[[ARG2:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg2}
+// CHECK-DAG: %[[TEMP:.+]] = alloc()
+// CHECK: linalg.fill(%[[TEMP]], %{{.+}})
+// CHECK: linalg.matmul ins(%[[ARG0]], %[[ARG1]]
+// CHECK-SAME: ) outs(%[[TEMP]]
+// CHECK-SAME: )
+// CHECK: linalg.generic
+// CHECK-SAME: ins(%[[ARG2]], %[[TEMP]]
+// CHECK-SAME: ) outs(%[[RET0]]
+// CHECK-SAME: )
+// CHECK: return
+
diff --git a/iree/compiler/Conversion/LinalgToLLVM/BUILD b/iree/compiler/Conversion/LinalgToLLVM/BUILD
index 5dbbadd..bb93b48 100644
--- a/iree/compiler/Conversion/LinalgToLLVM/BUILD
+++ b/iree/compiler/Conversion/LinalgToLLVM/BUILD
@@ -27,6 +27,7 @@
"LinalgTileAndDistributePass.cpp",
"LinalgTileAndVectorizePass.cpp",
"Passes.cpp",
+ "PlanConvLoopOrder.cpp",
],
hdrs = [
"KernelDispatch.h",
diff --git a/iree/compiler/Conversion/LinalgToLLVM/CMakeLists.txt b/iree/compiler/Conversion/LinalgToLLVM/CMakeLists.txt
index 4018a29..df9c39a 100644
--- a/iree/compiler/Conversion/LinalgToLLVM/CMakeLists.txt
+++ b/iree/compiler/Conversion/LinalgToLLVM/CMakeLists.txt
@@ -27,6 +27,7 @@
"LinalgTileAndDistributePass.cpp"
"LinalgTileAndVectorizePass.cpp"
"Passes.cpp"
+ "PlanConvLoopOrder.cpp"
DEPS
LLVMSupport
MLIRAffineToStandard
diff --git a/iree/compiler/Conversion/LinalgToLLVM/KernelDispatch.cpp b/iree/compiler/Conversion/LinalgToLLVM/KernelDispatch.cpp
index d37efed..0d81a0d 100644
--- a/iree/compiler/Conversion/LinalgToLLVM/KernelDispatch.cpp
+++ b/iree/compiler/Conversion/LinalgToLLVM/KernelDispatch.cpp
@@ -23,6 +23,14 @@
namespace mlir {
namespace iree_compiler {
+// TODO(ravishankarm): This needs to be put in a common place for the CPU and
+// GPU backends to use.
+static llvm::cl::list<unsigned> clLLVMTileSizes(
+ "iree-llvm-tile-size",
+ llvm::cl::desc("Set tile sizes to use for tiling Linalg operations in "
+ "LLVM code generation"),
+ llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated);
+
static llvm::cl::opt<int> matmulWorkgroupTileSize(
"iree-codegen-linalg-to-llvm-kernel-dispatch-matmul-workgroup-tile-size",
llvm::cl::desc(
@@ -151,5 +159,48 @@
#undef DEFINE_TILE_SIZE_FN
+Optional<LaunchConfig> initCPULaunchConfig(
+ MLIRContext *context, const linalg::LinalgDependenceGraph &dependenceGraph,
+ ArrayRef<linalg::LinalgOp> linalgOps) {
+ LaunchConfig config;
+ if (!clLLVMTileSizes.empty()) {
+ SmallVector<int64_t, 3> tileSizes(clLLVMTileSizes.begin(),
+ clLLVMTileSizes.end());
+ for (linalg::LinalgOp linalgOp : linalgOps) {
+ config.setTileSizes(linalgOp.getOperation(), tileSizes, 0);
+ }
+ return config;
+ }
+
+ Optional<linalg::LinalgOp> rootOperation = llvm::None;
+ for (auto linalgOp : linalgOps) {
+#define DISPATCH(opType) \
+ if (opType op = dyn_cast<opType>(linalgOp.getOperation())) { \
+ if (rootOperation) { \
+ op.emitError("unhandled multiple root operations in dispatch region"); \
+ return llvm::None; \
+ } \
+ rootOperation = linalgOp; \
+ config.setTileSizes( \
+ op, \
+ TileOpParameters::getSizes<opType, TilingLevel::WorkGroupTiles>(op), \
+ 0); \
+ continue; \
+ }
+
+ DISPATCH(linalg::MatmulOp)
+ DISPATCH(linalg::BatchMatmulOp)
+
+#undef DISPATCH
+ }
+ if (!rootOperation) {
+ return config;
+ }
+ if (failed(propogateRootOperationLaunchConfig(config, *rootOperation,
+ dependenceGraph)))
+ return llvm::None;
+ return config;
+}
+
} // namespace iree_compiler
} // namespace mlir
diff --git a/iree/compiler/Conversion/LinalgToLLVM/KernelDispatch.h b/iree/compiler/Conversion/LinalgToLLVM/KernelDispatch.h
index f2d7711..20972af 100644
--- a/iree/compiler/Conversion/LinalgToLLVM/KernelDispatch.h
+++ b/iree/compiler/Conversion/LinalgToLLVM/KernelDispatch.h
@@ -14,6 +14,7 @@
#include <cstdint>
+#include "iree/compiler/Conversion/Common/LaunchConfig.h"
#include "llvm/ADT/SmallVector.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Operation.h"
@@ -28,7 +29,8 @@
// Tile linalg operation on workgroup thread into L1 block tiles.
Level1Tiles = 1,
// Tile linalg operations on L1 block tiles into vector tiles.
- Level2Tiles = 2
+ Level2Tiles = 2,
+ NumTileLevels = 3
};
class CPUKernelDispatch {
@@ -44,5 +46,9 @@
Operation *operation);
};
+Optional<LaunchConfig> initCPULaunchConfig(
+ MLIRContext *context, const linalg::LinalgDependenceGraph &dependenceGraph,
+ ArrayRef<linalg::LinalgOp> linalgOps);
+
} // namespace iree_compiler
} // namespace mlir
diff --git a/iree/compiler/Conversion/LinalgToLLVM/LinalgTileAndDistributePass.cpp b/iree/compiler/Conversion/LinalgToLLVM/LinalgTileAndDistributePass.cpp
index 6e8929c..cd0cdd2 100644
--- a/iree/compiler/Conversion/LinalgToLLVM/LinalgTileAndDistributePass.cpp
+++ b/iree/compiler/Conversion/LinalgToLLVM/LinalgTileAndDistributePass.cpp
@@ -17,11 +17,13 @@
#include "iree/compiler/Conversion/CodegenUtils/MarkerUtils.h"
#include "iree/compiler/Conversion/CodegenUtils/MatmulCodegenStrategy.h"
#include "iree/compiler/Conversion/Common/Attributes.h"
+#include "iree/compiler/Conversion/Common/Transforms.h"
#include "iree/compiler/Conversion/LinalgToLLVM/KernelDispatch.h"
#include "iree/compiler/Dialect/IREE/IR/IREEDialect.h"
#include "iree/compiler/Dialect/IREE/IR/IREEOps.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/IR/Builders.h"
+#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -41,11 +43,6 @@
LinalgTileAndDistributePass() = default;
LinalgTileAndDistributePass(const LinalgTileAndDistributePass &pass) {}
void runOnOperation() override;
-
- private:
- ListOption<int64_t> tileSizes{
- *this, "tile-sizes", llvm::cl::desc("Set tile sizes to use"),
- llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
};
} // namespace
@@ -124,6 +121,34 @@
} // namespace
+Optional<Value> allocateThreadLocalMemory(OpBuilder &b, SubViewOp subview,
+ ArrayRef<Value> boundingSubViewSize,
+ OperationFolder *folder) {
+ // Allocate the memory into the entry block of the parent FuncOp. This better
+ // aligns with the semantics of this memory which is available at the entry of
+ // the function.
+ OpBuilder::InsertionGuard guard(b);
+ FuncOp funcOp = subview.getParentOfType<FuncOp>();
+ if (!funcOp) {
+ subview.emitError("expected op to be within std.func");
+ return llvm::None;
+ }
+ b.setInsertionPointToStart(&(*funcOp.getBody().begin()));
+ // The bounding subview size is expected to be constant. This specified the
+ // shape of the allocation.
+ SmallVector<int64_t, 2> shape = llvm::to_vector<2>(
+ llvm::map_range(boundingSubViewSize, [](Value v) -> int64_t {
+ APInt value;
+ if (matchPattern(v, m_ConstantInt(&value))) return value.getSExtValue();
+ return -1;
+ }));
+ if (llvm::any_of(shape, [](int64_t v) { return v == -1; })) return {};
+ MemRefType allocType =
+ MemRefType::get(shape, subview.getType().getElementType());
+ Value buffer = b.create<AllocaOp>(subview.getLoc(), allocType);
+ return buffer;
+}
+
void LinalgTileAndDistributePass::runOnOperation() {
MLIRContext *context = &getContext();
ModuleOp module = getOperation();
@@ -147,57 +172,57 @@
linalg::DistributionMethod::CyclicNumProcsEqNumIters,
linalg::DistributionMethod::CyclicNumProcsEqNumIters}};
- CPUKernelDispatch cpuKernelDispatch;
-
for (FuncOp funcOp : module.getOps<FuncOp>()) {
if (!isEntryPoint(funcOp)) continue;
- // Compute the Linalg Dependence Graph.
+ Region &body = funcOp.getBody();
+ if (!llvm::hasSingleElement(body.getBlocks())) {
+ funcOp.emitError("unhandled dispatch function with multiple blocks");
+ return signalPassFailure();
+ }
+ Block &block = body.front();
+ auto linalgOps = block.getOps<linalg::LinalgOp>();
+ if (linalgOps.empty()) continue;
+
+ SmallVector<linalg::LinalgOp, 4> linalgOpsVec =
+ llvm::to_vector<4>(llvm::map_range(linalgOps, [](Operation *op) {
+ return cast<linalg::LinalgOp>(op);
+ }));
linalg::Aliases aliases;
- linalg::LinalgDependenceGraph dependenceGraph =
- linalg::LinalgDependenceGraph::buildDependenceGraph(aliases, funcOp);
+ linalg::LinalgDependenceGraph dependenceGraph(aliases, linalgOpsVec);
+ Optional<LaunchConfig> launchConfigOpt =
+ initCPULaunchConfig(context, dependenceGraph, linalgOpsVec);
+ if (!launchConfigOpt) {
+ funcOp.emitError("unable to find launch configuration");
+ return signalPassFailure();
+ }
+ LaunchConfig &launchConfig = *launchConfigOpt;
- OwningRewritePatternList patterns;
-
- auto linalgTilingOptions =
- linalg::LinalgTilingOptions()
- .setDistributionOptions(workgroupDistributionOptions)
- .setLoopType(linalg::LinalgTilingLoopType::ParallelLoops);
- tileSizes.empty()
- ? linalgTilingOptions.setTileSizeComputationFunction(
- [&cpuKernelDispatch](
- OpBuilder &builder,
- Operation *operation) -> SmallVector<Value, 4> {
- return TileSizeFn::get<TilingLevel::WorkGroupTiles>(
- cpuKernelDispatch, builder, operation);
- })
- : linalgTilingOptions.setTileSizes(ArrayRef<int64_t>(tileSizes));
- patterns.insert<TileAndFuseToCPUThreads<linalg::MatmulOp>,
- TileAndFuseToCPUThreads<linalg::BatchMatmulOp>,
- TileToCPUThreads<linalg::MatmulOp>,
- TileToCPUThreads<linalg::BatchMatmulOp>>(
- context, dependenceGraph, cpuKernelDispatch, linalgTilingOptions,
- linalg::LinalgMarker(ArrayRef<Identifier>(),
- Identifier::get(getWorkgroupMarker(), context)));
-
- // Tile and distribute to CPU threads.
- applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
-
- // Apply canonicalization patterns.
- OwningRewritePatternList canonicalizationPatterns;
- canonicalizationPatterns.insert<AffineMinCanonicalizationPattern>(context);
- AffineApplyOp::getCanonicalizationPatterns(canonicalizationPatterns,
- context);
- AffineMinOp::getCanonicalizationPatterns(canonicalizationPatterns, context);
- SubViewOp::getCanonicalizationPatterns(canonicalizationPatterns, context);
-
- applyPatternsAndFoldGreedily(funcOp, std::move(canonicalizationPatterns));
-
- // Delete the ops that are marked for deletion.
- funcOp.walk([](linalg::LinalgOp linalgOp) {
- if (hasMarker(linalgOp.getOperation(), getDeleteMarker()))
- linalgOp.getOperation()->erase();
+ LLVM_DEBUG({
+ llvm::dbgs() << "@func " << funcOp.getName() << "\n";
+ for (auto op : linalgOps) {
+ llvm::dbgs() << "\t" << op.getOperation()->getName() << " : ";
+ TileSizesListTypeRef configTileSizes = launchConfig.getTileSizes(op);
+ llvm::dbgs() << "{";
+ std::string sep = "";
+ for (auto &level : enumerate(configTileSizes)) {
+ llvm::dbgs() << sep << level.index() << " : [";
+ sep = ", ";
+ interleaveComma(level.value(), llvm::dbgs());
+ llvm::dbgs() << "]";
+ }
+ llvm::dbgs() << "}\n";
+ }
});
+
+ TileAndFuseOptions tileAndFuseOptions = {workgroupDistributionOptions,
+ allocateThreadLocalMemory};
+ if (failed(tileAndFuseLinalgBufferOps(funcOp, linalgOpsVec, dependenceGraph,
+ launchConfig, tileAndFuseOptions))) {
+ return signalPassFailure();
+ }
+
+ launchConfig.finalize(funcOp);
}
}
diff --git a/iree/compiler/Conversion/LinalgToLLVM/Passes.cpp b/iree/compiler/Conversion/LinalgToLLVM/Passes.cpp
index 348b425..b3f2c9c 100644
--- a/iree/compiler/Conversion/LinalgToLLVM/Passes.cpp
+++ b/iree/compiler/Conversion/LinalgToLLVM/Passes.cpp
@@ -54,6 +54,7 @@
if (convImg2ColConversion) {
passManager.addNestedPass<FuncOp>(createConvImg2ColMatmulConversionPass());
}
+ passManager.addNestedPass<FuncOp>(createPlanConvLoopOrderPass());
passManager.addNestedPass<FuncOp>(
createLinalgTileAndVectorizeWorkgroupsPass());
diff --git a/iree/compiler/Conversion/LinalgToLLVM/Passes.h b/iree/compiler/Conversion/LinalgToLLVM/Passes.h
index 33c74d4..0a325dc 100644
--- a/iree/compiler/Conversion/LinalgToLLVM/Passes.h
+++ b/iree/compiler/Conversion/LinalgToLLVM/Passes.h
@@ -24,10 +24,14 @@
/// linalg::MatmulOp.
std::unique_ptr<FunctionPass> createConvImg2ColMatmulConversionPass();
-/// Distribute linalg ops among iree.workgroup logical threads.
+/// Converts linalg.conv into linalg.generic with a CPU-friendly iteration
+/// order.
+std::unique_ptr<FunctionPass> createPlanConvLoopOrderPass();
+
+/// Distributes linalg ops among iree.workgroup logical threads.
std::unique_ptr<OperationPass<ModuleOp>> createLinalgTileAndDistributePass();
-/// Vectorize linalg ops executed in the same iree.workgroup.
+/// Vectorizes linalg ops executed in the same iree.workgroup.
std::unique_ptr<FunctionPass> createLinalgTileAndVectorizeWorkgroupsPass();
std::unique_ptr<OperationPass<ModuleOp>>
@@ -38,7 +42,7 @@
void populateConvImg2ColMatmulConversionPatterns(
MLIRContext *context, OwningRewritePatternList &patterns);
-/// Pass to perform final conversion to LLVM dialect.
+/// Performs the final conversion to LLVM dialect.
std::unique_ptr<OperationPass<ModuleOp>> createConvertToLLVMPass();
/// Populates passes needed to lower a XLA HLO op to LLVM dialect via the
diff --git a/iree/compiler/Conversion/LinalgToLLVM/PlanConvLoopOrder.cpp b/iree/compiler/Conversion/LinalgToLLVM/PlanConvLoopOrder.cpp
new file mode 100644
index 0000000..db898a4
--- /dev/null
+++ b/iree/compiler/Conversion/LinalgToLLVM/PlanConvLoopOrder.cpp
@@ -0,0 +1,77 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "iree/compiler/Conversion/LinalgToLLVM/Passes.h"
+#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+namespace {
+
+struct PlanConvLoopOrderPass
+ : PassWrapper<PlanConvLoopOrderPass, FunctionPass> {
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<linalg::LinalgDialect>();
+ }
+ void runOnFunction() override;
+};
+
+} // namespace
+
+void PlanConvLoopOrderPass::runOnFunction() {
+ auto funcOp = getOperation();
+ auto context = funcOp.getContext();
+
+ auto marker = Identifier::get("generalized_from_conv", context);
+ linalg::LinalgMarker firstStepMarker(
+ /*matchDisjunction=*/ArrayRef<Identifier>(),
+ /*replacement=*/marker);
+ linalg::LinalgMarker secondStepMarker(
+ /*matchDisjunction=*/marker,
+ /*replacement=*/llvm::None);
+
+ SmallVector<unsigned, 8> loopOrder = {
+ /*batch=*/0,
+ /*output_height=*/1,
+ /*output_width=*/2,
+ /*filter_height=*/5,
+ /*filter_width=*/6,
+ /*input_channel=*/4,
+ /*output_channel=*/3,
+ };
+
+ OwningRewritePatternList patterns;
+ linalg::populateLinalgConvGeneralizationPatterns(context, patterns,
+ firstStepMarker);
+ patterns.insert<linalg::LinalgInterchangePattern<linalg::GenericOp>>(
+ context, loopOrder, secondStepMarker);
+
+ applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
+}
+
+std::unique_ptr<FunctionPass> createPlanConvLoopOrderPass() {
+ return std::make_unique<PlanConvLoopOrderPass>();
+}
+
+static PassRegistration<PlanConvLoopOrderPass> pass(
+ "iree-codegen-linalg-to-llvm-plan-conv-loop-order",
+ "Convert linalg.conv to linalg.generic with a CPU-friendly iterator order",
+ [] { return std::make_unique<PlanConvLoopOrderPass>(); });
+
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/iree/compiler/Conversion/LinalgToLLVM/test/plan-conv-loop-order.mlir b/iree/compiler/Conversion/LinalgToLLVM/test/plan-conv-loop-order.mlir
new file mode 100644
index 0000000..0cc8ace
--- /dev/null
+++ b/iree/compiler/Conversion/LinalgToLLVM/test/plan-conv-loop-order.mlir
@@ -0,0 +1,15 @@
+// RUN: iree-opt -iree-codegen-linalg-to-llvm-plan-conv-loop-order %s | IreeFileCheck %s
+
+func @conv(%filter: memref<3x3x3x32xf32>, %input: memref<1x225x225x3xf32>, %output: memref<1x112x112x32xf32>) {
+ linalg.conv(%filter, %input, %output) {dilations = [1, 1], strides = [2, 2]} : memref<3x3x3x32xf32>, memref<1x225x225x3xf32>, memref<1x112x112x32xf32>
+ return
+}
+
+// CHECK: #[[FILTER_MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d3, d4, d5, d6)>
+// CHECK: #[[INPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 * 2 + d3, d2 * 2 + d4, d5)>
+// CHECK: #[[OUTPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d6)>
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[FILTER_MAP]], #[[INPUT_MAP]], #[[OUTPUT_MAP]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "window", "window", "reduction", "parallel"]
+
+
diff --git a/iree/compiler/Conversion/LinalgToLLVM/test/tile_and_distribute.mlir b/iree/compiler/Conversion/LinalgToLLVM/test/tile_and_distribute.mlir
index baffba1..35f816b 100644
--- a/iree/compiler/Conversion/LinalgToLLVM/test/tile_and_distribute.mlir
+++ b/iree/compiler/Conversion/LinalgToLLVM/test/tile_and_distribute.mlir
@@ -1,35 +1,35 @@
-// RUN: iree-opt --iree-codegen-llvm-linalg-tile-and-distribute=tile-sizes=2,4,1 -cse -split-input-file %s | IreeFileCheck %s
+// RUN: iree-opt --iree-codegen-llvm-linalg-tile-and-distribute -iree-llvm-tile-size=2,4,1 -cse -split-input-file %s | IreeFileCheck %s
func @dynamic_matmul(%lhs: memref<?x?xf32>, %rhs: memref<?x?xf32>, %result: memref<?x?xf32>) {
linalg.matmul ins(%lhs, %rhs : memref<?x?xf32>, memref<?x?xf32>) outs(%result : memref<?x?xf32>)
return
}
-// CHECK: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 * 2)>
-// CHECK: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (2, s1 - s0 * 2)>
-// CHECK: #[[MAP2:.+]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
-// CHECK: #[[MAP3:.+]] = affine_map<()[s0] -> (s0 * 4)>
-// CHECK: #[[MAP4:.+]] = affine_map<()[s0, s1] -> (4, s1 - s0 * 4)>
-// CHECK: func @dynamic_matmul(%[[LHS:.+]]: memref<?x?xf32>, %[[RHS:.+]]: memref<?x?xf32>, %[[RESULT:.+]]: memref<?x?xf32>)
-// CHECK: %[[CONST_0:.+]] = constant 0 : index
-// CHECK: %[[CONST_1:.+]] = constant 1 : index
-// CHECK: %[[DIM_K:.+]] = dim %[[LHS]], %[[CONST_1]]
-// CHECK: %[[THREAD_X_ID:.+]] = iree.workgroup_id {dimension = "x"} : index
-// CHECK: %[[THREAD_Y_ID:.+]] = iree.workgroup_id {dimension = "y"} : index
-// CHECK: scf.for %[[K:.+]] = %[[CONST_0]] to %[[DIM_K]]
-// CHECK: %[[I:.+]] = affine.apply #[[MAP0]]()[%[[THREAD_Y_ID]]]
-// CHECK: %[[DIM_I:.+]] = dim %[[LHS]], %[[CONST_0]]
-// CHECK: %[[I_OFFSET:.+]] = affine.min #[[MAP1]]()[%[[THREAD_Y_ID]], %[[DIM_I]]]
-// CHECK: %[[LHS_SUBVIEW:.+]] = subview %[[LHS]][%[[I]], %[[K]]] [%[[I_OFFSET]], 1] [1, 1]
-// CHECK: %[[J:.+]] = affine.apply #[[MAP3]]()[%[[THREAD_X_ID]]]
-// CHECK: %[[DIM_J:.+]] = dim %[[RHS]], %[[CONST_1]]
-// CHECK: %[[J_OFFSET:.+]] = affine.min #[[MAP4]]()[%[[THREAD_X_ID]], %[[DIM_J]]]
-// CHECK: %[[RHS_SUBVIEW:.+]] = subview %[[RHS]][%[[K]], %[[J]]] [1, %[[J_OFFSET]]] [1, 1]
-// CHECK: %[[DIM_I:.+]] = dim %[[RESULT]], %[[CONST_0]]
-// CHECK: %[[DIM_I_OFFSET:.+]] = affine.min #[[MAP1]]()[%[[THREAD_Y_ID]], %[[DIM_I]]]
-// CHECK: %[[DIM_J:.+]] = dim %[[RESULT]], %[[CONST_1]]
-// CHECK: %[[DIM_J_OFFSET:.+]] = affine.min #[[MAP4]]()[%[[THREAD_X_ID]], %[[DIM_J]]]
-// CHECK: %[[RESULT_SUBVIEW:.+]] = subview %[[RESULT]][%[[I]], %[[J]]] [%[[DIM_I_OFFSET]], %[[DIM_J_OFFSET]]] [1, 1]
-// CHECK: linalg.matmul {__internal_linalg_transform__ = "workgroup"} ins(%[[LHS_SUBVIEW]], %[[RHS_SUBVIEW]] : memref<?x1xf32, #[[MAP2]]>, memref<1x?xf32, #[[MAP2]]>) outs(%[[RESULT_SUBVIEW]] : memref<?x?xf32, #[[MAP2]]>)
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 * 2)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (2, s1 - s0 * 2)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
+// CHECK-DAG: #[[MAP3:.+]] = affine_map<()[s0] -> (s0 * 4)>
+// CHECK-DAG: #[[MAP4:.+]] = affine_map<()[s0, s1] -> (4, s1 - s0 * 4)>
+// CHECK: func @dynamic_matmul(%[[LHS:.+]]: memref<?x?xf32>, %[[RHS:.+]]: memref<?x?xf32>, %[[RESULT:.+]]: memref<?x?xf32>)
+// CHECK-DAG: %[[CONST_0:.+]] = constant 0 : index
+// CHECK-DAG: %[[CONST_1:.+]] = constant 1 : index
+// CHECK-DAG: %[[DIM_K:.+]] = dim %[[LHS]], %[[CONST_1]]
+// CHECK-DAG: %[[THREAD_X_ID:.+]] = iree.workgroup_id {dimension = "x"} : index
+// CHECK-DAG: %[[THREAD_Y_ID:.+]] = iree.workgroup_id {dimension = "y"} : index
+// CHECK: scf.for %[[K:.+]] = %[[CONST_0]] to %[[DIM_K]]
+// CHECK: %[[I:.+]] = affine.apply #[[MAP0]]()[%[[THREAD_Y_ID]]]
+// CHECK: %[[DIM_I:.+]] = dim %[[LHS]], %[[CONST_0]]
+// CHECK: %[[I_OFFSET:.+]] = affine.min #[[MAP1]]()[%[[THREAD_Y_ID]], %[[DIM_I]]]
+// CHECK: %[[LHS_SUBVIEW:.+]] = subview %[[LHS]][%[[I]], %[[K]]] [%[[I_OFFSET]], 1] [1, 1]
+// CHECK: %[[J:.+]] = affine.apply #[[MAP3]]()[%[[THREAD_X_ID]]]
+// CHECK: %[[DIM_J:.+]] = dim %[[RHS]], %[[CONST_1]]
+// CHECK: %[[J_OFFSET:.+]] = affine.min #[[MAP4]]()[%[[THREAD_X_ID]], %[[DIM_J]]]
+// CHECK: %[[RHS_SUBVIEW:.+]] = subview %[[RHS]][%[[K]], %[[J]]] [1, %[[J_OFFSET]]] [1, 1]
+// CHECK: %[[DIM_I:.+]] = dim %[[RESULT]], %[[CONST_0]]
+// CHECK: %[[DIM_I_OFFSET:.+]] = affine.min #[[MAP1]]()[%[[THREAD_Y_ID]], %[[DIM_I]]]
+// CHECK: %[[DIM_J:.+]] = dim %[[RESULT]], %[[CONST_1]]
+// CHECK: %[[DIM_J_OFFSET:.+]] = affine.min #[[MAP4]]()[%[[THREAD_X_ID]], %[[DIM_J]]]
+// CHECK: %[[RESULT_SUBVIEW:.+]] = subview %[[RESULT]][%[[I]], %[[J]]] [%[[DIM_I_OFFSET]], %[[DIM_J_OFFSET]]] [1, 1]
+// CHECK: linalg.matmul {__internal_linalg_transform__ = "workgroup"} ins(%[[LHS_SUBVIEW]], %[[RHS_SUBVIEW]] : memref<?x1xf32, #[[MAP2]]>, memref<1x?xf32, #[[MAP2]]>) outs(%[[RESULT_SUBVIEW]] : memref<?x?xf32, #[[MAP2]]>)
// -----
@@ -37,20 +37,20 @@
linalg.matmul ins(%lhs, %rhs : memref<16x4xf32>, memref<4x8xf32>) outs(%result : memref<16x8xf32>)
return
}
-// CHECK: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 * 2)>
-// CHECK: #[[MAP1:.+]] = affine_map<(d0, d1)[s0] -> (d0 * 4 + s0 + d1)>
-// CHECK: #[[MAP2:.+]] = affine_map<()[s0] -> (s0 * 4)>
-// CHECK: #[[MAP3:.+]] = affine_map<(d0, d1)[s0] -> (d0 * 8 + s0 + d1)>
-// CHECK: func @static_matmul(%[[LHS:.+]]: memref<16x4xf32>, %[[RHS:.+]]: memref<4x8xf32>, %[[RESULT:.+]]: memref<16x8xf32>)
-// CHECK: %[[CONST_0:.+]] = constant 0 : index
-// CHECK: %[[CONST_4:.+]] = constant 4 : index
-// CHECK: %[[CONST_1:.+]] = constant 1 : index
-// CHECK: %[[THREAD_X_ID:.+]] = iree.workgroup_id {dimension = "x"} : index
-// CHECK: %[[THREAD_Y_ID:.+]] = iree.workgroup_id {dimension = "y"} : index
-// CHECK: scf.for %[[K:.+]] = %[[CONST_0]] to %[[CONST_4]] step %[[CONST_1]]
-// CHECK: %[[I:.+]] = affine.apply #[[MAP0]]()[%[[THREAD_Y_ID]]]
-// CHECK: %[[LHS_SUBVIEW:.+]] = subview %[[LHS]][%[[I]], %[[K]]] [2, 1] [1, 1] : memref<16x4xf32> to memref<2x1xf32, #[[MAP1]]>
-// CHECK: %[[J:.+]] = affine.apply #[[MAP2]]()[%[[THREAD_X_ID]]]
-// CHECK: %[[RHS_SUBVIEW:.+]] = subview %[[RHS]][%[[K]], %[[J]]] [1, 4] [1, 1] : memref<4x8xf32> to memref<1x4xf32, #[[MAP3]]>
-// CHECK: %[[RESULT_SUBVIEW:.+]] = subview %[[RESULT]][%[[I]], %[[J]]] [2, 4] [1, 1] : memref<16x8xf32> to memref<2x4xf32, #[[MAP3]]>
-// CHECK: linalg.matmul {__internal_linalg_transform__ = "workgroup"} ins(%[[LHS_SUBVIEW]], %[[RHS_SUBVIEW]] : memref<2x1xf32, #[[MAP1]]>, memref<1x4xf32, #[[MAP3]]>) outs(%6 : memref<2x4xf32, #[[MAP3]]>)
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 * 2)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1)[s0] -> (d0 * 4 + s0 + d1)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<()[s0] -> (s0 * 4)>
+// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1)[s0] -> (d0 * 8 + s0 + d1)>
+// CHECK: func @static_matmul(%[[LHS:.+]]: memref<16x4xf32>, %[[RHS:.+]]: memref<4x8xf32>, %[[RESULT:.+]]: memref<16x8xf32>)
+// CHECK-DAG: %[[CONST_0:.+]] = constant 0 : index
+// CHECK-DAG: %[[CONST_4:.+]] = constant 4 : index
+// CHECK-DAG: %[[CONST_1:.+]] = constant 1 : index
+// CHECK-DAG: %[[THREAD_X_ID:.+]] = iree.workgroup_id {dimension = "x"} : index
+// CHECK-DAG: %[[THREAD_Y_ID:.+]] = iree.workgroup_id {dimension = "y"} : index
+// CHECK: scf.for %[[K:.+]] = %[[CONST_0]] to %[[CONST_4]] step %[[CONST_1]]
+// CHECK: %[[I:.+]] = affine.apply #[[MAP0]]()[%[[THREAD_Y_ID]]]
+// CHECK: %[[LHS_SUBVIEW:.+]] = subview %[[LHS]][%[[I]], %[[K]]] [2, 1] [1, 1] : memref<16x4xf32> to memref<2x1xf32, #[[MAP1]]>
+// CHECK: %[[J:.+]] = affine.apply #[[MAP2]]()[%[[THREAD_X_ID]]]
+// CHECK: %[[RHS_SUBVIEW:.+]] = subview %[[RHS]][%[[K]], %[[J]]] [1, 4] [1, 1] : memref<4x8xf32> to memref<1x4xf32, #[[MAP3]]>
+// CHECK: %[[RESULT_SUBVIEW:.+]] = subview %[[RESULT]][%[[I]], %[[J]]] [2, 4] [1, 1] : memref<16x8xf32> to memref<2x4xf32, #[[MAP3]]>
+// CHECK: linalg.matmul {__internal_linalg_transform__ = "workgroup"} ins(%[[LHS_SUBVIEW]], %[[RHS_SUBVIEW]] : memref<2x1xf32, #[[MAP1]]>, memref<1x4xf32, #[[MAP3]]>) outs(%6 : memref<2x4xf32, #[[MAP3]]>)
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/ConvertToGPUPass.cpp b/iree/compiler/Conversion/LinalgToSPIRV/ConvertToGPUPass.cpp
index d281d79..8771e7f 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/ConvertToGPUPass.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/ConvertToGPUPass.cpp
@@ -754,6 +754,8 @@
MapLinalgOpToLocalInvocationId<linalg::ConvOp>,
MapLinalgOpToLocalInvocationId<linalg::CopyOp>,
MapLinalgOpToLocalInvocationId<linalg::FillOp>,
+ MapLinalgOpToLocalInvocationId<linalg::GenericOp>,
+ MapLinalgOpToLocalInvocationId<linalg::IndexedGenericOp>,
MapLinalgOpToLocalInvocationId<linalg::MatmulOp>,
MapLinalgOpToLocalInvocationId<linalg::BatchMatmulOp>,
MapLinalgOpToLocalInvocationId<linalg::PoolingMaxOp>,
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.cpp b/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.cpp
index 7f0198a..4178fcc 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.cpp
@@ -25,6 +25,7 @@
#include "iree/compiler/Conversion/CodegenUtils/FunctionUtils.h"
#include "iree/compiler/Conversion/Common/Attributes.h"
+#include "iree/compiler/Conversion/Common/LaunchConfig.h"
#include "iree/compiler/Conversion/LinalgToSPIRV/Passes.h"
#include "iree/compiler/Conversion/LinalgToSPIRV/Utils.h"
#include "iree/compiler/Dialect/IREE/IR/IREEOps.h"
@@ -49,10 +50,6 @@
namespace mlir {
namespace iree_compiler {
-/// Name of the StrAttr that can be used to get the key to access the tile size
-/// information.
-static const char kLaunchInfoKey[] = "launch_info_key";
-
/// Given `nprocs` try to distribute it evenly across 2 logical x and y.
static std::tuple<int64_t, int64_t> distributeProcs2D(int64_t nprocs) {
int64_t nprocs_x = std::max<int64_t>(
@@ -61,6 +58,13 @@
return std::make_tuple(nprocs_x, nprocs / nprocs_x);
}
+/// Returns the minimum of `shape` and `tileSize` if shape is static. If `shape`
+/// is dynamic returns `tileSize`.
+static int64_t getMinIfShapeStatic(int64_t shape, int64_t tileSize) {
+ if (shape == ShapedType::kDynamicSize) return tileSize;
+ return std::min(shape, tileSize);
+}
+
namespace {
struct LaunchConfigInfo {
std::array<int64_t, 3> workgroupSize = {1, 1, 1};
@@ -83,6 +87,51 @@
return op.emitError("undefined launch config for tiled operation");
}
+static void getMaliBestMatMulTileSizes(Type elementType,
+ SmallVectorImpl<int64_t> &tileSizes) {
+ if (elementType.isF16()) {
+ tileSizes.append({16, 64, 8});
+ } else {
+ tileSizes.append({8, 64, 4});
+ }
+}
+
+/// Launch configuration for Mali GPU configuration.
+static LogicalResult getMaliSpecificConfig(
+ linalg::BatchMatmulOp op, const spirv::TargetEnv &targetEnv,
+ const SPIRVCodegenOptions &options, TileSizesListType &tileSizes,
+ std::array<int64_t, 3> &workgroupSize,
+ std::array<int64_t, 3> &numSubgroups) {
+ if (targetEnv.getVendorID() != spirv::Vendor::ARM) return failure();
+
+ auto lhsType = op.inputs()[0].getType().cast<MemRefType>();
+ auto rhsType = op.inputs()[1].getType().cast<MemRefType>();
+ assert(lhsType.getElementType() == rhsType.getElementType());
+ // Pick ideal tile size based on the type.
+ SmallVector<int64_t, 4> workgroupLevelTs(1, 1);
+ getMaliBestMatMulTileSizes(lhsType.getElementType(), workgroupLevelTs);
+ // Fall back to the none vectorize path for cases we don't handle.
+ if (!lhsType.hasStaticShape() || !rhsType.hasStaticShape() ||
+ lhsType.getDimSize(1) % workgroupLevelTs[1] != 0 ||
+ rhsType.getDimSize(2) % workgroupLevelTs[2] != 0 ||
+ lhsType.getDimSize(2) % workgroupLevelTs[3] != 0) {
+ return failure();
+ }
+
+ workgroupSize[0] = targetEnv.getResourceLimits().subgroup_size().getInt();
+ workgroupSize[1] = 1;
+ workgroupSize[2] = 1;
+ tileSizes.emplace_back(workgroupLevelTs);
+ // No tiling at the subgroup level since this target doesn't use subgroup op
+ // or shared memory.
+ tileSizes.emplace_back();
+ SmallVector<int64_t, 4> invocationLevelTs = {
+ workgroupLevelTs[0], workgroupLevelTs[1],
+ workgroupLevelTs[2] / workgroupSize[0], workgroupLevelTs[3]};
+ tileSizes.emplace_back(invocationLevelTs);
+ return success();
+}
+
/// Launch config for `linalg.batchmatmul`.
template <>
LogicalResult getOpLaunchConfig(linalg::BatchMatmulOp op,
@@ -90,6 +139,13 @@
const SPIRVCodegenOptions &options,
TileSizesListType &tileSizes,
LaunchConfigInfo &config) {
+ if (options.enableVectorization &&
+ succeeded(getMaliSpecificConfig(op, targetEnv, options, tileSizes,
+ config.workgroupSize,
+ config.numSubgroups))) {
+ config.vectorize = true;
+ return success();
+ }
unsigned maxWorkgroupSize = targetEnv.getResourceLimits()
.max_compute_workgroup_invocations()
.getInt();
@@ -220,16 +276,12 @@
assert(lhsType.getElementType() == rhsType.getElementType());
// Pick ideal tile size based on the type.
SmallVector<int64_t, 4> workgroupLevelTs;
- if (lhsType.getElementType().isF16()) {
- workgroupLevelTs.append({16, 64, 8});
- } else {
- workgroupLevelTs.append({8, 64, 4});
- }
+ getMaliBestMatMulTileSizes(lhsType.getElementType(), workgroupLevelTs);
// Fall back to the none vectorize path for cases we don't handle.
if (!lhsType.hasStaticShape() || !rhsType.hasStaticShape() ||
lhsType.getDimSize(0) % workgroupLevelTs[0] != 0 ||
- rhsType.getDimSize(0) % workgroupLevelTs[1] != 0 ||
+ rhsType.getDimSize(1) % workgroupLevelTs[1] != 0 ||
lhsType.getDimSize(1) % workgroupLevelTs[2] != 0) {
return failure();
}
@@ -283,19 +335,74 @@
tileSizeK = 32;
}
assert(tileSizes.empty());
- SmallVector<int64_t, 4> ts = {nRowsPerWorkitem * config.workgroupSize[1],
- nColsPerWorkitem * config.workgroupSize[0],
- tileSizeK};
+ int64_t M = op.inputs()[0].getType().cast<ShapedType>().getShape()[0];
+ int64_t N = op.inputs()[1].getType().cast<ShapedType>().getShape()[1];
+ int64_t K = op.inputs()[0].getType().cast<ShapedType>().getShape()[1];
+ SmallVector<int64_t, 4> ts = {
+ getMinIfShapeStatic(M, nRowsPerWorkitem * config.workgroupSize[1]),
+ getMinIfShapeStatic(N, nColsPerWorkitem * config.workgroupSize[0]),
+ getMinIfShapeStatic(K, tileSizeK)};
tileSizes.emplace_back(std::move(ts));
return success();
}
+static LogicalResult getMaliSpecificConfig(linalg::ConvOp op,
+ TileSizesListType &tileSizes,
+ LaunchConfigInfo &config) {
+ auto inputType = op.getInput(1).getType().cast<MemRefType>();
+ auto outputType = op.getOutputBufferType(0).cast<MemRefType>();
+ if (!inputType.hasStaticShape() || !outputType.hasStaticShape())
+ return failure();
+
+ const int tileWidth = 8;
+ const int tileChannel = 32;
+
+ auto outputShape = outputType.getShape();
+ bool isInputTilable = inputType.getDimSize(3) % 4 == 0;
+ bool isOutputTilable = outputShape[0] == 1 &&
+ outputShape[2] % tileWidth == 0 &&
+ outputShape[3] % tileChannel == 0;
+ if (!isInputTilable || !isOutputTilable) return failure();
+
+ config.workgroupSize = {8, 2, 1};
+
+ SmallVector<int64_t, 4> workgroupLevel = {/*batch=*/0, /*output_height=*/1,
+ /*output_width=*/tileWidth,
+ /*output_channel=*/tileChannel};
+ tileSizes.emplace_back(std::move(workgroupLevel));
+
+ // No tiling at the subgroup level given that we don't use subgroup
+ // level syncrhonization or shared memory.
+ tileSizes.emplace_back();
+
+ SmallVector<int64_t, 4> invocationLevel = {
+ /*batch=*/0, /*output_height=*/1,
+ /*output_width=*/tileWidth / config.workgroupSize[1],
+ /*output_channel=*/tileChannel / config.workgroupSize[0]};
+ tileSizes.emplace_back(invocationLevel);
+
+ // Finally, for each invocation, we use tiling to generate loops to loop over
+ // the filter's height (step 1), width (step 1), and input channel (step 4)
+ // dimensions.
+ SmallVector<int64_t, 4> fourthLevel = {0, 0, 0, 0, 4, 1, 1};
+ tileSizes.emplace_back(fourthLevel);
+
+ config.vectorize = true;
+
+ return success();
+}
+
template <>
LogicalResult getOpLaunchConfig(linalg::ConvOp op,
const spirv::TargetEnv &targetEnv,
const SPIRVCodegenOptions &options,
TileSizesListType &tileSizes,
LaunchConfigInfo &config) {
+ if (targetEnv.getVendorID() == spirv::Vendor::ARM &&
+ succeeded(getMaliSpecificConfig(op, tileSizes, config))) {
+ return success();
+ }
+
unsigned maxWorkgroupSize = targetEnv.getResourceLimits()
.max_compute_workgroup_invocations()
.getInt();
@@ -346,55 +453,43 @@
#undef DEFINE_POOLINGOP_CONFIG
-Optional<StringRef> LaunchConfig::getKey(Operation *op) const {
- StringAttr attr = op->getAttrOfType<StringAttr>(kLaunchInfoKey);
- if (!attr) return {};
- return attr.getValue();
-}
-
-LogicalResult LaunchConfig::init(
+Optional<LaunchConfig> initGPULaunchConfig(
MLIRContext *context, const linalg::LinalgDependenceGraph &dependenceGraph,
const SPIRVCodegenOptions &options, ArrayRef<linalg::LinalgOp> linalgOps) {
- unsigned numTiledOps = 0;
- auto setKey = [&](Operation *op) -> std::string {
- std::string key = llvm::formatv("__op_num_{0}__", numTiledOps++).str();
- op->setAttr(Identifier::get(kLaunchInfoKey, context),
- StringAttr::get(key, context));
- return key;
- };
-
+ LaunchConfig launchConfig;
if (!options.workgroupSize.empty()) {
- for (linalg::LinalgOp linalgOp : linalgOps)
- tileSizes[setKey(linalgOp)].emplace_back(options.tileSizes.begin(),
- options.tileSizes.end());
- workgroupSize = {1, 1, 1};
- for (unsigned i = 0,
- e = std::min<unsigned>(3, options.workgroupSize.size());
- i != e; ++i)
- workgroupSize[i] = options.workgroupSize[i];
- return success();
+ SmallVector<int64_t, 3> tileSizes(options.tileSizes.begin(),
+ options.tileSizes.end());
+ for (linalg::LinalgOp linalgOp : linalgOps) {
+ launchConfig.setTileSizes(linalgOp.getOperation(), tileSizes, 0);
+ }
+ SmallVector<int64_t, 3> workgroupSize(options.workgroupSize.begin(),
+ options.workgroupSize.end());
+ launchConfig.setWorkgroupSize(workgroupSize);
+ return launchConfig;
}
- if (linalgOps.empty()) return success();
+ if (linalgOps.empty()) return launchConfig;
spirv::TargetEnv targetEnv(spirv::lookupTargetEnv(*linalgOps.begin()));
Optional<linalg::LinalgOp> rootOperation = {};
LaunchConfigInfo config;
for (linalg::LinalgOp linalgOp : linalgOps) {
-#define DISPATCH(opName) \
- if (auto op = dyn_cast<opName>(linalgOp.getOperation())) { \
- if (rootOperation) { \
- return op.emitError( \
- "unhandled multiple root operations in dispatch region"); \
- } \
- rootOperation = linalgOp; \
- TileSizesListType &tileSizesInfo = tileSizes[setKey(*rootOperation)]; \
- if (failed(getOpLaunchConfig(op, targetEnv, options, tileSizesInfo, \
- config))) { \
- return failure(); \
- } \
- continue; \
+#define DISPATCH(opName) \
+ if (auto op = dyn_cast<opName>(linalgOp.getOperation())) { \
+ if (rootOperation) { \
+ op.emitError("unhandled multiple root operations in dispatch region"); \
+ return llvm::None; \
+ } \
+ rootOperation = linalgOp; \
+ TileSizesListType tileSizesInfo; \
+ if (failed(getOpLaunchConfig(op, targetEnv, options, tileSizesInfo, \
+ config))) { \
+ return llvm::None; \
+ } \
+ launchConfig.setTileSizes(op, tileSizesInfo); \
+ continue; \
}
DISPATCH(linalg::BatchMatmulOp)
@@ -406,73 +501,23 @@
#undef DISPATCH
}
- workgroupSize = config.workgroupSize;
- numSubgroups = config.numSubgroups;
- vectorize = config.vectorize;
+
+ launchConfig.setWorkgroupSize(config.workgroupSize);
+ launchConfig.setNumSubgroups(config.numSubgroups);
+ launchConfig.setVectorize(config.vectorize);
+
if (!rootOperation) {
// No root operations found. Dont need to do anything.
- return success();
+ return launchConfig;
}
- // Check the dependencies going into and out of the root operation. For now
- // only the following dependencies are supported
- // - WAW dependencies going into the root operation.
- // - RAW dependencies going out of the root operation.
- // - WAW dependencies going out of the root operation.
- // i.e. there are no RAW dependences going into the root operation.
- auto inRAWDependencies = dependenceGraph.getDependencesInto(
- *rootOperation, linalg::LinalgDependenceGraph::RAW);
- if (!inRAWDependencies.empty()) {
- return rootOperation->getOperation()->emitError(
- "unhandled fusion of root operation with producer");
- }
- auto dependences =
- dependenceGraph.getDependentOperations(rootOperation.getValue());
- unsigned numOuterParallel = getNumOuterParallelLoops(*rootOperation);
-
- // Check that for all dependences into and out of the root operation,
- // - The result expressions of the indexing maps of the fused view in the
- // producer and consumer must match for the parallel loops.
- for (auto dependence :
- dependenceGraph.getDependentOperations(rootOperation.getValue())) {
- unsigned viewIndex = dependence.indexingOpView.operandIndex;
- AffineMap indexingMap = rootOperation->getIndexingMap(viewIndex);
- linalg::LinalgOp fusedOp =
- cast<linalg::LinalgOp>(dependence.dependentOpView.op);
- unsigned fusedViewIndex = dependence.dependentOpView.operandIndex;
- AffineMap fusedIndexingMap = fusedOp.getIndexingMap(fusedViewIndex);
- if (indexingMap.getNumResults() < numOuterParallel ||
- fusedIndexingMap.getNumResults() < numOuterParallel ||
- !llvm::all_of(
- llvm::seq<unsigned>(0, numOuterParallel),
- [&indexingMap, fusedIndexingMap](unsigned i) {
- return indexingMap.getResult(i).isa<AffineDimExpr>() &&
- fusedIndexingMap.getResult(i).isa<AffineDimExpr>() &&
- indexingMap.getResult(i) == fusedIndexingMap.getResult(i);
- })) {
- return rootOperation->getOperation()->emitError(
- "unhandled fusion of root operation with all operations in the "
- "dispatch region");
- }
- }
- // The dependent operations get the same tile size information as the root
- // operation. To propogate that information, just use the same key as the root
- // operation.
- for (auto dependence : dependences) {
- dependence.dependentOpView.op->setAttr(
- Identifier::get(kLaunchInfoKey, context),
- StringAttr::get(getKey(*rootOperation).getValue(), context));
- }
+ if (failed(propogateRootOperationLaunchConfig(launchConfig, *rootOperation,
+ dependenceGraph)))
+ return llvm::None;
// TODO(ravishankarm): Verify that the set configurations is within the device
// limits.
- return success();
-}
-
-void LaunchConfig::finalize(FuncOp funcOp) {
- funcOp.walk([&](linalg::LinalgOp linalgOp) {
- linalgOp.removeAttr(Identifier::get(kLaunchInfoKey, funcOp.getContext()));
- });
+ return launchConfig;
}
template <typename OpTy>
@@ -493,8 +538,14 @@
op.getAccType().cast<VectorType>().getElementType(),
op.getResultType().cast<VectorType>().getElementType());
} else {
+ unsigned lastParalleldim = 0;
+ for (auto it : llvm::enumerate(op.iterator_types())) {
+ if (isParallelIterator(it.value())) lastParalleldim = it.index();
+ }
+ SmallVector<int64_t, 4> nativeSize(op.iterator_types().size(), 1);
+ nativeSize[lastParalleldim] = 4;
// Map to vec4 fma operations.
- return SmallVector<int64_t, 4>({1, 4, 1});
+ return nativeSize;
}
}
@@ -508,10 +559,13 @@
// the contract.
return SmallVector<int64_t, 4>(op.getVectorType().getDimSize(0),
op.getVectorType().getDimSize(1));
- } else {
- // Map to load4.
- return SmallVector<int64_t, 4>({1, 4});
}
+
+ // Map to load4.
+ auto rank = op.getVectorType().getRank();
+ SmallVector<int64_t, 4> nativeSize(rank, 1);
+ nativeSize.back() = 4;
+ return nativeSize;
}
template <>
@@ -524,10 +578,13 @@
// the contract.
return SmallVector<int64_t, 4>(op.getVectorType().getDimSize(0),
op.getVectorType().getDimSize(1));
- } else {
- // Map to store4.
- return SmallVector<int64_t, 4>({1, 4});
}
+
+ // Map to store4.
+ auto rank = op.getVectorType().getRank();
+ SmallVector<int64_t, 4> nativeSize(rank, 1);
+ nativeSize.back() = 4;
+ return nativeSize;
}
Optional<SmallVector<int64_t, 4>> getNativeVectorSize(Operation *op) {
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.h b/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.h
index 72ee6d2..8f01683 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.h
+++ b/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.h
@@ -19,11 +19,13 @@
// the number of workgroups to use for launch, etc.
//
//===----------------------------------------------------------------------===//
+
#ifndef IREE_COMPILER_CONVERSION_LINALGTOSPIRV_KERNELDISPATCHUTILS_H_
#define IREE_COMPILER_CONVERSION_LINALGTOSPIRV_KERNELDISPATCHUTILS_H_
#include <array>
+#include "iree/compiler/Conversion/Common/LaunchConfig.h"
#include "iree/compiler/Conversion/LinalgToSPIRV/CodeGenOptionUtils.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringMap.h"
@@ -41,94 +43,9 @@
namespace mlir {
namespace iree_compiler {
-/// Store the tile sizes to use at different levels of tiling as a vector of
-/// vectors.
-/// - First level tiling maps to workgroups.
-/// - Second level tiling maps to subgroups.
-using TileSizesListType = SmallVector<SmallVector<int64_t, 4>, 1>;
-
-/// Based on the linalg operations in a dispatch region, the number of levels of
-/// tiling, the tile sizes needed, the workgroup size, etc. need to be
-/// decided. These parameters are called `LaunchConfig`. This class implements
-/// one heuristic to compute these for the different linalg operations on
-/// buffers. This can be adapted later to support multiple configurations that
-/// can be picked based on device information/problem size information. It
-/// exposes the information needed by the codegenerators, and hides the
-/// implementation from the rest of the pipeline.
-class LaunchConfig {
- public:
- LaunchConfig() : workgroupSize({1, 1, 1}), numSubgroups({1, 1, 1}) {}
-
- /// Given the sequence of `linalgOps` (and `options`), decide the launch
- /// configuration by deciding
- /// - the number of levels of tiling,
- /// - tile sizes for each level,
- /// - the workgroup size, and
- /// - number of subgroups to use.
- LogicalResult init(MLIRContext *context,
- const linalg::LinalgDependenceGraph &dependenceGraph,
- const SPIRVCodegenOptions &options,
- ArrayRef<linalg::LinalgOp> linalgOps);
-
- /// Remove attributed added to operations for retrieving tile size
- /// information.
- void finalize(FuncOp funcOp);
-
- /// Gets the tile size computed for an operation at all levels.
- TileSizesListType getTileSizes(Operation *op) const {
- auto key = getKey(op);
- if (!key) return {};
- auto it = tileSizes.find(*key);
- return it->second;
- }
-
- /// Gets the tile size computed for an operation for an level.
- ArrayRef<int64_t> getTileSizes(Operation *op, size_t level) const {
- auto key = getKey(op);
- if (!key) return {};
- auto it = tileSizes.find(*key);
- if (it == tileSizes.end() || level >= it->second.size()) return {};
- return it->second[level];
- }
-
- /// Returns the workgroup size to use based on the tile sizes.
- ArrayRef<int64_t> getWorkgroupSize() const { return workgroupSize; }
-
- /// Returns the number of subgroups to use.
- ArrayRef<int64_t> getNumSubgroups() const { return numSubgroups; }
-
- /// Returns true if tile sizes have been computed for the operation. If tile
- /// sizes arent set, it implies operation is not to be tiled.
- bool hasTileSizes(Operation *op, size_t level = 0) const {
- return !getTileSizes(op, level).empty();
- }
-
- /// Use vectorize transformations.
- bool useVectorize() const { return vectorize; }
-
- protected:
- /// Current tile size configuration per operation. They key used here to
- /// retrieve the tile size information per operation is the value of a StrAttr
- /// added to operations during `init`. When tiled this attribute is copied
- /// over to the tiled operation, thereby the same key can be used to retrieve
- /// the tile sizes for the next level of tiling. The `finalize` method removes
- /// these attributes.
- llvm::StringMap<TileSizesListType> tileSizes;
-
- /// Workgroup size to use.
- std::array<int64_t, 3> workgroupSize;
-
- /// Number of subgroups that are logically distributed along x, y & z.
- std::array<int64_t, 3> numSubgroups;
-
- /// Use vectorization.
- bool vectorize = false;
-
- private:
- /// Retrieves the key to use to get the `tileSizes` for a given
- /// `operation`. Returns llvm::None on failure.
- Optional<StringRef> getKey(Operation *op) const;
-};
+Optional<LaunchConfig> initGPULaunchConfig(
+ MLIRContext *context, const linalg::LinalgDependenceGraph &dependenceGraph,
+ const SPIRVCodegenOptions &options, ArrayRef<linalg::LinalgOp> linalgOps);
/// Returns the size of instruction in `vector` dialect that maps directly to
/// the hardware.
@@ -136,4 +53,5 @@
} // namespace iree_compiler
} // namespace mlir
+
#endif // IREE_COMPILER_CONVERSION_LINALGTOSPIRV_DISPATCHUTILS_H_
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndFusePass.cpp b/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndFusePass.cpp
index 34c6632..5be1613 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndFusePass.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndFusePass.cpp
@@ -21,13 +21,14 @@
#include "iree/compiler/Conversion/CodegenUtils/FunctionUtils.h"
#include "iree/compiler/Conversion/CodegenUtils/GetNumWorkgroups.h"
#include "iree/compiler/Conversion/CodegenUtils/MarkerUtils.h"
-#include "iree/compiler/Conversion/CodegenUtils/MatmulCodegenStrategy.h"
#include "iree/compiler/Conversion/Common/Attributes.h"
+#include "iree/compiler/Conversion/Common/Transforms.h"
#include "iree/compiler/Conversion/LinalgToSPIRV/CodeGenOptionUtils.h"
#include "iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.h"
#include "iree/compiler/Conversion/LinalgToSPIRV/MemorySpace.h"
#include "iree/compiler/Conversion/LinalgToSPIRV/Passes.h"
#include "iree/compiler/Conversion/LinalgToSPIRV/Utils.h"
+#include "iree/compiler/Conversion/LinalgToVector/Passes.h"
#include "iree/compiler/Dialect/Shape/IR/ShapeDialect.h"
#include "mlir/Dialect/GPU/GPUDialect.h"
#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
@@ -35,6 +36,8 @@
#include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/Vector/VectorTransforms.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/Identifier.h"
#include "mlir/IR/Matchers.h"
@@ -73,6 +76,32 @@
return linalg::LinalgMarker(markers, Identifier::get(replaceMarker, context));
}
+/// Returns the distribution options for operations when targeting workgroups.
+static linalg::LinalgLoopDistributionOptions getWorkgroupDistributionOptions() {
+ linalg::LinalgLoopDistributionOptions options;
+
+ options.procInfo = [](OpBuilder &builder, Location loc,
+ ArrayRef<Range> parallelLoopRanges) {
+ return getGPUProcessorIdsAndCounts<gpu::BlockIdOp, gpu::GridDimOp>(
+ builder, loc, parallelLoopRanges.size());
+ };
+ options.distributionMethod.assign(
+ 3, linalg::DistributionMethod::CyclicNumProcsEqNumIters);
+
+ return options;
+}
+
+/// Applies canonicalization over index calculation inside the given `funcOp`.
+static void applyIndexCalculationCanonicalization(FuncOp funcOp) {
+ MLIRContext *context = funcOp.getContext();
+ OwningRewritePatternList canonicalizationPatterns;
+ DimOp::getCanonicalizationPatterns(canonicalizationPatterns, context);
+ AddIOp::getCanonicalizationPatterns(canonicalizationPatterns, context);
+ SubIOp::getCanonicalizationPatterns(canonicalizationPatterns, context);
+ SignedDivIOp::getCanonicalizationPatterns(canonicalizationPatterns, context);
+ applyPatternsAndFoldGreedily(funcOp, std::move(canonicalizationPatterns));
+}
+
//===----------------------------------------------------------------------===//
// Main pass
//===----------------------------------------------------------------------===//
@@ -100,154 +129,6 @@
} // namespace
//===----------------------------------------------------------------------===//
-// Patterns to tile computation to map to workgroups
-//===----------------------------------------------------------------------===//
-
-/// Returns the distribution options for operations when targeting workgroups.
-static linalg::LinalgLoopDistributionOptions getWorkgroupDistributionOptions() {
- linalg::LinalgLoopDistributionOptions options;
-
- options.procInfo = [](OpBuilder &builder, Location loc,
- ArrayRef<Range> parallelLoopRanges) {
- return getGPUProcessorIdsAndCounts<gpu::BlockIdOp, gpu::GridDimOp>(
- builder, loc, parallelLoopRanges.size());
- };
- options.distributionMethod = {
- linalg::DistributionMethod::CyclicNumProcsEqNumIters,
- linalg::DistributionMethod::CyclicNumProcsEqNumIters,
- linalg::DistributionMethod::CyclicNumProcsEqNumIters};
-
- return options;
-}
-
-namespace {
-/// Pattern for tiling operations. Updates the workgroup size in the surrounding
-/// function operation if tiling succeeds, and generates the function that
-/// computes the number of workgroups for the launch.
-template <typename LinalgOpTy>
-class TileToWorkgroupsPattern : public linalg::LinalgBaseTilingPattern {
- public:
- TileToWorkgroupsPattern(MLIRContext *context,
- const linalg::LinalgDependenceGraph &dependenceGraph,
- linalg::LinalgTilingOptions options,
- linalg::LinalgMarker marker,
- const LaunchConfig &launchConfig,
- PatternBenefit benefit = 1)
- : Base(LinalgOpTy::getOperationName(), context, options, marker, benefit),
- dependenceGraph(dependenceGraph),
- launchConfig(launchConfig) {}
-
- LogicalResult matchAndRewrite(Operation *op,
- PatternRewriter &rewriter) const override {
- // Find the parent FuncOp before tiling. If tiling succeeds, the op will be
- // erased.
- FuncOp funcOp = op->getParentOfType<FuncOp>();
- SmallVector<Value, 4> tensorResults;
- linalg::LinalgOp linalgOp = cast<linalg::LinalgOp>(op);
- if (!funcOp || dependenceGraph.hasDependentOperations(linalgOp) ||
- failed(Base::matchAndRewriteBase(op, rewriter, tensorResults)) ||
- !tensorResults.empty() ||
- failed(updateWorkGroupSize(funcOp, launchConfig.getWorkgroupSize())) ||
- (funcOp.getAttr(getNumWorkgroupsFnAttrName()) &&
- failed(createNumWorkgroupsFromResultShape(
- rewriter, linalgOp, funcOp, getNumWorkgroupsFnAttrName(),
- launchConfig.getTileSizes(op, 0))))) {
- return failure();
- }
- setMarker(op, getDeleteMarker());
- return success();
- }
-
- private:
- using Base = linalg::LinalgBaseTilingPattern;
-
- const linalg::LinalgDependenceGraph &dependenceGraph;
- const LaunchConfig &launchConfig;
-};
-
-/// Pattern for tile + fuse of operations. Updates the workgroup size in the
-/// surrounding function operation if tiling succeeds, and generates the
-/// function that computes the number of workgroups for the launch..
-template <typename LinalgOpTy>
-class TileAndFuseToWorkgroupsPattern
- : public linalg::LinalgTileAndFusePattern<LinalgOpTy> {
- public:
- TileAndFuseToWorkgroupsPattern(
- MLIRContext *context,
- const linalg::LinalgDependenceGraph &dependenceGraph,
- linalg::LinalgTilingOptions tilingOptions, linalg::LinalgMarker marker,
- const LaunchConfig &launchConfig, PatternBenefit benefit = 1)
- : Base(context, dependenceGraph, tilingOptions,
- linalg::LinalgFusionOptions().setIndicesToFuse({2}), marker,
- marker, getLinalgReplaceMarker(getDeleteMarker(), context),
- benefit),
- dependenceGraph(dependenceGraph),
- launchConfig(launchConfig) {}
-
- virtual LogicalResult matchAndRewrite(Operation *op,
- PatternRewriter &rewriter) const {
- FuncOp funcOp = op->getParentOfType<FuncOp>();
- linalg::LinalgOp linalgOp = cast<linalg::LinalgOp>(op);
- if (!funcOp || !dependenceGraph.hasDependentOperations(linalgOp) ||
- failed(Base::matchAndRewrite(op, rewriter)) ||
- failed(updateWorkGroupSize(funcOp, launchConfig.getWorkgroupSize())) ||
- (funcOp.getAttr(getNumWorkgroupsFnAttrName()) &&
- failed(createNumWorkgroupsFromResultShape(
- rewriter, linalgOp, funcOp, getNumWorkgroupsFnAttrName(),
- launchConfig.getTileSizes(op, 0))))) {
- return failure();
- }
- return success();
- }
-
- private:
- using Base = linalg::LinalgTileAndFusePattern<LinalgOpTy>;
-
- const linalg::LinalgDependenceGraph &dependenceGraph;
- const LaunchConfig &launchConfig;
-};
-} // namespace
-
-/// Populate patterns for first-level tiling.
-static void populateTilingToWorkgroupPatterns(
- MLIRContext *context, const linalg::LinalgDependenceGraph &dependenceGraph,
- const LaunchConfig &launchConfig, OwningRewritePatternList &patterns) {
- // Function to compute first level tiling values.
- auto getOuterTileSizeFn = [&launchConfig](
- OpBuilder &builder,
- Operation *operation) -> SmallVector<Value, 4> {
- ArrayRef<int64_t> tileSizes = launchConfig.getTileSizes(operation, 0);
- if (tileSizes.empty()) return {};
- SmallVector<Value, 4> tileSizesVal;
- tileSizesVal.reserve(tileSizes.size());
- for (auto val : tileSizes) {
- tileSizesVal.push_back(
- builder.create<ConstantIndexOp>(operation->getLoc(), val));
- }
- return tileSizesVal;
- };
-
- patterns.insert<TileAndFuseToWorkgroupsPattern<linalg::BatchMatmulOp>,
- TileAndFuseToWorkgroupsPattern<linalg::ConvOp>,
- TileAndFuseToWorkgroupsPattern<linalg::MatmulOp>,
- TileAndFuseToWorkgroupsPattern<linalg::PoolingMaxOp>,
- TileAndFuseToWorkgroupsPattern<linalg::PoolingMinOp>,
- TileAndFuseToWorkgroupsPattern<linalg::PoolingSumOp>,
- TileToWorkgroupsPattern<linalg::BatchMatmulOp>,
- TileToWorkgroupsPattern<linalg::ConvOp>,
- TileToWorkgroupsPattern<linalg::MatmulOp>,
- TileToWorkgroupsPattern<linalg::PoolingMaxOp>,
- TileToWorkgroupsPattern<linalg::PoolingMinOp>,
- TileToWorkgroupsPattern<linalg::PoolingSumOp>>(
- context, dependenceGraph,
- linalg::LinalgTilingOptions()
- .setDistributionOptions(getWorkgroupDistributionOptions())
- .setTileSizeComputationFunction(getOuterTileSizeFn)
- .setLoopType(linalg::LinalgTilingLoopType::ParallelLoops),
- getLinalgReplaceMarker(getWorkgroupMarker(), context), launchConfig);
-}
-
-//===----------------------------------------------------------------------===//
// Patterns to promote subviews to workgroup memory
//===----------------------------------------------------------------------===//
@@ -404,35 +285,36 @@
return tileSizesVal;
};
- auto getThreadProcInfoFn = [&launchConfig](
- OpBuilder &builder, Location loc,
- ArrayRef<Range> parallelLoopRanges) {
- Type indexType = builder.getIndexType();
- SmallVector<linalg::ProcInfo, 2> procInfo(2);
- procInfo[1] = {builder.create<gpu::ThreadIdOp>(loc, indexType,
- builder.getStringAttr("x")),
- builder.create<ConstantIndexOp>(
- loc, launchConfig.getWorkgroupSize()[0])};
- procInfo[0] = {builder.create<gpu::ThreadIdOp>(loc, indexType,
- builder.getStringAttr("y")),
- builder.create<ConstantIndexOp>(
- loc, launchConfig.getWorkgroupSize()[1])};
- return procInfo;
+ auto getThreadProcInfoFn = [](OpBuilder &builder, Location loc,
+ ArrayRef<Range> parallelLoopRanges) {
+ return getGPUProcessorIdsAndCounts<gpu::ThreadIdOp, gpu::BlockDimOp>(
+ builder, loc, parallelLoopRanges.size());
};
- linalg::LinalgLoopDistributionOptions subgroupDistributionOptions = {
+ linalg::LinalgLoopDistributionOptions invocationDistributionOptions = {
getThreadProcInfoFn,
{linalg::DistributionMethod::CyclicNumProcsEqNumIters,
+ linalg::DistributionMethod::CyclicNumProcsEqNumIters,
linalg::DistributionMethod::CyclicNumProcsEqNumIters}};
- patterns.insert<linalg::LinalgTilingPattern<linalg::MatmulOp>,
- linalg::LinalgTilingPattern<linalg::FillOp>>(
- context,
+
+ auto tilingOptions =
linalg::LinalgTilingOptions()
.setLoopType(linalg::LinalgTilingLoopType::ParallelLoops)
.setTileSizeComputationFunction(getInnerTileSizeFn)
- .setDistributionOptions(subgroupDistributionOptions),
+ .setDistributionOptions(invocationDistributionOptions);
+
+ patterns.insert<linalg::LinalgTilingPattern<linalg::MatmulOp>,
+ linalg::LinalgTilingPattern<linalg::FillOp>,
+ linalg::LinalgTilingPattern<linalg::BatchMatmulOp>>(
+ context, tilingOptions,
getLinalgMatchAndReplaceMarker(
{getWorkgroupMemoryMarker(), getWorkgroupMarker()},
getVectorizeMarker(), context));
+
+ patterns.insert<linalg::LinalgTilingPattern<linalg::ConvOp>>(
+ context, tilingOptions,
+ getLinalgMatchAndReplaceMarker(
+ {getWorkgroupMemoryMarker(), getWorkgroupMarker()},
+ getConvFilterTileMarker(), context));
}
//====---------------------------------------------------------------------===//
@@ -443,22 +325,12 @@
const LaunchConfig &launchConfig,
OwningRewritePatternList &patterns) {
patterns.insert<linalg::LinalgVectorizationPattern<linalg::MatmulOp>,
+ linalg::LinalgVectorizationPattern<linalg::BatchMatmulOp>,
linalg::LinalgVectorizationPattern<linalg::FillOp>>(
context,
linalg::LinalgMarker(Identifier::get(getVectorizeMarker(), context)));
}
-/// Apply canonicalizations related to tiling to make promotion/vectorization
-/// easier.
-static void applyCanonicalizationPatterns(MLIRContext *context, Operation *op) {
- OwningRewritePatternList canonicalizationPatterns;
- canonicalizationPatterns.insert<AffineMinCanonicalizationPattern>(context);
- AffineApplyOp::getCanonicalizationPatterns(canonicalizationPatterns, context);
- AffineMinOp::getCanonicalizationPatterns(canonicalizationPatterns, context);
- SubViewOp::getCanonicalizationPatterns(canonicalizationPatterns, context);
- applyPatternsAndFoldGreedily(op, std::move(canonicalizationPatterns));
-}
-
//====---------------------------------------------------------------------===//
// Patterns for unrolling vectors
//====---------------------------------------------------------------------===//
@@ -507,6 +379,46 @@
}
//====---------------------------------------------------------------------===//
+// Patterns to tile convolution window dimensions
+//====---------------------------------------------------------------------===//
+
+static void populateTilingConvFilterPatterns(MLIRContext *context,
+ OwningRewritePatternList &patterns,
+ const LaunchConfig &launchConfig,
+ linalg::LinalgMarker marker) {
+ auto getTileSizeFn = [&launchConfig](OpBuilder &builder, Operation *op) {
+ SmallVector<Value, 4> tileSizes;
+ ArrayRef<int64_t> fourthLevel = launchConfig.getTileSizes(op, 3);
+ tileSizes.reserve(fourthLevel.size());
+
+ Location loc = op->getLoc();
+ for (int64_t size : fourthLevel) {
+ tileSizes.push_back(builder.create<ConstantIndexOp>(loc, size));
+ }
+ return tileSizes;
+ };
+
+ // TODO(antiagainst): move this to launch configuration.
+ SmallVector<unsigned, 8> loopOrder = {
+ /*batch=*/0,
+ /*output_height=*/1,
+ /*output_width=*/2,
+ /*output_channel=*/3,
+ /*filter_height=*/5,
+ /*filter_width=*/6,
+ /*input_channel=*/4,
+ };
+
+ auto tilingOptions = linalg::LinalgTilingOptions()
+ .setLoopType(linalg::LinalgTilingLoopType::Loops)
+ .setInterchange(loopOrder)
+ .setTileSizeComputationFunction(getTileSizeFn);
+
+ patterns.insert<linalg::LinalgTilingPattern<linalg::ConvOp>>(
+ context, tilingOptions, marker);
+}
+
+//====---------------------------------------------------------------------===//
// Main pass implementation
//====---------------------------------------------------------------------===//
@@ -528,18 +440,19 @@
auto linalgOps = block.getOps<linalg::LinalgOp>();
if (linalgOps.empty()) continue;
- LaunchConfig launchConfig;
SmallVector<linalg::LinalgOp, 4> linalgOpsVec =
llvm::to_vector<4>(llvm::map_range(linalgOps, [](Operation *op) {
return cast<linalg::LinalgOp>(op);
}));
linalg::Aliases aliases;
linalg::LinalgDependenceGraph dependenceGraph(aliases, linalgOpsVec);
- if (failed(launchConfig.init(context, dependenceGraph, options,
- linalgOpsVec))) {
+ Optional<LaunchConfig> launchConfigOpt =
+ initGPULaunchConfig(context, dependenceGraph, options, linalgOpsVec);
+ if (!launchConfigOpt) {
funcOp.emitError("unable to find launch configuration");
return signalPassFailure();
}
+ LaunchConfig &launchConfig = *launchConfigOpt;
LLVM_DEBUG({
llvm::dbgs() << "@func " << funcOp.getName() << ": # workgroup sizes: [";
@@ -547,7 +460,7 @@
llvm::dbgs() << "]\n";
for (auto op : linalgOps) {
llvm::dbgs() << "\t" << op.getOperation()->getName() << " : ";
- TileSizesListType const &tileSizes = launchConfig.getTileSizes(op);
+ TileSizesListTypeRef tileSizes = launchConfig.getTileSizes(op);
llvm::dbgs() << "{";
std::string sep = "";
for (auto &level : enumerate(tileSizes)) {
@@ -560,26 +473,38 @@
}
});
- {
- OwningRewritePatternList firstLevelTilingPatterns;
- populateTilingToWorkgroupPatterns(context, dependenceGraph, launchConfig,
- firstLevelTilingPatterns);
- applyPatternsAndFoldGreedily(funcOp, std::move(firstLevelTilingPatterns));
- applyCanonicalizationPatterns(context, funcOp);
-
- // Delete the ops that are marked for deletion.
- funcOp.walk([](linalg::LinalgOp linalgOp) {
- if (hasMarker(linalgOp.getOperation(), getDeleteMarker()))
- linalgOp.getOperation()->erase();
- });
+ TileAndFuseOptions tileAndFuseOptions = {getWorkgroupDistributionOptions(),
+ allocateWorkgroupMemory};
+ if (failed(tileAndFuseLinalgBufferOps(funcOp, linalgOpsVec, dependenceGraph,
+ launchConfig, tileAndFuseOptions)) ||
+ failed(updateWorkGroupSize(funcOp, launchConfig.getWorkgroupSize()))) {
+ return signalPassFailure();
}
LLVM_DEBUG({
- llvm::dbgs() << "--- After First level of tile+distribute ---\n";
+ llvm::dbgs() << "--- After first level of tiling and distribution ---\n";
funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
llvm::dbgs() << "\n\n";
});
+ // In the above we distributed ops to workgroup dimensions and populated a
+ // function for calculating the number of workgroups. In the folling steps,
+ // we will need to query the workgroup count function to simplify GPU
+ // processor ID uses. It relies on constant upper bounds. So we need to
+ // canonicalize the workgroup count function first.
+ if (funcOp.getAttrOfType<SymbolRefAttr>(getNumWorkgroupsFnAttrName())) {
+ FuncOp numWorkGroupFunc =
+ getNumWorkgroupsFn(funcOp, getNumWorkgroupsFnAttrName());
+ applyIndexCalculationCanonicalization(numWorkGroupFunc);
+
+ LLVM_DEBUG({
+ llvm::dbgs()
+ << "--- After canonicalizing workgroup count function ---\n";
+ numWorkGroupFunc.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
+ llvm::dbgs() << "\n\n";
+ });
+ }
+
if (options.useWorkgroupMemory) {
// The promotion patterns are put separate from the tiling patterns to
// make sure that the allocated scratchspace memory is constant sizes
@@ -587,7 +512,7 @@
OwningRewritePatternList promotionPatterns;
populatePromotionPatterns(context, promotionPatterns);
applyPatternsAndFoldGreedily(funcOp, std::move(promotionPatterns));
- applyCanonicalizationPatterns(context, funcOp);
+ applyCanonicalizationPatternsForTiling(context, funcOp);
LLVM_DEBUG({
llvm::dbgs() << "--- After Promotion ---\n";
@@ -603,7 +528,7 @@
secondLevelTilingPatterns);
applyPatternsAndFoldGreedily(funcOp,
std::move(secondLevelTilingPatterns));
- applyCanonicalizationPatterns(context, funcOp);
+ applyCanonicalizationPatternsForTiling(context, funcOp);
promoteSingleIterationLoops(funcOp);
LLVM_DEBUG({
@@ -619,7 +544,7 @@
thirdLevelTilingPatterns);
applyPatternsAndFoldGreedily(funcOp,
std::move(thirdLevelTilingPatterns));
- applyCanonicalizationPatterns(context, funcOp);
+ applyCanonicalizationPatternsForTiling(context, funcOp);
promoteSingleIterationLoops(funcOp);
LLVM_DEBUG({
@@ -630,9 +555,29 @@
}
{
+ OwningRewritePatternList tilingPatterns;
+ auto marker = getLinalgMatchAndReplaceMarker(
+ getConvFilterTileMarker(), getVectorizeMarker(), context);
+ populateTilingConvFilterPatterns(context, tilingPatterns, launchConfig,
+ marker);
+ populateFoldGPUProcessorIDUsesPatterns(context, tilingPatterns);
+ tilingPatterns.insert<linalg::AffineMinSCFCanonicalizationPattern>(
+ context);
+ applyPatternsAndFoldGreedily(funcOp, std::move(tilingPatterns));
+ applyCanonicalizationPatternsForTiling(context, funcOp);
+
+ LLVM_DEBUG({
+ llvm::dbgs() << "--- After tiling linalg.conv ---\n";
+ funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
+ llvm::dbgs() << "\n\n";
+ });
+ }
+
+ {
OwningRewritePatternList vectorizationPatterns;
populateVectorizationPatterns(context, launchConfig,
vectorizationPatterns);
+ populateVectorizeLinalgConvPatterns(context, vectorizationPatterns);
applyPatternsAndFoldGreedily(funcOp, std::move(vectorizationPatterns));
LLVM_DEBUG({
llvm::dbgs() << "--- After Vectorization ---\n";
@@ -645,11 +590,6 @@
}
launchConfig.finalize(funcOp);
- SmallVector<linalg::LinalgOp, 1> toDelete;
- funcOp.walk([&](linalg::LinalgOp linalgOp) {
- if (hasMarker(linalgOp.getOperation(), getDeleteMarker()))
- linalgOp.erase();
- });
}
}
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/Passes.cpp b/iree/compiler/Conversion/LinalgToSPIRV/Passes.cpp
index b014e19..25a20af 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/Passes.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/Passes.cpp
@@ -69,14 +69,15 @@
// - All Linalg ops have buffer semantics.
//
// Post-conditions:
+ // - If there are multiple linalg operations in the dispatch region, they
+ // are fused, using tile+fuse approach.
+ // - The fused loops are distributed across workgroups.
// - The operations that cannot be fused at buffer levels are split into
// separate entry points.
- // - If the input Linalg ops are tilable:
- // - loop.parallel ops are generated for mapping to workgroups.
- // - Linalg ops are nested inside loop.parallel ops and ready for mapping
- // to workitems.
- // - If multiple linalg operations are present they get tiled and fused to
- // get outer loop.parallel ops which can be mapped to workitems.
+ // - If there is a single linalg operation in the dispatch region, it is
+ // tiled and the generated parallel loop distributed.
+ // - The tiled linalg operation can be tiled again one or more times and
+ // then vectorized.
// - Otherwise:
// - The Linalg op is kept untouched.
//
@@ -147,6 +148,9 @@
// - Load/store on std.subview ops are converted into load/store on the
// original buffers.
//===--------------------------------------------------------------------===//
+ if (options.enableVectorization) {
+ pm.addNestedPass<FuncOp>(createVectorTransferOptimizationPass());
+ }
pm.addPass(createLegalizeStdOpsForSPIRVLoweringPass());
pm.addPass(createCanonicalizerPass());
pm.addPass(createCSEPass());
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/SplitDispatchFunctionPass.cpp b/iree/compiler/Conversion/LinalgToSPIRV/SplitDispatchFunctionPass.cpp
index 09ceb74..f195f10 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/SplitDispatchFunctionPass.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/SplitDispatchFunctionPass.cpp
@@ -38,6 +38,7 @@
#include "llvm/Support/FormatVariadic.h"
#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
+#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Attributes.h"
@@ -84,12 +85,16 @@
linalg::LinalgDependenceGraph::DependenceType::WAW)
ADD_FUSABLE_PAIR(linalg::FillOp, linalg::MatmulOp,
linalg::LinalgDependenceGraph::DependenceType::WAW)
+ ADD_FUSABLE_PAIR(linalg::FillOp, linalg::BatchMatmulOp,
+ linalg::LinalgDependenceGraph::DependenceType::WAW)
ADD_FUSABLE_PAIR(linalg::FillOp, linalg::PoolingMaxOp,
linalg::LinalgDependenceGraph::DependenceType::WAW)
ADD_FUSABLE_PAIR(linalg::FillOp, linalg::PoolingMinOp,
linalg::LinalgDependenceGraph::DependenceType::WAW)
ADD_FUSABLE_PAIR(linalg::FillOp, linalg::PoolingSumOp,
linalg::LinalgDependenceGraph::DependenceType::WAW)
+ ADD_FUSABLE_PAIR(linalg::MatmulOp, linalg::GenericOp,
+ linalg::LinalgDependenceGraph::DependenceType::RAW)
#undef ADD_FUSABLE_PAIR
}
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/Utils.cpp b/iree/compiler/Conversion/LinalgToSPIRV/Utils.cpp
index 6e01249..26ed11d 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/Utils.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/Utils.cpp
@@ -163,15 +163,5 @@
getGPUProcessorIdsAndCounts<GPUGlobalId, GPUGlobalCount>(OpBuilder &builder,
Location loc,
unsigned numDims);
-
-unsigned getNumOuterParallelLoops(linalg::LinalgOp op) {
- return op.iterator_types()
- .getValue()
- .take_while([](Attribute attr) -> bool {
- return attr.cast<StringAttr>().getValue() ==
- getParallelIteratorTypeName();
- })
- .size();
-}
} // namespace iree_compiler
} // namespace mlir
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/Utils.h b/iree/compiler/Conversion/LinalgToSPIRV/Utils.h
index ea1316c..0296226 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/Utils.h
+++ b/iree/compiler/Conversion/LinalgToSPIRV/Utils.h
@@ -56,10 +56,6 @@
SmallVector<linalg::ProcInfo, 2> getGPUProcessorIdsAndCounts(OpBuilder &builder,
Location loc,
unsigned numDims);
-
-/// Function to get number of outer parallel loops of a linalgOp
-unsigned getNumOuterParallelLoops(linalg::LinalgOp op);
-
/// Updates the workgroup size used for the dispatch region.
LogicalResult updateWorkGroupSize(FuncOp funcOp,
ArrayRef<int64_t> workGroupSize);
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/VectorToGPUPass.cpp b/iree/compiler/Conversion/LinalgToSPIRV/VectorToGPUPass.cpp
index bf032b7..82a94f3 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/VectorToGPUPass.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/VectorToGPUPass.cpp
@@ -248,7 +248,7 @@
};
// Lower vector contract to a single scalar or vector mulf+addf. Insert casts to
-// convert from 2D vector to 1D vector or scalar.
+// convert from N-D vector to 1D vector or scalar.
class VectorContractLowering : public OpRewritePattern<vector::ContractionOp> {
public:
using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
@@ -256,19 +256,20 @@
LogicalResult matchAndRewrite(vector::ContractionOp op,
PatternRewriter &rewriter) const override {
auto iteratorTypes = op.iterator_types().getValue();
- if (iteratorTypes.size() != 3 || !isParallelIterator(iteratorTypes[0]) ||
- !isParallelIterator(iteratorTypes[1]) ||
- !isReductionIterator(iteratorTypes[2]) ||
- !isRowMajorMatmul(op.indexing_maps())) {
+ if (!isReductionIterator(iteratorTypes.back()) ||
+ op.getContractingDimMap().size() > 1)
return failure();
- }
if (op.getLhsType().getNumElements() != 1) return failure();
- unsigned vecSize = op.getAccType().cast<VectorType>().getNumElements();
- if (!(vecSize >= 1 && vecSize <= 4)) return failure();
+ auto accType = op.getAccType().cast<VectorType>();
+ auto rhsType = op.getRhsType();
+ unsigned vecSize = accType.getNumElements();
+ if (accType != rhsType || !(vecSize >= 1 && vecSize <= 4) ||
+ accType.getShape().back() != vecSize)
+ return failure();
auto loc = op.getLoc();
VectorType vecType = VectorType::get(
vecSize, op.getResultType().cast<VectorType>().getElementType());
- std::array<int64_t, 2> zero = {0, 0};
+ llvm::SmallVector<int64_t, 4> zero(iteratorTypes.size() - 1, 0);
Value lhs = rewriter.create<vector::ExtractOp>(loc, op.lhs(), zero);
Value rhs, acc;
if (vecSize == 1) {
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/VectorizeMemref.cpp b/iree/compiler/Conversion/LinalgToSPIRV/VectorizeMemref.cpp
index 5a610d3..2f3616d 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/VectorizeMemref.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/VectorizeMemref.cpp
@@ -84,6 +84,7 @@
if (memrefType && !memrefType.getElementType().isa<VectorType>() &&
(kMaxVectorizationSizeInBits % memrefType.getElementTypeBitWidth() ==
0) &&
+ memrefType.getRank() > 0 &&
!ShapedType::isDynamic(memrefType.getShape().back()) &&
getUsesIfAllTransferOp(v, uses)) {
return calculateMemrefVecSize(uses);
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/batch_matmul_vectorization.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/batch_matmul_vectorization.mlir
new file mode 100644
index 0000000..f5aee9a
--- /dev/null
+++ b/iree/compiler/Conversion/LinalgToSPIRV/test/batch_matmul_vectorization.mlir
@@ -0,0 +1,333 @@
+// RUN: iree-opt -split-input-file -pass-pipeline="iree-codegen-linalg-tile-and-fuse,canonicalize,cse" -iree-spirv-enable-vectorization %s | IreeFileCheck %s
+
+module attributes {
+ spv.target_env =
+ #spv.target_env<#spv.vce<v1.3,
+ [Shader, Float64, Float16, Int64, Int16, Int8, StorageBuffer16BitAccess,
+ StorageUniform16, StoragePushConstant16, StorageBuffer8BitAccess,
+ UniformAndStorageBuffer8BitAccess, StoragePushConstant8, GroupNonUniform,
+ GroupNonUniformVote, GroupNonUniformArithmetic, GroupNonUniformBallot,
+ GroupNonUniformShuffle, GroupNonUniformShuffleRelative, VariablePointers,
+ VariablePointersStorageBuffer],
+ [SPV_KHR_16bit_storage, SPV_KHR_8bit_storage,
+ SPV_KHR_storage_buffer_storage_class, SPV_KHR_variable_pointers]>,
+ ARM:IntegratedGPU,
+ {max_compute_shared_memory_size = 32768 : i32,
+ max_compute_workgroup_invocations = 512 : i32,
+ max_compute_workgroup_size = dense<512> : vector<3xi32>,
+ subgroup_size = 16 : i32}>} {
+ func @batch_matmul_static_shape()
+ attributes {vkspv.num_workgroups_fn = @matmul_static_shape__num_workgroups__} {
+ %arg0 = iree.placeholder for "interface buffer"
+ {binding = @legacy_io::@arg0, operand_result_num = 0 : i32} : memref<4x1024x1024xf32>
+ %arg1 = iree.placeholder for "interface buffer"
+ {binding = @legacy_io::@arg1, operand_result_num = 1 : i32} : memref<4x1024x1024xf32>
+ %ret0 = iree.placeholder for "interface buffer"
+ {binding = @legacy_io::@ret0, operand_result_num = 2 : i32} : memref<4x1024x1024xf32>
+ linalg.batch_matmul ins(%arg0, %arg1 : memref<4x1024x1024xf32>, memref<4x1024x1024xf32>) outs(%ret0 : memref<4x1024x1024xf32>)
+ return
+ }
+ func @matmul_static_shape__num_workgroups__
+ (!shapex.ranked_shape<[4096, 4096]>, !shapex.ranked_shape<[4096, 4096]>,
+ !shapex.ranked_shape<[4096, 4096]>) -> (index, index, index)
+ attributes {sym_visibility = "private"}
+ hal.interface @legacy_io attributes {sym_visibility = "private"} {
+ hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+ hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
+ hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write"
+ }
+}
+
+
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 * 8)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 64)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<()[s0] -> (s0 * 4)>
+// CHECK: func @batch_matmul_static_shape
+// CHECK-DAG: %[[ARG0:.+]] = iree.placeholder {{.*}} {binding = @legacy_io::@arg0
+// CHECK-DAG: %[[ARG1:.+]] = iree.placeholder {{.*}} {binding = @legacy_io::@arg1
+// CHECK-DAG: %[[RET0:.+]] = iree.placeholder {{.*}} {binding = @legacy_io::@ret0
+// CHECK-DAG: %[[C0:.+]] = constant 0 : index
+// CHECK-DAG: %[[CST:.+]] = constant 0.0
+// CHECK-DAG: %[[C1:.+]] = constant 1 : index
+// CHECK-DAG: %[[C2:.+]] = constant 2 : index
+// CHECK-DAG: %[[C3:.+]] = constant 3 : index
+// CHECK-DAG: %[[C4:.+]] = constant 4 : index
+// CHECK-DAG: %[[C5:.+]] = constant 5 : index
+// CHECK-DAG: %[[C6:.+]] = constant 6 : index
+// CHECK-DAG: %[[C7:.+]] = constant 7 : index
+// CHECK: %[[BIDX:.+]] = "gpu.block_id"() {dimension = "x"}
+// CHECK: %[[BIDY:.+]] = "gpu.block_id"() {dimension = "y"}
+// CHECK: %[[BIDZ:.+]] = "gpu.block_id"() {dimension = "z"}
+// CHECK: %[[BOFFSET_Y:.+]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
+// CHECK: %[[BOFFSET_X:.+]] = affine.apply #[[MAP1]]()[%[[BIDX]]]
+// CHECK: %[[SUBVIEW_RESULT:.+]] = subview %[[RET0]]
+// CHECK-SAME: [%[[BIDZ]], %[[BOFFSET_Y]], %[[BOFFSET_X]]] [1, 8, 64]
+// CHECK: %[[IIDX:.+]] = "gpu.thread_id"() {dimension = "x"}
+// CHECK: %[[IIDY:.+]] = "gpu.thread_id"() {dimension = "y"}
+// CHECK: %[[IIDZ:.+]] = "gpu.thread_id"() {dimension = "z"}
+// CHECK: %[[IOFFSET_Y:.+]] = affine.apply #[[MAP0]]()[%[[IIDY]]]
+// CHECK: %[[IOFFSET_X:.+]] = affine.apply #[[MAP2]]()[%[[IIDX]]]
+// CHECK: %[[SUBVIEW_RESULT_2:.+]] = subview %[[SUBVIEW_RESULT]]
+// CHECK-SAME: [%[[IIDZ]], %[[IOFFSET_Y]], %[[IOFFSET_X]]] [1, 8, 4]
+// CHECK-DAG: %[[READ_INIT_0:.+]] = vector.transfer_read
+// CHECK-SAME: %[[SUBVIEW_RESULT_2]][%[[C0]], %[[C0]], %[[C0]]]
+// CHECK-DAG: %[[READ_INIT_1:.+]] = vector.transfer_read
+// CHECK-SAME: %[[SUBVIEW_RESULT_2]][%[[C0]], %[[C1]], %[[C0]]]
+// CHECK-DAG: %[[READ_INIT_2:.+]] = vector.transfer_read
+// CHECK-SAME: %[[SUBVIEW_RESULT_2]][%[[C0]], %[[C2]], %[[C0]]]
+// CHECK-DAG: %[[READ_INIT_3:.+]] = vector.transfer_read
+// CHECK-SAME: %[[SUBVIEW_RESULT_2]][%[[C0]], %[[C3]], %[[C0]]]
+// CHECK-DAG: %[[READ_INIT_4:.+]] = vector.transfer_read
+// CHECK-SAME: %[[SUBVIEW_RESULT_2]][%[[C0]], %[[C4]], %[[C0]]]
+// CHECK-DAG: %[[READ_INIT_5:.+]] = vector.transfer_read
+// CHECK-SAME: %[[SUBVIEW_RESULT_2]][%[[C0]], %[[C5]], %[[C0]]]
+// CHECK-DAG: %[[READ_INIT_6:.+]] = vector.transfer_read
+// CHECK-SAME: %[[SUBVIEW_RESULT_2]][%[[C0]], %[[C6]], %[[C0]]]
+// CHECK-DAG: %[[READ_INIT_7:.+]] = vector.transfer_read
+// CHECK-SAME: %[[SUBVIEW_RESULT_2]][%[[C0]], %[[C7]], %[[C0]]]
+
+// CHECK: %[[FOR_RES:.+]]:8 = scf.for %[[IV0:.+]] = {{.*}} to
+// CHECK-SAME: iter_args(%[[ACC_0:.+]] = %[[READ_INIT_0]],
+// CHECK-SAME: %[[ACC_1:.+]] = %[[READ_INIT_1]],
+// CHECK-SAME: %[[ACC_2:.+]] = %[[READ_INIT_2]],
+// CHECK-SAME: %[[ACC_3:.+]] = %[[READ_INIT_3]],
+// CHECK-SAME: %[[ACC_4:.+]] = %[[READ_INIT_4]],
+// CHECK-SAME: %[[ACC_5:.+]] = %[[READ_INIT_5]],
+// CHECK-SAME: %[[ACC_6:.+]] = %[[READ_INIT_6]],
+// CHECK-SAME: %[[ACC_7:.+]] = %[[READ_INIT_7]])
+// CHECK: %[[SUBVIEW_LHS:.+]] = subview %[[ARG0]]
+// CHECK-SAME: [%[[BIDZ]], %[[BOFFSET_Y]], %[[IV0]]] [1, 8, 4]
+// CHECK: %[[SUBVIEW_RHS:.+]] = subview %[[ARG1]]
+// CHECK-SAME: [%[[BIDZ]], %[[IV0]], %[[BOFFSET_X]]] [1, 4, 64]
+// CHECK: %[[SUBVIEW_LHS_2:.+]] = subview %[[SUBVIEW_LHS]]
+// CHECK-SAME: [%[[IIDZ]], %[[IOFFSET_Y]], 0] [1, 8, 4] [1, 1, 1]
+// CHECK: %[[SUBVIEW_RHS_2:.+]] = subview %[[SUBVIEW_RHS]]
+// CHECK-SAME: [%[[IIDZ]], 0, %[[IOFFSET_X]]] [1, 4, 4] [1, 1, 1]
+
+// CHECK-DAG: %[[READ_LHS_0:.+]] = vector.transfer_read
+// CHECK-SAME: %[[SUBVIEW_LHS_2]][%[[C0]], %[[C0]], %[[C0]]]
+// CHECK-DAG: %[[READ_LHS_1:.+]] = vector.transfer_read
+// CHECK-SAME: %[[SUBVIEW_LHS_2]][%[[C0]], %[[C1]], %[[C0]]]
+// CHECK-DAG: %[[READ_LHS_2:.+]] = vector.transfer_read
+// CHECK-SAME: %[[SUBVIEW_LHS_2]][%[[C0]], %[[C2]], %[[C0]]]
+// CHECK-DAG: %[[READ_LHS_3:.+]] = vector.transfer_read
+// CHECK-SAME: %[[SUBVIEW_LHS_2]][%[[C0]], %[[C3]], %[[C0]]]
+// CHECK-DAG: %[[READ_LHS_4:.+]] = vector.transfer_read
+// CHECK-SAME: %[[SUBVIEW_LHS_2]][%[[C0]], %[[C4]], %[[C0]]]
+// CHECK-DAG: %[[READ_LHS_5:.+]] = vector.transfer_read
+// CHECK-SAME: %[[SUBVIEW_LHS_2]][%[[C0]], %[[C5]], %[[C0]]]
+// CHECK-DAG: %[[READ_LHS_6:.+]] = vector.transfer_read
+// CHECK-SAME: %[[SUBVIEW_LHS_2]][%[[C0]], %[[C6]], %[[C0]]]
+// CHECK-DAG: %[[READ_LHS_7:.+]] = vector.transfer_read
+// CHECK-SAME: %[[SUBVIEW_LHS_2]][%[[C0]], %[[C7]], %[[C0]]]
+
+// CHECK-DAG: %[[READ_RHS_0:.+]] = vector.transfer_read
+// CHECK-SAME: %[[SUBVIEW_RHS_2]][%[[C0]], %[[C0]], %[[C0]]]
+// CHECK-DAG: %[[READ_RHS_1:.+]] = vector.transfer_read
+// CHECK-SAME: %[[SUBVIEW_RHS_2]][%[[C0]], %[[C1]], %[[C0]]]
+// CHECK-DAG: %[[READ_RHS_2:.+]] = vector.transfer_read
+// CHECK-SAME: %[[SUBVIEW_RHS_2]][%[[C0]], %[[C2]], %[[C0]]]
+// CHECK-DAG: %[[READ_RHS_3:.+]] = vector.transfer_read
+// CHECK-SAME: %[[SUBVIEW_RHS_2]][%[[C0]], %[[C3]], %[[C0]]]
+
+// CHECK-DAG: %[[READ_LHS_0_0:.+]] = vector.extract_strided_slice
+// CHECK-SAME: %[[READ_LHS_0]] {offsets = [0, 0, 0]
+// CHECK-DAG: %[[READ_LHS_0_1:.+]] = vector.extract_strided_slice
+// CHECK-SAME: %[[READ_LHS_0]] {offsets = [0, 0, 1]
+// CHECK-DAG: %[[READ_LHS_0_2:.+]] = vector.extract_strided_slice
+// CHECK-SAME: %[[READ_LHS_0]] {offsets = [0, 0, 2]
+// CHECK-DAG: %[[READ_LHS_0_3:.+]] = vector.extract_strided_slice
+// CHECK-SAME: %[[READ_LHS_0]] {offsets = [0, 0, 3]
+// CHECK-DAG: %[[READ_LHS_1_0:.+]] = vector.extract_strided_slice
+// CHECK-SAME: %[[READ_LHS_1]] {offsets = [0, 0, 0]
+// CHECK-DAG: %[[READ_LHS_1_1:.+]] = vector.extract_strided_slice
+// CHECK-SAME: %[[READ_LHS_1]] {offsets = [0, 0, 1]
+// CHECK-DAG: %[[READ_LHS_1_2:.+]] = vector.extract_strided_slice
+// CHECK-SAME: %[[READ_LHS_1]] {offsets = [0, 0, 2]
+// CHECK-DAG: %[[READ_LHS_1_3:.+]] = vector.extract_strided_slice
+// CHECK-SAME: %[[READ_LHS_1]] {offsets = [0, 0, 3]
+// CHECK-DAG: %[[READ_LHS_2_0:.+]] = vector.extract_strided_slice
+// CHECK-SAME: %[[READ_LHS_2]] {offsets = [0, 0, 0]
+// CHECK-DAG: %[[READ_LHS_2_1:.+]] = vector.extract_strided_slice
+// CHECK-SAME: %[[READ_LHS_2]] {offsets = [0, 0, 1]
+// CHECK-DAG: %[[READ_LHS_2_2:.+]] = vector.extract_strided_slice
+// CHECK-SAME: %[[READ_LHS_2]] {offsets = [0, 0, 2]
+// CHECK-DAG: %[[READ_LHS_2_3:.+]] = vector.extract_strided_slice
+// CHECK-SAME: %[[READ_LHS_2]] {offsets = [0, 0, 3]
+// CHECK-DAG: %[[READ_LHS_3_0:.+]] = vector.extract_strided_slice
+// CHECK-SAME: %[[READ_LHS_3]] {offsets = [0, 0, 0]
+// CHECK-DAG: %[[READ_LHS_3_1:.+]] = vector.extract_strided_slice
+// CHECK-SAME: %[[READ_LHS_3]] {offsets = [0, 0, 1]
+// CHECK-DAG: %[[READ_LHS_3_2:.+]] = vector.extract_strided_slice
+// CHECK-SAME: %[[READ_LHS_3]] {offsets = [0, 0, 2]
+// CHECK-DAG: %[[READ_LHS_3_3:.+]] = vector.extract_strided_slice
+// CHECK-SAME: %[[READ_LHS_3]] {offsets = [0, 0, 3]
+// CHECK-DAG: %[[READ_LHS_4_0:.+]] = vector.extract_strided_slice
+// CHECK-SAME: %[[READ_LHS_4]] {offsets = [0, 0, 0]
+// CHECK-DAG: %[[READ_LHS_4_1:.+]] = vector.extract_strided_slice
+// CHECK-SAME: %[[READ_LHS_4]] {offsets = [0, 0, 1]
+// CHECK-DAG: %[[READ_LHS_4_2:.+]] = vector.extract_strided_slice
+// CHECK-SAME: %[[READ_LHS_4]] {offsets = [0, 0, 2]
+// CHECK-DAG: %[[READ_LHS_4_3:.+]] = vector.extract_strided_slice
+// CHECK-SAME: %[[READ_LHS_4]] {offsets = [0, 0, 3]
+// CHECK-DAG: %[[READ_LHS_5_0:.+]] = vector.extract_strided_slice
+// CHECK-SAME: %[[READ_LHS_5]] {offsets = [0, 0, 0]
+// CHECK-DAG: %[[READ_LHS_5_1:.+]] = vector.extract_strided_slice
+// CHECK-SAME: %[[READ_LHS_5]] {offsets = [0, 0, 1]
+// CHECK-DAG: %[[READ_LHS_5_2:.+]] = vector.extract_strided_slice
+// CHECK-SAME: %[[READ_LHS_5]] {offsets = [0, 0, 2]
+// CHECK-DAG: %[[READ_LHS_5_3:.+]] = vector.extract_strided_slice
+// CHECK-SAME: %[[READ_LHS_5]] {offsets = [0, 0, 3]
+// CHECK-DAG: %[[READ_LHS_6_0:.+]] = vector.extract_strided_slice
+// CHECK-SAME: %[[READ_LHS_6]] {offsets = [0, 0, 0]
+// CHECK-DAG: %[[READ_LHS_6_1:.+]] = vector.extract_strided_slice
+// CHECK-SAME: %[[READ_LHS_6]] {offsets = [0, 0, 1]
+// CHECK-DAG: %[[READ_LHS_6_2:.+]] = vector.extract_strided_slice
+// CHECK-SAME: %[[READ_LHS_6]] {offsets = [0, 0, 2]
+// CHECK-DAG: %[[READ_LHS_6_3:.+]] = vector.extract_strided_slice
+// CHECK-SAME: %[[READ_LHS_6]] {offsets = [0, 0, 3]
+// CHECK-DAG: %[[READ_LHS_7_0:.+]] = vector.extract_strided_slice
+// CHECK-SAME: %[[READ_LHS_7]] {offsets = [0, 0, 0]
+// CHECK-DAG: %[[READ_LHS_7_1:.+]] = vector.extract_strided_slice
+// CHECK-SAME: %[[READ_LHS_7]] {offsets = [0, 0, 1]
+// CHECK-DAG: %[[READ_LHS_7_2:.+]] = vector.extract_strided_slice
+// CHECK-SAME: %[[READ_LHS_7]] {offsets = [0, 0, 2]
+// CHECK-DAG: %[[READ_LHS_7_3:.+]] = vector.extract_strided_slice
+// CHECK-SAME: %[[READ_LHS_7]] {offsets = [0, 0, 3]
+
+// CHECK: %[[CONTRACT_0_0:.+]] = vector.contract
+// CHECK-SAME: %[[READ_LHS_0_0]], %[[READ_RHS_0]], %[[ACC_0]]
+// CHECK: %[[CONTRACT_0_1:.+]] = vector.contract
+// CHECK-SAME: %[[READ_LHS_0_1]], %[[READ_RHS_1]], %[[CONTRACT_0_0]]
+// CHECK: %[[CONTRACT_0_2:.+]] = vector.contract
+// CHECK-SAME: %[[READ_LHS_0_2]], %[[READ_RHS_2]], %[[CONTRACT_0_1]]
+// CHECK: %[[CONTRACT_0_3:.+]] = vector.contract
+// CHECK-SAME: %[[READ_LHS_0_3]], %[[READ_RHS_3]], %[[CONTRACT_0_2]]
+
+// CHECK: %[[CONTRACT_1_0:.+]] = vector.contract
+// CHECK-SAME: %[[READ_LHS_1_0]], %[[READ_RHS_0]], %[[ACC_1]]
+// CHECK: %[[CONTRACT_1_1:.+]] = vector.contract
+// CHECK-SAME: %[[READ_LHS_1_1]], %[[READ_RHS_1]], %[[CONTRACT_1_0]]
+// CHECK: %[[CONTRACT_1_2:.+]] = vector.contract
+// CHECK-SAME: %[[READ_LHS_1_2]], %[[READ_RHS_2]], %[[CONTRACT_1_1]]
+// CHECK: %[[CONTRACT_1_3:.+]] = vector.contract
+// CHECK-SAME: %[[READ_LHS_1_3]], %[[READ_RHS_3]], %[[CONTRACT_1_2]]
+
+// CHECK: %[[CONTRACT_2_0:.+]] = vector.contract
+// CHECK-SAME: %[[READ_LHS_2_0]], %[[READ_RHS_0]], %[[ACC_2]]
+// CHECK: %[[CONTRACT_2_1:.+]] = vector.contract
+// CHECK-SAME: %[[READ_LHS_2_1]], %[[READ_RHS_1]], %[[CONTRACT_2_0]]
+// CHECK: %[[CONTRACT_2_2:.+]] = vector.contract
+// CHECK-SAME: %[[READ_LHS_2_2]], %[[READ_RHS_2]], %[[CONTRACT_2_1]]
+// CHECK: %[[CONTRACT_2_3:.+]] = vector.contract
+// CHECK-SAME: %[[READ_LHS_2_3]], %[[READ_RHS_3]], %[[CONTRACT_2_2]]
+
+// CHECK: %[[CONTRACT_3_0:.+]] = vector.contract
+// CHECK-SAME: %[[READ_LHS_3_0]], %[[READ_RHS_0]], %[[ACC_3]]
+// CHECK: %[[CONTRACT_3_1:.+]] = vector.contract
+// CHECK-SAME: %[[READ_LHS_3_1]], %[[READ_RHS_1]], %[[CONTRACT_3_0]]
+// CHECK: %[[CONTRACT_3_2:.+]] = vector.contract
+// CHECK-SAME: %[[READ_LHS_3_2]], %[[READ_RHS_2]], %[[CONTRACT_3_1]]
+// CHECK: %[[CONTRACT_3_3:.+]] = vector.contract
+// CHECK-SAME: %[[READ_LHS_3_3]], %[[READ_RHS_3]], %[[CONTRACT_3_2]]
+
+// CHECK: %[[CONTRACT_4_0:.+]] = vector.contract
+// CHECK-SAME: %[[READ_LHS_4_0]], %[[READ_RHS_0]], %[[ACC_4]]
+// CHECK: %[[CONTRACT_4_1:.+]] = vector.contract
+// CHECK-SAME: %[[READ_LHS_4_1]], %[[READ_RHS_1]], %[[CONTRACT_4_0]]
+// CHECK: %[[CONTRACT_4_2:.+]] = vector.contract
+// CHECK-SAME: %[[READ_LHS_4_2]], %[[READ_RHS_2]], %[[CONTRACT_4_1]]
+// CHECK: %[[CONTRACT_4_3:.+]] = vector.contract
+// CHECK-SAME: %[[READ_LHS_4_3]], %[[READ_RHS_3]], %[[CONTRACT_4_2]]
+
+// CHECK: %[[CONTRACT_5_0:.+]] = vector.contract
+// CHECK-SAME: %[[READ_LHS_5_0]], %[[READ_RHS_0]], %[[ACC_5]]
+// CHECK: %[[CONTRACT_5_1:.+]] = vector.contract
+// CHECK-SAME: %[[READ_LHS_5_1]], %[[READ_RHS_1]], %[[CONTRACT_5_0]]
+// CHECK: %[[CONTRACT_5_2:.+]] = vector.contract
+// CHECK-SAME: %[[READ_LHS_5_2]], %[[READ_RHS_2]], %[[CONTRACT_5_1]]
+// CHECK: %[[CONTRACT_5_3:.+]] = vector.contract
+// CHECK-SAME: %[[READ_LHS_5_3]], %[[READ_RHS_3]], %[[CONTRACT_5_2]]
+
+// CHECK: %[[CONTRACT_6_0:.+]] = vector.contract
+// CHECK-SAME: %[[READ_LHS_6_0]], %[[READ_RHS_0]], %[[ACC_6]]
+// CHECK: %[[CONTRACT_6_1:.+]] = vector.contract
+// CHECK-SAME: %[[READ_LHS_6_1]], %[[READ_RHS_1]], %[[CONTRACT_6_0]]
+// CHECK: %[[CONTRACT_6_2:.+]] = vector.contract
+// CHECK-SAME: %[[READ_LHS_6_2]], %[[READ_RHS_2]], %[[CONTRACT_6_1]]
+// CHECK: %[[CONTRACT_6_3:.+]] = vector.contract
+// CHECK-SAME: %[[READ_LHS_6_3]], %[[READ_RHS_3]], %[[CONTRACT_6_2]]
+
+// CHECK: %[[CONTRACT_7_0:.+]] = vector.contract
+// CHECK-SAME: %[[READ_LHS_7_0]], %[[READ_RHS_0]], %[[ACC_7]]
+// CHECK: %[[CONTRACT_7_1:.+]] = vector.contract
+// CHECK-SAME: %[[READ_LHS_7_1]], %[[READ_RHS_1]], %[[CONTRACT_7_0]]
+// CHECK: %[[CONTRACT_7_2:.+]] = vector.contract
+// CHECK-SAME: %[[READ_LHS_7_2]], %[[READ_RHS_2]], %[[CONTRACT_7_1]]
+// CHECK: %[[CONTRACT_7_3:.+]] = vector.contract
+// CHECK-SAME: %[[READ_LHS_7_3]], %[[READ_RHS_3]], %[[CONTRACT_7_2]]
+
+// CHECK: scf.yield %[[CONTRACT_0_3]], %[[CONTRACT_1_3]],
+// CHECK-SAME: %[[CONTRACT_2_3]], %[[CONTRACT_3_3]], %[[CONTRACT_4_3]],
+// CHECK-SAME: %[[CONTRACT_5_3]], %[[CONTRACT_6_3]], %[[CONTRACT_7_3]]
+
+// CHECK-DAG: vector.transfer_write %[[FOR_RES]]#0, %[[SUBVIEW_RESULT_2]][%[[C0]], %[[C0]], %[[C0]]]
+// CHECK-DAG: vector.transfer_write %[[FOR_RES]]#1, %[[SUBVIEW_RESULT_2]][%[[C0]], %[[C1]], %[[C0]]]
+// CHECK-DAG: vector.transfer_write %[[FOR_RES]]#2, %[[SUBVIEW_RESULT_2]][%[[C0]], %[[C2]], %[[C0]]]
+// CHECK-DAG: vector.transfer_write %[[FOR_RES]]#3, %[[SUBVIEW_RESULT_2]][%[[C0]], %[[C3]], %[[C0]]]
+// CHECK-DAG: vector.transfer_write %[[FOR_RES]]#4, %[[SUBVIEW_RESULT_2]][%[[C0]], %[[C4]], %[[C0]]]
+// CHECK-DAG: vector.transfer_write %[[FOR_RES]]#5, %[[SUBVIEW_RESULT_2]][%[[C0]], %[[C5]], %[[C0]]]
+// CHECK-DAG: vector.transfer_write %[[FOR_RES]]#6, %[[SUBVIEW_RESULT_2]][%[[C0]], %[[C6]], %[[C0]]]
+// CHECK-DAG: vector.transfer_write %[[FOR_RES]]#7, %[[SUBVIEW_RESULT_2]][%[[C0]], %[[C7]], %[[C0]]]
+
+// -----
+
+module attributes {
+ spv.target_env =
+ #spv.target_env<#spv.vce<v1.3,
+ [Shader, Float64, Float16, Int64, Int16, Int8, StorageBuffer16BitAccess,
+ StorageUniform16, StoragePushConstant16, StorageBuffer8BitAccess,
+ UniformAndStorageBuffer8BitAccess, StoragePushConstant8, GroupNonUniform,
+ GroupNonUniformVote, GroupNonUniformArithmetic, GroupNonUniformBallot,
+ GroupNonUniformShuffle, GroupNonUniformShuffleRelative, VariablePointers,
+ VariablePointersStorageBuffer],
+ [SPV_KHR_16bit_storage, SPV_KHR_8bit_storage,
+ SPV_KHR_storage_buffer_storage_class, SPV_KHR_variable_pointers]>,
+ ARM:IntegratedGPU,
+ {max_compute_shared_memory_size = 32768 : i32,
+ max_compute_workgroup_invocations = 512 : i32,
+ max_compute_workgroup_size = dense<512> : vector<3xi32>,
+ subgroup_size = 16 : i32}>} {
+ func @batch_matmul_fused_fillop()
+ attributes {vkspv.num_workgroups_fn = @batch_matmul_fused_fillop__num_workgroups__} {
+ %cst = constant 0.000000e+00 : f32
+ %arg0 = iree.placeholder for "interface buffer"
+ {binding = @legacy_io::@arg0, operand_result_num = 0 : i32} : memref<4x1024x1024xf32>
+ %arg1 = iree.placeholder for "interface buffer"
+ {binding = @legacy_io::@arg1, operand_result_num = 1 : i32} : memref<4x1024x1024xf32>
+ %ret0 = iree.placeholder for "interface buffer"
+ {binding = @legacy_io::@ret0, operand_result_num = 2 : i32} : memref<4x1024x1024xf32>
+ linalg.fill(%ret0, %cst) : memref<4x1024x1024xf32>, f32
+ linalg.batch_matmul ins(%arg0, %arg1 : memref<4x1024x1024xf32>, memref<4x1024x1024xf32>) outs(%ret0 : memref<4x1024x1024xf32>)
+ return
+ }
+ func @batch_matmul_fused_fillop__num_workgroups__
+ (!shapex.ranked_shape<[4096, 4096]>, !shapex.ranked_shape<[4096, 4096]>,
+ !shapex.ranked_shape<[4096, 4096]>) -> (index, index, index)
+ attributes {sym_visibility = "private"}
+ hal.interface @legacy_io attributes {sym_visibility = "private"} {
+ hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+ hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
+ hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write"
+ }
+}
+
+// CHECK-LABEL: func @batch_matmul_fused_fillop
+// CHECK-COUNT-8: vector.transfer_write
+// CHECK-COUNT-8: vector.transfer_read
+// CHECK: %[[FOR_RES:.+]]:8 = scf.for
+// CHECK-COUNT-12: vector.transfer_read
+// CHECK-COUNT-32: vector.contract
+// CHECK: scf.yield
+// CHECK-COUNT-8: vector.transfer_write %[[FOR_RES]]
+// CHECK: return
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/linalg_tile_and_fuse.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/linalg_tile_and_fuse.mlir
index 593a693..75cd37d 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/test/linalg_tile_and_fuse.mlir
+++ b/iree/compiler/Conversion/LinalgToSPIRV/test/linalg_tile_and_fuse.mlir
@@ -1,38 +1,4 @@
-// RUN: iree-opt -split-input-file -iree-codegen-linalg-tile-and-fuse %s | IreeFileCheck %s
-
-// Test to check that convolution with padding is not tiled.
-module attributes {
- spv.target_env =
- #spv.target_env<#spv.vce<v1.3,
- [Shader], [SPV_KHR_storage_buffer_storage_class]>,
- {max_compute_workgroup_invocations = 128 : i32,
- max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} {
- func @conv_padding() {
- %0 = iree.placeholder for "interace buffer" {binding = @legacy_io::@arg0} : memref<?x?x?x?xf32>
- %1 = iree.placeholder for "interace buffer" {binding = @legacy_io::@arg1} : memref<?x?x?x?xf32>
- %2 = iree.placeholder for "interace buffer" {binding = @legacy_io::@ret0} : memref<?x?x?x?xf32>
- linalg.conv(%0, %1, %2)
- {dilations = [1, 1],
- padding = dense<[[1, 1], [0, 1]]> : tensor<2x2xi64>, strides = [1, 1]} :
- memref<?x?x?x?xf32>, memref<?x?x?x?xf32>, memref<?x?x?x?xf32>
- return
- }
- hal.interface @legacy_io attributes {sym_visibility = "private"} {
- hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
- hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
- hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write"
- }
-}
-// CHECK: func @conv_padding()
-// CHECK-DAG: %[[ARG0:.+]] = iree.placeholder for "interace buffer" {binding = @legacy_io::@arg0}
-// CHECK-DAG: %[[ARG1:.+]] = iree.placeholder for "interace buffer" {binding = @legacy_io::@arg1}
-// CHECK-DAG: %[[RET0:.+]] = iree.placeholder for "interace buffer" {binding = @legacy_io::@ret0}
-// CHECK: linalg.conv
-// CHECK-SAME: %[[ARG0]]
-// CHECK-SAME: %[[ARG1]]
-// CHECK-SAME: %[[RET0]]
-
-// -----
+// RUN: iree-opt -split-input-file -iree-codegen-linalg-tile-and-fuse -iree-spirv-enable-vectorization -canonicalize -cse %s | IreeFileCheck %s
module attributes {
spv.target_env =
@@ -63,26 +29,24 @@
}
}
// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 * 4)>
-// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 32)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<()[s0] -> (s0 * 32)>
// CHECK: func @conv_no_padding()
// CHECK-SAME: hal.num_workgroups_fn = @[[NUM_WORKGROUPS_FN:[a-zA-Z0-9_]+]]
// CHECK-SAME: local_size = dense<[32, 4, 1]>
-// CHECK-DAG: %[[ARG0:.+]] = iree.placeholder {{.*}} {binding = @legacy_io::@arg0
-// CHECK-DAG: %[[ARG1:.+]] = iree.placeholder {{.*}} {binding = @legacy_io::@arg1
-// CHECK-DAG: %[[RET0:.+]] = iree.placeholder {{.*}} {binding = @legacy_io::@ret0
+// CHECK-DAG: %[[ARG0:.+]] = iree.placeholder {{.*}} @legacy_io::@arg0
+// CHECK-DAG: %[[ARG1:.+]] = iree.placeholder {{.*}} @legacy_io::@arg1
+// CHECK-DAG: %[[RET0:.+]] = iree.placeholder {{.*}} @legacy_io::@ret0
// CHECK-DAG: %[[BIDX:.+]] = "gpu.block_id"() {dimension = "x"}
// CHECK-DAG: %[[BIDY:.+]] = "gpu.block_id"() {dimension = "y"}
// CHECK-DAG: %[[BIDZ:.+]] = "gpu.block_id"() {dimension = "z"}
// CHECK: %[[LBY:.+]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
-// CHECK: %[[LBX:.+]] = affine.apply #[[MAP1]]()[%[[BIDX]]]
-// CHECK: %[[VIEW1:.+]] = subview %[[ARG1]]
+// CHECK: %[[LBX:.+]] = affine.apply #[[MAP2]]()[%[[BIDX]]]
+// CHECK: %[[SV_ARG1:.+]] = subview %[[ARG1]]
// CHECK-SAME: [%[[BIDZ]], %[[LBY]], %[[LBX]], 0]
-// CHECK: %[[LBY_2:.+]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
-// CHECK: %[[LBX_2:.+]] = affine.apply #[[MAP1]]()[%[[BIDX]]]
-// CHECK: %[[VIEW2:.+]] = subview %[[RET0]]
-// CHECK-SAME: [%[[BIDZ]], %[[LBY_2]], %[[LBX_2]], 0]
+// CHECK: %[[SV_RET0:.+]] = subview %[[RET0]]
+// CHECK-SAME: [%[[BIDZ]], %[[LBY]], %[[LBX]], 0]
// CHECK: linalg.conv
-// CHECK-SAME: %[[ARG0]], %[[VIEW1]], %[[VIEW2]]
+// CHECK-SAME: %[[ARG0]], %[[SV_ARG1]], %[[SV_RET0]]
// CHECK-SAME: "workgroup"
// CHECK: func private @[[NUM_WORKGROUPS_FN]]
// CHECK-DAG: %[[C0:.+]] = constant 0
@@ -98,10 +62,10 @@
// CHECK-DAG: %[[DIM0:.+]] = dim %[[ARG1]], %[[C0]]
// CHECK-DAG: %[[DIM1:.+]] = dim %[[RET0]], %[[C1]]
// CHECK-DAG: %[[DIM2:.+]] = dim %[[RET0]], %[[C2]]
-// CHECK: %[[T0:.+]] = addi %[[DIM2]], %[[C31]]
-// CHECK-DAG: %[[NBX:.+]] = divi_signed %[[T0]], %[[C32]]
// CHECK: %[[T1:.+]] = addi %[[DIM1]], %[[C3]]
// CHECK-DAG: %[[NBY:.+]] = divi_signed %[[T1]], %[[C4]]
+// CHECK: %[[T0:.+]] = addi %[[DIM2]], %[[C31]]
+// CHECK-DAG: %[[NBX:.+]] = divi_signed %[[T0]], %[[C32]]
// CHECK: return %[[NBX]], %[[NBY]], %[[DIM0]]
@@ -140,24 +104,24 @@
// CHECK: func @matmul()
// CHECK-SAME: hal.num_workgroups_fn = @[[NUM_WORKGROUPS_FN:[a-zA-Z0-9_]+]]
// CHECK-SAME: local_size = dense<[16, 8, 1]>
-// CHECK-DAG: %[[ARG0:.+]] = iree.placeholder {{.*}} {binding = @legacy_io::@arg0
-// CHECK-DAG: %[[ARG1:.+]] = iree.placeholder {{.*}} {binding = @legacy_io::@arg1
-// CHECK-DAG: %[[RET0:.+]] = iree.placeholder {{.*}} {binding = @legacy_io::@ret0
+// CHECK-DAG: %[[ARG0:.+]] = iree.placeholder {{.*}} @legacy_io::@arg0
+// CHECK-DAG: %[[ARG1:.+]] = iree.placeholder {{.*}} @legacy_io::@arg1
+// CHECK-DAG: %[[RET0:.+]] = iree.placeholder {{.*}} @legacy_io::@ret0
// CHECK-DAG: %[[BIDX:.+]] = "gpu.block_id"() {dimension = "x"}
// CHECK-DAG: %[[BIDY:.+]] = "gpu.block_id"() {dimension = "y"}
// CHECK-NOT: scf.parallel
// CHECK-NOT: scf.for
// CHECK: %[[LBY:.+]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
-// CHECK: %[[VIEW0:.+]] = subview %[[ARG0]][%[[LBY]], 0]
+// CHECK: %[[SV_ARG0:.+]] = subview %[[ARG0]][%[[LBY]], 0]
// CHECK: %[[LBX:.+]] = affine.apply #[[MAP3]]()[%[[BIDX]]]
-// CHECK: %[[VIEW1:.+]] = subview %[[ARG1]][0, %[[LBX]]]
-// CHECK: %[[LBY_2:.+]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
-// CHECK: %[[LBX_2:.+]] = affine.apply #[[MAP3]]()[%[[BIDX]]]
-// CHECK: %[[VIEW2:.+]] = subview %[[RET0]][%[[LBY_2]], %[[LBX_2]]]
+// CHECK: %[[SV_ARG1:.+]] = subview %[[ARG1]][0, %[[LBX]]]
+// CHECK: %[[SV_RET0:.+]] = subview %[[RET0]][%[[LBY]], %[[LBX]]]
// CHECK: linalg.matmul
// CHECK-SAME: "workgroup"
-// CHECK-SAME: ins(%[[VIEW0]], %[[VIEW1]]
-// CHECK-SAME: outs(%[[VIEW2]]
+// CHECK-SAME: ins(%[[SV_ARG0]], %[[SV_ARG1]]
+// CHECK-SAME: )
+// CHECK-SAME: outs(%[[SV_RET0]]
+// CHECK-SAME: )
// CHECK: func private @[[NUM_WORKGROUPS_FN]]
// CHECK-DAG: %[[C8:.+]] = constant 8 : index
// CHECK-DAG: %[[C7:.+]] = constant 7 : index
@@ -165,12 +129,14 @@
// CHECK-DAG: %[[C1:.+]] = constant 1 : index
// CHECK-DAG: %[[C16:.+]] = constant 16 : index
// CHECK-DAG: %[[C15:.+]] = constant 15 : index
-// CHECK: %[[DIM0:.+]] = dim %{{.*}}, %[[C0]]
-// CHECK: %[[DIM1:.+]] = dim %{{.*}}, %[[C1]]
-// CHECK: %[[T0:.+]] = addi %[[DIM1]], %[[C15]]
-// CHECK: %[[T1:.+]] = divi_signed %[[T0]], %[[C16]]
+// CHECK-DAG: %[[ARG0:.+]] = iree.placeholder {{.+}} {binding = @legacy_io::@arg0
+// CHECK-DAG: %[[ARG1:.+]] = iree.placeholder {{.+}} {binding = @legacy_io::@arg1
+// CHECK: %[[DIM0:.+]] = dim %[[ARG0]], %[[C0]]
+// CHECK: %[[DIM1:.+]] = dim %[[ARG1]], %[[C1]]
// CHECK: %[[T2:.+]] = addi %[[DIM0]], %[[C7]]
// CHECK: %[[T3:.+]] = divi_signed %[[T2]], %[[C8]]
+// CHECK: %[[T0:.+]] = addi %[[DIM1]], %[[C15]]
+// CHECK: %[[T1:.+]] = divi_signed %[[T0]], %[[C16]]
// CHECK: return %[[T1]], %[[T3]], %[[C1]]
// -----
@@ -215,12 +181,10 @@
// CHECK-DAG: %[[BIDY:.+]] = "gpu.block_id"() {dimension = "y"}
// CHECK: %[[LBY:.+]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
// CHECK: %[[LBX:.+]] = affine.apply #[[MAP1]]()[%[[BIDX]]]
-// CHECK: %[[VIEW0:.+]] = subview %[[ARG0]][%[[LBY]], %[[LBX]]]
-// CHECK: %[[LBY2:.+]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
-// CHECK: %[[LBX2:.+]] = affine.apply #[[MAP1]]()[%[[BIDX]]]
-// CHECK: %[[VIEW2:.+]] = subview %[[RET0]][%[[LBY2]], %[[LBX2]]]
+// CHECK: %[[SV_ARG0:.+]] = subview %[[ARG0]][%[[LBY]], %[[LBX]]]
+// CHECK: %[[SV_RET0:.+]] = subview %[[RET0]][%[[LBY]], %[[LBX]]]
// CHECK: linalg.pooling_max
-// CHECK-SAME: %[[VIEW0]], %[[ARG1]], %[[VIEW2]]
+// CHECK-SAME: %[[SV_ARG0]], %[[ARG1]], %[[SV_RET0]]
// CHECK-SAME: "workgroup"
// CHECK: func private @[[NUM_WORKGROUPS_FN]]
// CHECK-DAG: %[[C0:.+]] = constant 0
@@ -232,10 +196,10 @@
// CHECK-DAG: %[[RET0:.+]] = iree.placeholder {{.+}} {binding = @legacy_io::@ret0
// CHECK-DAG: %[[DIM0:.+]] = dim %[[RET0]], %[[C0]]
// CHECK-DAG: %[[DIM1:.+]] = dim %[[RET0]], %[[C1]]
-// CHECK: %[[T0:.+]] = addi %[[DIM1]], %[[C31]]
-// CHECK-DAG: %[[NBX:.+]] = divi_signed %[[T0]], %[[C32]]
// CHECK: %[[T1:.+]] = addi %[[DIM0]], %[[C3]]
// CHECK-DAG: %[[NBY:.+]] = divi_signed %[[T1]], %[[C4]]
+// CHECK: %[[T0:.+]] = addi %[[DIM1]], %[[C31]]
+// CHECK-DAG: %[[NBX:.+]] = divi_signed %[[T0]], %[[C32]]
// CHECK: return %[[NBX]], %[[NBY]], %[[C1]]
// -----
@@ -281,12 +245,10 @@
// CHECK-DAG: %[[BIDY:.+]] = "gpu.block_id"() {dimension = "y"}
// CHECK: %[[LBY:.+]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
// CHECK: %[[LBX:.+]] = affine.apply #[[MAP2]]()[%[[BIDX]]]
-// CHECK: %[[VIEW0:.+]] = subview %[[ARG0]][0, %[[LBY]], %[[LBX]], 0]
-// CHECK: %[[LBY2:.+]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
-// CHECK: %[[LBX2:.+]] = affine.apply #[[MAP2]]()[%[[BIDX]]]
-// CHECK: %[[VIEW2:.+]] = subview %[[RET0]][0, %[[LBY2]], %[[LBX2]], 0]
+// CHECK: %[[SV_ARG0:.+]] = subview %[[ARG0]][0, %[[LBY]], %[[LBX]], 0]
+// CHECK: %[[SV_RET0:.+]] = subview %[[RET0]][0, %[[LBY]], %[[LBX]], 0]
// CHECK: linalg.pooling_max
-// CHECK-SAME: %[[VIEW0]], %[[ARG1]], %[[VIEW2]]
+// CHECK-SAME: %[[SV_ARG0]], %[[ARG1]], %[[SV_RET0]]
// CHECK-SAME: "workgroup"
// CHECK: func private @[[NUM_WORKGROUPS_FN]]
// CHECK-DAG: %[[C1:.+]] = constant 1
@@ -297,14 +259,15 @@
// CHECK-DAG: %[[RET0:.+]] = iree.placeholder {{.+}} {binding = @legacy_io::@ret0
// CHECK-DAG: %[[DIM0:.+]] = dim %[[RET0]], %[[C1]]
// CHECK-DAG: %[[DIM1:.+]] = dim %[[RET0]], %[[C2]]
-// CHECK: %[[T0:.+]] = addi %[[DIM1]], %[[C31]]
-// CHECK-DAG: %[[NBX:.+]] = divi_signed %[[T0]], %[[C32]]
// CHECK: %[[T1:.+]] = addi %[[DIM0]], %[[C3]]
// CHECK-DAG: %[[NBY:.+]] = divi_signed %[[T1]], %[[C4]]
+// CHECK: %[[T0:.+]] = addi %[[DIM1]], %[[C31]]
+// CHECK-DAG: %[[NBX:.+]] = divi_signed %[[T0]], %[[C32]]
// CHECK: return %[[NBX]], %[[NBY]], %[[C1]]
// -----
+
module attributes {
spv.target_env =
#spv.target_env<#spv.vce<v1.3,
@@ -324,10 +287,9 @@
outs(%2 : memref<?x?xf32>)
return
}
- func @matmul_fusion__num_workgroups__
+ func private @matmul_fusion__num_workgroups__
(!shapex.ranked_shape<[?,?]>, !shapex.ranked_shape<[?,?]>,
!shapex.ranked_shape<[?,?]>) -> (index, index, index)
- attributes {sym_visibility = "private"}
hal.interface @legacy_io attributes {sym_visibility = "private"} {
hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
@@ -338,28 +300,41 @@
// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 * 8)>
// CHECK-DAG: #[[MAP3:.+]] = affine_map<()[s0] -> (s0 * 16)>
// CHECK: func @matmul_fusion()
+// CHECK-SAME: hal.num_workgroups_fn = @[[NWGFN:[a-zA-Z0-9_]+]]
// CHECK-SAME: local_size = dense<[16, 8, 1]>
-// CHECK-DAG: %[[ARG0:.+]] = iree.placeholder {{.*}} {binding = @legacy_io::@arg0
-// CHECK-DAG: %[[ARG1:.+]] = iree.placeholder {{.*}} {binding = @legacy_io::@arg1
-// CHECK-DAG: %[[RET0:.+]] = iree.placeholder {{.*}} {binding = @legacy_io::@ret0
+// CHECK-DAG: %[[ARG0:.+]] = iree.placeholder {{.+}} @legacy_io::@arg0
+// CHECK-DAG: %[[ARG1:.+]] = iree.placeholder {{.+}} @legacy_io::@arg1
+// CHECK-DAG: %[[RET0:.+]] = iree.placeholder {{.*}} @legacy_io::@ret0
// CHECK-DAG: %[[BIDX:.+]] = "gpu.block_id"() {dimension = "x"}
// CHECK-DAG: %[[BIDY:.+]] = "gpu.block_id"() {dimension = "y"}
// CHECK-NOT: scf.parallel
// CHECK-NOT: scf.for
// CHECK: %[[LBY:.+]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
-// CHECK: %[[VIEW0:.+]] = subview %[[ARG0]][%[[LBY]], 0]
+// CHECK: %[[SV_ARG0:.+]] = subview %[[ARG0]][%[[LBY]], 0]
// CHECK: %[[LBX:.+]] = affine.apply #[[MAP3]]()[%[[BIDX]]]
-// CHECK: %[[VIEW1:.+]] = subview %[[ARG1]][0, %[[LBX]]]
-// CHECK: %[[LBY_2:.+]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
-// CHECK: %[[LBX_2:.+]] = affine.apply #[[MAP3]]()[%[[BIDX]]]
-// CHECK: %[[VIEW2:.+]] = subview %[[RET0]][%[[LBY_2]], %[[LBX_2]]]
-// CHECK: %[[VIEW3:.+]] = subview %[[RET0]][%[[LBY]], %[[LBX]]]
-// CHECK: linalg.fill(%[[VIEW3]], %{{.+}})
+// CHECK: %[[SV_ARG1:.+]] = subview %[[ARG1]][0, %[[LBX]]]
+// CHECK: %[[SV_RET0_1:.+]] = subview %[[RET0]][%[[LBY]], %[[LBX]]]
+// CHECK: %[[SV_RET0_2:.+]] = subview %[[RET0]][%[[LBY]], %[[LBX]]]
+// CHECK: linalg.fill(%[[SV_RET0_2]], %{{.+}})
// CHECK-SAME: "workgroup"
// CHECK: linalg.matmul
// CHECK-SAME: "workgroup"
-// CHECK-SAME: ins(%[[VIEW0]], %[[VIEW1]]
-// CHECK-SAME: outs(%[[VIEW2]]
+// CHECK-SAME: ins(%[[SV_ARG0]], %[[SV_ARG1]]
+// CHECK-SAME: outs(%[[SV_RET0_1]]
+
+// CHECK: func private @[[NWGFN]]
+// CHECK-DAG: %[[C0:.+]] = constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = constant 1 : index
+// CHECK-DAG: %[[ARG0:.+]] = iree.placeholder {{.+}} @legacy_io::@arg0
+// CHECK-DAG: %[[ARG1:.+]] = iree.placeholder {{.+}} @legacy_io::@arg1
+// CHECK-DAG: %[[RET0:.+]] = iree.placeholder {{.*}} @legacy_io::@ret0
+// CHECK: %[[M:.+]] = dim %[[ARG0]], %[[C0]]
+// CHECK: %[[N:.+]] = dim %[[ARG1]], %[[C1]]
+// CHECK: %[[WGY_N:.+]] = addi %[[M]], %{{.+}}
+// CHECK: %[[WGY:.+]] = divi_signed %[[WGY_N]], %{{.+}}
+// CHECK: %[[WGX_N:.+]] = addi %[[N]], %{{.+}}
+// CHECK: %[[WGX:.+]] = divi_signed %[[WGX_N]], %{{.+}}
+// CHECK: return %[[WGX]], %[[WGY]], %[[C1]]
// -----
@@ -369,7 +344,9 @@
[Shader], [SPV_KHR_storage_buffer_storage_class]>,
{max_compute_workgroup_invocations = 128 : i32,
max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} {
- func @conv_no_padding_fusion() {
+ func @conv_no_padding_fusion()
+ attributes {
+ hal.num_workgroups_fn = @conv_no_padding_fusion__num_workgroups__} {
%0 = iree.placeholder for "interace buffer"
{binding = @legacy_io::@arg0, operand_result_index = 0 : i32} : memref<?x?x?x?xf32>
%1 = iree.placeholder for "interace buffer"
@@ -382,6 +359,9 @@
memref<?x?x?x?xf32>, memref<?x?x?x?xf32>, memref<?x?x?x?xf32>
return
}
+ func private @conv_no_padding_fusion__num_workgroups__
+ (!shapex.ranked_shape<[?,?,?,?]>, !shapex.ranked_shape<[?,?,?,?]>,
+ !shapex.ranked_shape<[?,?,?,?]>) -> (index, index, index)
hal.interface @legacy_io attributes {sym_visibility = "private"} {
hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
@@ -391,25 +371,185 @@
// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 * 4)>
// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 32)>
// CHECK: func @conv_no_padding_fusion()
+// CHECK-SAME: hal.num_workgroups_fn = @[[NWGFN:[a-zA-Z0-9_]+]]
// CHECK-SAME: local_size = dense<[32, 4, 1]>
-// CHECK-DAG: %[[ARG0:.+]] = iree.placeholder {{.*}} {binding = @legacy_io::@arg0
-// CHECK-DAG: %[[ARG1:.+]] = iree.placeholder {{.*}} {binding = @legacy_io::@arg1
-// CHECK-DAG: %[[RET0:.+]] = iree.placeholder {{.*}} {binding = @legacy_io::@ret0
+// CHECK-DAG: %[[ARG0:.+]] = iree.placeholder {{.*}} @legacy_io::@arg0
+// CHECK-DAG: %[[ARG1:.+]] = iree.placeholder {{.*}} @legacy_io::@arg1
+// CHECK-DAG: %[[RET0:.+]] = iree.placeholder {{.*}} @legacy_io::@ret0
// CHECK-DAG: %[[BIDX:.+]] = "gpu.block_id"() {dimension = "x"}
// CHECK-DAG: %[[BIDY:.+]] = "gpu.block_id"() {dimension = "y"}
// CHECK-DAG: %[[BIDZ:.+]] = "gpu.block_id"() {dimension = "z"}
// CHECK: %[[LBY:.+]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
// CHECK: %[[LBX:.+]] = affine.apply #[[MAP1]]()[%[[BIDX]]]
-// CHECK: %[[VIEW1:.+]] = subview %[[ARG1]]
+// CHECK: %[[SV_ARG1:.+]] = subview %[[ARG1]]
// CHECK-SAME: [%[[BIDZ]], %[[LBY]], %[[LBX]], 0]
-// CHECK: %[[LBY_2:.+]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
-// CHECK: %[[LBX_2:.+]] = affine.apply #[[MAP1]]()[%[[BIDX]]]
-// CHECK: %[[VIEW2:.+]] = subview %[[RET0]]
-// CHECK-SAME: [%[[BIDZ]], %[[LBY_2]], %[[LBX_2]], 0]
-// CHECK: %[[VIEW3:.+]] = subview %[[RET0]]
-// CHECK-SAME: [%[[BIDZ]], %[[LBY_2]], %[[LBX_2]], 0]
-// CHECK: linalg.fill(%[[VIEW3]], %{{.*}})
+// CHECK: %[[SV_RET0:.+]] = subview %[[RET0]]
+// CHECK-SAME: [%[[BIDZ]], %[[LBY]], %[[LBX]], 0]
+// CHECK: linalg.fill(%[[SV_RET0]], %{{.*}})
// CHECK-SAME: "workgroup"
// CHECK: linalg.conv
-// CHECK-SAME: %[[ARG0]], %[[VIEW1]], %[[VIEW2]]
+// CHECK-SAME: %[[ARG0]], %[[SV_ARG1]], %[[SV_RET0]]
// CHECK-SAME: "workgroup"
+
+// CHECK: func private @[[NWGFN]]
+// CHECK-DAG: %[[C0:.+]] = constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = constant 1 : index
+// CHECK-DAG: %[[C2:.+]] = constant 2 : index
+// CHECK-DAG: %[[ARG0:.+]] = iree.placeholder {{.*}} @legacy_io::@arg0
+// CHECK-DAG: %[[ARG1:.+]] = iree.placeholder {{.*}} @legacy_io::@arg1
+// CHECK-DAG: %[[RET0:.+]] = iree.placeholder {{.*}} @legacy_io::@ret0
+// CHECK: %[[N:.+]] = dim %[[ARG1]], %[[C0]]
+// CHECK: %[[R:.+]] = dim %[[RET0]], %[[C1]]
+// CHECK: %[[S:.+]] = dim %[[RET0]], %[[C2]]
+// CHECK: %[[WGY_N:.+]] = addi %[[R]], %{{.+}}
+// CHECK: %[[WGY:.+]] = divi_signed %[[WGY_N]], %{{.+}}
+// CHECK: %[[WGX_N:.+]] = addi %[[S]], %{{.+}}
+// CHECK: %[[WGX:.+]] = divi_signed %[[WGX_N]], %{{.+}}
+// CHECK: return %[[WGX]], %[[WGY]], %[[N]]
+
+// -----
+
+module attributes {
+ spv.target_env =
+ #spv.target_env<#spv.vce<v1.3,
+ [Shader], [SPV_KHR_storage_buffer_storage_class]>,
+ {max_compute_workgroup_invocations = 128 : i32,
+ max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} {
+ func @three_op_fusion()
+ attributes {
+ hal.num_workgroups_fn = @three_op_fusion__num_workgroups__} {
+ %cst = constant 0.000000e+00 : f32
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %0 = iree.placeholder for "interface buffer"
+ {binding = @legacy_io::@arg0, operand_result_index = 0 : i32}
+ : memref<?x?xf32>
+ %1 = iree.placeholder for "interface buffer"
+ {binding = @legacy_io::@arg1, operand_result_index = 1 : i32}
+ : memref<?x?xf32>
+ %d0 = dim %0, %c0 : memref<?x?xf32>
+ %d1 = dim %1, %c1 : memref<?x?xf32>
+ %2 = alloc(%d0, %d1) : memref<?x?xf32>
+ %3 = iree.placeholder for "interface buffer"
+ {binding = @legacy_io::@arg2, operand_result_index = 2 : i32}
+ : memref<?xf32>
+ %4 = iree.placeholder for "interface buffer"
+ {binding = @legacy_io::@ret0, operand_result_index = 3 : i32}
+ : memref<?x?xf32>
+ linalg.fill(%2, %cst) : memref<?x?xf32>, f32
+ linalg.matmul ins(%0, %1 : memref<?x?xf32>, memref<?x?xf32>)
+ outs(%2 : memref<?x?xf32>)
+ linalg.generic
+ {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+ affine_map<(d0, d1) -> (d1)>,
+ affine_map<(d0, d1) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%2, %3 : memref<?x?xf32>, memref<?xf32>)
+ outs(%4 : memref<?x?xf32>) {
+ ^bb0(%arg0 : f32, %arg1 : f32, %arg2 : f32) :
+ %5 = addf %arg0, %arg1 : f32
+ linalg.yield %5 : f32
+ }
+ return
+ }
+ func private @three_op_fusion__num_workgroups__
+ (!shapex.ranked_shape<[?,?]>, !shapex.ranked_shape<[?,?]>,
+ !shapex.ranked_shape<[?]>, !shapex.ranked_shape<[?,?]>)
+ -> (index, index, index)
+ hal.interface @legacy_io attributes {sym_visibility = "private"} {
+ hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+ hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
+ hal.interface.binding @arg2, set=0, binding=2, type="StorageBuffer", access="Read"
+ hal.interface.binding @ret0, set=0, binding=3, type="StorageBuffer", access="Write"
+ }
+}
+
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 * 8)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (8, s1 - s0 * 8)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<()[s0] -> (s0 * 16)>
+// CHECK-DAG: #[[MAP3:.+]] = affine_map<()[s0, s1] -> (16, s1 - s0 * 16)>
+// CHECK: func @three_op_fusion
+// CHECK-DAG: %[[C0:.+]] = constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = constant 1 : index
+// CHECK-DAG: %[[ALLOC:.+]] = alloc() : memref<8x16xf32, 3>
+// CHECK-DAG: %[[ARG0:.+]] = iree.placeholder {{.*}} @legacy_io::@arg0
+// CHECK-DAG: %[[ARG1:.+]] = iree.placeholder {{.*}} @legacy_io::@arg1
+// CHECK-DAG: %[[M:.+]] = dim %[[ARG0]], %[[C0]]
+// CHECK-DAG: %[[N:.+]] = dim %[[ARG1]], %[[C1]]
+// CHECK-DAG: %[[ARG2:.+]] = iree.placeholder {{.*}} @legacy_io::@arg2
+// CHECK-DAG: %[[RET0:.+]] = iree.placeholder {{.*}} @legacy_io::@ret0
+// CHECK-DAG: %[[BIDX:.+]] = "gpu.block_id"() {dimension = "x"}
+// CHECK-DAG: %[[BIDY:.+]] = "gpu.block_id"() {dimension = "y"}
+// CHECK-NOT: scf.parallel
+// CHECK-NOT: scf.for
+// CHECK: %[[LBY:.+]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
+// CHECK: %[[TILE_M:.+]] = affine.min #[[MAP1]]()[%[[BIDY]], %[[M]]]
+// CHECK: %[[LBX:.+]] = affine.apply #[[MAP2]]()[%[[BIDX]]]
+// CHECK: %[[TILE_N:.+]] = affine.min #[[MAP3]]()[%[[BIDX]], %[[N]]]
+// CHECK: %[[N_2:.+]] = dim %[[ARG2]], %[[C0]]
+// CHECK: %[[TILE_N_2:.+]] = affine.min #[[MAP3]]()[%[[BIDX]], %[[N_2]]]
+// CHECK: %[[SV_ARG2:.+]] = subview %[[ARG2]][%[[LBX]]] [%[[TILE_N_2]]]
+// CHECK: %[[M_2:.+]] = dim %[[RET0]], %[[C0]]
+// CHECK: %[[TILE_M_2:.+]] = affine.min #[[MAP1]]()[%[[BIDY]], %[[M_2]]]
+// CHECK: %[[N_3:.+]] = dim %[[RET0]], %[[C1]]
+// CHECK: %[[TILE_N_3:.+]] = affine.min #[[MAP3]]()[%[[BIDX]], %[[N_3]]]
+// CHECK: %[[SV_RET0:.+]] = subview %[[RET0]][%[[LBY]], %[[LBX]]
+// CHECK-SAME: [%[[TILE_M_2]], %[[TILE_N_3]]]
+// CHECK: %[[K:.+]] = dim %[[ARG0]], %[[C1]]
+// CHECK: %[[SV_ARG0:.+]] = subview %[[ARG0]][%[[LBY]], 0]
+// CHECK-SAME: [%[[TILE_M]], %[[K]]]
+// CHECK: %[[SV_ARG1:.+]] = subview %[[ARG1]][0, %[[LBX]]]
+// CHECK-SAME: [%[[K]], %[[TILE_N]]]
+// CHECK: %[[SV_ALLOC:.+]] = subview %[[ALLOC]][0, 0]
+// CHECK-SAME: [%[[TILE_M]], %[[TILE_N]]]
+// CHECK: linalg.fill(%[[SV_ALLOC]], %{{.+}})
+// CHECK-SAME: "workgroup"
+// CHECK: linalg.matmul
+// CHECK-SAME: "workgroup"
+// CHECK-SAME: ins(%[[SV_ARG0]], %[[SV_ARG1]]
+// CHECK-SAME: )
+// CHECK-SAME: outs(%[[SV_ALLOC]]
+// CHECK-SAME: )
+// CHECK: linalg.generic
+// CHECK-SAME: ins(%[[SV_ALLOC]], %[[SV_ARG2]]
+// CHECK-SAME: )
+// CHECK-SAME: outs(%[[SV_RET0]]
+// CHECK-SAME: )
+
+// -----
+
+module attributes {
+ spv.target_env = #spv.target_env<#spv.vce<v1.3, [Shader, Float16, Int16, Int8, StorageBuffer16BitAccess, StorageUniform16, StoragePushConstant16, StorageBuffer8BitAccess, UniformAndStorageBuffer8BitAccess, StoragePushConstant8, GroupNonUniform, VariablePointers, VariablePointersStorageBuffer], [SPV_KHR_16bit_storage, SPV_KHR_8bit_storage, SPV_KHR_storage_buffer_storage_class, SPV_KHR_variable_pointers]>, ARM:IntegratedGPU, {max_compute_shared_memory_size = 32768 : i32, max_compute_workgroup_invocations = 512 : i32, max_compute_workgroup_size = dense<512> : vector<3xi32>, subgroup_size = 16 : i32}>
+} {
+ func @conv_tiled_and_vectorized() attributes {hal.num_workgroups_fn = @get_num_workgroups} {
+ %cst = constant 0.000000e+00 : f32
+ %0 = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<1x112x112x32xf32>
+ %1 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<1x225x225x16xf32>
+ %2 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg1} : memref<3x3x16x32xf32>
+ linalg.fill(%0, %cst) : memref<1x112x112x32xf32>, f32
+ linalg.conv(%2, %1, %0) {dilations = [1, 1], strides = [2, 2]} : memref<3x3x16x32xf32>, memref<1x225x225x16xf32>, memref<1x112x112x32xf32>
+ return
+ }
+
+ func private @get_num_workgroups(!shapex.ranked_shape<[1,225,225,16]>, !shapex.ranked_shape<[3,3,16,32]>, !shapex.ranked_shape<[1,112,112,32]>) -> (index, index, index)
+
+ hal.interface @legacy_io attributes {sym_visibility = "private"} {
+ hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+ hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
+ hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
+ }
+}
+
+// CHECK-LABEL: func @conv_tiled_and_vectorized()
+
+// CHECK-COUNT-4: vector.transfer_read
+
+// check tiling loop along filter height/width and input channel
+// CHECK: scf.for %{{.*}} = %c0 to %c3 step %c1
+// CHECK: scf.for %{{.*}} = %c0 to %c3 step %c1
+// CHECK: scf.for %{{.*}} = %c0 to %c16 step %c4
+
+// CHECK-COUNT-16: vector.contract
+
+// CHECK-COUNT-3: scf.yield
+// CHECK-COUNT-4: vector.transfer_write
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/pipeline_test.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/pipeline_test.mlir
index dc4f5c9..6c9a4e3 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/test/pipeline_test.mlir
+++ b/iree/compiler/Conversion/LinalgToSPIRV/test/pipeline_test.mlir
@@ -46,4 +46,53 @@
// CHECK-COUNT-32: spv.FMul %{{.*}}, %{{.*}} : vector<4xf32>
// CHECK-COUNT-8: spv.Store "StorageBuffer" %{{.*}}, %{{.*}} : vector<4xf32>
+// -----
+module attributes {
+ spv.target_env =
+ #spv.target_env<#spv.vce<v1.3,
+ [Shader, Float64, Float16, Int64, Int16, Int8, StorageBuffer16BitAccess,
+ StorageUniform16, StoragePushConstant16, StorageBuffer8BitAccess,
+ UniformAndStorageBuffer8BitAccess, StoragePushConstant8, GroupNonUniform,
+ GroupNonUniformVote, GroupNonUniformArithmetic, GroupNonUniformBallot,
+ GroupNonUniformShuffle, GroupNonUniformShuffleRelative, VariablePointers,
+ VariablePointersStorageBuffer],
+ [SPV_KHR_16bit_storage, SPV_KHR_8bit_storage,
+ SPV_KHR_storage_buffer_storage_class, SPV_KHR_variable_pointers]>,
+ ARM:IntegratedGPU,
+ {max_compute_shared_memory_size = 32768 : i32,
+ max_compute_workgroup_invocations = 512 : i32,
+ max_compute_workgroup_size = dense<512> : vector<3xi32>,
+ subgroup_size = 16 : i32}>} {
+ func @matmul_fill_fused()
+ attributes {vkspv.num_workgroups_fn = @matmul_fill_fused__num_workgroups__} {
+ %arg0 = iree.placeholder for "interface buffer"
+ {binding = @legacy_io::@arg0, operand_result_num = 0 : i32} : memref<4096x4096xf32>
+ %arg1 = iree.placeholder for "interface buffer"
+ {binding = @legacy_io::@arg1, operand_result_num = 1 : i32} : memref<4096x4096xf32>
+ %ret0 = iree.placeholder for "interface buffer"
+ {binding = @legacy_io::@ret0, operand_result_num = 2 : i32} : memref<4096x4096xf32>
+ %cst = constant 0.000000e+00 : f32
+ linalg.fill(%ret0, %cst) : memref<4096x4096xf32>, f32
+ linalg.matmul ins(%arg0, %arg1 : memref<4096x4096xf32>, memref<4096x4096xf32>)
+ outs(%ret0 : memref<4096x4096xf32>)
+ return
+ }
+ func @matmul_fill_fused__num_workgroups__
+ (!shapex.ranked_shape<[4096, 4096]>, !shapex.ranked_shape<[4096, 4096]>,
+ !shapex.ranked_shape<[4096, 4096]>) -> (index, index, index)
+ attributes {sym_visibility = "private"}
+ hal.interface @legacy_io attributes {sym_visibility = "private"} {
+ hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+ hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
+ hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write"
+ }
+}
+
+// CHECK-LABEL: spv.func @matmul_fill_fused
+// CHECK-NOT: spv.Store "StorageBuffer"
+// CHECK-NOT: spv.Load "StorageBuffer"
+// CHECK: spv.loop
+// CHECK-COUNT-12: spv.Load "StorageBuffer" %{{.*}} : vector<4xf32>
+// CHECK-COUNT-32: spv.FMul %{{.*}}, %{{.*}} : vector<4xf32>
+// CHECK-COUNT-8: spv.Store "StorageBuffer" %{{.*}}, %{{.*}} : vector<4xf32>
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/split_dispatch_function.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/split_dispatch_function.mlir
index 7f7ad16..52779a9 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/test/split_dispatch_function.mlir
+++ b/iree/compiler/Conversion/LinalgToSPIRV/test/split_dispatch_function.mlir
@@ -1,9 +1,11 @@
// RUN: iree-opt -allow-unregistered-dialect -split-input-file -iree-codegen-split-dispatch-function -verify-diagnostics %s | IreeFileCheck %s
module {
- // CHECK: func @kernel_fusable_fill_conv_ops
- // CHECK: linalg.fill
- // CHECK: linalg.conv
+ // CHECK: func @kernel_fusable_fill_conv_ops
+ // CHECK: linalg.fill
+ // CHECK-NOT: return
+ // CHECK: linalg.conv
+ // CHECK: return
func @kernel_fusable_fill_conv_ops()
attributes {hal.num_workgroups_fn = @kernel_fusable_fill_conv_ops_num_workgroups__} {
@@ -20,10 +22,9 @@
linalg.conv(%1, %ts1, %ts2) {dilations = [1, 1], padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>, strides = [2, 2]} : memref<3x3x512x1xf32>, memref<?x2x2x512xf32>, memref<?x1x1x512xf32>
return
}
- func @kernel_fill_conv_ops_num_workgroups__(!shapex.ranked_shape<[?,2,2,512]>,
- !shapex.ranked_shape<[3,3,512,1]>,
- !shapex.ranked_shape<[?,1,1,512]>)
- -> (index, index, index)
+ func @kernel_fusable_fill_conv_ops_num_workgroups__
+ (!shapex.ranked_shape<[?,2,2,512]>, !shapex.ranked_shape<[3,3,512,1]>,
+ !shapex.ranked_shape<[?,1,1,512]>) -> (index, index, index)
attributes {sym_visibility = "private"}
hal.interface @legacy_io attributes {push_constants = 1 : i32, sym_visibility = "private"} {
hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
@@ -35,9 +36,11 @@
// -----
module {
- // CHECK: func @kernel_fusable_fill_matmul_ops
- // CHECK: linalg.fill
- // CHECK: linalg.matmul
+ // CHECK: func @kernel_fusable_fill_matmul_ops
+ // CHECK: linalg.fill
+ // CHECK-NOT: return
+ // CHECK: linalg.matmul
+ // CHECK: return
func @kernel_fusable_fill_matmul_ops()
attributes {hal.num_workgroups_fn = @kernel_fusable_fill_matmul_ops_num_workgroups__} {
@@ -73,9 +76,11 @@
// -----
module {
- // CHECK: func @kernel_fusable_pooling()
- // CHECK: linalg.fill
- // CHECK: linalg.pooling
+ // CHECK: func @kernel_fusable_pooling()
+ // CHECK: linalg.fill
+ // CHECK-NOT: return
+ // CHECK: linalg.pooling
+ // CHECK: return
func @kernel_fusable_pooling() attributes {hal.num_workgroups_fn = @kernel_fusable_pooling__num_workgroups__} {
%cst = constant 0.000000e+00 : f32
%0 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<?x?xf32>
@@ -221,7 +226,6 @@
}
}
-
// -----
// Nothing to do if there is just one Linalg op.
@@ -246,6 +250,8 @@
}
}
+
+
// -----
// Do not split when Linalg and non-Linalg ops are interleaving each other.
@@ -270,7 +276,6 @@
}
// -----
-
#map0 = affine_map<(d0, d1) -> (d0 * 12 + d1 + 53)>
module {
@@ -417,3 +422,62 @@
// CHECK-NEXT: linalg.copy
// CHECK-NOT: linalg
// CHECK: return
+
+// -----
+
+module {
+ // CHECK: func @kernel_fusable_fill_matmul_generic_ops
+ // CHECK: linalg.fill
+ // CHECK-NOT: return
+ // CHECK: linalg.matmul
+ // CHECK-NOT: return
+ // CHECK: linalg.generic
+ // CHECK: return
+
+ func @kernel_fusable_fill_matmul_generic_ops()
+ attributes {hal.num_workgroups_fn = @kernel_fusable_fill_matmul_generic_ops_num_workgroups__} {
+ %cst = constant 0.000000e+00 : f32
+ %dimM = hal.interface.load.constant offset = 0 : index
+ %dimN = hal.interface.load.constant offset = 1 : index
+ %shape1 = shapex.make_ranked_shape %dimM : (index) -> !shapex.ranked_shape<[?,512]>
+ %shape2 = shapex.make_ranked_shape %dimN : (index) -> !shapex.ranked_shape<[512,?]>
+ %shape3 = shapex.make_ranked_shape %dimM, %dimN : (index, index) -> !shapex.ranked_shape<[?,?]>
+ %0 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<?x512xf32>
+ %ts0 = shapex.tie_shape %0, %shape1 : memref<?x512xf32>, !shapex.ranked_shape<[?,512]>
+ %1 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg1} : memref<512x?xf32>
+ %ts1 = shapex.tie_shape %1, %shape2 : memref<512x?xf32>, !shapex.ranked_shape<[512, ?]>
+ %2 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg2} : memref<?x?xf32>
+ %ts2 = shapex.tie_shape %2, %shape3 : memref<?x?xf32>, !shapex.ranked_shape<[?, ?]>
+ %3 = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<?x?xf32>
+ %ts3 = shapex.tie_shape %3, %shape3 : memref<?x?xf32>, !shapex.ranked_shape<[?,?]>
+ %4 = alloc(%dimM, %dimN) : memref<?x?xf32>
+ %ts4 = shapex.tie_shape %4, %shape3 : memref<?x?xf32>, !shapex.ranked_shape<[?,?]>
+ linalg.fill(%ts4, %cst) : memref<?x?xf32>, f32
+ linalg.matmul ins(%ts0, %ts1 : memref<?x512xf32>, memref<512x?xf32>)
+ outs(%ts4 : memref<?x?xf32>)
+ linalg.generic
+ {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+ affine_map<(d0, d1) -> (d0, d1)>,
+ affine_map<(d0, d1) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%ts2, %ts4 : memref<?x?xf32>, memref<?x?xf32>)
+ outs(%ts3 : memref<?x?xf32>) {
+ ^bb0(%arg0 : f32, %arg1 : f32, %arg2 : f32):
+ %5 = addf %arg0, %arg1 : f32
+ linalg.yield %5 : f32
+ }
+ return
+ }
+ func @kernel_fusable_matmul_ops_num_workgroups__(!shapex.ranked_shape<[?,512]>,
+ !shapex.ranked_shape<[512,?]>,
+ !shapex.ranked_shape<[?,?]>,
+ !shapex.ranked_shape<[?,?]>)
+ -> (index, index, index)
+ attributes {sym_visibility = "private"}
+ hal.interface @legacy_io attributes {push_constants = 1 : i32, sym_visibility = "private"} {
+ hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+ hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
+ hal.interface.binding @arg2, set=0, binding=1, type="StorageBuffer", access="Read"
+ hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
+ }
+}
diff --git a/iree/compiler/Conversion/LinalgToVector/BUILD b/iree/compiler/Conversion/LinalgToVector/BUILD
index 95676fb..53eaa8d 100644
--- a/iree/compiler/Conversion/LinalgToVector/BUILD
+++ b/iree/compiler/Conversion/LinalgToVector/BUILD
@@ -22,6 +22,7 @@
name = "LinalgToVector",
srcs = [
"LoadStoreVectorization.cpp",
+ "VectorizeConv.cpp",
],
hdrs = [
"Passes.h",
@@ -32,6 +33,7 @@
"//iree/compiler/Dialect/Shape/IR",
"//iree/compiler/Dialect/Shape/Transforms",
"@llvm-project//llvm:Support",
+ "@llvm-project//mlir:DialectUtils",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:LinalgOps",
"@llvm-project//mlir:LinalgTransforms",
diff --git a/iree/compiler/Conversion/LinalgToVector/CMakeLists.txt b/iree/compiler/Conversion/LinalgToVector/CMakeLists.txt
index b292c28..1e56d8c 100644
--- a/iree/compiler/Conversion/LinalgToVector/CMakeLists.txt
+++ b/iree/compiler/Conversion/LinalgToVector/CMakeLists.txt
@@ -21,6 +21,7 @@
"Passes.h"
SRCS
"LoadStoreVectorization.cpp"
+ "VectorizeConv.cpp"
DEPS
LLVMSupport
MLIRIR
diff --git a/iree/compiler/Conversion/LinalgToVector/Passes.h b/iree/compiler/Conversion/LinalgToVector/Passes.h
index 98f07bd..4f7340b 100644
--- a/iree/compiler/Conversion/LinalgToVector/Passes.h
+++ b/iree/compiler/Conversion/LinalgToVector/Passes.h
@@ -23,6 +23,16 @@
/// Creates a pass to vectorize Linalg operations.
std::unique_ptr<Pass> createLoadStoreVectorizationPass();
+/// Creates a pass to vectorize a very specific form of linalg.conv ops.
+std::unique_ptr<Pass> createVectorizeLinalgConvPass();
+
+/// Populates `patterns` with a very specific pattern that vectorizes a
+/// linalg.conv op for a single thread. The linalg.conv should compute on
+/// static-sized subviews. To match, output shape must be 1x1xWoxCo, where Co
+/// Co is a multiple of 4, and filter shape must be 1x1x4xCo.
+void populateVectorizeLinalgConvPatterns(MLIRContext *context,
+ OwningRewritePatternList &patterns);
+
} // namespace iree_compiler
} // namespace mlir
diff --git a/iree/compiler/Conversion/LinalgToVector/VectorizeConv.cpp b/iree/compiler/Conversion/LinalgToVector/VectorizeConv.cpp
new file mode 100644
index 0000000..6d4bca5
--- /dev/null
+++ b/iree/compiler/Conversion/LinalgToVector/VectorizeConv.cpp
@@ -0,0 +1,234 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "iree/compiler/Conversion/LinalgToVector/Passes.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/Debug.h"
+#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
+#include "mlir/Dialect/Vector/VectorOps.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#define DEBUG_TYPE "iree-vectorize-conv"
+
+namespace mlir {
+namespace iree_compiler {
+
+namespace {
+
+/// Vectorizes linalg.conv for a single GPU invocation. Therefore, the
+/// linalg.conv op should have a very specific form; other patterns are
+/// expected to tile and distribute larger convolutions into this form for
+/// a single GPU invocation.
+///
+/// The linalg.conv op should follow:
+/// - Filter: HfWfCiCo format
+/// - Input : NHiWiCi format
+/// - Output: NHoWoCo format
+/// - For output:
+/// - N must be 1.
+/// - Ho must be 1.
+/// - Co must be a multiple of 4.
+/// - For filter:
+/// - Hf must be 1.
+/// - Hf must be 1.
+/// - Ci must be 4.
+/// - No dilation.
+/// - No padding.
+///
+/// Output channel is requried to be a multiple of 4 so that we can process
+/// them with load4/store4, which is native to GPUs. Similarly for the input
+/// channel size requirement.
+struct VectorizeLinalgConv : OpRewritePattern<linalg::ConvOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(linalg::ConvOp convOp,
+ PatternRewriter &rewriter) const override {
+ LLVM_DEBUG(llvm::dbgs() << "inspecting " << convOp << "\n");
+
+ // This pattern does not handle convolutions with dilation.
+ if (auto dilations = convOp.dilations()) {
+ auto values = dilations->getAsValueRange<IntegerAttr>();
+ if (llvm::any_of(values, [](const APInt &value) {
+ return value.getSExtValue() != 1;
+ })) {
+ return failure();
+ }
+ }
+
+ auto filterViewOp = convOp.filter().getDefiningOp<SubViewOp>();
+ auto inputViewOp = convOp.input().getDefiningOp<SubViewOp>();
+ auto outputViewOp = convOp.output().getDefiningOp<SubViewOp>();
+ if (!filterViewOp || !inputViewOp || !outputViewOp) return failure();
+
+ // The filter/input/output view should have static sizes to vectorize.
+ if (!llvm::empty(filterViewOp.getDynamicSizes()) ||
+ !llvm::empty(inputViewOp.getDynamicSizes()) ||
+ !llvm::empty(outputViewOp.getDynamicSizes())) {
+ return failure();
+ }
+
+ // The output batch and height dimensions should be 1. If not, other
+ // patterns can generate parallel loops can distribute them.
+ if (outputViewOp.getStaticSize(0) != 1 ||
+ outputViewOp.getStaticSize(1) != 1) {
+ return failure();
+ }
+
+ // We addtionally expect the filter height/width dimensions are both 1 to
+ // simplify vectorization. Other patterns can generate loops to create 1x1
+ // filter subivews.
+ if (filterViewOp.getStaticSize(0) != 1 ||
+ filterViewOp.getStaticSize(1) != 1) {
+ return failure();
+ }
+
+ int64_t numInputChannels = filterViewOp.getStaticSize(2);
+ int64_t numOutputChannels = filterViewOp.getStaticSize(3);
+ if (numInputChannels != 4 || numOutputChannels % 4 != 0) return failure();
+
+ int64_t numOutputWidths = outputViewOp.getStaticSize(2);
+ int64_t widthStride = convOp.getStride(1);
+
+ // This invocation handles a batch of (numOutputWidths * numOutputChannels).
+ LLVM_DEBUG({
+ llvm::dbgs() << "# output width: " << numOutputWidths << "\n";
+ llvm::dbgs() << "# output channels: " << numOutputChannels << "\n";
+ llvm::dbgs() << "width stride: " << widthStride << "\n";
+ });
+
+ MLIRContext *context = convOp.getContext();
+ Location loc = convOp.getLoc();
+
+ Type elementType = filterViewOp.getType().getElementType();
+ auto filterVectorType =
+ VectorType::get({numInputChannels, numOutputChannels}, elementType);
+ auto vector1x4Type = VectorType::get({1, 4}, elementType);
+ Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
+
+ // Load the entire filter subview.
+ SmallVector<Value, 4> filterIndices(4, zero);
+ Value wholeFilter = rewriter.create<vector::TransferReadOp>(
+ loc, filterVectorType, filterViewOp, filterIndices);
+
+ // Get filter slices so that later we can use them for dot product with the
+ // input. Both the height and width dimensions are 1; so we just need to
+ // loop over input and output channel dimensions.
+ SmallVector<SmallVector<Value, 1>, 4> filterVectors(numInputChannels);
+ for (int ic = 0; ic < numInputChannels; ++ic) {
+ auto &thisInputChannel = filterVectors[ic];
+ thisInputChannel.reserve(numOutputChannels / 4);
+ for (int oc = 0; oc < numOutputChannels / 4; ++oc) {
+ Value slice = rewriter.create<vector::ExtractStridedSliceOp>(
+ loc, wholeFilter, /*offsets=*/ArrayRef<int64_t>({ic, oc * 4}),
+ /*sizes=*/ArrayRef<int64_t>({1, 4}),
+ /*strides=*/ArrayRef<int64_t>({1, 1}));
+ thisInputChannel.push_back(slice);
+ }
+ }
+
+ // Build indexing maps for a later vector contraction op.
+ AffineExpr dim0 = getAffineDimExpr(0, context); // M
+ AffineExpr dim1 = getAffineDimExpr(1, context); // N
+ AffineExpr dim2 = getAffineDimExpr(2, context); // K
+ auto map02 = AffineMap::get(3, 0, {dim0, dim2}, context);
+ auto map21 = AffineMap::get(3, 0, {dim2, dim1}, context);
+ auto map01 = AffineMap::get(3, 0, {dim0, dim1}, context);
+ ArrayAttr indexingMaps =
+ rewriter.getAffineMapArrayAttr({map02, map21, map01});
+
+ // Also build iterator types for the vector contraction op.
+ ArrayAttr iterators = rewriter.getStrArrayAttr(
+ {getParallelIteratorTypeName(), getParallelIteratorTypeName(),
+ getReductionIteratorTypeName()});
+
+ // Compute the (numOutputWidths * numOutputChannels) batch. We only
+ // contribute numInputChannels accumulation along the reduction dimension.
+ // So read in the result from the output, compose a chain of
+ // numInputChannels vector dot operations, and then write out.
+ for (int ow = 0; ow < numOutputWidths; ++ow) {
+ // Read in the input vector for these 4 input channels a a batch. The
+ // input vector are used for computing all output channels so data can
+ // be reused.
+ SmallVector<Value, 4> inputIndices(4, zero);
+ inputIndices[2] = rewriter.create<ConstantIndexOp>(loc, ow * widthStride);
+ Value inputVector = rewriter.create<vector::TransferReadOp>(
+ loc, vector1x4Type, inputViewOp, inputIndices);
+
+ for (int oc = 0; oc < numOutputChannels / 4; ++oc) {
+ // Read in the initial value for this output vector.
+ SmallVector<Value, 4> outputIndices(4, zero);
+ outputIndices[2] = rewriter.create<ConstantIndexOp>(loc, ow);
+ outputIndices[3] = rewriter.create<ConstantIndexOp>(loc, oc * 4);
+ Value outputVector = rewriter.create<vector::TransferReadOp>(
+ loc, vector1x4Type, outputViewOp, outputIndices);
+
+ // Peform a chain of dot product and accumulation.
+ for (int i = 0; i < numInputChannels; ++i) {
+ auto inputSlice = rewriter.create<vector::ExtractStridedSliceOp>(
+ loc, inputVector, /*offsets=*/ArrayRef<int64_t>({0, i}),
+ /*sizes=*/ArrayRef<int64_t>({1, 1}),
+ /*strides=*/ArrayRef<int64_t>({1, 1}));
+ outputVector = rewriter.create<vector::ContractionOp>(
+ loc, inputSlice, filterVectors[i][oc], outputVector, indexingMaps,
+ iterators);
+ }
+
+ // Write out the output vector.
+ rewriter.create<vector::TransferWriteOp>(loc, outputVector,
+ outputViewOp, outputIndices);
+ }
+ }
+
+ rewriter.eraseOp(convOp);
+ return success();
+ }
+};
+
+struct VectorizeLinalgConvPass
+ : public PassWrapper<VectorizeLinalgConvPass, OperationPass<FuncOp>> {
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<linalg::LinalgDialect, vector::VectorDialect>();
+ }
+
+ void runOnOperation() override {
+ MLIRContext *context = &getContext();
+ OwningRewritePatternList patterns;
+ patterns.insert<VectorizeLinalgConv>(context);
+ applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+ }
+};
+
+} // namespace
+
+void populateVectorizeLinalgConvPatterns(MLIRContext *context,
+ OwningRewritePatternList &patterns) {
+ patterns.insert<VectorizeLinalgConv>(context);
+}
+
+std::unique_ptr<Pass> createVectorizeLinalgConvPass() {
+ return std::make_unique<VectorizeLinalgConvPass>();
+}
+
+static PassRegistration<VectorizeLinalgConvPass> pass(
+ "iree-codegen-vectorize-linalg-conv",
+ "Vectorize a very specific form of linalg.conv");
+
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/iree/compiler/Conversion/LinalgToVector/test/vectorize_linalg_conv.mlir b/iree/compiler/Conversion/LinalgToVector/test/vectorize_linalg_conv.mlir
new file mode 100644
index 0000000..194607a
--- /dev/null
+++ b/iree/compiler/Conversion/LinalgToVector/test/vectorize_linalg_conv.mlir
@@ -0,0 +1,142 @@
+// RUN: iree-opt -split-input-file -iree-codegen-vectorize-linalg-conv -canonicalize -cse %s | IreeFileCheck %s
+
+func @vectorize_conv(%filter: memref<1x1x4x4xf32>, %input: memref<1x1x7x4xf32>, %output: memref<1x1x4x4xf32>) {
+ %0 = subview %filter[0, 0, 0, 0] [1, 1, 4, 4] [1, 1, 1, 1] : memref<1x1x4x4xf32> to memref<1x1x4x4xf32>
+ %1 = subview %input[0, 0, 0, 0] [1, 1, 7, 4] [1, 1, 1, 1] : memref<1x1x7x4xf32> to memref<1x1x7x4xf32>
+ %2 = subview %output[0, 0, 0, 0] [1, 1, 4, 4] [1, 1, 1, 1] : memref<1x1x4x4xf32> to memref<1x1x4x4xf32>
+ linalg.conv(%0, %1, %2) {dilations = [1, 1], strides = [2, 2]} : memref<1x1x4x4xf32>, memref<1x1x7x4xf32>, memref<1x1x4x4xf32>
+ return
+}
+
+// CHECK: #map0 = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK: #map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+// CHECK: #map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK: func @vectorize_conv
+// CHECK-SAME: %[[FILTER_ARG:.+]]: memref<1x1x4x4xf32>,
+// CHECK-SAME: %[[INPUT_ARG:.+]]: memref<1x1x7x4xf32>,
+// CHECK-SAME: %[[OUTPUT_ARG:.+]]: memref<1x1x4x4xf32>
+
+// CHECK: %[[FLOAT_ZERO:.+]] = constant 0.000000e+00 : f32
+// CHECK: %[[FILTER:.+]] = subview %[[FILTER_ARG]]
+// CHECK: %[[INPUT:.+]] = subview %[[INPUT_ARG]]
+// CHECK: %[[OUTPUT:.+]] = subview %[[OUTPUT_ARG]]
+
+// Read in the filter and get slices
+// CHECK: %[[FILTER_VECTOR:.+]] = vector.transfer_read %[[FILTER]][%c0, %c0, %c0, %c0], %cst {masked = [false, false]} : memref<1x1x4x4xf32>, vector<4x4xf32>
+// CHECK: %[[FILTER_0:.+]] = vector.extract_strided_slice %[[FILTER_VECTOR]] {offsets = [0, 0], sizes = [1, 4], strides = [1, 1]} : vector<4x4xf32> to vector<1x4xf32>
+// CHECK: %[[FILTER_1:.+]] = vector.extract_strided_slice %[[FILTER_VECTOR]] {offsets = [1, 0], sizes = [1, 4], strides = [1, 1]} : vector<4x4xf32> to vector<1x4xf32>
+// CHECK: %[[FILTER_2:.+]] = vector.extract_strided_slice %[[FILTER_VECTOR]] {offsets = [2, 0], sizes = [1, 4], strides = [1, 1]} : vector<4x4xf32> to vector<1x4xf32>
+// CHECK: %[[FILTER_3:.+]] = vector.extract_strided_slice %[[FILTER_VECTOR]] {offsets = [3, 0], sizes = [1, 4], strides = [1, 1]} : vector<4x4xf32> to vector<1x4xf32>
+
+// Handle batch #0
+// CHECK: %[[INPUT_0:.+]] = vector.transfer_read %[[INPUT]][%c0, %c0, %c0, %c0], %[[FLOAT_ZERO]] {masked = [false, false]} : memref<1x1x7x4xf32>, vector<1x4xf32>
+// CHECK: %[[OUTPUT_0:.+]] = vector.transfer_read %[[OUTPUT]][%c0, %c0, %c0, %c0], %[[FLOAT_ZERO]] {masked = [false, false]} : memref<1x1x4x4xf32>, vector<1x4xf32>
+// CHECK: %[[INPUT_0_0:.+]] = vector.extract_strided_slice %[[INPUT_0]] {offsets = [0, 0], sizes = [1, 1], strides = [1, 1]} : vector<1x4xf32> to vector<1x1xf32>
+// CHECK: %[[DOT_0:.+]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[INPUT_0_0]], %[[FILTER_0]], %[[OUTPUT_0]] : vector<1x1xf32>, vector<1x4xf32> into vector<1x4xf32>
+// CHECK: %[[INPUT_0_1:.+]] = vector.extract_strided_slice %[[INPUT_0]] {offsets = [0, 1], sizes = [1, 1], strides = [1, 1]} : vector<1x4xf32> to vector<1x1xf32>
+// CHECK: %[[DOT_1:.+]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[INPUT_0_1]], %[[FILTER_1]], %[[DOT_0]] : vector<1x1xf32>, vector<1x4xf32> into vector<1x4xf32>
+// CHECK: %[[INPUT_0_2:.+]] = vector.extract_strided_slice %[[INPUT_0]] {offsets = [0, 2], sizes = [1, 1], strides = [1, 1]} : vector<1x4xf32> to vector<1x1xf32>
+// CHECK: %[[DOT_2:.+]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[INPUT_0_2]], %[[FILTER_2]], %[[DOT_1]] : vector<1x1xf32>, vector<1x4xf32> into vector<1x4xf32>
+// CHECK: %[[INPUT_0_3:.+]] = vector.extract_strided_slice %[[INPUT_0]] {offsets = [0, 3], sizes = [1, 1], strides = [1, 1]} : vector<1x4xf32> to vector<1x1xf32>
+// CHECK: %[[DOT_3:.+]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[INPUT_0_3]], %[[FILTER_3]], %[[DOT_2]] : vector<1x1xf32>, vector<1x4xf32> into vector<1x4xf32>
+// CHECK: vector.transfer_write %[[DOT_3]], %[[OUTPUT]][%c0, %c0, %c0, %c0] {masked = [false, false]} : vector<1x4xf32>, memref<1x1x4x4xf32>
+
+// Handle batch #1
+// CHECK: %[[INPUT_1:.+]] = vector.transfer_read %[[INPUT]][%c0, %c0, %c2, %c0], %[[FLOAT_ZERO]] {masked = [false, false]} : memref<1x1x7x4xf32>, vector<1x4xf32>
+// CHECK: %[[OUTPUT_1:.+]] = vector.transfer_read %[[OUTPUT]][%c0, %c0, %c1, %c0], %[[FLOAT_ZERO]] {masked = [false, false]} : memref<1x1x4x4xf32>, vector<1x4xf32>
+// CHECK: %[[INPUT_1_0:.+]] = vector.extract_strided_slice %[[INPUT_1]] {offsets = [0, 0], sizes = [1, 1], strides = [1, 1]} : vector<1x4xf32> to vector<1x1xf32>
+// CHECK: %[[DOT_0:.+]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[INPUT_1_0]], %[[FILTER_0]], %[[OUTPUT_1]] : vector<1x1xf32>, vector<1x4xf32> into vector<1x4xf32>
+// CHECK: %[[INPUT_1_1:.+]] = vector.extract_strided_slice %[[INPUT_1]] {offsets = [0, 1], sizes = [1, 1], strides = [1, 1]} : vector<1x4xf32> to vector<1x1xf32>
+// CHECK: %[[DOT_1:.+]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[INPUT_1_1]], %[[FILTER_1]], %[[DOT_0]] : vector<1x1xf32>, vector<1x4xf32> into vector<1x4xf32>
+// CHECK: %[[INPUT_1_2:.+]] = vector.extract_strided_slice %[[INPUT_1]] {offsets = [0, 2], sizes = [1, 1], strides = [1, 1]} : vector<1x4xf32> to vector<1x1xf32>
+// CHECK: %[[DOT_2:.+]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[INPUT_1_2]], %[[FILTER_2]], %[[DOT_1]] : vector<1x1xf32>, vector<1x4xf32> into vector<1x4xf32>
+// CHECK: %[[INPUT_1_3:.+]] = vector.extract_strided_slice %[[INPUT_1]] {offsets = [0, 3], sizes = [1, 1], strides = [1, 1]} : vector<1x4xf32> to vector<1x1xf32>
+// CHECK: %[[DOT_3:.+]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[INPUT_1_3]], %[[FILTER_3]], %[[DOT_2]] : vector<1x1xf32>, vector<1x4xf32> into vector<1x4xf32>
+// CHECK: vector.transfer_write %[[DOT_3]], %[[OUTPUT]][%c0, %c0, %c1, %c0] {masked = [false, false]} : vector<1x4xf32>, memref<1x1x4x4xf32>
+
+// Handle batch #2
+// CHECK: %[[INPUT_2:.+]] = vector.transfer_read %[[INPUT]][%c0, %c0, %c4, %c0], %[[FLOAT_ZERO]] {masked = [false, false]} : memref<1x1x7x4xf32>, vector<1x4xf32>
+// CHECK: %[[OUTPUT_2:.+]] = vector.transfer_read %[[OUTPUT]][%c0, %c0, %c2, %c0], %[[FLOAT_ZERO]] {masked = [false, false]} : memref<1x1x4x4xf32>, vector<1x4xf32>
+// CHECK: %[[INPUT_2_0:.+]] = vector.extract_strided_slice %[[INPUT_2]] {offsets = [0, 0], sizes = [1, 1], strides = [1, 1]} : vector<1x4xf32> to vector<1x1xf32>
+// CHECK: %[[DOT_0:.+]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[INPUT_2_0]], %[[FILTER_0]], %[[OUTPUT_2]] : vector<1x1xf32>, vector<1x4xf32> into vector<1x4xf32>
+// CHECK: %[[INPUT_2_1:.+]] = vector.extract_strided_slice %[[INPUT_2]] {offsets = [0, 1], sizes = [1, 1], strides = [1, 1]} : vector<1x4xf32> to vector<1x1xf32>
+// CHECK: %[[DOT_1:.+]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[INPUT_2_1]], %[[FILTER_1]], %[[DOT_0]] : vector<1x1xf32>, vector<1x4xf32> into vector<1x4xf32>
+// CHECK: %[[INPUT_2_2:.+]] = vector.extract_strided_slice %[[INPUT_2]] {offsets = [0, 2], sizes = [1, 1], strides = [1, 1]} : vector<1x4xf32> to vector<1x1xf32>
+// CHECK: %[[DOT_2:.+]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[INPUT_2_2]], %[[FILTER_2]], %[[DOT_1]] : vector<1x1xf32>, vector<1x4xf32> into vector<1x4xf32>
+// CHECK: %[[INPUT_2_3:.+]] = vector.extract_strided_slice %[[INPUT_2]] {offsets = [0, 3], sizes = [1, 1], strides = [1, 1]} : vector<1x4xf32> to vector<1x1xf32>
+// CHECK: %[[DOT_3:.+]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[INPUT_2_3]], %[[FILTER_3]], %[[DOT_2]] : vector<1x1xf32>, vector<1x4xf32> into vector<1x4xf32>
+// CHECK: vector.transfer_write %[[DOT_3]], %[[OUTPUT]][%c0, %c0, %c2, %c0] {masked = [false, false]} : vector<1x4xf32>, memref<1x1x4x4xf32>
+
+// Handle batch #3
+// CHECK: %[[INPUT_3:.+]] = vector.transfer_read %[[INPUT]][%c0, %c0, %c6, %c0], %[[FLOAT_ZERO]] {masked = [false, false]} : memref<1x1x7x4xf32>, vector<1x4xf32>
+// CHECK: %[[OUTPUT_3:.+]] = vector.transfer_read %[[OUTPUT]][%c0, %c0, %c3, %c0], %[[FLOAT_ZERO]] {masked = [false, false]} : memref<1x1x4x4xf32>, vector<1x4xf32>
+// CHECK: %[[INPUT_3_0:.+]] = vector.extract_strided_slice %[[INPUT_3]] {offsets = [0, 0], sizes = [1, 1], strides = [1, 1]} : vector<1x4xf32> to vector<1x1xf32>
+// CHECK: %[[DOT_0:.+]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[INPUT_3_0]], %[[FILTER_0]], %[[OUTPUT_3]] : vector<1x1xf32>, vector<1x4xf32> into vector<1x4xf32>
+// CHECK: %[[INPUT_3_1:.+]] = vector.extract_strided_slice %[[INPUT_3]] {offsets = [0, 1], sizes = [1, 1], strides = [1, 1]} : vector<1x4xf32> to vector<1x1xf32>
+// CHECK: %[[DOT_1:.+]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[INPUT_3_1]], %[[FILTER_1]], %[[DOT_0]] : vector<1x1xf32>, vector<1x4xf32> into vector<1x4xf32>
+// CHECK: %[[INPUT_3_2:.+]] = vector.extract_strided_slice %[[INPUT_3]] {offsets = [0, 2], sizes = [1, 1], strides = [1, 1]} : vector<1x4xf32> to vector<1x1xf32>
+// CHECK: %[[DOT_2:.+]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[INPUT_3_2]], %[[FILTER_2]], %[[DOT_1]] : vector<1x1xf32>, vector<1x4xf32> into vector<1x4xf32>
+// CHECK: %[[INPUT_3_3:.+]] = vector.extract_strided_slice %[[INPUT_3]] {offsets = [0, 3], sizes = [1, 1], strides = [1, 1]} : vector<1x4xf32> to vector<1x1xf32>
+// CHECK: %[[DOT_3:.+]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[INPUT_3_3]], %[[FILTER_3]], %[[DOT_2]] : vector<1x1xf32>, vector<1x4xf32> into vector<1x4xf32>
+// CHECK: vector.transfer_write %[[DOT_3]], %[[OUTPUT]][%c0, %c0, %c3, %c0] {masked = [false, false]} : vector<1x4xf32>, memref<1x1x4x4xf32>
+
+// -----
+
+// CHECK-LABEL: func @do_not_vectorize_conv_with_non_1_batch
+func @do_not_vectorize_conv_with_non_1_batch(%filter: memref<1x1x4x4xf32>, %input: memref<2x1x7x4xf32>, %output: memref<2x1x4x4xf32>) {
+ %0 = subview %filter[0, 0, 0, 0] [1, 1, 4, 4] [1, 1, 1, 1] : memref<1x1x4x4xf32> to memref<1x1x4x4xf32>
+ %1 = subview %input[0, 0, 0, 0] [2, 1, 7, 4] [1, 1, 1, 1] : memref<2x1x7x4xf32> to memref<2x1x7x4xf32>
+ %2 = subview %output[0, 0, 0, 0] [2, 1, 4, 4] [1, 1, 1, 1] : memref<2x1x4x4xf32> to memref<2x1x4x4xf32>
+ // CHECK: linalg.conv
+ linalg.conv(%0, %1, %2) {dilations = [1, 1], strides = [2, 2]} : memref<1x1x4x4xf32>, memref<2x1x7x4xf32>, memref<2x1x4x4xf32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func @do_not_vectorize_conv_with_non_1_output_height
+func @do_not_vectorize_conv_with_non_1_output_height(%filter: memref<1x1x4x4xf32>, %input: memref<1x3x7x4xf32>, %output: memref<1x2x4x4xf32>) {
+ %0 = subview %filter[0, 0, 0, 0] [1, 1, 4, 4] [1, 1, 1, 1] : memref<1x1x4x4xf32> to memref<1x1x4x4xf32>
+ %1 = subview %input[0, 0, 0, 0] [1, 3, 7, 4] [1, 1, 1, 1] : memref<1x3x7x4xf32> to memref<1x3x7x4xf32>
+ %2 = subview %output[0, 0, 0, 0] [1, 2, 4, 4] [1, 1, 1, 1] : memref<1x2x4x4xf32> to memref<1x2x4x4xf32>
+ // CHECK: linalg.conv
+ linalg.conv(%0, %1, %2) {dilations = [1, 1], strides = [2, 2]} : memref<1x1x4x4xf32>, memref<1x3x7x4xf32>, memref<1x2x4x4xf32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func @do_not_vectorize_conv_with_non_1_filter_height
+func @do_not_vectorize_conv_with_non_1_filter_height(%filter: memref<2x1x4x4xf32>, %input: memref<1x2x7x4xf32>, %output: memref<1x1x4x4xf32>) {
+ %0 = subview %filter[0, 0, 0, 0] [2, 1, 4, 4] [1, 1, 1, 1] : memref<2x1x4x4xf32> to memref<2x1x4x4xf32>
+ %1 = subview %input[0, 0, 0, 0] [1, 2, 7, 4] [1, 1, 1, 1] : memref<1x2x7x4xf32> to memref<1x2x7x4xf32>
+ %2 = subview %output[0, 0, 0, 0] [1, 1, 4, 4] [1, 1, 1, 1] : memref<1x1x4x4xf32> to memref<1x1x4x4xf32>
+ // CHECK: linalg.conv
+ linalg.conv(%0, %1, %2) {dilations = [1, 1], strides = [2, 2]} : memref<2x1x4x4xf32>, memref<1x2x7x4xf32>, memref<1x1x4x4xf32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func @do_not_vectorize_conv_with_non_1_filter_width
+func @do_not_vectorize_conv_with_non_1_filter_width(%filter: memref<1x2x4x4xf32>, %input: memref<1x1x8x4xf32>, %output: memref<1x1x4x4xf32>) {
+ %0 = subview %filter[0, 0, 0, 0] [1, 2, 4, 4] [1, 1, 1, 1] : memref<1x2x4x4xf32> to memref<1x2x4x4xf32>
+ %1 = subview %input[0, 0, 0, 0] [1, 1, 8, 4] [1, 1, 1, 1] : memref<1x1x8x4xf32> to memref<1x1x8x4xf32>
+ %2 = subview %output[0, 0, 0, 0] [1, 1, 4, 4] [1, 1, 1, 1] : memref<1x1x4x4xf32> to memref<1x1x4x4xf32>
+ // CHECK: linalg.conv
+ linalg.conv(%0, %1, %2) {dilations = [1, 1], strides = [2, 2]} : memref<1x2x4x4xf32>, memref<1x1x8x4xf32>, memref<1x1x4x4xf32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func @do_not_vectorize_conv_with_non_1_dilation
+func @do_not_vectorize_conv_with_non_1_dilation(%filter: memref<1x1x4x4xf32>, %input: memref<1x1x7x4xf32>, %output: memref<1x1x4x4xf32>) {
+ %0 = subview %filter[0, 0, 0, 0] [1, 1, 4, 4] [1, 1, 1, 1] : memref<1x1x4x4xf32> to memref<1x1x4x4xf32>
+ %1 = subview %input[0, 0, 0, 0] [1, 1, 7, 4] [1, 1, 1, 1] : memref<1x1x7x4xf32> to memref<1x1x7x4xf32>
+ %2 = subview %output[0, 0, 0, 0] [1, 1, 4, 4] [1, 1, 1, 1] : memref<1x1x4x4xf32> to memref<1x1x4x4xf32>
+ // CHECK: linalg.conv
+ linalg.conv(%0, %1, %2) {dilations = [2, 1], strides = [2, 2]} : memref<1x1x4x4xf32>, memref<1x1x7x4xf32>, memref<1x1x4x4xf32>
+ return
+}
diff --git a/iree/compiler/Conversion/init_conversions.h b/iree/compiler/Conversion/init_conversions.h
index 6d72146..d53c0ee 100644
--- a/iree/compiler/Conversion/init_conversions.h
+++ b/iree/compiler/Conversion/init_conversions.h
@@ -33,6 +33,7 @@
createDecomposeHLOClampPass();
createHLOToLinalgOnBuffersPass();
createHLOToLinalgOnTensorsPass();
+ createDemoteF32ToF16Pass();
}
inline void registerLinalgToVectorPasses() {
diff --git a/iree/compiler/Dialect/Flow/Transforms/DispatchConfig.cpp b/iree/compiler/Dialect/Flow/Transforms/DispatchConfig.cpp
index 049b5c1..4750793 100644
--- a/iree/compiler/Dialect/Flow/Transforms/DispatchConfig.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/DispatchConfig.cpp
@@ -17,12 +17,19 @@
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
#include "iree/compiler/Dialect/Flow/Utils/DispatchUtils.h"
#include "iree/compiler/Dialect/Shape/IR/ShapeOps.h"
+#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#define DEBUG_TYPE "iree-detail"
+static llvm::cl::opt<bool> clEnableMatmulFusion(
+ "iree-enable-matmul-fusion",
+ llvm::cl::desc("Flag to enable fusion of matmul with its consumers, "
+ "experimental flag to evaluate fusion"),
+ llvm::cl::init(false));
+
namespace mlir {
namespace iree_compiler {
namespace IREE {
@@ -104,7 +111,7 @@
}
int OpDispatchPolicy::getAnchorBenefit(Operation *op) {
- if (isUnsupportedFusionOp(op)) {
+ if (isUnsupportedFusionOp(op) || isFusableWithConsumersOnly(op)) {
return 100;
}
@@ -142,6 +149,9 @@
if (isUnsupportedFusionOp(anchorOp) || isUnsupportedFusionOp(inputOp)) {
return FusionType::DISABLED;
}
+ if (isFusableWithConsumersOnly(anchorOp) && !isa<mhlo::ReshapeOp>(inputOp)) {
+ return FusionType::DISABLED;
+ }
// By default for operands, they are duplicated into the dispatch region.
// Typically at the initial fusion stage, there is not a sufficient cost
@@ -167,6 +177,10 @@
if (isUnsupportedFusionOp(anchorOp) || isUnsupportedFusionOp(outputOp)) {
return FusionType::DISABLED;
}
+ if (isFusableWithConsumersOnly(anchorOp) &&
+ !isFusableWithConsumersOnly(outputOp)) {
+ return FusionType::MOVE_INTO;
+ }
// Generally, it is hard to reason locally about the legality of fusing an
// output, since additional analysis may need to be done to determine
@@ -176,12 +190,16 @@
return FusionType::DISABLED;
}
+bool OpDispatchPolicy::isFusableWithConsumersOnly(Operation *op) {
+ return clEnableMatmulFusion && isa<mhlo::DotOp>(op);
+}
+
// TODO(b/144530470): replace with tablegen attributes/interfaces.
bool OpDispatchPolicy::isUnsupportedFusionOp(Operation *op) {
- return isa<mhlo::ConcatenateOp, mhlo::ConvOp, mhlo::DotGeneralOp, mhlo::DotOp,
- mhlo::PadOp, mhlo::ReduceOp, mhlo::ReduceWindowOp,
- mhlo::TorchIndexSelectOp>(op) ||
- isRootOnlyOp(op);
+ return isa<mhlo::ConcatenateOp, mhlo::ConvOp, mhlo::DotGeneralOp, mhlo::PadOp,
+ mhlo::ReduceOp, mhlo::ReduceWindowOp, mhlo::TorchIndexSelectOp>(
+ op) ||
+ (!clEnableMatmulFusion && isa<mhlo::DotOp>(op)) || isRootOnlyOp(op);
}
bool OpDispatchPolicy::isRootOnlyOp(Operation *op) {
diff --git a/iree/compiler/Dialect/Flow/Transforms/DispatchConfig.h b/iree/compiler/Dialect/Flow/Transforms/DispatchConfig.h
index ee9299f..36e8518 100644
--- a/iree/compiler/Dialect/Flow/Transforms/DispatchConfig.h
+++ b/iree/compiler/Dialect/Flow/Transforms/DispatchConfig.h
@@ -41,6 +41,9 @@
OpDispatchPolicy(Dispatchability &dispatchability)
: dispatchability(dispatchability) {}
+ // Returns true if |op| is only fusable with its consumers.
+ static bool isFusableWithConsumersOnly(Operation *op);
+
// Returns true if |op| is not able to fuse with either producer or consumer.
static bool isUnsupportedFusionOp(Operation *op);
diff --git a/iree/compiler/Dialect/Flow/Transforms/FoldCompatibleDispatchRegions.cpp b/iree/compiler/Dialect/Flow/Transforms/FoldCompatibleDispatchRegions.cpp
index 0982a12..2cbfaba 100644
--- a/iree/compiler/Dialect/Flow/Transforms/FoldCompatibleDispatchRegions.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/FoldCompatibleDispatchRegions.cpp
@@ -202,7 +202,8 @@
for (auto &block : regionOp.body().getBlocks()) {
for (auto &op : block) {
// A root only op is mergable.
- if (OpDispatchPolicy::isUnsupportedFusionOp(&op) &&
+ if ((OpDispatchPolicy::isUnsupportedFusionOp(&op) ||
+ OpDispatchPolicy::isFusableWithConsumersOnly(&op)) &&
!OpDispatchPolicy::isRootOnlyOp(&op)) {
return false;
}
diff --git a/iree/compiler/Dialect/Flow/Transforms/IdentifyDispatchRegions2.cpp b/iree/compiler/Dialect/Flow/Transforms/IdentifyDispatchRegions2.cpp
index 6da8bef..f604b39 100644
--- a/iree/compiler/Dialect/Flow/Transforms/IdentifyDispatchRegions2.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/IdentifyDispatchRegions2.cpp
@@ -254,6 +254,44 @@
return success();
}
+// Inlining an op into a dispatch region makes the operands of the op the
+// operands of the dispatch region (if the operands arent already defined in the
+// dispatch region). The dispatch region has to be moved just after the last
+// defined operand for SSA value use to be valid.
+static LogicalResult moveDispatchOp(DispatchRegionOp dispatchRegionOp,
+ Operation *inlinedOp) {
+ // Check the operation that is the lexicographically first to produce an
+ // operand to the inlinedOp
+ Optional<Operation *> lastOperandDef = llvm::None;
+ for (Value operand : inlinedOp->getOperands()) {
+ if (Operation *definingOp = operand.getDefiningOp()) {
+ if (!lastOperandDef ||
+ lastOperandDef.getValue()->isBeforeInBlock(definingOp)) {
+ lastOperandDef = definingOp;
+ }
+ }
+ }
+ // If the last operand def is already before the dispatch region, there is
+ // nothing to do.
+ if (!lastOperandDef ||
+ lastOperandDef.getValue()->isBeforeInBlock(dispatchRegionOp)) {
+ return success();
+ }
+
+ // The dispatch region needs to be moved after the lastOperandDef, but before
+ // the first use.
+ Optional<Operation *> firstUse = llvm::None;
+ for (Operation *user : dispatchRegionOp.getOperation()->getUsers()) {
+ if (!firstUse || user->isBeforeInBlock(*firstUse)) {
+ firstUse = user;
+ }
+ }
+ if (firstUse && firstUse.getValue()->isBeforeInBlock(*lastOperandDef))
+ return failure();
+ dispatchRegionOp.getOperation()->moveAfter(lastOperandDef.getValue());
+ return success();
+}
+
LogicalResult fuseOutputs(DispatchRegion &dispatchRegion,
OpDispatchPolicy &policy) {
LLVM_DEBUG(llvm::dbgs() << "++ FUSING OUTPUT\n");
@@ -275,6 +313,11 @@
return nextOp->emitError()
<< "cannot fuse output except with MOVE_INTO action";
}
+ if (failed(moveDispatchOp(dispatchRegion.op, nextOp))) {
+ LLVM_DEBUG(llvm::dbgs() << "- SKIP Fusion due to SSA use-def violation "
+ << *nextOp << "\n");
+ continue;
+ }
LLVM_DEBUG(llvm::dbgs() << "- FUSABLE OUTPUT(" << static_cast<int>(action)
<< "): " << *nextOp << "\n");
// Since results will be redirected to the region results, need to scan
diff --git a/iree/compiler/Dialect/Flow/Transforms/test/identify_dispatch_regions2_enable_matmul_fusion.mlir b/iree/compiler/Dialect/Flow/Transforms/test/identify_dispatch_regions2_enable_matmul_fusion.mlir
new file mode 100644
index 0000000..f8d9fa8
--- /dev/null
+++ b/iree/compiler/Dialect/Flow/Transforms/test/identify_dispatch_regions2_enable_matmul_fusion.mlir
@@ -0,0 +1,109 @@
+// RUN: iree-opt -split-input-file -iree-flow-dispatchability-analysis -iree-flow-identify-dispatch-regions2 -iree-enable-matmul-fusion %s | IreeFileCheck %s
+
+func @simpleDotAddMul
+ (%arg0 : tensor<16x32xf32>, %arg1 : tensor<32x48xf32>,
+ %arg2 : tensor<16x48xf32>, %arg3 : tensor<16x48xf32>) -> tensor<16x48xf32> {
+ %0 = "mhlo.dot"(%arg0, %arg1) :
+ (tensor<16x32xf32>, tensor<32x48xf32>) -> tensor<16x48xf32>
+ %1 = mhlo.add %0, %arg2 : tensor<16x48xf32>
+ %2 = mhlo.multiply %1, %arg3 : tensor<16x48xf32>
+ return %2 : tensor<16x48xf32>
+}
+// CHECK-LABEL: func @simpleDotAddMul
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<16x32xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<32x48xf32>
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor<16x48xf32>
+// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: tensor<16x48xf32>
+// CHECK-NEXT: %[[WORKLOAD:.+]] = constant 768
+// CHECK-NEXT: %[[RESULT:.+]] = flow.dispatch.region[%[[WORKLOAD]] : index]
+// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]] = %[[ARG0]]
+// CHECK-SAME: %[[ARG5:[a-zA-Z0-9_]+]] = %[[ARG1]]
+// CHECK-SAME: %[[ARG6:[a-zA-Z0-9_]+]] = %[[ARG2]]
+// CHECK-SAME: %[[ARG7:[a-zA-Z0-9_]+]] = %[[ARG3]]
+// CHECK-SAME: {
+// CHECK-NEXT: %[[T1:.+]] = "mhlo.dot"(%[[ARG4]], %[[ARG5]])
+// CHECK-NEXT: %[[T2:.+]] = mhlo.add %[[T1]], %[[ARG6]]
+// CHECK-NEXT: %[[T3:.+]] = mhlo.multiply %[[T2]], %[[ARG7]]
+// CHECK-NEXT: flow.return %[[T3]]
+// CHECK-NEXT: }
+// CHECK-NEXT: return %[[RESULT]]
+
+// -----
+
+func @twoDots
+ (%arg0 : tensor<16x32xf32>, %arg1 : tensor<32x48xf32>,
+ %arg2 : tensor<16x48xf32>, %arg3 : tensor<16x64xf32>,
+ %arg4 : tensor<16x64xf32>) -> tensor<16x64xf32> {
+ %0 = "mhlo.dot"(%arg0, %arg1) :
+ (tensor<16x32xf32>, tensor<32x48xf32>) -> tensor<16x48xf32>
+ %1 = mhlo.add %0, %arg2 : tensor<16x48xf32>
+ %2 = "mhlo.dot"(%1, %arg3) :
+ (tensor<16x48xf32>, tensor<16x64xf32>) -> tensor<16x64xf32>
+ %3 = mhlo.multiply %2, %arg4 : tensor<16x64xf32>
+ return %3 : tensor<16x64xf32>
+}
+// CHECK-LABEL: func @twoDots
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<16x32xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<32x48xf32>
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor<16x48xf32>
+// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: tensor<16x64xf32>
+// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: tensor<16x64xf32>
+// CHECK-NEXT: %[[WORKLOAD1:.+]] = constant 1024
+// CHECK-NEXT: %[[WORKLOAD2:.+]] = constant 768
+// CHECK-NEXT: %[[RESULT1:.+]] = flow.dispatch.region[%[[WORKLOAD2]] : index]
+// CHECK-SAME: %[[ARG5:[a-zA-Z0-9_]+]] = %[[ARG0]]
+// CHECK-SAME: %[[ARG6:[a-zA-Z0-9_]+]] = %[[ARG1]]
+// CHECK-SAME: %[[ARG7:[a-zA-Z0-9_]+]] = %[[ARG2]]
+// CHECK-SAME: {
+// CHECK-NEXT: %[[T1:.+]] = "mhlo.dot"(%[[ARG5]], %[[ARG6]])
+// CHECK-NEXT: %[[T2:.+]] = mhlo.add %[[T1]], %[[ARG7]]
+// CHECK-NEXT: flow.return %[[T2]]
+// CHECK-NEXT: }
+// CHECK-NEXT: %[[RESULT2:.+]] = flow.dispatch.region[%[[WORKLOAD1]] : index]
+// CHECK-SAME: %[[ARG5:[a-zA-Z0-9_]+]] = %[[RESULT1]]
+// CHECK-SAME: %[[ARG6:[a-zA-Z0-9_]+]] = %[[ARG3]]
+// CHECK-SAME: %[[ARG7:[a-zA-Z0-9_]+]] = %[[ARG4]]
+// CHECK-SAME: {
+// CHECK-NEXT: %[[T3:.+]] = "mhlo.dot"(%[[ARG5]], %[[ARG6]])
+// CHECK-NEXT: %[[T4:.+]] = mhlo.multiply %[[T3]], %[[ARG7]]
+// CHECK-NEXT: flow.return %[[T4]]
+// CHECK-NEXT: }
+// CHECK-NEXT: return %[[RESULT2]]
+
+// -----
+
+func @moveDispatchOp
+ (%arg0 : tensor<1x384x384xf32>, %arg1 : tensor<384x512xf32>,
+ %arg2 : tensor<512xf32>) -> tensor<1x384x512xf32> {
+ %0 = "mhlo.reshape"(%arg0) : (tensor<1x384x384xf32>) -> tensor<384x384xf32>
+ %1 = "mhlo.dot"(%0, %arg1) :
+ (tensor<384x384xf32>, tensor<384x512xf32>) -> tensor<384x512xf32>
+ %2 = "mhlo.broadcast_in_dim"(%arg2)
+ {broadcast_dimensions = dense<1> : tensor<1xi64>} :
+ (tensor<512xf32>) -> tensor<384x512xf32>
+ %3 = mhlo.add %1, %2 : tensor<384x512xf32>
+ %4 = "mhlo.reshape"(%3) : (tensor<384x512xf32>) -> tensor<1x384x512xf32>
+ return %4 : tensor<1x384x512xf32>
+}
+// CHECK-LABEL: func @moveDispatchOp
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<1x384x384xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<384x512xf32>
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor<512xf32>
+// CHECK: %[[RESULT1:.+]] = flow.dispatch.region
+// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]] = %[[ARG2]]
+// CHECK-SAME: {
+// CHECK-NEXT: %[[T1:.+]] = "mhlo.broadcast_in_dim"(%[[ARG3]])
+// CHECK-NEXT: flow.return %[[T1]]
+// CHECK-NEXT: }
+// CHECK-NEXT: %[[RESULT2:.+]] = flow.dispatch.region
+// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]] = %[[ARG1]]
+// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]] = %[[RESULT1]]
+// CHECK-SAME: %[[ARG5:[a-zA-Z0-9_]+]] = %[[ARG0]]
+// CHECK-SAME: {
+// CHECK-NEXT: %[[T2:.+]] = "mhlo.reshape"(%[[ARG5]])
+// CHECK-NEXT: %[[T3:.+]] = "mhlo.dot"(%[[T2]], %[[ARG3]])
+// CHECK-NEXT: %[[T4:.+]] = mhlo.add %[[T3]], %[[ARG4]]
+// CHECK-NEXT: %[[T5:.+]] = "mhlo.reshape"(%[[T4]])
+// CHECK-NEXT: flow.return %[[T5]]
+// CHECK-NEXT: }
+// CHECK-NEXT: return %[[RESULT2]]
diff --git a/iree/compiler/Dialect/Modules/TensorList/Conversion/test/convert_hal_to_vm.mlir b/iree/compiler/Dialect/Modules/TensorList/Conversion/test/convert_hal_to_vm.mlir
index c5c8748..5974d45 100644
--- a/iree/compiler/Dialect/Modules/TensorList/Conversion/test/convert_hal_to_vm.mlir
+++ b/iree/compiler/Dialect/Modules/TensorList/Conversion/test/convert_hal_to_vm.mlir
@@ -3,7 +3,7 @@
// CHECK-LABEL: @Reserve
func @Reserve(%element_shape: !hal.buffer_view, %num_elements: !hal.buffer_view) -> !tensorlist.list {
// CHECK: vm.call @tensorlist.reserve
- %0 = "tensorlist.Reserve"(%element_shape, %num_elements) : (!hal.buffer_view, !hal.buffer_view) -> !tensorlist.list
+ %0 = "tensorlist.Reserve"(%element_shape, %num_elements) { element_type = 50331680 : i32 } : (!hal.buffer_view, !hal.buffer_view) -> !tensorlist.list
return %0 : !tensorlist.list
}
// CHECK: vm.import @tensorlist.reserve
@@ -11,9 +11,9 @@
// -----
// CHECK-LABEL: @GetItem
-func @GetItem(%list: !tensorlist.list, %index: !hal.buffer_view, %element_shape: !hal.buffer_view) -> !hal.buffer_view {
+func @GetItem(%list: !tensorlist.list, %index: !hal.buffer_view) -> !hal.buffer_view {
// CHECK: vm.call @tensorlist.get_item
- %0 = "tensorlist.GetItem"(%list, %index, %element_shape) : (!tensorlist.list, !hal.buffer_view, !hal.buffer_view) -> !hal.buffer_view
+ %0 = "tensorlist.GetItem"(%list, %index) : (!tensorlist.list, !hal.buffer_view) -> !hal.buffer_view
return %0 : !hal.buffer_view
}
// CHECK: vm.import @tensorlist.get_item
@@ -31,9 +31,11 @@
// -----
// CHECK-LABEL: @Stack
-func @Stack(%list: !tensorlist.list, %element_shape: !hal.buffer_view, %num_elements: !hal.buffer_view) -> !hal.buffer_view {
+func @Stack(%list: !tensorlist.list, %num_elements: !hal.buffer_view) -> !hal.buffer_view {
+ %dev = hal.ex.shared_device : !hal.device
+ %allocator = hal.device.allocator %dev : !hal.allocator
// CHECK: vm.call @tensorlist.stack
- %0 = "tensorlist.Stack"(%list, %element_shape, %num_elements) : (!tensorlist.list, !hal.buffer_view, !hal.buffer_view) -> !hal.buffer_view
+ %0 = "tensorlist.Stack"(%allocator, %list, %num_elements) : (!hal.allocator, !tensorlist.list, !hal.buffer_view) -> !hal.buffer_view
return %0 : !hal.buffer_view
}
@@ -41,7 +43,9 @@
// CHECK-LABEL: @Concat
func @Concat(%list: !tensorlist.list) -> !hal.buffer_view {
+ %dev = hal.ex.shared_device : !hal.device
+ %allocator = hal.device.allocator %dev : !hal.allocator
// CHECK: vm.call @tensorlist.concat
- %0 = "tensorlist.Concat"(%list) : (!tensorlist.list) -> !hal.buffer_view
+ %0 = "tensorlist.Concat"(%allocator, %list) : (!hal.allocator, !tensorlist.list) -> !hal.buffer_view
return %0 : !hal.buffer_view
}
diff --git a/iree/compiler/Dialect/Modules/TensorList/IR/TensorListOps.cpp b/iree/compiler/Dialect/Modules/TensorList/IR/TensorListOps.cpp
index 04ea528..296e199 100644
--- a/iree/compiler/Dialect/Modules/TensorList/IR/TensorListOps.cpp
+++ b/iree/compiler/Dialect/Modules/TensorList/IR/TensorListOps.cpp
@@ -15,6 +15,7 @@
#include "iree/compiler/Dialect/Modules/TensorList/IR/TensorListOps.h"
#include "iree/compiler/Dialect/HAL/IR/HALTypes.h"
+#include "mlir/IR/Builders.h"
#define GET_OP_CLASSES
#include "iree/compiler/Dialect/Modules/TensorList/IR/TensorListOps.cpp.inc"
diff --git a/iree/compiler/Dialect/Modules/TensorList/IR/TensorListOps.td b/iree/compiler/Dialect/Modules/TensorList/IR/TensorListOps.td
index d7bc2c9..2da11b9 100644
--- a/iree/compiler/Dialect/Modules/TensorList/IR/TensorListOps.td
+++ b/iree/compiler/Dialect/Modules/TensorList/IR/TensorListOps.td
@@ -27,10 +27,10 @@
}];
let arguments = (ins
- // TODO(silvasean): Do we need element_shape?
HAL_BufferView:$element_shape,
// TODO(silvasean): Convert to `I32:$count` instead.
- HAL_BufferView:$count
+ HAL_BufferView:$count,
+ HAL_ElementTypeAttr:$element_type
);
let results = (outs
@@ -47,9 +47,7 @@
let arguments = (ins
TensorList_TensorList:$list,
// TODO(silvasean): Convert to `I32:$index` instead.
- HAL_BufferView:$index,
- // TODO(silvasean): Do we need element_shape?
- HAL_BufferView:$element_shape
+ HAL_BufferView:$index
);
let results = (outs
@@ -83,9 +81,7 @@
a tensorlist `list` of length equal to `tensor`'s leading dimension.
}];
let arguments = (ins
- HAL_BufferView:$tensor,
- // TODO(silvasean): Do we need element_shape?
- HAL_BufferView:$element_shape
+ HAL_BufferView:$tensor
);
let results = (outs
TensorList_TensorList:$list
@@ -102,8 +98,8 @@
Requires all tensors contained in `list` to be the same shape.
}];
let arguments = (ins
+ HAL_Allocator:$allocator,
TensorList_TensorList:$list,
- HAL_BufferView:$element_shape,
HAL_BufferView:$num_elements
);
let results = (outs
@@ -122,6 +118,7 @@
the non-leading axes.
}];
let arguments = (ins
+ HAL_Allocator:$allocator,
TensorList_TensorList:$list
);
let results = (outs
diff --git a/iree/compiler/Dialect/Modules/TensorList/IR/test/ops.mlir b/iree/compiler/Dialect/Modules/TensorList/IR/test/ops.mlir
index 72cef98..688bc47 100644
--- a/iree/compiler/Dialect/Modules/TensorList/IR/test/ops.mlir
+++ b/iree/compiler/Dialect/Modules/TensorList/IR/test/ops.mlir
@@ -3,16 +3,16 @@
// CHECK-LABEL: @Reserve
func @Reserve(%element_shape: !hal.buffer_view, %num_elements: !hal.buffer_view) -> !tensorlist.list {
// CHECK: tensorlist.Reserve
- %0 = "tensorlist.Reserve"(%element_shape, %num_elements) : (!hal.buffer_view, !hal.buffer_view) -> !tensorlist.list
+ %0 = "tensorlist.Reserve"(%element_shape, %num_elements) {element_type = 13 : i32 } : (!hal.buffer_view, !hal.buffer_view) -> !tensorlist.list
return %0 : !tensorlist.list
}
// -----
// CHECK-LABEL: @GetItem
-func @GetItem(%list: !tensorlist.list, %index: !hal.buffer_view, %element_shape: !hal.buffer_view) -> !hal.buffer_view {
+func @GetItem(%list: !tensorlist.list, %index: !hal.buffer_view) -> !hal.buffer_view {
// CHECK: tensorlist.GetItem
- %0 = "tensorlist.GetItem"(%list, %index, %element_shape) : (!tensorlist.list, !hal.buffer_view, !hal.buffer_view) -> !hal.buffer_view
+ %0 = "tensorlist.GetItem"(%list, %index) : (!tensorlist.list, !hal.buffer_view) -> !hal.buffer_view
return %0 : !hal.buffer_view
}
@@ -28,17 +28,17 @@
// -----
// CHECK-LABEL: @Stack
-func @Stack(%list: !tensorlist.list, %element_shape: !hal.buffer_view, %num_elements: !hal.buffer_view) -> !hal.buffer_view {
+func @Stack(%allocator: !hal.allocator, %list: !tensorlist.list, %num_elements: !hal.buffer_view) -> !hal.buffer_view {
// CHECK: tensorlist.Stack
- %0 = "tensorlist.Stack"(%list, %element_shape, %num_elements) : (!tensorlist.list, !hal.buffer_view, !hal.buffer_view) -> !hal.buffer_view
+ %0 = "tensorlist.Stack"(%allocator, %list, %num_elements) : (!hal.allocator, !tensorlist.list, !hal.buffer_view) -> !hal.buffer_view
return %0 : !hal.buffer_view
}
// -----
// CHECK-LABEL: @Concat
-func @Concat(%list: !tensorlist.list) -> !hal.buffer_view {
+func @Concat(%allocator: !hal.allocator, %list: !tensorlist.list) -> !hal.buffer_view {
// CHECK: tensorlist.Concat
- %0 = "tensorlist.Concat"(%list) : (!tensorlist.list) -> !hal.buffer_view
+ %0 = "tensorlist.Concat"(%allocator, %list) : (!hal.allocator, !tensorlist.list) -> !hal.buffer_view
return %0 : !hal.buffer_view
-}
\ No newline at end of file
+}
diff --git a/iree/compiler/Dialect/Modules/TensorList/tensorlist.imports.mlir b/iree/compiler/Dialect/Modules/TensorList/tensorlist.imports.mlir
index fa55ef5..6a16add 100644
--- a/iree/compiler/Dialect/Modules/TensorList/tensorlist.imports.mlir
+++ b/iree/compiler/Dialect/Modules/TensorList/tensorlist.imports.mlir
@@ -17,15 +17,15 @@
// Maps to IREE::TensorList::Reserve.
vm.import @reserve(
%element_shape : !vm.ref<!hal.buffer_view>,
- %num_elements : !vm.ref<!hal.buffer_view>
+ %num_elements : !vm.ref<!hal.buffer_view>,
+ %element_type : i32
) -> !vm.ref<!tensorlist.list>
attributes {nosideeffects}
// Maps to IREE::TensorList::GetItem.
vm.import @get_item(
%list : !vm.ref<!tensorlist.list>,
- %index : !vm.ref<!hal.buffer_view>,
- %element_shape: !vm.ref<!hal.buffer_view>
+ %index : !vm.ref<!hal.buffer_view>
) -> !vm.ref<!hal.buffer_view>
attributes {nosideeffects}
@@ -39,21 +39,21 @@
// Maps to IREE:TensorList::FromTensor
vm.import @from_tensor(
- %tensor : !vm.ref<!hal.buffer_view>,
- %element_shape : !vm.ref<!hal.buffer_view>
+ %tensor : !vm.ref<!hal.buffer_view>
) -> !vm.ref<!tensorlist.list>
attributes {nosideeffects}
// Maps to IREE:TensorList::Concat
vm.import @concat(
+ %allocator : !vm.ref<!hal.allocator>,
%list : !vm.ref<!tensorlist.list>
) -> !vm.ref<!hal.buffer_view>
attributes {nosideeffects}
// Maps to IREE:TensorList::Stack
vm.import @stack(
+ %allocator : !vm.ref<!hal.allocator>,
%list : !vm.ref<!tensorlist.list>,
- %element_shape : !vm.ref<!hal.buffer_view>,
%num_elements : !vm.ref<!hal.buffer_view>
) -> !vm.ref<!hal.buffer_view>
attributes {nosideeffects}
diff --git a/iree/compiler/Dialect/VM/Conversion/VMToEmitC/CMakeLists.txt b/iree/compiler/Dialect/VM/Conversion/VMToEmitC/CMakeLists.txt
index f4f29c9..4dafc73 100644
--- a/iree/compiler/Dialect/VM/Conversion/VMToEmitC/CMakeLists.txt
+++ b/iree/compiler/Dialect/VM/Conversion/VMToEmitC/CMakeLists.txt
@@ -28,6 +28,9 @@
MLIREmitC
MLIRTransforms
iree::compiler::Dialect::VM::IR
+ INCLUDES
+ "${PROJECT_SOURCE_DIR}/third_party/mlir-emitc/include"
+ "${PROJECT_BINARY_DIR}/third_party/mlir-emitc/include"
PUBLIC
)
endif()
diff --git a/iree/compiler/Dialect/VM/Target/C/CMakeLists.txt b/iree/compiler/Dialect/VM/Target/C/CMakeLists.txt
index 0b0b4da..691c6d3 100644
--- a/iree/compiler/Dialect/VM/Target/C/CMakeLists.txt
+++ b/iree/compiler/Dialect/VM/Target/C/CMakeLists.txt
@@ -32,6 +32,8 @@
MLIRSupport
iree::compiler::Dialect::VM::IR
iree::compiler::Dialect::VM::Conversion::VMToEmitC
+ INCLUDES
+ "${PROJECT_SOURCE_DIR}/third_party/mlir-emitc/include"
PUBLIC
)
endif()
diff --git a/iree/hal/vulkan/descriptor_set_arena.cc b/iree/hal/vulkan/descriptor_set_arena.cc
index 238384f..ee60e6b 100644
--- a/iree/hal/vulkan/descriptor_set_arena.cc
+++ b/iree/hal/vulkan/descriptor_set_arena.cc
@@ -131,10 +131,11 @@
// Pick a bucket based on the number of descriptors required.
// NOTE: right now we are 1:1 with bindings.
- int required_descriptor_count = static_cast<int>(bindings.size() * 1);
- int max_descriptor_count =
- std::max(8, RoundUpToNearestPow2(required_descriptor_count));
- int bucket = TrailingZeros(max_descriptor_count >> 3);
+ uint32_t required_descriptor_count = static_cast<int>(bindings.size() * 1);
+ uint32_t max_descriptor_count =
+ std::max(8u, iree_math_round_up_to_pow2_u32(required_descriptor_count));
+ uint32_t bucket =
+ iree_math_count_trailing_zeros_u32(max_descriptor_count >> 3);
if (bucket >= descriptor_pool_buckets_.size()) {
return OutOfRangeErrorBuilder(IREE_LOC)
<< "Too many descriptors required: " << required_descriptor_count
diff --git a/iree/hal/vulkan/internal_vk_mem_alloc.h b/iree/hal/vulkan/internal_vk_mem_alloc.h
index 4541e9f..d2fa59e 100644
--- a/iree/hal/vulkan/internal_vk_mem_alloc.h
+++ b/iree/hal/vulkan/internal_vk_mem_alloc.h
@@ -27,5 +27,6 @@
// to be omitted and not have VMA poking around where it shouldn't.
#define VMA_DYNAMIC_VULKAN_FUNCTIONS 0
-#include "vk_mem_alloc.h"
+#include <vk_mem_alloc.h>
+
#endif // IREE_HAL_VULKAN_INTERNAL_VK_MEM_ALLOC_H_
diff --git a/iree/hal/vulkan/vulkan_device.cc b/iree/hal/vulkan/vulkan_device.cc
index a5e2cd0..582a8c3 100644
--- a/iree/hal/vulkan/vulkan_device.cc
+++ b/iree/hal/vulkan/vulkan_device.cc
@@ -171,7 +171,8 @@
const ref_ptr<DynamicSymbols>& syms) {
absl::InlinedVector<std::unique_ptr<CommandQueue>, 4> command_queues;
- uint64_t compute_queue_count = CountOnes64(compute_queue_set.queue_indices);
+ uint64_t compute_queue_count =
+ iree_math_count_ones_u64(compute_queue_set.queue_indices);
for (uint32_t i = 0; i < compute_queue_count; ++i) {
if (!(compute_queue_set.queue_indices & (1ull << i))) continue;
@@ -193,7 +194,8 @@
}
}
- uint64_t transfer_queue_count = CountOnes64(transfer_queue_set.queue_indices);
+ uint64_t transfer_queue_count =
+ iree_math_count_ones_u64(transfer_queue_set.queue_indices);
for (uint32_t i = 0; i < transfer_queue_count; ++i) {
if (!(transfer_queue_set.queue_indices & (1ull << i))) continue;
@@ -411,8 +413,10 @@
const ref_ptr<DynamicSymbols>& syms) {
IREE_TRACE_SCOPE0("VulkanDevice::Wrap");
- uint64_t compute_queue_count = CountOnes64(compute_queue_set.queue_indices);
- uint64_t transfer_queue_count = CountOnes64(transfer_queue_set.queue_indices);
+ uint64_t compute_queue_count =
+ iree_math_count_ones_u64(compute_queue_set.queue_indices);
+ uint64_t transfer_queue_count =
+ iree_math_count_ones_u64(transfer_queue_set.queue_indices);
if (compute_queue_count == 0) {
return InvalidArgumentErrorBuilder(IREE_LOC)
diff --git a/iree/modules/tensorlist/native_module.cc b/iree/modules/tensorlist/native_module.cc
index b23db9c..18d4770 100644
--- a/iree/modules/tensorlist/native_module.cc
+++ b/iree/modules/tensorlist/native_module.cc
@@ -38,6 +38,14 @@
namespace {
class TensorList final : public RefObject<TensorList> {
public:
+ TensorList(absl::Span<const int32_t> shape, iree_hal_element_type_t dtype)
+ : shape_(shape.begin(), shape.end()), dtype_(dtype) {}
+
+ TensorList(const vm::ref<TensorList>& other)
+ : shape_(other->shape_), dtype_(other->dtype_) {
+ CopyFrom(other);
+ }
+
void Resize(int32_t num_elements) { list_.resize(num_elements); }
// Copy from another iree_tensorlist.
// vm::ref has deleted copy operator=, so we can't use vector's operator=.
@@ -62,6 +70,9 @@
}
}
size_t Size() { return list_.size(); }
+ absl::Span<int32_t> Shape() {
+ return absl::Span<int32_t>(shape_.data(), shape_.size());
+ }
static StatusOr<vm::ref<TensorList>> FromTensor(
vm::ref<iree_hal_buffer_view_t> tensor) {
@@ -74,15 +85,21 @@
IREE_RETURN_IF_ERROR(
iree_hal_buffer_view_shape(tensor.get(), rank, shape.data(), nullptr));
- TensorList* list = new TensorList;
- list->Resize(shape[0]);
+ auto element_type = iree_hal_buffer_view_element_type(tensor.get());
+
+ int32_t list_elements = shape[0];
+ absl::Span<int32_t> element_shape(shape.data() + 1, shape.size() - 1);
+
+ TensorList* list = new TensorList(element_shape, element_type);
+ list->Resize(list_elements);
+
// The python pseudocode for this is:
// for i in range(t.shape[0]):
// list[i] = t[i,...]
absl::InlinedVector<int32_t, 6> start_indices(shape.size());
absl::InlinedVector<int32_t, 6> lengths = shape;
lengths[0] = 1;
- for (int i = 0, e = shape[0]; i < e; i++) {
+ for (int i = 0, e = list_elements; i < e; i++) {
start_indices[0] = i;
iree_device_size_t start_offset = 0;
iree_device_size_t subview_length = 0;
@@ -96,7 +113,7 @@
iree_hal_buffer_view_t* slice = nullptr;
IREE_RETURN_IF_ERROR(iree_hal_buffer_view_create(
- subview_buffer.get(), shape.data() + 1, shape.size() - 1,
+ subview_buffer.get(), element_shape.data(), element_shape.size(),
iree_hal_buffer_view_element_type(tensor.get()),
iree_allocator_system(), &slice));
list->SetItem(i, slice);
@@ -104,35 +121,29 @@
return list;
}
- StatusOr<vm::ref<iree_hal_buffer_view_t>> Stack() {
+ StatusOr<vm::ref<iree_hal_buffer_view_t>> Stack(
+ vm::ref<iree_hal_allocator_t> hal_allocator) {
size_t num_tensors = Size();
if (num_tensors == 0) {
return InvalidArgumentErrorBuilder(IREE_LOC) << "expected non-empty list";
}
- for (size_t i = 0; i < num_tensors; i++) {
- if (!GetItem(i).get()) {
- return InvalidArgumentErrorBuilder(IREE_LOC)
- << "uninitialized element in list";
- }
- }
- size_t rank = iree_hal_buffer_view_shape_rank(GetItem(0).get());
- iree_hal_element_type_t type =
- iree_hal_buffer_view_element_type(GetItem(0).get());
- absl::InlinedVector<int32_t, 6> shape(rank);
- IREE_RETURN_IF_ERROR(iree_hal_buffer_view_shape(GetItem(0).get(), rank,
- shape.data(), nullptr));
+ // Validate that all buffers are of the right shape/type.
+ absl::Span<int32_t> shape(shape_);
+ iree_hal_element_type_t type(dtype_);
for (size_t i = 0; i < num_tensors; i++) {
- size_t element_rank = iree_hal_buffer_view_shape_rank(GetItem(i).get());
+ auto item = GetItem(i).get();
+ if (!item) continue;
+ size_t element_rank = iree_hal_buffer_view_shape_rank(item);
absl::InlinedVector<int32_t, 6> element_shape(element_rank);
IREE_RETURN_IF_ERROR(iree_hal_buffer_view_shape(
- GetItem(i).get(), element_rank, element_shape.data(), nullptr));
+ item, element_rank, element_shape.data(), nullptr));
if (absl::MakeSpan(shape) != absl::MakeSpan(element_shape) ||
- iree_hal_buffer_view_element_type(GetItem(i).get()) != type) {
+ iree_hal_buffer_view_element_type(item) != type) {
return InvalidArgumentErrorBuilder(IREE_LOC)
<< "stacking list with elements of different shapes or element "
- "types. Mismatch between element 0 and element "
- << i;
+ << "types. Mismatch between element 0 and element " << i;
+ ;
}
}
@@ -141,13 +152,12 @@
for (int32_t dim : shape) {
num_elements_per_tensor *= dim;
}
- size_t element_size = iree_hal_buffer_view_element_size(GetItem(0).get());
+
+ size_t element_size = iree_hal_element_byte_count(type);
size_t num_result_elements = num_elements_per_tensor * num_tensors;
size_t result_byte_size = num_result_elements * element_size;
- iree_hal_allocator_t* hal_allocator = iree_hal_buffer_allocator(
- iree_hal_buffer_view_buffer(GetItem(0).get()));
IREE_RETURN_IF_ERROR(iree_hal_allocator_allocate_buffer(
- hal_allocator,
+ hal_allocator.get(),
static_cast<iree_hal_memory_type_t>(
IREE_HAL_MEMORY_TYPE_HOST_LOCAL |
IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE),
@@ -167,27 +177,28 @@
return std::move(result_view);
}
- StatusOr<vm::ref<iree_hal_buffer_view_t>> Concat() {
+ StatusOr<vm::ref<iree_hal_buffer_view_t>> Concat(
+ vm::ref<iree_hal_allocator_t> hal_allocator) {
size_t num_tensors = Size();
if (num_tensors == 0) {
return InvalidArgumentErrorBuilder(IREE_LOC) << "expected non-empty list";
}
- for (size_t i = 0; i < num_tensors; i++) {
- if (!GetItem(i).get()) {
- return InvalidArgumentErrorBuilder(IREE_LOC)
- << "uninitialized element in list";
- }
+
+ if (shape_.empty()) {
+ return InvalidArgumentErrorBuilder(IREE_LOC)
+ << "stacking rank must be greater than zero.";
}
size_t rank = iree_hal_buffer_view_shape_rank(GetItem(0).get());
- iree_hal_element_type_t type =
- iree_hal_buffer_view_element_type(GetItem(0).get());
+ iree_hal_element_type_t type = dtype_;
absl::InlinedVector<int32_t, 6> shape(rank);
IREE_RETURN_IF_ERROR(iree_hal_buffer_view_shape(GetItem(0).get(), rank,
shape.data(), nullptr));
- size_t num_rows = 0;
+ const size_t num_rows = num_tensors * shape[0];
for (size_t i = 0; i < num_tensors; i++) {
- size_t element_rank = iree_hal_buffer_view_shape_rank(GetItem(i).get());
+ auto item = GetItem(i).get();
+ if (!item) continue;
+ size_t element_rank = iree_hal_buffer_view_shape_rank(item);
if (element_rank < 1) {
return InvalidArgumentErrorBuilder(IREE_LOC)
<< "stacking rank must be greater than zero." << i;
@@ -196,7 +207,6 @@
absl::InlinedVector<int32_t, 6> element_shape(element_rank);
IREE_RETURN_IF_ERROR(iree_hal_buffer_view_shape(
GetItem(i).get(), element_rank, element_shape.data(), nullptr));
- num_rows += element_shape.front();
if (absl::MakeSpan(shape).subspan(1) !=
absl::MakeSpan(element_shape).subspan(1) ||
@@ -216,10 +226,8 @@
size_t element_size = iree_hal_buffer_view_element_size(GetItem(0).get());
size_t num_result_elements = num_elements_per_row * num_rows;
size_t result_byte_size = num_result_elements * element_size;
- iree_hal_allocator_t* hal_allocator = iree_hal_buffer_allocator(
- iree_hal_buffer_view_buffer(GetItem(0).get()));
IREE_RETURN_IF_ERROR(iree_hal_allocator_allocate_buffer(
- hal_allocator,
+ hal_allocator.get(),
static_cast<iree_hal_memory_type_t>(
IREE_HAL_MEMORY_TYPE_HOST_LOCAL |
IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE),
@@ -261,9 +269,19 @@
// in the compiler at which point there will be no "stack" function inside
// this module at all.
size_t num_tensors = Size();
- size_t offset = 0;
+ size_t tensor_byte_size = iree_hal_element_byte_count(dtype_);
+ for (auto dim : shape_) tensor_byte_size *= dim;
for (size_t i = 0; i < num_tensors; i++) {
iree_hal_buffer_view_t* tensor = GetItem(i).get();
+
+ auto block_begin = result_mapping.contents.data + i * tensor_byte_size;
+ auto block_size = tensor_byte_size;
+
+ if (!tensor) {
+ memset(block_begin, 0, block_size);
+ continue;
+ }
+
iree_hal_buffer_t* tensor_buffer = iree_hal_buffer_view_buffer(tensor);
iree_hal_mapped_memory_t tensor_mapping;
iree_device_size_t tensor_byte_size =
@@ -273,9 +291,7 @@
iree_hal_buffer_map(tensor_buffer, IREE_HAL_MEMORY_ACCESS_READ, 0,
tensor_byte_size, &tensor_mapping));
- memcpy(result_mapping.contents.data + offset,
- tensor_mapping.contents.data, tensor_byte_size);
- offset += tensor_byte_size;
+ memcpy(block_begin, tensor_mapping.contents.data, block_size);
IREE_RETURN_IF_ERROR(
iree_hal_buffer_unmap(tensor_buffer, &tensor_mapping));
@@ -285,6 +301,8 @@
}
std::vector<vm::ref<iree_hal_buffer_view_t>> list_;
+ std::vector<int32_t> shape_;
+ iree_hal_element_type_t dtype_;
};
} // namespace
@@ -339,6 +357,35 @@
return scalar;
}
+static StatusOr<std::vector<int32_t>> ReadInt32VectorFromBufferView(
+ iree_hal_buffer_view_t* buffer_view) {
+ if (iree_hal_buffer_view_element_type(buffer_view) !=
+ IREE_HAL_ELEMENT_TYPE_SINT_32) {
+ return InvalidArgumentErrorBuilder(IREE_LOC) << "expected i32 buffer view";
+ }
+ if (iree_hal_buffer_view_shape_rank(buffer_view) != 1) {
+ return InvalidArgumentErrorBuilder(IREE_LOC)
+ << "expected rank-1 buffer view";
+ }
+
+ int32_t length;
+ IREE_RETURN_IF_ERROR(iree_hal_buffer_view_shape(
+ buffer_view, /*rank_capacity=*/1, &length, nullptr));
+
+ iree_hal_buffer_t* buffer = iree_hal_buffer_view_buffer(buffer_view);
+ iree_hal_mapped_memory_t mapped_memory;
+ IREE_RETURN_IF_ERROR(iree_hal_buffer_map(buffer, IREE_HAL_MEMORY_ACCESS_READ,
+ 0, length * sizeof(int32_t),
+ &mapped_memory));
+
+ std::vector<int32_t> contents(
+ reinterpret_cast<int32_t*>(mapped_memory.contents.data),
+ reinterpret_cast<int32_t*>(mapped_memory.contents.data) + length);
+
+ IREE_RETURN_IF_ERROR(iree_hal_buffer_unmap(buffer, &mapped_memory));
+ return contents;
+}
+
namespace {
class TensorListModuleState final {
public:
@@ -348,10 +395,11 @@
// tensorlist.reserve(%element_shape, %num_elements) -> %list
StatusOr<vm::ref<TensorList>> Reserve(
vm::ref<iree_hal_buffer_view_t> element_shape,
- vm::ref<iree_hal_buffer_view_t> num_elements_buf) {
- // TODO(silvasean): Emulate element shape and dtype tracking in TensorList.
- (void)element_shape;
- TensorList* tensorlist = new TensorList;
+ vm::ref<iree_hal_buffer_view_t> num_elements_buf,
+ iree_hal_element_type_t element_type) {
+ IREE_ASSIGN_OR_RETURN(std::vector<int32_t> shape,
+ ReadInt32VectorFromBufferView(element_shape.get()));
+ TensorList* tensorlist = new TensorList(shape, element_type);
IREE_ASSIGN_OR_RETURN(int32_t num_elements, ReadInt32FromScalarBufferView(
num_elements_buf.get()));
tensorlist->Resize(num_elements);
@@ -360,10 +408,8 @@
// tensorlist.get_item(%list, %index, %element_shape) -> %item
StatusOr<vm::ref<iree_hal_buffer_view_t>> GetItem(
- vm::ref<TensorList> tensorlist, vm::ref<iree_hal_buffer_view_t> index_buf,
- vm::ref<iree_hal_buffer_view_t> element_shape) {
- // TODO(silvasean): Emulate element shape and dtype tracking in TensorList.
- (void)element_shape;
+ vm::ref<TensorList> tensorlist,
+ vm::ref<iree_hal_buffer_view_t> index_buf) {
IREE_ASSIGN_OR_RETURN(int32_t index,
ReadInt32FromScalarBufferView(index_buf.get()));
return vm::retain_ref(tensorlist->GetItem(index).get());
@@ -373,35 +419,29 @@
StatusOr<vm::ref<TensorList>> SetItem(
vm::ref<TensorList> list, vm::ref<iree_hal_buffer_view_t> index_buf,
vm::ref<iree_hal_buffer_view_t> item) {
- TensorList* new_list = new TensorList;
IREE_ASSIGN_OR_RETURN(int32_t index,
ReadInt32FromScalarBufferView(index_buf.get()));
- new_list->CopyFrom(list);
+ TensorList* new_list = new TensorList(list);
new_list->SetItem(index, vm::retain_ref(item));
return new_list;
}
// tensorlist.from_tensor(%tensor, %element_shape) -> %list
StatusOr<vm::ref<TensorList>> FromTensor(
- vm::ref<iree_hal_buffer_view_t> tensor,
- vm::ref<iree_hal_buffer_view_t> element_shape) {
- // TODO(silvasean): Emulate element shape and dtype tracking in TensorList.
- (void)element_shape;
+ vm::ref<iree_hal_buffer_view_t> tensor) {
return TensorList::FromTensor(tensor);
}
// tensorlist.concat(%list) -> %list
- StatusOr<vm::ref<iree_hal_buffer_view_t>> Concat(vm::ref<TensorList> list) {
- return list->Concat();
+ StatusOr<vm::ref<iree_hal_buffer_view_t>> Concat(
+ vm::ref<iree_hal_allocator_t> allocator, vm::ref<TensorList> list) {
+ return list->Concat(allocator);
}
// tensorlist.stack(%list, %element_shape, %num_elements) -> %list
StatusOr<vm::ref<iree_hal_buffer_view_t>> Stack(
- vm::ref<TensorList> list,
- vm::ref<iree_hal_buffer_view_t> element_shape_buffer_view,
+ vm::ref<iree_hal_allocator_t> allocator, vm::ref<TensorList> list,
vm::ref<iree_hal_buffer_view_t> num_elements_buffer_view) {
- // TODO(silvasean): Emulate element shape and dtype tracking in TensorList.
- (void)element_shape_buffer_view;
IREE_ASSIGN_OR_RETURN(
int32_t num_elements,
ReadInt32FromScalarBufferView(num_elements_buffer_view.get()));
@@ -410,7 +450,7 @@
<< "num_elements arg to tesorlist.stack doesn't match the list "
"size";
}
- return list->Stack();
+ return list->Stack(allocator);
}
};
} // namespace
diff --git a/iree/modules/tensorlist/tensorlist_test.cc b/iree/modules/tensorlist/tensorlist_test.cc
index 3293d57..8eba8e9 100644
--- a/iree/modules/tensorlist/tensorlist_test.cc
+++ b/iree/modules/tensorlist/tensorlist_test.cc
@@ -96,6 +96,91 @@
return function;
}
+ void Invoke(absl::string_view function_name,
+ absl::Span<const float> input_values,
+ absl::Span<const int32_t> input_shape,
+ absl::Span<const float> expected_values,
+ absl::Span<const int32_t> expected_shape) {
+ vm::ref<iree_hal_buffer_view_t> input_buffer_view;
+ CreateBufferView(input_values, input_shape, device_, &input_buffer_view);
+
+ // Pass in the tensor as a HAL buffer view.
+ vm::ref<iree_vm_list_t> inputs;
+ IREE_ASSERT_OK(iree_vm_list_create(/*element_type=*/nullptr, 1,
+ iree_allocator_system(), &inputs));
+ iree_vm_ref_t input_buffer_view_ref =
+ iree_hal_buffer_view_move_ref(input_buffer_view.get());
+ IREE_ASSERT_OK(
+ iree_vm_list_push_ref_retain(inputs.get(), &input_buffer_view_ref));
+
+ // Prepare outputs list to accept the results from the invocation.
+ vm::ref<iree_vm_list_t> outputs;
+ IREE_ASSERT_OK(iree_vm_list_create(/*element_type=*/nullptr, 1,
+ iree_allocator_system(), &outputs));
+
+ // Synchronously invoke the function.
+ IREE_ASSERT_OK(iree_vm_invoke(context_, LookupFunction(function_name),
+ /*policy=*/nullptr, inputs.get(),
+ outputs.get(), iree_allocator_system()));
+
+ auto* returned_buffer_view =
+ reinterpret_cast<iree_hal_buffer_view_t*>(iree_vm_list_get_ref_deref(
+ outputs.get(), 0, iree_hal_buffer_view_get_descriptor()));
+
+ absl::InlinedVector<int32_t, 5> returned_shape(
+ iree_hal_buffer_view_shape_rank(returned_buffer_view));
+ iree_hal_buffer_view_shape(returned_buffer_view, returned_shape.size(),
+ returned_shape.data(), nullptr);
+
+ EXPECT_EQ(returned_shape, expected_shape);
+
+ iree_hal_buffer_t* returned_buffer =
+ iree_hal_buffer_view_buffer(returned_buffer_view);
+ ASSERT_NE(returned_buffer, nullptr);
+
+ iree_hal_mapped_memory_t mapped_memory;
+ IREE_ASSERT_OK(iree_hal_buffer_map(returned_buffer,
+ IREE_HAL_MEMORY_ACCESS_READ, 0,
+ IREE_WHOLE_BUFFER, &mapped_memory));
+ for (int i = 0; i < expected_values.size(); i++) {
+ EXPECT_EQ(reinterpret_cast<float*>(mapped_memory.contents.data)[i],
+ expected_values[i]);
+ }
+
+ IREE_ASSERT_OK(iree_hal_buffer_unmap(returned_buffer, &mapped_memory));
+ }
+
+ void CreateBufferView(absl::Span<const float> contents,
+ absl::Span<const int32_t> shape,
+ iree_hal_device_t* device,
+ iree_hal_buffer_view_t** out_buffer_view) {
+ size_t num_elements = 1;
+ for (int32_t dim : shape) {
+ num_elements *= dim;
+ }
+ ASSERT_EQ(contents.size(), num_elements);
+ vm::ref<iree_hal_buffer_t> buffer;
+ iree_hal_allocator_t* allocator = iree_hal_device_allocator(device);
+ IREE_ASSERT_OK(iree_hal_allocator_allocate_buffer(
+ allocator,
+ static_cast<iree_hal_memory_type_t>(
+ IREE_HAL_MEMORY_TYPE_HOST_LOCAL |
+ IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE),
+ IREE_HAL_BUFFER_USAGE_ALL, contents.size() * sizeof(float), &buffer));
+ iree_hal_mapped_memory_t mapped_memory;
+ IREE_ASSERT_OK(iree_hal_buffer_map(buffer.get(),
+ IREE_HAL_MEMORY_ACCESS_WRITE, 0,
+ IREE_WHOLE_BUFFER, &mapped_memory));
+ memcpy(mapped_memory.contents.data,
+ static_cast<const void*>(contents.data()),
+ mapped_memory.contents.data_length);
+ IREE_ASSERT_OK(iree_hal_buffer_unmap(buffer.get(), &mapped_memory));
+ IREE_ASSERT_OK(iree_hal_buffer_view_create(
+ buffer.get(), shape.data(), shape.size(),
+ IREE_HAL_ELEMENT_TYPE_FLOAT_32, iree_allocator_system(),
+ &*out_buffer_view));
+ }
+
iree_hal_device_t* device_ = nullptr;
iree_vm_instance_t* instance_ = nullptr;
iree_vm_context_t* context_ = nullptr;
@@ -104,176 +189,53 @@
iree_vm_module_t* hal_module_ = nullptr;
};
-void CreateBufferView(absl::Span<float> contents,
- absl::Span<const int32_t> shape,
- iree_hal_device_t* device,
- iree_hal_buffer_view_t** out_buffer_view) {
- size_t num_elements = 1;
- for (int32_t dim : shape) {
- num_elements *= dim;
- }
- ASSERT_EQ(contents.size(), num_elements);
- vm::ref<iree_hal_buffer_t> buffer;
- iree_hal_allocator_t* allocator = iree_hal_device_allocator(device);
- IREE_ASSERT_OK(iree_hal_allocator_allocate_buffer(
- allocator,
- static_cast<iree_hal_memory_type_t>(IREE_HAL_MEMORY_TYPE_HOST_LOCAL |
- IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE),
- IREE_HAL_BUFFER_USAGE_ALL, contents.size() * sizeof(float), &buffer));
- iree_hal_mapped_memory_t mapped_memory;
- IREE_ASSERT_OK(iree_hal_buffer_map(buffer.get(), IREE_HAL_MEMORY_ACCESS_WRITE,
- 0, IREE_WHOLE_BUFFER, &mapped_memory));
- memcpy(mapped_memory.contents.data, static_cast<void*>(contents.data()),
- mapped_memory.contents.data_length);
- IREE_ASSERT_OK(iree_hal_buffer_unmap(buffer.get(), &mapped_memory));
- IREE_ASSERT_OK(iree_hal_buffer_view_create(
- buffer.get(), shape.data(), shape.size(), IREE_HAL_ELEMENT_TYPE_FLOAT_32,
- iree_allocator_system(), &*out_buffer_view));
-}
-
TEST_F(TensorListModulesTest, IdentityThroughSetItemGetItem) {
// Allocate the buffer we'll be passing through.
- static float kBufferContents[1] = {42.0f};
- absl::InlinedVector<int32_t, 4> shape;
- vm::ref<iree_hal_buffer_view_t> input_buffer_view;
- CreateBufferView(kBufferContents, shape, device_, &input_buffer_view);
+ std::vector<float> input = {42.0f};
+ std::vector<int32_t> input_shape = {};
+ Invoke("identity_through_set_item_get_item", input, input_shape, input,
+ input_shape);
+}
- // Pass in the tensor as a HAL buffer view.
- vm::ref<iree_vm_list_t> inputs;
- IREE_ASSERT_OK(iree_vm_list_create(/*element_type=*/nullptr, 1,
- iree_allocator_system(), &inputs));
- iree_vm_ref_t input_buffer_view_ref =
- iree_hal_buffer_view_move_ref(input_buffer_view.get());
- IREE_ASSERT_OK(
- iree_vm_list_push_ref_retain(inputs.get(), &input_buffer_view_ref));
-
- // Prepare outputs list to accept the results from the invocation.
- vm::ref<iree_vm_list_t> outputs;
- IREE_ASSERT_OK(iree_vm_list_create(/*element_type=*/nullptr, 1,
- iree_allocator_system(), &outputs));
-
- // Synchronously invoke the function.
- IREE_ASSERT_OK(iree_vm_invoke(
- context_, LookupFunction("identity_through_set_item_get_item"),
- /*policy=*/nullptr, inputs.get(), outputs.get(),
- iree_allocator_system()));
-
- auto* returned_buffer_view =
- reinterpret_cast<iree_hal_buffer_view_t*>(iree_vm_list_get_ref_deref(
- outputs.get(), 0, iree_hal_buffer_view_get_descriptor()));
- ASSERT_NE(nullptr, returned_buffer_view);
- iree_hal_buffer_t* returned_buffer =
- iree_hal_buffer_view_buffer(returned_buffer_view);
- ASSERT_NE(nullptr, returned_buffer);
-
- iree_hal_mapped_memory_t mapped_memory;
- IREE_ASSERT_OK(iree_hal_buffer_map(returned_buffer,
- IREE_HAL_MEMORY_ACCESS_READ, 0,
- IREE_WHOLE_BUFFER, &mapped_memory));
- EXPECT_EQ(reinterpret_cast<float*>(mapped_memory.contents.data)[0],
- kBufferContents[0]);
- IREE_ASSERT_OK(iree_hal_buffer_unmap(returned_buffer, &mapped_memory));
+TEST_F(TensorListModulesTest, IdentityThroughSetItemGetItem2D) {
+ // Allocate the buffer we'll be passing through.
+ std::vector<float> input = {42.0f};
+ std::vector<int32_t> input_shape = {1, 1};
+ Invoke("identity_through_set_item_get_item", input, input_shape, input,
+ input_shape);
}
TEST_F(TensorListModulesTest, IdentityThroughConcat) {
// Allocate the buffer we'll be passing through.
- static float kBufferContents[4] = {42.0f, 43.0f, 44.0f, 45.0f};
- absl::InlinedVector<int32_t, 4> shape = {4, 1};
- vm::ref<iree_hal_buffer_view_t> input_buffer_view;
- CreateBufferView(kBufferContents, shape, device_, &input_buffer_view);
+ std::vector<float> input = {42.0f, 43.0f, 44.0f, 45.0f};
+ absl::InlinedVector<int32_t, 4> input_shape = {4, 1};
+ absl::InlinedVector<int32_t, 4> expected_shape = {4};
+ Invoke("identity_through_concat", input, input_shape, input, expected_shape);
+}
- // Pass in the tensor as a HAL buffer view.
- vm::ref<iree_vm_list_t> inputs;
- IREE_ASSERT_OK(iree_vm_list_create(/*element_type=*/nullptr, 1,
- iree_allocator_system(), &inputs));
- iree_vm_ref_t input_buffer_view_ref =
- iree_hal_buffer_view_move_ref(input_buffer_view.get());
- IREE_ASSERT_OK(
- iree_vm_list_push_ref_retain(inputs.get(), &input_buffer_view_ref));
-
- // Prepare outputs list to accept the results from the invocation.
- vm::ref<iree_vm_list_t> outputs;
- IREE_ASSERT_OK(iree_vm_list_create(/*element_type=*/nullptr, 1,
- iree_allocator_system(), &outputs));
-
- // Synchronously invoke the function.
- IREE_ASSERT_OK(iree_vm_invoke(context_,
- LookupFunction("identity_through_concat"),
- /*policy=*/nullptr, inputs.get(), outputs.get(),
- iree_allocator_system()));
-
- auto* returned_buffer_view =
- reinterpret_cast<iree_hal_buffer_view_t*>(iree_vm_list_get_ref_deref(
- outputs.get(), 0, iree_hal_buffer_view_get_descriptor()));
- ASSERT_NE(nullptr, returned_buffer_view);
- iree_hal_buffer_t* returned_buffer =
- iree_hal_buffer_view_buffer(returned_buffer_view);
- ASSERT_NE(nullptr, returned_buffer);
-
- // Dimemsionality is reduced by 1.
- auto returned_rank = iree_hal_buffer_view_shape_rank(returned_buffer_view);
- ASSERT_EQ(returned_rank, shape.size() - 1);
-
- iree_hal_dim_t returned_shape[1];
- iree_hal_buffer_view_shape(returned_buffer_view, 1, returned_shape,
- &returned_rank);
- EXPECT_EQ(returned_shape[0], shape[0] * shape[1]);
-
- iree_hal_mapped_memory_t mapped_memory;
- IREE_ASSERT_OK(iree_hal_buffer_map(returned_buffer,
- IREE_HAL_MEMORY_ACCESS_READ, 0,
- IREE_WHOLE_BUFFER, &mapped_memory));
- EXPECT_EQ(std::memcmp(mapped_memory.contents.data,
- static_cast<void*>(&kBufferContents[0]),
- mapped_memory.contents.data_length),
- 0);
- IREE_ASSERT_OK(iree_hal_buffer_unmap(returned_buffer, &mapped_memory));
+TEST_F(TensorListModulesTest, ConcatAppendsEmpty) {
+ // Allocate the buffer we'll be passing through.
+ std::vector<float> input = {42.0f};
+ absl::InlinedVector<int32_t, 4> input_shape = {1};
+ std::vector<float> expected = {42.0f, 0.0f};
+ absl::InlinedVector<int32_t, 4> expected_shape = {2};
+ Invoke("concat_appends_empty", input, input_shape, expected, expected_shape);
}
TEST_F(TensorListModulesTest, IdentityThroughStack) {
// Allocate the buffer we'll be passing through.
- static float kBufferContents[2] = {42.0f, 43.0f};
- absl::InlinedVector<int32_t, 4> shape = {2, 1};
- vm::ref<iree_hal_buffer_view_t> input_buffer_view;
- CreateBufferView(kBufferContents, shape, device_, &input_buffer_view);
+ std::vector<float> input = {42.0f, 43.0f};
+ absl::InlinedVector<int32_t, 4> input_shape = {2, 1};
+ Invoke("identity_through_stack", input, input_shape, input, input_shape);
+}
- // Pass in the tensor as a HAL buffer view.
- vm::ref<iree_vm_list_t> inputs;
- IREE_ASSERT_OK(iree_vm_list_create(/*element_type=*/nullptr, 1,
- iree_allocator_system(), &inputs));
- iree_vm_ref_t input_buffer_view_ref =
- iree_hal_buffer_view_move_ref(input_buffer_view.get());
- IREE_ASSERT_OK(
- iree_vm_list_push_ref_retain(inputs.get(), &input_buffer_view_ref));
-
- // Prepare outputs list to accept the results from the invocation.
- vm::ref<iree_vm_list_t> outputs;
- IREE_ASSERT_OK(iree_vm_list_create(/*element_type=*/nullptr, 1,
- iree_allocator_system(), &outputs));
-
- // Synchronously invoke the function.
- IREE_ASSERT_OK(iree_vm_invoke(context_,
- LookupFunction("identity_through_stack"),
- /*policy=*/nullptr, inputs.get(), outputs.get(),
- iree_allocator_system()));
-
- auto* returned_buffer_view =
- reinterpret_cast<iree_hal_buffer_view_t*>(iree_vm_list_get_ref_deref(
- outputs.get(), 0, iree_hal_buffer_view_get_descriptor()));
- ASSERT_NE(nullptr, returned_buffer_view);
- iree_hal_buffer_t* returned_buffer =
- iree_hal_buffer_view_buffer(returned_buffer_view);
- ASSERT_NE(nullptr, returned_buffer);
-
- iree_hal_mapped_memory_t mapped_memory;
- IREE_ASSERT_OK(iree_hal_buffer_map(returned_buffer,
- IREE_HAL_MEMORY_ACCESS_READ, 0,
- IREE_WHOLE_BUFFER, &mapped_memory));
- EXPECT_EQ(std::memcmp(mapped_memory.contents.data,
- static_cast<void*>(&kBufferContents[0]),
- mapped_memory.contents.data_length),
- 0);
- IREE_ASSERT_OK(iree_hal_buffer_unmap(returned_buffer, &mapped_memory));
+TEST_F(TensorListModulesTest, StackAppendsEmpty) {
+ // Allocate the buffer we'll be passing through.
+ std::vector<float> input = {42.0f};
+ absl::InlinedVector<int32_t, 4> input_shape = {};
+ std::vector<float> expected = {42.0f, 0.0f};
+ absl::InlinedVector<int32_t, 4> expected_shape = {2};
+ Invoke("stack_appends_empty", input, input_shape, expected, expected_shape);
}
} // namespace
diff --git a/iree/modules/tensorlist/tensorlist_test.mlir b/iree/modules/tensorlist/tensorlist_test.mlir
index dfe10a4..6425b1b 100644
--- a/iree/modules/tensorlist/tensorlist_test.mlir
+++ b/iree/modules/tensorlist/tensorlist_test.mlir
@@ -4,27 +4,62 @@
%0 = hal.buffer_view.const %allocator, "HostLocal|DeviceVisible", "All" : !hal.buffer_view = dense<1> : tensor<i32>
%1 = hal.buffer_view.const %allocator, "HostLocal|DeviceVisible", "All" : !hal.buffer_view = dense<[]> : tensor<0xi32>
%2 = hal.buffer_view.const %allocator, "HostLocal|DeviceVisible", "All" : !hal.buffer_view = dense<0> : tensor<i32>
- %3 = "tensorlist.Reserve"(%1, %0) : (!hal.buffer_view, !hal.buffer_view) -> !tensorlist.list
+ %3 = "tensorlist.Reserve"(%1, %0) { element_type = 50331680 : i32} : (!hal.buffer_view, !hal.buffer_view) -> !tensorlist.list
%4 = "tensorlist.SetItem"(%3, %2, %arg0) : (!tensorlist.list, !hal.buffer_view, !hal.buffer_view) -> !tensorlist.list
- %5 = "tensorlist.GetItem"(%4, %2, %1) : (!tensorlist.list, !hal.buffer_view, !hal.buffer_view) -> !hal.buffer_view
+ %5 = "tensorlist.GetItem"(%4, %2) : (!tensorlist.list, !hal.buffer_view) -> !hal.buffer_view
return %5 : !hal.buffer_view
}
+func @identity_through_set_item_get_item_2D(%arg0: !hal.buffer_view) -> !hal.buffer_view attributes {iree.module.export, iree.abi.none} {
+ %dev = hal.ex.shared_device : !hal.device
+ %allocator = hal.device.allocator %dev : !hal.allocator
+ %0 = hal.buffer_view.const %allocator, "HostLocal|DeviceVisible", "All" : !hal.buffer_view = dense<1> : tensor<i32>
+ %1 = hal.buffer_view.const %allocator, "HostLocal|DeviceVisible", "All" : !hal.buffer_view = dense<[1, 1]> : tensor<2xi32>
+ %2 = hal.buffer_view.const %allocator, "HostLocal|DeviceVisible", "All" : !hal.buffer_view = dense<0> : tensor<i32>
+ %3 = "tensorlist.Reserve"(%1, %0) { element_type = 50331680 : i32} : (!hal.buffer_view, !hal.buffer_view) -> !tensorlist.list
+ %4 = "tensorlist.SetItem"(%3, %2, %arg0) : (!tensorlist.list, !hal.buffer_view, !hal.buffer_view) -> !tensorlist.list
+ %stacked = "tensorlist.Stack"(%allocator, %4, %0) : (!hal.allocator, !tensorlist.list, !hal.buffer_view) -> !hal.buffer_view
+ return %stacked : !hal.buffer_view
+}
+
func @identity_through_concat(%arg0: !hal.buffer_view) -> !hal.buffer_view attributes {iree.module.export, iree.abi.none} {
%dev = hal.ex.shared_device : !hal.device
%allocator = hal.device.allocator %dev : !hal.allocator
%element_shape = hal.buffer_view.const %allocator, "HostLocal|DeviceVisible", "All" : !hal.buffer_view = dense<[]> : tensor<0xi32>
- %list = "tensorlist.FromTensor"(%arg0, %element_shape) : (!hal.buffer_view, !hal.buffer_view) -> !tensorlist.list
- %concat = "tensorlist.Concat"(%list) : (!tensorlist.list) -> !hal.buffer_view
+ %list = "tensorlist.FromTensor"(%arg0) : (!hal.buffer_view) -> !tensorlist.list
+ %concat = "tensorlist.Concat"(%allocator, %list) : (!hal.allocator, !tensorlist.list) -> !hal.buffer_view
+ return %concat : !hal.buffer_view
+}
+
+func @concat_appends_empty(%arg0: !hal.buffer_view) -> !hal.buffer_view attributes {iree.module.export, iree.abi.none} {
+ %dev = hal.ex.shared_device : !hal.device
+ %allocator = hal.device.allocator %dev : !hal.allocator
+ %0 = hal.buffer_view.const %allocator, "HostLocal|DeviceVisible", "All" : !hal.buffer_view = dense<2> : tensor<i32>
+ %1 = hal.buffer_view.const %allocator, "HostLocal|DeviceVisible", "All" : !hal.buffer_view = dense<[1]> : tensor<1xi32>
+ %2 = hal.buffer_view.const %allocator, "HostLocal|DeviceVisible", "All" : !hal.buffer_view = dense<0> : tensor<i32>
+ %3 = "tensorlist.Reserve"(%1, %0) { element_type = 50331680 : i32} : (!hal.buffer_view, !hal.buffer_view) -> !tensorlist.list
+ %4 = "tensorlist.SetItem"(%3, %2, %arg0) : (!tensorlist.list, !hal.buffer_view, !hal.buffer_view) -> !tensorlist.list
+ %concat = "tensorlist.Concat"(%allocator, %4) : (!hal.allocator, !tensorlist.list) -> !hal.buffer_view
return %concat : !hal.buffer_view
}
func @identity_through_stack(%arg0: !hal.buffer_view) -> !hal.buffer_view attributes {iree.module.export, iree.abi.none} {
%dev = hal.ex.shared_device : !hal.device
%allocator = hal.device.allocator %dev : !hal.allocator
- %element_shape = hal.buffer_view.const %allocator, "HostLocal|DeviceVisible", "All" : !hal.buffer_view = dense<[]> : tensor<0xi32>
%num_elements = hal.buffer_view.const %allocator, "HostLocal|DeviceVisible", "All" : !hal.buffer_view = dense<2> : tensor<i32>
- %list = "tensorlist.FromTensor"(%arg0, %element_shape) : (!hal.buffer_view, !hal.buffer_view) -> !tensorlist.list
- %stacked = "tensorlist.Stack"(%list, %element_shape, %num_elements) : (!tensorlist.list, !hal.buffer_view, !hal.buffer_view) -> !hal.buffer_view
+ %list = "tensorlist.FromTensor"(%arg0) : (!hal.buffer_view) -> !tensorlist.list
+ %stacked = "tensorlist.Stack"(%allocator, %list, %num_elements) : (!hal.allocator, !tensorlist.list, !hal.buffer_view) -> !hal.buffer_view
+ return %stacked : !hal.buffer_view
+}
+
+func @stack_appends_empty(%arg0: !hal.buffer_view) -> !hal.buffer_view attributes {iree.module.export, iree.abi.none} {
+ %dev = hal.ex.shared_device : !hal.device
+ %allocator = hal.device.allocator %dev : !hal.allocator
+ %0 = hal.buffer_view.const %allocator, "HostLocal|DeviceVisible", "All" : !hal.buffer_view = dense<2> : tensor<i32>
+ %1 = hal.buffer_view.const %allocator, "HostLocal|DeviceVisible", "All" : !hal.buffer_view = dense<[]> : tensor<0xi32>
+ %2 = hal.buffer_view.const %allocator, "HostLocal|DeviceVisible", "All" : !hal.buffer_view = dense<0> : tensor<i32>
+ %3 = "tensorlist.Reserve"(%1, %0) { element_type = 50331680 : i32} : (!hal.buffer_view, !hal.buffer_view) -> !tensorlist.list
+ %4 = "tensorlist.SetItem"(%3, %2, %arg0) : (!tensorlist.list, !hal.buffer_view, !hal.buffer_view) -> !tensorlist.list
+ %stacked = "tensorlist.Stack"(%allocator, %4, %0) : (!hal.allocator, !tensorlist.list, !hal.buffer_view) -> !hal.buffer_view
return %stacked : !hal.buffer_view
}
diff --git a/iree/test/e2e/models/fullyconnected.mlir b/iree/test/e2e/models/fullyconnected.mlir
index a7f4cd4..81a5dee 100644
--- a/iree/test/e2e/models/fullyconnected.mlir
+++ b/iree/test/e2e/models/fullyconnected.mlir
@@ -1,5 +1,5 @@
// RUN: iree-run-mlir -export-all %s -iree-hal-target-backends=vmla -function-input="1x5xf32=1,-2,-3,4,-5" -function-input="1x5x3x1xf32=15,14,13,12,11,10,9,8,7,6,5,4,3,2,1" | IreeFileCheck %s
-// RUN: [[ $IREE_LLVMJIT_DISABLE == 1 ]] || (iree-run-mlir -export-all %s -iree-hal-target-backends=llvm-ir -function-input="1x5xf32=1,-2,-3,4,-5" -function-input="1x5x3x1xf32=15,14,13,12,11,10,9,8,7,6,5,4,3,2,1" | IreeFileCheck %s)
+// RUN: [[ $IREE_LLVMJIT_DISABLE == 1 ]] || (iree-run-mlir -export-all %s -iree-hal-target-backends=llvm-ir -function-input="1x5xf32=1,-2,-3,4,-5" -function-input="1x5x3x1xf32=15,14,13,12,11,10,9,8,7,6,5,4,3,2,1" -iree-enable-matmul-fusion | IreeFileCheck %s)
// RUN: [[ $IREE_VULKAN_DISABLE == 1 ]] || (iree-run-mlir -export-all %s -iree-hal-target-backends=vulkan-spirv -function-input="1x5xf32=1,-2,-3,4,-5" -function-input="1x5x3x1xf32=15,14,13,12,11,10,9,8,7,6,5,4,3,2,1" | IreeFileCheck %s)
// CHECK-LABEL: EXEC @main
diff --git a/iree/test/e2e/regression/matmul_add.mlir b/iree/test/e2e/regression/matmul_add.mlir
new file mode 100644
index 0000000..1584692
--- /dev/null
+++ b/iree/test/e2e/regression/matmul_add.mlir
@@ -0,0 +1,26 @@
+// RUN: [[ $IREE_LLVMJIT_DISABLE == 1 ]] || (iree-run-mlir -export-all -iree-hal-target-backends=llvm-ir -iree-enable-matmul-fusion %s | IreeFileCheck %s)
+// RUN: [[ $IREE_VULKAN_DISABLE == 1 ]] || (iree-run-mlir -export-all -iree-hal-target-backends=vulkan-spirv -iree-enable-matmul-fusion %s | IreeFileCheck %s)
+
+func @matmul_add() -> tensor<2x4xf32> {
+ %0 = iree.unfoldable_constant dense<[
+ [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0],
+ [9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0]]> : tensor<2x8xf32>
+ %1 = iree.unfoldable_constant dense<[
+ [ 1.0, 2.0, 3.0, 4.0],
+ [ 5.0, 6.0, 7.0, 8.0],
+ [ 9.0, 10.0, 11.0, 12.0],
+ [13.0, 14.0, 15.0, 16.0],
+ [17.0, 18.0, 19.0, 20.0],
+ [21.0, 22.0, 23.0, 24.0],
+ [25.0, 26.0, 27.0, 28.0],
+ [29.0, 30.0, 31.0, 32.0]]> : tensor<8x4xf32>
+ %2 = iree.unfoldable_constant dense<[
+ [1.0, 2.0, 3.0, 4.0],
+ [5.0, 6.0, 7.0, 8.0]]> : tensor<2x4xf32>
+ %3 = "mhlo.dot"(%0, %1) : (tensor<2x8xf32>, tensor<8x4xf32>) -> tensor<2x4xf32>
+ %4 = mhlo.add %3, %2 : tensor<2x4xf32>
+ return %4 : tensor<2x4xf32>
+}
+
+// CHECK: EXEC @matmul_add
+// CHECK: 2x4xf32=[709 746 783 820][1673 1774 1875 1976]
diff --git a/iree/test/e2e/vulkan_specific/BUILD b/iree/test/e2e/vulkan_specific/BUILD
index adde1a1..b340804 100644
--- a/iree/test/e2e/vulkan_specific/BUILD
+++ b/iree/test/e2e/vulkan_specific/BUILD
@@ -57,3 +57,16 @@
driver = "vulkan",
target_backend = "vulkan-spirv",
)
+
+iree_check_single_backend_test_suite(
+ name = "check_vulkan-spirv_vulkan_vectorized_conv",
+ srcs = [
+ "vectorized_conv.mlir",
+ ],
+ compiler_flags = [
+ "-iree-spirv-enable-vectorization",
+ "-iree-vulkan-target-triple=valhall-g77-unknown-android10",
+ ],
+ driver = "vulkan",
+ target_backend = "vulkan-spirv",
+)
diff --git a/iree/test/e2e/vulkan_specific/CMakeLists.txt b/iree/test/e2e/vulkan_specific/CMakeLists.txt
index a6126da..0cc1b30 100644
--- a/iree/test/e2e/vulkan_specific/CMakeLists.txt
+++ b/iree/test/e2e/vulkan_specific/CMakeLists.txt
@@ -56,3 +56,17 @@
COMPILER_FLAGS
"-iree-spirv-enable-memref-vectorization"
)
+
+iree_check_single_backend_test_suite(
+ NAME
+ check_vulkan-spirv_vulkan_vectorized_conv
+ SRCS
+ "vectorized_conv.mlir"
+ TARGET_BACKEND
+ vulkan-spirv
+ DRIVER
+ vulkan
+ COMPILER_FLAGS
+ "-iree-spirv-enable-vectorization"
+ "-iree-vulkan-target-triple=valhall-g77-unknown-android10"
+)
diff --git a/iree/test/e2e/vulkan_specific/vectorized_conv.mlir b/iree/test/e2e/vulkan_specific/vectorized_conv.mlir
new file mode 100644
index 0000000..39ad667
--- /dev/null
+++ b/iree/test/e2e/vulkan_specific/vectorized_conv.mlir
@@ -0,0 +1,70 @@
+func @conv() attributes { iree.module.export } {
+ %input = iree.unfoldable_constant dense<
+ [[[[6.0, 7.5, 0.0, 1.5],
+ [1.5, 3.5, 4.5, 2.0],
+ [3.0, 6.0, 0.5, 3.0]],
+ [[3.5, 7.0, 2.5, 6.5],
+ [4.0, 4.5, 8.0, 2.5],
+ [7.5, 7.5, 0.0, 1.5]],
+ [[7.0, 3.5, 0.0, 0.5],
+ [4.5, 0.0, 5.0, 1.5],
+ [5.5, 1.0, 0.0, 0.0]]]]>
+ : tensor<1x3x3x4xf32>
+ %filter = iree.unfoldable_constant dense<
+ [[[[2.0, 2.5, 2.5, 3.0, 4.0, 2.0, 0.5, 2.0, 4.5, 5.0, 5.0, 4.0, 0.5, 0.5, 3.5, 4.5,
+ 4.5, 1.5, 3.0, 3.5, 1.0, 0.0, 1.5, 2.5, 4.5, 5.0, 2.0, 2.0, 3.0, 2.0, 2.0, 1.5],
+ [2.0, 2.0, 4.0, 2.0, 1.5, 5.0, 3.5, 2.5, 2.5, 0.0, 0.5, 2.5, 4.5, 1.5, 0.0, 2.5,
+ 0.0, 0.5, 1.0, 2.0, 1.0, 0.0, 1.5, 1.0, 5.0, 0.0, 3.5, 2.5, 4.5, 0.0, 5.0, 1.0],
+ [5.0, 3.5, 1.0, 4.5, 1.0, 1.5, 1.5, 1.0, 1.5, 2.0, 0.5, 1.0, 4.5, 5.0, 0.5, 2.0,
+ 5.0, 3.0, 4.0, 1.0, 1.5, 0.0, 0.0, 3.0, 0.0, 3.0, 1.5, 5.0, 1.5, 4.0, 4.0, 4.0],
+ [1.0, 1.5, 1.0, 0.0, 4.0, 4.0, 1.5, 4.0, 5.0, 1.0, 4.0, 2.0, 1.5, 0.0, 2.0, 1.5,
+ 3.0, 4.5, 4.0, 0.0, 4.0, 2.5, 4.5, 0.0, 4.5, 3.0, 2.5, 1.5, 0.5, 4.0, 0.0, 2.0]],
+ [[4.5, 3.0, 2.5, 3.5, 4.0, 4.0, 4.5, 1.0, 4.0, 3.0, 3.0, 4.5, 0.5, 3.0, 4.0, 4.0,
+ 1.5, 1.0, 1.5, 5.0, 3.0, 1.5, 3.0, 2.5, 3.5, 0.0, 4.0, 2.0, 5.0, 3.0, 2.5, 4.0],
+ [1.0, 1.5, 4.5, 3.5, 2.5, 1.5, 2.0, 2.5, 1.5, 1.5, 3.5, 4.5, 4.5, 4.5, 3.5, 1.5,
+ 5.0, 1.0, 1.5, 4.5, 5.0, 3.5, 3.5, 2.5, 0.5, 1.0, 1.0, 4.0, 0.5, 2.5, 4.0, 2.0],
+ [0.0, 1.0, 2.5, 2.5, 0.0, 4.0, 0.5, 0.5, 0.0, 1.5, 4.0, 4.0, 2.0, 2.0, 0.0, 4.5,
+ 1.5, 3.5, 1.5, 1.0, 0.5, 0.5, 1.0, 0.5, 2.0, 1.0, 2.5, 2.5, 2.5, 1.0, 2.5, 3.5],
+ [3.5, 3.0, 0.5, 3.0, 3.5, 1.0, 1.5, 0.5, 4.5, 2.5, 4.5, 4.5, 1.0, 0.0, 4.5, 0.5,
+ 4.5, 5.0, 0.0, 3.0, 0.0, 5.0, 2.0, 4.0, 2.0, 1.5, 1.5, 4.0, 4.0, 3.5, 0.0, 1.5]]],
+ [[[4.0, 3.5, 3.5, 5.0, 0.5, 4.0, 2.0, 3.5, 0.0, 2.0, 4.5, 0.0, 5.0, 3.0, 2.0, 1.0,
+ 2.0, 3.0, 1.5, 5.0, 1.5, 3.5, 4.0, 2.5, 0.0, 4.0, 2.5, 2.0, 3.5, 5.0, 5.0, 2.0],
+ [0.5, 1.5, 1.5, 4.5, 1.0, 2.5, 1.0, 1.5, 2.5, 5.0, 3.5, 1.0, 3.5, 0.5, 3.0, 5.0,
+ 2.5, 0.0, 0.0, 5.0, 1.5, 5.0, 0.5, 5.0, 4.5, 4.5, 3.0, 3.0, 3.5, 4.0, 4.0, 3.5],
+ [0.0, 4.0, 3.0, 4.0, 4.5, 4.0, 1.5, 3.0, 0.5, 3.5, 2.0, 4.5, 1.0, 0.0, 4.0, 1.0,
+ 3.5, 4.0, 2.0, 2.0, 0.5, 3.5, 3.0, 4.5, 2.0, 0.5, 2.5, 4.5, 3.5, 0.5, 1.5, 2.5],
+ [3.5, 1.5, 3.0, 3.0, 3.5, 4.5, 0.5, 4.5, 3.0, 0.0, 1.5, 4.0, 2.0, 0.5, 2.0, 2.5,
+ 0.0, 1.5, 5.0, 0.5, 2.0, 2.0, 2.0, 0.0, 0.0, 5.0, 4.0, 2.0, 3.0, 4.5, 1.5, 1.5]],
+ [[1.0, 0.5, 5.0, 1.0, 0.5, 1.5, 2.0, 5.0, 0.5, 0.5, 0.0, 3.5, 4.0, 5.0, 2.0, 1.5,
+ 2.5, 3.0, 1.5, 1.0, 4.5, 4.0, 0.5, 2.0, 5.0, 0.0, 4.0, 1.5, 4.5, 2.5, 2.5, 0.5],
+ [3.5, 4.0, 3.0, 2.0, 3.5, 1.5, 2.5, 1.5, 3.0, 2.0, 3.5, 1.5, 0.0, 2.5, 4.5, 1.5,
+ 3.5, 2.5, 2.5, 4.0, 0.0, 4.0, 1.5, 3.0, 4.5, 5.0, 1.5, 1.0, 3.5, 0.0, 1.5, 5.0],
+ [0.0, 1.5, 3.0, 0.5, 4.5, 1.0, 4.5, 2.0, 4.5, 0.5, 1.5, 1.0, 2.0, 4.5, 3.5, 2.0,
+ 4.5, 2.0, 0.5, 1.0, 3.5, 1.0, 1.5, 4.5, 5.0, 3.5, 5.0, 3.0, 3.0, 1.0, 5.0, 1.5],
+ [3.0, 0.0, 5.0, 4.0, 0.0, 5.0, 3.5, 3.0, 2.5, 4.5, 3.0, 2.5, 1.0, 3.5, 0.5, 4.5,
+ 1.0, 1.0, 2.5, 3.0, 2.0, 1.0, 1.0, 0.5, 0.0, 4.5, 0.0, 1.0, 4.0, 1.5, 5.0, 0.0]]]]>
+ : tensor<2x2x4x32xf32>
+
+ %0 = "mhlo.convolution"(%input, %filter) {batch_group_count = 1 : i64, dimension_numbers = {input_batch_dimension = 0 : i64, input_feature_dimension = 3 : i64, input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>, kernel_input_feature_dimension = 2 : i64, kernel_output_feature_dimension = 3 : i64, kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>, output_batch_dimension = 0 : i64, output_feature_dimension = 3 : i64, output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>}, feature_group_count = 1 : i64, padding = dense<0> : tensor<2x2xi64>, rhs_dilation = dense<1> : tensor<2xi64>, window_strides = dense<1> : tensor<2xi64>} : (tensor<1x3x3x4xf32>, tensor<2x2x4x32xf32>) -> tensor<1x2x2x32xf32>
+
+ check.expect_almost_eq_const(%0, dense<
+ [[[[113.25, 127.0, 198.0, 173.25, 159.5, 190.75, 135.5, 160.0,
+ 169.5, 130.0, 173.75, 174.5, 158.5, 136.75, 159.75, 177.75,
+ 164.5, 122.25, 116.0, 168.0, 124.75, 144.0, 113.5, 159.0,
+ 208.0, 186.5, 190.5, 158.5, 213.75, 140.5, 206.75, 135.25],
+ [129.75, 147.25, 181.25, 181.75, 142.5, 161.75, 117.75, 153.25,
+ 119.5, 128.75, 149.25, 171.0, 152.5, 142.5, 166.0, 122.25,
+ 177.75, 142.75, 116.5, 170.0, 117.5, 176.75, 116.75, 162.25,
+ 161.25, 135.0, 145.5, 163.25, 190.5, 138.25, 162.5, 146.75]],
+ [[111.75, 115.75, 173.5, 158.25, 122.5, 187.25, 129.0, 142.5,
+ 142.25, 109.0, 175.75, 158.5, 172.75, 146.25, 122.25, 157.25,
+ 157.5, 141.25, 104.25, 151.25, 136.25, 122.0, 127.75, 125.75,
+ 180.5, 131.25, 168.75, 151.5, 180.75, 152.75, 193.5, 128.75],
+ [138.25, 133.75, 157.5, 168.5, 131.0, 149.75, 115.25, 130.75,
+ 114.5, 107.25, 127.75, 163.75, 153.5, 149.25, 133.5, 114.0,
+ 164.75, 120.75, 116.0, 149.5, 127.5, 113.5, 116.0, 129.75,
+ 126.75, 94.25, 135.0, 157.75, 158.75, 142.0, 158.75, 126.25]]]]>
+ : tensor<1x2x2x32xf32>) : tensor<1x2x2x32xf32>
+ return
+}
+
diff --git a/iree/tools/CMakeLists.txt b/iree/tools/CMakeLists.txt
index de1c9b9..0f63c8c 100644
--- a/iree/tools/CMakeLists.txt
+++ b/iree/tools/CMakeLists.txt
@@ -14,13 +14,9 @@
# bazel_to_cmake: DO NOT EDIT (Special logic is used throughout this file)
-# This need to come first so targets in the android/ directory can depend on it.
-# TODO(#3317): this seems to indicate an issue somewhere regarding dynamic
-# library dependency management.
-add_subdirectory(utils)
-
add_subdirectory(android)
add_subdirectory(test)
+add_subdirectory(utils)
# Enable compiler targets based on options.
set(IREE_COMPILER_TARGETS "")
@@ -53,6 +49,9 @@
MLIREmitC
MLIRTargetCpp
)
+ set(IREE_EMITC_INCLUDES
+ "${PROJECT_SOURCE_DIR}/third_party/mlir-emitc/include"
+ )
endif()
iree_cc_binary(
@@ -208,6 +207,8 @@
iree::compiler::Conversion::init_conversions
iree::compiler::Conversion::HLOToLinalg
iree::compiler::Dialect::HAL::Conversion::Passes
+ INCLUDES
+ "${IREE_EMITC_INCLUDES}"
PUBLIC
)
@@ -295,6 +296,8 @@
iree::compiler::Dialect::VM::Target::init_targets
iree::compiler::Translation::IREEVM
${IREE_TRANSLATE_CONDITIONAL_DEPS}
+ INCLUDES
+ "${IREE_EMITC_INCLUDES}"
PUBLIC
)
diff --git a/iree/vm/bytecode_dispatch.c b/iree/vm/bytecode_dispatch.c
index bb5241f..2a9ef1b 100644
--- a/iree/vm/bytecode_dispatch.c
+++ b/iree/vm/bytecode_dispatch.c
@@ -14,27 +14,12 @@
#include <string.h>
+#include "iree/base/math.h"
#include "iree/base/tracing.h"
#include "iree/vm/api.h"
#include "iree/vm/bytecode_dispatch_util.h"
//===----------------------------------------------------------------------===//
-// Math utilities, kept here to limit dependencies
-//===----------------------------------------------------------------------===//
-
-// Rounds up the value to the nearest power of 2 (if not already a power of 2).
-static inline uint32_t iree_math_round_up_to_pow2_u32(uint32_t n) {
- n--;
- n |= n >> 1;
- n |= n >> 2;
- n |= n >> 4;
- n |= n >> 8;
- n |= n >> 16;
- n++;
- return n;
-}
-
-//===----------------------------------------------------------------------===//
// Register remapping utilities
//===----------------------------------------------------------------------===//
diff --git a/scripts/get_e2e_artifacts.py b/scripts/get_e2e_artifacts.py
index 2440ab7..21d9259 100755
--- a/scripts/get_e2e_artifacts.py
+++ b/scripts/get_e2e_artifacts.py
@@ -41,8 +41,6 @@
'//integrations/tensorflow/e2e:e2e_tests',
'mobile_bert_squad_tests':
'//integrations/tensorflow/e2e:mobile_bert_squad_tests',
- 'keras_tests':
- '//integrations/tensorflow/e2e/keras:keras_tests',
'layers_tests':
'//integrations/tensorflow/e2e/keras/layers:layers_tests',
'layers_dynamic_batch_tests':
@@ -54,7 +52,7 @@
'keyword_spotting_internal_streaming_tests':
'//integrations/tensorflow/e2e/keras:keyword_spotting_internal_streaming_tests',
'imagenet_non_hermetic_tests':
- '//integrations/tensorflow/e2e/keras:imagenet_non_hermetic_tests',
+ '//integrations/tensorflow/e2e/keras/applications:imagenet_non_hermetic_tests',
'slim_vision_tests':
'//integrations/tensorflow/e2e/slim_vision_models:slim_vision_tests',
}
diff --git a/scripts/update_e2e_coverage.py b/scripts/update_e2e_coverage.py
index 3b5509c..fba79b2 100755
--- a/scripts/update_e2e_coverage.py
+++ b/scripts/update_e2e_coverage.py
@@ -60,7 +60,7 @@
'//integrations/tensorflow/e2e/keras:keyword_spotting_internal_streaming_tests',
],
'vision_coverage': [
- '//integrations/tensorflow/e2e/keras:imagenet_non_hermetic_tests',
+ '//integrations/tensorflow/e2e/keras/applications:imagenet_non_hermetic_tests',
'//integrations/tensorflow/e2e/slim_vision_models:slim_vision_tests',
],
}
@@ -116,7 +116,7 @@
'//integrations/tensorflow/e2e/keras:keyword_spotting_internal_streaming_tests':
f'End to end tests of {KWS_LINK} models in internal streaming mode',
# vision_coverage
- '//integrations/tensorflow/e2e/keras:imagenet_non_hermetic_tests':
+ '//integrations/tensorflow/e2e/keras/applications:imagenet_non_hermetic_tests':
'End to end tests of tf.keras.applications vision models on Imagenet',
'//integrations/tensorflow/e2e/slim_vision_models:slim_vision_tests':
'End to end tests of TensorFlow slim vision models',
@@ -162,7 +162,7 @@
'//integrations/tensorflow/e2e/keras:keyword_spotting_internal_streaming_tests':
'model',
# vision_coverage
- '//integrations/tensorflow/e2e/keras:imagenet_non_hermetic_tests':
+ '//integrations/tensorflow/e2e/keras/applications:imagenet_non_hermetic_tests':
'model',
'//integrations/tensorflow/e2e/slim_vision_models:slim_vision_tests':
'model',
@@ -193,7 +193,7 @@
'//integrations/tensorflow/e2e/keras:keyword_spotting_internal_streaming_tests':
'keyword_spotting_streaming_test',
# vision_coverage
- '//integrations/tensorflow/e2e/keras:imagenet_non_hermetic_tests':
+ '//integrations/tensorflow/e2e/keras/applications:imagenet_non_hermetic_tests':
'vision_model_test',
'//integrations/tensorflow/e2e/slim_vision_models:slim_vision_tests':
'slim_vision_model_test',
diff --git a/third_party/googletest b/third_party/googletest
index f2fb48c..b1fbd33 160000
--- a/third_party/googletest
+++ b/third_party/googletest
@@ -1 +1 @@
-Subproject commit f2fb48c3b3d79a75a88a99fba6576b25d42ec528
+Subproject commit b1fbd33c06cdb0024c67733c6fdec2009d17b384
diff --git a/third_party/tracy b/third_party/tracy
index d7059ec..9c3dac3 160000
--- a/third_party/tracy
+++ b/third_party/tracy
@@ -1 +1 @@
-Subproject commit d7059eca6351546d1f51e248fc75e49dfeee709e
+Subproject commit 9c3dac3ed2bd647b8d63f197fed058fee97a7e1e