Support GEMM Pipelining *without* Epilogue Peeling (#10388)
diff --git a/.github/workflows/benchmarks.yml b/.github/workflows/benchmarks.yml
index 734ce3a..0dd66ae 100644
--- a/.github/workflows/benchmarks.yml
+++ b/.github/workflows/benchmarks.yml
@@ -19,9 +19,6 @@
runner-env:
required: true
type: string
- gcs-dir:
- required: true
- type: string
build-dir:
required: true
type: string
@@ -41,6 +38,14 @@
required: true
type: string
+env:
+ # This duplicates the variable from ci.yml. The variable needs to be in env
+ # instead of the outputs of setup because it contains the run attempt and we
+ # want that to be the current attempt, not whatever attempt the setup step
+ # last ran in. It therefore can't be passed in via inputs because the env
+ # context isn't available there.
+ GCS_DIR: gs://iree-github-actions-${{ github.event_name == 'pull_request' && 'presubmit' || 'postsubmit' }}-artifacts/${{ github.run_id }}/${{ github.run_attempt }}
+
jobs:
build_suites:
runs-on:
@@ -141,7 +146,7 @@
id: upload
env:
BENCHMARK_TOOLS_ARCHIVE: ${{ steps.archive.outputs.benchmark-tools-archive }}
- BENCHMARK_TOOLS_GCS_ARTIFACT: ${{ inputs.gcs-dir }}/${{ steps.archive.outputs.benchmark-tools-archive }}
+ BENCHMARK_TOOLS_GCS_ARTIFACT: ${{ env.GCS_DIR }}/${{ steps.archive.outputs.benchmark-tools-archive }}
run: |
gcloud alpha storage cp "${BENCHMARK_TOOLS_ARCHIVE}" "${BENCHMARK_TOOLS_GCS_ARTIFACT}"
echo "::set-output name=${PLATFORM}-${ARCHITECTURE}-benchmark-tools-gcs-artifact::${BENCHMARK_TOOLS_GCS_ARTIFACT}"
diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
index 2f68eb2..3e6143d 100644
--- a/.github/workflows/ci.yml
+++ b/.github/workflows/ci.yml
@@ -54,6 +54,10 @@
# See note above regarding lack of proper variables. Also see note about
# pseudo-ternary hack.
_CI_STAGE: ${{ github.event_name == 'pull_request' && 'presubmit' || 'postsubmit' }}
+ # This needs to be in env instead of the outputs of setup because it contains
+ # the run attempt and we want that to be the current attempt, not whatever
+ # attempt the setup step last ran in.
+ GCS_DIR: gs://iree-github-actions-${{ github.event_name == 'pull_request' && 'presubmit' || 'postsubmit' }}-artifacts/${{ github.run_id }}/${{ github.run_attempt }}
# Jobs are organized into groups and topologically sorted by dependencies
jobs:
@@ -68,7 +72,6 @@
should-run: ${{ steps.should-run.outputs.should-run }}
# Variables for dependent jobs. See comment at top.
runner-env: prod
- gcs-dir: gs://iree-github-actions-${{ env._CI_STAGE }}-artifacts/${{ github.run_id }}/${{ github.run_attempt }}
runner-group: ${{ env._CI_STAGE }}
# Note that we can't flip the condition here because 0 is falsey. See
# comment at top.
@@ -151,7 +154,7 @@
id: upload
env:
BUILD_DIR_ARCHIVE: ${{ steps.archive.outputs.build-dir-archive }}
- BUILD_DIR_GCS_ARTIFACT: ${{ needs.setup.outputs.gcs-dir }}/${{ steps.archive.outputs.build-dir-archive }}
+ BUILD_DIR_GCS_ARTIFACT: ${{ env.GCS_DIR }}/${{ steps.archive.outputs.build-dir-archive }}
run: |
gcloud alpha storage cp "${BUILD_DIR_ARCHIVE}" "${BUILD_DIR_GCS_ARTIFACT}"
echo "::set-output name=build-dir-gcs-artifact::${BUILD_DIR_GCS_ARTIFACT}"
@@ -180,7 +183,7 @@
run: |
./build_tools/github_actions/docker_run.sh \
--env "IREE_BAZEL_WRITE_REMOTE_CACHE=${IREE_BAZEL_WRITE_REMOTE_CACHE}" \
- gcr.io/iree-oss/frontends-swiftshader@sha256:41e516b8c1b432e3c02896c4bf4b7f06df6a67371aa167b88767b8d4d2018ea6 \
+ gcr.io/iree-oss/frontends-swiftshader@sha256:3d5b879672d7f302124ab3d1aa533a6949bd0adfc176884177844ac6767e23e9 \
./build_tools/bazel/build_core.sh
test_all:
@@ -343,7 +346,7 @@
./build_tools/github_actions/docker_run.sh \
--env "IREE_BAZEL_WRITE_REMOTE_CACHE=${IREE_BAZEL_WRITE_REMOTE_CACHE}" \
--env "IREE_TF_BINARIES_OUTPUT_DIR=${IREE_TF_BINARIES_OUTPUT_DIR}" \
- gcr.io/iree-oss/frontends-swiftshader@sha256:3090418a8d8a64c356d35eff285af32570a72f41127aa123209c1562f57abb01 \
+ gcr.io/iree-oss/frontends-swiftshader@sha256:3d5b879672d7f302124ab3d1aa533a6949bd0adfc176884177844ac6767e23e9 \
build_tools/cmake/build_tf_binaries.sh
echo "::set-output name=binaries-dir::${IREE_TF_BINARIES_OUTPUT_DIR}"
- name: "Creating archive of binaries"
@@ -358,7 +361,7 @@
id: upload
env:
BINARIES_ARCHIVE: ${{ steps.archive.outputs.binaries-archive }}
- BINARIES_GCS_ARTIFACT: ${{ needs.setup.outputs.gcs-dir }}/${{ steps.archive.outputs.binaries-archive }}
+ BINARIES_GCS_ARTIFACT: ${{ env.GCS_DIR }}/${{ steps.archive.outputs.binaries-archive }}
run: |
gcloud alpha storage cp "${BINARIES_ARCHIVE}" "${BINARIES_GCS_ARTIFACT}"
echo "::set-output name=binaries-gcs-artifact::${BINARIES_GCS_ARTIFACT}"
@@ -398,7 +401,7 @@
- name: "Running TF integrations tests"
run: |
./build_tools/github_actions/docker_run.sh \
- gcr.io/iree-oss/frontends-swiftshader@sha256:3090418a8d8a64c356d35eff285af32570a72f41127aa123209c1562f57abb01 \
+ gcr.io/iree-oss/frontends-swiftshader@sha256:3d5b879672d7f302124ab3d1aa533a6949bd0adfc176884177844ac6767e23e9 \
build_tools/cmake/run_tf_tests.sh \
"${BUILD_DIR}"
@@ -440,7 +443,7 @@
--env IREE_LLVM_CPU_DISABLE=1 \
--gpus all \
--env NVIDIA_DRIVER_CAPABILITIES=all \
- gcr.io/iree-oss/frontends-nvidia@sha256:e934ed09e9e60c28ebe11a02f37a993dd975db40118d410c4279d0fa2d4e6b9a \
+ gcr.io/iree-oss/frontends-nvidia@sha256:28cd43f36b1ca0633bbd915911abe6d22b4aa16093f074e87016305322a0eba1 \
bash -euo pipefail -c \
"./build_tools/scripts/check_cuda.sh
./build_tools/scripts/check_vulkan.sh
@@ -494,9 +497,10 @@
if: needs.setup.outputs.should-run == 'true'
uses: ./.github/workflows/benchmarks.yml
with:
+ # env.GCS_DIR is also duplicated in this workflow. See the note there on
+ # why this is.
runner-group: ${{ needs.setup.outputs.runner-group }}
runner-env: ${{ needs.setup.outputs.runner-env }}
- gcs-dir: ${{ needs.setup.outputs.gcs-dir }}
build-dir: ${{ needs.build_all.outputs.build-dir }}
build-dir-archive: ${{ needs.build_all.outputs.build-dir-archive }}
build-dir-gcs-artifact: ${{ needs.build_all.outputs.build-dir-gcs-artifact }}
@@ -536,7 +540,7 @@
build_tools/github_actions/docker_run.sh \
--env "ANDROID_ABI=${ANDROID_ABI}" \
--env "IREE_HOST_BINARY_ROOT=${BUILD_DIR}/install" \
- gcr.io/iree-oss/android@sha256:9bc723fc707a18bd0c1be9c12e01ea5bb7c7d77f607427879e10ffcffd7d2bb5 \
+ gcr.io/iree-oss/android@sha256:76c2a52dcd6d07601227b965ac87d021c1d2d5e2d01f46ad58da28c89267f2ab \
build_tools/cmake/build_android.sh
riscv32:
@@ -621,7 +625,8 @@
##############################################################################
# Depends on all the other jobs to provide a single anchor that indicates the
- # final status. Status reporting will become more sophisticated in the future.
+ # final status. Status reporting will become more sophisticated in the future
+ # and we can hopefully avoid the need to explicitly list every single job...
summary:
# Even if you have an explicit if condition, you still need to override
# GitHub's default behavior of not running if any dependencies failed.
@@ -639,6 +644,7 @@
- test_runtime
# Tensorflow
+ - build_tf_integrations
- test_tf_integrations
- test_tf_integrations_gpu
diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml
index 5813290..e07e9ef 100644
--- a/.github/workflows/lint.yml
+++ b/.github/workflows/lint.yml
@@ -128,3 +128,11 @@
run: git fetch --no-tags --prune --depth=1 origin "${GITHUB_BASE_REF?}:${GITHUB_BASE_REF?}"
- name: yamllint
run: ./build_tools/scripts/run_yamllint.sh "${GITHUB_BASE_REF?}"
+
+ path_lengths:
+ runs-on: ubuntu-20.04
+ steps:
+ - name: Checking out repository
+ uses: actions/checkout@7884fcad6b5d53d10323aee724dc68d8b9096a2e # v2
+ - name: Running check_path_lengths
+ run: ./build_tools/scripts/check_path_lengths.py
diff --git a/build_tools/bazel/iree.bazelrc b/build_tools/bazel/iree.bazelrc
index 8937507..4cf79c2 100644
--- a/build_tools/bazel/iree.bazelrc
+++ b/build_tools/bazel/iree.bazelrc
@@ -268,7 +268,7 @@
# specific docker container the CI Bazel builds are run in. The image URL is
# included for clarity and so that this reference is automatically updated by
# manage_images.py
-build:remote_cache_bazel_ci --host_platform_remote_properties_override='properties:{name:"cache-silo-key" value:"gcr.io/iree-oss/frontends-swiftshader@sha256:3090418a8d8a64c356d35eff285af32570a72f41127aa123209c1562f57abb01"}'
+build:remote_cache_bazel_ci --host_platform_remote_properties_override='properties:{name:"cache-silo-key" value:"gcr.io/iree-oss/frontends-swiftshader@sha256:3d5b879672d7f302124ab3d1aa533a6949bd0adfc176884177844ac6767e23e9"}'
###############################################################################
# Configuration for uploading build results to Result Store UI
diff --git a/build_tools/benchmarks/comparisons/mobilebert_fp32_commands.py b/build_tools/benchmarks/comparisons/mobilebert_fp32_commands.py
index 370c108..4adb80f 100644
--- a/build_tools/benchmarks/comparisons/mobilebert_fp32_commands.py
+++ b/build_tools/benchmarks/comparisons/mobilebert_fp32_commands.py
@@ -118,7 +118,8 @@
# Generate IREE benchmarks.
driver = "local-task"
- iree_model_path = os.path.join(self._base_dir, "models", "iree", driver,
+ backend = "llvm-cpu"
+ iree_model_path = os.path.join(self._base_dir, "models", "iree", backend,
self._model_name + ".vmfb")
iree_mobilebert = IreeMobilebertFP32(self._iree_benchmark_binary_path,
self._model_name,
@@ -130,7 +131,7 @@
if device == "mobile":
model_mmt4d_name = self._model_name + "_mmt4d"
iree_mmt4d_model_path = os.path.join(self._base_dir, "models", "iree",
- driver, model_mmt4d_name + ".vmfb")
+ backend, model_mmt4d_name + ".vmfb")
iree_mmt4d_mobilebert = IreeMobilebertFP32(
self._iree_benchmark_binary_path,
model_mmt4d_name,
@@ -138,6 +139,17 @@
driver=driver)
commands.append(iree_mmt4d_mobilebert)
+ model_im2col_mmt4d_name = self._model_name + "_im2col_mmt4d"
+ iree_im2col_mmt4d_model_path = os.path.join(
+ self._base_dir, "models", "iree", backend,
+ model_im2col_mmt4d_name + ".vmfb")
+ iree_im2col_mmt4d_mobilebert = IreeMobilebertFP32(
+ self._iree_benchmark_binary_path,
+ model_im2col_mmt4d_name,
+ iree_im2col_mmt4d_model_path,
+ driver=driver)
+ commands.append(iree_im2col_mmt4d_mobilebert)
+
return commands
def _generate_gpu(self, driver: str):
@@ -171,7 +183,22 @@
self._model_name,
iree_model_path,
driver=driver)
+ iree_fp16_model_path = os.path.join(self._base_dir, "models", "iree",
+ driver, self._model_name + "_fp16.vmfb")
+ iree_mobilebert_fp16 = IreeMobilebertFP32(self._iree_benchmark_binary_path,
+ self._model_name + "_fp16",
+ iree_fp16_model_path,
+ driver=driver)
+ iree_padfuse_model_path = os.path.join(self._base_dir, "models", "iree",
+ driver,
+ self._model_name + "_padfuse.vmfb")
+ iree_mobilebert_padfuse = IreeMobilebertFP32(
+ self._iree_benchmark_binary_path,
+ self._model_name + "_padfuse",
+ iree_padfuse_model_path,
+ driver=driver)
+
return [
tflite_mobilebert, tflite_mobilebert_noxnn, tflite_mobilebert_fp16,
- iree_mobilebert
+ iree_mobilebert, iree_mobilebert_fp16, iree_mobilebert_padfuse
]
diff --git a/build_tools/benchmarks/comparisons/setup_desktop.sh b/build_tools/benchmarks/comparisons/setup_desktop.sh
index 1de08d7..e227f21 100644
--- a/build_tools/benchmarks/comparisons/setup_desktop.sh
+++ b/build_tools/benchmarks/comparisons/setup_desktop.sh
@@ -9,7 +9,7 @@
set -euo pipefail
# Install Bazel. From https://www.tensorflow.org/install/source
-npm install -g @bazel/bazelisk
+#npm install -g @bazel/bazelisk
# Create root dir.
ROOT_DIR=/tmp/mobilebert_benchmarks
@@ -40,9 +40,10 @@
export CC=clang
export CXX=clang++
-python configure_bazel.py
+python3 configure_bazel.py
cd integrations/tensorflow
+./symlink_binaries.sh
bazel build -c opt iree_tf_compiler:iree-import-tflite
IREE_COMPILE_PATH="${SOURCE_DIR}/iree-build/tools/iree-compile"
@@ -52,30 +53,37 @@
mkdir -p "${IREE_MODEL_DIR}/cuda"
mkdir -p "${IREE_MODEL_DIR}/llvm-cpu"
-MODEL_NAME="mobilebert_float_384_gpu"
-bazel-bin/iree_tf_compiler/iree-import-tflite "${TFLITE_MODEL_DIR}/${MODEL_NAME}.tflite" -o "${IREE_MODEL_DIR}/${MODEL_NAME}.mlir"
-# Build for CUDA.
-echo "Compiling ${MODEL_NAME}.vmfb for cuda..."
-"${IREE_COMPILE_PATH}" \
- --iree-input-type=tosa \
- --iree-hal-target-backends=cuda \
- --iree-hal-cuda-llvm-target-arch=sm_80 \
- --iree-llvm-debug-symbols=false \
- --iree-vm-bytecode-module-strip-source-map=true \
- --iree-vm-emit-polyglot-zip=false \
- "${IREE_MODEL_DIR}/${MODEL_NAME}.mlir" \
- --o "${IREE_MODEL_DIR}/cuda/${MODEL_NAME}.vmfb"
-# Build for x86.
-echo "Compiling ${MODEL_NAME}.vmfb for llvm-cpu..."
-"${IREE_COMPILE_PATH}" \
- --iree-input-type=tosa \
- --iree-llvm-target-cpu-features=host \
- --iree-hal-target-backends=llvm-cpu \
- --iree-llvm-debug-symbols=false \
- --iree-vm-bytecode-module-strip-source-map=true \
- --iree-vm-emit-polyglot-zip=false \
- "${IREE_MODEL_DIR}/${MODEL_NAME}.mlir" \
- --o "${IREE_MODEL_DIR}/llvm-cpu/${MODEL_NAME}.vmfb"
+# Runs `iree-compile` on all TFLite files in directory. If compilation fails, we
+# keep going.
+for i in $(ls ${ROOT_DIR}/models/tflite/); do
+ MODEL_NAME=$(basename $i .tflite)
+ echo "Processing ${MODEL_NAME} ..."
+
+ ${IREE_IMPORT_TFLITE_PATH} "${TFLITE_MODEL_DIR}/${MODEL_NAME}.tflite" -o "${IREE_MODEL_DIR}/${MODEL_NAME}.mlir"
+ # Build for CUDA.
+ echo "Compiling ${MODEL_NAME}.vmfb for cuda..."
+ "${IREE_COMPILE_PATH}" \
+ --iree-input-type=tosa \
+ --iree-hal-target-backends=cuda \
+ --iree-hal-cuda-llvm-target-arch=sm_80 \
+ --iree-llvm-debug-symbols=false \
+ --iree-vm-bytecode-module-strip-source-map=true \
+ --iree-vm-emit-polyglot-zip=false \
+ "${IREE_MODEL_DIR}/${MODEL_NAME}.mlir" \
+ --o "${IREE_MODEL_DIR}/cuda/${MODEL_NAME}.vmfb" || true
+ # Build for x86.
+ echo "Compiling ${MODEL_NAME}.vmfb for llvm-cpu..."
+ "${IREE_COMPILE_PATH}" \
+ --iree-input-type=tosa \
+ --iree-llvm-target-cpu-features=host \
+ --iree-hal-target-backends=llvm-cpu \
+ --iree-llvm-debug-symbols=false \
+ --iree-vm-bytecode-module-strip-source-map=true \
+ --iree-vm-emit-polyglot-zip=false \
+ "${IREE_MODEL_DIR}/${MODEL_NAME}.mlir" \
+ --o "${IREE_MODEL_DIR}/llvm-cpu/${MODEL_NAME}.vmfb" || true
+ done
+
cp "${SOURCE_DIR}/iree-build/tools/iree-benchmark-module" "${ROOT_DIR}/"
@@ -89,7 +97,7 @@
git clone https://github.com/tensorflow/tensorflow.git
cd tensorflow
# Select defaults and answer No to all questions.
-python configure.py
+python3 configure.py
bazel build -c opt --copt=-DCL_DELEGATE_NO_GL \
--copt=-DMESA_EGL_NO_X11_HEADERS=1 \
diff --git a/build_tools/benchmarks/comparisons/setup_mobile.sh b/build_tools/benchmarks/comparisons/setup_mobile.sh
index 9592879..92ef72e 100644
--- a/build_tools/benchmarks/comparisons/setup_mobile.sh
+++ b/build_tools/benchmarks/comparisons/setup_mobile.sh
@@ -57,11 +57,13 @@
python configure_bazel.py
cd integrations/tensorflow
+./symlink_binaries.sh
bazel build -c opt iree_tf_compiler:iree-import-tflite
echo "Done building iree-import-tflite"
echo
+IREE_IMPORT_TFLITE_PATH=${SOURCE_DIR}/iree/integrations/tensorflow/bazel-bin/iree_tf_compiler/iree-import-tflite
IREE_COMPILE_PATH="${SOURCE_DIR}/iree-build/tools/iree-compile"
TFLITE_MODEL_DIR="${ROOT_DIR}/models/tflite"
@@ -69,11 +71,13 @@
mkdir -p "${IREE_MODEL_DIR}/vulkan"
mkdir -p "${IREE_MODEL_DIR}/llvm-cpu"
+# Runs `iree-compile` on all TFLite files in directory. If compilation fails, we
+# keep going.
for i in $(ls ${ROOT_DIR}/models/tflite/); do
MODEL_NAME=$(basename $i .tflite)
echo "Processing ${MODEL_NAME} ..."
- bazel-bin/iree_tf_compiler/iree-import-tflite "${TFLITE_MODEL_DIR}/${MODEL_NAME}.tflite" -o "${IREE_MODEL_DIR}/${MODEL_NAME}.mlir"
+ ${IREE_IMPORT_TFLITE_PATH} "${TFLITE_MODEL_DIR}/${MODEL_NAME}.tflite" -o "${IREE_MODEL_DIR}/${MODEL_NAME}.mlir"
echo -e "\tCompiling ${MODEL_NAME}.vmfb for aarch64..."
"${IREE_COMPILE_PATH}" \
--iree-input-type=tosa \
@@ -83,7 +87,7 @@
--iree-vm-bytecode-module-strip-source-map=true \
--iree-vm-emit-polyglot-zip=false \
"${IREE_MODEL_DIR}/${MODEL_NAME}.mlir" \
- --o "${IREE_MODEL_DIR}/llvm-cpu/${MODEL_NAME}.vmfb"
+ --o "${IREE_MODEL_DIR}/llvm-cpu/${MODEL_NAME}.vmfb" || true
echo -e "\tCompiling ${MODEL_NAME}_mmt4d.vmfb for aarch64..."
"${IREE_COMPILE_PATH}" \
@@ -96,7 +100,22 @@
--iree-vm-bytecode-module-strip-source-map=true \
--iree-vm-emit-polyglot-zip=false \
"${IREE_MODEL_DIR}/${MODEL_NAME}.mlir" \
- --o "${IREE_MODEL_DIR}/llvm-cpu/${MODEL_NAME}_mmt4d.vmfb"
+ --o "${IREE_MODEL_DIR}/llvm-cpu/${MODEL_NAME}_mmt4d.vmfb" || true
+
+ echo -e "\tCompiling ${MODEL_NAME}_im2col_mmt4d.vmfb for aarch64..."
+ "${IREE_COMPILE_PATH}" \
+ --iree-input-type=tosa \
+ --iree-hal-target-backends=llvm-cpu \
+ --iree-llvm-target-triple=aarch64-none-linux-android29 \
+ "--iree-flow-mmt4d-target-options=arch=aarch64 features=+dotprod" \
+ --iree-llvm-target-cpu-features=+dotprod \
+ --iree-flow-enable-conv-img2col-transform \
+ --iree-llvm-debug-symbols=false \
+ --iree-vm-bytecode-module-strip-source-map=true \
+ --iree-vm-emit-polyglot-zip=false \
+ "${IREE_MODEL_DIR}/${MODEL_NAME}.mlir" \
+ --o "${IREE_MODEL_DIR}/llvm-cpu/${MODEL_NAME}_im2col_mmt4d.vmfb" || true
+
if [[ "${GPU_TYPE}" = "mali" ]]; then
echo -e "\tCompiling ${MODEL_NAME}.vmfb for vulkan mali..."
@@ -108,7 +127,29 @@
--iree-vm-bytecode-module-strip-source-map=true \
--iree-vm-emit-polyglot-zip=false \
"${IREE_MODEL_DIR}/${MODEL_NAME}.mlir" \
- --o "${IREE_MODEL_DIR}/vulkan/${MODEL_NAME}.vmfb"
+ --o "${IREE_MODEL_DIR}/vulkan/${MODEL_NAME}.vmfb" || true
+ echo -e "\tCompiling ${MODEL_NAME}_fp16.vmfb for vulkan mali..."
+ "${IREE_COMPILE_PATH}" \
+ --iree-input-type=tosa \
+ --iree-hal-target-backends=vulkan-spirv \
+ --iree-vulkan-target-triple=valhall-unknown-android31 \
+ --iree-flow-demote-f32-to-f16 \
+ --iree-llvm-debug-symbols=false \
+ --iree-vm-bytecode-module-strip-source-map=true \
+ --iree-vm-emit-polyglot-zip=false \
+ "${IREE_MODEL_DIR}/${MODEL_NAME}.mlir" \
+ --o "${IREE_MODEL_DIR}/vulkan/${MODEL_NAME}_fp16.vmfb" || true
+ echo -e "\tCompiling ${MODEL_NAME}_padfuse.vmfb for vulkan mali..."
+ "${IREE_COMPILE_PATH}" \
+ --iree-input-type=tosa \
+ --iree-hal-target-backends=vulkan-spirv \
+ --iree-vulkan-target-triple=valhall-unknown-android31 \
+ --iree-flow-enable-fuse-padding-into-consumer-ops \
+ --iree-llvm-debug-symbols=false \
+ --iree-vm-bytecode-module-strip-source-map=true \
+ --iree-vm-emit-polyglot-zip=false \
+ "${IREE_MODEL_DIR}/${MODEL_NAME}.mlir" \
+ --o "${IREE_MODEL_DIR}/vulkan/${MODEL_NAME}_padfuse.vmfb" || true
else
echo -e "\tCompiling ${MODEL_NAME}.vmfb for vulkan adreno..."
"${IREE_COMPILE_PATH}" \
@@ -119,7 +160,7 @@
--iree-vm-bytecode-module-strip-source-map=true \
--iree-vm-emit-polyglot-zip=false \
"${IREE_MODEL_DIR}/${MODEL_NAME}.mlir" \
- --o "${IREE_MODEL_DIR}/vulkan/${MODEL_NAME}.vmfb"
+ --o "${IREE_MODEL_DIR}/vulkan/${MODEL_NAME}.vmfb" || true
fi
done
diff --git a/build_tools/benchmarks/comparisons/simple_commands.py b/build_tools/benchmarks/comparisons/simple_commands.py
index cfc8101..6b5ef6e 100644
--- a/build_tools/benchmarks/comparisons/simple_commands.py
+++ b/build_tools/benchmarks/comparisons/simple_commands.py
@@ -114,10 +114,19 @@
driver="cpu")
commands.append(tflite)
+ tflite_noxnn = TfliteWrapper(self._tflite_benchmark_binary_path,
+ self._model_name + "_noxnn",
+ tflite_model_path,
+ self._input_name,
+ driver="cpu")
+ tflite_noxnn.args.append("--use_xnnpack=false")
+ commands.append(tflite_noxnn)
+
# Generate IREE benchmarks.
driver = "local-task"
+ backend = "llvm-cpu"
- iree_model_path = os.path.join(self._base_dir, "models", "iree", driver,
+ iree_model_path = os.path.join(self._base_dir, "models", "iree", backend,
self._model_name + ".vmfb")
iree = IreeWrapper(self._iree_benchmark_binary_path,
self._model_name,
@@ -130,7 +139,7 @@
if device == "mobile":
model_mmt4d_name = self._model_name + "_mmt4d"
iree_mmt4d_model_path = os.path.join(self._base_dir, "models", "iree",
- driver, model_mmt4d_name + ".vmfb")
+ backend, model_mmt4d_name + ".vmfb")
iree_mmt4d = IreeWrapper(self._iree_benchmark_binary_path,
model_mmt4d_name,
iree_mmt4d_model_path,
@@ -138,6 +147,17 @@
driver=driver)
commands.append(iree_mmt4d)
+ model_im2col_mmt4d_name = self._model_name + "_im2col_mmt4d"
+ iree_im2col_mmt4d_model_path = os.path.join(
+ self._base_dir, "models", "iree", backend,
+ model_im2col_mmt4d_name + ".vmfb")
+ iree_im2col_mmt4d = IreeWrapper(self._iree_benchmark_binary_path,
+ model_im2col_mmt4d_name,
+ iree_im2col_mmt4d_model_path,
+ self._function_input,
+ driver=driver)
+ commands.append(iree_im2col_mmt4d)
+
return commands
def _generate_gpu(self, driver: str):
@@ -153,6 +173,24 @@
tflite.args.append("--gpu_precision_loss_allowed=false")
commands.append(tflite)
+ tflite_noxnn = TfliteWrapper(self._tflite_benchmark_binary_path,
+ self._model_name + "_noxnn",
+ tflite_model_path,
+ self._input_name,
+ self._input_layer,
+ driver="gpu")
+ tflite.args.append("--use_xnnpack=false")
+ commands.append(tflite_noxnn)
+
+ tflite_fp16 = TfliteWrapper(self._tflite_benchmark_binary_path,
+ self._model_name + "_fp16",
+ tflite_model_path,
+ self._input_name,
+ self._input_layer,
+ driver="gpu")
+ tflite.args.append("--gpu_precision_loss_allowed=true")
+ commands.append(tflite_fp16)
+
iree_model_path = os.path.join(self._base_dir, "models", "iree", driver,
self._model_name + ".vmfb")
iree = IreeWrapper(self._iree_benchmark_binary_path,
@@ -161,4 +199,22 @@
self._function_input,
driver=driver)
commands.append(iree)
+
+ iree_model_path = os.path.join(self._base_dir, "models", "iree", driver,
+ self._model_name + "_fp16.vmfb")
+ iree = IreeWrapper(self._iree_benchmark_binary_path,
+ self._model_name + "_fp16",
+ iree_model_path,
+ self._function_input,
+ driver=driver)
+ commands.append(iree)
+
+ iree_model_path = os.path.join(self._base_dir, "models", "iree", driver,
+ self._model_name + "_padfuse.vmfb")
+ iree = IreeWrapper(self._iree_benchmark_binary_path,
+ self._model_name + "_padfuse",
+ iree_model_path,
+ self._function_input,
+ driver=driver)
+ commands.append(iree)
return commands
diff --git a/build_tools/benchmarks/suites/cmake_rule_generator.py b/build_tools/benchmarks/suites/cmake_rule_generator.py
index a2dfa43..3075ac8 100644
--- a/build_tools/benchmarks/suites/cmake_rule_generator.py
+++ b/build_tools/benchmarks/suites/cmake_rule_generator.py
@@ -125,6 +125,7 @@
model_id: str,
model_name: str,
model_source_type: common_definitions.ModelSourceType,
+ model_entry_function: str,
source_model_rule: ModelRule,
) -> IreeModelImportRule:
"""Adds a rule to fetch the model and import into MLIR. Reuses the rule when
@@ -161,6 +162,7 @@
cmake_rule = TF_IMPORT_CMAKE_TEMPLATE.substitute(
__TARGET_NAME=target_name,
__SOURCE_MODEL_PATH=source_model_rule.file_path,
+ __ENTRY_FUNCTION=model_entry_function,
__OUTPUT_PATH=output_file_path)
mlir_dialect_type = "mhlo"
else:
@@ -270,6 +272,7 @@
model_id=model.id,
model_name=model.name,
model_source_type=model.source_type,
+ model_entry_function=model.entry_function,
source_model_rule=source_model_rule)
iree_rule_factory.add_compile_module_rule(compile_config=compile_config,
model_import_rule=import_rule)
diff --git a/build_tools/benchmarks/suites/cmake_rule_generator_test.py b/build_tools/benchmarks/suites/cmake_rule_generator_test.py
index 582ae81..3ade14f 100644
--- a/build_tools/benchmarks/suites/cmake_rule_generator_test.py
+++ b/build_tools/benchmarks/suites/cmake_rule_generator_test.py
@@ -67,6 +67,7 @@
model_id="1234",
model_name="abcd",
model_source_type=common_definitions.ModelSourceType.EXPORTED_TFLITE,
+ model_entry_function="main",
source_model_rule=cmake_rule_generator.ModelRule(
target_name="model-1234", file_path="aaa", cmake_rule="bbb"))
@@ -85,6 +86,7 @@
model_name="abcd",
model_source_type=common_definitions.ModelSourceType.
EXPORTED_LINALG_MLIR,
+ model_entry_function="main",
source_model_rule=model_rule)
self.assertEqual(rule.target_name, model_rule.target_name)
@@ -122,12 +124,14 @@
model_id="1234",
model_name="abcd",
model_source_type=common_definitions.ModelSourceType.EXPORTED_TFLITE,
+ model_entry_function="main",
source_model_rule=cmake_rule_generator.ModelRule(
target_name="model-1234", file_path="aaa", cmake_rule="bbb"))
import_rule_2 = self._factory.add_import_model_rule(
model_id="5678",
model_name="efgh",
model_source_type=common_definitions.ModelSourceType.EXPORTED_TF,
+ model_entry_function="main",
source_model_rule=cmake_rule_generator.ModelRule(
target_name="model-5678", file_path="ccc", cmake_rule="eee"))
compile_config = iree_definitions.CompileConfig(
diff --git a/build_tools/benchmarks/suites/iree_tf_import_template.cmake b/build_tools/benchmarks/suites/iree_tf_import_template.cmake
index 18850ea..c8b33d0 100644
--- a/build_tools/benchmarks/suites/iree_tf_import_template.cmake
+++ b/build_tools/benchmarks/suites/iree_tf_import_template.cmake
@@ -2,5 +2,6 @@
iree_import_tf_model(
TARGET_NAME "$${_PACKAGE_NAME}_$__TARGET_NAME"
SOURCE "$__SOURCE_MODEL_PATH"
+ ENTRY_FUNCTION "$__ENTRY_FUNCTION"
OUTPUT_MLIR_FILE "$__OUTPUT_PATH"
)
diff --git a/build_tools/buildkite/cmake/android/arm64-v8a/benchmark2.yml b/build_tools/buildkite/cmake/android/arm64-v8a/benchmark2.yml
index 1debddb..47fdbe6 100644
--- a/build_tools/buildkite/cmake/android/arm64-v8a/benchmark2.yml
+++ b/build_tools/buildkite/cmake/android/arm64-v8a/benchmark2.yml
@@ -9,7 +9,7 @@
steps:
- label: "Build"
commands:
- - "docker run --user=$(id -u):$(id -g) --volume=\\${HOME?}:\\${HOME?} --volume=/etc/passwd:/etc/passwd:ro --volume=/etc/group:/etc/group:ro --volume=\\$PWD:\\$IREE_DOCKER_WORKDIR --workdir=\\$IREE_DOCKER_WORKDIR --rm gcr.io/iree-oss/frontends@sha256:bad174c580cdefaf435ce31a7df6bdd7f7cb7bfdcdff5d1acf40f630acf85bf5 build_tools/cmake/build_android_benchmark.sh"
+ - "docker run --user=$(id -u):$(id -g) --volume=\\${HOME?}:\\${HOME?} --volume=/etc/passwd:/etc/passwd:ro --volume=/etc/group:/etc/group:ro --volume=\\$PWD:\\$IREE_DOCKER_WORKDIR --workdir=\\$IREE_DOCKER_WORKDIR --rm gcr.io/iree-oss/frontends@sha256:7a7a6d2fce60f3db82bfd2f18316231f9e4662cd9307b079d5adfbb6e119b817 build_tools/cmake/build_android_benchmark.sh"
- "tar --exclude='*.tar.gz' --exclude='*.tgz' --exclude='*.mlir' --exclude='*.tflite' --exclude='*tf-model' -czvf benchmark-suites-${BUILDKITE_BUILD_NUMBER}.tgz build-host/benchmark_suites"
- "find build-host/benchmark_suites -name '*.mlir' | tar -czvf source-mlir-models-${BUILDKITE_BUILD_NUMBER}.tgz -T -"
- "tar -czvf iree-android-tools-${BUILDKITE_BUILD_NUMBER}.tgz build-android/tools/iree-benchmark-module build-android-trace/tools/iree-benchmark-module build-android/tools/build_config.txt"
diff --git a/build_tools/buildkite/cmake/android/arm64-v8a/pipeline.yml b/build_tools/buildkite/cmake/android/arm64-v8a/pipeline.yml
index 92e5bc7..a05aae9 100644
--- a/build_tools/buildkite/cmake/android/arm64-v8a/pipeline.yml
+++ b/build_tools/buildkite/cmake/android/arm64-v8a/pipeline.yml
@@ -8,7 +8,7 @@
- label: "build"
commands:
- "git submodule sync && git submodule update --init --jobs 8 --depth 1"
- - "docker run --user=$(id -u):$(id -g) --volume=\\$PWD:\\$IREE_DOCKER_WORKDIR --workdir=\\$IREE_DOCKER_WORKDIR --rm gcr.io/iree-oss/android@sha256:9bc723fc707a18bd0c1be9c12e01ea5bb7c7d77f607427879e10ffcffd7d2bb5 build_tools/cmake/build_host_and_android.sh arm64-v8a"
+ - "docker run --user=$(id -u):$(id -g) --volume=\\$PWD:\\$IREE_DOCKER_WORKDIR --workdir=\\$IREE_DOCKER_WORKDIR --rm gcr.io/iree-oss/android@sha256:76c2a52dcd6d07601227b965ac87d021c1d2d5e2d01f46ad58da28c89267f2ab build_tools/cmake/build_host_and_android.sh arm64-v8a"
- "tar --exclude='*.o' --exclude='*.a' -czvf build-artifacts.tgz build-android"
agents:
- "queue=build"
diff --git a/build_tools/buildkite/cmake/linux/pipeline.yml b/build_tools/buildkite/cmake/linux/pipeline.yml
index 0b26789..df2b2e3 100644
--- a/build_tools/buildkite/cmake/linux/pipeline.yml
+++ b/build_tools/buildkite/cmake/linux/pipeline.yml
@@ -5,7 +5,7 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
env:
- DOCKER_IMAGE: "gcr.io/iree-oss/frontends@sha256:bad174c580cdefaf435ce31a7df6bdd7f7cb7bfdcdff5d1acf40f630acf85bf5"
+ DOCKER_IMAGE: "gcr.io/iree-oss/frontends@sha256:7a7a6d2fce60f3db82bfd2f18316231f9e4662cd9307b079d5adfbb6e119b817"
IREE_DOCKER_WORKDIR: "/usr/src/github/iree"
steps:
diff --git a/build_tools/buildkite/cmake/linux/x86_64/benchmark.yml b/build_tools/buildkite/cmake/linux/x86_64/benchmark.yml
index 435b698..45af655 100644
--- a/build_tools/buildkite/cmake/linux/x86_64/benchmark.yml
+++ b/build_tools/buildkite/cmake/linux/x86_64/benchmark.yml
@@ -15,7 +15,7 @@
--volume="$$PWD:$$IREE_DOCKER_WORKDIR" \
--workdir="$$IREE_DOCKER_WORKDIR" \
--rm \
- gcr.io/iree-oss/frontends@sha256:bad174c580cdefaf435ce31a7df6bdd7f7cb7bfdcdff5d1acf40f630acf85bf5 \
+ gcr.io/iree-oss/frontends@sha256:7a7a6d2fce60f3db82bfd2f18316231f9e4662cd9307b079d5adfbb6e119b817 \
build_tools/cmake/build_linux_benchmark.sh
tar --exclude="*.tar.gz" \
--exclude="*.tgz" \
diff --git a/build_tools/cmake/iree_bytecode_module.cmake b/build_tools/cmake/iree_bytecode_module.cmake
index 166fe31..85c44d7 100644
--- a/build_tools/cmake/iree_bytecode_module.cmake
+++ b/build_tools/cmake/iree_bytecode_module.cmake
@@ -97,12 +97,6 @@
list(APPEND _ARGS "--iree-llvm-sanitize=thread")
endif()
- if(_RULE_FRIENDLY_NAME)
- set(_FRIENDLY_NAME "${_RULE_FRIENDLY_NAME}")
- else()
- get_filename_component(_FRIENDLY_NAME "${_RULE_SRC}" NAME)
- endif()
-
set(_OUTPUT_FILES "${_MODULE_FILE_NAME}")
# Check LLVM static library setting. If the static libary output path is set,
# retrieve the object path and the corresponding header file path.
@@ -115,6 +109,21 @@
list(APPEND _OUTPUT_FILES "${_RULE_STATIC_LIB_PATH}" "${_STATIC_HDR_PATH}")
endif()
+ if(CMAKE_SYSTEM_PROCESSOR STREQUAL "riscv" AND
+ RISCV_CPU STREQUAL "rv64" AND
+ NOT _RULE_FLAGS MATCHES "iree-llvm-target-triple")
+ # RV64 Linux crosscompile toolchain can support iree-compile with
+ # specific CPU flags. Add the llvm flags to support RV64 RVV codegen if
+ # llvm-target-triple is not specified.
+ list(APPEND _RULE_FLAGS ${RISCV64_TEST_DEFAULT_LLVM_FLAGS})
+ endif()
+
+ if(_RULE_FRIENDLY_NAME)
+ set(_FRIENDLY_NAME "${_RULE_FRIENDLY_NAME}")
+ else()
+ get_filename_component(_FRIENDLY_NAME "${_RULE_SRC}" NAME)
+ endif()
+
# Depending on the binary instead of the target here given we might not have
# a target in this CMake invocation when cross-compiling.
add_custom_command(
diff --git a/build_tools/cmake/iree_cc_test.cmake b/build_tools/cmake/iree_cc_test.cmake
index 48746c0..6bfc9fb 100644
--- a/build_tools/cmake/iree_cc_test.cmake
+++ b/build_tools/cmake/iree_cc_test.cmake
@@ -160,6 +160,18 @@
TEST_TMPDIR=${_ANDROID_ABS_DIR}/test_tmpdir
)
set_property(TEST ${_NAME_PATH} PROPERTY ENVIRONMENT ${_ENVIRONMENT_VARS})
+ elseif(CMAKE_SYSTEM_PROCESSOR STREQUAL "riscv" AND RISCV_CPU STREQUAL "rv64")
+ # The test target needs to run within the QEMU emulator for RV64 Linux
+ # crosscompile build or on-device.
+ add_test(
+ NAME
+ ${_NAME_PATH}
+ COMMAND
+ "${IREE_ROOT_DIR}/build_tools/cmake/run_riscv64_test.sh"
+ "$<TARGET_FILE:${_NAME}>"
+ ${_RULE_ARGS}
+ )
+ iree_configure_test(${_NAME_PATH})
else(ANDROID)
add_test(
NAME
diff --git a/build_tools/cmake/iree_check_test.cmake b/build_tools/cmake/iree_check_test.cmake
index 1f1f44a..eae1901 100644
--- a/build_tools/cmake/iree_check_test.cmake
+++ b/build_tools/cmake/iree_check_test.cmake
@@ -6,6 +6,11 @@
include(CMakeParseArguments)
+set(IREE_TARGET_BACKENDS_SUPPORTING_TARGET_CPU_FEATURES
+ llvm-cpu
+ vmvx
+)
+
# Helper for iree_check_test and iree_trace_runner_test.
# Just a thin wrapper around iree_bytecode_module, passing it some
# common flags, including the appropriate --iree-llvm-target-triple in the
@@ -44,22 +49,14 @@
# RV64 Linux crosscompile toolchain can support iree_check_test with
# specific CPU flags. Add the llvm flags to support RV64 RVV codegen if
# llvm-target-triple is not specified.
- list(APPEND _RULE_FLAGS "--iree-llvm-target-triple=riscv64")
- list(APPEND _RULE_FLAGS "--iree-llvm-target-cpu=generic-rv64")
- list(APPEND _RULE_FLAGS "--iree-llvm-target-abi=lp64d")
- if(NOT _RULE_TARGET_CPU_FEATURES)
- list(APPEND _RULE_FLAGS "--iree-llvm-target-cpu-features=+m,+a,+f,+d,+c,+v")
- list(APPEND _RULE_FLAGS "--riscv-v-fixed-length-vector-lmul-max=8")
- list(APPEND _RULE_FLAGS "--riscv-v-vector-bits-min=512")
- endif()
+ list(APPEND _RULE_FLAGS ${RISCV64_TEST_DEFAULT_LLVM_FLAGS})
endif()
if(_RULE_TARGET_CPU_FEATURES)
- if(NOT _RULE_TARGET_BACKEND STREQUAL "llvm-cpu")
+ if(NOT _RULE_TARGET_BACKEND IN_LIST IREE_TARGET_BACKENDS_SUPPORTING_TARGET_CPU_FEATURES)
message(SEND_ERROR "TARGET_CPU_FEATURES should be empty when \
-TARGET_BACKEND is not llvm-cpu. Actual values: \
-TARGET_CPU_FEATURES=${_RULE_TARGET_CPU_FEATURES}, \
-TARGET_BACKEND=${_RULE_TARGET_BACKEND}.")
+TARGET_BACKEND is not in the list (${IREE_TARGET_BACKENDS_SUPPORTING_TARGET_CPU_FEATURES}). Actual values: \
+TARGET_CPU_FEATURES=${_RULE_TARGET_CPU_FEATURES}, TARGET_BACKEND=${_RULE_TARGET_BACKEND}.")
endif()
list(APPEND _RULE_FLAGS "--iree-llvm-target-cpu-features=${_RULE_TARGET_CPU_FEATURES}")
endif()
diff --git a/build_tools/cmake/iree_native_test.cmake b/build_tools/cmake/iree_native_test.cmake
index a9e5bbd..c386283 100644
--- a/build_tools/cmake/iree_native_test.cmake
+++ b/build_tools/cmake/iree_native_test.cmake
@@ -115,23 +115,16 @@
set_property(TEST ${_TEST_NAME} PROPERTY ENVIRONMENT ${_ENVIRONMENT_VARS})
elseif(CMAKE_SYSTEM_PROCESSOR STREQUAL "riscv" AND RISCV_CPU STREQUAL "rv64")
# The test target needs to run within the QEMU emulator for RV64 Linux
- # crosscompile build. A QEMU 64 Linux emulator must be available at the
- # path specified by the `QEMU_RV64_BIN` environment variable.
- if(DEFINED ENV{QEMU_RV64_BIN})
- add_test(
- NAME
- ${_TEST_NAME}
- COMMAND
- "$ENV{QEMU_RV64_BIN}"
- -cpu rv64,x-v=true,x-k=true,vlen=512,elen=64,vext_spec=v1.0
- -L "${RISCV_TOOLCHAIN_ROOT}/sysroot"
- "$<TARGET_FILE:${_SRC_TARGET}>"
- ${_RULE_ARGS}
- )
- iree_configure_test(${_TEST_NAME})
- else()
- message(SEND_ERROR "QEMU not found.")
- endif()
+ # crosscompile build or on-device.
+ add_test(
+ NAME
+ ${_TEST_NAME}
+ COMMAND
+ "${IREE_ROOT_DIR}/build_tools/cmake/run_riscv64_test.sh"
+ "$<TARGET_FILE:${_SRC_TARGET}>"
+ ${_RULE_ARGS}
+ )
+ iree_configure_test(${_TEST_NAME})
else()
add_test(
NAME
diff --git a/build_tools/cmake/iree_trace_runner_test.cmake b/build_tools/cmake/iree_trace_runner_test.cmake
index 49c985a..634083d 100644
--- a/build_tools/cmake/iree_trace_runner_test.cmake
+++ b/build_tools/cmake/iree_trace_runner_test.cmake
@@ -311,7 +311,7 @@
foreach(_INDEX RANGE "${_MAX_INDEX}")
list(GET _RULE_TARGET_BACKENDS ${_INDEX} _TARGET_BACKEND)
list(GET _RULE_DRIVERS ${_INDEX} _DRIVER)
- if(_TARGET_BACKEND STREQUAL "llvm-cpu" AND _RULE_TARGET_CPU_FEATURES_VARIANTS)
+ if((_TARGET_BACKEND IN_LIST IREE_TARGET_BACKENDS_SUPPORTING_TARGET_CPU_FEATURES) AND _RULE_TARGET_CPU_FEATURES_VARIANTS)
set(_TARGET_CPU_FEATURES_VARIANTS "${_RULE_TARGET_CPU_FEATURES_VARIANTS}")
else()
set(_TARGET_CPU_FEATURES_VARIANTS "default")
diff --git a/build_tools/cmake/riscv.toolchain.cmake b/build_tools/cmake/riscv.toolchain.cmake
index a715469..8575a83 100644
--- a/build_tools/cmake/riscv.toolchain.cmake
+++ b/build_tools/cmake/riscv.toolchain.cmake
@@ -50,6 +50,14 @@
set(CMAKE_SYSTEM_LIBRARY_PATH "${RISCV_TOOLCHAIN_ROOT}/sysroot/usr/lib")
set(RISCV_COMPILER_FLAGS "${RISCV_COMPILER_FLAGS} -march=rv64gc -mabi=lp64d")
set(RISCV_LINKER_FLAGS "${RISCV_LINKER_FLAGS} -lstdc++ -lpthread -lm -ldl")
+ set(RISCV64_TEST_DEFAULT_LLVM_FLAGS
+ "--iree-llvm-target-triple=riscv64"
+ "--iree-llvm-target-cpu=generic-rv64"
+ "--iree-llvm-target-abi=lp64d"
+ "--iree-llvm-target-cpu-features=+m,+a,+f,+d,+c,+v"
+ "--riscv-v-fixed-length-vector-lmul-max=8"
+ "--riscv-v-vector-bits-min=512"
+ CACHE INTERNAL "Default llvm codegen flags for testing purposes")
elseif(RISCV_CPU STREQUAL "rv32-baremetal")
set(CMAKE_SYSTEM_NAME Generic)
set(CMAKE_CROSSCOMPILING ON CACHE BOOL "")
diff --git a/build_tools/cmake/run_riscv64_test.sh b/build_tools/cmake/run_riscv64_test.sh
new file mode 100755
index 0000000..833cf5a
--- /dev/null
+++ b/build_tools/cmake/run_riscv64_test.sh
@@ -0,0 +1,23 @@
+#!/bin/bash
+
+# Copyright 2022 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+# Wrapper script to run the artifact on RISC-V 64-bit Linux device.
+# This script checks if QEMU emulator is set, and use either the emulator or
+# the actual device to run the cross-compiled RISC-V 64-bit linux artifacts.
+
+set -x
+set -e
+
+# A QEMU 64 Linux emulator must be available at the path specified by the
+# `QEMU_RV64_BIN` environment variable to run the artifacts under the emulator.
+if [[ -z "${QEMU_RV64_BIN}" ]]; then
+ "${QEMU_RV64_BIN}" "-cpu rv64,x-v=true,x-k=true,vlen=512,elen=64,vext_spec=v1.0 \
+ -L ${RISCV_RV64_LINUX_TOOLCHAIN_ROOT}/sysroot $*"
+fi
+
+# TODO(dcaballe): Add on-device run commands.
diff --git a/build_tools/cmake/test_riscv64.sh b/build_tools/cmake/test_riscv64.sh
index f6cfc4e..ee0bc77 100755
--- a/build_tools/cmake/test_riscv64.sh
+++ b/build_tools/cmake/test_riscv64.sh
@@ -74,9 +74,13 @@
${PYTHON_BIN} "${ROOT_DIR}/third_party/llvm-project/llvm/utils/lit/lit.py" \
-v --path "${LLVM_BIN_DIR}" "${ROOT_DIR}/tests/riscv64"
-# Test e2e models. Excluding mobilebert for now.
-ctest --test-dir ${BUILD_RISCV_DIR}/tests/e2e/models -R llvm-cpu_local-task_mobilenet -E bert
-# Test all tosa ops
-ctest --test-dir ${BUILD_RISCV_DIR}/tests/e2e/tosa_ops -R check_llvm-cpu_local-task
-# Test all xla ops except fp16, which is not supported properly
-ctest --test-dir ${BUILD_RISCV_DIR}/tests/e2e/xla_ops -R check_llvm-cpu_local-task -E fp16
+# Test runtime unit tests
+ctest --test-dir ${BUILD_RISCV_DIR}/runtime/ --timeout 900 --output-on-failure \
+ --no-tests=error --label-exclude \
+ '(^nokokoro$|^driver=vulkan$|^driver=cuda$|^vulkan_uses_vk_khr_shader_float16_int8$|^requires-filesystem$|^requires-dtz$)'
+
+# Test e2e models. Excluding mobilebert and fp16 for now.
+ctest --test-dir ${BUILD_RISCV_DIR}/tests/e2e --timeout 900 --output-on-failure \
+ --no-tests=error --label-exclude \
+ '(^nokokoro$|^driver=vulkan$|^driver=cuda$|^vulkan_uses_vk_khr_shader_float16_int8$)' \
+ -E '(bert|fp16)'
diff --git a/build_tools/docker/android/Dockerfile b/build_tools/docker/android/Dockerfile
index e54d11d..87bfd6a 100644
--- a/build_tools/docker/android/Dockerfile
+++ b/build_tools/docker/android/Dockerfile
@@ -7,13 +7,13 @@
# An image for cross-compiling IREE towards Android.
FROM gcr.io/iree-oss/base@sha256:5d43683c6b50aebe1fca6c85f2012f3b0fa153bf4dd268e8767b619b1891423a
-ARG NDK_VERSION=r21d
+ARG NDK_VERSION=r25b
WORKDIR /install-ndk
ENV ANDROID_NDK "/usr/src/android-ndk-${NDK_VERSION}"
-RUN wget -q "https://dl.google.com/android/repository/android-ndk-${NDK_VERSION?}-linux-x86_64.zip" \
- && unzip -q "android-ndk-${NDK_VERSION?}-linux-x86_64.zip" -d /usr/src/ \
+RUN wget -q "https://dl.google.com/android/repository/android-ndk-${NDK_VERSION?}-linux.zip" \
+ && unzip -q "android-ndk-${NDK_VERSION?}-linux.zip" -d /usr/src/ \
&& rm -rf /install-ndk
WORKDIR /
diff --git a/build_tools/docker/docker_run.sh b/build_tools/docker/docker_run.sh
index f03beab..8ef2db9 100755
--- a/build_tools/docker/docker_run.sh
+++ b/build_tools/docker/docker_run.sh
@@ -22,7 +22,7 @@
# Make the source repository available and launch containers in that
# directory.
DOCKER_RUN_ARGS=(
- --volume="${DOCKER_HOST_WORKDIR}:${DOCKER_CONTAINER_WORKDIR}"
+ --mount="type=bind,source=${DOCKER_HOST_WORKDIR},dst=${DOCKER_CONTAINER_WORKDIR}"
--workdir="${DOCKER_CONTAINER_WORKDIR}"
)
@@ -53,17 +53,17 @@
# want these scripts to be runnable locally for debugging.
# Instead we dump the results of `getent` to some fake files.
local fake_etc_dir="${DOCKER_HOST_TMPDIR}/fake_etc"
- mkdir -p "${fake_etc_dir?}"
+ mkdir -p "${fake_etc_dir}"
- local fake_group="${fake_etc_dir?}/group"
- local fake_passwd="${fake_etc_dir?}/passwd"
+ local fake_group="${fake_etc_dir}/group"
+ local fake_passwd="${fake_etc_dir}/passwd"
- getent group > "${fake_group?}"
- getent passwd > "${fake_passwd?}"
+ getent group > "${fake_group}"
+ getent passwd > "${fake_passwd}"
DOCKER_RUN_ARGS+=(
- --volume="${fake_group?}:/etc/group:ro"
- --volume="${fake_passwd?}:/etc/passwd:ro"
+ --mount="type=bind,src=${fake_group},dst=/etc/group,readonly"
+ --mount="type=bind,src=${fake_passwd},dst=/etc/passwd,readonly"
)
@@ -84,19 +84,22 @@
mkdir -p "${fake_home_dir}"
DOCKER_RUN_ARGS+=(
- --volume="${fake_home_dir?}:${HOME?}"
+ --mount="type=bind,src=${fake_home_dir},dst=${HOME}"
)
- # Make gcloud credentials available. This isn't necessary when running in
- # GCE but enables using this script locally with remote caching.
- DOCKER_RUN_ARGS+=(
- --volume="${HOME?}/.config/gcloud:${HOME?}/.config/gcloud:ro"
- )
+ # Make gcloud credentials available if they are present. This isn't
+ # necessary when running in GCE but enables using this script locally with
+ # remote caching.
+ if [[ -d "${HOME}/.config/gcloud" ]]; then
+ DOCKER_RUN_ARGS+=(
+ --mount="type=bind,src=${HOME}/.config/gcloud,dst=${HOME}/.config/gcloud,readonly"
+ )
+ fi
# Give the container a ramdisk and set the Bazel sandbox base to point to
# it. This helps a lot with Bazel getting IO bound.
DOCKER_RUN_ARGS+=(
- --tmpfs /dev/shm
+ --mount="type=tmpfs,dst=/dev/shm"
--env SANDBOX_BASE=/dev/shm
)
diff --git a/build_tools/docker/frontends-nvidia/Dockerfile b/build_tools/docker/frontends-nvidia/Dockerfile
index ed2bea8..d20ff05 100644
--- a/build_tools/docker/frontends-nvidia/Dockerfile
+++ b/build_tools/docker/frontends-nvidia/Dockerfile
@@ -8,7 +8,7 @@
# The NVidia drivers need to *exactly* match between the host machine and the
# docker image.
-FROM gcr.io/iree-oss/frontends@sha256:bad174c580cdefaf435ce31a7df6bdd7f7cb7bfdcdff5d1acf40f630acf85bf5
+FROM gcr.io/iree-oss/frontends@sha256:7a7a6d2fce60f3db82bfd2f18316231f9e4662cd9307b079d5adfbb6e119b817
# We use .deb files that we host because we have to pin the version exactly to
# match the host machine and packages routinely dissapear from the Ubuntu
diff --git a/build_tools/docker/frontends-swiftshader/Dockerfile b/build_tools/docker/frontends-swiftshader/Dockerfile
index 13cc7b4..299d55f 100644
--- a/build_tools/docker/frontends-swiftshader/Dockerfile
+++ b/build_tools/docker/frontends-swiftshader/Dockerfile
@@ -4,7 +4,7 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-FROM gcr.io/iree-oss/frontends@sha256:bad174c580cdefaf435ce31a7df6bdd7f7cb7bfdcdff5d1acf40f630acf85bf5
+FROM gcr.io/iree-oss/frontends@sha256:7a7a6d2fce60f3db82bfd2f18316231f9e4662cd9307b079d5adfbb6e119b817
COPY --from=gcr.io/iree-oss/swiftshader@sha256:5027d56cdfee743d956bffd035668f7784166a486c48c74b42e5882cb0c289bf \
/swiftshader /swiftshader
diff --git a/build_tools/docker/frontends/Dockerfile b/build_tools/docker/frontends/Dockerfile
index 990a56b..e844c98 100644
--- a/build_tools/docker/frontends/Dockerfile
+++ b/build_tools/docker/frontends/Dockerfile
@@ -4,7 +4,7 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-FROM gcr.io/iree-oss/android@sha256:9bc723fc707a18bd0c1be9c12e01ea5bb7c7d77f607427879e10ffcffd7d2bb5
+FROM gcr.io/iree-oss/android@sha256:76c2a52dcd6d07601227b965ac87d021c1d2d5e2d01f46ad58da28c89267f2ab
WORKDIR /install-kws
diff --git a/build_tools/docker/prod_digests.txt b/build_tools/docker/prod_digests.txt
index 56a84ae..004071c 100644
--- a/build_tools/docker/prod_digests.txt
+++ b/build_tools/docker/prod_digests.txt
@@ -1,12 +1,12 @@
gcr.io/iree-oss/base@sha256:5d43683c6b50aebe1fca6c85f2012f3b0fa153bf4dd268e8767b619b1891423a
gcr.io/iree-oss/swiftshader@sha256:5027d56cdfee743d956bffd035668f7784166a486c48c74b42e5882cb0c289bf
gcr.io/iree-oss/samples@sha256:ea1bfce1c853e0b3d1afad094086535f903950dc81810024c4cf6347d90aea8a
-gcr.io/iree-oss/frontends@sha256:bad174c580cdefaf435ce31a7df6bdd7f7cb7bfdcdff5d1acf40f630acf85bf5
-gcr.io/iree-oss/frontends-nvidia@sha256:e934ed09e9e60c28ebe11a02f37a993dd975db40118d410c4279d0fa2d4e6b9a
-gcr.io/iree-oss/frontends-swiftshader@sha256:3090418a8d8a64c356d35eff285af32570a72f41127aa123209c1562f57abb01
+gcr.io/iree-oss/frontends@sha256:7a7a6d2fce60f3db82bfd2f18316231f9e4662cd9307b079d5adfbb6e119b817
+gcr.io/iree-oss/frontends-nvidia@sha256:28cd43f36b1ca0633bbd915911abe6d22b4aa16093f074e87016305322a0eba1
+gcr.io/iree-oss/frontends-swiftshader@sha256:3d5b879672d7f302124ab3d1aa533a6949bd0adfc176884177844ac6767e23e9
gcr.io/iree-oss/gradle-android@sha256:d9d0f880c3ac995b9e8a23bbf8079b80f6842851654016c5f362c747c09aaf93
gcr.io/iree-oss/riscv@sha256:720bc0215d8462ea14352edc22710a6ce4c0c1daff581d179dd173885f1d8a35
gcr.io/iree-oss/nvidia@sha256:7c2f56db65e656c15e6c96b5812a8275dd53c82bf41221192f9ba8a451aad870
gcr.io/iree-oss/emscripten@sha256:afa6aab07d753631c37a935695e69165d8f7598dec249b31d459b046593ccd56
-gcr.io/iree-oss/android@sha256:9bc723fc707a18bd0c1be9c12e01ea5bb7c7d77f607427879e10ffcffd7d2bb5
+gcr.io/iree-oss/android@sha256:76c2a52dcd6d07601227b965ac87d021c1d2d5e2d01f46ad58da28c89267f2ab
gcr.io/iree-oss/manylinux2014_x86_64-release@sha256:b09c10868f846308bad2eab253a77d0a3f097816c40342bc289d8e62509bc5f9
diff --git a/build_tools/github_actions/runner/gcp/create_image.sh b/build_tools/github_actions/runner/gcp/create_image.sh
new file mode 100755
index 0000000..a1006e2
--- /dev/null
+++ b/build_tools/github_actions/runner/gcp/create_image.sh
@@ -0,0 +1,163 @@
+#!/bin/bash
+
+# Copyright 2022 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+set -euo pipefail
+
+TIME_STRING="$(date +%Y-%m-%d-%s)"
+
+INSTANCE_NAME="${INSTANCE_NAME:-github-runner-template-cpu-${TIME_STRING}}"
+IMAGE_NAME="${IMAGE_NAME:-github-runner-cpu-${TIME_STRING}}"
+ZONE="${ZONE:-us-central1-a}"
+PROJECT=iree-oss
+BASE_IMAGE="${BASE_IMAGE:-projects/ubuntu-os-cloud/global/images/ubuntu-2204-jammy-v20220902}"
+# It takes a little bit to bring up ssh on the instance. I haven't found a
+# better way to wait for this than just polling.
+MAX_IP_ATTEMPTS=5
+MAX_SSH_ATTEMPTS=10
+MAX_SCP_ATTEMPTS=5
+
+SCRIPT_DIR="$(dirname -- "$( readlink -f -- "$0"; )")";
+
+CREATE_INSTANCE_ARGS=(
+ "${INSTANCE_NAME}"
+ --project=iree-oss
+ --zone="${ZONE}"
+ --machine-type=e2-medium
+ # `address=''` indicates an ephemeral IP. This *shouldn't* be necessary here,
+ # as the gcloud docs say that this is the default, but in fact if you leave it
+ # off the VM gets no external IP and is impossible to SSH into. This knowledge
+ # was hard won.
+ --network-interface=network=default,address='',network-tier=PREMIUM
+ --maintenance-policy=MIGRATE
+ --provisioning-model=STANDARD
+ --no-service-account
+ --no-scopes
+ --create-disk="boot=yes,device-name=${INSTANCE_NAME},image=${BASE_IMAGE},mode=rw,size=10,type=projects/${PROJECT}/zones/${ZONE}/diskTypes/pd-balanced"
+ --no-shielded-secure-boot
+ --shielded-vtpm
+ --shielded-integrity-monitoring
+ --reservation-affinity=any
+ --metadata-from-file=startup-script="${SCRIPT_DIR}/image_setup.sh"
+)
+
+function get_ip() {
+ gcloud compute instances describe \
+ "${INSTANCE_NAME}" \
+ --zone="${ZONE}" \
+ --format='value(networkInterfaces[0].accessConfigs[0].ip)'
+}
+
+function ssh_ping() {
+ gcloud compute ssh "${INSTANCE_NAME}" \
+ --zone="${ZONE}" \
+ --command=":"
+}
+
+function wait_for_ip() {
+ local -i max_attempts="$1"
+ local -i failed_attempts=0
+ while (( failed_attempts <= max_attempts )) && [[ get_ip == "" ]]; do
+ echo -n '.'
+ failed_attempts="$(( failed_attempts+1 ))"
+ sleep 1
+ done
+
+ if (( failed_attempts > max_attempts )); then
+ echo "Instance was never assigned an external IP. Aborting"
+ exit 1
+ fi
+}
+
+function wait_for_ssh() {
+ local -i max_attempts="$1"
+ local -i failed_attempts=0
+ local output=""
+ while (( failed_attempts <= max_attempts )) && ! ssh_output="$(ssh_ping 2>&1)"; do
+ echo -n '.'
+ failed_attempts="$(( failed_attempts+1 ))"
+ sleep 1
+ done
+
+ if (( failed_attempts > max_attempts )); then
+ echo "Failed to connect to instance via ssh. Output from ssh command:"
+ echo "${ssh_output}"
+ exit 1
+ fi
+}
+
+function create_image() {
+ echo "Creating instance for boot disk"
+ (set -x; gcloud compute instances create "${CREATE_INSTANCE_ARGS[@]}")
+
+ # We could only use the ssh check below, but it's much nicer to know why an
+ # an instance isn't responsive and this is something we can check first.
+ echo "Waiting for instance to start up"
+ wait_for_ip "${MAX_IP_ATTEMPTS}"
+ wait_for_ssh "${MAX_SSH_ATTEMPTS}"
+
+ local log_file="$(mktemp)"
+ touch "${log_file}"
+
+ echo ""
+ echo "Streaming startup logs from instance"
+ tail -f "${log_file}" &
+ local -i failed_scp_attempts=0
+ local last_line=""
+ local scp_output=""
+ # Is waiting for a certain line in the logs kind of hacky? yes
+ # Is there a better way to do it? probably
+ # Does the better way involve a bunch of fiddling about? also probably
+ while (( failed_scp_attempts < MAX_SCP_ATTEMPTS )) && [[ "${last_line}" != "Setup complete" ]]; do
+ ret=0
+ scp_output="$(gcloud compute scp \
+ --zone="${ZONE}" \
+ "${INSTANCE_NAME}:/startup.log" \
+ "${log_file}" 2>&1)" || ret=$?
+ if (( ret != 0 )); then
+ failed_scp_attempts="$(( failed_scp_attempts+1 ))"
+ sleep 1
+ else
+ last_line="$(tail --lines=1 "${log_file}")"
+ fi
+ done
+
+ if (( failed_scp_attempts >= MAX_SCP_ATTEMPTS )); then
+ echo "Was unable to copy logs from instance. Output from scp:"
+ echo "${scp_output}"
+ exit 1
+ fi
+
+ if [[ "${last_line}" != "Setup complete" ]]; then
+ echo "Instance did not complete its setup. Please check the logs above."
+ exit 1
+ fi
+
+ echo "Startup finished successfully."
+
+ echo "Deleting log file"
+ gcloud compute ssh "${INSTANCE_NAME}" --zone="${ZONE}" \
+ --no-user-output-enabled \
+ --command="sudo rm /startup.log"
+
+ echo "Shutting down instance"
+ # This actually does things synchronously, so we don't need our own loop to
+ # wait.
+ gcloud compute instances stop "${INSTANCE_NAME}" --zone="${ZONE}"
+
+ echo "Creating disk image"
+ gcloud compute images create "${IMAGE_NAME}" \
+ --source-disk="${INSTANCE_NAME}" \
+ --source-disk-zone="${ZONE}"
+
+ echo "Deleting instance"
+ gcloud compute instances delete "${INSTANCE_NAME}" --zone="${ZONE}" --quiet
+
+ echo "Successfully created image: ${IMAGE_NAME}"
+}
+
+create_image
diff --git a/build_tools/github_actions/runner/gcp/image_setup.sh b/build_tools/github_actions/runner/gcp/image_setup.sh
new file mode 100644
index 0000000..f942241
--- /dev/null
+++ b/build_tools/github_actions/runner/gcp/image_setup.sh
@@ -0,0 +1,157 @@
+#!/bin/bash
+
+# Copyright 2022 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+# This is the series of commands run on the a VM from a fresh image in order to
+# set up the disk to be used as a boot image. This script must be run as root.
+
+set -o verbose # Print all command lines literally as they are read
+set -o xtrace # Print all commands after they are expanded
+set -o errexit # Exit if any command fails
+set -o errtrace # make ERR trap inherit
+set -o pipefail # return error if any part of a pipe errors
+set -o nounset # error if an undefined variable is used
+
+
+function startup() {
+ #################################### APT #####################################
+ # Disable apt prompts
+ export DEBIAN_FRONTEND="noninteractive"
+
+ # Disable automatic updates and upgrades. These are ephemeral machines. We don't
+ # want the latency or inconsistency of automatic updatees.
+ systemctl stop apt-daily.timer
+ systemctl disable apt-daily.timer
+ systemctl disable apt-daily.service
+ systemctl stop apt-daily-upgrade.timer
+ systemctl disable apt-daily-upgrade.timer
+ systemctl disable apt-daily-upgrade.service
+
+ # Don't install documentation (except copyrights) since this is a CI system.
+ cat > /etc/dpkg/dpkg.cfg.d/github-actions <<EOF
+force-all
+no-pager
+# don't install docs
+path-exclude /usr/share/doc/*
+path-exclude /usr/share/man/*
+path-exclude /usr/share/groff/*
+path-exclude /usr/share/info/*
+# keep copyright files for legal reasons
+path-include /usr/share/doc/*/copyright
+EOF
+
+ # Provide default apt options like --assume-yes and --quiet since this is
+ # designed to run on CI.
+ cat > /etc/apt/apt.conf.d/github-actions <<EOF
+APT {
+ Install-Recommends "false";
+ HideAutoRemove "true";
+}
+Aptitude {
+ CmdLine {
+ Assume-Yes "true";
+ }
+}
+Acquire {
+ Retries "5";
+}
+DPkg {
+ Use-Pty "0";
+ Options {
+ "--force-confdef";
+ "--force-confnew";
+ "--force-confold";
+ }
+}
+Quiet "2";
+EOF
+
+ # Install apt-fast for parallel apt package installation.
+ add-apt-repository -y ppa:apt-fast/stable
+ apt-get update
+ apt-get install apt-fast
+ apt-get upgrade
+ apt-get dist-upgrade
+ apt-get full-upgrade
+ # Install common deps.
+ apt-get install \
+ apt-transport-https \
+ aria2 \
+ ca-certificates \
+ curl \
+ git \
+ gnupg2 \
+ jq \
+ lsb-release \
+ software-properties-common
+
+ ########################### Create the runner user ###########################
+
+ # GCE "helpfully" creates users for apparently any account that has ever
+ # logged in on any VM. Delete it if it's there.
+ userdel --force --remove runner || true
+ adduser --system --group "runner"
+ groupadd docker
+ usermod --append --groups docker runner
+ usermod --append --groups sudo runner
+ groups runner # Print out the groups of runner to verify this worked
+
+ echo "enabling passwordless sudo for runner user"
+ echo "runner ALL=(ALL) NOPASSWD:ALL" > /etc/sudoers.d/runner
+
+
+
+ ############################### Install Docker ###############################
+
+ # Remove Docker stuff that may already be installed, proceeding if they're not.
+ apt-get remove containerd docker docker-engine docker.io moby-engine moby-cli runc || true
+
+ # Install the latest Docker
+ curl -sfSL https://download.docker.com/linux/ubuntu/gpg | gpg --dearmor -o /usr/share/keyrings/docker-archive-keyring.gpg
+ echo \
+ "deb [arch=amd64 signed-by=/usr/share/keyrings/docker-archive-keyring.gpg] https://download.docker.com/linux/ubuntu \
+ $(lsb_release -cs) stable" | tee /etc/apt/sources.list.d/docker.list
+ apt-get update
+ apt-get install docker-ce docker-ce-cli containerd.io
+
+ # Enable docker.service.
+ sudo systemctl enable docker.service
+ sudo systemctl start docker.service
+ sudo systemctl enable containerd.service
+ sudo systemctl start containerd.service
+
+ # Docker daemon takes time to come up after installing.
+ for i in $(seq 1 30); do
+ if docker info; then
+ break
+ fi
+ done
+
+ # Make sure the runner user can use docker
+ runuser --user runner -- docker ps
+
+ ################################### Cleanup ####################################
+
+ apt-get clean
+ rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/*
+ rm -rf /var/lib/dhcp/*
+
+ # Delete unnecessary log files
+ find /var/log -type f -regex ".*\.gz$" -delete
+ find /var/log -type f -regex ".*\.[0-9]$" -delete
+
+ # Clear all journal files
+ journalctl --rotate --vacuum-time=1s
+
+ # And clear others
+ find /var/log/ -type f -exec truncate -s 0 {} \;
+
+ # This specific log line is load bearing, as it's referenced in create_image.sh
+ echo "Setup complete"
+}
+
+startup 2>&1 | tee /startup.log
diff --git a/build_tools/scripts/check_path_lengths.py b/build_tools/scripts/check_path_lengths.py
new file mode 100755
index 0000000..645ba7d
--- /dev/null
+++ b/build_tools/scripts/check_path_lengths.py
@@ -0,0 +1,99 @@
+#!/usr/bin/env python3
+# Copyright 2022 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+# This scans the IREE source tree for long path lengths, which are problematic
+# on Windows: https://docs.microsoft.com/en-us/windows/win32/fileio/maximum-file-path-limitation
+#
+# We ultimately care that the build system is happy, but CMake on Windows in
+# particular does not actually give early or easy to understand error messages,
+# and developers/CI using Linux may still want to see warnings. We'll use
+# relative directory path length as a reasonable heuristic for "will the build
+# system be happy?", since CMake tends to create paths like this:
+# `iree/compiler/.../Foo/CMakeFiles/iree_compiler_Foo_Foo.objects.dir/bar.obj`.
+# Note that 'Foo' appears three times in that path, so that's typically the best
+# place to trim characters (and not file names).
+#
+# To check that all relative paths are shorter than the default limit:
+# python check_path_lengths.py
+#
+# To check that all relative paths are shorter than a custom limit:
+# python check_path_lengths.py --limit=50
+
+import argparse
+import os
+import pathlib
+import sys
+
+
+def parse_arguments():
+ parser = argparse.ArgumentParser(description="Path length checker")
+ # The default limit was selected based on repository state when this script
+ # was added. If the max path length decreases, consider lowering this too.
+ parser.add_argument("--limit",
+ help="Path length limit (inclusive)",
+ type=int,
+ default=75)
+ parser.add_argument(
+ "--include_tests",
+ help=
+ "Includes /test directories. False by default as these don't usually generate problematic files during the build",
+ action="store_true",
+ default=False)
+ parser.add_argument("--verbose",
+ help="Outputs detailed information about path lengths",
+ action="store_true",
+ default=False)
+ args = parser.parse_args()
+ return args
+
+
+def main(args):
+ repo_root = pathlib.Path(__file__).parent.parent.parent
+
+ # Just look at the compiler directory for now, since it has historically had
+ # by far the longest paths.
+ walk_root = os.path.join(repo_root, "compiler")
+
+ longest_path_length = -1
+ long_paths = []
+ short_paths = []
+ for dirpath, dirnames, _ in os.walk(walk_root):
+ # Don't descend into test directories, since they typically don't generate
+ # object files or binaries that could trip up the build system.
+ if not args.include_tests and "test" in dirnames:
+ dirnames.remove("test")
+
+ path = pathlib.Path(dirpath).relative_to(repo_root).as_posix()
+ if len(path) > args.limit:
+ long_paths.append(path)
+ else:
+ short_paths.append(path)
+ longest_path_length = max(longest_path_length, len(path))
+ long_paths.sort(key=len)
+ short_paths.sort(key=len)
+
+ if args.verbose and short_paths:
+ print(f"These paths are shorter than the limit of {args.limit} characters:")
+ for path in short_paths:
+ print("{:3d}, {}".format(len(path), path))
+
+ if long_paths:
+ print(f"These paths are longer than the limit of {args.limit} characters:")
+ for path in long_paths:
+ print("{:3d}, {}".format(len(path), path))
+ print(
+ f"Error: {len(long_paths)} source paths are longer than {args.limit} characters."
+ )
+ print(" Long paths can be problematic when building on Windows.")
+ print(" Please look at the output above and trim the paths.")
+ sys.exit(1)
+ else:
+ print(f"All path lengths are under the limit of {args.limit} characters.")
+
+
+if __name__ == "__main__":
+ main(parse_arguments())
diff --git a/build_tools/scripts/integrate/README.md b/build_tools/scripts/integrate/README.md
index 43788d7..c035a57 100644
--- a/build_tools/scripts/integrate/README.md
+++ b/build_tools/scripts/integrate/README.md
@@ -351,8 +351,8 @@
An example from a log:
```
-[18:30:23 UTC] docker run --volume=/tmpfs/src/github/iree:/tmpfs/src/github/iree --workdir=/tmpfs/src/github/iree --rm --user=1003:1004 --volume=/tmpfs/fake_etc/group:/etc/group:ro --volume=/tmpfs/fake_etc/passwd:/etc/passwd:ro --volume=/tmpfs/fake_home:/home/kbuilder --volume=/home/kbuilder/.config/gcloud:/home/kbuilder/.config/gcloud:ro gcr.io/iree-oss/frontends-swiftshader@sha256:3090418a8d8a64c356d35eff285af32570a72f41127aa123209c1562f57abb01 build_tools/kokoro/gcp_ubuntu/bazel/linux/x86-swiftshader/core/build.sh
-Unable to find image 'gcr.io/iree-oss/frontends-swiftshader@sha256:3090418a8d8a64c356d35eff285af32570a72f41127aa123209c1562f57abb01' locally
+[18:30:23 UTC] docker run --volume=/tmpfs/src/github/iree:/tmpfs/src/github/iree --workdir=/tmpfs/src/github/iree --rm --user=1003:1004 --volume=/tmpfs/fake_etc/group:/etc/group:ro --volume=/tmpfs/fake_etc/passwd:/etc/passwd:ro --volume=/tmpfs/fake_home:/home/kbuilder --volume=/home/kbuilder/.config/gcloud:/home/kbuilder/.config/gcloud:ro gcr.io/iree-oss/frontends-swiftshader@sha256:3d5b879672d7f302124ab3d1aa533a6949bd0adfc176884177844ac6767e23e9 build_tools/kokoro/gcp_ubuntu/bazel/linux/x86-swiftshader/core/build.sh
+Unable to find image 'gcr.io/iree-oss/frontends-swiftshader@sha256:3d5b879672d7f302124ab3d1aa533a6949bd0adfc176884177844ac6767e23e9' locally
sha256:aeb8de9fb7af3913d385ec6b274320197d61aa7bc51a6e8bc0deba644da3e405: Pulling from iree-oss/frontends-swiftshader
```
@@ -360,7 +360,7 @@
you have the enviroment as same as CI bot and requires less local setup.
```
-docker run --interactive --tty --rm --volume=$PWD:/src/iree --workdir=/src/iree gcr.io/iree-oss/frontends-swiftshader@sha256:3090418a8d8a64c356d35eff285af32570a72f41127aa123209c1562f57abb01
+docker run --interactive --tty --rm --volume=$PWD:/src/iree --workdir=/src/iree gcr.io/iree-oss/frontends-swiftshader@sha256:3d5b879672d7f302124ab3d1aa533a6949bd0adfc176884177844ac6767e23e9
```
To repro failures in `iree/e2e/`:
diff --git a/build_tools/scripts/lint.sh b/build_tools/scripts/lint.sh
index 06ac125..ebab0dd 100755
--- a/build_tools/scripts/lint.sh
+++ b/build_tools/scripts/lint.sh
@@ -120,6 +120,9 @@
echo "'yamllint' not found. Skipping check"
fi
+echo "***** Path Lengths *****"
+./build_tools/scripts/check_path_lengths.py
+
if (( "${FINAL_RET}" != 0 )); then
echo "Encountered failures. Check error messages and changes to the working" \
"directory and git index (which may contain fixes) and try again."
diff --git a/compiler/src/iree/compiler/Codegen/Common/BUILD b/compiler/src/iree/compiler/Codegen/Common/BUILD
index 31c60d3..af1d920 100644
--- a/compiler/src/iree/compiler/Codegen/Common/BUILD
+++ b/compiler/src/iree/compiler/Codegen/Common/BUILD
@@ -180,6 +180,7 @@
"GPUDistributeSharedMemoryCopy.cpp",
"GPUPipelining.cpp",
"GPUVectorization.cpp",
+ "LinalgOpInfo.cpp",
"MemrefCopyToLinalg.cpp",
"PadDynamicAlloc.cpp",
"RemoveTrivialLoops.cpp",
@@ -190,6 +191,9 @@
"VectorizeConv.cpp",
"WorkGroupSwizzle.cpp",
],
+ hdrs = [
+ "LinalgOpInfo.h",
+ ],
deps = [
":CommonPasses",
":TransformDialectInterpreterPass",
diff --git a/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt
index 59a6dac..17ff3ad 100644
--- a/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt
@@ -145,12 +145,15 @@
iree_cc_library(
NAME
Common
+ HDRS
+ "LinalgOpInfo.h"
SRCS
"DecomposeLinalgGeneric.cpp"
"FoldAffineMinInDistributedLoops.cpp"
"GPUDistributeSharedMemoryCopy.cpp"
"GPUPipelining.cpp"
"GPUVectorization.cpp"
+ "LinalgOpInfo.cpp"
"MemrefCopyToLinalg.cpp"
"PadDynamicAlloc.cpp"
"RemoveTrivialLoops.cpp"
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPUDistributeSharedMemoryCopy.cpp b/compiler/src/iree/compiler/Codegen/Common/GPUDistributeSharedMemoryCopy.cpp
index a493bb6..3c02c82 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPUDistributeSharedMemoryCopy.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPUDistributeSharedMemoryCopy.cpp
@@ -7,6 +7,7 @@
#include <algorithm>
#include <numeric>
+#include "iree-dialects/Dialect/LinalgExt/Transforms/Transforms.h"
#include "iree/compiler/Codegen/Dialect/LoweringConfig.h"
#include "iree/compiler/Codegen/PassDetail.h"
#include "iree/compiler/Codegen/Passes.h"
@@ -16,6 +17,7 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/SCF/Utils/Utils.h"
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/MLIRContext.h"
@@ -23,11 +25,18 @@
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/Passes.h"
+using mlir::iree_compiler::IREE::LinalgExt::LinalgVectorizationPattern;
+using mlir::iree_compiler::IREE::LinalgExt::VectorizationPatterns;
+
//====---------------------------------------------------------------------===//
// Pass to lower workgroup memory copy to distibuted
// transfer_read/transfer_write ops.
//====---------------------------------------------------------------------===//
+// Markers for intermediate transformations.
+static const llvm::StringRef kCopyToDistribute = "copy_to_distribute";
+static const llvm::StringRef kCopyDistributed = "copy_distributed";
+
namespace mlir {
namespace iree_compiler {
@@ -85,26 +94,17 @@
StringAttr::get(patterns.getContext(), getVectorizeMarker())));
}
-static void populateVectorizationPatterns(RewritePatternSet &patterns) {
- linalg::VectorizationPatterns<linalg::GenericOp>::insert(
- patterns, linalg::LinalgVectorizationOptions(),
- linalg::LinalgTransformationFilter(StringAttr::get(
- patterns.getContext(), getCopyToWorkgroupMemoryMarker())));
-}
-
-/// Compute a vector size so that the numer of elements is equal to the flat
+/// Compute a tile size so that the numer of iteraton is equal to the flat
/// workgroup size.
-static Optional<SmallVector<int64_t, 4>> getGPUNativeVectorSize(
- Operation *op, int64_t flatWorkgroupSize,
- const llvm::SmallDenseSet<VectorTransferOpInterface> &opsToIgnore) {
- auto vt = dyn_cast<VectorTransferOpInterface>(op);
- if (!vt) return llvm::None;
- if (opsToIgnore.count(vt)) return llvm::None;
- if (!vt.permutation_map().isMinorIdentity()) return llvm::None;
- ArrayRef<int64_t> shape = vt.getVectorType().getShape();
- int targetVectorSize =
- copyVectorNumBits / vt.getVectorType().getElementTypeBitWidth();
- SmallVector<int64_t, 4> unroll;
+static Optional<SmallVector<int64_t>> getTileToDistributableSize(
+ linalg::GenericOp copyOp, int64_t flatWorkgroupSize) {
+ SmallVector<int64_t, 4> shape = copyOp.getStaticLoopRanges();
+ unsigned bitWidth = copyOp->getOperand(0)
+ .getType()
+ .cast<MemRefType>()
+ .getElementTypeBitWidth();
+ int targetVectorSize = copyVectorNumBits / bitWidth;
+ SmallVector<int64_t> unroll;
assert(shape.back() % targetVectorSize == 0);
int64_t threadsAvailable = flatWorkgroupSize;
for (auto &dim : llvm::enumerate(llvm::reverse(shape))) {
@@ -119,18 +119,130 @@
assert(threadsAvailable == 1);
unroll.resize(shape.size(), 1);
std::reverse(unroll.begin(), unroll.end());
- if (unroll == shape) return llvm::None;
return unroll;
}
-static void populateVectorUnrollPatterns(
- RewritePatternSet &patterns, int64_t flatWorkgroupSize,
- const llvm::SmallDenseSet<VectorTransferOpInterface> &opsToIgnore) {
- auto getShape = [flatWorkgroupSize, &opsToIgnore](Operation *op) {
- return getGPUNativeVectorSize(op, flatWorkgroupSize, opsToIgnore);
+/// Pattern to tile copies using serial loops into a shape that can be
+/// distributed onto thread.
+static void populateTileToUnroll(RewritePatternSet &patterns,
+ int64_t flatWorkgroupSize) {
+ linalg::TileSizeComputationFunction wgCopyTileSizeFn =
+ [flatWorkgroupSize](OpBuilder &builder, Operation *operation) {
+ SmallVector<Value, 4> tileSizesVal;
+ auto copyOp = dyn_cast<linalg::GenericOp>(operation);
+ if (!copyOp) return tileSizesVal;
+ Optional<SmallVector<int64_t>> staticSize =
+ getTileToDistributableSize(copyOp, flatWorkgroupSize);
+ for (int64_t dim : *staticSize) {
+ tileSizesVal.push_back(
+ builder.create<arith::ConstantIndexOp>(operation->getLoc(), dim));
+ }
+ return tileSizesVal;
+ };
+
+ auto tilingOptions = linalg::LinalgTilingOptions()
+ .setLoopType(linalg::LinalgTilingLoopType::Loops)
+ .setTileSizeComputationFunction(wgCopyTileSizeFn);
+ patterns.insert<linalg::LinalgTilingPattern>(
+ linalg::GenericOp::getOperationName(), patterns.getContext(),
+ tilingOptions,
+ linalg::LinalgTransformationFilter(
+ {StringAttr::get(patterns.getContext(),
+ getCopyToWorkgroupMemoryMarker())},
+ StringAttr::get(patterns.getContext(), kCopyToDistribute)));
+}
+
+/// Break up the flat id onto the static loop ranges.
+SmallVector<linalg::ProcInfo> getIds(OpBuilder &b, Location loc,
+ ArrayRef<Range> parallelLoopRanges,
+ Value flatThreadId) {
+ SmallVector<linalg::ProcInfo> infos;
+ Value id = flatThreadId;
+ AffineExpr d0 = getAffineDimExpr(0, b.getContext());
+ for (Range r : llvm::reverse(parallelLoopRanges)) {
+ linalg::ProcInfo info;
+ auto offset = r.offset.dyn_cast<Attribute>();
+ auto stride = r.stride.dyn_cast<Attribute>();
+ auto size = r.size.dyn_cast<Attribute>();
+ assert(offset && stride && size);
+ int64_t numThreadsDim = (size.cast<IntegerAttr>().getInt() -
+ offset.cast<IntegerAttr>().getInt()) /
+ stride.cast<IntegerAttr>().getInt();
+ Value dimId = id;
+ if (infos.size() != parallelLoopRanges.size() - 1)
+ dimId = makeComposedAffineApply(b, loc, d0 % numThreadsDim, {dimId});
+ info.procId = dimId;
+ info.nprocs = b.create<arith::ConstantIndexOp>(loc, numThreadsDim);
+ info.distributionMethod =
+ linalg::DistributionMethod::CyclicNumProcsEqNumIters;
+ infos.push_back(info);
+ id = makeComposedAffineApply(b, loc, d0.floorDiv(numThreadsDim), {id});
+ }
+ std::reverse(infos.begin(), infos.end());
+ return infos;
+}
+
+/// Return the shape of copy op that can be vectorized to a
+/// transfer_read/transfer_write of size `targetVectorSize`.
+SmallVector<int64_t> getNativeDstShape(linalg::GenericOp copyOp) {
+ unsigned bitWidth = copyOp->getOperand(0)
+ .getType()
+ .cast<MemRefType>()
+ .getElementTypeBitWidth();
+ int targetVectorSize = copyVectorNumBits / bitWidth;
+ SmallVector<int64_t> dstShape;
+ for (int64_t dim : copyOp.getStaticLoopRanges()) {
+ // Skip tiling of dimension of size 1 to simplify distribution.
+ dstShape.push_back(dim == 1 ? 0 : 1);
+ }
+ dstShape.back() = targetVectorSize;
+ return dstShape;
+}
+
+/// Distribute linalg copy onto threads based on the flat id.
+static void populateTilingAndDistribute(RewritePatternSet &patterns,
+ Value flatThreadId) {
+ linalg::TileSizeComputationFunction wgCopyTileSizeFn =
+ [](OpBuilder &builder, Operation *operation) {
+ SmallVector<Value, 4> tileSizesVal;
+ auto copyOp = dyn_cast<linalg::GenericOp>(operation);
+ if (!copyOp) return tileSizesVal;
+ SmallVector<int64_t> staticSize = getNativeDstShape(copyOp);
+ for (int64_t dim : staticSize) {
+ tileSizesVal.push_back(
+ builder.create<arith::ConstantIndexOp>(operation->getLoc(), dim));
+ }
+ return tileSizesVal;
+ };
+ auto getCopyThreadProcInfoFn = [flatThreadId](
+ OpBuilder &builder, Location loc,
+ ArrayRef<Range> parallelLoopRanges) {
+ return getIds(builder, loc, parallelLoopRanges, flatThreadId);
};
- vector::populateVectorUnrollPatterns(
- patterns, vector::UnrollVectorOptions().setNativeShapeFn(getShape));
+ linalg::LinalgLoopDistributionOptions copyInvocationDistributionOptions;
+ copyInvocationDistributionOptions.procInfo = getCopyThreadProcInfoFn;
+
+ auto tilingOptions =
+ linalg::LinalgTilingOptions()
+ .setLoopType(linalg::LinalgTilingLoopType::ParallelLoops)
+ .setTileSizeComputationFunction(wgCopyTileSizeFn)
+ .setDistributionOptions(copyInvocationDistributionOptions);
+ patterns.insert<linalg::LinalgTilingPattern>(
+ linalg::GenericOp::getOperationName(), patterns.getContext(),
+ tilingOptions,
+ linalg::LinalgTransformationFilter(
+ {StringAttr::get(patterns.getContext(), kCopyToDistribute)},
+ StringAttr::get(patterns.getContext(), kCopyDistributed)));
+}
+
+static void populateVectorizationPatterns(RewritePatternSet &patterns) {
+ VectorizationPatterns<linalg::GenericOp>::insert(
+ patterns, linalg::LinalgVectorizationOptions(),
+ linalg::LinalgTransformationFilter(
+ {StringAttr::get(patterns.getContext(),
+ getCopyToWorkgroupMemoryMarker()),
+ StringAttr::get(patterns.getContext(), kCopyDistributed)},
+ llvm::None));
}
/// Return a flattened Id Value by combining the 3D gpu thread IDs.
@@ -154,58 +266,6 @@
return flatThreadId;
}
-/// Distribute a transfer read operations on the given thread ids.
-static void distributeTransferRead(
- func::FuncOp funcOp, Value flatThreadId, int64_t flatWorkgroupSize,
- const llvm::SmallDenseSet<VectorTransferOpInterface> &opsToIgnore) {
- funcOp.walk([&](vector::TransferReadOp readOp) {
- if (opsToIgnore.count(
- cast<VectorTransferOpInterface>(readOp.getOperation())))
- return WalkResult::advance();
- OpBuilder b(readOp);
- Value id = flatThreadId;
- SmallVector<int64_t, 2> multiplier;
- auto shape = readOp.getVectorType().getShape();
- int targetVectorSize =
- copyVectorNumBits / readOp.getVectorType().getElementTypeBitWidth();
- SmallVector<Value> ids;
- SmallVector<AffineExpr> exprs;
- AffineExpr d0 = getAffineDimExpr(0, b.getContext());
- int64_t numThreads = flatWorkgroupSize;
- for (auto &dim : llvm::enumerate(llvm::reverse(shape))) {
- int64_t threads =
- dim.index() == 0 ? (dim.value() / targetVectorSize) : dim.value();
- // If we don't need to distribute the dimension, skip it.
- if (threads == 1) continue;
- exprs.push_back(getAffineDimExpr(shape.size() - dim.index() - 1,
- funcOp->getContext()));
- multiplier.push_back(threads);
- Value dimId = id;
- assert(numThreads % threads == 0);
- if (numThreads / threads > 1) {
- dimId =
- makeComposedAffineApply(b, funcOp.getLoc(), d0 % threads, {dimId});
- }
- ids.push_back(dimId);
- numThreads = numThreads / threads;
- id = makeComposedAffineApply(b, funcOp.getLoc(), d0.floorDiv(threads),
- {id});
- if (numThreads <= 1) break;
- }
- std::reverse(ids.begin(), ids.end());
- Optional<mlir::vector::DistributeOps> ops =
- vector::distributPointwiseVectorOp(
- b, readOp, ids, multiplier,
- AffineMap::get(shape.size(), 0, exprs, funcOp.getContext()));
- if (ops.has_value()) {
- SmallPtrSet<Operation *, 1> extractOp({ops->extract, ops->insert});
- readOp.getResult().replaceAllUsesExcept(ops->insert.getResult(),
- extractOp);
- }
- return WalkResult::advance();
- });
-}
-
/// Hoist allocations to the top of the loop if they have no dependencies.
static void hoistAlloc(func::FuncOp funcOp) {
SmallVector<memref::AllocOp> allocs;
@@ -239,6 +299,31 @@
});
}
+/// Return the number of iteration if it is static, otherwise returns 0.
+static int64_t numIteration(scf::ForOp forOp) {
+ auto lbCstOp = forOp.getLowerBound().getDefiningOp<arith::ConstantIndexOp>();
+ auto ubCstOp = forOp.getUpperBound().getDefiningOp<arith::ConstantIndexOp>();
+ auto stepCstOp = forOp.getStep().getDefiningOp<arith::ConstantIndexOp>();
+ if (!lbCstOp || !ubCstOp || !stepCstOp || lbCstOp.value() < 0 ||
+ ubCstOp.value() < 0 || stepCstOp.value() < 0)
+ return 0;
+ int64_t tripCount =
+ mlir::ceilDiv(ubCstOp.value() - lbCstOp.value(), stepCstOp.value());
+ return tripCount;
+}
+
+/// Fully unroll all the static loops unless they are part of the ignore map.
+static void UnrollSharedMemoryLoops(
+ func::FuncOp funcOp, const llvm::SmallDenseSet<scf::ForOp> &loopsToIgnore) {
+ SmallVector<scf::ForOp> forOpsToUnroll;
+ funcOp.walk([&](scf::ForOp forOp) {
+ if (!loopsToIgnore.count(forOp)) forOpsToUnroll.push_back(forOp);
+ });
+ for (scf::ForOp forOp : llvm::reverse(forOpsToUnroll)) {
+ (void)loopUnrollByFactor(forOp, numIteration(forOp));
+ }
+}
+
namespace {
class GPUDistributeSharedMemoryCopyPass
@@ -278,12 +363,30 @@
targetVectorSize);
});
if (isAligned) {
- // Ignore all the exisiting vector transfer ops.
- llvm::SmallDenseSet<VectorTransferOpInterface> opsToIgnore;
- funcOp.walk([&](VectorTransferOpInterface transferOp) {
- opsToIgnore.insert(transferOp);
- });
- // Step 1. Vectorize the shared memory copy.
+ // Ignore all the exisiting loop
+ llvm::SmallDenseSet<scf::ForOp> loopsToIgnore;
+ funcOp.walk([&](scf::ForOp loop) { loopsToIgnore.insert(loop); });
+
+ // Step 1. tile copies to get to a shape that can be distributed to
+ // 128bits per lane copies.
+ RewritePatternSet serialTilingPatterns(context);
+ populateTileToUnroll(serialTilingPatterns, flatWorkgroupSize);
+ if (failed(applyPatternsAndFoldGreedily(
+ funcOp, std::move(serialTilingPatterns)))) {
+ return signalPassFailure();
+ }
+
+ // Calculate a flat id that will then be broken down during distribution.
+ Value flatId = createFlatId(funcOp, workgroupSize);
+ // Step 2. Distribute the linalg op onto threads.
+ RewritePatternSet tileAndDistributePatterns(context);
+ populateTilingAndDistribute(tileAndDistributePatterns, flatId);
+ if (failed(applyPatternsAndFoldGreedily(
+ funcOp, std::move(tileAndDistributePatterns)))) {
+ return signalPassFailure();
+ }
+
+ // Step 3. Vectorize the distributed copies.
RewritePatternSet vectorizationPatterns(context);
populateVectorizationPatterns(vectorizationPatterns);
if (failed(applyPatternsAndFoldGreedily(
@@ -291,27 +394,8 @@
return signalPassFailure();
}
- // Step 2. Unroll transfer_read/transfer_write to a vector with the number
- // of element equal to `targetVectorSize * targetVectorSize`. The.
- // transfer op generated can. then be distributed to a single op of target
- // size.
- RewritePatternSet vectorUnrollPatterns(context);
- populateVectorUnrollPatterns(vectorUnrollPatterns, flatWorkgroupSize,
- opsToIgnore);
- if (failed(applyPatternsAndFoldGreedily(
- funcOp, std::move(vectorUnrollPatterns)))) {
- return signalPassFailure();
- }
- // Step 3. Distribute the transfer ops onto the flat ids.
- Value flatId = createFlatId(funcOp, workgroupSize);
- distributeTransferRead(funcOp, flatId, flatWorkgroupSize, opsToIgnore);
- // Propagate vector distribution to the chain of ops.
- RewritePatternSet distributePatterns(context);
- vector::populatePropagateVectorDistributionPatterns(distributePatterns);
- if (failed(applyPatternsAndFoldGreedily(funcOp,
- std::move(distributePatterns)))) {
- return signalPassFailure();
- }
+ // Step4. Finally unroll all the loop created
+ UnrollSharedMemoryLoops(funcOp, loopsToIgnore);
} else {
// Fall back to basic tiling for cases where workgroup memory size is not
// well aligned on the number of threads.
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPUVectorization.cpp b/compiler/src/iree/compiler/Codegen/Common/GPUVectorization.cpp
index 4455b11..b633279 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPUVectorization.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPUVectorization.cpp
@@ -4,6 +4,7 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+#include "iree-dialects/Dialect/LinalgExt/Transforms/Transforms.h"
#include "iree/compiler/Codegen/PassDetail.h"
#include "iree/compiler/Codegen/Passes.h"
#include "iree/compiler/Codegen/Transforms/Transforms.h"
@@ -17,6 +18,9 @@
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/Passes.h"
+using mlir::iree_compiler::IREE::LinalgExt::LinalgVectorizationPattern;
+using mlir::iree_compiler::IREE::LinalgExt::VectorizationPatterns;
+
#define DEBUG_TYPE "iree-codegen-gpu-vectorization"
namespace mlir {
@@ -34,10 +38,10 @@
StringAttr::get(ctx, getVectorizeMarker())},
llvm::None);
f.setMatchByDefault();
- linalg::VectorizationPatterns<linalg::FillOp, linalg::GenericOp>::insert(
- patterns, opt, f);
+ VectorizationPatterns<linalg::FillOp, linalg::GenericOp>::insert(patterns,
+ opt, f);
patterns.add<linalg::CopyVectorizationPattern>(ctx);
- patterns.add<linalg::LinalgVectorizationPattern>(
+ patterns.add<LinalgVectorizationPattern>(
ctx, f.addOpFilter<linalg::ContractionOpInterface>(), opt);
}
diff --git a/compiler/src/iree/compiler/Codegen/Common/LinalgOpInfo.cpp b/compiler/src/iree/compiler/Codegen/Common/LinalgOpInfo.cpp
new file mode 100644
index 0000000..f96bd67
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/Common/LinalgOpInfo.cpp
@@ -0,0 +1,108 @@
+// Copyright 2022 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Codegen/Common/LinalgOpInfo.h"
+
+#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
+
+using namespace mlir::linalg;
+
+namespace mlir {
+namespace iree_compiler {
+
+LinalgOpInfo::LinalgOpInfo(linalg::LinalgOp linalgOp) { computeInfo(linalgOp); }
+
+/// Returns true if `map` is a tranpose. A transpose map is a projected
+/// permutation with or without zeros in results where there exist at least two
+/// dimensions di and dj such that di < dj and result_pos(di) > result_pos(dj).
+/// Examples:
+///
+/// (d0, d1, d2) -> (d0, d2) is not a transpose map.
+/// (d0, d1, d2) -> (d2, d0) is a transpose map.
+/// (d0, d1, d2) -> (d1, d2) is not a transpose map.
+/// (d0, d1, d2) -> (d0, 0, d1) is not a transpose map.
+/// (d0, d1, d2) -> (d2, 0, d1) is a transpose map.
+/// (d0, d1, d2) -> (d1, 0) is not a transpose map.
+///
+// TODO(dcaballe): Discern between "memcopy" transposes and "shuffle"
+// transposes.
+// TODO(dcaballe): Move to Affine utils?
+static bool isTransposeMap(AffineMap map) {
+ // A transpose map must be a projected permutation with or without
+ // broadcasted/reduction dimensions.
+ if (!map.isProjectedPermutation(/*allowZeroInResults=*/true)) {
+ return false;
+ }
+
+ // Check that the projected permutation has at least two result dimensions
+ // that are actually transposed by comparing its input position.
+ unsigned prevDim = 0;
+ for (AffineExpr expr : map.getResults()) {
+ if (auto constExpr = expr.dyn_cast<AffineConstantExpr>()) {
+ // Constant zero expression, guaranteed by 'allowZeroInResults' above.
+ continue;
+ } else if (auto dimExpr = expr.dyn_cast<AffineDimExpr>()) {
+ if (prevDim > dimExpr.getPosition()) {
+ return true;
+ }
+ prevDim = dimExpr.getPosition();
+ } else {
+ return false;
+ }
+ }
+
+ return false;
+}
+
+/// Returns true if a LinalgOp implements a transpose.
+// TODO(dcaballe):
+// * Consider transpose + reductions.
+// * Consider input and output transposes.
+static bool isTransposeLinalgOp(linalg::LinalgOp linalgOp) {
+ // Reductions are not supported.
+ if (linalgOp.getNumReductionLoops() > 0) {
+ return false;
+ }
+
+ // Multiple outputs are not supported yet.
+ if (linalgOp.getNumOutputs() != 1) {
+ return false;
+ }
+
+ // Inverse map to use transfer op permutation logic.
+ AffineMap outputInversedMap = inversePermutation(
+ linalgOp.getTiedIndexingMap(linalgOp.getOutputOperand(0)));
+ SmallVector<AffineMap> inputInversedMaps;
+ for (OpOperand *linalgOperand : linalgOp.getInputOperands()) {
+ auto map = linalgOp.getTiedIndexingMap(linalgOperand);
+ if (!map.isProjectedPermutation(/*allowZeroInResults=*/true)) {
+ return false;
+ }
+ inputInversedMaps.push_back(inverseAndBroadcastProjectedPermutation(map));
+ }
+
+ bool isInputTransposed = llvm::any_of(
+ inputInversedMaps, [](AffineMap map) { return isTransposeMap(map); });
+ bool isOutputTransposed = isTransposeMap(outputInversedMap);
+
+ return isInputTransposed || isOutputTransposed;
+}
+
+static bool computeTransposeInfo(LinalgOp linalgOp) {
+ return isTransposeLinalgOp(linalgOp);
+}
+
+static bool computeReductionInfo(LinalgOp linalgOp) {
+ return linalgOp.getNumReductionLoops() > 1;
+}
+
+void LinalgOpInfo::computeInfo(LinalgOp linalgOp) {
+ transposeTrait = computeTransposeInfo(linalgOp);
+ reductionTrait = computeReductionInfo(linalgOp);
+}
+
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/compiler/src/iree/compiler/Codegen/Common/LinalgOpInfo.h b/compiler/src/iree/compiler/Codegen/Common/LinalgOpInfo.h
new file mode 100644
index 0000000..2a84f69
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/Common/LinalgOpInfo.h
@@ -0,0 +1,35 @@
+// Copyright 2022 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_COMPILER_CODEGEN_COMMON_LINALGOPINFO_H_
+#define IREE_COMPILER_CODEGEN_COMMON_LINALGOPINFO_H_
+
+namespace mlir {
+
+namespace linalg {
+class LinalgOp;
+}
+
+namespace iree_compiler {
+
+class LinalgOpInfo {
+ public:
+ LinalgOpInfo(linalg::LinalgOp linalgOp);
+
+ bool isTranspose() const { return transposeTrait; }
+ bool isReduction() const { return reductionTrait; }
+
+ private:
+ void computeInfo(linalg::LinalgOp);
+
+ bool transposeTrait;
+ bool reductionTrait;
+};
+
+} // namespace iree_compiler
+} // namespace mlir
+
+#endif // IREE_COMPILER_CODEGEN_COMMON_LINALGOPINFO_H_
diff --git a/compiler/src/iree/compiler/Codegen/Common/VectorizeConv.cpp b/compiler/src/iree/compiler/Codegen/Common/VectorizeConv.cpp
index 578541c..85a1e91 100644
--- a/compiler/src/iree/compiler/Codegen/Common/VectorizeConv.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/VectorizeConv.cpp
@@ -153,9 +153,14 @@
rewriter.getAffineMapArrayAttr({map02, map21, map01});
// Also build iterator types for the vector contraction op.
- ArrayAttr iterators = rewriter.getStrArrayAttr(
- {getParallelIteratorTypeName(), getParallelIteratorTypeName(),
- getReductionIteratorTypeName()});
+ auto parallelIteratorTypeAttr = vector::IteratorTypeAttr::get(
+ rewriter.getContext(), vector::IteratorType::parallel);
+ auto reductionIteratorTypeAttr = vector::IteratorTypeAttr::get(
+ rewriter.getContext(), vector::IteratorType::reduction);
+ SmallVector<Attribute> iteratorsList = {parallelIteratorTypeAttr,
+ parallelIteratorTypeAttr,
+ reductionIteratorTypeAttr};
+ ArrayAttr iterators = rewriter.getArrayAttr(iteratorsList);
// Compute the (numOutputHeights * numOutputWidths * numOutputChannels)
// batch. We only contribute numInputChannels accumulation along the
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/distribute_gpu_shared_memory.mlir b/compiler/src/iree/compiler/Codegen/Common/test/distribute_gpu_shared_memory.mlir
index adb183d..232f77a 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/distribute_gpu_shared_memory.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/test/distribute_gpu_shared_memory.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt --pass-pipeline='hal.executable(hal.executable.variant(builtin.module(func.func(iree-gpu-distribute-shared-memory-copy))))' --cse %s | FileCheck %s
+// RUN: iree-opt --pass-pipeline='hal.executable(hal.executable.variant(builtin.module(func.func(iree-gpu-distribute-shared-memory-copy))))' --fold-memref-alias-ops --canonicalize --cse %s | FileCheck %s
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0, s1, s2] -> (s1 * 8 + s2 * 32 + s0 floordiv 4)>
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 4) * 16)>
@@ -48,9 +48,9 @@
// CHECK-DAG: %[[Y0:.*]] = affine.apply #[[$MAP0]]()[%[[TX]], %[[TY]], %[[TZ]]]
// CHECK-DAG: %[[X0:.*]] = affine.apply #[[$MAP1]]()[%[[TX]]]
// CHECK: %[[R0:.*]] = vector.transfer_read %{{.*}}[%[[Y0]], %[[X0]]], %{{.*}} {in_bounds = [true, true]} : memref<64x16xf32>, vector<1x4xf32>
+ // CHECK: vector.transfer_write %[[R0]], %{{.*}}[%[[Y0]], %[[X0]]] {in_bounds = [true, true]} : vector<1x4xf32>, memref<64x16xf32, 3>
// CHECK-DAG: %[[Y1:.*]] = affine.apply #[[$MAP2]]()[%[[TX]], %[[TY]], %[[TZ]]]
// CHECK: %[[R1:.*]] = vector.transfer_read %{{.*}}[%[[Y1]], %[[X0]]], %{{.*}} {in_bounds = [true, true]} : memref<64x16xf32>, vector<1x4xf32>
- // CHECK: vector.transfer_write %[[R0]], %{{.*}}[%[[Y0]], %[[X0]]] {in_bounds = [true, true]} : vector<1x4xf32>, memref<64x16xf32, 3>
// CHECK: vector.transfer_write %[[R1]], %{{.*}}[%[[Y1]], %[[X0]]] {in_bounds = [true, true]} : vector<1x4xf32>, memref<64x16xf32, 3>
linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
@@ -64,9 +64,9 @@
// CHECK: %[[Y1:.*]] = affine.apply #[[$MAP3]]()[%[[TX]], %[[TY]], %[[TZ]]]
// CHECK: %[[R2:.*]] = vector.transfer_read %{{.*}}[%[[Y1]], %[[C0]]], %{{.*}} {in_bounds = [true, true]} : memref<256x4xf32>, vector<1x4xf32>
+ // CHECK: vector.transfer_write %[[R2]], %{{.*}}[%[[Y1]], %[[C0]]] {in_bounds = [true, true]} : vector<1x4xf32>, memref<256x4xf32, 3>
// CHECK: %[[Y2:.*]] = affine.apply #[[$MAP4]]()[%[[TX]], %[[TY]], %[[TZ]]]
// CHECK: %[[R3:.*]] = vector.transfer_read %{{.*}}[%[[Y2]], %[[C0]]], %{{.*}} {in_bounds = [true, true]} : memref<256x4xf32>, vector<1x4xf32>
- // CHECK: vector.transfer_write %[[R2]], %{{.*}}[%[[Y1]], %[[C0]]] {in_bounds = [true, true]} : vector<1x4xf32>, memref<256x4xf32, 3>
// CHECK: vector.transfer_write %[[R3]], %{{.*}}[%[[Y2]], %[[C0]]] {in_bounds = [true, true]} : vector<1x4xf32>, memref<256x4xf32, 3>
linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
@@ -80,11 +80,11 @@
// CHECK: %[[X1:.*]] = affine.apply #[[$MAP5]]()[%[[TX]], %[[TY]], %[[TZ]]]
// CHECK: %[[R4:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[X1]]], %{{.*}} {in_bounds = [true, true]} : memref<3x512xf32>, vector<1x4xf32>
+ // CHECK: vector.transfer_write %[[R4]], %{{.*}}[%[[C0]], %[[X1]]] {in_bounds = [true, true]} : vector<1x4xf32>, memref<3x512xf32, 3>
// CHECK: %[[R5:.*]] = vector.transfer_read %{{.*}}[%[[C1]], %[[X1]]], %{{.*}} {in_bounds = [true, true]} : memref<3x512xf32>, vector<1x4xf32>
+ // CHECK: vector.transfer_write %[[R5]], %{{.*}}[%[[C1]], %[[X1]]] {in_bounds = [true, true]} : vector<1x4xf32>, memref<3x512xf32, 3>
// CHECK: %[[R6:.*]] = vector.transfer_read %{{.*}}[%[[C2]], %[[X1]]], %{{.*}} {in_bounds = [true, true]} : memref<3x512xf32>, vector<1x4xf32>
- // CHECK: vector.transfer_write %[[R4]], %{{.*}}[%c0, %15] {in_bounds = [true, true]} : vector<1x4xf32>, memref<3x512xf32, 3>
- // CHECK: vector.transfer_write %[[R5]], %{{.*}}[%c1, %15] {in_bounds = [true, true]} : vector<1x4xf32>, memref<3x512xf32, 3>
- // CHECK: vector.transfer_write %[[R6]], %{{.*}}[%c2, %15] {in_bounds = [true, true]} : vector<1x4xf32>, memref<3x512xf32, 3>
+ // CHECK: vector.transfer_write %[[R6]], %{{.*}}[%[[C2]], %[[X1]]] {in_bounds = [true, true]} : vector<1x4xf32>, memref<3x512xf32, 3>
linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
iterator_types = ["parallel", "parallel"]}
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/BUILD b/compiler/src/iree/compiler/Codegen/LLVMCPU/BUILD
index ef0af0b..4e143ee 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/BUILD
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/BUILD
@@ -24,11 +24,13 @@
"LLVMCPUSynchronizeSymbolVisibility.cpp",
"LLVMCPUUnfuseFMAOps.cpp",
"Passes.cpp",
+ "TargetMLTransformInfo.cpp",
"VectorContractCustomKernels.cpp",
"VerifyLinalgTransformLegality.cpp",
],
hdrs = [
"KernelDispatch.h",
+ "TargetMLTransformInfo.h",
],
deps = [
"//compiler/src/iree/compiler/Codegen:PassHeaders",
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt
index f5aee1a..455d22e 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt
@@ -15,6 +15,7 @@
LLVMCPU
HDRS
"KernelDispatch.h"
+ "TargetMLTransformInfo.h"
SRCS
"ConvertToLLVM.cpp"
"KernelDispatch.cpp"
@@ -25,6 +26,7 @@
"LLVMCPUSynchronizeSymbolVisibility.cpp"
"LLVMCPUUnfuseFMAOps.cpp"
"Passes.cpp"
+ "TargetMLTransformInfo.cpp"
"VectorContractCustomKernels.cpp"
"VerifyLinalgTransformLegality.cpp"
DEPS
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp
index fa1e795..c7b5af5 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp
@@ -9,6 +9,8 @@
#include <numeric>
#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h"
+#include "iree/compiler/Codegen/Common/LinalgOpInfo.h"
+#include "iree/compiler/Codegen/LLVMCPU/TargetMLTransformInfo.h"
#include "iree/compiler/Codegen/Transforms/Transforms.h"
#include "iree/compiler/Codegen/Utils/Utils.h"
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
@@ -177,10 +179,12 @@
// tile sizes for vectorization/unrolling in one shot.
static SmallVector<int64_t> getMinTilingSizesForEachDim(
func::FuncOp entryPointFn, linalg::LinalgOp op,
- unsigned maxUnrollFactor = 8) {
+ const LinalgOpInfo &linalgOpInfo,
+ const TargetMLTransformInfo &targetMLTransInfo) {
unsigned numLoops = op.getNumLoops();
SmallVector<int64_t> minTileSizes(numLoops, 1);
auto inputOutputOpOperands = op.getInputAndOutputOperands();
+
for (auto map : llvm::enumerate(op.getIndexingMapsArray())) {
// Check the fastest varying dimension of the operand. Set the vector size
// of the corresponding loop to the vector size.
@@ -194,16 +198,36 @@
auto operandType =
inputOutputOpOperands[map.index()]->get().getType().cast<ShapedType>();
int64_t tileSize = getVectorSize(entryPointFn, operandType);
- // Vectorization of reductions is driven by input tensors and considering
- // the output's fastest varying dim leads to large unroll factors. We limit
- // the tile size for this case to 'maxUnrollFactor'.
- if (op.isOutputTensor(inputOutputOpOperands[map.index()]) &&
- op.getNumReductionLoops() > 0)
- tileSize = std::min<int64_t>(tileSize, maxUnrollFactor);
minTileSizes[fastestVaryingDim] =
std::max<int64_t>(minTileSizes[fastestVaryingDim], tileSize);
}
+
+ // Limit unroll factor. For now, we assume the rightmost non-one tiled
+ // dimension is for vectorization and any other non-one dimension is for
+ // unrolling.
+ auto limitUnrollFactor = [&](int64_t maxUnrollFactor) {
+ int vecDim;
+ for (vecDim = minTileSizes.size() - 1; vecDim >= 0; --vecDim) {
+ if (minTileSizes[vecDim] > 1) {
+ break;
+ }
+ }
+ for (int unrollDim = vecDim - 1; unrollDim >= 0; --unrollDim) {
+ minTileSizes[unrollDim] =
+ std::min<int64_t>(minTileSizes[unrollDim], maxUnrollFactor);
+ }
+ };
+
+ if (linalgOpInfo.isTranspose()) {
+ // Limit unrolling on transpose operations.
+ // TODO(dcaballe): Consider input and output transposes.
+ limitUnrollFactor(targetMLTransInfo.defaultMaxTransposeUnrollFactor);
+ } else {
+ // Limit unrolling to the default target maximum.
+ limitUnrollFactor(targetMLTransInfo.defaultMaxUnrollFactor);
+ }
+
return minTileSizes;
}
@@ -990,7 +1014,9 @@
/// Sets the default lowering configuration for a generic op to use
/// CPUDoubleTilingExpert pipeline.
static LogicalResult setDefaultGenericOpRootConfig(
- func::FuncOp entryPointFn, linalg::GenericOp genericOp) {
+ func::FuncOp entryPointFn, linalg::GenericOp genericOp,
+ const LinalgOpInfo &linalgOpInfo,
+ const TargetMLTransformInfo &targetMLTransInfo) {
if (getLoweringConfig(genericOp)) {
return success();
}
@@ -1003,8 +1029,8 @@
DispatchLoweringPassPipeline::CPUDefault);
}
- SmallVector<int64_t> minTileSizes =
- getMinTilingSizesForEachDim(entryPointFn, genericOp);
+ SmallVector<int64_t> minTileSizes = getMinTilingSizesForEachDim(
+ entryPointFn, genericOp, linalgOpInfo, targetMLTransInfo);
// For generic ops we'll use the default divided by 2 to control the stack
// allocation limit See #9469 for example.
SmallVector<int64_t> maxTileSizes(numLoops, defaultWorkgroupTileSize / 2);
@@ -1047,8 +1073,10 @@
/// Sets the lowering configuration for a generic op implementing a
/// transposition to use CPUDoubleTilingExpert pipeline.
-static LogicalResult setTransposeLikeOpRootConfig(func::FuncOp entryPointFn,
- linalg::GenericOp genericOp) {
+static LogicalResult setTransposeLikeOpRootConfig(
+ func::FuncOp entryPointFn, linalg::GenericOp genericOp,
+ const LinalgOpInfo &linalgOpInfo,
+ const TargetMLTransformInfo &targetMLTransInfo) {
if (getLoweringConfig(genericOp)) {
return success();
}
@@ -1060,8 +1088,8 @@
}
unsigned numLoops = genericOp.getNumLoops();
- SmallVector<int64_t> minTileSizes =
- getMinTilingSizesForEachDim(entryPointFn, genericOp);
+ SmallVector<int64_t> minTileSizes = getMinTilingSizesForEachDim(
+ entryPointFn, genericOp, linalgOpInfo, targetMLTransInfo);
SmallVector<int64_t> maxTileSizes(numLoops, defaultWorkgroupTileSize);
if (llvm::all_of(minTileSizes, [](int64_t vs) { return vs == 1; })) {
// Nothing to vectorize just lower to loops.
@@ -1116,7 +1144,9 @@
/// workload per workgroup to a larger number, which prevents runtime overheads
/// from tiny dispatches.
static LogicalResult setElementwiseGenericOpRootConfig(
- func::FuncOp entryPointFn, linalg::GenericOp genericOp) {
+ func::FuncOp entryPointFn, linalg::GenericOp genericOp,
+ const LinalgOpInfo &linalgOpInfo,
+ const TargetMLTransformInfo &targetMLTransInfo) {
if (getLoweringConfig(genericOp)) {
return success();
}
@@ -1126,8 +1156,8 @@
if (!linalg::isElementwise(genericOp)) return success();
// Set the flow level tiling to the default.
- SmallVector<int64_t> minTileSizes =
- getMinTilingSizesForEachDim(entryPointFn, genericOp);
+ SmallVector<int64_t> minTileSizes = getMinTilingSizesForEachDim(
+ entryPointFn, genericOp, linalgOpInfo, targetMLTransInfo);
SmallVector<int64_t> maxTileSizes(numLoops, defaultWorkgroupTileSize);
SmallVector<int64_t> flowTileSizes =
getDefaultDistributedLevelTileSizes(genericOp, minTileSizes, maxTileSizes,
@@ -1193,11 +1223,16 @@
/// Sets the lowering configuration for a generic op to use
/// CPUDoubleTilingExpert pipeline.
-static LogicalResult setRootConfig(func::FuncOp entryPointFn,
- linalg::GenericOp genericOp) {
- if (failed(setTransposeLikeOpRootConfig(entryPointFn, genericOp)) ||
- failed(setElementwiseGenericOpRootConfig(entryPointFn, genericOp)) ||
- failed(setDefaultGenericOpRootConfig(entryPointFn, genericOp))) {
+static LogicalResult setRootConfig(
+ func::FuncOp entryPointFn, linalg::GenericOp genericOp,
+ const LinalgOpInfo &linalgOpInfo,
+ const TargetMLTransformInfo &targetMLTransInfo) {
+ if (failed(setTransposeLikeOpRootConfig(entryPointFn, genericOp, linalgOpInfo,
+ targetMLTransInfo)) ||
+ failed(setElementwiseGenericOpRootConfig(
+ entryPointFn, genericOp, linalgOpInfo, targetMLTransInfo)) ||
+ failed(setDefaultGenericOpRootConfig(entryPointFn, genericOp,
+ linalgOpInfo, targetMLTransInfo))) {
return failure();
}
return success();
@@ -1356,16 +1391,21 @@
}
/// Redirects to methods that set the configuration based on operation type.
-static LogicalResult setRootConfigImpl(func::FuncOp entryPointFn,
- Operation *op) {
+static LogicalResult setRootConfigImpl(
+ func::FuncOp entryPointFn, Operation *op,
+ const TargetMLTransformInfo &targetMLTransInfo) {
// Do not overwrite default configuration.
if (getLoweringConfig(op)) return success();
// Redirect to individual operations.
auto setRootConfigFn = [&](Operation *op) -> LogicalResult {
return TypeSwitch<Operation *, LogicalResult>(op)
- .Case<IREE::LinalgExt::FftOp, linalg::GenericOp, linalg::Mmt4DOp,
- linalg::Conv2DNhwcHwcfOp, linalg::DepthwiseConv2DNhwcHwcOp>(
+ .Case<linalg::GenericOp>([&](auto op) {
+ return setRootConfig(entryPointFn, op, LinalgOpInfo(op),
+ targetMLTransInfo);
+ })
+ .Case<IREE::LinalgExt::FftOp, linalg::Mmt4DOp, linalg::Conv2DNhwcHwcfOp,
+ linalg::DepthwiseConv2DNhwcHwcOp>(
[&](auto op) { return setRootConfig(entryPointFn, op); })
.Case<linalg::ContractionOpInterface>(
[&](auto op) { return setRootConfig(entryPointFn, op); })
@@ -1451,7 +1491,10 @@
return failure();
}
} else {
- if (failed(setRootConfigImpl(entryPointFn, rootOperation))) {
+ auto targetMLTransInfo =
+ TargetMLTransformInfo::getTargetMLTransformInfo(*variantOp);
+ if (failed(setRootConfigImpl(entryPointFn, rootOperation,
+ targetMLTransInfo))) {
return failure();
}
}
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/TargetMLTransformInfo.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/TargetMLTransformInfo.cpp
new file mode 100644
index 0000000..bf26b16
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/TargetMLTransformInfo.cpp
@@ -0,0 +1,38 @@
+// Copyright 2022 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Codegen/LLVMCPU/TargetMLTransformInfo.h"
+
+#include "iree/compiler/Codegen/Utils/Utils.h"
+
+using namespace mlir;
+using namespace mlir::iree_compiler;
+
+namespace {
+
+struct RISCVTargetMLTransformInfo : TargetMLTransformInfo {
+ RISCVTargetMLTransformInfo() {
+ defaultMaxUnrollFactor = 8;
+ defaultMaxTransposeUnrollFactor = 1;
+ }
+};
+
+} // namespace
+
+namespace mlir {
+namespace iree_compiler {
+
+const TargetMLTransformInfo TargetMLTransformInfo::getTargetMLTransformInfo(
+ IREE::HAL::ExecutableVariantOp variantOp) {
+ if (isRISCV(variantOp)) {
+ return RISCVTargetMLTransformInfo();
+ }
+
+ return TargetMLTransformInfo();
+};
+
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/TargetMLTransformInfo.h b/compiler/src/iree/compiler/Codegen/LLVMCPU/TargetMLTransformInfo.h
new file mode 100644
index 0000000..bbdf4d3
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/TargetMLTransformInfo.h
@@ -0,0 +1,31 @@
+// Copyright 2022 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_COMPILER_CODEGEN_LLVMCPU_TARGETMLTRANSFORMINFO_H_
+#define IREE_COMPILER_CODEGEN_LLVMCPU_TARGETMLTRANSFORMINFO_H_
+
+#include <limits>
+
+#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+/// Holds target specific information to specialize ML transformations.
+// TODO(dcaballe): Move to a Concept-Model implementation when it's worth it.
+struct TargetMLTransformInfo {
+ unsigned defaultMaxUnrollFactor = 8;
+ unsigned defaultMaxTransposeUnrollFactor =
+ std::numeric_limits<unsigned>::max();
+
+ static const TargetMLTransformInfo getTargetMLTransformInfo(
+ IREE::HAL::ExecutableVariantOp variantOp);
+};
+
+} // namespace iree_compiler
+} // namespace mlir
+
+#endif // IREE_COMPILER_CODEGEN_LLVMCPU_TARGETMLTRANSFORMINFO_H_
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/VectorContractCustomKernels.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/VectorContractCustomKernels.cpp
index 5301830..514b070 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/VectorContractCustomKernels.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/VectorContractCustomKernels.cpp
@@ -47,9 +47,9 @@
SmallVector<int, 3> parallelIterators;
SmallVector<int, 3> reductionIterators;
for (int i = 0; i < 3; i++) {
- if (isParallelIterator(iteratorTypes[i])) {
+ if (vector::isParallelIterator(iteratorTypes[i])) {
parallelIterators.push_back(i);
- } else if (isReductionIterator(iteratorTypes[i])) {
+ } else if (vector::isReductionIterator(iteratorTypes[i])) {
reductionIterators.push_back(i);
} else {
return false;
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/transform_dialect_bufferize.mlir b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/transform_dialect_bufferize.mlir
index ed99f19..1b9ba95 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/transform_dialect_bufferize.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/transform_dialect_bufferize.mlir
@@ -33,10 +33,7 @@
}
}
-transform.with_pdl_patterns {
-^bb0(%arg0: !pdl.operation):
- transform.structured.canonicalized_sequence %arg0 failures(propagate) {
- ^bb1(%variant_op: !pdl.operation):
- transform.iree.bufferize %variant_op
- }
+transform.structured.canonicalized_sequence failures(propagate) {
+^bb1(%variant_op: !pdl.operation):
+ transform.iree.bufferize %variant_op
}
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
index 28ca796..9ded7ea 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
@@ -529,11 +529,26 @@
return success();
}
+/// Returns true if the index map represents a transpose that benefits from
+/// shared mem. Currently supports 2D transposes.
+static bool isSharedMemTranspose(AffineMap indexMap) {
+ if (!indexMap.isEmpty() && indexMap.isPermutation() &&
+ indexMap.getNumInputs() == 2) {
+ // Ensure that the fasted moving dimension (the last one) is permuted,
+ // Otherwise shared memory promotion will not benefit the operation.
+ if (indexMap.getDimPosition(indexMap.getNumDims() - 1) !=
+ indexMap.getNumDims() - 1) {
+ return true;
+ }
+ }
+ return false;
+}
+
/// Returns true if the operation is a GenericOp implementing a 2D transpose.
static bool isTransposeOp(linalg::LinalgOp linalgOp) {
if (!isa<linalg::GenericOp>(linalgOp)) return false;
- // Check that the op has 2 parallel loops.
- if (linalgOp.getNumParallelLoops() != 2) {
+ // Check that the op has at least 2 parallel loops.
+ if (linalgOp.getNumParallelLoops() < 2) {
return false;
}
@@ -542,31 +557,19 @@
return false;
}
- // Check that the op has only one input and one output.
- if ((linalgOp.getNumInputs() != 1) || (linalgOp.getNumOutputs() != 1)) {
- return false;
- }
- // Check for 2D operations
- auto inputShape =
- linalgOp.inputs()[0].getType().cast<ShapedType>().getShape();
- auto outputShape =
- linalgOp.outputs()[0].getType().cast<ShapedType>().getShape();
- if (inputShape.size() != 2 || outputShape.size() != 2) {
- return false;
- }
-
// Only transpose static shapes
if (linalgOp.hasDynamicShape()) {
return false;
}
- // Check that the two indexing maps are a permutation of each other.
- auto indexing_maps = linalgOp.getIndexingMapsArray();
- return !indexing_maps[0].isEmpty() && !indexing_maps[1].isEmpty() &&
- ((indexing_maps[0].isIdentity() && !indexing_maps[1].isIdentity() &&
- indexing_maps[1].isPermutation()) ||
- (!indexing_maps[0].isIdentity() && indexing_maps[0].isPermutation() &&
- indexing_maps[1].isIdentity()));
+ // Check that at least one input operands is transposed.
+ bool hasPermutation = false;
+ for (auto indexMap : linalgOp.getIndexingMapsArray()) {
+ if (isSharedMemTranspose(indexMap)) {
+ hasPermutation = true;
+ }
+ }
+ return hasPermutation;
}
static LogicalResult setTransposeConfig(func::FuncOp entryPoint,
@@ -576,12 +579,13 @@
TileSizesListType tileSizes;
tileSizes.push_back({tileM, tileN});
- // Check alignment with tile size
+ // Check alignment with tile size for each transpose.
if (auto genericOp = dyn_cast<linalg::GenericOp>(op)) {
- auto inputShape =
- genericOp.inputs()[0].getType().cast<ShapedType>().getShape();
- if (inputShape[0] % tileM != 0 || inputShape[1] % tileN != 0) {
- return failure();
+ auto loopRanges = genericOp.getStaticLoopRanges();
+ for (auto loopRange : loopRanges) {
+ if (loopRange % 32 != 0) {
+ return failure();
+ }
}
} else {
return failure();
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUTensorCoreVectorization.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUTensorCoreVectorization.cpp
index 99e669a..fe87b59 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUTensorCoreVectorization.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUTensorCoreVectorization.cpp
@@ -4,6 +4,7 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+#include "iree-dialects/Dialect/LinalgExt/Transforms/Transforms.h"
#include "iree/compiler/Codegen/Dialect/LoweringConfig.h"
#include "iree/compiler/Codegen/PassDetail.h"
#include "iree/compiler/Codegen/Passes.h"
@@ -14,6 +15,9 @@
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/Passes.h"
+using mlir::iree_compiler::IREE::LinalgExt::LinalgVectorizationPattern;
+using mlir::iree_compiler::IREE::LinalgExt::VectorizationPatterns;
+
namespace mlir {
namespace iree_compiler {
@@ -28,9 +32,9 @@
linalg::LinalgVectorizationOptions opt;
linalg::LinalgTransformationFilter f(
StringAttr::get(patterns.getContext(), getVectorizeMarker()));
- linalg::VectorizationPatterns<linalg::FillOp, linalg::GenericOp>::insert(
- patterns, opt, f);
- patterns.add<linalg::LinalgVectorizationPattern>(
+ VectorizationPatterns<linalg::FillOp, linalg::GenericOp>::insert(patterns,
+ opt, f);
+ patterns.add<LinalgVectorizationPattern>(
patterns.getContext(), f.addOpFilter<linalg::ContractionOpInterface>(),
opt);
vector::populateVectorTransferPermutationMapLoweringPatterns(patterns);
@@ -42,10 +46,12 @@
if (!contract) return llvm::None;
SmallVector<int64_t> order;
// Pick an unrolling order that will allow tensorcore operation to reuse LHS
- // register. THis is needed to get good performance on sm_80 target.
- // First make reduction the outter dimensions.
+ // register. This is needed to get good performance on sm_80 target.
+ // First make reduction the outer dimensions.
for (auto iter : llvm::enumerate(contract.getIteratorTypes())) {
- if (isReductionIterator(iter.value())) order.push_back(iter.index());
+ if (vector::isReductionIterator(iter.value())) {
+ order.push_back(iter.index());
+ }
}
llvm::SmallDenseSet<int64_t> dims;
@@ -54,13 +60,15 @@
}
// Then parallel dimensions that are part of Lhs as we want to re-use Lhs.
for (auto iter : llvm::enumerate(contract.getIteratorTypes())) {
- if (isParallelIterator(iter.value()) && dims.count(iter.index()))
+ if (vector::isParallelIterator(iter.value()) && dims.count(iter.index())) {
order.push_back(iter.index());
+ }
}
// Then the remaining parallel loops.
for (auto iter : llvm::enumerate(contract.getIteratorTypes())) {
- if (isParallelIterator(iter.value()) && !dims.count(iter.index()))
+ if (vector::isParallelIterator(iter.value()) && !dims.count(iter.index())) {
order.push_back(iter.index());
+ }
}
return order;
}
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUTileAndDistribute.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUTileAndDistribute.cpp
index 4f0952a..d72ae32 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUTileAndDistribute.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUTileAndDistribute.cpp
@@ -29,6 +29,8 @@
#include "mlir/Transforms/Passes.h"
#include "mlir/Transforms/SideEffectUtils.h"
+using mlir::iree_compiler::IREE::LinalgExt::TilingPatterns;
+
#define DEBUG_TYPE "iree-llvmgpu-tile-and-distribute"
namespace mlir {
@@ -66,9 +68,8 @@
StringAttr::get(context, getWorkgroupMemoryMarker())},
StringAttr::get(context, getWorkgroupKTiledMarker()));
filter.setMatchByDefault();
- linalg::TilingPatterns<linalg::MatmulOp, linalg::BatchMatmulOp,
- linalg::GenericOp>::insert(patterns, tilingOptions,
- filter);
+ TilingPatterns<linalg::MatmulOp, linalg::BatchMatmulOp,
+ linalg::GenericOp>::insert(patterns, tilingOptions, filter);
}
/// Return the tile size associated to one thread or warp based on the number of
@@ -134,10 +135,8 @@
StringAttr::get(context, getWorkgroupMemoryMarker())},
StringAttr::get(context, getVectorizeMarker()));
filter.setMatchByDefault();
- linalg::TilingPatterns<linalg::MatmulOp, linalg::FillOp,
- linalg::BatchMatmulOp,
- linalg::GenericOp>::insert(patterns, tilingOptions,
- filter);
+ TilingPatterns<linalg::MatmulOp, linalg::FillOp, linalg::BatchMatmulOp,
+ linalg::GenericOp>::insert(patterns, tilingOptions, filter);
}
/// Patterns for thread level tiling.
@@ -182,13 +181,57 @@
return success();
}
+/// Returns the indices of the transposed operands in a linalg generic.
+static SmallVector<int64_t> getTransposedOperands(linalg::GenericOp linalgOp) {
+ // Determine which operands to promote:
+ SmallVector<int64_t> transposedOperands;
+ if (linalgOp.getNumParallelLoops() < 2) {
+ return transposedOperands;
+ }
+ for (auto indexValue : llvm::enumerate(linalgOp.getIndexingMapsArray())) {
+ int64_t opIndex = indexValue.index();
+ auto indexMap = indexValue.value();
+ if (!indexMap.isEmpty() && indexMap.isPermutation()) {
+ // Ensure that the fasted moving dimension (the last one) is permuted
+ // otherwise data isn't moved.
+ if (indexMap.getDimPosition(indexMap.getNumDims() - 1) !=
+ indexMap.getNumDims() - 1) {
+ // Add operand to promote to list and mark the linalg for this
+ // promotion.
+ transposedOperands.push_back(opIndex);
+ }
+ }
+ }
+ return transposedOperands;
+}
+
+using PromotionFilterFunction = std::function<LogicalResult(Operation *op)>;
+
+/// Returns true if op is appropriate transpose for promotion.
+static LogicalResult transposeFilter(Operation *op,
+ linalg::GenericOp promotedFilterOp) {
+ return success(op == promotedFilterOp.getOperation());
+}
+
+/// Returns true if op is appropriate contract for promotion.
+static LogicalResult contractOpFilter(Operation *op) {
+ auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
+ if (!linalgOp) return failure();
+ // Limit promotion to matmul and batch matmul, there may be generic
+ // ops with more batch dimensions we didn't distribute and therefore
+ // cannot find a higher bound.
+ return success(linalg::isaContractionOpInterface(op) &&
+ linalgOp.getNumParallelLoops() >= 2 &&
+ linalgOp.getNumParallelLoops() <= 3);
+}
+
template <typename T>
using LinalgPromotionPattern =
mlir::iree_compiler::IREE::LinalgExt::LinalgPromotionPattern<T>;
-static void populatePromotionPatterns(
- MLIRContext *context, RewritePatternSet &patterns,
- GPUPromoteSharedMemPattern promoteSharedMemPattern,
- ArrayRef<int64_t> operandsToPromote) {
+static void populatePromotionPatterns(MLIRContext *context,
+ RewritePatternSet &patterns,
+ PromotionFilterFunction filterFunction,
+ ArrayRef<int64_t> operandsToPromote) {
patterns.insert<LinalgPromotionPattern<linalg::MatmulOp>,
LinalgPromotionPattern<linalg::BatchMatmulOp>,
LinalgPromotionPattern<linalg::GenericOp>>(
@@ -203,20 +246,7 @@
{StringAttr::get(context, getWorkgroupKTiledMarker())},
StringAttr::get(context, getWorkgroupMemoryMarker()))
.setMatchByDefault()
- .addFilter([promoteSharedMemPattern](Operation *op) {
- auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
- if (!linalgOp) return failure();
- if (promoteSharedMemPattern ==
- GPUPromoteSharedMemPattern::TransposeOpPattern) {
- return success(linalgOp.getNumParallelLoops() == 2);
- }
- // Limit promotion to matmul and batch matmul, there may be generic
- // ops with more batch dimensions we didn't distribute and therefore
- // cannot find a higher bound.
- return success(linalg::isaContractionOpInterface(op) &&
- linalgOp.getNumParallelLoops() >= 2 &&
- linalgOp.getNumParallelLoops() <= 3);
- }));
+ .addFilter(filterFunction));
}
/// Transformation to propagate FillOp + CopyOp to temp allocation.
@@ -275,8 +305,8 @@
// allocation. This needs to be done before reduction tiling.
if (llvmgpuUseMMASync) {
RewritePatternSet promotionPatterns(&getContext());
- populatePromotionPatterns(context, promotionPatterns,
- promoteSharedMemPattern, {2});
+ populatePromotionPatterns(context, promotionPatterns, contractOpFilter,
+ {2});
if (failed(applyPatternsAndFoldGreedily(funcOp,
std::move(promotionPatterns)))) {
return signalPassFailure();
@@ -309,12 +339,29 @@
switch (promoteSharedMemPattern) {
case GPUPromoteSharedMemPattern::ContractionOpPattern:
populatePromotionPatterns(context, promotionPatterns,
- promoteSharedMemPattern, {0, 1});
+ contractOpFilter, {0, 1});
break;
case GPUPromoteSharedMemPattern::TransposeOpPattern:
- populatePromotionPatterns(context, promotionPatterns,
- promoteSharedMemPattern, {0});
+ funcOp.walk(
+ [&context, &promotionPatterns](linalg::GenericOp linalgOp) {
+ // Promotion patterns accept a fixed list of operands to promote
+ // before determine which op is being promoted. To support
+ // multiple linalg generic ops with different promoted operands,
+ // We walk each linalg generic op to determine which operands to
+ // promote, then create a filter that will only apply to it's
+ // configuration.
+ SmallVector<int64_t> operandsToPromote =
+ getTransposedOperands(linalgOp);
+ if (!operandsToPromote.empty()) {
+ populatePromotionPatterns(
+ context, promotionPatterns,
+ [linalgOp](Operation *op) -> LogicalResult {
+ return transposeFilter(op, linalgOp);
+ },
+ operandsToPromote);
+ }
+ });
break;
}
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUTileTensor.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUTileTensor.cpp
index bda6d07..5ab1571 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUTileTensor.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUTileTensor.cpp
@@ -7,7 +7,7 @@
#include <numeric>
#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h"
-#include "iree-dialects/Dialect/LinalgExt/Passes/Transforms.h"
+#include "iree-dialects/Dialect/LinalgExt/Transforms/Transforms.h"
#include "iree/compiler/Codegen/Dialect/LoweringConfig.h"
#include "iree/compiler/Codegen/PassDetail.h"
#include "iree/compiler/Codegen/Passes.h"
@@ -19,6 +19,8 @@
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+using mlir::iree_compiler::IREE::LinalgExt::TilingPatterns;
+
#define DEBUG_TYPE "iree-llvmgpu-tile-tensor"
namespace mlir {
@@ -53,9 +55,8 @@
StringAttr::get(context, getWorkgroupMemoryMarker())},
StringAttr::get(context, getWorkgroupKTiledMarker()));
filter.setMatchByDefault();
- linalg::TilingPatterns<linalg::MatmulOp, linalg::BatchMatmulOp,
- linalg::GenericOp>::insert(patterns, tilingOptions,
- filter);
+ TilingPatterns<linalg::MatmulOp, linalg::BatchMatmulOp,
+ linalg::GenericOp>::insert(patterns, tilingOptions, filter);
}
LogicalResult tileReduction(func::FuncOp funcOp) {
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp
index 41ab038..2f95f6a 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp
@@ -179,9 +179,7 @@
}
// Step 5. syncthreads.
- if (syncAfterDistribute) {
- rewriter.create<gpu::BarrierOp>(loc);
- }
+ if (syncAfterDistribute) rewriter.create<gpu::BarrierOp>(loc);
// Step 6. Erase old op.
rewriter.eraseOp(foreachThreadOp);
@@ -489,6 +487,23 @@
return laneVal;
}
+/// Return a value yielded by `warpOp` which statifies the filter lamdba
+/// condition and is not dead.
+static OpOperand *getWarpResult(vector::WarpExecuteOnLane0Op warpOp,
+ function_ref<bool(Operation *)> fn) {
+ auto yield = cast<vector::YieldOp>(
+ warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+ for (OpOperand &yieldOperand : yield->getOpOperands()) {
+ Value yieldValues = yieldOperand.get();
+ Operation *definedOp = yieldValues.getDefiningOp();
+ if (definedOp && fn(definedOp)) {
+ if (!warpOp.getResult(yieldOperand.getOperandNumber()).use_empty())
+ return &yieldOperand;
+ }
+ }
+ return {};
+}
+
namespace {
/// Pattern to convert InsertElement to broadcast, this is a workaround until
@@ -507,60 +522,113 @@
}
};
+/// Sink out load op feeding into a warp op yield.
+/// ```
+/// %0 = vector.warp_execute_on_lane_0(%arg0) -> (f32) {
+/// ...
+// %2 = memref.load %src[%c0] : memref<1024xf32>
+/// vector.yield %2 : f32
+/// }
+/// ```
+/// To
+/// ```
+/// %dead = vector.warp_execute_on_lane_0(%arg0) -> (f32) {
+/// ...
+// %2 = memref.load %src[%c0] : memref<1024xf32>
+/// vector.yield %2 : f32
+/// }
+/// gpu.synchronize
+/// %0 = memref.load %src[%c0] : memref<1024xf32>
+struct WarpOpLoad : public OpRewritePattern<vector::WarpExecuteOnLane0Op> {
+ using OpRewritePattern<vector::WarpExecuteOnLane0Op>::OpRewritePattern;
+ LogicalResult matchAndRewrite(vector::WarpExecuteOnLane0Op warpOp,
+ PatternRewriter &rewriter) const override {
+ OpOperand *operand = getWarpResult(
+ warpOp, [](Operation *op) { return isa<memref::LoadOp>(op); });
+ if (!operand) return failure();
+ auto load = operand->get().getDefiningOp<memref::LoadOp>();
+ unsigned operandIndex = operand->getOperandNumber();
+ Value distributedVal = warpOp.getResult(operandIndex);
+
+ SmallVector<Value, 4> indices(load.getIndices().begin(),
+ load.getIndices().end());
+ if (!indices.empty()) return failure();
+
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPointAfter(warpOp);
+ // TODO: generalize this.
+ // options.warpSyncronizationFn currently must take a WarpExecuteOnLane0Op
+ // which we don't have here.
+ rewriter.create<gpu::BarrierOp>(load.getLoc());
+ Value newRead = rewriter.create<memref::LoadOp>(
+ load.getLoc(), distributedVal.getType(), load.getMemref(), indices);
+
+ // The result type of WarpExecuteOnLane0Op may or may not match the yielded
+ // type depending on whether the op has "broadcast" behavior (see the doc
+ // of WarpExecuteOnLane0Op).
+ for (OpOperand &use : distributedVal.getUses()) {
+ rewriter.startRootUpdate(use.getOwner());
+ Value replacement = newRead;
+ if (use.get().getType() != newRead.getType()) {
+ replacement = rewriter.create<vector::BroadcastOp>(
+ load.getLoc(), use.get().getType(), newRead);
+ }
+ use.getOwner()->setOperand(use.getOperandNumber(), replacement);
+ rewriter.finalizeRootUpdate(use.getOwner());
+ }
+ return success();
+ }
+};
} // namespace
-static LogicalResult applyMultiReductionLoweringPatterns(Operation *target) {
+static void populateMultiReductionLoweringPatterns(Operation *target,
+ RewritePatternSet &patterns,
+ PatternBenefit benefit) {
assert(target->hasTrait<OpTrait::IsIsolatedFromAbove>());
- MLIRContext *ctx = target->getContext();
- RewritePatternSet patterns(ctx);
-
vector::populateVectorMultiReductionLoweringPatterns(
- patterns, vector::VectorMultiReductionLowering::InnerReduction);
- patterns.add<InsertElementToBroadcast>(ctx);
- return applyPatternsAndFoldGreedily(target, std::move(patterns));
+ patterns, vector::VectorMultiReductionLowering::InnerReduction, benefit);
+ patterns.add<InsertElementToBroadcast>(target->getContext(), benefit);
}
-static LogicalResult applyVectorTransferWriteDistribution(Operation *target) {
- assert(target->hasTrait<OpTrait::IsIsolatedFromAbove>());
-
- auto distributionFn = [](vector::TransferWriteOp writeOp) {
- // Create a map (d0, d1) -> (d1) to distribute along the inner
- // dimension. Once we support n-d distribution we can add more
- // complex cases.
- int64_t vecRank = writeOp.getVectorType().getRank();
- OpBuilder builder(writeOp.getContext());
- auto map =
- AffineMap::get(vecRank, 0, builder.getAffineDimExpr(vecRank - 1));
- return map;
- };
- MLIRContext *ctx = target->getContext();
- RewritePatternSet patterns(ctx);
- vector::populateDistributeTransferWriteOpPatterns(patterns, distributionFn);
- return applyPatternsAndFoldGreedily(target, std::move(patterns));
+static AffineMap simpleDistributionFunction(vector::TransferWriteOp writeOp) {
+ // Create a map (d0, d1) -> (d1) to distribute along the inner
+ // dimension. Once we support n-d distribution we can add more
+ // complex cases.
+ int64_t vecRank = writeOp.getVectorType().getRank();
+ OpBuilder builder(writeOp.getContext());
+ auto map = AffineMap::get(vecRank, 0, builder.getAffineDimExpr(vecRank - 1));
+ return map;
}
-static LogicalResult applyPropagateVectorDistribution(Operation *target) {
+static void populateVectorTransferWriteDistribution(Operation *target,
+ RewritePatternSet &patterns,
+ PatternBenefit benefit) {
assert(target->hasTrait<OpTrait::IsIsolatedFromAbove>());
-
- MLIRContext *ctx = target->getContext();
- RewritePatternSet patterns(ctx);
- vector::populatePropagateWarpVectorDistributionPatterns(patterns);
- vector::populateDistributeReduction(patterns, warpReduction);
- return applyPatternsAndFoldGreedily(target, std::move(patterns));
+ vector::populateDistributeTransferWriteOpPatterns(
+ patterns, simpleDistributionFunction, benefit);
}
-static LogicalResult applyWarpExecuteOnLane0ToScf(Operation *target) {
+static void populatePropagateVectorDistribution(Operation *target,
+ RewritePatternSet &patterns,
+ PatternBenefit benefit) {
assert(target->hasTrait<OpTrait::IsIsolatedFromAbove>());
+ vector::populatePropagateWarpVectorDistributionPatterns(patterns, benefit);
+ vector::populateDistributeReduction(patterns, warpReduction, benefit);
+ patterns.add<WarpOpLoad>(target->getContext(), benefit);
+}
- MLIRContext *ctx = target->getContext();
- RewritePatternSet patterns(ctx);
- vector::WarpExecuteOnLane0LoweringOptions options;
- options.warpAllocationFn = allocateGlobalSharedMemory;
- options.warpSyncronizationFn = [](Location loc, OpBuilder &builder,
- vector::WarpExecuteOnLane0Op warpOp) {};
- vector::populateWarpExecuteOnLane0OpToScfForPattern(patterns, options);
- return applyPatternsAndFoldGreedily(target, std::move(patterns));
+static void warpSyncronizationFn(Location loc, OpBuilder &builder,
+ vector::WarpExecuteOnLane0Op warpOp) {
+ builder.create<gpu::BarrierOp>(loc);
+};
+
+static void populateWarpExecuteOnLane0ToScf(
+ Operation *target, RewritePatternSet &patterns,
+ vector::WarpExecuteOnLane0LoweringOptions options, PatternBenefit benefit) {
+ assert(target->hasTrait<OpTrait::IsIsolatedFromAbove>());
+ vector::populateWarpExecuteOnLane0OpToScfForPattern(patterns, options,
+ benefit);
}
DiagnosedSilenceableFailure
@@ -577,22 +645,24 @@
// TODO: Hook up into the ApplyPatternOp in CommonExtensions.cpp to
// automatically get listening capabilities.
+ MLIRContext *ctx = target->getContext();
+ RewritePatternSet patterns(ctx);
// MultiReduction lowering is necessary until we have explicit support for
// distributing that op.
- if (failed(applyMultiReductionLoweringPatterns(target))) {
- target->emitOpError("multi-reduction lowering patterns failed to apply");
+ populateMultiReductionLoweringPatterns(target, patterns, /*benefit=*/3);
+ populateVectorTransferWriteDistribution(target, patterns, /*benefit=*/2);
+ populatePropagateVectorDistribution(target, patterns, /*benefit=*/1);
+ if (failed(applyPatternsAndFoldGreedily(target, std::move(patterns)))) {
+ target->emitOpError("warp distribution patterns failed to apply");
return DiagnosedSilenceableFailure(reportUnknownTransformError(target));
}
- if (failed(applyVectorTransferWriteDistribution(target))) {
- target->emitOpError("transfer write distribution patterns failed to apply");
- return DiagnosedSilenceableFailure(reportUnknownTransformError(target));
- }
- if (failed(applyPropagateVectorDistribution(target))) {
- target->emitOpError(
- "propagate vector distribution patterns failed to apply");
- return DiagnosedSilenceableFailure(reportUnknownTransformError(target));
- }
- if (failed(applyWarpExecuteOnLane0ToScf(target))) {
+
+ RewritePatternSet endPatterns(ctx);
+ vector::WarpExecuteOnLane0LoweringOptions options;
+ options.warpAllocationFn = allocateGlobalSharedMemory;
+ options.warpSyncronizationFn = warpSyncronizationFn;
+ populateWarpExecuteOnLane0ToScf(target, endPatterns, options, /*benefit=*/0);
+ if (failed(applyPatternsAndFoldGreedily(target, std::move(endPatterns)))) {
target->emitOpError(
"warp execute on lane 0 to scf patterns failed to apply");
return DiagnosedSilenceableFailure(reportUnknownTransformError(target));
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/nvvm_pipeline_test.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/nvvm_pipeline_test.mlir
index 1992d69..7d02e83 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/nvvm_pipeline_test.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/nvvm_pipeline_test.mlir
@@ -522,8 +522,10 @@
// MMASYNC-COUNT-4: nvvm.ldmatrix{{.*}} : (!llvm.ptr<f32, 3>) -> !llvm.struct<(i32, i32)
// MMASYNC-COUNT-8: nvvm.mma.sync
// MMASYNC-COUNT-4: llvm.store {{.*}} : !llvm.ptr<vector<2xf32>, 3>
-// MMASYNC-COUNT-2: llvm.load {{.*}} : !llvm.ptr<vector<4xf32>, 3>
-// MMASYNC-COUNT-2: llvm.store {{.*}} : !llvm.ptr<vector<4xf32>>
+// MMASYNC-COUNT: llvm.load {{.*}} : !llvm.ptr<vector<4xf32>, 3>
+// MMASYNC-COUNT: llvm.store {{.*}} : !llvm.ptr<vector<4xf32>>
+// MMASYNC-COUNT: llvm.load {{.*}} : !llvm.ptr<vector<4xf32>, 3>
+// MMASYNC-COUNT: llvm.store {{.*}} : !llvm.ptr<vector<4xf32>>
// C matrix promotion prevent efficient fusion with matmul consumer, this needs
// to be fixed to get good performance.
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_dialect_bufferize.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_dialect_bufferize.mlir
index ad1370a..2e1078a 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_dialect_bufferize.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_dialect_bufferize.mlir
@@ -26,11 +26,9 @@
return %6: tensor<250x1020xf32>
}
}
- transform.with_pdl_patterns {
- ^bb0(%arg0: !pdl.operation):
- transform.structured.canonicalized_sequence %arg0 failures(propagate) {
- ^bb1(%variant_op: !pdl.operation):
- transform.iree.bufferize { target_gpu } %variant_op
- }
+
+ transform.structured.canonicalized_sequence failures(propagate) {
+ ^bb1(%variant_op: !pdl.operation):
+ transform.iree.bufferize { target_gpu } %variant_op
}
}
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_dialect_codegen_bufferize_spec.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_dialect_codegen_bufferize_spec.mlir
index 24c98ce..ea6606b 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_dialect_codegen_bufferize_spec.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_dialect_codegen_bufferize_spec.mlir
@@ -1,7 +1,4 @@
-transform.with_pdl_patterns {
-^bb0(%arg0: !pdl.operation):
- transform.structured.canonicalized_sequence %arg0 failures(propagate) {
- ^bb1(%variant_op: !pdl.operation):
- transform.iree.bufferize %variant_op
- }
+transform.structured.canonicalized_sequence failures(propagate) {
+^bb1(%variant_op: !pdl.operation):
+ transform.iree.bufferize %variant_op
}
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_dialect_codegen_foreach_to_gpu_spec.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_dialect_codegen_foreach_to_gpu_spec.mlir
index c7f5571..93a56af 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_dialect_codegen_foreach_to_gpu_spec.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_dialect_codegen_foreach_to_gpu_spec.mlir
@@ -1,18 +1,15 @@
-transform.with_pdl_patterns {
-^bb0(%arg0: !pdl.operation):
- transform.structured.canonicalized_sequence %arg0 failures(propagate) {
- ^bb1(%variant_op: !pdl.operation):
- %0 = transform.structured.match ops{["linalg.fill"]} in %variant_op
- %foreach_thread, %tiled_fill = transform.structured.tile_to_foreach_thread_op %0 num_threads [5, 1] (mapped to dims [1, 0, 2])
+transform.structured.canonicalized_sequence failures(propagate) {
+^bb1(%variant_op: !pdl.operation):
+ %0 = transform.structured.match ops{["linalg.fill"]} in %variant_op
+ %foreach_thread, %tiled_fill = transform.structured.tile_to_foreach_thread_op %0 num_threads [5, 1] (mapped to dims [1, 0, 2])
- %1 = transform.structured.match ops{["linalg.matmul"]} in %variant_op
- %foreach_thread_2, %tiled_matmul = transform.structured.tile_to_foreach_thread_op %1 num_threads [7, 9]
+ %1 = transform.structured.match ops{["linalg.matmul"]} in %variant_op
+ %foreach_thread_2, %tiled_matmul = transform.structured.tile_to_foreach_thread_op %1 num_threads [7, 9]
- %variant_op_2 = transform.iree.bufferize %variant_op
+ %variant_op_2 = transform.iree.bufferize %variant_op
- // Get the function to which to apply to.
- %2 = transform.structured.match ops{["linalg.matmul"]} in %variant_op_2
- %func = transform.get_closest_isolated_parent %2
- transform.iree.foreach_thread_to_gpu_and_translation_info %func { workgroup_size = [10, 11]}
- }
+ // Get the function to which to apply to.
+ %2 = transform.structured.match ops{["linalg.matmul"]} in %variant_op_2
+ %func = transform.get_closest_isolated_parent %2
+ transform.iree.foreach_thread_to_gpu_and_translation_info %func { workgroup_size = [10, 11]}
}
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_dialect_codegen_vector_distribution_spec.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_dialect_codegen_vector_distribution_spec.mlir
index 3a0135e..48ef112 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_dialect_codegen_vector_distribution_spec.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_dialect_codegen_vector_distribution_spec.mlir
@@ -1,10 +1,7 @@
-transform.with_pdl_patterns {
-^bb0(%arg0: !pdl.operation):
- transform.structured.canonicalized_sequence %arg0 failures(propagate) {
- ^bb1(%arg1: !pdl.operation):
- %if_op = transform.structured.match ops{["scf.if"]} in %arg1
- %warp = transform.iree.vector.to_warp_execute_on_lane_0 %if_op { warp_size = 32 }
- %isolated = transform.get_closest_isolated_parent %warp
- transform.iree.vector.warp_distribute %isolated
- }
+transform.structured.canonicalized_sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+ %if_op = transform.structured.match ops{["scf.if"]} in %arg1
+ %warp = transform.iree.vector.to_warp_execute_on_lane_0 %if_op { warp_size = 32 }
+ %isolated = transform.get_closest_isolated_parent %warp
+ transform.iree.vector.warp_distribute %isolated
}
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_dialect_codegen_vector_warp_execute_on_lane_0_spec.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_dialect_codegen_vector_warp_execute_on_lane_0_spec.mlir
index e4c1a20..e24e76b 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_dialect_codegen_vector_warp_execute_on_lane_0_spec.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_dialect_codegen_vector_warp_execute_on_lane_0_spec.mlir
@@ -1,8 +1,5 @@
-transform.with_pdl_patterns {
-^bb0(%arg0: !pdl.operation):
- transform.structured.canonicalized_sequence %arg0 failures(propagate) {
- ^bb1(%arg1: !pdl.operation):
- %if_op = transform.structured.match ops{["scf.if"]} in %arg1
- transform.iree.vector.to_warp_execute_on_lane_0 %if_op { warp_size = 32 }
- }
+transform.structured.canonicalized_sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+ %if_op = transform.structured.match ops{["scf.if"]} in %arg1
+ transform.iree.vector.to_warp_execute_on_lane_0 %if_op { warp_size = 32 }
}
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transpose_pipeline_test.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transpose_pipeline_test.mlir
index 5967f8e..b418007 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transpose_pipeline_test.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transpose_pipeline_test.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt --split-input-file --pass-pipeline='hal.executable(hal.executable.variant(iree-llvmgpu-lower-executable-target-pass))' %s | FileCheck %s
+// RUN: iree-opt --split-input-file --pass-pipeline='hal.executable(hal.executable.variant(iree-llvmgpu-lower-executable-target-pass))' %s --fold-memref-alias-ops -canonicalize -cse | FileCheck %s
#device_target_cuda = #hal.device.target<"cuda", {executable_targets = [#hal.executable.target<"cuda", "cuda-nvptx-fb", {target_arch = "sm_80"}>], legacy_sync}>
#executable_target_cuda_nvptx_fb = #hal.executable.target<"cuda", "cuda-nvptx-fb", {target_arch = "sm_80"}>
@@ -34,18 +34,68 @@
// CHECK: hal.executable.variant public @cuda
// CHECK-DAG: %[[C0:.+]] = arith.constant 0
// CHECK-DAG: %[[CST:.+]] = arith.constant 0
-// CHECK: %[[D3:.+]] = memref.alloc() : memref<32x33xf32, 3>
-// CHECK: %[[D4:.+]] = memref.subview %[[D3]][0, 0] [32, 32] [1, 1] : memref<32x33xf32, 3> to memref<32x32xf32, #{{.*}}, 3>
-// CHECK: %[[D9:.+]] = memref.subview %[[D6:.+]][%{{.*}}, %{{.*}}] [32, 32] [1, 1] : memref<4096x4096xf32> to memref<32x32xf32, #{{.*}}>
-// CHECK: %[[D10:.+]] = memref.subview %[[D5:.+]][%{{.*}}, %{{.*}}] [32, 32] [1, 1] : memref<4096x4096xf32> to memref<32x32xf32, #{{.*}}>
+// CHECK-DAG: %[[IN:.+]] = hal.interface.binding.subspan set(0) binding(0)
+// CHECK-DAG: %[[OUT:.+]] = hal.interface.binding.subspan set(0) binding(1)
+// CHECK-DAG: %[[ALLOC:.+]] = memref.alloc() : memref<32x33xf32, 3>
// CHECK: gpu.barrier
-// CHECK: %[[D13:.+]] = vector.transfer_read %[[D10]][%{{.*}}, %{{.*}}], %[[CST]] {in_bounds = [true]} : memref<32x32xf32, #{{.*}}>, vector<4xf32>
-// CHECK: vector.transfer_write %[[D13]], %[[D4]][%{{.*}}, %{{.*}}] {in_bounds = [true]} : vector<4xf32>, memref<32x32xf32, #{{.*}}, 3>
+// CHECK: %[[R0:.+]] = vector.transfer_read %[[IN]][%{{.*}}, %{{.*}}], %[[CST]] {in_bounds = [true]} : memref<4096x4096xf32>, vector<4xf32>
+// CHECK: vector.transfer_write %[[R0]], %[[ALLOC]][%{{.*}}, %{{.*}}] {in_bounds = [true]} : vector<4xf32>, memref<32x33xf32, 3>
// CHECK: gpu.barrier
-// CHECK: %[[D15:.+]] = memref.subview %[[D4]][%{{.*}}, %{{.*}}] [4, 1] [1, 1] : memref<32x32xf32, #{{.*}}, 3> to memref<4x1xf32, #{{.*}}, 3>
-// CHECK: %[[D16:.+]] = memref.subview %[[D9]][%{{.*}}, %{{.*}}] [1, 4] [1, 1] : memref<32x32xf32, #{{.*}}> to memref<1x4xf32, #{{.*}}>
-// CHECK: %[[D17:.+]] = vector.transfer_read %[[D15]][%{{.*}}, %{{.*}}], %[[CST]] {in_bounds = [true, true]} : memref<4x1xf32, #{{.*}}, 3>, vector<4x1xf32>
-// CHECK: %[[D18:.+]] = vector.shape_cast %[[D17]] : vector<4x1xf32> to vector<1x4xf32>
-// CHECK: %[[D19:.+]] = vector.extract %[[D18]][0] : vector<1x4xf32>
-// CHECK: vector.transfer_write %[[D19]], %[[D16]][%{{.*}}, %{{.*}}] {in_bounds = [true]} : vector<4xf32>, memref<1x4xf32, #{{.*}}>
+// CHECK: %[[R1:.+]] = vector.transfer_read %[[ALLOC]][%{{.*}}, %{{.*}}], %[[CST]] {in_bounds = [true, true]} : memref<32x33xf32, 3>, vector<4x1xf32>
+// CHECK: %[[R2:.+]] = vector.shape_cast %[[R1]] : vector<4x1xf32> to vector<1x4xf32>
+// CHECK: %[[R3:.+]] = vector.extract %[[R2]][0] : vector<1x4xf32>
+// CHECK: vector.transfer_write %[[R3]], %[[OUT]][%{{.*}}, %{{.*}}] {in_bounds = [true]} : vector<4xf32>, memref<4096x4096xf32>
+
// -----
+
+#device_target_cuda = #hal.device.target<"cuda", {executable_targets = [#hal.executable.target<"cuda", "cuda-nvptx-fb", {target_arch = "sm_80"}>], legacy_sync}>
+#executable_target_cuda_nvptx_fb = #hal.executable.target<"cuda", "cuda-nvptx-fb", {target_arch = "sm_80"}>
+#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [<0, bindings = [<0, storage_buffer, ReadOnly>, <1, storage_buffer>]>]>
+module attributes {hal.device.targets = [#device_target_cuda]} {
+ hal.executable @transpose_single_operand_dispatch_0_generic_768x2048 {
+ hal.executable.variant public @cuda_nvptx_fb, target = #executable_target_cuda_nvptx_fb {
+ hal.executable.export public @transpose_single_operand_dispatch_0_generic_768x2048 ordinal(0) layout(#pipeline_layout) {
+ ^bb0(%arg0: !hal.device, %arg1: index, %arg2: index):
+ %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg1, %arg2
+ hal.return %x, %y, %z : index, index, index
+ }
+ builtin.module {
+ func.func @transpose_single_operand_dispatch_0_generic_768x2048() {
+ %c0 = arith.constant 0 : index
+ %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<readonly:2048x768xf32>
+ %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<readonly:768x2048xf32>
+ %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<writeonly:768x2048xf32>
+ %3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [2048, 768], strides = [1, 1] : !flow.dispatch.tensor<readonly:2048x768xf32> -> tensor<2048x768xf32>
+ %4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [768, 2048], strides = [1, 1] : !flow.dispatch.tensor<readonly:768x2048xf32> -> tensor<768x2048xf32>
+ %5 = linalg.init_tensor [768, 2048] : tensor<768x2048xf32>
+ %6 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%3, %4 : tensor<2048x768xf32>, tensor<768x2048xf32>) outs(%5 : tensor<768x2048xf32>) {
+ ^bb0(%arg0: f32, %arg1: f32, %arg2: f32):
+ %7 = arith.addf %arg0, %arg1 : f32
+ linalg.yield %7 : f32
+ } -> tensor<768x2048xf32>
+ flow.dispatch.tensor.store %6, %2, offsets = [0, 0], sizes = [768, 2048], strides = [1, 1] : tensor<768x2048xf32> -> !flow.dispatch.tensor<writeonly:768x2048xf32>
+ return
+ }
+ }
+ }
+ }
+}
+
+// CHECK-LABEL: hal.executable public @transpose_single_operand_dispatch_0_generic_768x2048
+// CHECK: hal.executable.variant public @cuda
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0
+// CHECK-DAG: %[[CST:.+]] = arith.constant 0
+// CHECK-DAG: %[[ALLOC:.+]] = memref.alloc() : memref<32x33xf32, 3>
+// CHECK-DAG: %[[IN0:.+]] = hal.interface.binding.subspan set(0) binding(0)
+// CHECK-DAG: %[[IN1:.+]] = hal.interface.binding.subspan set(0) binding(1)
+// CHECK-DAG: %[[OUT:.+]] = hal.interface.binding.subspan set(0) binding(2)
+// CHECK: gpu.barrier
+// CHECK: %[[R0:.+]] = vector.transfer_read %[[IN0]][%{{.*}}, %{{.*}}], %[[CST]] {in_bounds = [true]} : memref<2048x768xf32>, vector<4xf32>
+// CHECK: vector.transfer_write %[[R0]], %[[ALLOC]][%{{.*}}, %{{.*}}] {in_bounds = [true]} : vector<4xf32>, memref<32x33xf32, 3>
+// CHECK: gpu.barrier
+// CHECK: %[[R1:.+]] = vector.transfer_read %[[ALLOC]][%{{.*}}, %{{.*}}], %[[CST]] {in_bounds = [true, true]} : memref<32x33xf32, 3>, vector<4x1xf32>
+// CHECK: %[[R2:.+]] = vector.shape_cast %[[R1]] : vector<4x1xf32> to vector<1x4xf32>
+// CHECK: %[[R3:.+]] = vector.transfer_read %[[IN1]][%{{.*}}, %{{.*}}], %[[CST]] {in_bounds = [true]} : memref<768x2048xf32>, vector<4xf32>
+// CHECK: %[[R4:.+]] = vector.extract %[[R2]][0] : vector<1x4xf32>
+// CHECK: %[[R5:.+]] = arith.addf %[[R4]], %[[R3]] : vector<4xf32>
+// CHECK: vector.transfer_write %[[R5]], %[[OUT]][%{{.*}}, %{{.*}}] {in_bounds = [true]} : vector<4xf32>, memref<768x2048xf32>
diff --git a/compiler/src/iree/compiler/Codegen/Passes.cpp b/compiler/src/iree/compiler/Codegen/Passes.cpp
index bca39a5..f7eb1ea 100644
--- a/compiler/src/iree/compiler/Codegen/Passes.cpp
+++ b/compiler/src/iree/compiler/Codegen/Passes.cpp
@@ -45,10 +45,9 @@
static PassPipelineRegistration<> LinalgSPIRVPipeline(
"iree-codegen-linalg-to-spirv-pipeline",
- "Runs the progressive lowering pipeline from XLA HLO to Linalg to "
- "SPIR-V",
+ "Runs the progressive lowering pipeline from linalg to SPIR-V",
[](OpPassManager &passManager) {
- buildSPIRVCodegenPassPipeline(passManager);
+ buildSPIRVCodegenPassPipeline(passManager, /*enableFastMath=*/false);
});
}
diff --git a/compiler/src/iree/compiler/Codegen/Passes.h b/compiler/src/iree/compiler/Codegen/Passes.h
index d4969cc..689cc34 100644
--- a/compiler/src/iree/compiler/Codegen/Passes.h
+++ b/compiler/src/iree/compiler/Codegen/Passes.h
@@ -437,7 +437,8 @@
/// This pass converts remaining interface ops into SPIR-V global variables,
/// GPU processor ID ops into SPIR-V global variables, loop/standard ops into
/// corresponding SPIR-V ops.
-std::unique_ptr<OperationPass<ModuleOp>> createConvertToSPIRVPass();
+std::unique_ptr<OperationPass<ModuleOp>> createConvertToSPIRVPass(
+ bool enableFastMath = false);
/// Creates a pass to fold processor ID uses where possible.
std::unique_ptr<OperationPass<func::FuncOp>>
@@ -490,14 +491,10 @@
// SPIRV Codegen Pass Pipelines.
//----------------------------------------------------------------------------//
-/// Populates passes needed to lower a XLA HLO op to SPIR-V dialect via the
-/// structured ops path. The pass manager `pm` in here operate on the module
-/// within the IREE::HAL::ExecutableOp. The `workGroupSize` can be used to
-/// control the work group size used in the code generation and is intended for
-/// testing purposes only. The pass pipeline will set an appropriate workgroup
-/// size.
-/// TODO: Are both of these needed and does this one still work on HLO?
-void buildSPIRVCodegenPassPipeline(OpPassManager &pm);
+/// Populates passes needed to lower linalg/arith/math ops to SPIR-V ops via the
+/// structured ops path. The pass manager `pm` here operate on the module
+/// within the IREE::HAL::ExecutableOp.
+void buildSPIRVCodegenPassPipeline(OpPassManager &pm, bool enableFastMath);
//------------------------------------------------------------------------------
// VMVX passes
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/ConvertToSPIRVPass.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/ConvertToSPIRVPass.cpp
index a59de84..3ea7be5 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/ConvertToSPIRVPass.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/ConvertToSPIRVPass.cpp
@@ -288,15 +288,19 @@
/// This pass converts remaining interface ops into SPIR-V global variables,
/// GPU processor ID ops into SPIR-V global variables, loop/standard ops into
/// corresponding SPIR-V ops.
-struct ConvertToSPIRVPass : public ConvertToSPIRVBase<ConvertToSPIRVPass> {
+class ConvertToSPIRVPass : public ConvertToSPIRVBase<ConvertToSPIRVPass> {
+ public:
void getDependentDialects(DialectRegistry ®istry) const override {
registry.insert<spirv::SPIRVDialect>();
}
- ConvertToSPIRVPass() {}
- ConvertToSPIRVPass(const ConvertToSPIRVPass &pass) {}
+ explicit ConvertToSPIRVPass(bool enableFastMath)
+ : enableFastMath(enableFastMath) {}
void runOnOperation() override;
+
+ private:
+ bool enableFastMath;
};
} // namespace
@@ -327,7 +331,10 @@
spirv::TargetEnvAttr targetAttr = getSPIRVTargetEnvAttr(moduleOp);
moduleOp->setAttr(spirv::getTargetEnvAttrName(), targetAttr);
- SPIRVTypeConverter typeConverter(targetAttr);
+
+ SPIRVConversionOptions options = {};
+ options.enableFastMathMode = this->enableFastMath;
+ SPIRVTypeConverter typeConverter(targetAttr, options);
RewritePatternSet patterns(&getContext());
ScfToSPIRVContext scfToSPIRVContext;
@@ -432,8 +439,9 @@
// Pass entry point and registration
//===----------------------------------------------------------------------===//
-std::unique_ptr<OperationPass<ModuleOp>> createConvertToSPIRVPass() {
- return std::make_unique<ConvertToSPIRVPass>();
+std::unique_ptr<OperationPass<ModuleOp>> createConvertToSPIRVPass(
+ bool enableFastMath) {
+ return std::make_unique<ConvertToSPIRVPass>(enableFastMath);
}
} // namespace iree_compiler
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp
index dc551c6..a0cbf27 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp
@@ -211,7 +211,7 @@
int bIndex = -1, mIndex = -1, nIndex = -1, kIndex = -1;
int lastParallelDim = -1;
for (unsigned i = 0; i < op.getNumLoops(); ++i) {
- if (isReductionIterator(op.getIteratorTypes()[i])) {
+ if (linalg::isReductionIterator(op.getIteratorTypes()[i])) {
kIndex = i;
continue;
}
@@ -663,8 +663,10 @@
for (const auto &it : llvm::enumerate(linalgOp.getIteratorTypes())) {
auto i = it.index();
if (loopBounds[i] % 4 != 0) continue;
- if (isReductionIterator(it.value()) || workgroupTileSizes[i] == 0)
+ if (linalg::isReductionIterator(it.value()) ||
+ workgroupTileSizes[i] == 0) {
loopTileSizes[it.index()] = 4;
+ }
}
if (llvm::any_of(loopTileSizes, [](int64_t s) { return s != 0; })) {
tileSizes.push_back(loopTileSizes);
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp
index bbf8d4a..7fbf0d1 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp
@@ -177,7 +177,7 @@
}
/// Adds passes to perform the final SPIR-V conversion.
-static void addSPIRVLoweringPasses(OpPassManager &pm) {
+static void addSPIRVLoweringPasses(OpPassManager &pm, bool enableFastMath) {
pm.addPass(createCanonicalizerPass());
pm.addPass(createCSEPass());
@@ -186,7 +186,7 @@
pm.addPass(createCSEPass());
pm.addPass(createMapMemRefStorageClassPass());
- pm.addPass(createConvertToSPIRVPass());
+ pm.addPass(createConvertToSPIRVPass(enableFastMath));
OpPassManager &spirvPM = pm.nest<spirv::ModuleOp>();
spirvPM.addPass(spirv::createUnifyAliasedResourcePass());
@@ -358,13 +358,13 @@
// Entry Point
//===----------------------------------------------------------------------===//
-void buildSPIRVCodegenPassPipeline(OpPassManager &pm) {
+void buildSPIRVCodegenPassPipeline(OpPassManager &pm, bool enableFastMath) {
pm.nest<ModuleOp>().nest<func::FuncOp>().addPass(createTypePropagationPass());
pm.nest<ModuleOp>().addPass(createBufferizeCopyOnlyDispatchesPass());
pm.addPass(createSPIRVLowerExecutableTargetPass());
addMemRefLoweringPasses(pm.nest<ModuleOp>());
- addSPIRVLoweringPasses(pm.nest<ModuleOp>());
+ addSPIRVLoweringPasses(pm.nest<ModuleOp>(), enableFastMath);
LLVM_DEBUG({
llvm::dbgs() << "Using SPIR-V pass pipeline:\n";
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTile.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTile.cpp
index b128945..2b929e6 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTile.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTile.cpp
@@ -11,6 +11,7 @@
//===----------------------------------------------------------------------===//
#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h"
+#include "iree-dialects/Dialect/LinalgExt/Transforms/Transforms.h"
#include "iree/compiler/Codegen/Dialect/LoweringConfig.h"
#include "iree/compiler/Codegen/PassDetail.h"
#include "iree/compiler/Codegen/Passes.h"
@@ -29,6 +30,8 @@
#include "mlir/IR/Matchers.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+using mlir::iree_compiler::IREE::LinalgExt::TilingPatterns;
+
#define DEBUG_TYPE "iree-spirv-tile"
namespace mlir {
@@ -51,10 +54,9 @@
auto marker = StringAttr::get(context, getTileReductionMarker());
auto filter = linalg::LinalgTransformationFilter({marker}, llvm::None);
- linalg::TilingPatterns<linalg::BatchMatmulOp, linalg::Conv2DNhwcHwcfOp,
- linalg::DepthwiseConv2DNhwcHwcOp, linalg::GenericOp,
- linalg::MatmulOp>::insert(patterns, tilingOptions,
- filter);
+ TilingPatterns<linalg::BatchMatmulOp, linalg::Conv2DNhwcHwcfOp,
+ linalg::DepthwiseConv2DNhwcHwcOp, linalg::GenericOp,
+ linalg::MatmulOp>::insert(patterns, tilingOptions, filter);
}
//===----------------------------------------------------------------------===//
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTileAndDistribute.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTileAndDistribute.cpp
index de0fdd4..ce33da8 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTileAndDistribute.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTileAndDistribute.cpp
@@ -13,6 +13,7 @@
#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h"
#include "iree-dialects/Dialect/LinalgExt/Passes/Transforms.h"
+#include "iree-dialects/Dialect/LinalgExt/Transforms/Transforms.h"
#include "iree/compiler/Codegen/Dialect/LoweringConfig.h"
#include "iree/compiler/Codegen/PassDetail.h"
#include "iree/compiler/Codegen/Passes.h"
@@ -40,6 +41,8 @@
#include "mlir/Transforms/FoldUtils.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+using mlir::iree_compiler::IREE::LinalgExt::TilingPatterns;
+
#define DEBUG_TYPE "iree-spirv-tile-and-distribute"
namespace mlir {
@@ -97,10 +100,9 @@
.setLoopType(linalg::LinalgTilingLoopType::Loops)
.setTileSizeComputationFunction(getTileSizeFn);
- linalg::TilingPatterns<linalg::BatchMatmulOp, linalg::Conv2DNhwcHwcfOp,
- linalg::DepthwiseConv2DNhwcHwcOp,
- linalg::MatmulOp>::insert(patterns, tilingOptions,
- filter);
+ TilingPatterns<linalg::BatchMatmulOp, linalg::Conv2DNhwcHwcfOp,
+ linalg::DepthwiseConv2DNhwcHwcOp,
+ linalg::MatmulOp>::insert(patterns, tilingOptions, filter);
}
//===----------------------------------------------------------------------===//
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTileAndPromote.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTileAndPromote.cpp
index 18b66b5..dd3ff94 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTileAndPromote.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTileAndPromote.cpp
@@ -30,6 +30,8 @@
#define DEBUG_TYPE "iree-spirv-tile-and-promote"
+using mlir::iree_compiler::IREE::LinalgExt::TilingPatterns;
+
namespace mlir {
namespace iree_compiler {
@@ -46,7 +48,7 @@
auto tilingOptions = linalg::LinalgTilingOptions()
.setLoopType(linalg::LinalgTilingLoopType::Loops)
.setTileSizeComputationFunction(getTileSizeFn);
- linalg::TilingPatterns<linalg::BatchMatmulOp, linalg::MatmulOp>::insert(
+ TilingPatterns<linalg::BatchMatmulOp, linalg::MatmulOp>::insert(
patterns, tilingOptions, filter);
}
@@ -74,10 +76,8 @@
.setTileSizeComputationFunction(getTileSizeFn)
.setDistributionOptions(distributionOptions);
- linalg::TilingPatterns<linalg::BatchMatmulOp, linalg::FillOp,
- linalg::GenericOp,
- linalg::MatmulOp>::insert(patterns, tilingOptions,
- filter);
+ TilingPatterns<linalg::BatchMatmulOp, linalg::FillOp, linalg::GenericOp,
+ linalg::MatmulOp>::insert(patterns, tilingOptions, filter);
}
//===----------------------------------------------------------------------===//
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTileAndVectorizeToCooperativeOps.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTileAndVectorizeToCooperativeOps.cpp
index 1ebd710..6ee225a 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTileAndVectorizeToCooperativeOps.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTileAndVectorizeToCooperativeOps.cpp
@@ -13,6 +13,7 @@
#include <algorithm>
+#include "iree-dialects/Dialect/LinalgExt/Transforms/Transforms.h"
#include "iree/compiler/Codegen/Dialect/LoweringConfig.h"
#include "iree/compiler/Codegen/PassDetail.h"
#include "iree/compiler/Codegen/Passes.h"
@@ -34,6 +35,10 @@
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+using mlir::iree_compiler::IREE::LinalgExt::LinalgVectorizationPattern;
+using mlir::iree_compiler::IREE::LinalgExt::TilingPatterns;
+using mlir::iree_compiler::IREE::LinalgExt::VectorizationPatterns;
+
#define DEBUG_TYPE "iree-spirv-tile-and-vectorize-to-cooperative-ops"
namespace mlir {
@@ -131,9 +136,8 @@
auto filter = linalg::LinalgTransformationFilter(
ArrayRef<StringAttr>{}, StringAttr::get(context, getVectorizeMarker()));
- linalg::TilingPatterns<linalg::FillOp, linalg::MatmulOp,
- linalg::GenericOp>::insert(patterns, tilingOptions,
- filter);
+ TilingPatterns<linalg::FillOp, linalg::MatmulOp, linalg::GenericOp>::insert(
+ patterns, tilingOptions, filter);
}
//===----------------------------------------------------------------------===//
@@ -146,9 +150,9 @@
linalg::LinalgVectorizationOptions opt;
linalg::LinalgTransformationFilter f(
StringAttr::get(context, getVectorizeMarker()));
- linalg::VectorizationPatterns<linalg::FillOp, linalg::GenericOp>::insert(
- patterns, opt, f);
- patterns.add<linalg::LinalgVectorizationPattern>(
+ VectorizationPatterns<linalg::FillOp, linalg::GenericOp>::insert(patterns,
+ opt, f);
+ patterns.add<LinalgVectorizationPattern>(
context, f.addOpFilter<linalg::ContractionOpInterface>(), opt);
vector::populateVectorTransferPermutationMapLoweringPatterns(patterns);
vector::populateVectorReductionToContractPatterns(patterns);
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVVectorToCooperativeOps.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVVectorToCooperativeOps.cpp
index 6b51c7a..35d6537 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVVectorToCooperativeOps.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVVectorToCooperativeOps.cpp
@@ -125,9 +125,9 @@
// Check that this is a matmul operation.
auto iterators = contractOp.getIteratorTypes().getValue();
- if (iterators.size() != 3 || !isParallelIterator(iterators[0]) ||
- !isParallelIterator(iterators[1]) ||
- !isReductionIterator(iterators[2])) {
+ if (iterators.size() != 3 || !vector::isParallelIterator(iterators[0]) ||
+ !vector::isParallelIterator(iterators[1]) ||
+ !vector::isReductionIterator(iterators[2])) {
return failure();
}
if (contractOp.getKind() != vector::CombiningKind::ADD) return failure();
@@ -233,19 +233,19 @@
// a result of performing cooperative matrix conversions earlier (it needs
// to be done before FlattenMemRefSubspanPass because we need 2-D MemRefs)
// and conversions spreading across upstream and IREE repos..
- typeConverter.addConversion(
- [&typeConverter](MemRefType type) -> Optional<Type> {
- if (!type.hasStaticShape()) return llvm::None;
- // In IREE all MemRefs are originated from subspan ops, which should
- // have identity layout.
- if (!type.getLayout().isIdentity()) return llvm::None;
- auto storage = spirv::mapMemorySpaceToVulkanStorageClass(
- type.getMemorySpaceAsInt());
- auto flattenedType = MemRefType::get(
- ShapedType::kDynamicSize, type.getElementType(), AffineMap(),
- spirv::StorageClassAttr::get(type.getContext(), *storage));
- return typeConverter.convertType(flattenedType);
- });
+ typeConverter.addConversion([&typeConverter](
+ MemRefType type) -> Optional<Type> {
+ if (!type.hasStaticShape()) return llvm::None;
+ // In IREE all MemRefs are originated from subspan ops, which should
+ // have identity layout.
+ if (!type.getLayout().isIdentity()) return llvm::None;
+ auto storage =
+ spirv::mapMemorySpaceToVulkanStorageClass(type.getMemorySpaceAsInt());
+ auto flattenedType = MemRefType::get(
+ ShapedType::kDynamicSize, type.getElementType(), AffineMap(),
+ spirv::StorageClassAttr::get(type.getContext(), *storage));
+ return typeConverter.convertType(flattenedType);
+ });
// Add unrealized conversion cast ops to bridge type conversions: we are
// only converting the cooperative matrix subset; the rest needs to be done
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVVectorize.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVVectorize.cpp
index 3ecc37f..f73caab 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVVectorize.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVVectorize.cpp
@@ -10,6 +10,7 @@
//
//===----------------------------------------------------------------------===//
+#include "iree-dialects/Dialect/LinalgExt/Transforms/Transforms.h"
#include "iree/compiler/Codegen/PassDetail.h"
#include "iree/compiler/Codegen/Passes.h"
#include "iree/compiler/Codegen/SPIRV/Utils.h"
@@ -30,6 +31,9 @@
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+using mlir::iree_compiler::IREE::LinalgExt::LinalgVectorizationPattern;
+using mlir::iree_compiler::IREE::LinalgExt::VectorizationPatterns;
+
#define DEBUG_TYPE "iree-spirv-vectorize"
namespace mlir {
@@ -82,7 +86,7 @@
} else if (auto contractOp = dyn_cast<vector::ContractionOp>(op)) {
unsigned lastParallelDim = 0;
for (const auto &it : llvm::enumerate(contractOp.getIteratorTypes())) {
- if (isParallelIterator(it.value())) lastParallelDim = it.index();
+ if (vector::isParallelIterator(it.value())) lastParallelDim = it.index();
}
SmallVector<int64_t, 4> nativeSize(contractOp.getIteratorTypes().size(), 1);
SmallVector<int64_t, 4> bounds;
@@ -116,9 +120,9 @@
void populateVectorizationPatterns(RewritePatternSet &patterns) {
linalg::LinalgVectorizationOptions opt;
linalg::LinalgTransformationFilter f;
- linalg::VectorizationPatterns<linalg::FillOp, linalg::GenericOp>::insert(
- patterns, opt, f);
- patterns.add<linalg::LinalgVectorizationPattern>(
+ VectorizationPatterns<linalg::FillOp, linalg::GenericOp>::insert(patterns,
+ opt, f);
+ patterns.add<LinalgVectorizationPattern>(
patterns.getContext(), f.addOpFilter<linalg::ContractionOpInterface>(),
opt);
// Additinally pull in patterns to canonicalize transfer ops and to shuffle
diff --git a/compiler/src/iree/compiler/Codegen/Sandbox/BUILD b/compiler/src/iree/compiler/Codegen/Sandbox/BUILD
index d2306fd..65886f9 100644
--- a/compiler/src/iree/compiler/Codegen/Sandbox/BUILD
+++ b/compiler/src/iree/compiler/Codegen/Sandbox/BUILD
@@ -55,6 +55,8 @@
":PassesIncGen",
"//compiler/src/iree/compiler/Codegen/Dialect:IREECodegenDialect",
"//compiler/src/iree/compiler/Codegen/Utils",
+ "//llvm-external-projects/iree-dialects:IREELinalgExtPasses",
+ "//llvm-external-projects/iree-dialects:IREELinalgExtTransforms",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:AffineDialect",
"@llvm-project//mlir:ArithmeticDialect",
diff --git a/compiler/src/iree/compiler/Codegen/Sandbox/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Sandbox/CMakeLists.txt
index 0fde1f5..720357a 100644
--- a/compiler/src/iree/compiler/Codegen/Sandbox/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/Sandbox/CMakeLists.txt
@@ -45,6 +45,8 @@
DEPS
::PassHeaders
::PassesIncGen
+ IREELinalgExtPasses
+ IREELinalgExtTransforms
LLVMSupport
MLIRAffineDialect
MLIRArithmeticDialect
diff --git a/compiler/src/iree/compiler/Codegen/Sandbox/LinalgTensorCodegenDriver.cpp b/compiler/src/iree/compiler/Codegen/Sandbox/LinalgTensorCodegenDriver.cpp
index 845f8aa..65e5aac 100644
--- a/compiler/src/iree/compiler/Codegen/Sandbox/LinalgTensorCodegenDriver.cpp
+++ b/compiler/src/iree/compiler/Codegen/Sandbox/LinalgTensorCodegenDriver.cpp
@@ -4,6 +4,8 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+#include "iree-dialects/Dialect/LinalgExt/Passes/Passes.h"
+#include "iree-dialects/Dialect/LinalgExt/Transforms/CodegenStrategy.h"
#include "iree/compiler/Codegen/Dialect/LoweringConfig.h"
#include "iree/compiler/Codegen/Sandbox/PassDetail.h"
#include "iree/compiler/Codegen/Sandbox/Passes.h"
@@ -12,8 +14,6 @@
#include "mlir/AsmParser/AsmParser.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
-#include "mlir/Dialect/Linalg/Passes.h"
-#include "mlir/Dialect/Linalg/Transforms/CodegenStrategy.h"
#include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
@@ -26,7 +26,9 @@
#include "mlir/Transforms/Passes.h"
using namespace mlir;
-using namespace mlir::linalg;
+// using namespace mlir::linalg;
+
+using mlir::iree_compiler::IREE::LinalgExt::CodegenStrategy;
#define DEBUG_TYPE "iree-linalg-tensor-codegen-driver"
@@ -53,8 +55,9 @@
/// Default method to initialize the tiling options in IREE. These could be
/// overriden by the command line options if specified. For now the sentinel
/// -1 is used for avoiding querying the lowering config.
-static bool getTilingOptionsFromConfig(func::FuncOp funcOp, int64_t tilingLevel,
- LinalgTilingOptions &tilingOptions) {
+static bool getTilingOptionsFromConfig(
+ func::FuncOp funcOp, int64_t tilingLevel,
+ linalg::LinalgTilingOptions &tilingOptions) {
if (tilingLevel != -1) {
FailureOr<Operation *> rootOp = getRootOp(funcOp);
if (failed(rootOp)) {
@@ -115,10 +118,10 @@
/// Default method to initialize the tiling options for fusion in IREE. These
/// could be ovveridden by the command line options if specified.
-static FailureOr<LinalgTilingAndFusionOptions> getTileAndFuseOptionsFromConfig(
- func::FuncOp funcOp, int64_t tilingLevel) {
+static FailureOr<linalg::LinalgTilingAndFusionOptions>
+getTileAndFuseOptionsFromConfig(func::FuncOp funcOp, int64_t tilingLevel) {
if (tilingLevel == -1) {
- return LinalgTilingAndFusionOptions();
+ return linalg::LinalgTilingAndFusionOptions();
}
FailureOr<Operation *> rootOp = getRootOp(funcOp);
@@ -127,7 +130,7 @@
iree_compiler::IREE::Codegen::LoweringConfigAttr loweringConfig =
iree_compiler::getLoweringConfig(rootOp.value());
- LinalgTilingAndFusionOptions options;
+ linalg::LinalgTilingAndFusionOptions options;
options.tileSizes.assign(loweringConfig.getTileSizeVals(tilingLevel));
options.tileInterchange.assign(
loweringConfig.getTileInterchangeVals(tilingLevel));
@@ -260,12 +263,13 @@
func::FuncOp funcOp = getOperation();
// Set up tiling and vectorization options.
- FailureOr<LinalgTilingAndFusionOptions> defaultTilingOptions =
+ FailureOr<linalg::LinalgTilingAndFusionOptions> defaultTilingOptions =
getTileAndFuseOptionsFromConfig(funcOp, tilingLevel);
if (failed(defaultTilingOptions)) {
return signalPassFailure();
}
- LinalgTilingAndFusionOptions tilingOptions = defaultTilingOptions.value();
+ linalg::LinalgTilingAndFusionOptions tilingOptions =
+ defaultTilingOptions.value();
bool doTiling = !tilingOptions.tileSizes.empty();
if (!tileSizes.empty()) {
doTiling = true;
@@ -336,7 +340,7 @@
transposePaddingVectors.push_back(transposeVector);
}
- LinalgPaddingOptions paddingOptions;
+ linalg::LinalgPaddingOptions paddingOptions;
paddingOptions.setPaddingValues(paddingValueAttributes);
paddingOptions.setPaddingDimensions(
SmallVector<int64_t>{paddingDimensions.begin(), paddingDimensions.end()});
@@ -354,6 +358,8 @@
// Created a nested OpPassManager and run.
OpPassManager dynamicPM(func::FuncOp::getOperationName());
strategy.configurePassPipeline(dynamicPM, funcOp.getContext());
+ dynamicPM.addPass(
+ iree_compiler::IREE::LinalgExt::createLinalgStrategyEnablePass());
if (failed(runPipeline(dynamicPM, funcOp))) {
return signalPassFailure();
@@ -364,7 +370,7 @@
func::FuncOp funcOp = getOperation();
// Set up tiling and vectorization options.
- LinalgTilingOptions tilingOptions;
+ linalg::LinalgTilingOptions tilingOptions;
bool doTiling =
getTilingOptionsFromConfig(funcOp, tilingLevel, tilingOptions);
if (!tileSizes.empty()) {
@@ -399,7 +405,7 @@
transposePaddingVectors.push_back(transposeVector);
}
- LinalgPaddingOptions paddingOptions;
+ linalg::LinalgPaddingOptions paddingOptions;
paddingOptions.setPaddingValues(paddingValueAttributes);
paddingOptions.setPackPaddings(
SmallVector<bool>{packPaddings.begin(), packPaddings.end()});
@@ -409,7 +415,7 @@
// Gather tiled loops that aren't distribution loops from previous tiling
// stages.
- LinalgPeelOptions peelingOptions;
+ linalg::LinalgPeelOptions peelingOptions;
peelingOptions.loopsToPeelComputationFunction =
[](OpBuilder &builder, Operation *op,
SmallVectorImpl<scf::ForOp> &loopsToPeel) {
@@ -432,7 +438,7 @@
};
CodegenStrategy strategy;
- StringRef genericOpName = GenericOp::getOperationName();
+ StringRef genericOpName = linalg::GenericOp::getOperationName();
strategy.tileIf(doTiling, anchorOpName, tilingOptions)
.padIf(pad, anchorOpName, paddingOptions)
.decomposeIf(decomposeToLowerDimOp)
@@ -443,6 +449,8 @@
// Created a nested OpPassManager and run.
OpPassManager dynamicPM(func::FuncOp::getOperationName());
strategy.configurePassPipeline(dynamicPM, funcOp.getContext());
+ dynamicPM.addPass(
+ iree_compiler::IREE::LinalgExt::createLinalgStrategyEnablePass());
if (failed(runPipeline(dynamicPM, funcOp))) {
return signalPassFailure();
}
@@ -490,8 +498,8 @@
.enableFullUnroll(unrollVectorTransfers)
.enableLowerPermutationMaps();
- LinalgVectorLoweringOptions vectorLoweringOptions =
- LinalgVectorLoweringOptions()
+ linalg::LinalgVectorLoweringOptions vectorLoweringOptions =
+ linalg::LinalgVectorLoweringOptions()
// Lowering of vector contractions.
.enableContractionLowering(vectorLoweringStage >= 0)
// Lowering of vector multi_reduction.
@@ -526,6 +534,8 @@
OpPassManager dynamicPM(func::FuncOp::getOperationName());
func::FuncOp funcOp = getOperation();
strategy.configurePassPipeline(dynamicPM, funcOp.getContext());
+ dynamicPM.addPass(
+ iree_compiler::IREE::LinalgExt::createLinalgStrategyEnablePass());
if (failed(runPipeline(dynamicPM, funcOp))) {
return signalPassFailure();
}
diff --git a/compiler/src/iree/compiler/Dialect/Flow/TransformExtensions/FlowExtensions.cpp b/compiler/src/iree/compiler/Dialect/Flow/TransformExtensions/FlowExtensions.cpp
index c1c0bfc..a224bbd 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/TransformExtensions/FlowExtensions.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/TransformExtensions/FlowExtensions.cpp
@@ -10,6 +10,7 @@
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
#include "iree/compiler/Dialect/Flow/Transforms/ConvertRegionToWorkgroups.h"
#include "iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.h"
+#include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
@@ -482,244 +483,6 @@
return DiagnosedSilenceableFailure(success());
}
-/// Return `true` if the given type is a ShapedType and has at least one
-/// dynamic dimension.
-static bool hasDynamicShape(Type t) {
- auto shapedType = t.dyn_cast<ShapedType>();
- if (!shapedType) return false;
- return !shapedType.hasStaticShape();
-}
-
-/// Reify the dynamic dimensions of the given value.
-static LogicalResult reifyDynamicResultDims(OpBuilder b, Value value,
- SmallVector<Value> &dynamicDims) {
- OpBuilder::InsertionGuard guard(b);
-
- // Case 1: No dynamic result dims.
- if (!hasDynamicShape(value.getType())) return success();
-
- // There is at least one dynamic dimension, continue...
- ShapedType shapedType = value.getType().cast<ShapedType>();
-
- // Case 2: Value is a block argument.
- if (auto bbArg = value.dyn_cast<BlockArgument>()) {
- b.setInsertionPointToStart(bbArg.getOwner());
- for (int64_t i = 0; i < shapedType.getRank(); ++i) {
- if (shapedType.isDynamicDim(i)) {
- Value dim = b.create<tensor::DimOp>(bbArg.getLoc(), bbArg, i);
- dynamicDims.push_back(dim);
- }
- }
- return success();
- }
-
- // Value is an OpResult.
- Operation *op = value.getDefiningOp();
- OpResult opResult = value.cast<OpResult>();
- b.setInsertionPoint(op);
-
- // Case 3: Value is tied. Reify the dimensions of the tied operand.
- auto tiedOp = dyn_cast<IREE::Util::TiedOpInterface>(op);
- if (tiedOp) {
- Value tiedOperand = tiedOp.getTiedResultOperand(value);
- if (tiedOperand) {
-#ifndef NDEBUG
- ShapedType tiedOperandType = tiedOperand.getType().cast<ShapedType>();
- assert(tiedOperandType == shapedType && "expected same type");
-#endif // NDEBUG
- return reifyDynamicResultDims(b, tiedOperand, dynamicDims);
- }
- }
-
- // Case 4: Query ShapeAwareOpInterface.
- auto shapeAwareOp = dyn_cast<IREE::Util::ShapeAwareOpInterface>(op);
- if (shapeAwareOp) {
- ValueRange dims =
- shapeAwareOp.getResultDynamicDims(opResult.getResultNumber());
- dynamicDims.append(dims.begin(), dims.end());
- return success();
- }
-
- // Case 5: Query ReifyRankedShapedTypeOpInterface.
- auto reifyShapeOp = dyn_cast<ReifyRankedShapedTypeOpInterface>(op);
- if (reifyShapeOp) {
- ReifiedRankedShapedTypeDims dims;
- if (failed(reifyShapeOp.reifyResultShapes(b, dims))) return failure();
- for (int64_t i = 0; i < shapedType.getRank(); ++i)
- if (shapedType.isDynamicDim(i))
- dynamicDims.push_back(dims[opResult.getResultNumber()][i]);
- return success();
- }
-
- return failure();
-}
-
-// Append a result to the given DispatchRegionOp. The newly created
-// DispatchRegionOp is returned.
-static FailureOr<Flow::DispatchRegionOp> appendDispatchRegionResult(
- RewriterBase &rewriter, Flow::DispatchRegionOp regionOp, Value result) {
- OpBuilder::InsertionGuard guard(rewriter);
-
- // Determine dynamic result dims.
- rewriter.setInsertionPoint(regionOp);
- SmallVector<Value> dynamicDims(regionOp.getResultDims().begin(),
- regionOp.getResultDims().end());
- if (failed(reifyDynamicResultDims(rewriter, result, dynamicDims)))
- return failure();
-
- // Determine result types of new RegionOp.
- SmallVector<Type> resultTypes(regionOp.getResultTypes().begin(),
- regionOp.getResultTypes().end());
- resultTypes.push_back(result.getType());
-
- // Create new DispatchRegionOp and move over the body.
- auto newRegionOp = rewriter.create<Flow::DispatchRegionOp>(
- regionOp->getLoc(), resultTypes, dynamicDims);
- newRegionOp.getBody().takeBody(regionOp.getBody());
- rewriter.replaceOp(
- regionOp, newRegionOp.getResults().take_front(regionOp->getNumResults()));
-
- // Update terminator.
- Flow::ReturnOp returnOp =
- cast<Flow::ReturnOp>(newRegionOp.getBody().front().getTerminator());
- SmallVector<Value> returnedValues(returnOp.getOperands().begin(),
- returnOp.getOperands().end());
- returnedValues.push_back(result);
- returnOp.operandsMutable().assign(returnedValues);
-
- return newRegionOp;
-}
-
-/// Return `true` if the given op post-dominates the dispatch region.
-static bool isAfterRegion(Operation *op, Flow::DispatchRegionOp regionOp) {
- Operation *ancestor = regionOp->getBlock()->findAncestorOpInBlock(*op);
- assert(ancestor && "expected that op and regionOp are in the same block");
- return regionOp->isBeforeInBlock(ancestor);
-}
-
-// Clone a `target` op that is preceding the given dispatch region op into the
-// dispatch region.
-//
-// All uses of the target inside of the dispatch region are replaced with the
-// results of the cloned op.
-//
-// If `updateUsesOutsideOfRegion` is set, all uses of the target op after the
-// dispatch region, are also updated: The target op's results are returned from
-// the dispatch region an used in those places.
-//
-// Example when `updateUsesOutsideOfRegion` is unset:
-//
-// %0 = "some_op"() : () -> (tensor<?xf32>)
-// %r = flow.dispatch.region -> (tensor<?xf32>{%d0}) {
-// %1 = "another_op"(%0) : (tensor<?xf32>) -> (tensor<?xf32>)
-// flow.return %1 : tensor<?xf32>
-// }
-// %2 = "yet_another_use"(%0) : (tensor<?xf32>) -> (tensor<?xf32>)
-//
-// In this example, "some_op" will be cloned into the dispatch region and the
-// OpOperand of "another_op" will be replaced:
-//
-// %0 = "some_op"() : () -> (tensor<?xf32>)
-// %r = flow.dispatch.region -> (tensor<?xf32>{%d0}) {
-// %0_clone = "some_op"() : () -> (tensor<?xf32>)
-// %1 = "another_op"(%0_clone) : (tensor<?xf32>) -> (tensor<?xf32>)
-// flow.return %1 : tensor<?xf32>
-// }
-// %2 = "yet_another_use"(%0) : (tensor<?xf32>) -> (tensor<?xf32>)
-static FailureOr<Flow::DispatchRegionOp> clonePrecedingOpIntoDispatchRegion(
- RewriterBase &rewriter, Operation *target, Flow::DispatchRegionOp regionOp,
- bool updateUsesOutsideOfRegion) {
- assert(target->isBeforeInBlock(regionOp) &&
- "expected that target comes first");
- Block &body = regionOp.getBody().front();
-
- // Gather all uses of `target`.
- SmallVector<OpOperand *> usesInsideOfRegion, usesAfterRegion;
- bool hasUsesBeforeRegion = false;
- for (OpOperand &use : target->getUses()) {
- if (regionOp->isProperAncestor(use.getOwner())) {
- usesInsideOfRegion.push_back(&use);
- } else {
- // Collect only uses that post-dominate the region.
- if (isAfterRegion(use.getOwner(), regionOp)) {
- usesAfterRegion.push_back(&use);
- } else {
- hasUsesBeforeRegion = true;
- }
- }
- }
-
- // Clone op into dispatch region.
- Operation *newTargetOp;
- if (usesAfterRegion.empty() && !hasUsesBeforeRegion) {
- // Optimization: If there are not uses outside of the region, we can simply
- // move the target instead of cloning it.
- target->moveBefore(&body.front());
- newTargetOp = target;
- } else {
- OpBuilder::InsertionGuard guard(rewriter);
- rewriter.setInsertionPointToStart(&body);
- newTargetOp = rewriter.clone(*target);
-
- // Replace all uses in the dispatch region.
- for (OpOperand *use : usesInsideOfRegion) {
- rewriter.updateRootInPlace(use->getOwner(), [&]() {
- use->set(newTargetOp->getResult(
- use->get().cast<OpResult>().getResultNumber()));
- });
- }
- }
-
- // Replace all uses outside of the dispatch region.
- if (updateUsesOutsideOfRegion && !usesAfterRegion.empty()) {
- // Fail if there are uses before the dispatch region. In that case it does
- // usually not make sense to update uses after the region; we can just keep
- // using the original op result.
- if (hasUsesBeforeRegion) return failure();
-
- unsigned previousNumResults = regionOp->getNumResults();
-
- // Note: Appending results one-by-one here so that this can be extended to
- // specific results in the future. Many ops have just one result, so this
- // should not be a large overhead.
- for (Value v : newTargetOp->getResults()) {
- auto newRegionOp = appendDispatchRegionResult(rewriter, regionOp, v);
- if (failed(newRegionOp)) return failure();
- regionOp = *newRegionOp;
- }
-
- // Replace uses of `target` after the dispatch region.
- for (OpOperand *use : usesAfterRegion) {
- assert(DominanceInfo().properlyDominates(regionOp, use->getOwner()) &&
- "all target uses must be inside or after regionOp");
- rewriter.updateRootInPlace(use->getOwner(), [&]() {
- use->set(
- regionOp->getResult(previousNumResults +
- use->get().cast<OpResult>().getResultNumber()));
- });
- }
- }
-
- // Remove the original target if it no longer has any uses.
- if (target->use_empty()) rewriter.eraseOp(target);
-
- return regionOp;
-}
-
-// Move a `target` op that is preceding the given dispatch region op into the
-// dispatch region. All uses of the target must be inside the region.
-static FailureOr<Flow::DispatchRegionOp> movePrecedingOpIntoDispatchRegion(
- RewriterBase &rewriter, Operation *target,
- Flow::DispatchRegionOp regionOp) {
- assert(llvm::all_of(target->getUses(),
- [&](OpOperand &use) {
- return regionOp->isProperAncestor(use.getOwner());
- }) &&
- "cannot move target into region");
- return clonePrecedingOpIntoDispatchRegion(
- rewriter, target, regionOp, /*updateUsesOutsideOfRegion=*/false);
-}
-
DiagnosedSilenceableFailure
transform_dialect::ClonePrecedingOpIntoDispatchRegionOp::apply(
transform::TransformResults &transformResults,
@@ -753,9 +516,51 @@
SmallVector<Operation *> orderedTargets =
llvm::to_vector(llvm::reverse(targetOps));
IRRewriter rewriter(regionOp->getContext());
+ for (Operation *target : orderedTargets)
+ if (failed(clonePrecedingOpIntoDispatchRegion(rewriter, target, regionOp)))
+ return DiagnosedSilenceableFailure(reportUnknownTransformError(target));
+
+ transformResults.set(getTransformed().cast<OpResult>(),
+ regionOp.getOperation());
+ return DiagnosedSilenceableFailure(success());
+}
+
+DiagnosedSilenceableFailure
+transform_dialect::MovePrecedingOpIntoDispatchRegionOp::apply(
+ transform::TransformResults &transformResults,
+ transform::TransformState &state) {
+ ArrayRef<Operation *> targetOps = state.getPayloadOps(getTarget());
+ ArrayRef<Operation *> dispatchRegion =
+ state.getPayloadOps(getDispatchRegion());
+
+ if (targetOps.empty() && dispatchRegion.empty()) {
+ transformResults.set(getResult().cast<OpResult>(),
+ SmallVector<mlir::Operation *>{});
+ return DiagnosedSilenceableFailure::success();
+ }
+
+ if (dispatchRegion.size() != 1)
+ return DiagnosedSilenceableFailure(this->emitOpError(
+ "requires exactly one target/dispatch region handle"));
+
+ auto regionOp = dyn_cast<Flow::DispatchRegionOp>(dispatchRegion.front());
+ if (!regionOp)
+ return DiagnosedSilenceableFailure(
+ this->emitOpError("expected 'dispatch.region' operand"));
+
+ // We are cloning ops one-by-one, so the order must be inversed (as opposed
+ // to cloning all ops in one go).
+ SmallVector<Operation *> targetOpsList(targetOps.begin(), targetOps.end());
+ bool sortResult = computeTopologicalSorting(
+ dispatchRegion.front()->getBlock(), targetOpsList);
+ (void)sortResult;
+ assert(sortResult && "unable to sort topologically");
+ SmallVector<Operation *> orderedTargets =
+ llvm::to_vector(llvm::reverse(targetOps));
+ IRRewriter rewriter(regionOp->getContext());
for (Operation *target : orderedTargets) {
- auto newRegionOp = clonePrecedingOpIntoDispatchRegion(
- rewriter, target, regionOp, getUpdateUsesOutsideOfRegion());
+ auto newRegionOp =
+ movePrecedingOpIntoDispatchRegion(rewriter, target, regionOp);
if (failed(newRegionOp))
return DiagnosedSilenceableFailure(reportUnknownTransformError(target));
regionOp = *newRegionOp;
@@ -908,19 +713,11 @@
Operation *target, SmallVectorImpl<Operation *> &results,
transform::TransformState &state) {
IRRewriter rewriter(target->getContext());
-
- // Make an empty dispatch region right before the target.
- rewriter.setInsertionPointAfter(target);
- Flow::DispatchRegionOp regionOp =
- makeEmptyDispatchRegion(rewriter, target->getLoc());
-
- // Move the target into the dispatch region.
- auto newRegionOp = clonePrecedingOpIntoDispatchRegion(
- rewriter, target, regionOp, /*updateUsesOutsideOfRegion=*/true);
- if (failed(newRegionOp))
+ auto regionOp = Flow::wrapOpInDispatchRegion(rewriter, target);
+ if (failed(regionOp))
return DiagnosedSilenceableFailure(reportUnknownTransformError(target));
- results.push_back(*newRegionOp);
+ results.push_back(*regionOp);
return DiagnosedSilenceableFailure(success());
}
diff --git a/compiler/src/iree/compiler/Dialect/Flow/TransformExtensions/FlowExtensionsOps.td b/compiler/src/iree/compiler/Dialect/Flow/TransformExtensions/FlowExtensionsOps.td
index 20b955b..2ec4c0d 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/TransformExtensions/FlowExtensionsOps.td
+++ b/compiler/src/iree/compiler/Dialect/Flow/TransformExtensions/FlowExtensionsOps.td
@@ -111,27 +111,57 @@
handle must be mapped to exactly one payload op.
All uses of the target inside of the dispatch region are replaced with the
- results of the cloned op.
-
- If `update_uses_outside_of_region` is set (default value: `false`), all
- uses outside of the dispatch region are also replaced: The results of the
- cloned target op are yielded from the dispatch region and used in all uses
- outside of the dispatch region. The transform fails if there are uses that
- appear before the dispatch region.
+ results of the cloned op. Uses of the target outside of the dispatch region
+ remain unchanged.
#### Return modes
- This transform consumes both the `target` handle and the `dispatch_region`
+ This transform reads the `target` handle and consumes the `dispatch_region`
handle. It produces a new handle to the extended dispatch region.
}];
let arguments = (ins Arg<PDL_Operation, "",
- [TransformMappingRead,
- TransformMappingFree]>:$target,
+ [TransformMappingRead]>:$target,
Arg<PDL_Operation, "",
[TransformMappingRead,
- TransformMappingFree]>:$dispatch_region,
- DefaultValuedAttr<BoolAttr, "false">:$update_uses_outside_of_region);
+ TransformMappingFree]>:$dispatch_region);
+ let results = (outs Res<PDL_Operation, "",
+ [TransformMappingAlloc,
+ TransformMappingWrite]>:$transformed);
+ let assemblyFormat = "$target `into` $dispatch_region attr-dict";
+ let cppNamespace = "mlir::iree_compiler::IREE::transform_dialect";
+ let extraClassDeclaration = [{
+ ::mlir::DiagnosedSilenceableFailure apply(
+ ::mlir::transform::TransformResults &transformResults,
+ ::mlir::transform::TransformState &state);
+ }];
+}
+
+def MovePrecedingOpIntoDispatchRegionOp : Op<
+ Transform_Dialect, "iree.move_preceding_op_into_dispatch_region",
+ [TransformOpInterface]> {
+ let description = [{
+ Move the `target` op into the given dispatch region op. The dispatch region
+ handle must be mapped to exactly one payload op.
+
+ An extra result is added to the dispatch region for every result of the
+ target op. All uses of the target op are replaced with the newly added
+ results of the dispatch region.
+
+ Note: This transform generates invalid IR if there are uses of the target op
+ that appear before (i.e., dominate) the dispatch region.
+
+ #### Return modes
+
+ This transform reads the `target` handle and consumes the `dispatch_region`
+ handle. It produces a new handle to the extended dispatch region.
+ }];
+
+ let arguments = (ins Arg<PDL_Operation, "",
+ [TransformMappingRead]>:$target,
+ Arg<PDL_Operation, "",
+ [TransformMappingRead,
+ TransformMappingFree]>:$dispatch_region);
let results = (outs Res<PDL_Operation, "",
[TransformMappingAlloc,
TransformMappingWrite]>:$transformed);
@@ -152,7 +182,7 @@
handle must be mapped to exactly one payload op.
All operands of the target are replaced with values that are defined inside
- of the dispatch region when possible.
+ of the dispatch region when possible.
If `update_uses_outside_of_region` is set (default value: `true`), all uses
of the original target op are replaced: The results of the cloned target op
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/BUILD b/compiler/src/iree/compiler/Dialect/Flow/Transforms/BUILD
index 4f5a384..51c1a05 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/BUILD
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/BUILD
@@ -38,9 +38,11 @@
"ConvertConv2DToImg2Col.cpp",
"ConvertLinalgMatmulToMmt4D.cpp",
"ConvertRegionToWorkgroups.cpp",
+ "ConvertToFlow.cpp",
"DeduplicateExecutables.cpp",
"DetachElementwiseFromNamedOps.cpp",
"DispatchLinalgOnTensors.cpp",
+ "DispatchLinalgOnTensorsViaRegionOps.cpp",
"DispatchWithTransformDialect.cpp",
"DumpDispatchGraph.cpp",
"ExpandTensorShapes.cpp",
@@ -58,6 +60,7 @@
"PadTensorToTensorInsertSlice.cpp",
"PassDetail.h",
"Passes.cpp",
+ "RegionOpUtils.cpp",
"SplitReduction.cpp",
"StripAndSplatConstantVariables.cpp",
"StripSignedness.cpp",
@@ -69,6 +72,7 @@
"FusionUtils.h",
"Passes.h",
"Passes.h.inc",
+ "RegionOpUtils.h",
],
deps = [
":PassesIncGen",
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt
index 06895a5..efdf6d3 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt
@@ -28,6 +28,7 @@
"FusionUtils.h"
"Passes.h"
"Passes.h.inc"
+ "RegionOpUtils.h"
SRCS
"CaptureDispatchDynamicDims.cpp"
"CleanupNumericNarrowing.cpp"
@@ -36,9 +37,11 @@
"ConvertConv2DToImg2Col.cpp"
"ConvertLinalgMatmulToMmt4D.cpp"
"ConvertRegionToWorkgroups.cpp"
+ "ConvertToFlow.cpp"
"DeduplicateExecutables.cpp"
"DetachElementwiseFromNamedOps.cpp"
"DispatchLinalgOnTensors.cpp"
+ "DispatchLinalgOnTensorsViaRegionOps.cpp"
"DispatchWithTransformDialect.cpp"
"DumpDispatchGraph.cpp"
"ExpandTensorShapes.cpp"
@@ -56,6 +59,7 @@
"PadTensorToTensorInsertSlice.cpp"
"PassDetail.h"
"Passes.cpp"
+ "RegionOpUtils.cpp"
"SplitReduction.cpp"
"StripAndSplatConstantVariables.cpp"
"StripSignedness.cpp"
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/ConvertRegionToWorkgroups.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/ConvertRegionToWorkgroups.cpp
index 62d8ea3..367e32a 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/ConvertRegionToWorkgroups.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/ConvertRegionToWorkgroups.cpp
@@ -79,7 +79,8 @@
/// explicitly and makes them available inside the region via block arguments.
FailureOr<Flow::DispatchWorkgroupsOp>
rewriteFlowDispatchRegionToFlowDispatchWorkgroups(
- Flow::DispatchRegionOp regionOp, RewriterBase &rewriter) {
+ Flow::DispatchRegionOp regionOp, RewriterBase &rewriter,
+ ValueRange workload, WorkloadBuilderFn workloadRegionBuilder) {
// Only ops with a single block are supported.
Region ®ion = regionOp.getBody();
if (!region.hasOneBlock()) return failure();
@@ -137,8 +138,8 @@
}
}
auto workgroupsOp = rewriter.create<IREE::Flow::DispatchWorkgroupsOp>(
- loc, /*workload=*/ValueRange(), regionOp.getResultTypes(),
- regionOp.getResultDims(), arguments, argumentDims, tiedArguments);
+ loc, workload, regionOp.getResultTypes(), regionOp.getResultDims(),
+ arguments, argumentDims, tiedArguments);
BlockAndValueMapping bvm;
bvm.map(arguments, workgroupsOp.getInputBlockArguments());
@@ -200,6 +201,17 @@
rewriter.create<IREE::Flow::ReturnOp>(loc);
rewriter.eraseOp(terminator);
+ // Create workload region.
+ if (workloadRegionBuilder) {
+ Region &workgroupCountRegion = workgroupsOp.getWorkgroupCount();
+ Block *body = rewriter.createBlock(&workgroupCountRegion);
+ SmallVector<BlockArgument> workloadArgs;
+ for (Value v : workload)
+ workloadArgs.push_back(body->addArgument(v.getType(), loc));
+ rewriter.setInsertionPointToStart(body);
+ workloadRegionBuilder(rewriter, loc, workloadArgs);
+ }
+
rewriter.replaceOp(regionOp, workgroupsOp.getResults());
return workgroupsOp;
}
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/ConvertRegionToWorkgroups.h b/compiler/src/iree/compiler/Dialect/Flow/Transforms/ConvertRegionToWorkgroups.h
index a6746d6..172eeb2 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/ConvertRegionToWorkgroups.h
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/ConvertRegionToWorkgroups.h
@@ -7,9 +7,13 @@
#ifndef IREE_COMPILER_DIALECT_FLOW_TRANSFORMS_CONVERTREGIONTOWORKGROUPS_H_
#define IREE_COMPILER_DIALECT_FLOW_TRANSFORMS_CONVERTREGIONTOWORKGROUPS_H_
+#include "mlir/IR/ValueRange.h"
#include "mlir/Support/LogicalResult.h"
namespace mlir {
+class BlockArgument;
+class Location;
+class OpBuilder;
class RewriterBase;
namespace iree_compiler {
@@ -18,13 +22,18 @@
class DispatchRegionOp;
class DispatchWorkgroupsOp;
+/// A function that builds the workload body of a DispatchWorkgroupsOp.
+using WorkloadBuilderFn =
+ std::function<void(OpBuilder &, Location, ArrayRef<BlockArgument>)>;
+
/// Rewrite the DispatchRegionOp into a DispatchWorkgroupsOp. The
/// DispatchRegionOp is not isolated from above and may capture any SSA value
/// that is in scope. The generated DispatchWorkgroupsOp captures all SSA values
/// explicitly and makes them available inside the region via block arguments.
FailureOr<DispatchWorkgroupsOp>
-rewriteFlowDispatchRegionToFlowDispatchWorkgroups(DispatchRegionOp regionOp,
- RewriterBase &rewriter);
+rewriteFlowDispatchRegionToFlowDispatchWorkgroups(
+ DispatchRegionOp regionOp, RewriterBase &rewriter, ValueRange workload = {},
+ WorkloadBuilderFn workloadRegionBuilder = nullptr);
} // namespace Flow
} // namespace IREE
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/ConvertToFlow.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/ConvertToFlow.cpp
new file mode 100644
index 0000000..e0a1dd9
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/ConvertToFlow.cpp
@@ -0,0 +1,47 @@
+// Copyright 2020 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Dialect/Flow/Conversion/TensorToFlow/ConvertTensorToFlow.h"
+#include "iree/compiler/Dialect/Flow/Transforms/PassDetail.h"
+#include "iree/compiler/Dialect/Flow/Transforms/Passes.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/MemRef/Transforms/Passes.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+using namespace mlir::iree_compiler;
+using namespace mlir::iree_compiler::IREE;
+
+namespace {
+// Pass to test conversion to flow patterns.
+struct ConvertToFlowPass : public Flow::ConvertToFlowBase<ConvertToFlowPass> {
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry
+ .insert<AffineDialect, IREE::Flow::FlowDialect, linalg::LinalgDialect,
+ scf::SCFDialect, tensor::TensorDialect>();
+ }
+
+ void runOnOperation() override {
+ MLIRContext *context = &getContext();
+ RewritePatternSet convertToFlowPatterns(context);
+ Flow::populateTensorToFlowConversionPatterns(context,
+ convertToFlowPatterns);
+ memref::populateResolveRankedShapeTypeResultDimsPatterns(
+ convertToFlowPatterns);
+ if (failed(applyPatternsAndFoldGreedily(
+ getOperation(), std::move(convertToFlowPatterns)))) {
+ return signalPassFailure();
+ }
+ }
+};
+} // namespace
+
+std::unique_ptr<Pass> Flow::createConvertToFlowPass() {
+ return std::make_unique<ConvertToFlowPass>();
+}
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/DetachElementwiseFromNamedOps.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/DetachElementwiseFromNamedOps.cpp
index e3c73b6..f8db58a 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/DetachElementwiseFromNamedOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/DetachElementwiseFromNamedOps.cpp
@@ -49,8 +49,12 @@
// we see multiple output ops.
if (outputOperands.size() != 1) return failure();
Value outputOperand = outputOperands.front()->get();
- if (outputOperand.getDefiningOp<linalg::FillOp>()) return failure();
+ auto outsDefiningOp = outputOperand.getDefiningOp<linalg::LinalgOp>();
+ if (!outsDefiningOp || isa<linalg::FillOp>(outsDefiningOp.getOperation())) {
+ // If not linalg op, or is a fill op, do nothing.
+ return failure();
+ }
auto outputType = outputOperand.getType().cast<RankedTensorType>();
if (!outputType.getElementType().isIntOrFloat()) return failure();
auto elementType = outputType.getElementType();
@@ -88,7 +92,7 @@
for (int i = 0, e = outputMap.getNumResults(); i < e; ++i) {
int pos = outputMap.getResult(i).cast<AffineDimExpr>().getPosition();
auto attr = linalgOp.getIteratorTypes()[pos].cast<StringAttr>();
- if (!isParallelIterator(attr)) return failure();
+ if (!linalg::isParallelIterator(attr)) return failure();
iterators.push_back(attr.getValue());
}
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp
index d8c991b..4a8f212 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp
@@ -17,6 +17,7 @@
#include "iree/compiler/Dialect/Flow/Transforms/FusionUtils.h"
#include "iree/compiler/Dialect/Flow/Transforms/PassDetail.h"
#include "iree/compiler/Dialect/Flow/Transforms/Passes.h"
+#include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/CommandLine.h"
@@ -195,70 +196,22 @@
// Methods for getting the workload information for dispatch region creation.
//===----------------------------------------------------------------------===//
-/// For a given operation returns the loop ranges needed to compute the op.
-template <typename T>
-static SmallVector<Range> getLoopRanges(T operation, Location loc,
- PatternRewriter &rewriter);
-
-template <>
-SmallVector<Range> getLoopRanges<TilingInterface>(TilingInterface tilableOp,
- Location loc,
- PatternRewriter &rewriter) {
- SmallVector<Range> loopRanges = tilableOp.getIterationDomain(rewriter);
- Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
- for (auto iteratorType : llvm::enumerate(tilableOp.getLoopIteratorTypes())) {
- if (iteratorType.value() == getReductionIteratorTypeName()) {
- loopRanges[iteratorType.index()].size = one;
- }
- }
- return loopRanges;
-}
-
-template <>
-SmallVector<Range> getLoopRanges<tensor::InsertSliceOp>(
- tensor::InsertSliceOp insertSliceOp, Location loc,
- PatternRewriter &rewriter) {
- OpFoldResult zero = rewriter.getIndexAttr(0);
- OpFoldResult one = rewriter.getIndexAttr(1);
- Value source = insertSliceOp.getSource();
- SmallVector<Range> loopRanges(insertSliceOp.getSourceType().getRank(),
- Range{zero, one, one});
- for (auto dim : llvm::seq<unsigned>(0, loopRanges.size())) {
- loopRanges[dim].size =
- rewriter.create<tensor::DimOp>(loc, source, dim).getResult();
- }
- return loopRanges;
-}
-
-template <>
-SmallVector<Range> getLoopRanges<tensor::ExtractSliceOp>(
- tensor::ExtractSliceOp sliceOp, Location loc, PatternRewriter &rewriter) {
- Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
- Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
- ReifiedRankedShapedTypeDims resultDims;
- (void)sliceOp.reifyResultShapes(rewriter, resultDims);
- return llvm::to_vector(llvm::map_range(resultDims[0], [&](Value v) {
- return Range{zero, v, one};
- }));
-}
-
/// Compute the workload to use for the workgroup based on the root op.
-template <typename OpTy>
-static SmallVector<Value> getWorkloadForRootOp(PatternRewriter &rewriter,
- OpTy rootOp) {
+static SmallVector<Value> getWorkloadForRootOp(OpBuilder &builder,
+ Operation *rootOp) {
// Compute workgroup count to use for the dispatch op. These are the ranges
// of the outermost parallel loops that can be distributed.
Location loc = rootOp->getLoc();
- SmallVector<Range> loopRanges = getLoopRanges(rootOp, loc, rewriter);
+ SmallVector<Range> loopRanges = getLoopRanges(rootOp, loc, builder);
AffineExpr s0, s1, s2;
- bindSymbols(rewriter.getContext(), s0, s1, s2);
+ bindSymbols(builder.getContext(), s0, s1, s2);
AffineMap workload = AffineMap::get(0, 3, (s1 - s0).ceilDiv(s2));
return llvm::to_vector(llvm::map_range(loopRanges, [&](Range r) -> Value {
- Value offset = getValueOrCreateConstantIndexOp(rewriter, loc, r.offset);
- Value size = getValueOrCreateConstantIndexOp(rewriter, loc, r.size);
- Value stride = getValueOrCreateConstantIndexOp(rewriter, loc, r.stride);
- return rewriter.create<AffineApplyOp>(rootOp.getLoc(), workload,
- ValueRange{offset, size, stride});
+ Value offset = getValueOrCreateConstantIndexOp(builder, loc, r.offset);
+ Value size = getValueOrCreateConstantIndexOp(builder, loc, r.size);
+ Value stride = getValueOrCreateConstantIndexOp(builder, loc, r.stride);
+ return builder.create<AffineApplyOp>(rootOp->getLoc(), workload,
+ ValueRange{offset, size, stride});
}));
}
@@ -450,88 +403,6 @@
// Methods to legalize a dispatch region op, i.e. make it isolated from above.
//===---------------------------------------------------------------------===//
-/// Reorders the operations in `ops` such that they could be inlined into the
-/// dispatch region in that order to satisfy dependencies.
-static SmallVector<Operation *> orderOperations(ArrayRef<Operation *> ops) {
- LLVM_DEBUG({
- llvm::dbgs() << "Ops to be inlined :\n";
- for (auto op : ops) {
- llvm::dbgs() << "\t";
- op->print(llvm::dbgs());
- llvm::dbgs() << "\n";
- }
- });
-
- llvm::SmallMapVector<Operation *, SmallVector<Operation *>, 16>
- insertAfterMap;
- llvm::SetVector<Operation *> opSet(ops.begin(), ops.end());
- llvm::SetVector<Operation *> leafOps(ops.begin(), ops.end());
- // For each operation compute the list of operations in `ops` that use its
- // results. Also compute the operations that form the leafs of the DAG of
- // operations in `ops`.
- for (auto op : ops) {
- for (auto operand : op->getOperands()) {
- auto definingOp = operand.getDefiningOp();
- if (!definingOp || !opSet.count(definingOp)) continue;
- insertAfterMap[definingOp].push_back(op);
- if (leafOps.count(op)) leafOps.remove(op);
- }
- }
-
- // The leaves are at the head of the ordered list.
- SmallVector<Operation *> orderedOps(leafOps.begin(), leafOps.end());
- orderedOps.reserve(ops.size());
- llvm::SmallPtrSet<Operation *, 16> processed;
- processed.insert(leafOps.begin(), leafOps.end());
-
- // `readyOps` contains the list of operations that have been just added to the
- // `orderedOps` list. With these marked ready, they might make further
- // operations in `ops` ready as well.
- // The complexity of the algorithm is driven by these
- // - Each operations is added to `readyOps` list at most once, and is removed
- // after being processed
- // - For every operation in `readyOps` every use of its results (within `ops`)
- // is looked at once.
- // - For every use, the operands of the user are processed.
- // Assuming operands is O(1), i.e. constant order, the complexity is O(sum of
- // number of uses of each operation). Given that the size of `ops` is at max
- // O(10), and not O(100), this is assumed to be reasonable.
- ArrayRef<Operation *> readyOps(orderedOps);
- size_t startPos = 0;
- while (!readyOps.empty()) {
- auto op = readyOps.front();
- startPos++;
- // Check all uses of `op` within `ops`. If all of the operations that define
- // the operands of the user have been added to `orderedOps`, then the user
- // is ready to be scheduled.
- for (auto insertAfterOp : insertAfterMap[op]) {
- if (processed.count(insertAfterOp)) continue;
- if (llvm::all_of(insertAfterOp->getOperands(), [&](Value operand) {
- Operation *operandDefiningOp = operand.getDefiningOp();
- return !operandDefiningOp || !opSet.count(operandDefiningOp) ||
- processed.count(operandDefiningOp);
- })) {
- // readyOps.push_back(insertAfterOp);
- orderedOps.push_back(insertAfterOp);
- processed.insert(insertAfterOp);
- }
- }
- readyOps = ArrayRef<Operation *>(orderedOps).drop_front(startPos);
- }
-
- LLVM_DEBUG({
- llvm::dbgs() << "Ops to be inlined (sorted) : \n";
- for (auto op : orderedOps) {
- llvm::dbgs() << "\t";
- op->print(llvm::dbgs());
- llvm::dbgs() << "\n";
- }
- });
- assert(orderedOps.size() == ops.size() &&
- "ordering of inlined operations failed");
- return orderedOps;
-}
-
/// Checks if the `Value` has a use within the dispatch that is unfusable.
static bool hasUnfusableUseInDispatch(
Value v, IREE::Flow::DispatchWorkgroupsOp dispatchOp) {
@@ -820,7 +691,7 @@
// Get the workload to use for the dispatch.
FailureOr<SmallVector<Value>> workload =
- getWorkloadForRootOp(rewriter, rootOp);
+ getWorkloadForRootOp(rewriter, rootOp.getOperation());
if (failed(workload)) {
return failure();
}
@@ -1033,28 +904,6 @@
Statistic numDispatches{this, "number of dispatches",
"Number of Flow dispatches created"};
};
-
-// Pass to test conversion to flow patterns.
-struct ConvertToFlowPass : public ConvertToFlowBase<ConvertToFlowPass> {
- void getDependentDialects(DialectRegistry ®istry) const override {
- registry
- .insert<AffineDialect, IREE::Flow::FlowDialect, linalg::LinalgDialect,
- scf::SCFDialect, tensor::TensorDialect>();
- }
-
- void runOnOperation() override {
- MLIRContext *context = &getContext();
- RewritePatternSet convertToFlowPatterns(context);
- populateTensorToFlowConversionPatterns(context, convertToFlowPatterns);
- memref::populateResolveRankedShapeTypeResultDimsPatterns(
- convertToFlowPatterns);
- if (failed(applyPatternsAndFoldGreedily(
- getOperation(), std::move(convertToFlowPatterns)))) {
- return signalPassFailure();
- }
- }
-};
-
} // namespace
/// For all ops within `funcOp` tagged as root ops, create dispatch regions.
@@ -1218,10 +1067,6 @@
return std::make_unique<DispatchLinalgOnTensorsPass>();
}
-std::unique_ptr<Pass> createConvertToFlowPass() {
- return std::make_unique<ConvertToFlowPass>();
-}
-
} // namespace Flow
} // namespace IREE
} // namespace iree_compiler
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensorsViaRegionOps.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensorsViaRegionOps.cpp
new file mode 100644
index 0000000..0b8f008
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensorsViaRegionOps.cpp
@@ -0,0 +1,753 @@
+// Copyright 2020 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+// This is a variant of DispatchLinalgOnTensors.cpp. DispatchWorkgroupsOps are
+// built from DispatchRegionOps. This file can eventually replace the original
+// DispatchLinalgOnTensors.cpp
+//
+// Note: The heuristic part of the implementation is unchanged and copied from
+// DispatchLinalgOnTensors.cpp.
+
+#include "iree/compiler/Dialect/Flow/Conversion/TensorToFlow/ConvertTensorToFlow.h"
+#include "iree/compiler/Dialect/Flow/IR/FlowDialect.h"
+#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
+#include "iree/compiler/Dialect/Flow/Transforms/ConvertRegionToWorkgroups.h"
+#include "iree/compiler/Dialect/Flow/Transforms/FusionUtils.h"
+#include "iree/compiler/Dialect/Flow/Transforms/PassDetail.h"
+#include "iree/compiler/Dialect/Flow/Transforms/Passes.h"
+#include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h"
+#include "llvm/ADT/SmallPtrSet.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
+#include "mlir/Dialect/Arithmetic/Utils/Utils.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Dialect/MemRef/Transforms/Passes.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Dominance.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "mlir/Transforms/RegionUtils.h"
+
+using namespace mlir;
+using namespace mlir::iree_compiler;
+using namespace mlir::iree_compiler::IREE;
+
+#define DEBUG_TYPE "iree-flow-dispatch-linalg-on-tensors-via-region-ops"
+
+static const int kInlineConstantByteLength = 256;
+static const bool kEnableMultiResultDispatches = false;
+static const char kRootOpAttr[] = "__root_op__";
+static const char kFusionGroupsAttr[] = "__fused_op__";
+
+//===----------------------------------------------------------------------===//
+// Helpers for fusion group formation
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// A rewriter that keeps track of all tensor::DimOps.
+class TensorDimTrackingRewriter : public IRRewriter {
+ public:
+ /// Create a new rewriter: Scan the given op for tensor::DimOps.
+ TensorDimTrackingRewriter(Operation *op) : IRRewriter(op->getContext()) {
+ op->walk([&](tensor::DimOp dimOp) { dimOps.insert(dimOp.getOperation()); });
+ }
+
+ /// Return all tracked tensor::DimOps.
+ SmallVector<tensor::DimOp> getTensorDimOps() {
+ SmallVector<tensor::DimOp> result;
+ for (Operation *op : dimOps) result.push_back(cast<tensor::DimOp>(op));
+ return result;
+ }
+
+ protected:
+ void notifyOperationRemoved(Operation *op) override {
+ IRRewriter::notifyOperationRemoved(op);
+ if (isa<tensor::DimOp>(op)) dimOps.erase(op);
+ }
+
+ void notifyOperationInserted(Operation *op) override {
+ IRRewriter::notifyOperationInserted(op);
+ if (isa<tensor::DimOp>(op)) dimOps.insert(op);
+ }
+
+ private:
+ SmallPtrSet<Operation *, 16> dimOps;
+};
+} // namespace
+
+/// Simplfy the given tensor::DimOps as much as possible.
+/// * Static dimensions are replaced by constant.
+/// * Dynamic dim ops are pushed as much as possible to the top of the function,
+/// i.e., if the dim of a value is known to be equal to the dim of a value on
+/// the reverse SSA use-def chain, rewrite the value with a dim op of that
+/// value.
+static LogicalResult simplifyDimOps(RewriterBase &rewriter,
+ const SmallVector<tensor::DimOp> &dimOps) {
+ for (tensor::DimOp dimOp : dimOps) {
+ // Only DimOps with static indices are supported.
+ Optional<int64_t> idx = dimOp.getConstantIndex();
+ if (!idx.hasValue()) continue;
+ // Only DimOps with ranked tensors are supported.
+ auto tensorType = dimOp.getSource().getType().dyn_cast<RankedTensorType>();
+ if (!tensorType) continue;
+
+ if (!tensorType.isDynamicDim(*idx)) {
+ // Rewrite static dimension with constant.
+ int64_t size = tensorType.getShape()[*idx];
+ rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(dimOp, size);
+ continue;
+ }
+
+ // Try to simplify dynamic dims.
+ SmallVector<Value> dynamicDims;
+ if (failed(Flow::reifyDynamicResultDims(rewriter, dimOp.getSource(),
+ dynamicDims)))
+ return failure();
+ unsigned ctr = 0;
+ for (int64_t i = 0; i < *dimOp.getConstantIndex(); ++i)
+ if (tensorType.isDynamicDim(i)) ++ctr;
+ rewriter.replaceOp(dimOp, dynamicDims[ctr]);
+ }
+
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Root and fusion group attribute handling
+//===----------------------------------------------------------------------===//
+
+/// Returns true if an op has a root operation.
+static bool hasRootOpAttribute(Operation *op) {
+ return static_cast<bool>(op->getAttrOfType<IntegerAttr>(kRootOpAttr));
+}
+
+/// Removes root attribute. Asserts if root attribute is not present.
+static void removeRootOpAttribute(Operation *op) {
+ op->removeAttr(kRootOpAttr);
+}
+
+/// Sets the root attribute for an operation. The root attribute needs a number
+/// to identify the root. Asserts if root attribute is already set on an
+/// operation.
+static void setRootAttribute(MLIRContext *context, Operation *op,
+ int64_t rootNumber) {
+ assert(!op->hasAttr(kRootOpAttr) &&
+ "invalid to update root attribute on an op");
+ op->setAttr(kRootOpAttr,
+ IntegerAttr::get(IntegerType::get(context, 64), rootNumber));
+}
+
+/// Returns the number of the root. Asserts if the operation is not already set
+/// as a root.
+static int64_t getRootNumber(Operation *op) {
+ return op->getAttrOfType<IntegerAttr>(kRootOpAttr).getInt();
+}
+
+/// Returns true if an op is part of a fusion group.
+static bool hasFusionGroupsAttribute(Operation *op) {
+ return static_cast<bool>(op->getAttrOfType<ArrayAttr>(kFusionGroupsAttr));
+}
+
+/// Returns the fusion groups for the given `op`.
+static SmallVector<int64_t, 1> getFusionGroups(Operation *op) {
+ SmallVector<int64_t, 1> fusionGroups = {};
+ if (auto fusionGroupsAttr = op->getAttrOfType<ArrayAttr>(kFusionGroupsAttr)) {
+ fusionGroups = llvm::to_vector<1>(llvm::map_range(
+ fusionGroupsAttr,
+ [](Attribute attr) { return attr.cast<IntegerAttr>().getInt(); }));
+ }
+ return fusionGroups;
+}
+
+/// Appends the given `op` to the `newGroups` fusion groups.
+static void appendToFusionGroup(Operation *op, ArrayRef<int64_t> newGroups) {
+ SmallVector<int64_t, 1> fusionGroups = getFusionGroups(op);
+ fusionGroups.append(newGroups.begin(), newGroups.end());
+ op->setAttr(kFusionGroupsAttr, Builder(op).getI64ArrayAttr(fusionGroups));
+}
+
+/// Returns true if the given `op` is in the `targetGroup` fusion group.
+static bool isInFusionGroup(Operation *op, unsigned targetGroup) {
+ if (ArrayAttr opGroupAttr = op->getAttrOfType<ArrayAttr>(kFusionGroupsAttr)) {
+ return llvm::any_of(opGroupAttr, [&targetGroup](Attribute attr) {
+ return attr.cast<IntegerAttr>().getInt() == targetGroup;
+ });
+ }
+ return false;
+}
+
+/// Removes the fusion groups attribute.
+static void removeFusionGroupsAttribute(Operation *op) {
+ op->removeAttr(kFusionGroupsAttr);
+}
+
+//===----------------------------------------------------------------------===//
+// Op property charecterizations
+//===----------------------------------------------------------------------===//
+
+/// Operations that are treated as root operations for dispatch region
+/// formation.
+static bool isRootOp(Operation *op) {
+ if (op->getParentOfType<IREE::Flow::DispatchRegionOp>() ||
+ op->getParentOfType<IREE::Flow::DispatchWorkgroupsOp>()) {
+ return false;
+ }
+ // Any Linalg named op or generic op with reduction iterator types is a root
+ // op.
+ if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
+ if (isa<linalg::GenericOp>(op)) {
+ return linalgOp.getNumReductionLoops() != 0;
+ }
+ return !isa<linalg::FillOp>(op);
+ }
+ return isa<TilingInterface>(op);
+}
+
+/// Operations that are cloned into dispatch regions formed with other
+/// operations as roots.
+bool isClonableIntoDispatchOp(Operation *op) {
+ // TODO(#8637): `tensor.collapse_shape` and `tensor.expand_shape` are
+ // trivially clonable too, but they cause problems
+ // with bufferization. Make them clonable when fixed.
+ if (isa<arith::IndexCastOp, linalg::InitTensorOp, tensor::CastOp,
+ tensor::ExtractOp, tensor::ExtractSliceOp, tensor::PadOp>(op)) {
+ return true;
+ }
+ if (auto constantOp = dyn_cast<arith::ConstantOp>(op)) {
+ auto constantValueAttr = constantOp.getValue();
+ auto constantType = constantOp.getType();
+ if (constantValueAttr.isa<SplatElementsAttr>()) {
+ return true;
+ } else if (auto denseAttr =
+ constantValueAttr.dyn_cast<DenseElementsAttr>()) {
+ auto shapedType = constantOp.getType().cast<ShapedType>();
+ uint64_t estimatedByteLength =
+ (shapedType.getNumElements() * shapedType.getElementTypeBitWidth()) /
+ 8;
+ return denseAttr.isSplat() ||
+ estimatedByteLength <= kInlineConstantByteLength;
+ } else if (constantType.isIntOrIndexOrFloat()) {
+ return true;
+ }
+ }
+ if (llvm::all_of(op->getOperands(),
+ [&](Value v) { return v.getType().isIntOrFloat(); }) &&
+ llvm::all_of(op->getResults(),
+ [&](Value v) { return v.getType().isIntOrFloat(); })) {
+ return true;
+ }
+ return false;
+}
+
+/// Checks if the `Value` has a use within the dispatch that is unfusable.
+static bool hasUnfusableUseInDispatch(Value v, Operation *dispatchOp) {
+ for (OpOperand &use : v.getUses()) {
+ Operation *user = use.getOwner();
+ Operation *ownerWorkgroups =
+ user->getParentOfType<IREE::Flow::DispatchWorkgroupsOp>();
+ Operation *ownerRegion =
+ user->getParentOfType<IREE::Flow::DispatchRegionOp>();
+ Operation *owner = ownerWorkgroups ? ownerWorkgroups : ownerRegion;
+
+ // Ignore uses outside of dispatch workgroups op.
+ if (owner != dispatchOp) continue;
+
+ // Cannot fuse producer of `dest` with `tensor.insert_slice`.
+ if (auto insertSliceUser = dyn_cast<tensor::InsertSliceOp>(user)) {
+ if (insertSliceUser.getDest() == v) return true;
+ }
+ }
+ return false;
+}
+
+//===----------------------------------------------------------------------===//
+// Methods for getting the workload information for dispatch region creation.
+//===----------------------------------------------------------------------===//
+
+/// Compute the workload to use for the workgroup based on the root op.
+static SmallVector<Value> getWorkloadForRootOp(OpBuilder &builder,
+ Operation *rootOp) {
+ // Compute workgroup count to use for the dispatch op. These are the ranges
+ // of the outermost parallel loops that can be distributed.
+ Location loc = rootOp->getLoc();
+ SmallVector<Range> loopRanges = Flow::getLoopRanges(rootOp, loc, builder);
+ AffineExpr s0, s1, s2;
+ bindSymbols(builder.getContext(), s0, s1, s2);
+ AffineMap workload = AffineMap::get(0, 3, (s1 - s0).ceilDiv(s2));
+ return llvm::to_vector(llvm::map_range(loopRanges, [&](Range r) -> Value {
+ Value offset = getValueOrCreateConstantIndexOp(builder, loc, r.offset);
+ Value size = getValueOrCreateConstantIndexOp(builder, loc, r.size);
+ Value stride = getValueOrCreateConstantIndexOp(builder, loc, r.stride);
+ return builder.create<AffineApplyOp>(rootOp->getLoc(), workload,
+ ValueRange{offset, size, stride});
+ }));
+}
+
+//===----------------------------------------------------------------------===//
+// Heuristics for fusing dispatchble ops with root ops using tile + fuse.
+//===----------------------------------------------------------------------===//
+
+/// Collect all ops that should be cloned into the given dispatch region op.
+static SmallVector<Operation *> getCloneableOps(
+ Flow::DispatchRegionOp regionOp) {
+ // Find values that are used inside of the dispatch region but defined outside
+ // of the dispatch region.
+ llvm::SetVector<Value> valuesDefinedAbove;
+ mlir::getUsedValuesDefinedAbove(regionOp.getBody(), valuesDefinedAbove);
+ if (valuesDefinedAbove.empty()) return {};
+
+ // Traverse the defining ops of these values (and the ops on their reverse
+ // SSA use-def chain).
+ SmallVector<Operation *> result;
+ llvm::SetVector<Value> visited;
+ SmallVector<Value, 4> worklist;
+ worklist.assign(valuesDefinedAbove.begin(), valuesDefinedAbove.end());
+ while (!worklist.empty()) {
+ Value outsideValue = worklist.pop_back_val();
+ // Skip values that were already visited.
+ if (visited.count(outsideValue)) continue;
+ visited.insert(outsideValue);
+
+ Operation *definingOp = outsideValue.getDefiningOp();
+ if (!definingOp || !(isClonableIntoDispatchOp(definingOp)) ||
+ hasUnfusableUseInDispatch(outsideValue, regionOp)) {
+ valuesDefinedAbove.insert(outsideValue);
+ continue;
+ }
+ result.push_back(definingOp);
+ worklist.append(definingOp->operand_begin(), definingOp->operand_end());
+ }
+
+ return result;
+}
+
+/// Checks if the producer and consumer LinalgOps can be fused.
+static bool areFusableLinalgOps(OpOperand &use) {
+ return Flow::areLinalgOpsFusableUsingTileAndFuse(use);
+}
+
+/// Returns true if this is a fusable use.
+static bool isFusableWithConsumer(OpOperand &use) {
+ // Check for linalg producer -> consumer fusion with tile + fuse.
+ return areFusableLinalgOps(use);
+}
+
+/// For all uses of an operation, finds the use that dominates all other uses.
+static Optional<OpOperand *> getFusableUse(Operation *op,
+ DominanceInfo const &dominanceInfo) {
+ if (!kEnableMultiResultDispatches) {
+ if (op->hasOneUse()) {
+ OpOperand &use = *(op->use_begin());
+ return &use;
+ }
+ return llvm::None;
+ }
+ for (auto &use : op->getUses()) {
+ Operation *user = use.getOwner();
+ if (llvm::all_of(op->getUsers(), [&](Operation *c) {
+ return dominanceInfo.dominates(user, c);
+ })) {
+ return &use;
+ }
+ }
+ return llvm::None;
+}
+
+/// Fuses roots with its consumers. If a root is fused with its consumer, it is
+/// no more tagged as a root to aid with the dispatch region formation.
+static void fuseRootsWithConsumers(MLIRContext *context,
+ ArrayRef<Operation *> roots,
+ DominanceInfo const &dominanceInfo) {
+ SmallVector<Operation *> workList(roots.begin(), roots.end());
+ // Fuse with consumers where possible.
+ while (!workList.empty()) {
+ Operation *currRoot = workList.pop_back_val();
+ assert(hasRootOpAttribute(currRoot) &&
+ "unexpected non-root op in worklist");
+
+ // Helper function to make the consumer the root instead of the producer
+ // when they are to be fused.
+ auto updateRootTo = [&context, &currRoot](Operation *newRoot) {
+ int64_t rootNumber = getRootNumber(currRoot);
+ setRootAttribute(context, newRoot, rootNumber);
+ removeRootOpAttribute(currRoot);
+ appendToFusionGroup(currRoot, rootNumber);
+ };
+
+ Optional<OpOperand *> fusableUse = getFusableUse(currRoot, dominanceInfo);
+ if (!fusableUse) continue;
+
+ // Analyse the use to see if it is fusable.
+ Operation *consumerOp = fusableUse.value()->getOwner();
+ if (hasRootOpAttribute(consumerOp) ||
+ hasFusionGroupsAttribute(consumerOp)) {
+ continue;
+ }
+
+ if (isFusableWithConsumer(*(fusableUse.value()))) {
+ updateRootTo(consumerOp);
+ workList.push_back(consumerOp);
+ }
+ }
+}
+
+/// Method to check if the consumer of a use can be fused with its producer.
+static bool isFusableWithProducer(OpOperand &operand) {
+ Operation *producer = operand.get().getDefiningOp();
+ Operation *consumer = operand.getOwner();
+
+ if (isa<linalg::LinalgOp>(consumer) && isa<linalg::LinalgOp>(producer)) {
+ auto consumerLinalgOp = cast<linalg::LinalgOp>(consumer);
+ auto producerLinalgOp = cast<linalg::LinalgOp>(producer);
+ if (consumerLinalgOp.isOutputTensor(&operand) &&
+ producerLinalgOp.getNumLoops() ==
+ producerLinalgOp.getNumParallelLoops()) {
+ return true;
+ }
+ }
+ return false;
+}
+
+/// Starting from the `root` op, traverse the operand use-def chain
+/// in reverse to fuse with producers.
+static void fuseRootsWithProducers(MLIRContext *context, Operation *root,
+ unsigned groupNum,
+ DominanceInfo const &dominanceInfo) {
+ // We probably want a worklist algorithm here, but for now just look at
+ // immediate producers.
+ for (OpOperand &operand : root->getOpOperands()) {
+ Operation *producer = operand.get().getDefiningOp();
+ if (!producer) continue;
+ if (hasFusionGroupsAttribute(producer) || hasRootOpAttribute(producer)) {
+ continue;
+ }
+
+ Optional<OpOperand *> fusableUse = getFusableUse(producer, dominanceInfo);
+ if (!fusableUse || fusableUse.value()->getOwner() != root) continue;
+
+ if (isFusableWithProducer(operand)) {
+ appendToFusionGroup(producer, groupNum);
+ }
+ }
+}
+
+/// Some heuristic is needed to fuse a dispatchable op with root operations
+/// using tile + fuse. Using some heuristic, each root operation is tagged with
+/// an ID (using an IntegerAttr with name `kRootOpAttr`) and all dispatchable
+/// ops to be fused with it is tagged with the same ID (using a list of
+/// IntegerAttr with name `kFusionGroupsAttr`). Each dispatchable operation can
+/// be marked to fuse with multiple root operations (i.e. replicated). For now a
+/// very simple heuristic is used below, but the mechanism should be general
+/// enough to capture any heuristic.
+static unsigned decideFusableLinalgOps(FunctionOpInterface funcOp,
+ DominanceInfo const &dominanceInfo) {
+ unsigned numRootOps = 0;
+ MLIRContext *context = funcOp->getContext();
+ OpBuilder builder(context);
+ for (Block &block : funcOp.getBody()) {
+ // Dispatch region formation works by first cloning the root into
+ // the dispatch region and then pulling operations in.
+ // So procedure here is to
+ // - First find the roots
+ // - To fuse with consumers make the consumer the root.
+ SmallVector<Operation *> roots;
+ for (Operation &op : llvm::reverse(block)) {
+ // Start with a root operation and fuse its producers.
+ if (hasFusionGroupsAttribute(&op) || !isRootOp(&op)) continue;
+ unsigned newGroup = numRootOps++;
+ setRootAttribute(context, &op, newGroup);
+
+ fuseRootsWithProducers(context, &op, newGroup, dominanceInfo);
+ roots.push_back(&op);
+ }
+ roots = llvm::to_vector(llvm::reverse(roots));
+ fuseRootsWithConsumers(context, roots, dominanceInfo);
+ }
+
+ // Once all root linalg ops have been tagged, put all remaining generic ops
+ // into their own dispatches.
+ for (Block &block : funcOp.getBody()) {
+ SmallVector<Operation *> roots;
+ for (Operation &op : llvm::reverse(block)) {
+ // If it is part of a fusion group or root op, ignore it.
+ if (hasFusionGroupsAttribute(&op) || hasRootOpAttribute(&op)) continue;
+ // Only look for Linalg ops here. Avoid moving `linalg.fill` that aren't
+ // fused with anything else into their own dispatches since it is better
+ // to convert them to splats.
+ if (!isa<linalg::LinalgOp>(op) || isa<linalg::FillOp>(op)) continue;
+
+ unsigned newGroup = numRootOps++;
+ setRootAttribute(context, &op, newGroup);
+ roots.push_back(&op);
+ }
+ roots = llvm::to_vector(llvm::reverse(roots));
+ fuseRootsWithConsumers(context, roots, dominanceInfo);
+ }
+
+ LLVM_DEBUG({
+ llvm::dbgs() << "\n--- After annotating linalg op fusion scheme ---\n";
+ funcOp->print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
+ llvm::dbgs() << "\n\n";
+ });
+
+ return numRootOps;
+}
+
+//===----------------------------------------------------------------------===//
+// Dispatch region formation
+//===----------------------------------------------------------------------===//
+
+/// Clone producers into the dispatch region.
+static LogicalResult cloneProducers(RewriterBase &rewriter,
+ Flow::DispatchRegionOp regionOp) {
+ SmallVector<Operation *> cloneableOps = getCloneableOps(regionOp);
+ SmallVector<Operation *> orderedProducers =
+ Flow::orderOperations(cloneableOps);
+
+ for (Operation *producer : llvm::reverse(orderedProducers))
+ if (failed(
+ clonePrecedingOpIntoDispatchRegion(rewriter, producer, regionOp)))
+ return failure();
+
+ return success();
+}
+
+/// Helper function that builds the workload region body.
+static void buildWorkloadRegionBody(OpBuilder &builder, Location loc,
+ ArrayRef<BlockArgument> args) {
+ auto numWorkgroupsOp =
+ builder.create<Flow::DispatchWorkgroupCountFromDagRootOp>(loc, args);
+ builder.create<Flow::ReturnOp>(loc, numWorkgroupsOp.getResults());
+}
+
+/// Create Flow::DispatchGroupsOps based on a fusion heuristic.
+static FailureOr<SmallVector<Flow::DispatchWorkgroupsOp>> createFusionGroups(
+ TensorDimTrackingRewriter &rewriter, FunctionOpInterface funcOp,
+ DominanceInfo const &dominanceInfo, bool generateWorkloadRegion) {
+ // Decide fusion groups (heuristic).
+ unsigned numRoots = decideFusableLinalgOps(funcOp, dominanceInfo);
+ SmallVector<Operation *> roots(numRoots, nullptr);
+ DenseMap<unsigned, SmallVector<Operation *>> producers;
+
+ // TODO: Incrementally add ops to an empty DispatchGroupOp instead of
+ // annotating fusion group IDs via attributes.
+ funcOp.walk([&](Operation *op) {
+ if (hasRootOpAttribute(op)) roots[getRootNumber(op)] = op;
+ if (hasFusionGroupsAttribute(op)) {
+ assert(getFusionGroups(op).size() == 1 && "expected exactly one group");
+ producers[getFusionGroups(op).front()].push_back(op);
+ }
+ });
+
+ // Create a DispatchRegionOp for every fusion group.
+ OpBuilder::InsertionGuard g(rewriter);
+ SmallVector<Flow::DispatchRegionOp> regionOps;
+ DenseMap<Flow::DispatchRegionOp, SmallVector<Value>> workloads;
+ for (const auto &it : llvm::enumerate(roots)) {
+ // Compute workload.
+ SmallVector<Value> workload;
+ if (generateWorkloadRegion) {
+ rewriter.setInsertionPoint(it.value());
+ FailureOr<SmallVector<Value>> maybeWorkload =
+ getWorkloadForRootOp(rewriter, it.value());
+ if (failed(maybeWorkload)) return failure();
+ workload = *maybeWorkload;
+ }
+
+ // Simplify tensor::DimOps.
+ SmallVector<tensor::DimOp> dimOps = rewriter.getTensorDimOps();
+ if (failed(simplifyDimOps(rewriter, dimOps))) return failure();
+
+ // Create fusion group.
+ Flow::DispatchRegionOp regionOp;
+ auto maybeRegionOp = Flow::wrapOpInDispatchRegion(rewriter, it.value());
+ if (failed(maybeRegionOp)) return failure();
+ regionOp = *maybeRegionOp;
+ workloads[regionOp] = workload;
+
+ // Sort producers topologically. All producers must be in the same block as
+ // the root.
+ // TODO: Use mlir::computeTopologicalSorting. This is currently not possible
+ // because some of the producers are in different blocks.
+ SmallVector<Operation *> orderedProducers =
+ Flow::orderOperations(producers[it.index()]);
+
+ // Move ops into the region.
+ for (Operation *producer : llvm::reverse(orderedProducers)) {
+ auto newRegionOp =
+ movePrecedingOpIntoDispatchRegion(rewriter, producer, regionOp);
+ if (failed(newRegionOp)) return failure();
+ regionOp = *newRegionOp;
+ }
+
+ regionOps.push_back(regionOp);
+ }
+
+ // Clone additional producers and rewrite to DispatchWorkgroupsOp.
+ SmallVector<Flow::DispatchWorkgroupsOp> result;
+ for (auto regionOp : regionOps) {
+ if (failed(cloneProducers(rewriter, regionOp))) return failure();
+ auto maybeWorkgroupOp =
+ Flow::rewriteFlowDispatchRegionToFlowDispatchWorkgroups(
+ regionOp, rewriter, workloads[regionOp],
+ generateWorkloadRegion ? buildWorkloadRegionBody : nullptr);
+ if (failed(maybeWorkgroupOp)) return failure();
+
+ result.push_back(*maybeWorkgroupOp);
+ }
+
+ return result;
+}
+
+/// Wrap a single op in a DispatchWorkgroupsOp.
+static FailureOr<Flow::DispatchWorkgroupsOp> wrapInWorkgroupsOp(
+ TensorDimTrackingRewriter &rewriter, Operation *op,
+ bool generateWorkloadRegion) {
+ // Compute workload.
+ SmallVector<Value> workload;
+ if (generateWorkloadRegion) {
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(op);
+ FailureOr<SmallVector<Value>> maybeWorkload =
+ getWorkloadForRootOp(rewriter, op);
+ if (failed(maybeWorkload)) return failure();
+ workload = *maybeWorkload;
+ }
+
+ // Simplify tensor::DimOps.
+ SmallVector<tensor::DimOp> dimOps = rewriter.getTensorDimOps();
+ if (failed(simplifyDimOps(rewriter, rewriter.getTensorDimOps())))
+ return failure();
+
+ // Wrap operation.
+ auto regionOp = Flow::wrapOpInDispatchRegion(rewriter, op);
+ if (failed(regionOp)) return failure();
+ if (failed(cloneProducers(rewriter, *regionOp))) return failure();
+ auto workgroupsOp = Flow::rewriteFlowDispatchRegionToFlowDispatchWorkgroups(
+ *regionOp, rewriter, workload,
+ generateWorkloadRegion ? buildWorkloadRegionBody : nullptr);
+ if (failed(workgroupsOp)) return failure();
+ return *workgroupsOp;
+}
+
+/// Wrap all ops of the given type that are direct children of the given op in
+/// a DispatchWorkgroupsOp.
+template <typename OpTy>
+static FailureOr<SmallVector<Flow::DispatchWorkgroupsOp>> wrapInWorkgroupsOp(
+ TensorDimTrackingRewriter &rewriter, Operation *op,
+ bool generateWorkloadRegion) {
+ // Find ops of type OpTy.
+ SmallVector<Operation *> rootOps;
+ for (Region &r : op->getRegions())
+ for (Block &b : r.getBlocks())
+ for (auto op : b.getOps<OpTy>()) rootOps.push_back(op.getOperation());
+
+ // Wrap ops in DispatchWorkgroupsOps.
+ SmallVector<Flow::DispatchWorkgroupsOp> result;
+ for (Operation *rootOp : rootOps) {
+ auto workgroupsOp =
+ wrapInWorkgroupsOp(rewriter, rootOp, generateWorkloadRegion);
+ if (failed(workgroupsOp)) return failure();
+ result.push_back(*workgroupsOp);
+ }
+ return result;
+}
+
+namespace {
+/// Pass declaration.
+struct DispatchLinalgOnTensorsViaRegionOpsPass
+ : public Flow::DispatchLinalgOnTensorsViaRegionOpsBase<
+ DispatchLinalgOnTensorsViaRegionOpsPass> {
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry
+ .insert<AffineDialect, IREE::Flow::FlowDialect, linalg::LinalgDialect,
+ scf::SCFDialect, tensor::TensorDialect>();
+ }
+ DispatchLinalgOnTensorsViaRegionOpsPass(bool generateWorkloadRegion) {
+ this->generateWorkloadRegion = generateWorkloadRegion;
+ }
+ DispatchLinalgOnTensorsViaRegionOpsPass(
+ const DispatchLinalgOnTensorsViaRegionOpsPass &pass) {
+ this->generateWorkloadRegion = pass.generateWorkloadRegion;
+ }
+ void runOnOperation() override;
+
+ private:
+ bool generateWorkloadRegion = true;
+};
+} // namespace
+
+void DispatchLinalgOnTensorsViaRegionOpsPass::runOnOperation() {
+ auto funcOp = getOperation();
+ MLIRContext *context = &getContext();
+
+ DominanceInfo const &dominanceInfo = getAnalysis<DominanceInfo>();
+ TensorDimTrackingRewriter rewriter(funcOp);
+
+ // Step 1: Create a DispatchWorkgroupsOp for every fusion group.
+ auto maybeWorkgroupsOps = createFusionGroups(rewriter, funcOp, dominanceInfo,
+ generateWorkloadRegion);
+ if (failed(maybeWorkgroupsOps)) return signalPassFailure();
+ SmallVector<Flow::DispatchWorkgroupsOp> workgroupsOps = *maybeWorkgroupsOps;
+
+ LLVM_DEBUG({
+ llvm::dbgs() << "\n--- After first step of dispatch region formation ---\n";
+ funcOp->print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
+ llvm::dbgs() << "\n\n";
+ });
+
+ // Step 2: Create a DispatchWorkgroupsOp for every remaining InsertSliceOp.
+ FailureOr<SmallVector<Flow::DispatchWorkgroupsOp>> newWorkgroupsOps =
+ wrapInWorkgroupsOp<tensor::InsertSliceOp>(rewriter, funcOp,
+ generateWorkloadRegion);
+ if (failed(newWorkgroupsOps)) return signalPassFailure();
+ workgroupsOps.append(newWorkgroupsOps->begin(), newWorkgroupsOps->end());
+
+ // Step 3: Create a DispatchWorkgroupsOp for every remaining ExtractSliceOp.
+ newWorkgroupsOps = wrapInWorkgroupsOp<tensor::ExtractSliceOp>(
+ rewriter, funcOp, generateWorkloadRegion);
+ if (failed(newWorkgroupsOps)) return signalPassFailure();
+ workgroupsOps.append(newWorkgroupsOps->begin(), newWorkgroupsOps->end());
+
+ // A few extra canonicalizations/lowerings.
+ {
+ RewritePatternSet convertToFlowPatterns(context);
+ Flow::populateTensorToFlowConversionPatterns(context,
+ convertToFlowPatterns);
+ memref::populateResolveRankedShapeTypeResultDimsPatterns(
+ convertToFlowPatterns);
+ IREE::Flow::TensorReshapeOp::getCanonicalizationPatterns(
+ convertToFlowPatterns, context);
+ if (failed(applyPatternsAndFoldGreedily(funcOp,
+ std::move(convertToFlowPatterns))))
+ return signalPassFailure();
+
+ // Finally fold `tensor.insert_slice/extract_slice` operations with
+ // `flow.dispatch.tensor.load/store`.
+ RewritePatternSet foldExtractInsertSliceOps(context);
+ Flow::populateTensorSliceOpWithDispatchTensorOpFoldingPatterns(
+ foldExtractInsertSliceOps, context);
+ if (failed(applyPatternsAndFoldGreedily(
+ funcOp, std::move(foldExtractInsertSliceOps))))
+ return signalPassFailure();
+ }
+
+ // Finally walk all the ops and remove the attributes
+ funcOp.walk([](Operation *op) {
+ removeFusionGroupsAttribute(op);
+ removeRootOpAttribute(op);
+ op->removeAttr(linalg::LinalgTransforms::kLinalgTransformMarker);
+ });
+}
+
+std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
+Flow::createDispatchLinalgOnTensorsViaRegionOpsPass(
+ bool generateWorkloadRegion) {
+ return std::make_unique<DispatchLinalgOnTensorsViaRegionOpsPass>(
+ generateWorkloadRegion);
+}
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/InterchangeGenericOps.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/InterchangeGenericOps.cpp
index 3f0966d..0d187ea 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/InterchangeGenericOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/InterchangeGenericOps.cpp
@@ -35,7 +35,7 @@
unsigned numParallelLoop = genericOp.getNumParallelLoops();
if (numParallelLoop == 0) return failure();
for (auto iter : llvm::enumerate(genericOp.iterator_types())) {
- if (isParallelIterator(iter.value())) {
+ if (linalg::isParallelIterator(iter.value())) {
interchange.push_back(iter.index());
if (iter.index() >= numParallelLoop) needInterchange = true;
}
@@ -43,7 +43,7 @@
// If all the parallel loops are outter loops skip the pattern.
if (!needInterchange) return failure();
for (auto iter : llvm::enumerate(genericOp.iterator_types())) {
- if (isReductionIterator(iter.value())) {
+ if (linalg::isReductionIterator(iter.value())) {
interchange.push_back(iter.index());
}
}
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.cpp
index 1492a72..8fbb162 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.cpp
@@ -110,6 +110,17 @@
"the transformations to apply to form dispatch regions."),
llvm::cl::init(""));
+static llvm::cl::opt<bool> clDispatchViaRegionOps(
+ "iree-flow-dispatch-via-region-ops",
+ llvm::cl::desc("Create dispatches via DispatchRegionOps"),
+ llvm::cl::init(false));
+
+static llvm::cl::opt<bool> clDispatchViaRegionOpsGenerateWorkloadRegion(
+ "iree-flow-dispatch-via-region-ops-generate-workload-region",
+ llvm::cl::desc("Generate the workload region when running with "
+ "iree-flow-dispatch-via-region-ops"),
+ llvm::cl::init(true));
+
namespace mlir {
namespace iree_compiler {
namespace IREE {
@@ -187,8 +198,9 @@
.addPass(IREE::Flow::createConvertConv2D1x1ToMatmulPass)
.addPredicatedPass(clEnableConvToImg2Col,
IREE::Flow::createConvertConv2DToImg2ColPass)
- .addPredicatedPass(clDispatchTransformFileName.empty(),
- IREE::Flow::createDetachElementwiseFromNamedOpsPass)
+ .addPredicatedPass(
+ clDispatchTransformFileName.empty() && !clDispatchViaRegionOps,
+ IREE::Flow::createDetachElementwiseFromNamedOpsPass)
// Input should now be legal.
.addPass(IREE::Flow::createVerifyInputLegalityPass)
// Catch matmul ops before we do anything else with them.
@@ -252,8 +264,18 @@
clDispatchTransformFileName);
})
// Only want use the transform dialect for some dispatch regions and let
- // the DispatchLinalgOnTensorsPass unconditionally handle the rest.
- .addPass(createDispatchLinalgOnTensorsPass)
+ // the DispatchLinalgOnTensorsPass handle the rest.
+ .addPredicatedPass(!clDispatchViaRegionOps,
+ createDispatchLinalgOnTensorsPass)
+ // DispatchLinalgOnTensorsViaRegionsPass is a variant of
+ // DispatchLinalgOnTensorsPass that lowers via DispatchRegionOps. This is
+ // on an opt-in basis until the pass is stable enough to replace
+ // DispatchLinalgOnTensorsPass.
+ .addPredicatedPass(clDispatchViaRegionOps,
+ [&]() {
+ return createDispatchLinalgOnTensorsViaRegionOpsPass(
+ clDispatchViaRegionOpsGenerateWorkloadRegion);
+ })
////////////////////////////////////////////////////////////////////////
.addPass(createCaptureDispatchDynamicDimsPass)
.addPass(mlir::createCanonicalizerPass)
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.h b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.h
index 6805810..76b51ae 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.h
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.h
@@ -142,6 +142,13 @@
std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
createDispatchLinalgOnTensorsPass();
+// Pass to perform dispatch of Linalg on tensor ops by tiling and distribution.
+// A dispatch region is created for each tiled loop nest. (First create
+// DispatchRegionOps, then DispatchWorkgroupsOps.)
+std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
+createDispatchLinalgOnTensorsViaRegionOpsPass(
+ bool generateWorkloadRegion = true);
+
// Pass to perform dispatch of Linalg on tensor ops by using the transform
// dialect. Dispatch regions are created as specified by the transform module
// that is parsed from `transformFileName`.
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.td b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.td
index 593726c..b7c85a9 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.td
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.td
@@ -68,6 +68,12 @@
let constructor = "mlir::iree_compiler::IREE::Flow::createDispatchLinalgOnTensorsPass()";
}
+def DispatchLinalgOnTensorsViaRegionOps :
+ InterfacePass<"iree-flow-dispatch-linalg-on-tensors-via-regionops-pass", "mlir::FunctionOpInterface"> {
+ let summary = "Dispatch Linalg operations on tensors by using tile and distribute (via DispatchRegionOps)";
+ let constructor = "mlir::iree_compiler::IREE::Flow::createDispatchLinalgOnTensorsViaRegionOpsPass()";
+}
+
def DispatchWithTransformDialect :
InterfacePass<"iree-flow-dispatch-with-transform-dialect", "mlir::FunctionOpInterface"> {
let summary = "Dispatch Linalg operations on tensors by using the transform dialect interpreter";
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.cpp
new file mode 100644
index 0000000..81cf629
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.cpp
@@ -0,0 +1,369 @@
+// Copyright 2020 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h"
+
+#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
+#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Dominance.h"
+
+using namespace mlir;
+using namespace mlir::iree_compiler;
+using namespace mlir::iree_compiler::IREE;
+
+#define DEBUG_TYPE "iree-flow-region-op-utils"
+
+static SmallVector<Range> getLoopRangesImpl(TilingInterface tilableOp,
+ Location loc, OpBuilder &builder) {
+ SmallVector<Range> loopRanges = tilableOp.getIterationDomain(builder);
+ Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
+ for (auto iteratorType : llvm::enumerate(tilableOp.getLoopIteratorTypes())) {
+ if (iteratorType.value() == getReductionIteratorTypeName()) {
+ loopRanges[iteratorType.index()].size = one;
+ }
+ }
+ return loopRanges;
+}
+
+static SmallVector<Range> getLoopRangesImpl(tensor::InsertSliceOp insertSliceOp,
+ Location loc, OpBuilder &builder) {
+ OpFoldResult zero = builder.getIndexAttr(0);
+ OpFoldResult one = builder.getIndexAttr(1);
+ Value source = insertSliceOp.getSource();
+ SmallVector<Range> loopRanges(insertSliceOp.getSourceType().getRank(),
+ Range{zero, one, one});
+ for (auto dim : llvm::seq<unsigned>(0, loopRanges.size())) {
+ loopRanges[dim].size =
+ builder.create<tensor::DimOp>(loc, source, dim).getResult();
+ }
+ return loopRanges;
+}
+
+static SmallVector<Range> getLoopRangesImpl(tensor::ExtractSliceOp sliceOp,
+ Location loc, OpBuilder &builder) {
+ Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
+ Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
+ ReifiedRankedShapedTypeDims resultDims;
+ LogicalResult status = sliceOp.reifyResultShapes(builder, resultDims);
+ (void)status;
+ assert(succeeded(status) && "reifyResultShapes failed");
+ return llvm::to_vector(llvm::map_range(resultDims[0], [&](Value v) {
+ return Range{zero, v, one};
+ }));
+}
+
+/// For a given operation returns the loop ranges needed to compute the op.
+SmallVector<Range> Flow::getLoopRanges(Operation *op, Location loc,
+ OpBuilder &builder) {
+ return llvm::TypeSwitch<Operation *, SmallVector<Range>>(op)
+ .Case([&](TilingInterface op) {
+ return getLoopRangesImpl(op, loc, builder);
+ })
+ .Case([&](tensor::InsertSliceOp op) {
+ return getLoopRangesImpl(op, loc, builder);
+ })
+ .Case([&](tensor::ExtractSliceOp op) {
+ return getLoopRangesImpl(op, loc, builder);
+ })
+ .Default([](Operation *op) -> SmallVector<Range> {
+ llvm_unreachable("op not supported");
+ });
+}
+
+/// Return `true` if the given type is a ShapedType and has at least one
+/// dynamic dimension.
+static bool hasDynamicShape(Type t) {
+ auto shapedType = t.dyn_cast<ShapedType>();
+ if (!shapedType) return false;
+ return !shapedType.hasStaticShape();
+}
+
+/// Reify the dynamic dimensions of the given value.
+LogicalResult Flow::reifyDynamicResultDims(OpBuilder &b, Value value,
+ SmallVector<Value> &dynamicDims) {
+ OpBuilder::InsertionGuard guard(b);
+
+ // Case 1: No dynamic result dims.
+ if (!hasDynamicShape(value.getType())) return success();
+
+ // There is at least one dynamic dimension, continue...
+ ShapedType shapedType = value.getType().cast<ShapedType>();
+
+ // Case 2: Value is a block argument.
+ if (auto bbArg = value.dyn_cast<BlockArgument>()) {
+ b.setInsertionPointToStart(bbArg.getOwner());
+ for (int64_t i = 0; i < shapedType.getRank(); ++i) {
+ if (shapedType.isDynamicDim(i)) {
+ Value dim = b.create<tensor::DimOp>(bbArg.getLoc(), bbArg, i);
+ dynamicDims.push_back(dim);
+ }
+ }
+ return success();
+ }
+
+ // Value is an OpResult.
+ Operation *op = value.getDefiningOp();
+ OpResult opResult = value.cast<OpResult>();
+ b.setInsertionPoint(op);
+
+ // Case 3: Value is tied. Reify the dimensions of the tied operand.
+ auto tiedOp = dyn_cast<IREE::Util::TiedOpInterface>(op);
+ if (tiedOp) {
+ Value tiedOperand = tiedOp.getTiedResultOperand(value);
+ if (tiedOperand && tiedOperand.getType() == value.getType())
+ return reifyDynamicResultDims(b, tiedOperand, dynamicDims);
+ }
+
+ // Case 4: Query ShapeAwareOpInterface.
+ auto shapeAwareOp = dyn_cast<IREE::Util::ShapeAwareOpInterface>(op);
+ if (shapeAwareOp) {
+ ValueRange dims =
+ shapeAwareOp.getResultDynamicDims(opResult.getResultNumber());
+ dynamicDims.append(dims.begin(), dims.end());
+ return success();
+ }
+
+ // Case 5: Query ReifyRankedShapedTypeOpInterface.
+ auto reifyShapeOp = dyn_cast<ReifyRankedShapedTypeOpInterface>(op);
+ if (reifyShapeOp) {
+ ReifiedRankedShapedTypeDims dims;
+ if (failed(reifyShapeOp.reifyResultShapes(b, dims))) return failure();
+ for (int64_t i = 0; i < shapedType.getRank(); ++i)
+ if (shapedType.isDynamicDim(i))
+ dynamicDims.push_back(dims[opResult.getResultNumber()][i]);
+ return success();
+ }
+
+ return failure();
+}
+
+// Append a result to the given DispatchRegionOp. The newly created
+// DispatchRegionOp is returned.
+FailureOr<Flow::DispatchRegionOp> Flow::appendDispatchRegionResult(
+ RewriterBase &rewriter, Flow::DispatchRegionOp regionOp, Value result) {
+ OpBuilder::InsertionGuard guard(rewriter);
+
+ // Determine dynamic result dims.
+ rewriter.setInsertionPoint(regionOp);
+ SmallVector<Value> dynamicDims(regionOp.getResultDims().begin(),
+ regionOp.getResultDims().end());
+ if (failed(reifyDynamicResultDims(rewriter, result, dynamicDims)))
+ return failure();
+
+ // Determine result types of new RegionOp.
+ SmallVector<Type> resultTypes(regionOp.getResultTypes().begin(),
+ regionOp.getResultTypes().end());
+ resultTypes.push_back(result.getType());
+
+ // Create new DispatchRegionOp and move over the body.
+ auto newRegionOp = rewriter.create<Flow::DispatchRegionOp>(
+ regionOp->getLoc(), resultTypes, dynamicDims);
+ newRegionOp.getBody().takeBody(regionOp.getBody());
+ rewriter.replaceOp(
+ regionOp, newRegionOp.getResults().take_front(regionOp->getNumResults()));
+
+ // Update terminator.
+ Flow::ReturnOp returnOp =
+ cast<Flow::ReturnOp>(newRegionOp.getBody().front().getTerminator());
+ SmallVector<Value> returnedValues(returnOp.getOperands().begin(),
+ returnOp.getOperands().end());
+ returnedValues.push_back(result);
+ returnOp.operandsMutable().assign(returnedValues);
+
+ return newRegionOp;
+}
+
+Flow::DispatchRegionOp Flow::makeEmptyDispatchRegion(OpBuilder &builder,
+ Location loc) {
+ OpBuilder::InsertionGuard guard(builder);
+
+ // Create RegionOp.
+ auto regionOp = builder.create<Flow::DispatchRegionOp>(
+ loc, /*resultTypes=*/TypeRange(), /*dynamicDims=*/ValueRange());
+ Block &body = regionOp.getBody().emplaceBlock();
+ builder.setInsertionPointToStart(&body);
+ builder.create<Flow::ReturnOp>(loc, ValueRange());
+
+ return regionOp;
+}
+
+// Clone a `target` op that is preceding the given dispatch region op into the
+// dispatch region.
+LogicalResult Flow::clonePrecedingOpIntoDispatchRegion(
+ RewriterBase &rewriter, Operation *target,
+ Flow::DispatchRegionOp regionOp) {
+ Block &body = regionOp.getBody().front();
+
+ // Gather all uses of `target`.
+ SmallVector<OpOperand *> usesInsideOfRegion;
+ for (OpOperand &use : target->getUses()) {
+ if (regionOp->isProperAncestor(use.getOwner()))
+ usesInsideOfRegion.push_back(&use);
+ }
+
+ // Clone op into dispatch region.
+ OpBuilder::InsertionGuard guard(rewriter);
+ rewriter.setInsertionPointToStart(&body);
+ Operation *newTargetOp = rewriter.clone(*target);
+
+ // Replace all uses in the dispatch region.
+ for (OpOperand *use : usesInsideOfRegion) {
+ rewriter.updateRootInPlace(use->getOwner(), [&]() {
+ use->set(newTargetOp->getResult(
+ use->get().cast<OpResult>().getResultNumber()));
+ });
+ }
+
+ return success();
+}
+
+// Move a `target` op that is preceding the given dispatch region op into the
+// dispatch region.
+FailureOr<Flow::DispatchRegionOp> Flow::movePrecedingOpIntoDispatchRegion(
+ RewriterBase &rewriter, Operation *target,
+ Flow::DispatchRegionOp regionOp) {
+#ifndef NDEBUG
+ DominanceInfo domInfo;
+ for (OpOperand &use : target->getUses()) {
+ if (regionOp->isProperAncestor(use.getOwner())) continue;
+ assert(domInfo.properlyDominates(regionOp, use.getOwner()) &&
+ "found use that does not post-dominate target");
+ }
+#endif // NDEBUG
+
+ Block &body = regionOp.getBody().front();
+
+ // Gather all uses of `target`.
+ SmallVector<OpOperand *> usesOutsideOfRegion;
+ for (OpOperand &use : target->getUses())
+ if (!regionOp->isProperAncestor(use.getOwner()))
+ usesOutsideOfRegion.push_back(&use);
+
+ // Move op into dispatch region.
+ target->moveBefore(&body.front());
+
+ // Replace all uses outside of the dispatch region.
+ if (!usesOutsideOfRegion.empty()) {
+ unsigned previousNumResults = regionOp->getNumResults();
+
+ // Note: Appending results one-by-one here so that this can be extended to
+ // specific results in the future. Many ops have just one result, so this
+ // should not be a large overhead.
+ for (Value v : target->getResults()) {
+ auto newRegionOp = appendDispatchRegionResult(rewriter, regionOp, v);
+ if (failed(newRegionOp)) return failure();
+ regionOp = *newRegionOp;
+ }
+
+ // Replace uses of `target` after the dispatch region.
+ for (OpOperand *use : usesOutsideOfRegion) {
+ rewriter.updateRootInPlace(use->getOwner(), [&]() {
+ use->set(
+ regionOp->getResult(previousNumResults +
+ use->get().cast<OpResult>().getResultNumber()));
+ });
+ }
+ }
+
+ return regionOp;
+}
+
+FailureOr<Flow::DispatchRegionOp> Flow::wrapOpInDispatchRegion(
+ RewriterBase &rewriter, Operation *op) {
+ // Make an empty dispatch region right before the op.
+ rewriter.setInsertionPointAfter(op);
+ Flow::DispatchRegionOp regionOp =
+ Flow::makeEmptyDispatchRegion(rewriter, op->getLoc());
+
+ // Move the op into the dispatch region.
+ auto newRegionOp = movePrecedingOpIntoDispatchRegion(rewriter, op, regionOp);
+ return newRegionOp;
+}
+
+/// Reorders the operations in `ops` such that they could be inlined into the
+/// dispatch region in that order to satisfy dependencies.
+SmallVector<Operation *> Flow::orderOperations(ArrayRef<Operation *> ops) {
+ LLVM_DEBUG({
+ llvm::dbgs() << "Ops to be inlined :\n";
+ for (auto op : ops) {
+ llvm::dbgs() << "\t";
+ op->print(llvm::dbgs());
+ llvm::dbgs() << "\n";
+ }
+ });
+
+ llvm::SmallMapVector<Operation *, SmallVector<Operation *>, 16>
+ insertAfterMap;
+ llvm::SetVector<Operation *> opSet(ops.begin(), ops.end());
+ llvm::SetVector<Operation *> leafOps(ops.begin(), ops.end());
+ // For each operation compute the list of operations in `ops` that use its
+ // results. Also compute the operations that form the leafs of the DAG of
+ // operations in `ops`.
+ for (auto op : ops) {
+ for (auto operand : op->getOperands()) {
+ auto definingOp = operand.getDefiningOp();
+ if (!definingOp || !opSet.count(definingOp)) continue;
+ insertAfterMap[definingOp].push_back(op);
+ if (leafOps.count(op)) leafOps.remove(op);
+ }
+ }
+
+ // The leaves are at the head of the ordered list.
+ SmallVector<Operation *> orderedOps(leafOps.begin(), leafOps.end());
+ orderedOps.reserve(ops.size());
+ llvm::SmallPtrSet<Operation *, 16> processed;
+ processed.insert(leafOps.begin(), leafOps.end());
+
+ // `readyOps` contains the list of operations that have been just added to the
+ // `orderedOps` list. With these marked ready, they might make further
+ // operations in `ops` ready as well.
+ // The complexity of the algorithm is driven by these
+ // - Each operations is added to `readyOps` list at most once, and is removed
+ // after being processed
+ // - For every operation in `readyOps` every use of its results (within `ops`)
+ // is looked at once.
+ // - For every use, the operands of the user are processed.
+ // Assuming operands is O(1), i.e. constant order, the complexity is O(sum of
+ // number of uses of each operation). Given that the size of `ops` is at max
+ // O(10), and not O(100), this is assumed to be reasonable.
+ ArrayRef<Operation *> readyOps(orderedOps);
+ size_t startPos = 0;
+ while (!readyOps.empty()) {
+ auto op = readyOps.front();
+ startPos++;
+ // Check all uses of `op` within `ops`. If all of the operations that define
+ // the operands of the user have been added to `orderedOps`, then the user
+ // is ready to be scheduled.
+ for (auto insertAfterOp : insertAfterMap[op]) {
+ if (processed.count(insertAfterOp)) continue;
+ if (llvm::all_of(insertAfterOp->getOperands(), [&](Value operand) {
+ Operation *operandDefiningOp = operand.getDefiningOp();
+ return !operandDefiningOp || !opSet.count(operandDefiningOp) ||
+ processed.count(operandDefiningOp);
+ })) {
+ // readyOps.push_back(insertAfterOp);
+ orderedOps.push_back(insertAfterOp);
+ processed.insert(insertAfterOp);
+ }
+ }
+ readyOps = ArrayRef<Operation *>(orderedOps).drop_front(startPos);
+ }
+
+ LLVM_DEBUG({
+ llvm::dbgs() << "Ops to be inlined (sorted) : \n";
+ for (auto op : orderedOps) {
+ llvm::dbgs() << "\t";
+ op->print(llvm::dbgs());
+ llvm::dbgs() << "\n";
+ }
+ });
+ assert(orderedOps.size() == ops.size() &&
+ "ordering of inlined operations failed");
+ return orderedOps;
+}
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h b/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h
new file mode 100644
index 0000000..7d107b3
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h
@@ -0,0 +1,99 @@
+// Copyright 2022 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+#ifndef IREE_COMPILER_DIALECT_FLOW_TRANSFORMS_REGIONOPUTILS_H_
+#define IREE_COMPILER_DIALECT_FLOW_TRANSFORMS_REGIONOPUTILS_H_
+
+#include "llvm/ADT/SmallVector.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/Support/LogicalResult.h"
+
+namespace mlir {
+class Location;
+class OpBuilder;
+class Operation;
+class RewriterBase;
+class Value;
+
+namespace iree_compiler {
+namespace IREE {
+namespace Flow {
+class DispatchRegionOp;
+
+/// For a given operation returns the loop ranges needed to compute the op.
+SmallVector<Range> getLoopRanges(Operation *op, Location loc,
+ OpBuilder &builder);
+
+/// Reify the dynamic dimensions of the given value.
+LogicalResult reifyDynamicResultDims(OpBuilder &b, Value value,
+ SmallVector<Value> &dynamicDims);
+
+/// Append a result to the given DispatchRegionOp. The newly created
+/// DispatchRegionOp is returned.
+FailureOr<Flow::DispatchRegionOp> appendDispatchRegionResult(
+ RewriterBase &rewriter, Flow::DispatchRegionOp regionOp, Value result);
+
+/// Create an empty DispatchRegionOp.
+Flow::DispatchRegionOp makeEmptyDispatchRegion(OpBuilder &builder,
+ Location loc);
+
+/// Clone a `target` op that is preceding the given dispatch region op into the
+/// dispatch region.
+///
+/// All uses of the target inside of the dispatch region are replaced with the
+/// results of the cloned op.
+///
+/// Example:
+///
+/// %0 = "some_op"() : () -> (tensor<?xf32>)
+/// %r = flow.dispatch.region -> (tensor<?xf32>{%d0}) {
+/// %1 = "another_op"(%0) : (tensor<?xf32>) -> (tensor<?xf32>)
+/// flow.return %1 : tensor<?xf32>
+/// }
+/// %2 = "yet_another_use"(%0) : (tensor<?xf32>) -> (tensor<?xf32>)
+LogicalResult clonePrecedingOpIntoDispatchRegion(
+ RewriterBase &rewriter, Operation *target, Flow::DispatchRegionOp regionOp);
+
+/// Move a `target` op that is preceding the given dispatch region op into the
+/// dispatch region.
+///
+/// All uses of the target outside of the dispatch region are replaced with the
+/// results of the cloned op.
+///
+/// Example:
+///
+/// %0 = "some_op"() : () -> (tensor<?xf32>)
+/// %r = flow.dispatch.region -> (tensor<?xf32>{%d0}) {
+/// %0_clone = "some_op"() : () -> (tensor<?xf32>)
+/// %1 = "another_op"(%0_clone) : (tensor<?xf32>) -> (tensor<?xf32>)
+/// flow.return %1 : tensor<?xf32>
+/// }
+/// %2 = "yet_another_use"(%0) : (tensor<?xf32>) -> (tensor<?xf32>)
+FailureOr<Flow::DispatchRegionOp> movePrecedingOpIntoDispatchRegion(
+ RewriterBase &rewriter, Operation *target, Flow::DispatchRegionOp regionOp);
+
+/// Wrap the given op in a new dispatch region op.
+FailureOr<Flow::DispatchRegionOp> wrapOpInDispatchRegion(RewriterBase &rewriter,
+ Operation *op);
+
+/// Sort the given ops topologically, so that they can be inlined into a
+/// dispatch region without dominance violations.
+///
+/// Example:
+///
+/// %0 = "some_op"()
+/// %1 = "another_op"(%1)
+///
+/// In the above example, "some_op" is before "another_op" in the result.
+// TODO: Improve mlir::sortTopologically. This function does currently not
+// support ops from different blocks.
+SmallVector<Operation *> orderOperations(ArrayRef<Operation *> ops);
+
+} // namespace Flow
+} // namespace IREE
+} // namespace iree_compiler
+} // namespace mlir
+
+#endif // IREE_COMPILER_DIALECT_FLOW_TRANSFORMS_REGIONOPUTILS_H_
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/detach_elementwise_from_named_ops.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/detach_elementwise_from_named_ops.mlir
index aedca26..acd5fa8 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/detach_elementwise_from_named_ops.mlir
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/detach_elementwise_from_named_ops.mlir
@@ -1,17 +1,26 @@
// RUN: iree-opt --split-input-file --iree-flow-detach-elementwise-from-named-ops --mlir-print-local-scope %s | FileCheck %s
func.func @matmul(%a: tensor<?x64xf32>, %b: tensor<64x?xf32>, %c: tensor<?x?xf32>) -> tensor<?x?xf32> {
- %0 = linalg.matmul ins(%a, %b : tensor<?x64xf32>, tensor<64x?xf32>) outs(%c : tensor<?x?xf32>) -> tensor<?x?xf32>
- return %0 : tensor<?x?xf32>
+ %0 = linalg.generic {
+ indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%c : tensor<?x?xf32>) outs(%c : tensor<?x?xf32>) {
+ ^bb0(%b0 : f32, %b1 : f32):
+ %1 = arith.addf %b0, %b0 : f32
+ linalg.yield %1 : f32
+ } -> tensor<?x?xf32>
+ %1 = linalg.matmul ins(%a, %b : tensor<?x64xf32>, tensor<64x?xf32>) outs(%0 : tensor<?x?xf32>) -> tensor<?x?xf32>
+ return %1 : tensor<?x?xf32>
}
// CHECK-LABEL: func @matmul
-// CHECK-SAME: (%[[A:.+]]: tensor<?x64xf32>, %[[B:.+]]: tensor<64x?xf32>, %[[C:.+]]: tensor<?x?xf32>)
+// CHECK-SAME: (%[[A:.+]]: tensor<?x64xf32>, %[[B:.+]]: tensor<64x?xf32>, %[[ARG2:.+]]: tensor<?x?xf32>)
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
// CHECK-DAG: %[[F0:.+]] = arith.constant 0.000000e+00 : f32
-
+// CHECK: %[[C:.+]] = linalg.generic
+// CHECK-SAME: ins(%[[ARG2]] :
// CHECK: %[[DIM0:.+]] = tensor.dim %[[C]], %[[C0]]
// CHECK: %[[DIM1:.+]] = tensor.dim %[[C]], %[[C1]]
// CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[DIM0]], %[[DIM1]]]
@@ -32,16 +41,25 @@
// -----
func.func @batch_matmul(%a: tensor<?x8x?xi32>, %b: tensor<?x?x16xi32>, %c: tensor<?x8x16xi32>) -> tensor<?x8x16xi32> {
- %0 = linalg.batch_matmul ins(%a, %b : tensor<?x8x?xi32>, tensor<?x?x16xi32>) outs(%c : tensor<?x8x16xi32>) -> tensor<?x8x16xi32>
- return %0 : tensor<?x8x16xi32>
+ %0 = linalg.generic {
+ indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
+ iterator_types = ["parallel", "parallel", "parallel"]}
+ ins(%c : tensor<?x8x16xi32>) outs(%c : tensor<?x8x16xi32>) {
+ ^bb0(%b0 : i32, %b1 : i32):
+ %1 = arith.addi %b0, %b0 : i32
+ linalg.yield %1 : i32
+ } -> tensor<?x8x16xi32>
+ %1 = linalg.batch_matmul ins(%a, %b : tensor<?x8x?xi32>, tensor<?x?x16xi32>) outs(%0 : tensor<?x8x16xi32>) -> tensor<?x8x16xi32>
+ return %1 : tensor<?x8x16xi32>
}
// CHECK-LABEL: func @batch_matmul
-// CHECK-SAME: (%[[A:.+]]: tensor<?x8x?xi32>, %[[B:.+]]: tensor<?x?x16xi32>, %[[C:.+]]: tensor<?x8x16xi32>)
+// CHECK-SAME: (%[[A:.+]]: tensor<?x8x?xi32>, %[[B:.+]]: tensor<?x?x16xi32>, %[[ARG2:.+]]: tensor<?x8x16xi32>)
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[I0:.+]] = arith.constant 0 : i32
-
+// CHECK: %[[C:.+]] = linalg.generic
+// CHECK-SAME: ins(%[[ARG2]] :
// CHECK: %[[DIM0:.+]] = tensor.dim %[[C]], %[[C0]] : tensor<?x8x16xi32>
// CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[DIM0]], 8, 16] : tensor<?x8x16xi32>
// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[I0]] : i32) outs(%[[INIT]] : tensor<?x8x16xi32>) -> tensor<?x8x16xi32>
@@ -57,14 +75,24 @@
// -----
-func.func @conv(%input: tensor<1x225x225x3xf32>, %filter: tensor<3x3x3x32xf32>, %init: tensor<1x112x112x32xf32>) -> tensor<1x112x112x32xf32> {
- %0 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>}
- ins(%input, %filter : tensor<1x225x225x3xf32>, tensor<3x3x3x32xf32>) outs(%init : tensor<1x112x112x32xf32>) -> tensor<1x112x112x32xf32>
- return %0 : tensor<1x112x112x32xf32>
+func.func @conv(%input: tensor<1x225x225x3xf32>, %filter: tensor<3x3x3x32xf32>, %init: tensor<32xf32>) -> tensor<1x112x112x32xf32> {
+ %init0 = linalg.init_tensor [1, 112, 112, 32] : tensor<1x112x112x32xf32>
+ %0 = linalg.generic {
+ indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>],
+ iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
+ ins(%init : tensor<32xf32>) outs(%init0 : tensor<1x112x112x32xf32>) {
+ ^bb0(%b0 : f32, %b1 : f32):
+ linalg.yield %b0 : f32
+ } -> tensor<1x112x112x32xf32>
+ %1 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>}
+ ins(%input, %filter : tensor<1x225x225x3xf32>, tensor<3x3x3x32xf32>) outs(%0 : tensor<1x112x112x32xf32>) -> tensor<1x112x112x32xf32>
+ return %1 : tensor<1x112x112x32xf32>
}
// CHECK-LABEL: func @conv
-// CHECK-SAME: (%{{.+}}: tensor<1x225x225x3xf32>, %{{.+}}: tensor<3x3x3x32xf32>, %[[INIT:.+]]: tensor<1x112x112x32xf32>)
+// CHECK-SAME: (%{{.+}}: tensor<1x225x225x3xf32>, %{{.+}}: tensor<3x3x3x32xf32>, %[[BIAS:.+]]: tensor<32xf32>)
+// CHECK: %[[INIT:.+]] = linalg.generic
+// CHECK-SAME: ins(%[[BIAS]] :
// CHECK: %[[FILL:.+]] = linalg.fill
// CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf
// CHECK: linalg.generic
@@ -73,6 +101,33 @@
// -----
+func.func @keep_fill(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %cst = arith.constant 0.0 : f32
+ %d0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
+ %d1 = tensor.dim %arg1, %c1 : tensor<?x?xf32>
+ %init = linalg.init_tensor [%d0, %d1] : tensor<?x?xf32>
+ %fill = linalg.fill ins(%cst : f32) outs(%init : tensor<?x?xf32>) -> tensor<?x?xf32>
+ %gemm = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%fill : tensor<?x?xf32>) -> tensor<?x?xf32>
+ return %gemm : tensor<?x?xf32>
+}
+// CHECK-LABEL: func.func @keep_fill
+// CHECK-NOT: linalg.generic
+
+// -----
+
+func.func @keep_arg(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>, %arg2 : tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %0 = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+// CHECK-LABEL: func.func @keep_arg
+// CHECK-NOT: linalg.generic
+
+// -----
+
func.func @fft_cst_output(%arg0 : tensor<3x2190x1x512xf32>)
-> (tensor<3x2190x1x512xf32>, tensor<3x2190x1x512xf32>) {
%c1 = arith.constant 1 : index
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/transform_dialect_dispatch_spec.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/transform_dialect_dispatch_spec.mlir
index 2274e22..87315ba 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/transform_dialect_dispatch_spec.mlir
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/transform_dialect_dispatch_spec.mlir
@@ -1,9 +1,6 @@
-transform.with_pdl_patterns {
-^bb0(%arg0: !pdl.operation):
- transform.structured.canonicalized_sequence %arg0 failures(propagate) {
- ^bb1(%arg1: !pdl.operation):
- %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1
- %foreach_op, %tiled_op = transform.structured.tile_to_foreach_thread_op %0 num_threads [42, 67]
- %dispatch_op = transform.iree.foreach_thread_to_flow %foreach_op
- }
+transform.structured.canonicalized_sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1
+ %foreach_op, %tiled_op = transform.structured.tile_to_foreach_thread_op %0 num_threads [42, 67]
+ %dispatch_op = transform.iree.foreach_thread_to_flow %foreach_op
}
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/transform_dispatch_region_formation.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/transform_dispatch_region_formation.mlir
index d8a1061..687eeb5 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/transform_dispatch_region_formation.mlir
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/transform_dispatch_region_formation.mlir
@@ -84,7 +84,7 @@
%0 = transform.structured.match ops{["tensor.insert_slice"]} in %arg1
%dispatch_op = transform.iree.wrap_in_dispatch_region %0
%1 = transform.structured.match ops{["tensor.extract_slice"]} in %arg1
- transform.iree.clone_preceding_op_into_dispatch_region %1 into %dispatch_op {update_uses_outside_of_region = true}
+ transform.iree.move_preceding_op_into_dispatch_region %1 into %dispatch_op
}
}
@@ -116,20 +116,18 @@
// -----
-// CHECK-LABEL: func @move_multiple_preceding
+// CHECK-LABEL: func @clone_multiple_preceding
// CHECK-DAG: arith.constant
// CHECK-DAG: arith.constant
// CHECK-DAG: tensor.dim
// CHECK-DAG: tensor.dim
-// CHECK-NEXT: "test.dummy_op"
-// CHECK-NEXT: "test.third_user"
-// CHECK-NEXT: flow.dispatch.region
+// CHECK: flow.dispatch.region
// CHECK-NEXT: "test.dummy_op"
// CHECK-NEXT: "test.first_user"
// CHECK-NEXT: "test.second_user"
// CHECK-NEXT: "test.merge1"
// CHECK-NEXT: "test.merge2"
-func.func @move_multiple_preceding(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %s1: index, %s2: index) -> (tensor<?x?xf32>) {
+func.func @clone_multiple_preceding(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %s1: index, %s2: index) -> (tensor<?x?xf32>) {
%0 = "test.dummy_op"(%arg0) {__tagged__} : (tensor<?x?xf32>) -> (tensor<?x?xf32>)
%1 = "test.first_user"(%0) {__tagged__} : (tensor<?x?xf32>) -> (tensor<?x?xf32>)
%2 = "test.second_user"(%0) {__tagged__} : (tensor<?x?xf32>) -> (tensor<?x?xf32>)
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/BUILD b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/BUILD
index ea644e0..59d8737 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/BUILD
+++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/BUILD
@@ -24,7 +24,6 @@
"ConvertExperimentalOps.cpp",
"ConvertFenceOps.cpp",
"ConvertHALToVM.cpp",
- "ConvertSemaphoreOps.cpp",
],
hdrs = [
"ConvertHALToVM.h",
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/CMakeLists.txt
index efa8b6f..29261e7 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/CMakeLists.txt
@@ -25,7 +25,6 @@
"ConvertExperimentalOps.cpp"
"ConvertFenceOps.cpp"
"ConvertHALToVM.cpp"
- "ConvertSemaphoreOps.cpp"
DEPS
LLVMSupport
MLIRArithmeticDialect
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertFenceOps.cpp b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertFenceOps.cpp
index a674336..e9fe827 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertFenceOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertFenceOps.cpp
@@ -6,57 +6,21 @@
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
#include "iree/compiler/Dialect/VM/Conversion/ImportUtils.h"
-#include "iree/compiler/Dialect/VM/IR/VMOps.h"
#include "mlir/Transforms/DialectConversion.h"
namespace mlir {
namespace iree_compiler {
-namespace {
-
-struct FenceCreateOpConversion
- : public OpConversionPattern<IREE::HAL::FenceCreateOp> {
- FenceCreateOpConversion(MLIRContext *context, SymbolTable &importSymbols,
- TypeConverter &typeConverter, StringRef importName)
- : OpConversionPattern(context) {
- importOp = importSymbols.lookup<IREE::VM::ImportOp>(importName);
- assert(importOp);
- }
- LogicalResult matchAndRewrite(
- IREE::HAL::FenceCreateOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- auto importType = importOp.getFunctionType();
-
- SmallVector<Value, 8> callOperands;
- SmallVector<int16_t, 5> segmentSizes = {
- /*timepoints=*/
- static_cast<int16_t>(adaptor.getSemaphores().size()),
- };
- for (auto it : llvm::zip(adaptor.getSemaphores(), adaptor.getMinValues())) {
- callOperands.push_back(std::get<0>(it));
- callOperands.push_back(std::get<1>(it));
- }
-
- auto callOp = rewriter.replaceOpWithNewOp<IREE::VM::CallVariadicOp>(
- op, SymbolRefAttr::get(importOp), importType.getResults(), segmentSizes,
- importType.getInputs(), callOperands);
- copyImportAttrs(importOp, callOp);
- return success();
- }
-
- mutable IREE::VM::ImportOp importOp;
-};
-
-} // namespace
-
void populateHALFenceToVMPatterns(MLIRContext *context,
SymbolTable &importSymbols,
TypeConverter &typeConverter,
RewritePatternSet &patterns) {
- patterns.insert<FenceCreateOpConversion>(context, importSymbols,
- typeConverter, "hal.fence.create");
+ patterns.insert<VMImportOpConversion<IREE::HAL::FenceCreateOp>>(
+ context, importSymbols, typeConverter, "hal.fence.create");
patterns.insert<VMImportOpConversion<IREE::HAL::FenceJoinOp>>(
context, importSymbols, typeConverter, "hal.fence.join");
+ patterns.insert<VMImportOpConversion<IREE::HAL::FenceQueryOp>>(
+ context, importSymbols, typeConverter, "hal.fence.query");
patterns.insert<VMImportOpConversion<IREE::HAL::FenceSignalOp>>(
context, importSymbols, typeConverter, "hal.fence.signal");
patterns.insert<VMImportOpConversion<IREE::HAL::FenceFailOp>>(
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertHALToVM.cpp b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertHALToVM.cpp
index 2f9f45d..fda5153 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertHALToVM.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertHALToVM.cpp
@@ -60,10 +60,6 @@
SymbolTable &importSymbols,
TypeConverter &typeConverter,
RewritePatternSet &patterns);
-extern void populateHALSemaphoreToVMPatterns(MLIRContext *context,
- SymbolTable &importSymbols,
- TypeConverter &typeConverter,
- RewritePatternSet &patterns);
void populateHALToVMPatterns(MLIRContext *context, SymbolTable &importSymbols,
RewritePatternSet &patterns,
@@ -83,8 +79,6 @@
populateHALExperimentalToVMPatterns(context, importSymbols, typeConverter,
patterns);
populateHALFenceToVMPatterns(context, importSymbols, typeConverter, patterns);
- populateHALSemaphoreToVMPatterns(context, importSymbols, typeConverter,
- patterns);
}
namespace {
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertSemaphoreOps.cpp b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertSemaphoreOps.cpp
deleted file mode 100644
index 1972b68..0000000
--- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertSemaphoreOps.cpp
+++ /dev/null
@@ -1,32 +0,0 @@
-// Copyright 2019 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
-#include "iree/compiler/Dialect/VM/Conversion/ImportUtils.h"
-#include "iree/compiler/Dialect/VM/IR/VMOps.h"
-#include "mlir/Transforms/DialectConversion.h"
-
-namespace mlir {
-namespace iree_compiler {
-
-void populateHALSemaphoreToVMPatterns(MLIRContext *context,
- SymbolTable &importSymbols,
- TypeConverter &typeConverter,
- RewritePatternSet &patterns) {
- patterns.insert<VMImportOpConversion<IREE::HAL::SemaphoreCreateOp>>(
- context, importSymbols, typeConverter, "hal.semaphore.create");
- patterns.insert<VMImportOpConversion<IREE::HAL::SemaphoreQueryOp>>(
- context, importSymbols, typeConverter, "hal.semaphore.query");
- patterns.insert<VMImportOpConversion<IREE::HAL::SemaphoreSignalOp>>(
- context, importSymbols, typeConverter, "hal.semaphore.signal");
- patterns.insert<VMImportOpConversion<IREE::HAL::SemaphoreFailOp>>(
- context, importSymbols, typeConverter, "hal.semaphore.fail");
- patterns.insert<VMImportOpConversion<IREE::HAL::SemaphoreAwaitOp>>(
- context, importSymbols, typeConverter, "hal.semaphore.await");
-}
-
-} // namespace iree_compiler
-} // namespace mlir
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/fence_ops.mlir b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/fence_ops.mlir
index 940087a..6eb570e 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/fence_ops.mlir
+++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/fence_ops.mlir
@@ -1,17 +1,11 @@
// RUN: iree-opt --split-input-file --iree-convert-hal-to-vm %s | FileCheck %s
// CHECK-LABEL: @fence_create
-// CHECK-SAME: (%[[SEMAPHORE0:.+]]: !vm.ref<!hal.semaphore>, %[[TIME0:.+]]: i64,
-// CHECK-SAME: %[[SEMAPHORE1:.+]]: !vm.ref<!hal.semaphore>, %[[TIME1:.+]]: i64)
-func.func @fence_create(
- %semaphore0: !hal.semaphore, %time0: i64,
- %semaphore1: !hal.semaphore, %time1: i64) -> !hal.fence {
- // CHECK: %[[FENCE:.+]] = vm.call.variadic @hal.fence.create
- // CHECK-SAME: ([(%[[SEMAPHORE0]], %[[TIME0]]), (%[[SEMAPHORE1]], %[[TIME1]])])
- %fence = hal.fence.create
- at<%semaphore0 : !hal.semaphore>(%time0)
- at<%semaphore1 : !hal.semaphore>(%time1)
- -> !hal.fence
+// CHECK-SAME: (%[[DEVICE:.+]]: !vm.ref<!hal.device>)
+func.func @fence_create(%device: !hal.device) -> !hal.fence {
+ // CHECK: %[[FLAGS:.+]] = vm.const.i32.zero
+ // CHECK: %[[FENCE:.+]] = vm.call @hal.fence.create(%[[DEVICE]], %[[FLAGS]])
+ %fence = hal.fence.create device(%device : !hal.device) flags("None") : !hal.fence
// CHECK: vm.return %[[FENCE]]
return %fence : !hal.fence
}
@@ -30,6 +24,17 @@
// -----
+// CHECK-LABEL: @fence_query
+// CHECK-SAME: (%[[FENCE:.+]]: !vm.ref<!hal.fence>)
+func.func @fence_query(%fence: !hal.fence) -> i32 {
+ // CHECK: %[[STATUS:.+]] = vm.call @hal.fence.query(%[[FENCE]])
+ %status = hal.fence.query<%fence : !hal.fence> : i32
+ // CHECK: vm.return %[[STATUS]]
+ return %status : i32
+}
+
+// -----
+
// CHECK-LABEL: @fence_signal
// CHECK-SAME: (%[[FENCE:.+]]: !vm.ref<!hal.fence>)
func.func @fence_signal(%fence: !hal.fence) {
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/ConvertStreamToHAL.cpp b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/ConvertStreamToHAL.cpp
index 595fede..5508286 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/ConvertStreamToHAL.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/ConvertStreamToHAL.cpp
@@ -255,8 +255,9 @@
// Gather wait/signal fence, which are optional.
Value waitFence =
getOrCreateWaitFence(loc, adaptor.getAwaitTimepoint(), rewriter);
- Value signalFence = rewriter.create<IREE::HAL::TimelineAdvanceOp>(
- loc, rewriter.getType<IREE::HAL::FenceType>());
+ Value signalFence = rewriter.create<IREE::HAL::FenceCreateOp>(
+ loc, rewriter.getType<IREE::HAL::FenceType>(), device,
+ IREE::HAL::FenceFlagBitfield::None);
// Queue allocation.
auto queueAffinity = rewriter.create<arith::ConstantIntOp>(loc, -1, 64);
@@ -282,8 +283,9 @@
// Gather wait/signal fence, which are optional.
Value waitFence =
getOrCreateWaitFence(loc, adaptor.getAwaitTimepoint(), rewriter);
- Value signalFence = rewriter.create<IREE::HAL::TimelineAdvanceOp>(
- loc, rewriter.getType<IREE::HAL::FenceType>());
+ Value signalFence = rewriter.create<IREE::HAL::FenceCreateOp>(
+ loc, rewriter.getType<IREE::HAL::FenceType>(), device,
+ IREE::HAL::FenceFlagBitfield::None);
// Queue allocation.
auto queueAffinity = rewriter.create<arith::ConstantIntOp>(loc, -1, 64);
@@ -569,8 +571,10 @@
}
rewriter.replaceOpWithNewOp<IREE::HAL::BufferViewCreateOp>(
- exportOp, adaptor.getSource(), elementType.value(),
- encodingType.value(), dims);
+ exportOp, adaptor.getSource(),
+ rewriter.create<arith::ConstantIndexOp>(loc, 0),
+ adaptor.getSourceSize(), elementType.value(), encodingType.value(),
+ dims);
return success();
}
};
@@ -855,8 +859,9 @@
// Gather wait/signal fence, which are optional.
Value waitFence =
getOrCreateWaitFence(loc, adaptor.getAwaitTimepoint(), rewriter);
- Value signalFence = rewriter.create<IREE::HAL::TimelineAdvanceOp>(
- loc, rewriter.getType<IREE::HAL::FenceType>());
+ Value signalFence = rewriter.create<IREE::HAL::FenceCreateOp>(
+ loc, rewriter.getType<IREE::HAL::FenceType>(), device,
+ IREE::HAL::FenceFlagBitfield::None);
// Queue execution.
auto queueAffinity = rewriter.create<arith::ConstantIntOp>(loc, -1, 64);
@@ -928,17 +933,9 @@
operands[0].getType().isa<IREE::HAL::FenceType>()) {
rewriter.replaceOp(importOp, operands[0]);
return success();
- } else if (operands.size() == 2 &&
- operands[0].getType().isa<IREE::HAL::SemaphoreType>() &&
- operands[1].getType().isIntOrIndex()) {
- rewriter.replaceOpWithNewOp<IREE::HAL::FenceCreateOp>(
- importOp, rewriter.getType<IREE::HAL::FenceType>(),
- ValueRange{operands[0]}, ValueRange{operands[1]});
- return success();
} else {
- return rewriter.notifyMatchFailure(importOp,
- "only imports from HAL semaphore + "
- "sequence value tuples are supported");
+ return rewriter.notifyMatchFailure(
+ importOp, "only imports from HAL fences are supported");
}
}
};
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/cmd_ops.mlir b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/cmd_ops.mlir
index f7e2a89..507e76a 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/cmd_ops.mlir
+++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/cmd_ops.mlir
@@ -90,7 +90,7 @@
}
} => !stream.timepoint
// CHECK-NEXT: hal.command_buffer.finalize<%[[CMD]]
- // CHECK: %[[SIGNAL_FENCE:.+]] = hal.timeline.advance
+ // CHECK: %[[SIGNAL_FENCE:.+]] = hal.fence.create
// CHECK: hal.device.queue.execute
// CHECK-SAME: affinity(%c-1
// CHECK-SAME: wait(%arg4)
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/resource_ops.mlir b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/resource_ops.mlir
index 6a1f1ee..5fc5f9d 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/resource_ops.mlir
+++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/resource_ops.mlir
@@ -21,7 +21,7 @@
// CHECK-SAME: (%[[SIZE:.+]]: index)
func.func @resourceAlloca(%size: index) -> (!stream.resource<staging>, !stream.timepoint) {
// CHECK: %[[WAIT_FENCE:.+]] = util.null : !hal.fence
- // CHECK: %[[SIGNAL_FENCE:.+]] = hal.timeline.advance
+ // CHECK: %[[SIGNAL_FENCE:.+]] = hal.fence.create
// CHECK: %[[RET0:.+]] = hal.device.queue.alloca
// CHECK-SAME: affinity(%c-1
// CHECK-SAME: wait(%[[WAIT_FENCE]])
@@ -40,7 +40,7 @@
// CHECK-LABEL: @resourceAllocaAwait
// CHECK-SAME: (%[[SIZE:.+]]: index, %[[WAIT_FENCE:.+]]: !hal.fence)
func.func @resourceAllocaAwait(%size: index, %await_timepoint: !stream.timepoint) -> (!stream.resource<staging>, !stream.timepoint) {
- // CHECK: %[[SIGNAL_FENCE:.+]] = hal.timeline.advance
+ // CHECK: %[[SIGNAL_FENCE:.+]] = hal.fence.create
// CHECK: %[[RET0:.+]] = hal.device.queue.alloca
// CHECK-SAME: affinity(%c-1
// CHECK-SAME: wait(%[[WAIT_FENCE]])
@@ -60,7 +60,7 @@
// CHECK-SAME: (%[[SIZE:.+]]: index, %[[RESOURCE:.+]]: !hal.buffer)
func.func @resourceDealloca(%size: index, %resource: !stream.resource<staging>) -> !stream.timepoint {
// CHECK: %[[WAIT_FENCE:.+]] = util.null : !hal.fence
- // CHECK: %[[SIGNAL_FENCE:.+]] = hal.timeline.advance
+ // CHECK: %[[SIGNAL_FENCE:.+]] = hal.fence.create
// CHECK: hal.device.queue.dealloca
// CHECK-SAME: affinity(%c-1
// CHECK-SAME: wait(%[[WAIT_FENCE]])
@@ -78,7 +78,7 @@
// CHECK-LABEL: @resourceDeallocaAwait
// CHECK-SAME: (%[[SIZE:.+]]: index, %[[RESOURCE:.+]]: !hal.buffer, %[[WAIT_FENCE:.+]]: !hal.fence)
func.func @resourceDeallocaAwait(%size: index, %resource: !stream.resource<staging>, %await_timepoint: !stream.timepoint) -> !stream.timepoint {
- // CHECK: %[[SIGNAL_FENCE:.+]] = hal.timeline.advance
+ // CHECK: %[[SIGNAL_FENCE:.+]] = hal.fence.create
// CHECK: hal.device.queue.dealloca
// CHECK-SAME: affinity(%c-1
// CHECK-SAME: wait(%[[WAIT_FENCE]])
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/timepoint_ops.mlir b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/timepoint_ops.mlir
index 80ae081..d5683f6 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/timepoint_ops.mlir
+++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/timepoint_ops.mlir
@@ -33,16 +33,6 @@
// -----
-// CHECK-LABEL: @timepointImportSemaphore
-func.func @timepointImportSemaphore(%arg0: !hal.semaphore, %arg1: i64) -> !stream.timepoint {
- // CHECK: %[[FENCE:.+]] = hal.fence.create at<%arg0 : !hal.semaphore>(%arg1) -> !hal.fence
- %0 = stream.timepoint.import %arg0, %arg1 : (!hal.semaphore, i64) => !stream.timepoint
- // CHECK: return %[[FENCE]]
- return %0 : !stream.timepoint
-}
-
-// -----
-
// CHECK-LABEL: @timepointExportFence
func.func @timepointExportFence(%arg0: !stream.timepoint) -> !hal.fence {
%0 = stream.timepoint.export %arg0 => (!hal.fence)
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALBase.td b/compiler/src/iree/compiler/Dialect/HAL/IR/HALBase.td
index 135b1c4..c6c718e 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALBase.td
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALBase.td
@@ -185,6 +185,16 @@
let cppNamespace = "mlir::iree_compiler::IREE::HAL";
}
+def HAL_FenceFlag_None : I32BitEnumAttrCase<"None", 0x0000>;
+def HAL_FenceFlag_Reserved : I32BitEnumAttrCase<"Reserved", 0x0001>;
+def HAL_FenceFlagBitfieldAttr :
+ I32BitEnumAttr<"FenceFlagBitfield", "valid FenceFlag", [
+ HAL_FenceFlag_None,
+ HAL_FenceFlag_Reserved,
+ ]> {
+ let cppNamespace = "mlir::iree_compiler::IREE::HAL";
+}
+
def HAL_AccessScope_None : I32BitEnumAttrCase<"None", 0x0000>;
def HAL_AccessScope_IndirectCommandRead : I32BitEnumAttrCase<"IndirectCommandRead", 0x0001>;
def HAL_AccessScope_ConstantRead : I32BitEnumAttrCase<"ConstantRead", 0x0002>;
@@ -327,31 +337,6 @@
let builderCall = "$_builder.getType<IREE::HAL::FenceType>()";
}
-def HAL_RingBuffer : DialectType<
- HAL_Dialect,
- CPred<"$_self.isa<IREE::HAL::RingBufferType>()">,
- "ring_buffer"> {
- let description = [{
- Ringbuffer used for transient buffer allocation.
- }];
- let builderCall = "$_builder.getType<IREE::HAL::RingBufferType>()";
-}
-
-def HAL_Semaphore : DialectType<
- HAL_Dialect,
- CPred<"$_self.isa<IREE::HAL::SemaphoreType>()">,
- "semaphore"> {
- let description = [{
- Synchronization mechanism for host->device, device->host, host->host,
- and device->device notification. Semaphores behave like Vulkan timeline
- semaphores (or D3D12 fences) and contain a monotonically increasing
- uint64_t payload. They may be waited on any number of times even if they
- have already been signaled for a particular value. They may also be waited
- on for a particular value prior to the signal for that value.
- }];
- let builderCall = "$_builder.getType<IREE::HAL::SemaphoreType>()";
-}
-
def HAL_ObjectType : AnyTypeOf<[
HAL_Allocator,
HAL_Buffer,
@@ -361,9 +346,8 @@
HAL_Device,
HAL_Event,
HAL_Executable,
+ HAL_Fence,
HAL_PipelineLayout,
- HAL_RingBuffer,
- HAL_Semaphore,
]>;
def HAL_BufferType : AnyTypeOf<[
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOpFolders.cpp b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOpFolders.cpp
index 5b42a9f..84748a3 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOpFolders.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOpFolders.cpp
@@ -55,6 +55,44 @@
namespace {
+/// Folds hal.buffer.subspans into buffer view creation subspans.
+struct FoldBufferViewCreateSubspan
+ : public OpRewritePattern<BufferViewCreateOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(BufferViewCreateOp op,
+ PatternRewriter &rewriter) const override {
+ auto ip = rewriter.saveInsertionPoint();
+ rewriter.setInsertionPoint(op);
+ bool needsUpdate = false;
+ auto newSourceBuffer = op.getSourceBuffer();
+ auto newSourceOffset = op.getSourceOffset();
+ if (auto subspanOp = dyn_cast_or_null<BufferSubspanOp>(
+ op.getSourceBuffer().getDefiningOp())) {
+ newSourceBuffer = subspanOp.getSourceBuffer();
+ newSourceOffset = rewriter.createOrFold<mlir::arith::AddIOp>(
+ subspanOp.getLoc(), subspanOp.getSourceOffset(),
+ op.getSourceOffset());
+ needsUpdate = true;
+ }
+ rewriter.restoreInsertionPoint(ip);
+ if (!needsUpdate) return failure();
+ rewriter.updateRootInPlace(op, [&]() {
+ op.getSourceBufferMutable().assign(newSourceBuffer);
+ op.getSourceOffsetMutable().assign(newSourceOffset);
+ });
+ return success();
+ }
+};
+
+} // namespace
+
+void BufferViewCreateOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ results.insert<FoldBufferViewCreateSubspan>(context);
+}
+
+namespace {
+
/// Skips a hal.buffer_view.buffer accessor when the buffer view was created in
/// the same scope and we know the origin buffer.
struct SkipBufferViewBufferOp : public OpRewritePattern<BufferViewBufferOp> {
@@ -64,7 +102,7 @@
PatternRewriter &rewriter) const override {
if (auto createOp = dyn_cast_or_null<BufferViewCreateOp>(
op.getBufferView().getDefiningOp())) {
- rewriter.replaceOp(op, createOp.getBuffer());
+ rewriter.replaceOp(op, createOp.getSourceBuffer());
return success();
}
return failure();
@@ -267,64 +305,16 @@
namespace {
/// Replaces a fence with no timepoints with a null value.
-struct ElideEmptyFenceCreate : public OpRewritePattern<FenceCreateOp> {
+struct ElideUnusedFenceCreate : public OpRewritePattern<FenceCreateOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(FenceCreateOp op,
PatternRewriter &rewriter) const override {
- if (op.getNumOperands() != 0) return failure();
- rewriter.replaceOpWithNewOp<IREE::Util::NullOp>(op,
- op.getResult().getType());
- return success();
- }
-};
-
-/// Deduplicates timepoints by taking the maximum payload value of any that
-/// share the same semaphore.
-struct DeduplicateFenceCreateTimepoints
- : public OpRewritePattern<FenceCreateOp> {
- using OpRewritePattern::OpRewritePattern;
- LogicalResult matchAndRewrite(FenceCreateOp op,
- PatternRewriter &rewriter) const override {
- // Check to see if the fence is over a single (semaphore, value) timepoint.
- if (op.getSemaphores().size() <= 1) {
- return failure(); // just 0 or 1 timepoint
+ if (op.use_empty()) {
+ rewriter.eraseOp(op);
+ return success();
+ } else {
+ return failure();
}
-
- // Build a map of all timepoints keyed on semaphore.
- // This will implicitly deduplicate the semaphores and the values for each.
- llvm::MapVector<Value, SetVector<Value>> timepoints;
- for (auto it : llvm::zip(op.getSemaphores(), op.getMinValues())) {
- auto semaphore = std::get<0>(it);
- auto minValue = std::get<1>(it);
- timepoints[semaphore].insert(minValue);
- }
-
- // Check for no-op when we don't deduplicate anything.
- if (timepoints.size() == op.getSemaphores().size()) return failure();
-
- // Build the timepoints.
- // A single semaphore may have multiple values and we need to take the max.
- SmallVector<Value> semaphores;
- SmallVector<Value> minValues;
- semaphores.reserve(timepoints.size());
- minValues.reserve(timepoints.size());
- for (auto it : timepoints) {
- semaphores.push_back(it.first);
- if (it.second.size() == 1) {
- // Single timepoint.
- minValues.push_back(it.second.front());
- } else {
- // Join timepoints. This will fold if constant.
- minValues.push_back(rewriter.createOrFold<IREE::Util::RangeMaxOp>(
- op.getLoc(), it.second.takeVector()));
- }
- }
-
- // Build new op. The map/set vectors we used will ensure the relative order
- // of the timepoints matches the original.
- rewriter.replaceOpWithNewOp<FenceCreateOp>(op, op.getResult().getType(),
- semaphores, minValues);
- return success();
}
};
@@ -332,8 +322,7 @@
void FenceCreateOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.insert<ElideEmptyFenceCreate>(context);
- results.insert<DeduplicateFenceCreateTimepoints>(context);
+ results.insert<ElideUnusedFenceCreate>(context);
}
//===----------------------------------------------------------------------===//
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp
index 1553b9f..21961d0 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp
@@ -119,54 +119,6 @@
}
//===----------------------------------------------------------------------===//
-// custom<TimepointList>($semaphores, $values)
-//===----------------------------------------------------------------------===//
-// at<%semaphore : !hal.semaphore>(%value) ...
-
-static ParseResult parseTimepointList(
- OpAsmParser &parser,
- SmallVectorImpl<OpAsmParser::UnresolvedOperand> &semaphores,
- SmallVectorImpl<Type> &semaphoreTypes,
- SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values) {
- while (succeeded(parser.parseOptionalKeyword("at"))) {
- OpAsmParser::UnresolvedOperand semaphore;
- Type semaphoreType;
- OpAsmParser::UnresolvedOperand value;
- if (failed(parser.parseLess()) || failed(parser.parseOperand(semaphore)) ||
- failed(parser.parseColonType(semaphoreType)) ||
- failed(parser.parseGreater()) || failed(parser.parseLParen()) ||
- failed(parser.parseOperand(value)) || failed(parser.parseRParen())) {
- return failure();
- }
- semaphores.push_back(semaphore);
- semaphoreTypes.push_back(semaphoreType);
- values.push_back(value);
- }
- return success();
-}
-
-static void printTimepointList(OpAsmPrinter &p, Operation *op,
- ValueRange semaphores, TypeRange semaphoreTypes,
- ValueRange values) {
- if (semaphores.empty()) return;
- llvm::interleave(
- llvm::zip(semaphores, semaphoreTypes, values), p,
- [&](std::tuple<Value, Type, Value> it) {
- auto semaphore = std::get<0>(it);
- auto semaphoreType = std::get<1>(it);
- auto value = std::get<2>(it);
- p << "at<";
- p.printOperand(semaphore);
- p << " : ";
- p.printType(semaphoreType);
- p << ">(";
- p.printOperand(value);
- p << ")";
- },
- " ");
-}
-
-//===----------------------------------------------------------------------===//
// hal.ex.shared_device
//===----------------------------------------------------------------------===//
@@ -401,9 +353,10 @@
//===----------------------------------------------------------------------===//
void BufferViewCreateOp::build(OpBuilder &builder, OperationState &state,
- Value buffer, int32_t elementType,
+ Value sourceBuffer, Value sourceOffset,
+ Value sourceLength, int32_t elementType,
int32_t encodingType, ValueRange shape) {
- build(builder, state, buffer,
+ build(builder, state, sourceBuffer, sourceOffset, sourceLength,
builder.createOrFold<arith::ConstantIntOp>(state.location, elementType,
32),
builder.createOrFold<arith::ConstantIntOp>(state.location, encodingType,
@@ -412,9 +365,11 @@
}
void BufferViewCreateOp::build(OpBuilder &builder, OperationState &state,
- Value buffer, Value elementType,
+ Value sourceBuffer, Value sourceOffset,
+ Value sourceLength, Value elementType,
Value encodingType, ValueRange shape) {
- state.addOperands({buffer, elementType, encodingType});
+ state.addOperands(
+ {sourceBuffer, sourceOffset, sourceLength, elementType, encodingType});
state.addOperands(shape);
state.addTypes({BufferViewType::get(builder.getContext())});
}
@@ -1078,25 +1033,6 @@
setNameFn(getStatus(), "status");
}
-//===----------------------------------------------------------------------===//
-// hal.semaphore.*
-//===----------------------------------------------------------------------===//
-
-void SemaphoreCreateOp::getAsmResultNames(
- function_ref<void(Value, StringRef)> setNameFn) {
- setNameFn(getResult(), "semaphore");
-}
-
-void SemaphoreQueryOp::getAsmResultNames(
- function_ref<void(Value, StringRef)> setNameFn) {
- setNameFn(getStatus(), "status");
-}
-
-void SemaphoreAwaitOp::getAsmResultNames(
- function_ref<void(Value, StringRef)> setNameFn) {
- setNameFn(getStatus(), "status");
-}
-
} // namespace HAL
} // namespace IREE
} // namespace iree_compiler
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td
index 9eabdb5..8f1ab3e 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td
@@ -172,34 +172,6 @@
let hasFolder = 1;
}
-// NOTE: this has side-effects as it is mutating the global timeline.
-// Eventually we'll probably want a dedicated hal.timeline type instead.
-def HAL_TimelineAdvanceOp : HAL_Op<"timeline.advance"> {
- let summary = [{advances a program timeline by one step}];
- let description = [{
- Returns a fence indicating when the timeline has been advanced one step.
- This fence can be used to wait until the timeline reaches or exceeds the
- timepoint or used to signal the that it has.
-
- This is a pseudo-op that is expanded into a semaphore and target value
- pair during timeline materialization. The op represents when the advancement
- should occur in program order but not what the actual live timepoint would
- be.
- }];
-
- // TODO(benvanik): discriminator when multiple devices or timelines are
- // present. Today we only have a single timeline.
- let arguments = (ins);
- let results = (outs
- HAL_Fence:$fence
- );
-
- let assemblyFormat = [{
- `:` type($fence)
- attr-dict-with-keyword
- }];
-}
-
//===----------------------------------------------------------------------===//
// !hal.allocator / iree_hal_allocator_t
//===----------------------------------------------------------------------===//
@@ -458,7 +430,9 @@
}];
let arguments = (ins
- HAL_BufferType:$buffer,
+ HAL_BufferType:$source_buffer,
+ HAL_DeviceSize:$source_offset,
+ HAL_DeviceSize:$source_length,
HAL_ElementType:$element_type,
HAL_EncodingType:$encoding_type,
HAL_Shape:$shape
@@ -468,7 +442,8 @@
);
let assemblyFormat = [{
- `buffer` `(` $buffer `:` type($buffer) `)`
+ `buffer` `(` $source_buffer `:` type($source_buffer) `)`
+ `` `[` $source_offset `,` $source_length `]`
`shape` `(` `[` $shape `]` `)`
`type` `(` $element_type `)`
`encoding` `(` $encoding_type `)`
@@ -479,18 +454,24 @@
let skipDefaultBuilders = 1;
let builders = [
OpBuilder<(ins
- "Value":$buffer,
+ "Value":$sourceBuffer,
+ "Value":$sourceOffset,
+ "Value":$sourceLength,
"int32_t":$elementType,
"int32_t":$encodingType,
"ValueRange":$shape
)>,
OpBuilder<(ins
- "Value":$buffer,
+ "Value":$sourceBuffer,
+ "Value":$sourceOffset,
+ "Value":$sourceLength,
"Value":$elementType,
"Value":$encodingType,
"ValueRange":$shape
)>,
];
+
+ let hasCanonicalizer = 1;
}
def HAL_BufferViewAssertOp : HAL_Op<"buffer_view.assert", []> {
@@ -2056,24 +2037,28 @@
def HAL_FenceCreateOp : HAL_Op<"fence.create", [
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
- SameVariadicOperandSize,
+ MemoryEffects<[MemAlloc]>,
]> {
- let summary = [{creates a fence from the given timepoints}];
+ let summary = [{creates an unsignaled fence}];
let description = [{
- Returns a fence that defines a point in time across one or more timelines.
+ Returns a fence that defines a point in time. By default fences will remain
+ unsignaled unless they are explicitly signaled with `hal.fence.signal` or
+ asynchronously signaled by the device by passing them as an operand to
+ queue submission ops.
}];
let arguments = (ins
- Variadic<HAL_Semaphore>:$semaphores,
- Variadic<HAL_TimelineValue>:$min_values
+ HAL_Device:$device,
+ HAL_FenceFlagBitfieldAttr:$flags
);
let results = (outs
HAL_Fence:$result
);
let assemblyFormat = [{
- custom<TimepointList>($semaphores, type($semaphores), $min_values)
- `->` type($result)
+ `device` `(` $device `:` type($device) `)`
+ `flags` `(` $flags `)`
+ `:` type($result)
attr-dict-with-keyword
}];
@@ -2104,6 +2089,28 @@
let hasCanonicalizer = 1;
}
+def HAL_FenceQueryOp : HAL_Op<"fence.query"> {
+ let summary = [{fence query operation}];
+ let description = [{
+ Queries whether the fence has been reached and its status.
+ Returns OK if the fence has been signaled successfully, DEFERRED if it is
+ unsignaled, and otherwise an error indicating the failure.
+ }];
+
+ let arguments = (ins
+ HAL_Fence:$fence
+ );
+ let results = (outs
+ Util_Status:$status
+ );
+
+ let assemblyFormat = [{
+ `<` $fence `:` type($fence) `>`
+ `:` type($status)
+ attr-dict-with-keyword
+ }];
+}
+
def HAL_FenceSignalOp : HAL_Op<"fence.signal"> {
let summary = [{fence signal operation}];
let description = [{
@@ -2169,125 +2176,4 @@
let hasCanonicalizer = 1;
}
-//===----------------------------------------------------------------------===//
-// !hal.semaphore / iree_hal_semaphore_t
-//===----------------------------------------------------------------------===//
-
-def HAL_SemaphoreCreateOp : HAL_Op<"semaphore.create", [
- DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
- ]> {
- let summary = [{semaphore allocation operation}];
- let description = [{
- Returns a semaphore from the device pool with the given initial value.
- }];
-
- let arguments = (ins
- HAL_Device:$device,
- HAL_TimelineValue:$initial_value
- );
- let results = (outs
- HAL_Semaphore:$result
- );
-
- let assemblyFormat = [{
- `device` `(` $device `:` type($device) `)`
- `initial` `(` $initial_value `)`
- `:` type($result)
- attr-dict-with-keyword
- }];
-}
-
-def HAL_SemaphoreQueryOp : HAL_Op<"semaphore.query", [
- DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
- ]> {
- let summary = [{semaphore payload value query}];
- let description = [{
- Queries the current payload and returns a tuple of `(status, value)`.
- As the payload is monotonically increasing it is guaranteed that
- the value is at least equal to the previous result of a
- `hal.semaphore.signal` call and coherent with any waits for a
- specified value via `hal.semaphore.await`.
- }];
-
- let arguments = (ins
- HAL_Semaphore:$semaphore
- );
- let results = (outs
- Util_Status:$status,
- HAL_TimelineValue:$value
- );
-
- let assemblyFormat = [{
- `<` $semaphore `:` type($semaphore) `>`
- `:` type($status) `,` type($value)
- attr-dict-with-keyword
- }];
-}
-
-def HAL_SemaphoreSignalOp : HAL_Op<"semaphore.signal"> {
- let summary = [{semaphore payload value signal operation}];
- let description = [{
- Signals the semaphore to the given payload value.
- The call is ignored if the current payload value exceeds `new_value`.
- }];
-
- let arguments = (ins
- HAL_Semaphore:$semaphore,
- HAL_TimelineValue:$new_value
- );
-
- let assemblyFormat = [{
- `<` $semaphore `:` type($semaphore) `>`
- `value` `(` $new_value `)`
- attr-dict-with-keyword
- }];
-}
-
-def HAL_SemaphoreFailOp : HAL_Op<"semaphore.fail"> {
- let summary = [{semaphore asynchronous failure operation}];
- let description = [{
- Signals the semaphore with a failure. The `status` will be returned from
- `hal.semaphore.query` and `hal.semaphore.signal` for the lifetime
- of the semaphore.
- }];
-
- let arguments = (ins
- HAL_Semaphore:$semaphore,
- Util_Status:$status
- );
-
- let assemblyFormat = [{
- `<` $semaphore `:` type($semaphore) `>`
- `status` `(` $status `)`
- attr-dict-with-keyword
- }];
-}
-
-def HAL_SemaphoreAwaitOp : HAL_Op<"semaphore.await", [
- Util_YieldPoint,
- DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
- ]> {
- let summary = [{asynchronous semaphore wait operation}];
- let description = [{
- Yields the caller until the semaphore reaches or exceeds the specified
- payload `min_value`. Returns the `status` of the semaphore after the wait,
- with a non-zero value indicating failure.
- }];
-
- let arguments = (ins
- HAL_Semaphore:$semaphore,
- HAL_TimelineValue:$min_value
- );
- let results = (outs
- Util_Status:$status
- );
-
- let assemblyFormat = [{
- `<` $semaphore `:` type($semaphore) `>`
- `until` `(` $min_value `)`
- `:` type($status)
- attr-dict-with-keyword
- }];
-}
-
#endif // IREE_DIALECT_HAL_OPS
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/test/BUILD b/compiler/src/iree/compiler/Dialect/HAL/IR/test/BUILD
index 614c52d..8c00970 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/test/BUILD
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/test/BUILD
@@ -32,7 +32,6 @@
"fence_ops.mlir",
"interface_ops.mlir",
"invalid.mlir",
- "semaphore_ops.mlir",
"tensor_op_folding.mlir",
"tensor_ops.mlir",
],
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/test/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/HAL/IR/test/CMakeLists.txt
index 8131976..351933a 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/test/CMakeLists.txt
@@ -30,7 +30,6 @@
"fence_ops.mlir"
"interface_ops.mlir"
"invalid.mlir"
- "semaphore_ops.mlir"
"tensor_op_folding.mlir"
"tensor_ops.mlir"
TOOLS
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/test/buffer_view_folding.mlir b/compiler/src/iree/compiler/Dialect/HAL/IR/test/buffer_view_folding.mlir
index 39659c2..5e12563 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/test/buffer_view_folding.mlir
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/test/buffer_view_folding.mlir
@@ -1,13 +1,37 @@
// RUN: iree-opt --split-input-file --canonicalize -cse %s | iree-opt --allow-unregistered-dialect --split-input-file | FileCheck %s
-// CHECK-LABEL: func.func @skip_buffer_view_buffer
+// CHECK-LABEL: @FoldBufferViewCreateSubspan
+// CHECK-SAME: (%[[BASE_BUFFER:.+]]: !hal.buffer, %[[SUBSPAN_OFFSET:.+]]: index, %[[SUBSPAN_LENGTH:.+]]: index)
+func.func @FoldBufferViewCreateSubspan(%base_buffer: !hal.buffer, %subspan_offset: index, %subspan_length: index) -> !hal.buffer_view {
+ %subspan = hal.buffer.subspan<%base_buffer : !hal.buffer>[%subspan_offset, %subspan_length] : !hal.buffer
+ // CHECK-DAG: %[[VIEW_OFFSET:.+]] = arith.constant 512
+ %view_offset = arith.constant 512 : index
+ // CHECK-DAG: %[[VIEW_LENGTH:.+]] = arith.constant 1024
+ %view_length = arith.constant 1024 : index
+ // CHECK-DAG: %[[FOLDED_OFFSET:.+]] = arith.addi %[[SUBSPAN_OFFSET]], %[[VIEW_OFFSET]]
+ // CHECK: = hal.buffer_view.create
+ // CHECK-SAME: buffer(%[[BASE_BUFFER]] : !hal.buffer)[%[[FOLDED_OFFSET]], %[[VIEW_LENGTH]]]
+ %dim0 = arith.constant 128 : index
+ %type = arith.constant 32 : i32
+ %encoding = arith.constant 1 : i32
+ %view = hal.buffer_view.create buffer(%subspan : !hal.buffer)[%view_offset, %view_length]
+ shape([%dim0])
+ type(%type)
+ encoding(%encoding) : !hal.buffer_view
+ return %view : !hal.buffer_view
+}
+
+// -----
+
+// CHECK-LABEL: func.func @SkipBufferViewBufferOp
// CHECK-SAME: %[[BUFFER:.+]]: !hal.buffer
-func.func @skip_buffer_view_buffer(%buffer : !hal.buffer) -> !hal.buffer {
+func.func @SkipBufferViewBufferOp(%buffer : !hal.buffer) -> !hal.buffer {
+ %c0 = arith.constant 0 : index
%c1 = arith.constant 1 : i32
%c10 = arith.constant 10 : index
%c11 = arith.constant 11 : index
%c32 = arith.constant 32 : i32
- %view = hal.buffer_view.create buffer(%buffer : !hal.buffer)
+ %view = hal.buffer_view.create buffer(%buffer : !hal.buffer)[%c0, %c10]
shape([%c10, %c11])
type(%c32)
encoding(%c1) : !hal.buffer_view
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/test/buffer_view_ops.mlir b/compiler/src/iree/compiler/Dialect/HAL/IR/test/buffer_view_ops.mlir
index a4b1f07..63615f4 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/test/buffer_view_ops.mlir
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/test/buffer_view_ops.mlir
@@ -2,17 +2,19 @@
// CHECK-LABEL: @buffer_view_create
func.func @buffer_view_create(%arg0: !hal.buffer, %arg1: index, %arg2: index) -> !hal.buffer_view {
- %c1 = arith.constant 1 : i32
- %c32 = arith.constant 32 : i32
+ %c0 = arith.constant 0 : index
+ %c128 = arith.constant 128 : index
+ %c1_i32 = arith.constant 1 : i32
+ %c32_i32 = arith.constant 32 : i32
// CHECK: %view = hal.buffer_view.create
- // CHECK-SAME: buffer(%arg0 : !hal.buffer)
+ // CHECK-SAME: buffer(%arg0 : !hal.buffer)[%c0, %c128]
// CHECK-SAME: shape([%arg1, %arg2])
// CHECK-SAME: type(%c32_i32)
// CHECK-SAME: encoding(%c1_i32) : !hal.buffer_view
- %view = hal.buffer_view.create buffer(%arg0 : !hal.buffer)
+ %view = hal.buffer_view.create buffer(%arg0 : !hal.buffer)[%c0, %c128]
shape([%arg1, %arg2])
- type(%c32)
- encoding(%c1) : !hal.buffer_view
+ type(%c32_i32)
+ encoding(%c1_i32) : !hal.buffer_view
return %view : !hal.buffer_view
}
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/test/fence_folding.mlir b/compiler/src/iree/compiler/Dialect/HAL/IR/test/fence_folding.mlir
index bc176e8..ec9487e 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/test/fence_folding.mlir
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/test/fence_folding.mlir
@@ -4,60 +4,11 @@
// This avoids the allocation and lets the null propagate through the rest of
// the program to simplify submissions.
-// CHECK-LABEL: @fence_create_empty
-func.func @fence_create_empty() -> !hal.fence {
- // CHECK: %[[FENCE:.+]] = util.null : !hal.fence
- %fence = hal.fence.create -> !hal.fence
- // CHECK: return %[[FENCE]]
- return %fence : !hal.fence
-}
-
-// -----
-
-// Tests that a fence with multiple timepoints sharing the same semaphore are
-// deduplicated to the max of the timepoints.
-
-// CHECK-LABEL: @fence_create_duplicate_semaphores
-// CHECK-SAME: %[[SEMAPHORE0:.+]]: !hal.semaphore, %[[TIME0:.+]]: i64, %[[SEMAPHORE1:.+]]: !hal.semaphore, %[[TIME1:.+]]: i64, %[[TIME2:.+]]: i64
-func.func @fence_create_duplicate_semaphores(
- %semaphore0: !hal.semaphore, %time0: i64,
- %semaphore1: !hal.semaphore, %time1: i64, %time2: i64) -> !hal.fence {
- // CHECK: %[[TIMEMAX:.+]] = arith.maxui %[[TIME1]], %[[TIME2]] : i64
- // CHECK: %[[FENCE:.+]] = hal.fence.create
- // CHECK-SAME: at<%[[SEMAPHORE0]] : !hal.semaphore>(%[[TIME0]])
- // CHECK-SAME: at<%[[SEMAPHORE1]] : !hal.semaphore>(%[[TIMEMAX]])
- %fence = hal.fence.create
- at<%semaphore0 : !hal.semaphore>(%time0)
- at<%semaphore1 : !hal.semaphore>(%time1)
- at<%semaphore1 : !hal.semaphore>(%time2)
- -> !hal.fence
- // CHECK: return %[[FENCE]]
- return %fence : !hal.fence
-}
-
-// -----
-
-// Tests that timepoints with the same values are deduplicated.
-// This would be handled by util.range.max canonicalizations as above but this
-// avoids emitting additional IR and is effectively free.
-
-// CHECK-LABEL: @fence_create_duplicate_values
-// CHECK-SAME: %[[SEMAPHORE0:.+]]: !hal.semaphore, %[[TIME0:.+]]: i64, %[[SEMAPHORE1:.+]]: !hal.semaphore, %[[TIME1:.+]]: i64
-func.func @fence_create_duplicate_values(
- %semaphore0: !hal.semaphore, %time0: i64,
- %semaphore1: !hal.semaphore, %time1: i64) -> !hal.fence {
- // CHECK: %[[FENCE:.+]] = hal.fence.create
- %fence = hal.fence.create
- // CHECK-SAME: at<%[[SEMAPHORE0]] : !hal.semaphore>(%[[TIME0]])
- at<%semaphore0 : !hal.semaphore>(%time0)
- at<%semaphore0 : !hal.semaphore>(%time0)
- at<%semaphore0 : !hal.semaphore>(%time0)
- // CHECK-SAME: at<%[[SEMAPHORE1]] : !hal.semaphore>(%[[TIME1]])
- at<%semaphore1 : !hal.semaphore>(%time1)
- at<%semaphore1 : !hal.semaphore>(%time1)
- -> !hal.fence
- // CHECK: return %[[FENCE]]
- return %fence : !hal.fence
+// CHECK-LABEL: @fence_create_unused
+func.func @fence_create_unused(%device: !hal.device) {
+ // CHECK-NOT: hal.fence.create
+ %fence = hal.fence.create device(%device : !hal.device) flags("None") : !hal.fence
+ return
}
// -----
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/test/fence_ops.mlir b/compiler/src/iree/compiler/Dialect/HAL/IR/test/fence_ops.mlir
index de8cedb..65f7d60 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/test/fence_ops.mlir
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/test/fence_ops.mlir
@@ -1,24 +1,9 @@
// RUN: iree-opt --split-input-file %s | iree-opt --split-input-file | FileCheck %s
-// CHECK-LABEL: @timeline_advance
-func.func @timeline_advance() -> !hal.fence {
- // CHECK: = hal.timeline.advance : !hal.fence
- %fence = hal.timeline.advance : !hal.fence
- return %fence : !hal.fence
-}
-
-// -----
-
// CHECK-LABEL: @fence_create
-func.func @fence_create(%arg0: !hal.semaphore, %arg1: i64, %arg2: i64) -> !hal.fence {
- // CHECK: = hal.fence.create
- // CHECK-SAME: at<%arg0 : !hal.semaphore>(%arg1)
- // CHECK-SAME: at<%arg0 : !hal.semaphore>(%arg2)
- // CHECK-SAME: -> !hal.fence
- %fence = hal.fence.create
- at<%arg0 : !hal.semaphore>(%arg1)
- at<%arg0 : !hal.semaphore>(%arg2)
- -> !hal.fence
+func.func @fence_create(%arg0: !hal.device) -> !hal.fence {
+ // CHECK: = hal.fence.create device(%arg0 : !hal.device) flags("None") : !hal.fence
+ %fence = hal.fence.create device(%arg0 : !hal.device) flags("None") : !hal.fence
return %fence : !hal.fence
}
@@ -33,6 +18,15 @@
// -----
+// CHECK-LABEL: @fence_query
+func.func @fence_query(%arg0: !hal.fence) -> i32 {
+ // CHECK: = hal.fence.query<%arg0 : !hal.fence> : i32
+ %status = hal.fence.query<%arg0 : !hal.fence> : i32
+ return %status : i32
+}
+
+// -----
+
// CHECK-LABEL: @fence_signal
func.func @fence_signal(%arg0: !hal.fence) {
// CHECK: hal.fence.signal<%arg0 : !hal.fence>
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/test/semaphore_ops.mlir b/compiler/src/iree/compiler/Dialect/HAL/IR/test/semaphore_ops.mlir
deleted file mode 100644
index 842ddc5..0000000
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/test/semaphore_ops.mlir
+++ /dev/null
@@ -1,52 +0,0 @@
-// RUN: iree-opt --split-input-file %s | iree-opt --split-input-file | FileCheck %s
-
-// CHECK-LABEL: @semaphore_create
-func.func @semaphore_create(%arg0 : !hal.device) -> !hal.semaphore {
- // CHECK: %[[C0:.+]] = arith.constant 0
- %c0 = arith.constant 0 : i64
- // CHECK: %semaphore = hal.semaphore.create device(%arg0 : !hal.device) initial(%[[C0]]) : !hal.semaphore
- %semaphore = hal.semaphore.create device(%arg0 : !hal.device) initial(%c0) : !hal.semaphore
- return %semaphore : !hal.semaphore
-}
-
-// -----
-
-// CHECK-LABEL: @semaphore_query
-func.func @semaphore_query(%arg0 : !hal.semaphore) {
- // CHECK: = hal.semaphore.query<%arg0 : !hal.semaphore> : i32, i64
- %status, %value = hal.semaphore.query<%arg0 : !hal.semaphore> : i32, i64
- return
-}
-
-// -----
-
-// CHECK-LABEL: @semaphore_signal
-func.func @semaphore_signal(%arg0 : !hal.semaphore) {
- // CHECK: %[[C0:.+]] = arith.constant 0
- %c0 = arith.constant 0 : i64
- // CHECK: hal.semaphore.signal<%arg0 : !hal.semaphore> value(%[[C0]])
- hal.semaphore.signal<%arg0 : !hal.semaphore> value(%c0)
- return
-}
-
-// -----
-
-// CHECK-LABEL: @semaphore_fail
-func.func @semaphore_fail(%arg0 : !hal.semaphore) {
- // CHECK: %[[C0:.+]] = arith.constant 0
- %c0 = arith.constant 0 : i32
- // CHECK: hal.semaphore.fail<%arg0 : !hal.semaphore> status(%[[C0]])
- hal.semaphore.fail<%arg0 : !hal.semaphore> status(%c0)
- return
-}
-
-// -----
-
-// CHECK-LABEL: @semaphore_await
-func.func @semaphore_await(%arg0 : !hal.semaphore) {
- // CHECK: %[[C0:.+]] = arith.constant 0
- %c0 = arith.constant 0 : i64
- // CHECK: = hal.semaphore.await<%arg0 : !hal.semaphore> until(%[[C0]]) : i32
- %0 = hal.semaphore.await<%arg0 : !hal.semaphore> until(%c0) : i32
- return
-}
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Target/MetalSPIRV/MetalSPIRVTarget.cpp b/compiler/src/iree/compiler/Dialect/HAL/Target/MetalSPIRV/MetalSPIRVTarget.cpp
index 5796680..7e2386e 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Target/MetalSPIRV/MetalSPIRVTarget.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Target/MetalSPIRV/MetalSPIRVTarget.cpp
@@ -67,7 +67,7 @@
}
void buildTranslationPassPipeline(OpPassManager &passManager) override {
- buildSPIRVCodegenPassPipeline(passManager);
+ buildSPIRVCodegenPassPipeline(passManager, /*enableFastMath=*/false);
}
LogicalResult serializeExecutable(const SerializationOptions &options,
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp b/compiler/src/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp
index 459ab3d..fffc2d5 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp
@@ -139,7 +139,7 @@
}
void buildTranslationPassPipeline(OpPassManager &passManager) override {
- buildSPIRVCodegenPassPipeline(passManager);
+ buildSPIRVCodegenPassPipeline(passManager, /*enableFastMath=*/false);
}
// TODO(antiagainst): Re-enable SPIR-V linking once the tensorflow integration
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Target/WebGPU/WebGPUTarget.cpp b/compiler/src/iree/compiler/Dialect/HAL/Target/WebGPU/WebGPUTarget.cpp
index 2343ecf..a402686 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Target/WebGPU/WebGPUTarget.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Target/WebGPU/WebGPUTarget.cpp
@@ -84,7 +84,18 @@
}
void buildTranslationPassPipeline(OpPassManager &passManager) override {
- buildSPIRVCodegenPassPipeline(passManager);
+ // From WGSL spec, "Floating Point Evaluation"
+ // (https://www.w3.org/TR/WGSL/#floating-point-evaluation):
+ // - Implementations may assume that NaNs and infinities are not present at
+ // runtime.
+ // - In such an implementation, when an evaluation would produce an
+ // infinity or a NaN, an undefined value of the target type is produced
+ // instead.
+ // So WebGPU effectively assumes fast math mode. We also don't have reliable
+ // ways to check whether a floating point number is NaN or infinity.
+ // Therefore, just let the SPIR-V CodeGen to avoid generating guards w.r.t.
+ // NaN and infinity.
+ buildSPIRVCodegenPassPipeline(passManager, /*enableFastMath=*/true);
// TODO(scotttodd): additional passes for WebGPU/WGSL
// (here or during serialization?)
}
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/BUILD b/compiler/src/iree/compiler/Dialect/HAL/Transforms/BUILD
index 5847ccc..21bd93d 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/BUILD
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/BUILD
@@ -26,7 +26,6 @@
"LinkExecutables.cpp",
"MaterializeInterfaces.cpp",
"MaterializeResourceCaches.cpp",
- "MaterializeTimelines.cpp",
"MemoizeDeviceQueries.cpp",
"Passes.cpp",
"ResolveExportOrdinals.cpp",
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/HAL/Transforms/CMakeLists.txt
index 9231c12..340364b 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/CMakeLists.txt
@@ -27,7 +27,6 @@
"LinkExecutables.cpp"
"MaterializeInterfaces.cpp"
"MaterializeResourceCaches.cpp"
- "MaterializeTimelines.cpp"
"MemoizeDeviceQueries.cpp"
"Passes.cpp"
"ResolveExportOrdinals.cpp"
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/DumpExecutableBenchmarks.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/DumpExecutableBenchmarks.cpp
index cd0dc4b..5c7db3c 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/DumpExecutableBenchmarks.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/DumpExecutableBenchmarks.cpp
@@ -323,8 +323,9 @@
// TODO(benvanik): add fences to ABI so the benchmark tool can pipeline.
Value waitFence = funcBuilder.create<IREE::Util::NullOp>(
loc, funcBuilder.getType<IREE::HAL::FenceType>());
- Value signalFence = funcBuilder.create<IREE::HAL::TimelineAdvanceOp>(
- loc, funcBuilder.getType<IREE::HAL::FenceType>());
+ Value signalFence = funcBuilder.create<IREE::HAL::FenceCreateOp>(
+ loc, funcBuilder.getType<IREE::HAL::FenceType>(), device,
+ IREE::HAL::FenceFlagBitfield::None);
// Queue execution.
auto queueAffinity = funcBuilder.create<arith::ConstantIntOp>(loc, -1, 64);
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeTimelines.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeTimelines.cpp
deleted file mode 100644
index 96edf34..0000000
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeTimelines.cpp
+++ /dev/null
@@ -1,147 +0,0 @@
-// Copyright 2022 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#include <memory>
-#include <utility>
-
-#include "iree/compiler/Dialect/HAL/IR/HALDialect.h"
-#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
-#include "iree/compiler/Dialect/HAL/Transforms/Passes.h"
-#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
-#include "mlir/IR/Attributes.h"
-#include "mlir/IR/Builders.h"
-#include "mlir/IR/BuiltinTypes.h"
-#include "mlir/Pass/Pass.h"
-
-namespace mlir {
-namespace iree_compiler {
-namespace IREE {
-namespace HAL {
-namespace {
-
-//===----------------------------------------------------------------------===//
-// hal.timeline analysis
-//===----------------------------------------------------------------------===//
-
-// This pass is provisional and only works because we have a single device and
-// don't do multi-queue scheduling. When we want to do that we'll need to attach
-// device information to each `hal.timeline.advance` or have it take a device
-// SSA value. We may also want a top-level timeline type we insert before
-// lowering streams to hal - possibly even in the stream dialect as a final
-// stage.
-
-struct Timeline {
- IREE::Util::GlobalOp semaphore;
- IREE::Util::GlobalOp value;
-};
-
-static Timeline defineGlobalTimeline(mlir::ModuleOp moduleOp) {
- auto moduleBuilder = OpBuilder::atBlockBegin(moduleOp.getBody());
-
- // When we support multiple devices and queues we'd want to name the globals
- // based on them and use their canonical location information (maybe all
- // places that touch the timeline).
- Timeline timeline;
- std::string namePrefix = "_timeline";
- auto loc = moduleBuilder.getUnknownLoc();
-
- // Internal timelines start at zero.
- auto initialValueAttr = moduleBuilder.getI64IntegerAttr(0);
-
- timeline.semaphore = moduleBuilder.create<IREE::Util::GlobalOp>(
- loc, namePrefix + "_semaphore", /*isMutable=*/false,
- moduleBuilder.getType<IREE::HAL::SemaphoreType>());
- timeline.semaphore.setPrivate();
- auto initializerOp = moduleBuilder.create<IREE::Util::InitializerOp>(loc);
- auto initializerBuilder =
- OpBuilder::atBlockBegin(initializerOp.addEntryBlock());
- Value device = initializerBuilder.create<IREE::HAL::ExSharedDeviceOp>(loc);
- Value initialValue =
- initializerBuilder.create<arith::ConstantOp>(loc, initialValueAttr);
- auto semaphore = initializerBuilder.create<IREE::HAL::SemaphoreCreateOp>(
- loc, initializerBuilder.getType<IREE::HAL::SemaphoreType>(), device,
- initialValue);
- initializerBuilder.create<IREE::Util::GlobalStoreOp>(loc, semaphore,
- timeline.semaphore);
- initializerBuilder.create<IREE::Util::InitializerReturnOp>(loc);
-
- timeline.value = moduleBuilder.create<IREE::Util::GlobalOp>(
- loc, namePrefix + "_value", /*isMutable=*/true,
- moduleBuilder.getI64Type(), initialValueAttr);
- timeline.value.setPrivate();
-
- return timeline;
-}
-
-static void rewriteTimelineOps(Timeline timeline, mlir::ModuleOp rootOp) {
- for (auto funcOp : rootOp.getOps<FunctionOpInterface>()) {
- funcOp.walk([&](IREE::HAL::TimelineAdvanceOp advanceOp) {
- auto builder = OpBuilder(advanceOp);
- Value semaphore = builder.create<IREE::Util::GlobalLoadOp>(
- advanceOp.getLoc(), timeline.semaphore);
- Value currentValue = builder.create<IREE::Util::GlobalLoadOp>(
- advanceOp.getLoc(), timeline.value);
- Value one =
- builder.create<arith::ConstantIntOp>(advanceOp.getLoc(), 1, 64);
- Value nextValue =
- builder.create<arith::AddIOp>(advanceOp.getLoc(), currentValue, one);
- builder.create<IREE::Util::GlobalStoreOp>(advanceOp.getLoc(), nextValue,
- timeline.value);
- Value fence = builder.create<IREE::HAL::FenceCreateOp>(
- advanceOp.getLoc(), builder.getType<IREE::HAL::FenceType>(),
- ValueRange{semaphore}, ValueRange{nextValue});
- advanceOp.replaceAllUsesWith(fence);
- advanceOp.erase();
- });
- }
-}
-
-//===----------------------------------------------------------------------===//
-// -iree-hal-materialize-timelines
-//===----------------------------------------------------------------------===//
-
-class MaterializeTimelinesPass
- : public PassWrapper<MaterializeTimelinesPass,
- OperationPass<mlir::ModuleOp>> {
- public:
- MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MaterializeTimelinesPass)
-
- MaterializeTimelinesPass() = default;
-
- void getDependentDialects(DialectRegistry ®istry) const override {
- registry.insert<IREE::HAL::HALDialect>();
- registry.insert<arith::ArithmeticDialect>();
- }
-
- StringRef getArgument() const override {
- return "iree-hal-materialize-timelines";
- }
-
- StringRef getDescription() const override {
- return "Materializes timelines for device queues.";
- }
-
- void runOnOperation() override {
- auto moduleOp = getOperation();
- auto timeline = defineGlobalTimeline(moduleOp);
- rewriteTimelineOps(timeline, moduleOp);
- }
-};
-
-} // namespace
-
-std::unique_ptr<OperationPass<ModuleOp>> createMaterializeTimelinesPass() {
- return std::make_unique<MaterializeTimelinesPass>();
-}
-
-static PassRegistration<MaterializeTimelinesPass> pass([] {
- return std::make_unique<MaterializeTimelinesPass>();
-});
-
-} // namespace HAL
-} // namespace IREE
-} // namespace iree_compiler
-} // namespace mlir
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.cpp
index 025927d..2d6bd79 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.cpp
@@ -149,9 +149,6 @@
// Convert supported input dialects (std, stream, etc) into the HAL dialect.
passManager.addPass(createConvertToHALPass());
- // Materialize timelines for device queues.
- passManager.addPass(createMaterializeTimelinesPass());
-
// If any devices require the legacy synchronous execution behavior then
// make all async operations blocking.
passManager.addPass(createFixupLegacySyncPass());
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.h b/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.h
index 942b54d..abe4ace 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.h
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.h
@@ -54,9 +54,6 @@
// Converts input flow/std/etc dialects to the IREE HAL dialect.
std::unique_ptr<OperationPass<mlir::ModuleOp>> createConvertToHALPass();
-// Materializes timelines for device queues.
-std::unique_ptr<OperationPass<mlir::ModuleOp>> createMaterializeTimelinesPass();
-
//===----------------------------------------------------------------------===//
// Device management
//===----------------------------------------------------------------------===//
@@ -171,7 +168,6 @@
createLinkTargetExecutablesPass("");
createMaterializeInterfacesPass();
createMaterializeResourceCachesPass(targetOptions);
- createMaterializeTimelinesPass();
createMemoizeDeviceQueriesPass();
createResolveExportOrdinalsPass();
createSerializeExecutablesPass();
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/BUILD b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/BUILD
index 3b0dfc9..4f9c17c 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/BUILD
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/BUILD
@@ -26,7 +26,6 @@
"inline_device_switches.mlir",
"materialize_interfaces.mlir",
"materialize_resource_caches.mlir",
- "materialize_timelines.mlir",
"memoize_device_queries.mlir",
"resolve_export_ordinals.mlir",
"verify_target_environment.mlir",
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/CMakeLists.txt
index 6b8e6a2..0d854b8 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/CMakeLists.txt
@@ -24,7 +24,6 @@
"inline_device_switches.mlir"
"materialize_interfaces.mlir"
"materialize_resource_caches.mlir"
- "materialize_timelines.mlir"
"memoize_device_queries.mlir"
"resolve_export_ordinals.mlir"
"verify_target_environment.mlir"
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/convert_to_hal.mlir b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/convert_to_hal.mlir
index a74a2b1..d3c7fb3 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/convert_to_hal.mlir
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/convert_to_hal.mlir
@@ -116,7 +116,7 @@
} => !stream.timepoint
// CHECK: %[[WAIT_FENCE:.+]] = util.null : !hal.fence
- // CHECK: %[[SIGNAL_FENCE:.+]] = hal.timeline.advance
+ // CHECK: %[[SIGNAL_FENCE:.+]] = hal.fence.create
// CHECK: hal.device.queue.execute<%[[DEVICE]]
// CHECK-SAME: wait(%[[WAIT_FENCE]])
// CHECK-SAME: signal(%[[SIGNAL_FENCE]])
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/materialize_timelines.mlir b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/materialize_timelines.mlir
deleted file mode 100644
index b5a03bd..0000000
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/materialize_timelines.mlir
+++ /dev/null
@@ -1,44 +0,0 @@
-// RUN: iree-opt --split-input-file --iree-hal-materialize-timelines %s | FileCheck %s
-
-// CHECK: util.global private @_timeline_semaphore : !hal.semaphore
-// CHECK: util.initializer {
-// CHECK: %[[DEVICE:.+]] = hal.ex.shared_device
-// CHECK: %[[SEMAPHORE:.+]] = hal.semaphore.create
-// CHECK-SAME: device(%[[DEVICE]] : !hal.device)
-// CHECK-SAME: initial(%c0_i64)
-// CHECK-NEXT: util.global.store %[[SEMAPHORE]], @_timeline_semaphore
-// CHECK: }
-
-// CHECK: util.global private mutable @_timeline_value = 0 : i64
-
-// CHECK-LABEL: @fn1
-func.func @fn1() -> !hal.fence {
- // CHECK: %[[SEMAPHORE:.+]] = util.global.load @_timeline_semaphore
- // CHECK: %[[CURRENT_VALUE:.+]] = util.global.load @_timeline_value
- // CHECK: %[[NEXT_VALUE:.+]] = arith.addi %[[CURRENT_VALUE]], %c1
- // CHECK: util.global.store %[[NEXT_VALUE]], @_timeline_value
- // CHECK: %[[FENCE0:.+]] = hal.fence.create at<%[[SEMAPHORE]] : !hal.semaphore>(%[[NEXT_VALUE]])
- %0 = hal.timeline.advance : !hal.fence
- // CHECK: return %[[FENCE0]]
- return %0 : !hal.fence
-}
-
-// CHECK-LABEL: @fn2
-func.func @fn2(%arg0: i1, %arg1: !hal.fence) -> !hal.fence {
- // CHECK: %[[FENCE:.+]] = scf.if
- %0 = scf.if %arg0 -> (!hal.fence) {
- // CHECK: scf.yield %arg1
- scf.yield %arg1 : !hal.fence
- } else {
- // CHECK: %[[SEMAPHORE:.+]] = util.global.load @_timeline_semaphore
- // CHECK: %[[CURRENT_VALUE:.+]] = util.global.load @_timeline_value
- // CHECK: %[[NEXT_VALUE:.+]] = arith.addi %[[CURRENT_VALUE]], %c1
- // CHECK: util.global.store %[[NEXT_VALUE]], @_timeline_value
- // CHECK: %[[NEW_FENCE:.+]] = hal.fence.create at<%[[SEMAPHORE]] : !hal.semaphore>(%[[NEXT_VALUE]])
- %1 = hal.timeline.advance : !hal.fence
- // CHECK: scf.yield %[[NEW_FENCE]]
- scf.yield %1 : !hal.fence
- }
- // CHECK: return %[[FENCE]]
- return %0 : !hal.fence
-}
diff --git a/compiler/src/iree/compiler/Dialect/HAL/hal.imports.mlir b/compiler/src/iree/compiler/Dialect/HAL/hal.imports.mlir
index 0e6afa0..d6c745c 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/hal.imports.mlir
+++ b/compiler/src/iree/compiler/Dialect/HAL/hal.imports.mlir
@@ -108,6 +108,8 @@
// Creates a reference to a buffer with a particular shape and element type.
vm.import @buffer_view.create(
%buffer : !vm.ref<!hal.buffer>,
+ %source_offset : i64,
+ %source_length : i64,
%element_type : i32,
%encoding_type : i32,
%shape : i64 ...
@@ -358,12 +360,11 @@
// iree_hal_fence_t
//===----------------------------------------------------------------------===//
-// Returns a fence that defines a point in time across one or more timelines.
+// Returns an unsignaled fence that defines a point in time.
vm.import @fence.create(
- // <semaphore, min_value>
- %timepoints : tuple<!vm.ref<!hal.fence>, i64>...
+ %device : !vm.ref<!hal.device>,
+ %flags : i32
) -> !vm.ref<!hal.fence>
-attributes {nosideeffects}
// Returns a fence that joins the input fences as a wait-all operation.
vm.import @fence.join(
@@ -371,13 +372,20 @@
) -> !vm.ref<!hal.fence>
attributes {nosideeffects}
+// Queries whether the fence has been reached and returns its status.
+// Returns OK if the fence has been signaled successfully, DEFERRED if it is
+// unsignaled, and otherwise an error indicating the failure.
+vm.import @fence.query(
+ %fence : !vm.ref<!hal.fence>
+) -> i32
+
// Signals the fence.
vm.import @fence.signal(
%fence : !vm.ref<!hal.fence>
)
-// Signals the fence with a failure. The |status| will be returned from the
-// `hal.semaphore.query` and `hal.semaphore.signal` of each timepoint semaphore.
+// Signals the fence with a failure. The |status| will be returned from
+// `hal.fence.query` and `hal.fence.await`.
vm.import @fence.fail(
%fence : !vm.ref<!hal.fence>,
%status : i32
@@ -403,50 +411,4 @@
) -> !vm.ref<!hal.pipeline_layout>
attributes {nosideeffects}
-//===----------------------------------------------------------------------===//
-// iree_hal_semaphore_t
-//===----------------------------------------------------------------------===//
-
-// Returns a semaphore from the device pool with the given initial value.
-vm.import @semaphore.create(
- %device : !vm.ref<!hal.device>,
- %initial_value : i64
-) -> !vm.ref<!hal.semaphore>
-attributes {nosideeffects}
-
-// Queries the current payload and returns a tuple of `(status, value)`.
-// As the payload is monotonically increasing it is guaranteed that
-// the value is at least equal to the previous result of a
-// `hal.semaphore.signal` call and coherent with any waits for a
-// specified value via `hal.semaphore.await`.
-vm.import @semaphore.query(
- %semaphore : !vm.ref<!hal.semaphore>
-) -> (i32, i64)
-
-// Signals the semaphore to the given payload value.
-// The call is ignored if the current payload value exceeds |new_value|.
-vm.import @semaphore.signal(
- %semaphore : !vm.ref<!hal.semaphore>,
- %new_value : i64
-)
-
-// Signals the semaphore with a failure. The |status| will be returned from
-// `hal.semaphore.query` and `hal.semaphore.signal` for the lifetime
-// of the semaphore.
-vm.import @semaphore.fail(
- %semaphore : !vm.ref<!hal.semaphore>,
- %status : i32
-)
-
-// Yields the caller until the semaphore reaches or exceeds the specified
-// payload |value|.
-//
-// Returns the status of the semaphore after the wait, with a non-zero value
-// indicating failure.
-vm.import @semaphore.await(
- %semaphore : !vm.ref<!hal.semaphore>,
- %min_value : i64
-) -> i32
-attributes {vm.yield}
-
} // module
diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp
index 0c1f89a..e4aac4a 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp
+++ b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp
@@ -1604,6 +1604,29 @@
// stream.async.load
//===----------------------------------------------------------------------===//
+namespace {
+
+// Folds subsequent bitcasts into the load op. The bit width will be the same
+// and it avoids additional conversion.
+struct FoldAsyncLoadBitcast : public OpRewritePattern<AsyncLoadOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(AsyncLoadOp loadOp,
+ PatternRewriter &rewriter) const override {
+ auto loadedValue = loadOp.getResult();
+ if (!loadedValue.hasOneUse()) return failure();
+ auto bitcastOp =
+ dyn_cast<arith::BitcastOp>(*loadedValue.getUsers().begin());
+ if (!bitcastOp) return failure();
+ rewriter.updateRootInPlace(
+ loadOp, [&]() { loadedValue.setType(bitcastOp.getType()); });
+ bitcastOp.getResult().replaceAllUsesWith(loadedValue);
+ rewriter.eraseOp(bitcastOp);
+ return success();
+ }
+};
+
+} // namespace
+
void AsyncLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
// TODO(benvanik): splat + load -> splat value.
@@ -1611,16 +1634,40 @@
// TODO(benvanik): slice + ex load -> slice (ranged) + load.
// TODO(benvanik): value->transfer->load -> value->slice->transfer->load?
// TODO(benvanik): combine multiple loads from the same target if contiguous.
+ results.insert<FoldAsyncLoadBitcast>(context);
}
//===----------------------------------------------------------------------===//
// stream.async.store
//===----------------------------------------------------------------------===//
+namespace {
+
+// Folds preceding bitcasts into the store op. The bit width will be the same
+// and it avoids additional conversion.
+struct FoldAsyncStoreBitcast : public OpRewritePattern<AsyncStoreOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(AsyncStoreOp storeOp,
+ PatternRewriter &rewriter) const override {
+ auto storedValue = storeOp.getValue();
+ if (auto bitcastOp =
+ dyn_cast_or_null<arith::BitcastOp>(storedValue.getDefiningOp())) {
+ rewriter.updateRootInPlace(storeOp, [&]() {
+ storeOp.getValueMutable().assign(bitcastOp.getOperand());
+ });
+ return success();
+ }
+ return failure();
+ }
+};
+
+} // namespace
+
void AsyncStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
// TODO(benvanik): if value is a constant splat then turn into fill.
// TODO(benvanik): combine multiple stores to the same target if contiguous.
+ results.insert<FoldAsyncStoreBitcast>(context);
}
//===----------------------------------------------------------------------===//
diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.cpp b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.cpp
index 7e88b79..94a073a 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.cpp
@@ -146,6 +146,7 @@
availableResources.insert(result);
}
for (auto operand : op.getOperands()) {
+ if (!operand) continue;
if (!operand.getType().isa<IREE::Stream::ResourceType>()) continue;
if (!availableResources.contains(operand)) {
return op.emitOpError() << "used resource not listed in explicit "
diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamTypes.cpp b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamTypes.cpp
index 782a06c..f422079 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamTypes.cpp
+++ b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamTypes.cpp
@@ -25,7 +25,7 @@
namespace IREE {
namespace Stream {
-static llvm::cl::opt<Favor> partitioningFavor(
+static llvm::cl::opt<Favor> clPartitioningFavor(
"iree-stream-partitioning-favor",
llvm::cl::desc("Default stream partitioning favor configuration."),
llvm::cl::init(Favor::MaxConcurrency),
@@ -264,7 +264,7 @@
op = op->getParentOp();
}
// No config found; use defaults.
- auto favorAttr = FavorAttr::get(attrId.getContext(), partitioningFavor);
+ auto favorAttr = FavorAttr::get(attrId.getContext(), clPartitioningFavor);
return PartitioningConfigAttr::get(favorAttr);
}
@@ -359,7 +359,11 @@
void StreamDialect::registerAttributes() {
// Register command line flags:
- (void)partitioningFavor;
+ (void)clPartitioningFavor;
+ (void)clResourceMaxAllocationSize;
+ (void)clResourceMinOffsetAlignment;
+ (void)clResourceMaxRange;
+ (void)clResourceIndexBits;
addAttributes<
#define GET_ATTRDEF_LIST
diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/test/async_folding.mlir b/compiler/src/iree/compiler/Dialect/Stream/IR/test/async_folding.mlir
index 7d5440d..dc583b2 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/IR/test/async_folding.mlir
+++ b/compiler/src/iree/compiler/Dialect/Stream/IR/test/async_folding.mlir
@@ -184,6 +184,30 @@
// -----
+// CHECK-LABEL: @FoldAsyncLoadBitcast
+func.func @FoldAsyncLoadBitcast(%arg0: !stream.resource<staging>, %arg1: index) -> f32 {
+ %c0 = arith.constant 0 : index
+ // CHECK: %[[F32:.+]] = stream.async.load %arg0[%c0] : !stream.resource<staging>{%arg1} -> f32
+ %0 = stream.async.load %arg0[%c0] : !stream.resource<staging>{%arg1} -> i32
+ // CHECK-NOT: arith.bitcast
+ %1 = arith.bitcast %0 : i32 to f32
+ // CHECK: return %[[F32]]
+ return %1 : f32
+}
+
+// -----
+
+// CHECK-LABEL: @FoldAsyncStoreBitcast
+func.func @FoldAsyncStoreBitcast(%arg0: !stream.resource<staging>, %arg1: index, %arg2: f32) -> !stream.resource<staging> {
+ %c0 = arith.constant 0 : index
+ %0 = arith.bitcast %arg2 : f32 to i32
+ // CHECK: = stream.async.store %arg2, %arg0[%c0] : f32 -> %arg0 as !stream.resource<staging>{%arg1}
+ %1 = stream.async.store %0, %arg0[%c0] : i32 -> %arg0 as !stream.resource<staging>{%arg1}
+ return %1 : !stream.resource<staging>
+}
+
+// -----
+
// CHECK-LABEL: @ElideImmediateAsyncExecuteWaits
func.func @ElideImmediateAsyncExecuteWaits(%arg0: !stream.resource<*>, %arg1: index) -> (!stream.resource<*>, !stream.timepoint) {
%c1 = arith.constant 1 : index
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/PackConstants.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/PackConstants.cpp
index 99f3b1d..b448021 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/PackConstants.cpp
+++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/PackConstants.cpp
@@ -47,8 +47,11 @@
// Returns the length, in bytes, of the constant value prior to alignment or
// padding.
- uint64_t getRawLength() const {
- if (auto denseAttr = value.dyn_cast<DenseElementsAttr>()) {
+ uint64_t getStorageSize() const {
+ if (auto serializableAttr =
+ value.dyn_cast<IREE::Util::SerializableAttrInterface>()) {
+ return serializableAttr.getStorageSize();
+ } else if (auto denseAttr = value.dyn_cast<DenseElementsAttr>()) {
return denseAttr.getRawData().size();
} else {
assert(false && "invalid constant attr type");
@@ -92,7 +95,7 @@
for (auto slice : slices) {
uint64_t offset = IREE::Util::align(
currentBuffer->totalSize, resourceConfig.getMinBufferOffsetAlignment());
- uint64_t unpaddedLength = slice.getRawLength();
+ uint64_t unpaddedLength = slice.getStorageSize();
uint64_t paddedLength = IREE::Util::align(
unpaddedLength, resourceConfig.getMinBufferRangeAlignment());
if (offset + unpaddedLength > resourceConfig.getMaxAllocationSize()) {
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/pack_constants.mlir b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/pack_constants.mlir
index ee1efc5..6bad315 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/pack_constants.mlir
+++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/pack_constants.mlir
@@ -4,50 +4,55 @@
// Subsequent tests focus on individual components.
// Constants get packed into composite attributes.
-// CHECK: #composite_of_128b = #util.composite<128xi8, [
+// CHECK: #composite_of_192b = #util.composite<192xi8, [
// CHECK-NEXT: dense<100> : tensor<1xi32>,
// CHECK-NEXT: dense<0> : vector<60xi8>,
// CHECK-NEXT: dense<[101, 102]> : tensor<2xi32>,
// CHECK-NEXT: dense<0> : vector<56xi8>,
+// CHECK-NEXT: dense_resource<__elided__> : tensor<3x4xf32>,
+// CHECK-NEXT: dense<0> : vector<16xi8>,
// CHECK-NEXT: ]>
// CHECK-LABEL: @resourceConstants
-func.func @resourceConstants() -> (!stream.resource<constant>, !stream.resource<constant>, !stream.timepoint) {
+func.func @resourceConstants() -> (!stream.resource<constant>, !stream.resource<constant>, !stream.resource<constant>, !stream.timepoint) {
%c4 = arith.constant 4 : index
%c8 = arith.constant 8 : index
+ %c48 = arith.constant 48 : index
// Fetch the read-only host data containing the constants.
- // CHECK: %[[RODATA:.+]] = util.buffer.constant {alignment = 64 : index} : !util.buffer = #composite_of_128b
- %0:3 = stream.resource.constants :
+ // CHECK: %[[RODATA:.+]] = util.buffer.constant {alignment = 64 : index} : !util.buffer = #composite_of_192b
+ %0:4 = stream.resource.constants :
!stream.resource<constant>{%c4} = dense<100> : tensor<1xi32>,
- !stream.resource<constant>{%c8} = dense<[101, 102]> : tensor<2xi32>
+ !stream.resource<constant>{%c8} = dense<[101, 102]> : tensor<2xi32>,
+ !stream.resource<constant>{%c48} = dense_resource<__elided__> : tensor<3x4xf32>
=> !stream.timepoint
// Try first to map the memory directly into a usable resource. If this
// succeeds we are done and can avoid allocation/complete immediately.
// CHECK: %[[DID_MAP:.+]], %[[TRY_MAP:.+]] = stream.resource.try_map %[[RODATA]][%c0] :
- // CHECK-SAME: !util.buffer -> i1, !stream.resource<constant>{%c128}
+ // CHECK-SAME: !util.buffer -> i1, !stream.resource<constant>{%c192}
// CHECK: %[[IF:.+]]:2 = scf.if %[[DID_MAP]] -> (!stream.resource<constant>, !stream.timepoint) {
// CHECK-NEXT: %[[IMMEDIATE:.+]] = stream.timepoint.immediate => !stream.timepoint
// CHECK-NEXT: scf.yield %[[TRY_MAP]], %[[IMMEDIATE]]
// CHECK-NEXT: } else {
// If the mapping fails we need to perform an upload via a staging buffer.
- // CHECK: %[[STAGING:.+]] = stream.resource.map %[[RODATA]][%c0] : !util.buffer -> !stream.resource<staging>{%c128}
- // CHECK: %[[ALLOC:.+]] = stream.resource.alloc uninitialized : !stream.resource<constant>{%c128}
+ // CHECK: %[[STAGING:.+]] = stream.resource.map %[[RODATA]][%c0] : !util.buffer -> !stream.resource<staging>{%c192}
+ // CHECK: %[[ALLOC:.+]] = stream.resource.alloc uninitialized : !stream.resource<constant>{%c192}
// CHECK: %[[EXEC_TIMEPOINT:.+]] = stream.cmd.execute
- // CHECK-SAME: with(%[[STAGING]] as %[[STAGING_CAPTURE:.+]]: !stream.resource<staging>{%c128},
- // CHECK-SAME: %[[ALLOC]] as %[[ALLOC_CAPTURE:.+]]: !stream.resource<constant>{%c128}) {
- // CHECK: stream.cmd.copy %[[STAGING_CAPTURE]][%c0], %[[ALLOC_CAPTURE]][%c0], %c128 : !stream.resource<staging>{%c128} -> !stream.resource<constant>{%c128}
+ // CHECK-SAME: with(%[[STAGING]] as %[[STAGING_CAPTURE:.+]]: !stream.resource<staging>{%c192},
+ // CHECK-SAME: %[[ALLOC]] as %[[ALLOC_CAPTURE:.+]]: !stream.resource<constant>{%c192}) {
+ // CHECK: stream.cmd.copy %[[STAGING_CAPTURE]][%c0], %[[ALLOC_CAPTURE]][%c0], %c192 : !stream.resource<staging>{%c192} -> !stream.resource<constant>{%c192}
// CHECK: } => !stream.timepoint
// CHECK: scf.yield %[[ALLOC]], %[[EXEC_TIMEPOINT]]
// Get subviews pointing to the subresources within the packed resource.
- // CHECK: %[[RES0:.+]] = stream.resource.subview %[[IF]]#0[%c0] : !stream.resource<constant>{%c128} -> !stream.resource<constant>{%c4}
- // CHECK: %[[RES1:.+]] = stream.resource.subview %[[IF]]#0[%c64] : !stream.resource<constant>{%c128} -> !stream.resource<constant>{%c8}
+ // CHECK: %[[RES0:.+]] = stream.resource.subview %[[IF]]#0[%c0] : !stream.resource<constant>{%c192} -> !stream.resource<constant>{%c4}
+ // CHECK: %[[RES1:.+]] = stream.resource.subview %[[IF]]#0[%c64] : !stream.resource<constant>{%c192} -> !stream.resource<constant>{%c8}
+ // CHECK: %[[RES2:.+]] = stream.resource.subview %[[IF]]#0[%c128] : !stream.resource<constant>{%c192} -> !stream.resource<constant>{%c48}
- // CHECK: return %[[RES0]], %[[RES1]], %[[IF]]#1
- return %0#0, %0#1, %0#2 : !stream.resource<constant>, !stream.resource<constant>, !stream.timepoint
+ // CHECK: return %[[RES0]], %[[RES1]], %[[RES2]], %[[IF]]#1
+ return %0#0, %0#1, %0#2, %0#3 : !stream.resource<constant>, !stream.resource<constant>, !stream.resource<constant>, !stream.timepoint
}
// -----
diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/UtilAttrs.cpp b/compiler/src/iree/compiler/Dialect/Util/IR/UtilAttrs.cpp
index 38e21d1..d53d317 100644
--- a/compiler/src/iree/compiler/Dialect/Util/IR/UtilAttrs.cpp
+++ b/compiler/src/iree/compiler/Dialect/Util/IR/UtilAttrs.cpp
@@ -7,11 +7,14 @@
#include "iree/compiler/Dialect/Util/IR/UtilDialect.h"
#include "iree/compiler/Dialect/Util/IR/UtilTypes.h"
#include "llvm/ADT/BitVector.h"
+#include "llvm/Support/CommandLine.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/DialectImplementation.h"
+#include "mlir/IR/DialectResourceBlobManager.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/TypeSupport.h"
@@ -28,6 +31,11 @@
namespace IREE {
namespace Util {
+static llvm::cl::opt<bool> clZeroFillElidedAttrs(
+ "iree-util-zero-fill-elided-attrs",
+ llvm::cl::desc("Fills elided attributes with zeros when serializing."),
+ llvm::cl::init(false));
+
//===----------------------------------------------------------------------===//
// ostream utilities
//===----------------------------------------------------------------------===//
@@ -470,8 +478,8 @@
SerializableDenseElementsAttrModel, DenseIntOrFPElementsAttr> {
int64_t getStorageSize(Attribute baseAttr) const {
auto attr = baseAttr.cast<ElementsAttr>();
- int32_t bitwidth = attr.getType().getElementTypeBitWidth();
- return attr.getNumElements() * (bitwidth / 8);
+ return attr.getNumElements() * IREE::Util::getRoundedElementByteWidth(
+ attr.getType().getElementType());
}
LogicalResult serializeToVector(Attribute baseAttr,
@@ -514,6 +522,59 @@
}
};
+// External interface applied to ElementsAttrs so that we can serialize them to
+// byte buffers.
+struct SerializableDenseResourceElementsAttrModel
+ : public SerializableAttrInterface::ExternalModel<
+ SerializableDenseResourceElementsAttrModel,
+ DenseResourceElementsAttr> {
+ int64_t getStorageSize(Attribute baseAttr) const {
+ auto attr = baseAttr.cast<DenseResourceElementsAttr>();
+ return attr.getNumElements() * IREE::Util::getRoundedElementByteWidth(
+ attr.getType().getElementType());
+ }
+
+ LogicalResult serializeToVector(Attribute baseAttr,
+ llvm::support::endianness endian,
+ SmallVectorImpl<char> &buffer) const {
+ buffer.resize(getStorageSize(baseAttr));
+ return serializeToBuffer(baseAttr, endian, buffer);
+ }
+
+ LogicalResult serializeToBuffer(Attribute baseAttr,
+ llvm::support::endianness endian,
+ ArrayRef<char> buffer) const {
+ raw_inplace_ostream os(buffer);
+ return serializeToStream(baseAttr, endian, os);
+ }
+
+ LogicalResult serializeToStream(Attribute baseAttr,
+ llvm::support::endianness endian,
+ llvm::raw_ostream &os) const {
+ auto attr = baseAttr.cast<DenseResourceElementsAttr>();
+ auto handle = attr.getRawHandle();
+
+ // Special testing path for elided attributes. We want this to be an
+ // error in normal circumstances as the output will produce garbage
+ // results if executed but it can be useful when building reproducers.
+ if (handle.getKey() == "__elided__") {
+ if (!clZeroFillElidedAttrs) {
+ return mlir::emitError(UnknownLoc::get(baseAttr.getContext()))
+ << "elided attributes cannot be serialized; provide non-elided "
+ "values or pass --iree-util-zero-fill-elided-attrs for "
+ "testing and expect invalid execution results";
+ }
+ os.write_zeros(attr.getNumElements() *
+ IREE::Util::getRoundedElementByteWidth(
+ attr.getType().getElementType()));
+ return success();
+ }
+
+ return mlir::emitError(UnknownLoc::get(baseAttr.getContext()))
+ << "DenseResourceElementsAttr not yet supported for serialization";
+ }
+};
+
// External interface applied to string attrs so that we can serialize them to
// byte buffers. We don't include NUL terminators as it's 2022.
struct SerializableStringAttrModel
@@ -558,6 +619,9 @@
#include "iree/compiler/Dialect/Util/IR/UtilAttrInterfaces.cpp.inc"
void UtilDialect::registerAttributes() {
+ // Register command line flags:
+ (void)clZeroFillElidedAttrs;
+
addAttributes<
#define GET_ATTRDEF_LIST
#include "iree/compiler/Dialect/Util/IR/UtilAttrs.cpp.inc" // IWYU pragma: keep
@@ -567,11 +631,14 @@
// serialization mechanism and may be something we want to handle much higher
// up in the stack - things that end up here are generally already in a target
// encoding.
+ auto &context = *getContext();
DenseIntElementsAttr::attachInterface<SerializableDenseElementsAttrModel>(
- *getContext());
+ context);
DenseFPElementsAttr::attachInterface<SerializableDenseElementsAttrModel>(
- *getContext());
- StringAttr::attachInterface<SerializableStringAttrModel>(*getContext());
+ context);
+ DenseResourceElementsAttr::attachInterface<
+ SerializableDenseResourceElementsAttrModel>(context);
+ StringAttr::attachInterface<SerializableStringAttrModel>(context);
}
} // namespace Util
diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/UtilBase.td b/compiler/src/iree/compiler/Dialect/Util/IR/UtilBase.td
index 5efe0df..5c61a05 100644
--- a/compiler/src/iree/compiler/Dialect/Util/IR/UtilBase.td
+++ b/compiler/src/iree/compiler/Dialect/Util/IR/UtilBase.td
@@ -71,6 +71,7 @@
def Util_AnySerializableAttr : Attr<Or<[
CPred<"$_self.isa<mlir::DenseElementsAttr>()">,
+ CPred<"$_self.isa<mlir::DenseResourceElementsAttr>()">,
CPred<"$_self.isa<IREE::Util::SerializableAttrInterface>()">,
]>, "buffer-like constant attribute values"> {
let storageType = [{ ::mlir::Attribute }];
diff --git a/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/ArchiveWriter.cpp b/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/ArchiveWriter.cpp
index edf5f9f..79c1368 100644
--- a/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/ArchiveWriter.cpp
+++ b/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/ArchiveWriter.cpp
@@ -357,7 +357,7 @@
// appendZIPFile implementation used when |os| is a stream without random
// access (like stdout). This requires us to serialize the file twice in order
// to compute the total length and CRC32.
-static ZIPFileRef appendZIPFileToStream(
+static Optional<ZIPFileRef> appendZIPFileToStream(
std::string fileName, uint64_t filePadding, uint64_t fileLength,
std::function<LogicalResult(llvm::raw_ostream &os)> write,
llvm::raw_ostream &os) {
@@ -376,7 +376,7 @@
uint32_t crc32 = 0;
null_crc32_ostream crcStream(crc32);
if (failed(write(crcStream))) {
- return {};
+ return None;
}
// Write the ZIP header and padding up to the start of the file.
@@ -386,7 +386,7 @@
// Stream out the file contents to the output stream.
uint64_t start = os.tell();
if (failed(write(os))) {
- return {};
+ return None;
}
fileRef.totalLength = os.tell() - start;
assert(fileRef.totalLength == fileLength && "declared length mismatch");
@@ -416,7 +416,7 @@
// appendZIPFile implementation used when |os| is a file with random access.
// This allows us to write the header and backpatch the CRC computed while while
// serializing the file contents.
-static ZIPFileRef appendZIPFileToFD(
+static Optional<ZIPFileRef> appendZIPFileToFD(
std::string fileName, uint64_t filePadding, uint64_t fileLength,
std::function<LogicalResult(llvm::raw_ostream &os)> write,
llvm::raw_fd_ostream &os) {
@@ -431,7 +431,7 @@
{
crc32_ostream crcStream(os, fileRef.crc32);
if (failed(write(crcStream))) {
- return {};
+ return None;
}
crcStream.flush();
}
@@ -450,7 +450,7 @@
// Appends a file wrapped in a ZIP header and data descriptor.
// |write| is used to stream the file contents to |os| while also capturing its
// CRC as required for the central directory.
-static ZIPFileRef appendZIPFile(
+static Optional<ZIPFileRef> appendZIPFile(
std::string fileName, uint64_t filePadding, uint64_t fileLength,
std::function<LogicalResult(llvm::raw_ostream &os)> write,
llvm::raw_ostream &os) {
@@ -627,7 +627,7 @@
sizeof(flatbuffers_uoffset_t));
// Stream out the FlatBuffer contents.
- fileRefs.push_back(appendZIPFile(
+ auto zipFile = appendZIPFile(
moduleName, modulePadding, paddedModuleLength,
[&](llvm::raw_ostream &os) -> LogicalResult {
os.write(reinterpret_cast<char *>(&paddedModuleLength),
@@ -637,7 +637,11 @@
os.write_zeros(paddedModuleLength - moduleData.size());
return success();
},
- os));
+ os);
+ if (!zipFile.has_value()) {
+ return mlir::emitError(loc) << "failed to serialize flatbuffer module";
+ }
+ fileRefs.push_back(*zipFile);
// Pad out to the start of the external rodata segment.
// This ensures we begin writing at an aligned offset; all relative offsets
@@ -655,7 +659,7 @@
baseOffset + file.relativeOffset + file.prefixLength - os.tell());
// Write file header and payload.
- fileRefs.push_back(appendZIPFile(
+ auto zipFile = appendZIPFile(
file.fileName, filePadding, file.fileLength,
[this, file](llvm::raw_ostream &os) -> LogicalResult {
if (failed(file.write(os))) {
@@ -666,7 +670,9 @@
}
return success();
},
- os));
+ os);
+ if (!zipFile.has_value()) return failure();
+ fileRefs.push_back(*zipFile);
}
// Append the central directory containing an index of all the files.
diff --git a/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeEncoder.h b/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeEncoder.h
index 85cf7f0..f58100e 100644
--- a/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeEncoder.h
+++ b/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeEncoder.h
@@ -32,7 +32,7 @@
class BytecodeEncoder : public VMFuncEncoder {
public:
// Matches IREE_VM_BYTECODE_VERSION_MAJOR.
- static constexpr uint32_t kVersionMajor = 10;
+ static constexpr uint32_t kVersionMajor = 12;
// Matches IREE_VM_BYTECODE_VERSION_MINOR.
static constexpr uint32_t kVersionMinor = 0;
static constexpr uint32_t kVersion = (kVersionMajor << 16) | kVersionMinor;
diff --git a/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeModuleTarget.cpp b/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeModuleTarget.cpp
index 6ffdf53..0ba032e 100644
--- a/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeModuleTarget.cpp
+++ b/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeModuleTarget.cpp
@@ -120,6 +120,8 @@
llvm::support::endianness::little,
ArrayRef<char>(reinterpret_cast<char *>(bytePtr),
static_cast<size_t>(totalSize))))) {
+ mlir::emitError(loc) << "constant attribute failed to serialize: "
+ "unsupported format or encoding";
return {};
}
@@ -420,23 +422,25 @@
// layout planning by preserving the order in the IR is useful.
SmallVector<iree_vm_RodataSegmentDef_ref_t, 8> rodataSegmentRefs;
for (auto &rodataRef : llvm::reverse(rodataRefs)) {
- flatbuffers_uint8_vec_ref_t embedded_ref = 0;
- if (!rodataRef.archiveFile.has_value()) {
- embedded_ref = serializeEmbeddedData(
- rodataRef.rodataOp.getLoc(), rodataRef.rodataOp.getValue(),
- rodataRef.alignment, rodataRef.totalSize, fbb);
- }
- iree_vm_RodataSegmentDef_start(fbb);
if (rodataRef.archiveFile.has_value()) {
+ // Data is already in the file at a calculated offset.
+ iree_vm_RodataSegmentDef_start(fbb);
iree_vm_RodataSegmentDef_external_data_offset_add(
fbb, rodataRef.archiveFile->relativeOffset +
rodataRef.archiveFile->prefixLength);
iree_vm_RodataSegmentDef_external_data_length_add(
fbb, rodataRef.archiveFile->fileLength);
+ rodataSegmentRefs.push_back(iree_vm_RodataSegmentDef_end(fbb));
} else {
- iree_vm_RodataSegmentDef_embedded_data_add(fbb, embedded_ref);
+ // Serialize the embedded data first so that we can reference it.
+ flatbuffers_uint8_vec_ref_t embeddedRef = serializeEmbeddedData(
+ rodataRef.rodataOp.getLoc(), rodataRef.rodataOp.getValue(),
+ rodataRef.alignment, rodataRef.totalSize, fbb);
+ if (!embeddedRef) return failure();
+ iree_vm_RodataSegmentDef_start(fbb);
+ iree_vm_RodataSegmentDef_embedded_data_add(fbb, embeddedRef);
+ rodataSegmentRefs.push_back(iree_vm_RodataSegmentDef_end(fbb));
}
- rodataSegmentRefs.push_back(iree_vm_RodataSegmentDef_end(fbb));
}
std::reverse(rodataSegmentRefs.begin(), rodataSegmentRefs.end());
diff --git a/compiler/src/iree/compiler/Dialect/VMVX/Transforms/ResolveBufferDescriptors.cpp b/compiler/src/iree/compiler/Dialect/VMVX/Transforms/ResolveBufferDescriptors.cpp
index b7c15f4..1483925 100644
--- a/compiler/src/iree/compiler/Dialect/VMVX/Transforms/ResolveBufferDescriptors.cpp
+++ b/compiler/src/iree/compiler/Dialect/VMVX/Transforms/ResolveBufferDescriptors.cpp
@@ -33,6 +33,7 @@
auto sourceType = source.getType().cast<MemRefType>();
int sourceRank = sourceType.getRank();
int subRank = subType.getRank();
+ (void)subRank;
// Create a descriptor for the source.
IndexType indexType = rewriter.getIndexType();
@@ -44,20 +45,28 @@
loc, op.getBaseBuffer().getType(), indexType, sizeStrideTypes,
sizeStrideTypes, subview.getSource());
- // For sizes, we just use the new ones, discarding the source.
+ // For sizes, we just use the new ones.
+ llvm::SmallBitVector droppedDims = subview.getDroppedDims();
+ unsigned insertedDims = 0;
SmallVector<Value> newSizes;
- for (int i = 0; i < subRank; ++i) {
+ for (int i = 0; i < sourceRank; ++i) {
+ // Skip the sizes that don't show up in the final type.
+ if (droppedDims.test(i)) continue;
+
if (subview.isDynamicSize(i)) {
newSizes.push_back(subview.getDynamicSize(i));
} else {
newSizes.push_back(indexSet.get(subview.getStaticSize(i)));
}
- op.getSizes()[i].replaceAllUsesWith(newSizes.back());
+ op.getSizes()[insertedDims++].replaceAllUsesWith(newSizes.back());
}
+ assert(insertedDims == subRank &&
+ "Should have populated all the non-reduced sizes");
// Apply stride multipliers.
SmallVector<Value> strides;
- for (int i = 0; i < subRank; ++i) {
+ insertedDims = 0;
+ for (int i = 0; i < sourceRank; ++i) {
Value currentStride;
if (subview.isDynamicStride(i)) {
currentStride = subview.getDynamicStride(i);
@@ -67,12 +76,20 @@
currentStride = rewriter.createOrFold<arith::MulIOp>(
loc, sourceDesc.getStrides()[i], currentStride);
strides.push_back(currentStride);
- op.getStrides()[i].replaceAllUsesWith(currentStride);
+
+ // Don't replace the value of dropped dimensions.
+ // Although the new stride will be used in the computation of the final
+ // offset, there's no value to replace.
+ if (droppedDims.test(i)) continue;
+
+ op.getStrides()[insertedDims++].replaceAllUsesWith(currentStride);
}
+ assert(insertedDims == subRank &&
+ "Should have populated all the non-reduced strides");
// Offsets.
Value offset = sourceDesc.getOffset();
- for (int i = 0; i < subRank; ++i) {
+ for (int i = 0; i < sourceRank; ++i) {
Value logicalOffset;
if (subview.isDynamicOffset(i)) {
logicalOffset = subview.getDynamicOffset(i);
diff --git a/compiler/src/iree/compiler/Dialect/VMVX/Transforms/test/resolve_buffer_descriptors.mlir b/compiler/src/iree/compiler/Dialect/VMVX/Transforms/test/resolve_buffer_descriptors.mlir
index d2d38aa..d4fda74 100644
--- a/compiler/src/iree/compiler/Dialect/VMVX/Transforms/test/resolve_buffer_descriptors.mlir
+++ b/compiler/src/iree/compiler/Dialect/VMVX/Transforms/test/resolve_buffer_descriptors.mlir
@@ -28,7 +28,9 @@
// CHECK-DAG: %[[C64:.*]] = arith.constant 64 : index
// CHECK-DAG: %[[I0:.*]] = arith.muli %arg1, %[[BASE_STRIDES]]#0 : index
// CHECK: %[[I1:.*]] = arith.addi %[[BASE_OFFSET]], %[[I0]] : index
- // CHECK: return %[[BASE_BUFFER]], %[[I1]], %[[C64]], %[[BASE_STRIDES]]#0
+ // CHECK: %[[I2:.*]] = arith.muli %arg2, %[[BASE_STRIDES]]#1 : index
+ // CHECK: %[[I3:.*]] = arith.addi %[[I1]], %[[I2]] : index
+ // CHECK: return %[[BASE_BUFFER]], %[[I3]], %[[C64]], %[[BASE_STRIDES]]#0
%0 = memref.subview %arg0[%arg1, %arg2] [64, 1] [1, 1] : memref<384x128xf32> to memref<64xf32, #map0>
%base_buffer, %offset, %size, %stride = vmvx.get_buffer_descriptor %0 : memref<64xf32, #map0> -> !util.buffer, index, index, index
return %base_buffer, %offset, %size, %stride : !util.buffer, index, index, index
@@ -36,6 +38,37 @@
// -----
+// Check that we properly resolve subview with rankreducing when the dropped
+// rank is not the last one.
+// Orig strides: [%strides#0, %strides#1, %strides#2]
+// Sub strides: [1, 1, 1]
+// => New strides: [%strides#0, %strides#1, %strides#2]
+// Final strides == filterOutReducedDim(new strides, 0) == [%strides#1 , %strides#2]
+//
+// Orig offset: %offset
+// Sub offsets: [%arg1, %arg2, 0]
+// => Final offset: %arg1 * %strides#0 + %arg2 * %strides#1 + 0 * %strides#2 + %offset
+//
+// Final sizes == filterOutReducedDim(subview sizes, 0) == [6, 3]
+//
+// CHECK-LABEL: @resolve_subview_rankreducing_not_at_the_end
+func.func @resolve_subview_rankreducing_not_at_the_end(%arg0: memref<8x16x4xf32>, %arg1 : index, %arg2 : index) -> (!util.buffer, index, index, index, index, index) {
+ // CHECK-DAG: %[[BASE_BUFFER:.*]], %[[BASE_OFFSET:.*]], %[[BASE_SIZES:.*]]:3, %[[BASE_STRIDES:.*]]:3 = vmvx.get_buffer_descriptor %arg0
+ // CHECK-DAG: %[[C6:.*]] = arith.constant 6 : index
+ // CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
+ // CHECK-DAG: %[[I0:.*]] = arith.muli %arg1, %[[BASE_STRIDES]]#0 : index
+ // CHECK: %[[I1:.*]] = arith.addi %[[BASE_OFFSET]], %[[I0]] : index
+ // CHECK: %[[I2:.*]] = arith.muli %arg2, %[[BASE_STRIDES]]#1 : index
+ // CHECK: %[[I3:.*]] = arith.addi %[[I1]], %[[I2]] : index
+ // CHECK: return %[[BASE_BUFFER]], %[[I3]], %[[C6]], %[[C3]], %[[BASE_STRIDES]]#1, %[[BASE_STRIDES]]#2
+
+ %0 = memref.subview %arg0[%arg1, %arg2, 0] [1, 6, 3] [1, 1, 1] : memref<8x16x4xf32> to memref<6x3xf32, strided<[4,1], offset : ?>>
+ %base_buffer, %offset, %sizes:2, %strides:2 = vmvx.get_buffer_descriptor %0 : memref<6x3xf32, strided<[4,1], offset : ?>> -> !util.buffer, index, index, index, index, index
+ return %base_buffer, %offset, %sizes#0, %sizes#1, %strides#0, %strides#1 : !util.buffer, index, index, index, index, index
+}
+
+// -----
+
// CHECK-LABEL: @resolve_binding_subspan_zero_offset
func.func @resolve_binding_subspan_zero_offset() -> (!util.buffer, index, index, index, index, index) {
// CHECK-DAG: %[[C512:.*]] = arith.constant 512 : index
diff --git a/compiler/src/iree/compiler/Modules/HAL/Inline/Conversion/HALToHALInline/ConvertHALToHALInline.cpp b/compiler/src/iree/compiler/Modules/HAL/Inline/Conversion/HALToHALInline/ConvertHALToHALInline.cpp
index a2a8911..4aa7954 100644
--- a/compiler/src/iree/compiler/Modules/HAL/Inline/Conversion/HALToHALInline/ConvertHALToHALInline.cpp
+++ b/compiler/src/iree/compiler/Modules/HAL/Inline/Conversion/HALToHALInline/ConvertHALToHALInline.cpp
@@ -91,7 +91,8 @@
IREE::HAL::BufferViewCreateOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<IREE::HAL::Inline::BufferViewCreateOp>(
- op, adaptor.getBuffer(), adaptor.getElementType(),
+ op, adaptor.getSourceBuffer(), adaptor.getSourceOffset(),
+ adaptor.getSourceLength(), adaptor.getElementType(),
adaptor.getEncodingType(), adaptor.getShape());
return success();
}
diff --git a/compiler/src/iree/compiler/Modules/HAL/Inline/Conversion/HALToHALInline/test/buffer_view_ops.mlir b/compiler/src/iree/compiler/Modules/HAL/Inline/Conversion/HALToHALInline/test/buffer_view_ops.mlir
index c65b1c0..ca71d84 100644
--- a/compiler/src/iree/compiler/Modules/HAL/Inline/Conversion/HALToHALInline/test/buffer_view_ops.mlir
+++ b/compiler/src/iree/compiler/Modules/HAL/Inline/Conversion/HALToHALInline/test/buffer_view_ops.mlir
@@ -5,11 +5,11 @@
%c1 = arith.constant 1 : i32
%c32 = arith.constant 32 : i32
// CHECK: %view = hal_inline.buffer_view.create
- // CHECK-SAME: buffer(%arg0 : !hal.buffer)
+ // CHECK-SAME: buffer(%arg0 : !hal.buffer)[%arg1, %arg2]
// CHECK-SAME: shape([%arg1, %arg2])
// CHECK-SAME: type(%c32_i32)
// CHECK-SAME: encoding(%c1_i32) : !hal.buffer_view
- %view = hal.buffer_view.create buffer(%arg0 : !hal.buffer)
+ %view = hal.buffer_view.create buffer(%arg0 : !hal.buffer)[%arg1, %arg2]
shape([%arg1, %arg2])
type(%c32)
encoding(%c1) : !hal.buffer_view
diff --git a/compiler/src/iree/compiler/Modules/HAL/Inline/Conversion/StreamToHALInline/ConvertStreamToHALInline.cpp b/compiler/src/iree/compiler/Modules/HAL/Inline/Conversion/StreamToHALInline/ConvertStreamToHALInline.cpp
index 3b13f27..aa73b2f 100644
--- a/compiler/src/iree/compiler/Modules/HAL/Inline/Conversion/StreamToHALInline/ConvertStreamToHALInline.cpp
+++ b/compiler/src/iree/compiler/Modules/HAL/Inline/Conversion/StreamToHALInline/ConvertStreamToHALInline.cpp
@@ -318,8 +318,10 @@
}
rewriter.replaceOpWithNewOp<IREE::HAL::Inline::BufferViewCreateOp>(
- exportOp, adaptor.getSource(), elementType.value(),
- encodingType.value(), dims);
+ exportOp, adaptor.getSource(),
+ rewriter.create<arith::ConstantIndexOp>(loc, 0),
+ adaptor.getSourceSize(), elementType.value(), encodingType.value(),
+ dims);
return success();
}
};
diff --git a/compiler/src/iree/compiler/Modules/HAL/Inline/IR/HALInlineOps.cpp b/compiler/src/iree/compiler/Modules/HAL/Inline/IR/HALInlineOps.cpp
index 9b926ed..b6ed61d 100644
--- a/compiler/src/iree/compiler/Modules/HAL/Inline/IR/HALInlineOps.cpp
+++ b/compiler/src/iree/compiler/Modules/HAL/Inline/IR/HALInlineOps.cpp
@@ -123,9 +123,10 @@
//===----------------------------------------------------------------------===//
void BufferViewCreateOp::build(OpBuilder &builder, OperationState &state,
- Value buffer, int32_t elementType,
+ Value sourceBuffer, Value sourceOffset,
+ Value sourceLength, int32_t elementType,
int32_t encodingType, ValueRange shape) {
- build(builder, state, buffer,
+ build(builder, state, sourceBuffer, sourceOffset, sourceLength,
builder.createOrFold<arith::ConstantIntOp>(state.location, elementType,
32),
builder.createOrFold<arith::ConstantIntOp>(state.location, encodingType,
@@ -134,9 +135,11 @@
}
void BufferViewCreateOp::build(OpBuilder &builder, OperationState &state,
- Value buffer, Value elementType,
+ Value sourceBuffer, Value sourceOffset,
+ Value sourceLength, Value elementType,
Value encodingType, ValueRange shape) {
- state.addOperands({buffer, elementType, encodingType});
+ state.addOperands(
+ {sourceBuffer, sourceOffset, sourceLength, elementType, encodingType});
state.addOperands(shape);
state.addTypes({BufferViewType::get(builder.getContext())});
}
@@ -146,6 +149,44 @@
setNameFn(getResult(), "view");
}
+namespace {
+
+/// Folds hal_inline.buffer_view.subspans into buffer view creation subspans.
+struct FoldBufferViewCreateSubspan
+ : public OpRewritePattern<BufferViewCreateOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(BufferViewCreateOp op,
+ PatternRewriter &rewriter) const override {
+ auto ip = rewriter.saveInsertionPoint();
+ rewriter.setInsertionPoint(op);
+ bool needsUpdate = false;
+ auto newSourceBuffer = op.getSourceBuffer();
+ auto newSourceOffset = op.getSourceOffset();
+ if (auto subspanOp = dyn_cast_or_null<BufferSubspanOp>(
+ op.getSourceBuffer().getDefiningOp())) {
+ newSourceBuffer = subspanOp.getSourceBuffer();
+ newSourceOffset = rewriter.createOrFold<mlir::arith::AddIOp>(
+ subspanOp.getLoc(), subspanOp.getSourceOffset(),
+ op.getSourceOffset());
+ needsUpdate = true;
+ }
+ rewriter.restoreInsertionPoint(ip);
+ if (!needsUpdate) return failure();
+ rewriter.updateRootInPlace(op, [&]() {
+ op.getSourceBufferMutable().assign(newSourceBuffer);
+ op.getSourceOffsetMutable().assign(newSourceOffset);
+ });
+ return success();
+ }
+};
+
+} // namespace
+
+void BufferViewCreateOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ results.insert<FoldBufferViewCreateSubspan>(context);
+}
+
//===----------------------------------------------------------------------===//
// hal_inline.buffer_view.buffer
//===----------------------------------------------------------------------===//
@@ -161,12 +202,11 @@
/// the same scope and we know the origin buffer.
struct SkipBufferViewBufferOp : public OpRewritePattern<BufferViewBufferOp> {
using OpRewritePattern<BufferViewBufferOp>::OpRewritePattern;
-
LogicalResult matchAndRewrite(BufferViewBufferOp op,
PatternRewriter &rewriter) const override {
if (auto createOp = dyn_cast_or_null<BufferViewCreateOp>(
op.getBufferView().getDefiningOp())) {
- rewriter.replaceOp(op, createOp.getBuffer());
+ rewriter.replaceOp(op, createOp.getSourceBuffer());
return success();
}
return failure();
diff --git a/compiler/src/iree/compiler/Modules/HAL/Inline/IR/HALInlineOps.td b/compiler/src/iree/compiler/Modules/HAL/Inline/IR/HALInlineOps.td
index ae9cbed..55f9f45 100644
--- a/compiler/src/iree/compiler/Modules/HAL/Inline/IR/HALInlineOps.td
+++ b/compiler/src/iree/compiler/Modules/HAL/Inline/IR/HALInlineOps.td
@@ -213,7 +213,9 @@
}];
let arguments = (ins
- HAL_BufferType:$buffer,
+ HAL_BufferType:$source_buffer,
+ HAL_DeviceSize:$source_offset,
+ HAL_DeviceSize:$source_length,
HAL_ElementType:$element_type,
HAL_EncodingType:$encoding_type,
HAL_Shape:$shape
@@ -223,7 +225,8 @@
);
let assemblyFormat = [{
- `buffer` `(` $buffer `:` type($buffer) `)`
+ `buffer` `(` $source_buffer `:` type($source_buffer) `)`
+ `` `[` $source_offset `,` $source_length `]`
`shape` `(` `[` $shape `]` `)`
`type` `(` $element_type `)`
`encoding` `(` $encoding_type `)`
@@ -234,18 +237,24 @@
let skipDefaultBuilders = 1;
let builders = [
OpBuilder<(ins
- "Value":$buffer,
+ "Value":$sourceBuffer,
+ "Value":$sourceOffset,
+ "Value":$sourceLength,
"int32_t":$elementType,
"int32_t":$encodingType,
"ValueRange":$shape
)>,
OpBuilder<(ins
- "Value":$buffer,
+ "Value":$sourceBuffer,
+ "Value":$sourceOffset,
+ "Value":$sourceLength,
"Value":$elementType,
"Value":$encodingType,
"ValueRange":$shape
)>,
];
+
+ let hasCanonicalizer = 1;
}
def HALInline_BufferViewAssertOp : HALInline_Op<"buffer_view.assert"> {
diff --git a/compiler/src/iree/compiler/Modules/HAL/Inline/IR/test/buffer_folding.mlir b/compiler/src/iree/compiler/Modules/HAL/Inline/IR/test/buffer_folding.mlir
index eead77f..4fae891 100644
--- a/compiler/src/iree/compiler/Modules/HAL/Inline/IR/test/buffer_folding.mlir
+++ b/compiler/src/iree/compiler/Modules/HAL/Inline/IR/test/buffer_folding.mlir
@@ -1,8 +1,8 @@
// RUN: iree-opt --split-input-file --canonicalize -cse %s | iree-opt --allow-unregistered-dialect --split-input-file | FileCheck %s
-// CHECK-LABEL: func @fold_buffer_length
+// CHECK-LABEL: func @FoldBufferLengthOp
// CHECK-SAME: (%[[LENGTH:.+]]: index)
-func.func @fold_buffer_length(%length: index) -> index {
+func.func @FoldBufferLengthOp(%length: index) -> index {
%c64 = arith.constant 64 : index
%buffer, %storage = hal_inline.buffer.allocate alignment(%c64) : !hal.buffer{%length} in !util.buffer
// CHECK-NOT: hal_inline.buffer.length
@@ -13,8 +13,8 @@
// -----
-// CHECK-LABEL: func @fold_buffer_storage
-func.func @fold_buffer_storage(%length: index) -> !util.buffer {
+// CHECK-LABEL: func @FoldBufferStorageOp
+func.func @FoldBufferStorageOp(%length: index) -> !util.buffer {
%c64 = arith.constant 64 : index
// CHECK: %[[BUFFER:.+]], %[[STORAGE:.+]] = hal_inline.buffer.allocate
%buffer, %storage = hal_inline.buffer.allocate alignment(%c64) : !hal.buffer{%length} in !util.buffer
@@ -26,14 +26,39 @@
// -----
-// CHECK-LABEL: func @skip_buffer_view_buffer
+// CHECK-LABEL: @FoldBufferViewCreateSubspan
+// CHECK-SAME: (%[[BASE_BUFFER:.+]]: !hal.buffer, %[[SUBSPAN_OFFSET:.+]]: index, %[[SUBSPAN_LENGTH:.+]]: index)
+func.func @FoldBufferViewCreateSubspan(%base_buffer: !hal.buffer, %subspan_offset: index, %subspan_length: index) -> !hal.buffer_view {
+ %subspan = hal_inline.buffer.subspan<%base_buffer : !hal.buffer>[%subspan_offset, %subspan_length] : !hal.buffer
+ // CHECK-DAG: %[[VIEW_OFFSET:.+]] = arith.constant 512
+ %view_offset = arith.constant 512 : index
+ // CHECK-DAG: %[[VIEW_LENGTH:.+]] = arith.constant 1024
+ %view_length = arith.constant 1024 : index
+ // CHECK-DAG: %[[FOLDED_OFFSET:.+]] = arith.addi %[[SUBSPAN_OFFSET]], %[[VIEW_OFFSET]]
+ // CHECK: = hal_inline.buffer_view.create
+ // CHECK-SAME: buffer(%[[BASE_BUFFER]] : !hal.buffer)[%[[FOLDED_OFFSET]], %[[VIEW_LENGTH]]]
+ %dim0 = arith.constant 128 : index
+ %type = arith.constant 32 : i32
+ %encoding = arith.constant 1 : i32
+ %view = hal_inline.buffer_view.create buffer(%subspan : !hal.buffer)[%view_offset, %view_length]
+ shape([%dim0])
+ type(%type)
+ encoding(%encoding) : !hal.buffer_view
+ return %view : !hal.buffer_view
+}
+
+// -----
+
+// CHECK-LABEL: func @SkipBufferViewBufferOp
// CHECK-SAME: (%[[BUFFER:.+]]: !hal.buffer)
-func.func @skip_buffer_view_buffer(%buffer: !hal.buffer) -> !hal.buffer {
+func.func @SkipBufferViewBufferOp(%buffer: !hal.buffer) -> !hal.buffer {
+ %c0 = arith.constant 0 : index
%c1 = arith.constant 1 : i32
%c10 = arith.constant 10 : index
%c11 = arith.constant 11 : index
%c32 = arith.constant 32 : i32
- %view = hal_inline.buffer_view.create buffer(%buffer : !hal.buffer)
+ %c64 = arith.constant 64 : index
+ %view = hal_inline.buffer_view.create buffer(%buffer : !hal.buffer)[%c0, %c64]
shape([%c10, %c11])
type(%c32)
encoding(%c1) : !hal.buffer_view
diff --git a/compiler/src/iree/compiler/Modules/HAL/Inline/hal_inline.imports.mlir b/compiler/src/iree/compiler/Modules/HAL/Inline/hal_inline.imports.mlir
index e9ee06d..dce15ad 100644
--- a/compiler/src/iree/compiler/Modules/HAL/Inline/hal_inline.imports.mlir
+++ b/compiler/src/iree/compiler/Modules/HAL/Inline/hal_inline.imports.mlir
@@ -64,7 +64,9 @@
// Creates a reference to a buffer with a particular shape and element type.
vm.import @buffer_view.create(
- %buffer : !vm.ref<!hal.buffer>,
+ %source_buffer : !vm.ref<!hal.buffer>,
+ %source_offset : i64,
+ %source_length : i64,
%element_type : i32,
%encoding_type : i32,
%shape : i64 ...
diff --git a/integrations/tensorflow/iree_tf_compiler/BUILD b/integrations/tensorflow/iree_tf_compiler/BUILD
index 195b006..96c8d42 100644
--- a/integrations/tensorflow/iree_tf_compiler/BUILD
+++ b/integrations/tensorflow/iree_tf_compiler/BUILD
@@ -78,6 +78,7 @@
"//iree_tf_compiler/MHLO",
"//iree_tf_compiler/TF",
"@llvm-project//llvm:Support",
+ "@llvm-project//mlir:BytecodeWriter",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
@@ -96,6 +97,7 @@
deps = [
"//iree_tf_compiler/TFL",
"@llvm-project//llvm:Support",
+ "@llvm-project//mlir:BytecodeWriter",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Parser",
"@llvm-project//mlir:Pass",
@@ -116,6 +118,7 @@
"//iree_tf_compiler/MHLO",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:ArithmeticDialect",
+ "@llvm-project//mlir:BytecodeWriter",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:MathDialect",
diff --git a/integrations/tensorflow/iree_tf_compiler/TFL/test/import/add.mlir b/integrations/tensorflow/iree_tf_compiler/TFL/test/import/add.mlir
index cb99070..76bf350 100644
--- a/integrations/tensorflow/iree_tf_compiler/TFL/test/import/add.mlir
+++ b/integrations/tensorflow/iree_tf_compiler/TFL/test/import/add.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-import-tflite %S/add.tflite | FileCheck %s
+// RUN: iree-import-tflite --output-format=mlir-ir %S/add.tflite | FileCheck %s
// CHECK: module {
// CHECK-NEXT: func.func @main(%arg0: tensor<1x8x8x3xf32> {iree.identifier = "input"}) -> (tensor<1x8x8x3xf32> {iree.identifier = "output"}) {
diff --git a/integrations/tensorflow/iree_tf_compiler/TFL/test/import/multi_add.mlir b/integrations/tensorflow/iree_tf_compiler/TFL/test/import/multi_add.mlir
index ba1eb5f..e9444e6 100644
--- a/integrations/tensorflow/iree_tf_compiler/TFL/test/import/multi_add.mlir
+++ b/integrations/tensorflow/iree_tf_compiler/TFL/test/import/multi_add.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-import-tflite %S/multi_add.tflite | FileCheck %s
+// RUN: iree-import-tflite --output-format=mlir-ir %S/multi_add.tflite | FileCheck %s
// CHECK: module {
// CHECK-NEXT: func.func @main(%arg0: tensor<1x8x8x3xf32> {iree.identifier = "a"}, %arg1: tensor<1x8x8x3xf32> {iree.identifier = "b"}, %arg2: tensor<1x8x8x3xf32> {iree.identifier = "c"}, %arg3: tensor<1x8x8x3xf32> {iree.identifier = "d"}) -> (tensor<1x8x8x3xf32> {iree.identifier = "x"}, tensor<1x8x8x3xf32> {iree.identifier = "y"}) {
diff --git a/integrations/tensorflow/iree_tf_compiler/iree-import-tf-main.cpp b/integrations/tensorflow/iree_tf_compiler/iree-import-tf-main.cpp
index 412a421..6962815 100644
--- a/integrations/tensorflow/iree_tf_compiler/iree-import-tf-main.cpp
+++ b/integrations/tensorflow/iree_tf_compiler/iree-import-tf-main.cpp
@@ -19,6 +19,7 @@
#include "llvm/Support/InitLLVM.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/ToolOutputFile.h"
+#include "mlir/Bytecode/BytecodeWriter.h"
#include "mlir/IR/AsmState.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Dialect.h"
@@ -42,6 +43,12 @@
savedmodel_v1,
};
+enum class OutputFormat {
+ none,
+ mlir_ir,
+ mlir_bytecode,
+};
+
} // namespace
static OwningOpRef<mlir::ModuleOp> importSavedModelV2(
@@ -73,7 +80,7 @@
return nullptr;
}
- return std::move(loadedModule).ValueOrDie();
+ return std::move(loadedModule).value();
}
static OwningOpRef<mlir::ModuleOp> importSavedModelV1(
@@ -117,7 +124,7 @@
return nullptr;
}
- return std::move(loadedModule).ValueOrDie();
+ return std::move(loadedModule).value();
}
int main(int argc, char **argv) {
@@ -135,6 +142,15 @@
clEnumVal(savedmodel_v1,
"Import a TensorFlow SavedModel V1 (directory)")));
+ // The output format flag is the master control for what we do with the
+ // in-memory compiled form.
+ llvm::cl::opt<OutputFormat> outputFormat(
+ "output-format", llvm::cl::desc("Format of imported output"),
+ llvm::cl::values(clEnumValN(OutputFormat::mlir_bytecode, "mlir-bytecode",
+ "MLIR Bytecode (default)"),
+ clEnumValN(OutputFormat::mlir_ir, "mlir-ir", "MLIR IR")),
+ llvm::cl::init(OutputFormat::mlir_bytecode));
+
static llvm::cl::opt<std::string> savedModelExportedNames(
"tf-savedmodel-exported-names",
llvm::cl::desc("Names to export from SavedModel, separated by ','. Empty "
@@ -190,14 +206,22 @@
llvm::errs() << "Could not open output file: " << savePath << "\n";
return failure();
}
- OpPrintingFlags printFlags;
- // TODO: Re-enable custom assembly format once fixed:
- // https://github.com/tensorflow/mlir-hlo/issues/25
- printFlags.printGenericOpForm();
- module->print(outputFile->os(), printFlags);
- outputFile->os() << "\n";
- outputFile->keep();
- return success();
+
+ if (outputFormat == OutputFormat::mlir_ir) {
+ OpPrintingFlags printFlags;
+ module->print(outputFile->os(), printFlags);
+ outputFile->os() << "\n";
+ outputFile->keep();
+ return success();
+ }
+
+ if (outputFormat == OutputFormat::mlir_bytecode) {
+ mlir::writeBytecodeToFile(*module, outputFile->os());
+ outputFile->keep();
+ return success();
+ }
+ llvm::errs() << "Unknown output format\n";
+ return failure();
};
// First stage import.
diff --git a/integrations/tensorflow/iree_tf_compiler/iree-import-tflite-main.cpp b/integrations/tensorflow/iree_tf_compiler/iree-import-tflite-main.cpp
index 71907f9..d5113b4 100644
--- a/integrations/tensorflow/iree_tf_compiler/iree-import-tflite-main.cpp
+++ b/integrations/tensorflow/iree_tf_compiler/iree-import-tflite-main.cpp
@@ -11,6 +11,7 @@
#include "llvm/Support/MemoryBuffer.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/ToolOutputFile.h"
+#include "mlir/Bytecode/BytecodeWriter.h"
#include "mlir/Dialect/Quant/QuantOps.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/IR/AsmState.h"
@@ -28,6 +29,12 @@
using namespace llvm;
using namespace mlir;
+enum class OutputFormat {
+ none,
+ mlir_ir,
+ mlir_bytecode,
+};
+
int main(int argc, char **argv) {
llvm::InitLLVM y(argc, argv);
@@ -37,6 +44,15 @@
cl::value_desc("filename"),
cl::init("-"));
+ // The output format flag is the master control for what we do with the
+ // in-memory compiled form.
+ llvm::cl::opt<OutputFormat> outputFormat(
+ "output-format", llvm::cl::desc("Format of imported output"),
+ llvm::cl::values(clEnumValN(OutputFormat::mlir_bytecode, "mlir-bytecode",
+ "MLIR Bytecode (default)"),
+ clEnumValN(OutputFormat::mlir_ir, "mlir-ir", "MLIR IR")),
+ llvm::cl::init(OutputFormat::mlir_bytecode));
+
static cl::opt<std::string> saveTempTflInput(
"save-temp-tfl-input",
cl::desc("Save the TFL pipeline input to this file"), cl::init(""));
@@ -119,11 +135,22 @@
llvm::errs() << "Could not open output file: " << savePath << "\n";
return failure();
}
- OpPrintingFlags printFlags;
- module->print(outputFile->os(), printFlags);
- outputFile->os() << "\n";
- outputFile->keep();
- return success();
+
+ if (outputFormat == OutputFormat::mlir_ir) {
+ OpPrintingFlags printFlags;
+ module->print(outputFile->os(), printFlags);
+ outputFile->os() << "\n";
+ outputFile->keep();
+ return success();
+ }
+
+ if (outputFormat == OutputFormat::mlir_bytecode) {
+ mlir::writeBytecodeToFile(*module, outputFile->os());
+ outputFile->keep();
+ return success();
+ }
+ llvm::errs() << "Unknown output format\n";
+ return failure();
};
// Save temp input.
diff --git a/integrations/tensorflow/iree_tf_compiler/iree-import-xla-main.cpp b/integrations/tensorflow/iree_tf_compiler/iree-import-xla-main.cpp
index 7c5f434..09a47c1 100644
--- a/integrations/tensorflow/iree_tf_compiler/iree-import-xla-main.cpp
+++ b/integrations/tensorflow/iree_tf_compiler/iree-import-xla-main.cpp
@@ -16,6 +16,7 @@
#include "llvm/Support/InitLLVM.h"
#include "llvm/Support/ToolOutputFile.h"
#include "mlir-hlo/Dialect/mhlo/IR/register.h"
+#include "mlir/Bytecode/BytecodeWriter.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Math/IR/Math.h"
@@ -45,6 +46,12 @@
mlir_text,
};
+enum class OutputFormat {
+ none,
+ mlir_ir,
+ mlir_bytecode,
+};
+
// Error collector that prints errors.
class PrintErrorCollector : public tensorflow::protobuf::io::ErrorCollector {
public:
@@ -119,6 +126,15 @@
clEnumVal(hlo_text, "Parse an HLO module in its native text format"),
clEnumVal(mlir_text, "Parse MLIR text containing MHLO ops")));
+ // The output format flag is the master control for what we do with the
+ // in-memory compiled form.
+ llvm::cl::opt<OutputFormat> outputFormat(
+ "output-format", llvm::cl::desc("Format of imported output"),
+ llvm::cl::values(clEnumValN(OutputFormat::mlir_bytecode, "mlir-bytecode",
+ "MLIR Bytecode (default)"),
+ clEnumValN(OutputFormat::mlir_ir, "mlir-ir", "MLIR IR")),
+ llvm::cl::init(OutputFormat::mlir_bytecode));
+
// Register any command line options.
registerAsmPrinterCLOptions();
registerMLIRContextCLOptions();
@@ -252,11 +268,22 @@
llvm::errs() << "Could not open output file: " << savePath << "\n";
return failure();
}
- OpPrintingFlags printFlags;
- module->print(outputFile->os(), printFlags);
- outputFile->os() << "\n";
- outputFile->keep();
- return success();
+
+ if (outputFormat == OutputFormat::mlir_ir) {
+ OpPrintingFlags printFlags;
+ module->print(outputFile->os(), printFlags);
+ outputFile->os() << "\n";
+ outputFile->keep();
+ return success();
+ }
+
+ if (outputFormat == OutputFormat::mlir_bytecode) {
+ mlir::writeBytecodeToFile(*module, outputFile->os());
+ outputFile->keep();
+ return success();
+ }
+ llvm::errs() << "Unknown output format\n";
+ return failure();
};
// Save temp output.
diff --git a/integrations/tensorflow/test/python/iree_tfl_tests/test_util.py b/integrations/tensorflow/test/python/iree_tfl_tests/test_util.py
index 285b7df..7df4ce3 100644
--- a/integrations/tensorflow/test/python/iree_tfl_tests/test_util.py
+++ b/integrations/tensorflow/test/python/iree_tfl_tests/test_util.py
@@ -63,14 +63,14 @@
return
self.workdir = _setup_artifacts_dir("download")
print(f"TMPDIR = {self.workdir}")
- self.tflite_file = '/'.join([self.workdir, 'model.tflite'])
- self.tflite_ir = '/'.join([self.workdir, 'tflite.mlir'])
- self.iree_ir = '/'.join([self.workdir, 'tosa.mlir'])
+ self.tflite_file = '/'.join([self.workdir, 'model.mlirbc'])
+ self.tflite_ir = '/'.join([self.workdir, 'tflite.mlirbc'])
+ self.iree_ir = '/'.join([self.workdir, 'tosa.mlirbc'])
if os.path.exists(self.model_path):
self.tflite_file = self.model_path
else:
urllib.request.urlretrieve(self.model_path, self.tflite_file)
- self.binary = '/'.join([self.workdir, 'module.bytecode'])
+ self.binary = '/'.join([self.workdir, 'module.vmfb'])
def generate_inputs(self, input_details):
args = []
diff --git a/llvm-external-projects/iree-dialects/BUILD b/llvm-external-projects/iree-dialects/BUILD
index c98945e..8e9c22f 100644
--- a/llvm-external-projects/iree-dialects/BUILD
+++ b/llvm-external-projects/iree-dialects/BUILD
@@ -375,6 +375,7 @@
":IREELinalgExtPasses",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:AffineDialect",
+ "@llvm-project//mlir:AffineUtils",
"@llvm-project//mlir:ArithmeticDialect",
"@llvm-project//mlir:ArithmeticUtils",
"@llvm-project//mlir:AsyncDialect",
@@ -391,12 +392,14 @@
"@llvm-project//mlir:Parser",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:SCFDialect",
+ "@llvm-project//mlir:SCFTransforms",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TensorDialect",
"@llvm-project//mlir:TensorUtils",
"@llvm-project//mlir:TilingInterface",
"@llvm-project//mlir:TransformUtils",
"@llvm-project//mlir:Transforms",
+ "@llvm-project//mlir:VectorTransforms",
],
)
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Passes/PassDetail.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Passes/PassDetail.h
index 417a582..bac5ee8 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Passes/PassDetail.h
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Passes/PassDetail.h
@@ -8,6 +8,8 @@
#define IREE_DIALECTS_DIALECT_LINALGEXT_TRANSFORMS_PASS_DETAIL_H_
#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Pass/Pass.h"
namespace mlir {
@@ -16,6 +18,7 @@
namespace LinalgExt {
#define GEN_PASS_CLASSES
+
#include "iree-dialects/Dialect/LinalgExt/Passes/Passes.h.inc" // IWYU pragma: keep
} // namespace LinalgExt
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Passes/Passes.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Passes/Passes.h
index 503e210..f261bff 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Passes/Passes.h
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Passes/Passes.h
@@ -43,6 +43,70 @@
// Marker used as attribute the depth of the split reduction transformations.
const StringLiteral kSplitReductionDepthMarker = "__split_reduction_depth__";
+//===---------------------------------------------------------------------===//
+// Codegen Strategy passes that are moved into IREE.
+//===---------------------------------------------------------------------===//
+/// Create a LinalgStrategyTileAndFusePass.
+std::unique_ptr<OperationPass<func::FuncOp>>
+createLinalgStrategyTileAndFusePass(
+ StringRef opName = "", const linalg::LinalgTilingAndFusionOptions &opt = {},
+ const linalg::LinalgTransformationFilter &filter =
+ linalg::LinalgTransformationFilter());
+
+/// Create a LinalgStrategyTilePass.
+std::unique_ptr<OperationPass<func::FuncOp>> createLinalgStrategyTilePass(
+ StringRef opName = "",
+ const linalg::LinalgTilingOptions &opt = linalg::LinalgTilingOptions(),
+ const linalg::LinalgTransformationFilter &filter =
+ linalg::LinalgTransformationFilter());
+
+/// Create a LinalgStrategyPadPass.
+std::unique_ptr<OperationPass<func::FuncOp>> createLinalgStrategyPadPass(
+ StringRef opName = "",
+ const linalg::LinalgPaddingOptions &opt = linalg::LinalgPaddingOptions(),
+ const linalg::LinalgTransformationFilter &filter =
+ linalg::LinalgTransformationFilter());
+
+/// Create a LinalgStrategyDecomposePass.
+// TODO: if/when we need finer control add an `opName` parameter.
+std::unique_ptr<OperationPass<func::FuncOp>> createLinalgStrategyDecomposePass(
+ const linalg::LinalgTransformationFilter &filter =
+ linalg::LinalgTransformationFilter());
+
+/// Create a LinalgStrategyPeelPass.
+std::unique_ptr<OperationPass<func::FuncOp>> createLinalgStrategyPeelPass(
+ StringRef opName = "",
+ const linalg::LinalgPeelOptions &opt = linalg::LinalgPeelOptions(),
+ const linalg::LinalgTransformationFilter &filter =
+ linalg::LinalgTransformationFilter());
+
+/// Create a LinalgStrategyVectorizePass.
+std::unique_ptr<OperationPass<func::FuncOp>> createLinalgStrategyVectorizePass(
+ StringRef opName = "",
+ linalg::LinalgVectorizationOptions opt =
+ linalg::LinalgVectorizationOptions(),
+ const linalg::LinalgTransformationFilter &filter =
+ linalg::LinalgTransformationFilter(),
+ bool padVectorize = false);
+
+/// Create a LinalgStrategyEnablePass.
+std::unique_ptr<OperationPass<func::FuncOp>> createLinalgStrategyEnablePass(
+ linalg::LinalgEnablingOptions opt = linalg::LinalgEnablingOptions(),
+ const linalg::LinalgTransformationFilter &filter =
+ linalg::LinalgTransformationFilter());
+
+/// Create a LinalgStrategyLowerVectorsPass.
+std::unique_ptr<OperationPass<func::FuncOp>>
+createLinalgStrategyLowerVectorsPass(
+ linalg::LinalgVectorLoweringOptions opt =
+ linalg::LinalgVectorLoweringOptions(),
+ const linalg::LinalgTransformationFilter &filter =
+ linalg::LinalgTransformationFilter());
+
+/// Create a LinalgStrategyRemoveMarkersPass.
+std::unique_ptr<OperationPass<func::FuncOp>>
+createLinalgStrategyRemoveMarkersPass();
+
void registerPasses();
} // namespace LinalgExt
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Passes/Passes.td b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Passes/Passes.td
index db01f82..aba76b8 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Passes/Passes.td
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Passes/Passes.td
@@ -59,4 +59,124 @@
];
}
+//===---------------------------------------------------------------------====//
+// Codegen Strategy passes moved into IREE
+// TODO: Deprecate all this.
+//===---------------------------------------------------------------------====//
+
+def LinalgStrategyTileAndFusePass
+ : Pass<"iree-linalg-strategy-tile-and-fuse-pass", "func::FuncOp"> {
+ let summary = "Configurable pass to apply pattern-based tiling and fusion.";
+ let constructor = "createLinalgStrategyTileAndFusePass()";
+ let options = [
+ Option<"anchorFuncName", "anchor-func", "std::string", /*default=*/"",
+ "Which func op is the anchor to latch on.">,
+ Option<"anchorOpName", "anchor-op", "std::string", /*default=*/"",
+ "Which linalg op within the func is the anchor to latch on.">,
+ ];
+}
+
+def LinalgStrategyTilePass
+ : Pass<"iree-linalg-strategy-tile-pass", "func::FuncOp"> {
+ let summary = "Configurable pass to apply pattern-based linalg tiling.";
+ let constructor = "createLinalgStrategyTilePass()";
+ let dependentDialects = ["linalg::LinalgDialect"];
+ let options = [
+ Option<"anchorFuncName", "anchor-func", "std::string", /*default=*/"",
+ "Which func op is the anchor to latch on.">,
+ Option<"anchorOpName", "anchor-op", "std::string", /*default=*/"",
+ "Which linalg op within the func is the anchor to latch on.">,
+ ];
+}
+
+def LinalgStrategyPadPass
+ : Pass<"iree-linalg-strategy-pad-pass", "func::FuncOp"> {
+ let summary = "Configurable pass to apply padding and hoisting.";
+ let constructor = "createLinalgStrategyPadPass()";
+ let dependentDialects = ["linalg::LinalgDialect"];
+ let options = [
+ Option<"anchorFuncName", "anchor-func", "std::string", /*default=*/"",
+ "Which func op is the anchor to latch on.">,
+ Option<"anchorOpName", "anchor-op", "std::string", /*default=*/"",
+ "Which linalg op within the func is the anchor to latch on.">,
+ ];
+}
+
+// TODO: if/when we need finer control add an anchorOp option.
+def LinalgStrategyDecomposePass
+ : Pass<"iree-linalg-strategy-decompose-pass", "func::FuncOp"> {
+ let summary = "Configurable pass to apply pattern-based generalization.";
+ let constructor = "createLinalgStrategyDecomposePass()";
+ let dependentDialects = ["linalg::LinalgDialect"];
+ let options = [
+ Option<"anchorFuncName", "anchor-func", "std::string", /*default=*/"",
+ "Which func op is the anchor to latch on.">,
+ ];
+}
+
+def LinalgStrategyPeelPass
+ : Pass<"iree-linalg-strategy-peel-pass", "func::FuncOp"> {
+ let summary = "Configurable pass to apply pattern-based linalg peeling.";
+ let constructor = "createLinalgStrategyPeelPass()";
+ let dependentDialects = [
+ "linalg::LinalgDialect",
+ "scf::SCFDialect"
+ ];
+ let options = [
+ Option<"anchorFuncName", "anchor-func", "std::string", /*default=*/"",
+ "Which func op is the anchor to latch on.">,
+ Option<"anchorOpName", "anchor-op", "std::string", /*default=*/"",
+ "Which linalg op within the func is the anchor to latch on.">,
+ ];
+}
+
+def LinalgStrategyVectorizePass
+ : Pass<"iree-linalg-strategy-vectorize-pass", "func::FuncOp"> {
+ let summary = "Configurable pass to apply pattern-based linalg vectorization.";
+ let constructor = "createLinalgStrategyVectorizePass()";
+ let dependentDialects = ["linalg::LinalgDialect"];
+ let options = [
+ Option<"anchorFuncName", "anchor-func", "std::string", /*default=*/"",
+ "Which func op is the anchor to latch on.">,
+ Option<"anchorOpName", "anchor-op", "std::string", /*default=*/"",
+ "Which linalg op within the func is the anchor to latch on.">,
+ Option<"vectorizePadding", "vectorize-padding", "bool", "false",
+ "Enable vectorization of padding ops.">,
+ ];
+}
+
+def LinalgStrategyEnablePass
+ : Pass<"iree-linalg-strategy-enable-pass", "func::FuncOp"> {
+ let summary = "Configurable pass to enable the application of other "
+ "pattern-based linalg passes.";
+ let constructor = "createLinalgStrategyEnablePass()";
+ let dependentDialects = ["linalg::LinalgDialect"];
+ let options = [
+ Option<"anchorFuncName", "anchor-func", "std::string", /*default=*/"",
+ "Which func op is the anchor to latch on.">,
+ ];
+}
+
+def LinalgStrategyLowerVectorsPass
+ : Pass<"iree-linalg-strategy-lower-vectors-pass", "func::FuncOp"> {
+ let summary = "Configurable pass to lower vector operations.";
+ let constructor = "createLinalgStrategyLowerVectorsPass()";
+ let dependentDialects = ["linalg::LinalgDialect"];
+ let options = [
+ Option<"anchorFuncName", "anchor-func", "std::string", /*default=*/"",
+ "Which func op is the anchor to latch on.">,
+ ];
+}
+
+def LinalgStrategyRemoveMarkersPass
+ : Pass<"iree-linalg-strategy-remove-markers-pass", "func::FuncOp"> {
+ let summary = "Cleanup pass that drops markers.";
+ let constructor = "createLinalgStrategyRemoveMarkersPass()";
+ let dependentDialects = ["linalg::LinalgDialect"];
+ let options = [
+ Option<"anchorFuncName", "anchor-func", "std::string", /*default=*/"",
+ "Which func op is the anchor to latch on.">,
+ ];
+}
+
#endif // IREE_DIALECT_LINALGEXT_PASSES
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Transforms/CodegenStrategy.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Transforms/CodegenStrategy.h
new file mode 100644
index 0000000..d7dcfc9
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Transforms/CodegenStrategy.h
@@ -0,0 +1,287 @@
+// Copyright 2022 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_DIALECTS_DIALECT_LINALGEXT_TRANSFORMS_CODEGENSTRATEGY_H_
+#define IREE_DIALECTS_DIALECT_LINALGEXT_TRANSFORMS_CODEGENSTRATEGY_H_
+
+#include "iree-dialects/Dialect/LinalgExt/Passes/Passes.h"
+#include "mlir/Pass/PassManager.h"
+
+#include <utility>
+
+//===----------------------------------------------------------------------===//
+// Strategies moved from upstream MLIR as IREE still heavily relies on patterns
+// that compose through filters.
+// TODO: Deprecate everything below.
+//===----------------------------------------------------------------------===//
+
+namespace mlir {
+namespace iree_compiler {
+namespace IREE {
+namespace LinalgExt {
+
+/// Abstract Transformation class applied in a sequence that also handles state
+/// through markers.
+struct Transformation {
+ explicit Transformation(linalg::LinalgTransformationFilter::FilterFunction f)
+ : filter(std::move(f)) {}
+ virtual ~Transformation() = default;
+ virtual void
+ addToPassPipeline(OpPassManager &pm,
+ linalg::LinalgTransformationFilter m) const = 0;
+ linalg::LinalgTransformationFilter::FilterFunction filter = nullptr;
+};
+
+/// Represent one application of LinalgStrategyTileAndFusePass.
+struct TileAndFuse : public Transformation {
+ TileAndFuse(StringRef name, linalg::LinalgTilingAndFusionOptions options,
+ linalg::LinalgTransformationFilter::FilterFunction f = nullptr)
+ : Transformation(std::move(f)), opName(name),
+ options(std::move(options)) {}
+
+ void addToPassPipeline(OpPassManager &pm,
+ linalg::LinalgTransformationFilter m) const override {
+ pm.addPass(createLinalgStrategyTileAndFusePass(opName, options, m));
+ }
+
+private:
+ std::string opName;
+ linalg::LinalgTilingAndFusionOptions options;
+};
+
+/// Represent one application of LinalgStrategyTilePass.
+struct Tile : public Transformation {
+ Tile(StringRef name, linalg::LinalgTilingOptions options,
+ linalg::LinalgTransformationFilter::FilterFunction f = nullptr)
+ : Transformation(std::move(f)), opName(name),
+ options(std::move(options)) {}
+
+ void addToPassPipeline(OpPassManager &pm,
+ linalg::LinalgTransformationFilter m) const override {
+ pm.addPass(createLinalgStrategyTilePass(opName, options, m));
+ }
+
+private:
+ std::string opName;
+ linalg::LinalgTilingOptions options;
+};
+
+/// Represent one application of LinalgStrategyPadPass.
+struct Pad : public Transformation {
+ Pad(StringRef name, linalg::LinalgPaddingOptions options,
+ linalg::LinalgTransformationFilter::FilterFunction f = nullptr)
+ : Transformation(std::move(f)), opName(name),
+ options(std::move(options)) {}
+
+ void addToPassPipeline(OpPassManager &pm,
+ linalg::LinalgTransformationFilter m) const override {
+ pm.addPass(createLinalgStrategyPadPass(opName, options, m));
+ }
+
+private:
+ std::string opName;
+ linalg::LinalgPaddingOptions options;
+};
+
+/// Represent one application of createLinalgStrategyDecomposePass.
+struct Decompose : public Transformation {
+ explicit Decompose(
+ linalg::LinalgTransformationFilter::FilterFunction f = nullptr)
+ : Transformation(std::move(f)) {}
+
+ void addToPassPipeline(OpPassManager &pm,
+ linalg::LinalgTransformationFilter m) const override {
+ pm.addPass(createLinalgStrategyDecomposePass(m));
+ }
+};
+
+/// Represent one application of createLinalgStrategyPeelPass.
+struct Peel : public Transformation {
+ explicit Peel(linalg::LinalgPeelOptions options,
+ linalg::LinalgTransformationFilter::FilterFunction f = nullptr)
+ : Transformation(std::move(f)), options(options) {}
+
+ Peel(StringRef name, linalg::LinalgPeelOptions options,
+ linalg::LinalgTransformationFilter::FilterFunction f = nullptr)
+ : Transformation(std::move(f)), opName(name), options(options) {}
+
+ void addToPassPipeline(OpPassManager &pm,
+ linalg::LinalgTransformationFilter m) const override {
+ pm.addPass(createLinalgStrategyPeelPass(opName, options, m));
+ }
+
+private:
+ std::string opName;
+ linalg::LinalgPeelOptions options;
+};
+
+/// Represent one application of createLinalgStrategyVectorizePass.
+struct Vectorize : public Transformation {
+ explicit Vectorize(
+ linalg::LinalgVectorizationOptions options,
+ linalg::LinalgTransformationFilter::FilterFunction f = nullptr,
+ bool padVectorize = false)
+ : Transformation(std::move(f)), options(options),
+ vectorizePadding(padVectorize) {}
+
+ Vectorize(StringRef name, linalg::LinalgVectorizationOptions options,
+ linalg::LinalgTransformationFilter::FilterFunction f = nullptr,
+ bool padVectorize = false)
+ : Transformation(std::move(f)), opName(name), options(options),
+ vectorizePadding(padVectorize) {}
+
+ void addToPassPipeline(OpPassManager &pm,
+ linalg::LinalgTransformationFilter m) const override {
+ pm.addPass(createLinalgStrategyVectorizePass(opName, options, m,
+ vectorizePadding));
+ }
+
+private:
+ std::string opName;
+ linalg::LinalgVectorizationOptions options;
+ bool vectorizePadding;
+};
+
+/// Represent one application of createLinalgStrategyLowerVectorsPass.
+struct VectorLowering : public Transformation {
+ explicit VectorLowering(
+ linalg::LinalgVectorLoweringOptions options,
+ linalg::LinalgTransformationFilter::FilterFunction f = nullptr)
+ : Transformation(std::move(f)), options(options) {}
+
+ void addToPassPipeline(OpPassManager &pm,
+ linalg::LinalgTransformationFilter m) const override {
+ pm.addPass(createLinalgStrategyLowerVectorsPass(options, m));
+ }
+
+private:
+ linalg::LinalgVectorLoweringOptions options;
+};
+
+/// Codegen strategy controls how a Linalg op is progressively lowered.
+struct CodegenStrategy {
+ /// Append a pattern to tile the Op `opName` and fuse its producers with
+ /// tiling and fusion `options`.
+ CodegenStrategy &tileAndFuse(
+ StringRef opName, const linalg::LinalgTilingAndFusionOptions &options,
+ const linalg::LinalgTransformationFilter::FilterFunction &f = nullptr) {
+ transformationSequence.emplace_back(
+ std::make_unique<TileAndFuse>(opName, options, f));
+ return *this;
+ }
+ /// Conditionally append a pattern to tile the Op `opName` and fuse its
+ /// producers with tiling and fusion `options`.
+ CodegenStrategy &tileAndFuseIf(
+ bool b, StringRef opName, linalg::LinalgTilingAndFusionOptions options,
+ linalg::LinalgTransformationFilter::FilterFunction f = nullptr) {
+ return b ? tileAndFuse(opName, std::move(options), std::move(f)) : *this;
+ }
+ /// Append a pattern to add a level of tiling for Op `opName` with tiling
+ /// `options`.
+ CodegenStrategy &
+ tile(StringRef opName, const linalg::LinalgTilingOptions &options,
+ const linalg::LinalgTransformationFilter::FilterFunction &f = nullptr) {
+ transformationSequence.emplace_back(
+ std::make_unique<Tile>(opName, options, f));
+ return *this;
+ }
+ /// Conditionally append a pattern to add a level of tiling for
+ /// `LinalgOpType` with tiling `options`.
+ CodegenStrategy &
+ tileIf(bool b, StringRef opName, linalg::LinalgTilingOptions options,
+ linalg::LinalgTransformationFilter::FilterFunction f = nullptr) {
+ return b ? tile(opName, std::move(options), std::move(f)) : *this;
+ }
+ /// Append a pattern to pad and hoist the operands of Op `opName` with padding
+ /// `options`.
+ CodegenStrategy &
+ pad(StringRef opName, const linalg::LinalgPaddingOptions &options,
+ const linalg::LinalgTransformationFilter::FilterFunction &f = nullptr) {
+ transformationSequence.emplace_back(
+ std::make_unique<Pad>(opName, options, f));
+ return *this;
+ }
+ /// Conditionally append a pattern to pad and hoist the operands of Op
+ /// `opName` with padding `options`.
+ CodegenStrategy &
+ padIf(bool b, StringRef opName, linalg::LinalgPaddingOptions options,
+ linalg::LinalgTransformationFilter::FilterFunction f = nullptr) {
+ return b ? pad(opName, std::move(options), std::move(f)) : *this;
+ }
+ /// Append patterns to decompose convolutions.
+ CodegenStrategy &decompose(
+ const linalg::LinalgTransformationFilter::FilterFunction &f = nullptr) {
+ transformationSequence.emplace_back(std::make_unique<Decompose>(f));
+ return *this;
+ }
+ /// Conditionally append patterns to decompose convolutions.
+ CodegenStrategy &
+ decomposeIf(bool b,
+ linalg::LinalgTransformationFilter::FilterFunction f = nullptr) {
+ return b ? decompose(std::move(f)) : *this;
+ }
+ /// Append a pattern to peel 'LinalgOpType'.
+ CodegenStrategy &
+ peel(StringRef opName, const linalg::LinalgPeelOptions &options,
+ const linalg::LinalgTransformationFilter::FilterFunction &f = nullptr) {
+ transformationSequence.emplace_back(
+ std::make_unique<Peel>(opName, options, f));
+ return *this;
+ }
+ /// Conditionally append a pattern to peel 'LinalgOpType'.
+ CodegenStrategy &
+ peelIf(bool b, StringRef opName, const linalg::LinalgPeelOptions &options,
+ linalg::LinalgTransformationFilter::FilterFunction f = nullptr) {
+ return b ? peel(opName, options, std::move(f)) : *this;
+ }
+ /// Append a pattern to rewrite `LinalgOpType` as a vector operation.
+ CodegenStrategy &vectorize(
+ StringRef opName,
+ const linalg::LinalgTransformationFilter::FilterFunction &f = nullptr,
+ bool vectorizePadding = false) {
+ transformationSequence.emplace_back(std::make_unique<Vectorize>(
+ opName, linalg::LinalgVectorizationOptions(), f, vectorizePadding));
+ return *this;
+ }
+ /// Conditionally append a pattern to rewrite `LinalgOpType` as a vector
+ /// operation.
+ CodegenStrategy &
+ vectorizeIf(bool b, StringRef opName,
+ linalg::LinalgTransformationFilter::FilterFunction f = nullptr,
+ bool vectorizePadding = false) {
+ return b ? vectorize(opName, std::move(f), vectorizePadding) : *this;
+ }
+ /// Append a pattern to lower all vector operations.
+ CodegenStrategy &vectorLowering(linalg::LinalgVectorLoweringOptions options) {
+ transformationSequence.emplace_back(
+ std::make_unique<VectorLowering>(options));
+ return *this;
+ }
+ /// Configure the post staged-patterns global enabling passes options.
+ CodegenStrategy &
+ setVectorTransferToSCFOptions(linalg::LinalgEnablingOptions options) {
+ linalgEnablingOptions = options;
+ return *this;
+ }
+
+ /// Apply the transformation patterns in sequence with cleanup
+ /// transformations interleaved.
+ void configurePassPipeline(OpPassManager &pm, MLIRContext *context,
+ bool addEnablePass = true) const;
+
+private:
+ LogicalResult postPatternTransforms(Operation *func) const;
+
+ linalg::LinalgEnablingOptions linalgEnablingOptions;
+ SmallVector<std::unique_ptr<Transformation>, 4> transformationSequence;
+};
+
+} // namespace LinalgExt
+} // namespace IREE
+} // namespace iree_compiler
+} // namespace mlir
+
+#endif // IREE_DIALECTS_DIALECT_LINALGEXT_TRANSFORMS_CODEGENSTRATEGY_H_
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Transforms/Transforms.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Transforms/Transforms.h
index f93fbf4..5a54d2c 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Transforms/Transforms.h
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Transforms/Transforms.h
@@ -96,17 +96,135 @@
//===----------------------------------------------------------------------===//
// Transformations exposed as patterns, moved from upstream MLIR as IREE still
// heavily relies on patterns that compose through filters.
-// TODO: Deprecate this.
+// TODO: Deprecate all the patterns below.
//===----------------------------------------------------------------------===//
///
+/// Linalg tiling pattern.
+///
+/// Apply the `tiling` transformation as a pattern.
+/// `filter` controls LinalgTransformMarker matching and update when specified.
+/// See `tiling` for more details.
+// TODO: TiledOpInterface
+struct LinalgTilingPattern
+ : public OpInterfaceRewritePattern<linalg::LinalgOp> {
+ /// Construct a generic pattern applied to all LinalgOp that verify `filter`.
+ LinalgTilingPattern(MLIRContext *context, linalg::LinalgTilingOptions options,
+ linalg::LinalgTransformationFilter f =
+ linalg::LinalgTransformationFilter(),
+ PatternBenefit benefit = 1);
+
+ /// Construct a pattern specifically applied to `opName`.
+ LinalgTilingPattern(StringRef opName, MLIRContext *context,
+ linalg::LinalgTilingOptions options,
+ linalg::LinalgTransformationFilter f =
+ linalg::LinalgTransformationFilter(),
+ PatternBenefit benefit = 1);
+
+ /// `matchAndRewrite` implementation that returns the significant transformed
+ /// pieces of IR.
+ FailureOr<linalg::TiledLinalgOp>
+ returningMatchAndRewrite(linalg::LinalgOp op,
+ PatternRewriter &rewriter) const;
+
+ LogicalResult matchAndRewrite(linalg::LinalgOp op,
+ PatternRewriter &rewriter) const override {
+ return returningMatchAndRewrite(op, rewriter);
+ }
+
+private:
+ /// LinalgTransformMarker handles special attribute manipulations.
+ linalg::LinalgTransformationFilter filter;
+ /// Options to control tiling;
+ linalg::LinalgTilingOptions options;
+};
+
+template <typename... OpTypes>
+class TilingPatterns;
+
+template <>
+class TilingPatterns<> {
+public:
+ static void insert(RewritePatternSet &patterns,
+ const linalg::LinalgTilingOptions &options,
+ const linalg::LinalgTransformationFilter &f) {}
+};
+
+template <typename OpTy, typename... OpTypes>
+class TilingPatterns<OpTy, OpTypes...> {
+public:
+ static void insert(RewritePatternSet &patterns,
+ const linalg::LinalgTilingOptions &options,
+ const linalg::LinalgTransformationFilter &f) {
+ patterns.add<LinalgTilingPattern>(OpTy::getOperationName(),
+ patterns.getContext(), options, f);
+ TilingPatterns<OpTypes...>::insert(patterns, options, f);
+ }
+};
+
+///
+/// Linalg vectorization patterns.
+///
+/// `filter` controls LinalgTransformMarker matching and update when specified.
+/// See `vectorizeLinalgOp` for more details.
+struct LinalgVectorizationPattern
+ : public OpInterfaceRewritePattern<linalg::LinalgOp> {
+ /// Construct a generic pattern applied to all LinalgOp that verify `filter`.
+ LinalgVectorizationPattern(MLIRContext *context,
+ linalg::LinalgTransformationFilter f =
+ linalg::LinalgTransformationFilter(),
+ linalg::LinalgVectorizationOptions options =
+ linalg::LinalgVectorizationOptions(),
+ PatternBenefit benefit = 1);
+
+ /// Construct a pattern specifically applied to `opName`.
+ LinalgVectorizationPattern(StringRef opName, MLIRContext *context,
+ linalg::LinalgVectorizationOptions options =
+ linalg::LinalgVectorizationOptions(),
+ linalg::LinalgTransformationFilter f =
+ linalg::LinalgTransformationFilter(),
+ PatternBenefit benefit = 1);
+
+ LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp,
+ PatternRewriter &rewriter) const override;
+
+private:
+ /// LinalgTransformMarker handles special attribute manipulations.
+ linalg::LinalgTransformationFilter filter;
+};
+
+template <typename... OpTypes>
+class VectorizationPatterns;
+
+template <>
+class VectorizationPatterns<> {
+public:
+ static void insert(RewritePatternSet &patterns,
+ const linalg::LinalgVectorizationOptions &options,
+ const linalg::LinalgTransformationFilter &f) {}
+};
+
+template <typename OpTy, typename... OpTypes>
+class VectorizationPatterns<OpTy, OpTypes...> {
+public:
+ static void insert(RewritePatternSet &patterns,
+ const linalg::LinalgVectorizationOptions &options,
+ const linalg::LinalgTransformationFilter &f) {
+ patterns.add<LinalgVectorizationPattern>(OpTy::getOperationName(),
+ patterns.getContext(), options, f);
+ VectorizationPatterns<OpTypes...>::insert(patterns, options, f);
+ }
+};
+
+///
/// Linalg promotion patterns.
///
/// Apply the `promoteSubViews` transformation as a pattern.
/// `filter` controls LinalgTransformMarker matching and update when specified.
/// See `promoteSubViews` for more details.
struct LinalgBasePromotionPattern : public RewritePattern {
- /// Entry point to match any LinalgOp OpInterface.
- /// MatchAnyOpTag-based constructor with a mandatory `filter`.
+ /// Entry point to match any LinalgOp
+ /// OpInterface. MatchAnyOpTag-based constructor
+ /// with a mandatory `filter`.
LinalgBasePromotionPattern(
MLIRContext *context, linalg::LinalgTransformationFilter f,
linalg::LinalgPromotionOptions options = linalg::LinalgPromotionOptions(),
@@ -129,10 +247,13 @@
if (failed(promoteSubviewsPrecondition(op, options)))
return failure();
- // TODO: We cannot use root update here. This pattern is creating other ops,
- // so if the promotion fails, those need to be cleaned up, which doesnt seem
- // to be happening here. So to fail properly, we should be cloning the op
- // and deleting the previous op. This needs more investigation.
+ // TODO: We cannot use root update here. This
+ // pattern is creating other ops, so if the
+ // promotion fails, those need to be cleaned
+ // up, which doesnt seem to be happening here.
+ // So to fail properly, we should be cloning
+ // the op and deleting the previous op. This
+ // needs more investigation.
rewriter.startRootUpdate(op);
Optional<linalg::LinalgOp> promotedOp =
promoteSubViews(rewriter, op, options);
@@ -146,7 +267,8 @@
}
private:
- /// LinalgTransformMarker handles special attribute manipulations.
+ /// LinalgTransformMarker handles special
+ /// attribute manipulations.
linalg::LinalgTransformationFilter filter;
/// Promotion options.
linalg::LinalgPromotionOptions options;
@@ -154,8 +276,9 @@
template <typename OpTy>
struct LinalgPromotionPattern : public LinalgBasePromotionPattern {
- /// SFINAE: This constructor can only trigger for concrete ops that have a
- /// static `getOperationName` method.
+ /// SFINAE: This constructor can only trigger for
+ /// concrete ops that have a static
+ /// `getOperationName` method.
template <typename ConcreateOpTy = OpTy>
LinalgPromotionPattern(MLIRContext *context,
linalg::LinalgPromotionOptions options,
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/CMakeLists.txt b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/CMakeLists.txt
index 6656921..85f46a5 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/CMakeLists.txt
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/CMakeLists.txt
@@ -1,8 +1,10 @@
add_mlir_library(IREELinalgExtTransforms
+ CodegenStrategy.cpp
ForeachThreadToAsync.cpp
ForeachThreadToSequentialFor.cpp
Fusion.cpp
Tiling.cpp
+ Transforms.cpp
Utils.cpp
PARTIAL_SOURCES_INTENDED
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/CodegenStrategy.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/CodegenStrategy.cpp
new file mode 100644
index 0000000..8ef2223
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/CodegenStrategy.cpp
@@ -0,0 +1,46 @@
+// Copyright 2022 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree-dialects/Dialect/LinalgExt/Transforms/CodegenStrategy.h"
+#include "iree-dialects/Dialect/LinalgExt/Passes/Passes.h"
+#include "mlir/Pass/PassManager.h"
+
+using namespace mlir;
+
+#define DEBUG_TYPE "linalg-codegen-strategy"
+
+namespace mlir {
+namespace iree_compiler {
+namespace IREE {
+namespace LinalgExt {
+
+void CodegenStrategy::configurePassPipeline(OpPassManager &pm,
+ MLIRContext *context,
+ bool addEnablePass) const {
+ for (unsigned stepCount = 0, e = transformationSequence.size(); stepCount < e;
+ ++stepCount) {
+ const std::unique_ptr<Transformation> &t =
+ transformationSequence[stepCount];
+ std::string currentStr = std::to_string(stepCount);
+ auto currentState = StringAttr::get(context, currentStr);
+ std::string nextStr = std::to_string(stepCount + 1);
+ auto nextState = StringAttr::get(context, nextStr);
+ auto filter = (currentState.str() == std::to_string(0))
+ ? linalg::LinalgTransformationFilter(
+ t->filter, ArrayRef<StringAttr>{}, nextState)
+ : linalg::LinalgTransformationFilter(
+ t->filter, currentState, nextState);
+ t->addToPassPipeline(pm, filter);
+ if (addEnablePass)
+ pm.addPass(createLinalgStrategyEnablePass(linalgEnablingOptions));
+ }
+ pm.addPass(createLinalgStrategyRemoveMarkersPass());
+}
+
+} // namespace LinalgExt
+} // namespace IREE
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/Transforms.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/Transforms.cpp
new file mode 100644
index 0000000..9bc3b10
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/Transforms.cpp
@@ -0,0 +1,513 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree-dialects/Dialect/LinalgExt/Transforms/Transforms.h"
+
+#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h"
+#include "iree-dialects/Dialect/LinalgExt/Passes/PassDetail.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Affine/LoopUtils.h"
+#include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Dialect/SCF/Transforms/Transforms.h"
+#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "mlir/Transforms/LoopInvariantCodeMotionUtils.h"
+#include "mlir/Transforms/Passes.h"
+
+using namespace mlir;
+
+namespace mlir {
+namespace iree_compiler {
+namespace IREE {
+namespace LinalgExt {
+/// Linalg tiling pattern.
+LinalgTilingPattern::LinalgTilingPattern(MLIRContext *context,
+ linalg::LinalgTilingOptions options,
+ linalg::LinalgTransformationFilter f,
+ PatternBenefit benefit)
+ : OpInterfaceRewritePattern<linalg::LinalgOp>(context, benefit),
+ filter(std::move(f)), options(std::move(options)) {}
+
+LinalgTilingPattern::LinalgTilingPattern(StringRef opName, MLIRContext *context,
+ linalg::LinalgTilingOptions options,
+ linalg::LinalgTransformationFilter f,
+ PatternBenefit benefit)
+ : OpInterfaceRewritePattern<linalg::LinalgOp>(context, benefit),
+ filter(f.addOpNameFilter(opName)), options(std::move(options)) {}
+
+FailureOr<linalg::TiledLinalgOp>
+LinalgTilingPattern::returningMatchAndRewrite(linalg::LinalgOp op,
+ PatternRewriter &rewriter) const {
+ if (failed(filter.checkAndNotify(rewriter, op)))
+ return failure();
+
+ FailureOr<linalg::TiledLinalgOp> res = tileLinalgOp(rewriter, op, options);
+ if (failed(res))
+ return failure();
+
+ // Clear filter to stop recursive pattern application.
+ // This must be done here to properly propagate to peeling branches.
+ filter.replaceLinalgTransformationFilter(rewriter, res->op);
+
+ // Peel the loops of the TiledLinalgOp.
+ peelTiledLinalgOp(rewriter, *res, options.peeledLoops, options.loopType);
+
+ if (res->tensorResults.empty())
+ rewriter.eraseOp(op);
+ else
+ rewriter.replaceOp(op, res->tensorResults);
+
+ return res;
+}
+
+LinalgVectorizationPattern::LinalgVectorizationPattern(
+ MLIRContext *context, linalg::LinalgTransformationFilter f,
+ linalg::LinalgVectorizationOptions options, PatternBenefit benefit)
+ : OpInterfaceRewritePattern<linalg::LinalgOp>(context, benefit),
+ filter(std::move(f)) {}
+
+LinalgVectorizationPattern::LinalgVectorizationPattern(
+ StringRef opName, MLIRContext *context,
+ linalg::LinalgVectorizationOptions options,
+ linalg::LinalgTransformationFilter f, PatternBenefit benefit)
+ : OpInterfaceRewritePattern<linalg::LinalgOp>(context, benefit),
+ filter(f.addOpNameFilter(opName)) {}
+
+LogicalResult
+LinalgVectorizationPattern::matchAndRewrite(linalg::LinalgOp linalgOp,
+ PatternRewriter &rewriter) const {
+ if (failed(filter.checkAndNotify(rewriter, linalgOp)))
+ return failure();
+ return vectorize(rewriter, linalgOp);
+}
+
+namespace {
+
+/// Configurable pass to apply pattern-based tiling and fusion.
+struct LinalgStrategyTileAndFusePass
+ : public LinalgStrategyTileAndFusePassBase<LinalgStrategyTileAndFusePass> {
+
+ LinalgStrategyTileAndFusePass() = default;
+
+ LinalgStrategyTileAndFusePass(StringRef opName,
+ linalg::LinalgTilingAndFusionOptions opt,
+ linalg::LinalgTransformationFilter filt)
+ : options(std::move(opt)), filter(std::move(filt)) {
+ this->anchorOpName.setValue(opName.str());
+ }
+
+ void runOnOperation() override {
+ auto funcOp = getOperation();
+ if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
+ return;
+
+ RewritePatternSet tilingAndFusionPattern(funcOp.getContext());
+ if (!anchorOpName.empty()) {
+ tilingAndFusionPattern.add<linalg::LinalgTileAndFuseTensorOpsPattern>(
+ anchorOpName, funcOp.getContext(), options, filter);
+ } else {
+ tilingAndFusionPattern.add<linalg::LinalgTileAndFuseTensorOpsPattern>(
+ funcOp.getContext(), options, filter);
+ }
+ // Search the root operation using bottom up traversal.
+ GreedyRewriteConfig config;
+ config.useTopDownTraversal = false;
+ (void)applyPatternsAndFoldGreedily(
+ funcOp, std::move(tilingAndFusionPattern), config);
+ }
+
+ linalg::LinalgTilingAndFusionOptions options;
+ linalg::LinalgTransformationFilter filter;
+};
+
+/// Configurable pass to apply pattern-based linalg tiling.
+struct LinalgStrategyTilePass
+ : public LinalgStrategyTilePassBase<LinalgStrategyTilePass> {
+
+ LinalgStrategyTilePass() = default;
+
+ LinalgStrategyTilePass(StringRef opName, linalg::LinalgTilingOptions opt,
+ linalg::LinalgTransformationFilter filt)
+ : options(std::move(opt)), filter(std::move(filt)) {
+ this->anchorOpName.setValue(opName.str());
+ }
+
+ void runOnOperation() override {
+ auto funcOp = getOperation();
+ if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
+ return;
+
+ MLIRContext *ctx = funcOp.getContext();
+ RewritePatternSet tilingPattern(ctx);
+ if (!anchorOpName.empty())
+ tilingPattern.add<LinalgTilingPattern>(anchorOpName, ctx, options,
+ filter);
+ else
+ tilingPattern.add<LinalgTilingPattern>(ctx, options, filter);
+ if (anchorOpName == tensor::PadOp::getOperationName())
+ populatePadTensorTilingPatterns(tilingPattern, options);
+ (void)applyPatternsAndFoldGreedily(funcOp, std::move(tilingPattern));
+ }
+
+ linalg::LinalgTilingOptions options;
+ linalg::LinalgTransformationFilter filter;
+};
+
+/// Configurable pass to apply hoisting and padding.
+struct LinalgStrategyPadPass
+ : public LinalgStrategyPadPassBase<LinalgStrategyPadPass> {
+
+ LinalgStrategyPadPass() = default;
+
+ LinalgStrategyPadPass(StringRef opName, linalg::LinalgPaddingOptions opt,
+ linalg::LinalgTransformationFilter filt)
+ : options(std::move(opt)), filter(std::move(filt)) {
+ this->anchorOpName.setValue(opName.str());
+ }
+
+ void runOnOperation() override {
+ auto funcOp = getOperation();
+ if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
+ return;
+
+ RewritePatternSet paddingPattern(funcOp.getContext());
+ if (!anchorOpName.empty()) {
+ paddingPattern.add<linalg::LinalgPaddingPattern>(
+ anchorOpName, funcOp.getContext(), options, filter);
+ } else {
+ paddingPattern.add<linalg::LinalgPaddingPattern>(funcOp.getContext(),
+ options, filter);
+ }
+ (void)applyPatternsAndFoldGreedily(funcOp, std::move(paddingPattern));
+ }
+
+ linalg::LinalgPaddingOptions options;
+ linalg::LinalgTransformationFilter filter;
+};
+
+/// Configurable pass to apply lowering of coarser-grained named linalg ops into
+/// finer-grained named versions.
+struct LinalgStrategyDecomposePass
+ : public LinalgStrategyDecomposePassBase<LinalgStrategyDecomposePass> {
+
+ LinalgStrategyDecomposePass() = default;
+
+ LinalgStrategyDecomposePass(linalg::LinalgTransformationFilter filter)
+ : filter(std::move(filter)) {}
+
+ void runOnOperation() override {
+ auto funcOp = getOperation();
+ if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
+ return;
+ RewritePatternSet decompositionPattern(funcOp.getContext());
+ populateDecomposeConvolutionPatterns(decompositionPattern, filter);
+ if (failed(applyPatternsAndFoldGreedily(funcOp,
+ std::move(decompositionPattern))))
+ signalPassFailure();
+ }
+
+ linalg::LinalgTransformationFilter filter;
+};
+
+/// Configurable pass to apply pattern-based linalg peeling.
+struct LinalgStrategyPeelPass
+ : public LinalgStrategyPeelPassBase<LinalgStrategyPeelPass> {
+
+ LinalgStrategyPeelPass() = default;
+
+ LinalgStrategyPeelPass(StringRef opName, linalg::LinalgPeelOptions opt,
+ linalg::LinalgTransformationFilter filt)
+ : options(std::move(opt)), filter(std::move(filt)) {
+ this->anchorOpName.setValue(opName.str());
+ }
+
+ void runOnOperation() override {
+ auto funcOp = getOperation();
+ if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
+ return;
+
+ RewritePatternSet peelingPatterns(funcOp.getContext());
+ if (!anchorOpName.empty()) {
+ peelingPatterns.add<linalg::LinalgPeelingPattern>(
+ anchorOpName, funcOp.getContext(), options, filter);
+ } else {
+ peelingPatterns.add<linalg::LinalgPeelingPattern>(funcOp.getContext(),
+ filter, options);
+ }
+ if (failed(
+ applyPatternsAndFoldGreedily(funcOp, std::move(peelingPatterns))))
+ return signalPassFailure();
+ }
+
+ linalg::LinalgPeelOptions options;
+ linalg::LinalgTransformationFilter filter;
+};
+
+/// Configurable pass to apply pattern-based linalg vectorization.
+struct LinalgStrategyVectorizePass
+ : public LinalgStrategyVectorizePassBase<LinalgStrategyVectorizePass> {
+
+ LinalgStrategyVectorizePass() = default;
+
+ LinalgStrategyVectorizePass(StringRef opName,
+ linalg::LinalgVectorizationOptions opt,
+ linalg::LinalgTransformationFilter filt,
+ bool padVectorize = false)
+ : options(opt), filter(std::move(filt)) {
+ this->anchorOpName.setValue(opName.str());
+ this->vectorizePadding.setValue(padVectorize);
+ }
+
+ void runOnOperation() override {
+ auto funcOp = getOperation();
+ if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
+ return;
+
+ RewritePatternSet vectorizationPatterns(funcOp.getContext());
+ if (!anchorOpName.empty()) {
+ vectorizationPatterns.add<LinalgVectorizationPattern>(
+ anchorOpName, funcOp.getContext(), options, filter);
+ } else {
+ vectorizationPatterns.add<LinalgVectorizationPattern>(funcOp.getContext(),
+ filter, options);
+ }
+ vector::populateVectorTransferPermutationMapLoweringPatterns(
+ vectorizationPatterns);
+ vector::populateVectorReductionToContractPatterns(vectorizationPatterns);
+ vectorizationPatterns.add<linalg::LinalgCopyVTRForwardingPattern,
+ linalg::LinalgCopyVTWForwardingPattern>(
+ funcOp.getContext(), /*benefit=*/2);
+ vector::TransferReadOp::getCanonicalizationPatterns(vectorizationPatterns,
+ funcOp.getContext());
+ vector::TransferWriteOp::getCanonicalizationPatterns(vectorizationPatterns,
+ funcOp.getContext());
+ (void)applyPatternsAndFoldGreedily(funcOp,
+ std::move(vectorizationPatterns));
+
+ // Apply the pad tensor op vectorization separately to avoid running the
+ // GenericPadOpVectorizationPattern too early.
+ // TODO: Improve once we have better infrastructure to control pattern
+ // application.
+ if (vectorizePadding) {
+ RewritePatternSet patterns(funcOp.getContext());
+ linalg::populatePadOpVectorizationPatterns(patterns);
+ (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
+ }
+ }
+
+ linalg::LinalgVectorizationOptions options;
+ linalg::LinalgTransformationFilter filter;
+};
+
+/// Configurable pass to enable the application of other pattern-based linalg
+/// passes.
+struct LinalgStrategyEnablePass
+ : public LinalgStrategyEnablePassBase<LinalgStrategyEnablePass> {
+
+ LinalgStrategyEnablePass(linalg::LinalgEnablingOptions opt,
+ linalg::LinalgTransformationFilter filt)
+ : options(opt), filter(std::move(filt)) {}
+
+ void runOnOperation() override {
+ auto funcOp = getOperation();
+ if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
+ return;
+
+ MLIRContext *context = funcOp.getContext();
+ RewritePatternSet patterns =
+ linalg::getLinalgTilingCanonicalizationPatterns(context);
+ scf::populateSCFForLoopCanonicalizationPatterns(patterns);
+ if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns))))
+ return signalPassFailure();
+
+ if (options.licm) {
+ funcOp->walk([&](LoopLikeOpInterface loopLike) {
+ moveLoopInvariantCode(loopLike);
+ });
+ }
+
+ // Gathers all innermost loops through a post order pruned walk.
+ funcOp.walk([](Operation *op) {
+ if (auto forOp = dyn_cast<AffineForOp>(op))
+ (void)promoteIfSingleIteration(forOp);
+ else if (auto forOp = dyn_cast<scf::ForOp>(op))
+ (void)promoteIfSingleIteration(forOp);
+ });
+ if (options.hoistRedundantVectorTransfers)
+ linalg::hoistRedundantVectorTransfers(funcOp);
+
+ if (options.hoistRedundantVectorTransfersOnTensor)
+ linalg::hoistRedundantVectorTransfersOnTensor(funcOp);
+
+ // Run CSE to cleanup after canonicalization.
+ OpPassManager dynamicPM("func.func");
+ dynamicPM.addPass(createCSEPass());
+ if (failed(runPipeline(dynamicPM, funcOp)))
+ return signalPassFailure();
+ }
+
+ linalg::LinalgEnablingOptions options;
+ linalg::LinalgTransformationFilter filter;
+};
+
+/// Configurable pass to lower vector operations.
+struct LinalgStrategyLowerVectorsPass
+ : public LinalgStrategyLowerVectorsPassBase<
+ LinalgStrategyLowerVectorsPass> {
+
+ LinalgStrategyLowerVectorsPass(linalg::LinalgVectorLoweringOptions opt,
+ linalg::LinalgTransformationFilter filt)
+ : options(opt), filter(std::move(filt)) {}
+
+ void runOnOperation() override {
+ auto funcOp = getOperation();
+ if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
+ return;
+
+ MLIRContext *context = funcOp.getContext();
+ RewritePatternSet patterns(context);
+ vector::populateVectorToVectorCanonicalizationPatterns(patterns);
+ // In a progressive lowering of vectors, this would be the 1st step.
+ if (options.contractionLowering) {
+ patterns.add<vector::ContractionOpToOuterProductOpLowering,
+ vector::ContractionOpToMatmulOpLowering,
+ vector::ContractionOpLowering>(
+ options.vectorTransformOptions, context);
+ vector::populateVectorTransferPermutationMapLoweringPatterns(patterns);
+ }
+ // In a progressive lowering of vectors, this would be the 2nd step.
+ if (options.multiReductionLowering) {
+ vector::populateVectorMultiReductionLoweringPatterns(
+ patterns,
+ options.vectorTransformOptions.vectorMultiReductionLowering);
+ }
+ // In a progressive lowering of vectors, this would be the 3rd step.
+ if (options.transferPartialRewrite) {
+ patterns.add<vector::VectorTransferFullPartialRewriter>(
+ context, options.vectorTransformOptions);
+ }
+ // In a progressive lowering of vectors, this would be the 4th step.
+ if (options.transferLowering) {
+ vector::populateVectorTransferLoweringPatterns(patterns,
+ options.maxTransferRank);
+ }
+ // In a progressive lowering of vectors, this would be the 5th step.
+ if (options.transferToSCFConversion) {
+ populateVectorToSCFConversionPatterns(
+ patterns, options.vectorTransferToSCFOptions.setTargetRank(
+ options.maxTransferRank));
+ }
+ // In a progressive lowering of vectors, this would be the 6th step.
+ if (options.shapeCastLowering) {
+ vector::populateVectorShapeCastLoweringPatterns(patterns);
+ }
+ // In a progressive lowering of vectors, this would be the 7th step.
+ if (options.transposeLowering) {
+ vector::populateVectorTransposeLoweringPatterns(
+ patterns, options.vectorTransformOptions);
+ if (options.avx2Lowering)
+ x86vector::avx2::populateSpecializedTransposeLoweringPatterns(
+ patterns, options.avx2LoweringOptions, /*benefit=*/10);
+ }
+ (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
+ }
+
+ linalg::LinalgVectorLoweringOptions options;
+ linalg::LinalgTransformationFilter filter;
+};
+
+/// Configurable pass to lower vector operations.
+struct LinalgStrategyRemoveMarkersPass
+ : public LinalgStrategyRemoveMarkersPassBase<
+ LinalgStrategyRemoveMarkersPass> {
+
+ void runOnOperation() override {
+ auto funcOp = getOperation();
+ if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
+ return;
+ funcOp.walk([](linalg::LinalgOp op) {
+ op->removeAttr(linalg::LinalgTransforms::kLinalgTransformMarker);
+ });
+ }
+};
+} // namespace
+
+/// Create a LinalgStrategyTileAndFusePass.
+std::unique_ptr<OperationPass<func::FuncOp>>
+createLinalgStrategyTileAndFusePass(
+ StringRef opName, const linalg::LinalgTilingAndFusionOptions &options,
+ const linalg::LinalgTransformationFilter &filter) {
+ return std::make_unique<LinalgStrategyTileAndFusePass>(opName, options,
+ filter);
+}
+
+/// Create a LinalgStrategyTilePass.
+std::unique_ptr<OperationPass<func::FuncOp>>
+createLinalgStrategyTilePass(StringRef opName,
+ const linalg::LinalgTilingOptions &opt,
+ const linalg::LinalgTransformationFilter &filter) {
+ return std::make_unique<LinalgStrategyTilePass>(opName, opt, filter);
+}
+
+/// Create a LinalgStrategyPadPass.
+std::unique_ptr<OperationPass<func::FuncOp>>
+createLinalgStrategyPadPass(StringRef opName,
+ const linalg::LinalgPaddingOptions &opt,
+ const linalg::LinalgTransformationFilter &filter) {
+ return std::make_unique<LinalgStrategyPadPass>(opName, opt, filter);
+}
+
+/// Create a LinalgStrategyDecomposePass.
+// TODO: if/when we need finer control add an `opName` parameter.
+std::unique_ptr<OperationPass<func::FuncOp>> createLinalgStrategyDecomposePass(
+ const linalg::LinalgTransformationFilter &filter) {
+ return std::make_unique<LinalgStrategyDecomposePass>(filter);
+}
+
+/// Create a LinalgStrategyPeelPass.
+std::unique_ptr<OperationPass<func::FuncOp>>
+createLinalgStrategyPeelPass(StringRef opName,
+ const linalg::LinalgPeelOptions &opt,
+ const linalg::LinalgTransformationFilter &filter) {
+ return std::make_unique<LinalgStrategyPeelPass>(opName, opt, filter);
+}
+
+/// Create a LinalgStrategyVectorizePass.
+std::unique_ptr<OperationPass<func::FuncOp>> createLinalgStrategyVectorizePass(
+ StringRef opName, linalg::LinalgVectorizationOptions opt,
+ const linalg::LinalgTransformationFilter &filter, bool padVectorize) {
+ return std::make_unique<LinalgStrategyVectorizePass>(opName, opt, filter,
+ padVectorize);
+}
+
+/// Create a LinalgStrategyEnablePass.
+std::unique_ptr<OperationPass<func::FuncOp>> createLinalgStrategyEnablePass(
+ linalg::LinalgEnablingOptions opt,
+ const linalg::LinalgTransformationFilter &filter) {
+ return std::make_unique<LinalgStrategyEnablePass>(opt, filter);
+}
+
+/// Create a LinalgStrategyLowerVectorsPass.
+std::unique_ptr<OperationPass<func::FuncOp>>
+createLinalgStrategyLowerVectorsPass(
+ linalg::LinalgVectorLoweringOptions opt,
+ const linalg::LinalgTransformationFilter &filter) {
+ return std::make_unique<LinalgStrategyLowerVectorsPass>(opt, filter);
+}
+
+/// Create a LinalgStrategyRemoveMarkersPass.
+std::unique_ptr<OperationPass<func::FuncOp>>
+createLinalgStrategyRemoveMarkersPass() {
+ return std::make_unique<LinalgStrategyRemoveMarkersPass>();
+}
+
+} // namespace LinalgExt
+} // namespace IREE
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/StructuredTransformOpsExt.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/StructuredTransformOpsExt.cpp
index 850d9b2..f56d4a1 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/StructuredTransformOpsExt.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/StructuredTransformOpsExt.cpp
@@ -347,6 +347,7 @@
DBGS() << "different source linalg ops for replacing one op: \n"
<< sourceOp << "\n"
<< currentSourceOp << "\n");
+ return nullptr;
}
LLVM_DEBUG(DBGS() << "replacing linalg op with unknown non-linalg op:\n"
<< *value.getDefiningOp() << "\n");
@@ -366,6 +367,7 @@
}
LLVM_DEBUG(
DBGS() << "different source scf.for ops when replacing one op\n");
+ return nullptr;
}
LLVM_DEBUG(
@@ -376,6 +378,26 @@
return forOp;
}
+/// Find the op that defines all values in the range.
+static Operation *findSingleOpDefiningAll(ValueRange range) {
+ Operation *op = nullptr;
+ for (Value value : range) {
+ if (auto currentSourceOp = value.getDefiningOp()) {
+ if (!op || op == currentSourceOp) {
+ op = currentSourceOp;
+ continue;
+ }
+ LLVM_DEBUG(DBGS() << "different source op when replacing one op\n");
+ return nullptr;
+ }
+
+ LLVM_DEBUG(
+ DBGS() << "could not find a source op when replacing another op\n");
+ return nullptr;
+ }
+ return op;
+}
+
// Find a single op that defines all values in the range, optionally
// transitively through other operations in an op-specific way.
static Operation *findSingleDefiningOp(Operation *replacedOp,
@@ -387,7 +409,9 @@
.Case<scf::ForOp>([&](scf::ForOp) -> Operation * {
return findSingleForOpDefiningAll(range);
})
- .Default([](Operation *) -> Operation * { return nullptr; });
+ .Default([&](Operation *) -> Operation * {
+ return findSingleOpDefiningAll(range);
+ });
}
void mlir::TrackingListener::notifyOperationReplaced(Operation *op,
diff --git a/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/foreach-thread-to-async.mlir b/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/foreach-thread-to-async.mlir
index dda41e1..45b26ef 100644
--- a/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/foreach-thread-to-async.mlir
+++ b/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/foreach-thread-to-async.mlir
@@ -9,60 +9,49 @@
#map3 = affine_map<(d0, d1) -> (d0, d1)>
#map4 = affine_map<(d0) -> (d0)>
-module {
- // CHECK-LABEL: func.func @static_tile
- // CHECK-SAME: %[[CHUNK_SIZE:[0-9a-z]+]]: index
- // CHECK-SAME: %[[IN:[0-9a-z]+]]: memref<?xf32>
- // CHECK-SAME: %[[OUT:[0-9a-z]+]]: memref<?xf32>
- func.func @static_tile(%arg0: index, %arg1: memref<?xf32>, %arg2: memref<?xf32>) {
- %cst = arith.constant 4.200000e+01 : f32
- %c0 = arith.constant 0 : index
- %c1 = arith.constant 1 : index
- %0 = memref.dim %arg2, %c0 : memref<?xf32>
- %1 = affine.apply #map0(%0)[%arg0]
+// CHECK-LABEL: func.func @static_tile
+// CHECK-SAME: %[[CHUNK_SIZE:[0-9a-z]+]]: index
+// CHECK-SAME: %[[IN:[0-9a-z]+]]: memref<?xf32>
+// CHECK-SAME: %[[OUT:[0-9a-z]+]]: memref<?xf32>
+func.func @static_tile(%arg0: index, %arg1: memref<?xf32>, %arg2: memref<?xf32>) {
+ %cst = arith.constant 4.200000e+01 : f32
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %0 = memref.dim %arg2, %c0 : memref<?xf32>
+ %1 = affine.apply #map0(%0)[%arg0]
- // CHECK: %[[M:.*]] = memref.dim %{{.*}}, %{{.*}} : memref<?xf32>
- // CHECK: %[[group:.*]] = async.create_group {{.*}}: !async.group
- // CHECK: scf.for %[[IV:.*]] = {{.*}}
- // CHECK: %[[token:.*]] = async.execute {
- // CHECK: subview
- // CHECK: subview
- // CHECK: linalg.generic
- // CHECK: async.yield
- // CHECK: }
- // CHECK: async.add_to_group %[[token]], %[[group]] : !async.token
- // CHECK: }
- // CHECK: async.await_all %[[group]]
- scf.foreach_thread (%arg3) in (%1) shared_outs() -> () {
- %3 = affine.apply #map1(%arg3)[%arg0]
- %4 = affine.apply #map2(%0, %3)
- %5 = affine.min #map3(%4, %arg0)
+ // CHECK: %[[M:.*]] = memref.dim %{{.*}}, %{{.*}} : memref<?xf32>
+ // CHECK: %[[group:.*]] = async.create_group {{.*}}: !async.group
+ // CHECK: scf.for %[[IV:.*]] = {{.*}}
+ // CHECK: %[[token:.*]] = async.execute {
+ // CHECK: subview
+ // CHECK: subview
+ // CHECK: linalg.generic
+ // CHECK: async.yield
+ // CHECK: }
+ // CHECK: async.add_to_group %[[token]], %[[group]] : !async.token
+ // CHECK: }
+ // CHECK: async.await_all %[[group]]
+ scf.foreach_thread (%arg3) in (%1) shared_outs() -> () {
+ %3 = affine.apply #map1(%arg3)[%arg0]
+ %4 = affine.apply #map2(%0, %3)
+ %5 = affine.min #map3(%4, %arg0)
- %6 = memref.subview %arg2[%3] [%5] [%c1] : memref<?xf32> to memref<?xf32, strided<[?], offset:?>>
- %7 = memref.subview %arg1[%3] [%5] [1] : memref<?xf32> to memref<?xf32, strided<[?], offset:?>>
+ %6 = memref.subview %arg2[%3] [%5] [%c1] : memref<?xf32> to memref<?xf32, strided<[?], offset:?>>
+ %7 = memref.subview %arg1[%3] [%5] [1] : memref<?xf32> to memref<?xf32, strided<[?], offset:?>>
- linalg.generic {indexing_maps = [#map4, #map4], iterator_types = ["parallel"]}
- ins(%7 : memref<?xf32, strided<[?], offset:?>>) outs(%6 : memref<?xf32, strided<[?], offset:?>>) {
- ^bb0(%arg4: f32, %arg5: f32): // no predecessors
- %9 = arith.mulf %arg4, %cst : f32
- linalg.yield %9 : f32
- }
- }
- return
+ linalg.generic {indexing_maps = [#map4, #map4], iterator_types = ["parallel"]}
+ ins(%7 : memref<?xf32, strided<[?], offset:?>>) outs(%6 : memref<?xf32, strided<[?], offset:?>>) {
+ ^bb0(%arg4: f32, %arg5: f32): // no predecessors
+ %9 = arith.mulf %arg4, %cst : f32
+ linalg.yield %9 : f32
+ }
}
+ return
+}
- transform.with_pdl_patterns {
- ^bb0(%arg0: !pdl.operation):
- pdl.pattern @match_foreach_thread : benefit(1) {
- %0 = operands
- %1 = types
- %2 = operation "scf.foreach_thread"(%0 : !pdl.range<value>) -> (%1 : !pdl.range<type>)
- rewrite %2 with "transform.dialect"
- }
- transform.structured.canonicalized_sequence %arg0 failures(propagate) {
- ^bb1(%arg1: !pdl.operation):
- %0 = pdl_match @match_foreach_thread in %arg1
- %1 = foreach_thread_to_async %0
- }
- }
+transform.structured.canonicalized_sequence failures(propagate) {
+^bb1(%module_op: !pdl.operation):
+ %0 = transform.structured.match ops{["scf.foreach_thread"]} in %module_op
+ %1 = foreach_thread_to_async %0
}
diff --git a/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/foreach-thread-to-scf-for.mlir b/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/foreach-thread-to-scf-for.mlir
index d5a3704..a261e9e 100644
--- a/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/foreach-thread-to-scf-for.mlir
+++ b/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/foreach-thread-to-scf-for.mlir
@@ -9,55 +9,43 @@
#map3 = affine_map<(d0, d1) -> (d0, d1)>
#map4 = affine_map<(d0) -> (d0)>
-module {
+// CHECK-LABEL: func.func @static_tile_buffers
+// CHECK-SAME: %[[CHUNK_SIZE:[0-9a-z]+]]: index
+// CHECK-SAME: %[[IN:[0-9a-z]+]]: memref<?xf32>
+// CHECK-SAME: %[[OUT:[0-9a-z]+]]: memref<?xf32>
+func.func @static_tile_buffers(%arg0: index, %arg1: memref<?xf32>, %arg2: memref<?xf32>) {
+ %cst = arith.constant 4.200000e+01 : f32
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %0 = memref.dim %arg2, %c0 : memref<?xf32>
+ %1 = affine.apply #map0(%0)[%arg0]
- // CHECK-LABEL: func.func @static_tile_buffers
- // CHECK-SAME: %[[CHUNK_SIZE:[0-9a-z]+]]: index
- // CHECK-SAME: %[[IN:[0-9a-z]+]]: memref<?xf32>
- // CHECK-SAME: %[[OUT:[0-9a-z]+]]: memref<?xf32>
- func.func @static_tile_buffers(%arg0: index, %arg1: memref<?xf32>, %arg2: memref<?xf32>) {
- %cst = arith.constant 4.200000e+01 : f32
- %c0 = arith.constant 0 : index
- %c1 = arith.constant 1 : index
- %0 = memref.dim %arg2, %c0 : memref<?xf32>
- %1 = affine.apply #map0(%0)[%arg0]
+ // CHECK: %[[C1:.*]] = arith.constant 1 : index
+ // CHECK: %[[M:.*]] = memref.dim %{{.*}}, %{{.*}} : memref<?xf32>
+ // CHECK: scf.for %[[IV:.*]] = {{.*}} step %[[C1]] {
+ scf.foreach_thread (%arg3) in (%1) shared_outs() -> () {
+ %3 = affine.apply #map1(%arg3)[%arg0]
+ %4 = affine.apply #map2(%0, %3)
+ %5 = affine.min #map3(%4, %arg0)
- // CHECK: %[[C1:.*]] = arith.constant 1 : index
- // CHECK: %[[M:.*]] = memref.dim %{{.*}}, %{{.*}} : memref<?xf32>
- // CHECK: scf.for %[[IV:.*]] = {{.*}} step %[[C1]] {
- scf.foreach_thread (%arg3) in (%1) shared_outs() -> () {
- %3 = affine.apply #map1(%arg3)[%arg0]
- %4 = affine.apply #map2(%0, %3)
- %5 = affine.min #map3(%4, %arg0)
+ %6 = memref.subview %arg2[%3] [%5] [%c1] : memref<?xf32> to memref<?xf32, strided<[?], offset:?>>
+ %7 = memref.subview %arg1[%3] [%5] [1] : memref<?xf32> to memref<?xf32, strided<[?], offset:?>>
- %6 = memref.subview %arg2[%3] [%5] [%c1] : memref<?xf32> to memref<?xf32, strided<[?], offset:?>>
- %7 = memref.subview %arg1[%3] [%5] [1] : memref<?xf32> to memref<?xf32, strided<[?], offset:?>>
+ linalg.generic {indexing_maps = [#map4, #map4], iterator_types = ["parallel"]}
+ ins(%7 : memref<?xf32, strided<[?], offset:?>>) outs(%6 : memref<?xf32, strided<[?], offset:?>>) {
+ ^bb0(%arg4: f32, %arg5: f32): // no predecessors
+ %9 = arith.mulf %arg4, %cst : f32
+ linalg.yield %9 : f32
+ }
- linalg.generic {indexing_maps = [#map4, #map4], iterator_types = ["parallel"]}
- ins(%7 : memref<?xf32, strided<[?], offset:?>>) outs(%6 : memref<?xf32, strided<[?], offset:?>>) {
- ^bb0(%arg4: f32, %arg5: f32): // no predecessors
- %9 = arith.mulf %arg4, %cst : f32
- linalg.yield %9 : f32
- }
-
- // Nothing is yielded, skip the terminator.
- // CHECK-NOT: scf.yield
- }
- return
+ // Nothing is yielded, skip the terminator.
+ // CHECK-NOT: scf.yield
}
+ return
+}
- transform.with_pdl_patterns {
- ^bb0(%arg0: !pdl.operation):
- pdl.pattern @match_foreach_thread : benefit(1) {
- %0 = operands
- %1 = types
- %2 = operation "scf.foreach_thread"(%0 : !pdl.range<value>) -> (%1 : !pdl.range<type>)
- rewrite %2 with "transform.dialect"
- }
- transform.structured.canonicalized_sequence %arg0 failures(propagate) {
- ^bb1(%arg1: !pdl.operation):
- %0 = pdl_match @match_foreach_thread in %arg1
- %1 = foreach_thread_to_scf_for %0
- }
- }
+transform.structured.canonicalized_sequence failures(propagate) {
+^bb1(%module_op: !pdl.operation):
+ %0 = transform.structured.match ops{["scf.foreach_thread"]} in %module_op
+ %1 = foreach_thread_to_scf_for %0
}
diff --git a/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/bufferize.mlir b/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/bufferize.mlir
index 0486119..5f4e74e 100644
--- a/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/bufferize.mlir
+++ b/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/bufferize.mlir
@@ -18,20 +18,7 @@
// CHECK: }
}
-transform.with_pdl_patterns {
-^bb0(%arg0: !pdl.operation):
- pdl.pattern @pdl_target : benefit(1) {
- %args = operands
- %results = types
- %0 = operation "linalg.matmul"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
- %1 = pdl.attribute = @matmul_tensors
- apply_native_constraint "nestedInFunc"(%0, %1 : !pdl.operation, !pdl.attribute)
- // TODO: we don't want this, but it is the required terminator for pdl.pattern
- rewrite %0 with "transform.dialect"
- }
-
- transform.structured.canonicalized_sequence %arg0 failures(propagate) {
- ^bb1(%arg1: !pdl.operation):
- bufferize
- }
+transform.structured.canonicalized_sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+ bufferize
}
diff --git a/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/double-tiling.mlir b/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/double-tiling.mlir
deleted file mode 100644
index bfbab68..0000000
--- a/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/double-tiling.mlir
+++ /dev/null
@@ -1,89 +0,0 @@
-// RUN: iree-dialects-opt --transform-dialect-interpreter --split-input-file %s | FileCheck %s
-
-// This test is verifying that a non-trivial 2*tiling+padding+vectorization transformation completes successfully
-
-// CHECK-LABEL: func.func @matmul_tensors(
-func.func @matmul_tensors(
- %arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, %arg2: tensor<128x128xf32> { linalg.inplaceable = true})
- -> tensor<128x128xf32> {
- // Padding and transpose should be folded into vectorization
- // CHECK-NOT: tensor.pad
- // CHECK-NOT: linalg.generic
-
- // CHECK: scf.for
- // CHECK: scf.for
- // CHECK: scf.for
- // CHECK: scf.for
- // CHECK: scf.for
- // CHECK-NOT: linalg.generic
- // CHECK: vector.contract
- %0 = linalg.matmul ins(%arg0, %arg1: tensor<128x128xf32>, tensor<128x128xf32>)
- outs(%arg2: tensor<128x128xf32>)
- -> tensor<128x128xf32>
-
- return %0 : tensor<128x128xf32>
-}
-
-transform.with_pdl_patterns {
-^bb0(%arg0: !pdl.operation):
- pdl.pattern @pdl_target: benefit(1) {
- %args = operands
- %results = types
- %0 = operation "linalg.matmul"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
- %1 = pdl.attribute = @matmul_tensors
- apply_native_constraint "nestedInFunc"(%0, %1 : !pdl.operation, !pdl.attribute)
- rewrite %0 with "transform.dialect"
- }
-
- transform.structured.canonicalized_sequence %arg0 failures(propagate) {
- ^bb1(%arg1: !pdl.operation):
- %0 = pdl_match @pdl_target in %arg1
- %1, %loops1:3 = transform.structured.tile %0 [32, 32, 32] {interchange = [0, 2, 1]}
- %2, %loops2:3 = transform.structured.tile %1 [4, 4, 1] {interchange = [0, 1, 2]}
- %3 = transform.structured.pad %2 {padding_values=[0.0 : f32, 0.0 : f32, 0.0 : f32], pack_paddings = [1, 1, 1], hoist_paddings = [6, 6, 0], transpose_paddings = [[1, 0], [0, 1]]}
- %4 = transform.get_closest_isolated_parent %3
- transform.structured.vectorize %4 { vectorize_padding = true }
- }
-}
-
-// -----
-
-// CHECK-LABEL: func.func @matmul_tensors_pad(
-func.func @matmul_tensors_pad(
- %arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, %arg2: tensor<128x128xf32> { linalg.inplaceable = true})
- -> tensor<128x128xf32> {
- // CHECK: scf.for
- // CHECK: scf.for
- // CHECK: scf.for
- // CHECK: scf.for
- // CHECK: scf.for
- // CHECK: tensor.pad
- // CHECK: vector.contract
- %0 = linalg.matmul ins(%arg0, %arg1: tensor<128x128xf32>, tensor<128x128xf32>)
- outs(%arg2: tensor<128x128xf32>)
- -> tensor<128x128xf32>
-
- return %0 : tensor<128x128xf32>
-}
-
-transform.with_pdl_patterns {
-^bb0(%arg0: !pdl.operation):
- pdl.pattern @pdl_target: benefit(1) {
- %args = operands
- %results = types
- %0 = operation "linalg.matmul"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
- %1 = pdl.attribute = @matmul_tensors_pad
- apply_native_constraint "nestedInFunc"(%0, %1 : !pdl.operation, !pdl.attribute)
- rewrite %0 with "transform.dialect"
- }
-
- transform.structured.canonicalized_sequence %arg0 failures(propagate) {
- ^bb1(%arg1: !pdl.operation):
- %0 = pdl_match @pdl_target in %arg1
- %1, %loops1:3 = transform.structured.tile %0 [32, 32, 32] {interchange = [0, 2, 1]}
- %2, %loops2:3 = transform.structured.tile %1 [4, 4, 1] {interchange = [0, 1, 2]}
- %3 = transform.structured.pad %2 {padding_values=[0.0 : f32, 0.0 : f32, 0.0 : f32], pack_paddings = [1, 1, 1], hoist_paddings = [6, 6, 0], transpose_paddings = [[1, 0], [0, 1]]}
- %4 = transform.get_closest_isolated_parent %3
- transform.structured.vectorize %4 { vectorize_padding = false }
- }
-}
diff --git a/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/expert.mlir b/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/expert.mlir
index a191979..3dc056c 100644
--- a/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/expert.mlir
+++ b/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/expert.mlir
@@ -30,7 +30,7 @@
// EXPAND-NOT: expert apply
// EXPAND: %[[OP:.*]] = match @pdl_target
// EXPAND: %[[HANDLE:.*]], %{{.*}}:3 = tile %[[OP]] {sizes = [4, 4, 4]}
- // EXPAND: %[[HANDLE2:.*]] = vectorize %[[HANDLE]] {vectorize_padding = true}
+ // EXPAND: %[[HANDLE2:.*]] = vectorize %[[HANDLE]] vectorize_padding
// EXPAND: bufferize
// EXPAND: lower_vectors {multireduction_lowering = "innerreduce"}
// EXPAND: lower_to_llvm
@@ -114,7 +114,8 @@
// EXPAND: %[[OP:.*]] = match @pdl_target2
// EXPAND: %[[HANDLE:.*]], %{{.*}}:3 = tile %[[OP]] {sizes = [32, 8, 8]}
// EXPAND: %[[HANDLE2:.*]], %{{.*}}:3 = tile %[[HANDLE]] {sizes = [4, 4, 4]}
- // EXPAND: %[[HANDLE3:.*]] = vectorize %[[HANDLE2]] {vectorize_padding = false}
+ // EXPAND: %[[HANDLE3:.*]] = vectorize %[[HANDLE2]]
+ // EXPAND-NOT: vectorize_padding
// EXPAND: bufferize
// EXPAND: lower_vectors {multireduction_lowering = "innerparallel"}
// EXPAND: lower_to_llvm
diff --git a/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/fuse-and-peel.mlir b/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/fuse-and-peel.mlir
deleted file mode 100644
index 62ad6cd..0000000
--- a/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/fuse-and-peel.mlir
+++ /dev/null
@@ -1,40 +0,0 @@
-// RUN: iree-dialects-opt --transform-dialect-interpreter %s | FileCheck %s
-
-// CHECK-LABEL: func.func @fuse_unary
-func.func @fuse_unary(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
-
- // CHECK: scf.for
- // CHECK: scf.for
- // CHECK: linalg.elemwise_unary
- // CHECK: linalg.elemwise_binary
- // CHECK: scf.for
- // CHECK: scf.for
- // CHECK: linalg.elemwise_unary
- // CHECK: linalg.elemwise_binary
- %0 = linalg.elemwise_unary ins(%arg0 : tensor<?x?xf32>)
- outs(%arg1: tensor<?x?xf32>) -> tensor<?x?xf32>
- %1 = linalg.elemwise_binary ins(%0, %arg0 : tensor<?x?xf32>, tensor<?x?xf32>)
- outs(%arg1: tensor<?x?xf32>) -> tensor<?x?xf32>
- return %1 : tensor<?x?xf32>
-}
-
-transform.with_pdl_patterns {
-^bb0(%arg0: !pdl.operation):
- pdl.pattern @pdl_target : benefit(1) {
- %args = operands
- %results = types
- %0 = pdl.operation "linalg.elemwise_binary"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
- %1 = pdl.attribute = @fuse_unary
- apply_native_constraint "nestedInFunc"(%0, %1 : !pdl.operation, !pdl.attribute)
- // TODO: we don't want this, but it is the required terminator for pdl.pattern
- rewrite %0 with "transform.dialect"
- }
-
- transform.structured.canonicalized_sequence %arg0 failures(propagate) {
- ^bb1(%arg1: !pdl.operation):
- %0 = pdl_match @pdl_target in %arg1
- %1, %loops:2 = transform.structured.fuse %0 {tile_sizes = [32, 32], tile_interchange = [0, 1]}
-
- transform.loop.peel %loops#0
- }
-}
diff --git a/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/fuse.mlir b/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/fuse.mlir
deleted file mode 100644
index 2fd69de..0000000
--- a/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/fuse.mlir
+++ /dev/null
@@ -1,35 +0,0 @@
-// RUN: iree-dialects-opt --transform-dialect-interpreter %s | FileCheck %s
-
-
-// CHECK-LABEL: func.func @fuse_unary
-func.func @fuse_unary(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
-
- // CHECK: scf.for
- // CHECK: scf.for
- // CHECK: linalg.elemwise_unary
- // CHECK: linalg.elemwise_binary
- %0 = linalg.elemwise_unary ins(%arg0 : tensor<?x?xf32>)
- outs(%arg1: tensor<?x?xf32>) -> tensor<?x?xf32>
- %1 = linalg.elemwise_binary ins(%0, %arg0 : tensor<?x?xf32>, tensor<?x?xf32>)
- outs(%arg1: tensor<?x?xf32>) -> tensor<?x?xf32>
- return %1 : tensor<?x?xf32>
-}
-
-transform.with_pdl_patterns {
-^bb0(%arg0: !pdl.operation):
- pdl.pattern @pdl_target : benefit(1) {
- %args = operands
- %results = types
- %0 = pdl.operation "linalg.elemwise_binary"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
- %1 = pdl.attribute = @fuse_unary
- apply_native_constraint "nestedInFunc"(%0, %1 : !pdl.operation, !pdl.attribute)
- // TODO: we don't want this, but it is the required terminator for pdl.pattern
- rewrite %0 with "transform.dialect"
- }
-
- transform.structured.canonicalized_sequence %arg0 failures(propagate) {
- ^bb1(%arg1: !pdl.operation):
- %0 = pdl_match @pdl_target in %arg1
- %1, %loops:2 = transform.structured.fuse %0 {tile_sizes = [32, 32], tile_interchange = [0, 1]}
- }
-}
diff --git a/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/generalize.mlir b/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/generalize.mlir
deleted file mode 100644
index e8ff91f..0000000
--- a/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/generalize.mlir
+++ /dev/null
@@ -1,31 +0,0 @@
-// RUN: iree-dialects-opt --transform-dialect-interpreter %s | FileCheck %s
-
-
-// CHECK-LABEL: func.func @generalize_unary
-func.func @generalize_unary(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
-
- // CHECK-NOT: linalg.elemwise_unary
- // CHECK: linalg.generic
- %0 = linalg.elemwise_unary ins(%arg0 : tensor<?x?xf32>)
- outs(%arg1: tensor<?x?xf32>) -> tensor<?x?xf32>
- return %0 : tensor<?x?xf32>
-}
-
-transform.with_pdl_patterns {
-^bb0(%arg0: !pdl.operation):
- pdl.pattern @pdl_target : benefit(1) {
- %args = operands
- %results = types
- %0 = pdl.operation "linalg.elemwise_unary"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
- %1 = pdl.attribute = @generalize_unary
- apply_native_constraint "nestedInFunc"(%0, %1 : !pdl.operation, !pdl.attribute)
- // TODO: we don't want this, but it is the required terminator for pdl.pattern
- rewrite %0 with "transform.dialect"
- }
-
- transform.structured.canonicalized_sequence %arg0 failures(propagate) {
- ^bb1(%arg1: !pdl.operation):
- %0 = pdl_match @pdl_target in %arg1
- transform.structured.generalize %0
- }
-}
diff --git a/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/interchange.mlir b/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/interchange.mlir
deleted file mode 100644
index d126809..0000000
--- a/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/interchange.mlir
+++ /dev/null
@@ -1,38 +0,0 @@
-// RUN: iree-dialects-opt --transform-dialect-interpreter %s | FileCheck %s
-
-// CHECK: #[[$MAP:.*]] = affine_map<(d0, d1) -> (d1, d0)>
-
-// CHECK-LABEL: func.func @interchange_generic
-func.func @interchange_generic(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
-
- // CHECK: linalg.generic
- // CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP]]
- %0 = linalg.generic {
- indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
- iterator_types = ["parallel", "parallel"]
- } ins(%arg0 : tensor<?x?xf32>) outs(%arg1 : tensor<?x?xf32>) {
- ^bb0(%arg2: f32, %arg3: f32):
- %1 = math.exp %arg2 : f32
- linalg.yield %1 : f32
- } -> tensor<?x?xf32>
- return %0 : tensor<?x?xf32>
-}
-
-transform.with_pdl_patterns {
-^bb0(%root: !pdl.operation):
- pdl.pattern @pdl_target : benefit(1) {
- %args = operands
- %results = types
- %0 = pdl.operation "linalg.generic"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
- %1 = pdl.attribute = @interchange_generic
- apply_native_constraint "nestedInFunc"(%0, %1 : !pdl.operation, !pdl.attribute)
- // TODO: we don't want this, but it is the required terminator for pdl.pattern
- rewrite %0 with "transform.dialect"
- }
-
- transform.structured.canonicalized_sequence %root failures(propagate) {
- ^bb0(%arg0: !pdl.operation):
- %0 = pdl_match @pdl_target in %arg0
- transform.structured.interchange %0 {iterator_interchange = [1, 0]}
- }
-}
diff --git a/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/pad.mlir b/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/pad.mlir
deleted file mode 100644
index 858a01d..0000000
--- a/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/pad.mlir
+++ /dev/null
@@ -1,52 +0,0 @@
-// RUN: iree-dialects-opt --transform-dialect-interpreter %s | FileCheck %s
-
-#map = affine_map<()[s0] -> (-s0 + 12, 5)>
-
-// CHECK-LABEL: func.func @pad_unary
-func.func @pad_unary(%arg0: tensor<24x12xf32>,
- %arg1: tensor<24x12xf32>) -> tensor<24x12xf32> {
- // CHECK: %[[C0:.*]] = arith.constant 0 : index
- %c0 = arith.constant 0 : index
- %c12 = arith.constant 12 : index
- %c5 = arith.constant 5 : index
-
- // CHECK: scf.for
- // CHECK: tensor.pad
- // CHECK: linalg.generic
- // CHECK: scf.for
- %0 = scf.for %arg3 = %c0 to %c12 step %c5 iter_args(%arg2 = %arg1) -> (tensor<24x12xf32>) {
- %ts = affine.min #map()[%arg3]
- %1 = tensor.extract_slice %arg0[0, %arg3] [24, %ts] [1, 1] : tensor<24x12xf32> to tensor<24x?xf32>
- %2 = tensor.extract_slice %arg2[0, %arg3] [24, %ts] [1, 1] : tensor<24x12xf32> to tensor<24x?xf32>
-
- // CHECK: linalg.generic
- // CHECK: %[[WIDTH:.*]] = affine.apply
- // CHECK: tensor.pad
- // CHECK-SAME: high[%[[C0]], %[[WIDTH]]]
- // CHECK: linalg.elemwise_unary
- %3 = linalg.elemwise_unary ins(%1 : tensor<24x?xf32>)
- outs(%2: tensor<24x?xf32>) -> tensor<24x?xf32>
- %4 = tensor.insert_slice %3 into %arg2[0, %arg3] [24, %ts] [1, 1] : tensor<24x?xf32> into tensor<24x12xf32>
- scf.yield %4 : tensor<24x12xf32>
- }
- return %0 : tensor<24x12xf32>
-}
-
-transform.with_pdl_patterns {
-^bb0(%arg0: !pdl.operation):
- pdl.pattern @pdl_target : benefit(1) {
- %args = operands
- %results = types
- %0 = pdl.operation "linalg.elemwise_unary"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
- %1 = pdl.attribute = @pad_unary
- apply_native_constraint "nestedInFunc"(%0, %1 : !pdl.operation, !pdl.attribute)
- // TODO: we don't want this, but it is the required terminator for pdl.pattern
- rewrite %0 with "transform.dialect"
- }
-
- transform.structured.canonicalized_sequence %arg0 failures(propagate) {
- ^bb1(%arg1: !pdl.operation):
- %0 = pdl_match @pdl_target in %arg1
- %1 = transform.structured.pad %0 {padding_values=[0.0 : f32, 0.0 : f32], padding_dimensions=[1], pack_paddings=[1, 1], hoist_paddings=[1, 0], transpose_paddings=[[1, 0], [0, 1]]}
- }
-}
diff --git a/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/peel.mlir b/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/peel.mlir
deleted file mode 100644
index c736220..0000000
--- a/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/peel.mlir
+++ /dev/null
@@ -1,53 +0,0 @@
-// RUN: iree-dialects-opt --transform-dialect-interpreter %s | FileCheck %s
-
-
-// CHECK-DAG: #[[MAP0:.*]] = affine_map<()[s0, s1, s2] -> (s1 - (-s0 + s1) mod s2)>
-// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0)[s0] -> (-d0 + s0)>
-// CHECK: func.func @fully_dynamic_bounds(
-// CHECK-SAME: %[[LB:.*]]: index, %[[UB:.*]]: index, %[[STEP:.*]]: index
-// CHECK: %[[C0_I32:.*]] = arith.constant 0 : i32
-// CHECK: %[[NEW_UB:.*]] = affine.apply #[[MAP0]]()[%[[LB]], %[[UB]], %[[STEP]]]
-// CHECK: %[[CAST:.*]] = arith.index_cast %[[STEP]] : index to i32
-// CHECK: %[[LOOP:.*]] = scf.for %[[IV:.*]] = %[[LB]] to %[[NEW_UB]]
-// CHECK-SAME: step %[[STEP]] iter_args(%[[ACC:.*]] = %[[C0_I32]]) -> (i32) {
-// CHECK: %[[ADD:.*]] = arith.addi %[[ACC]], %[[CAST]] : i32
-// CHECK: scf.yield %[[ADD]]
-// CHECK: }
-// CHECK: %[[RESULT:.*]] = scf.for %[[IV2:.*]] = %[[NEW_UB]] to %[[UB]]
-// CHECK-SAME: step %[[STEP]] iter_args(%[[ACC2:.*]] = %[[LOOP]]) -> (i32) {
-// CHECK: %[[REM:.*]] = affine.apply #[[MAP1]](%[[IV2]])[%[[UB]]]
-// CHECK: %[[CAST2:.*]] = arith.index_cast %[[REM]]
-// CHECK: %[[ADD2:.*]] = arith.addi %[[ACC2]], %[[CAST2]]
-// CHECK: scf.yield %[[ADD2]]
-// CHECK: }
-// CHECK: return %[[RESULT]]
-#map = affine_map<(d0, d1)[s0] -> (s0, d0 - d1)>
-func.func @fully_dynamic_bounds(%lb : index, %ub: index, %step: index) -> i32 {
- %c0 = arith.constant 0 : i32
- %r = scf.for %iv = %lb to %ub step %step iter_args(%arg = %c0) -> i32 {
- %s = affine.min #map(%ub, %iv)[%step]
- %casted = arith.index_cast %s : index to i32
- %0 = arith.addi %arg, %casted : i32
- scf.yield %0 : i32
- }
- return %r : i32
-}
-
-transform.with_pdl_patterns {
-^bb0(%arg0: !pdl.operation):
- pdl.pattern @pdl_target : benefit(1) {
- %args = operands
- %results = types
- %0 = pdl.operation "scf.for"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
- %1 = pdl.attribute = @fully_dynamic_bounds
- apply_native_constraint "nestedInFunc"(%0, %1 : !pdl.operation, !pdl.attribute)
- // TODO: we don't want this, but it is the required terminator for pdl.pattern
- rewrite %0 with "transform.dialect"
- }
-
- transform.structured.canonicalized_sequence %arg0 failures(propagate) {
- ^bb1(%arg1: !pdl.operation):
- %0 = pdl_match @pdl_target in %arg1
- transform.loop.peel %0
- }
-}
diff --git a/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/roundtrip.mlir b/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/roundtrip.mlir
index 547ab70..54b6dfd 100644
--- a/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/roundtrip.mlir
+++ b/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/roundtrip.mlir
@@ -11,8 +11,8 @@
%2, %loops2:3 = transform.structured.tile %1 [2, 2, 2]
// CHECK: %[[PADDED:.*]] = transform.structured.pad %[[TILED2]] {hoist_paddings = [], pack_paddings = [1, 1, 0], padding_dimensions = [], padding_values = [], transpose_paddings = []}
%3 = transform.structured.pad %2 {pack_paddings = [1, 1, 0]}
- // CHECK: %{{.*}} = transform.structured.vectorize %[[PADDED]] {vectorize_padding = true}
- %4 = transform.structured.vectorize %3 {vectorize_padding = true}
+ // CHECK: %{{.*}} = transform.structured.vectorize %[[PADDED]] {vectorize_padding}
+ %4 = transform.structured.vectorize %3 { vectorize_padding }
// CHECK: %[[OPS2:.*]] = pdl_match @{{.*}}
%5 = pdl_match @match2 in %arg0
// CHECK: transform.structured.vectorize %[[OPS2]]
diff --git a/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/scalarize.mlir b/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/scalarize.mlir
deleted file mode 100644
index 9fc49c2..0000000
--- a/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/scalarize.mlir
+++ /dev/null
@@ -1,37 +0,0 @@
-// TODO(#9510): Enable the test.
-// RUN: iree-dialects-opt --transform-dialect-interpreter %s | FileCheck %s
-
-func.func @fun_to_benchmark(%arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, %arg2: tensor<128x128xf32>) ->
- tensor<128x128xf32> attributes {passthrough = ["noinline", ["target-cpu", "skylake-avx512"], ["prefer-vector-width", "512"]]} {
- // With scalarization we expect vectorization to still work albeit with a leading
- // `1` dimension.
- // CHECK: vector.contract {{.*}} : vector<1x32xf32>, vector<32x16xf32> into vector<1x16xf32>
- %0 = linalg.matmul ins(%arg0, %arg1 : tensor<128x128xf32>, tensor<128x128xf32>)
- outs(%arg2 : tensor<128x128xf32>) -> tensor<128x128xf32>
- return %0 : tensor<128x128xf32>
-}
-
-transform.with_pdl_patterns {
-^bb0(%arg0: !pdl.operation):
- pdl.pattern @isa_linalg.matmul : benefit(1) {
- %0 = operands
- %1 = types
- %2 = operation "linalg.matmul"(%0 : !pdl.range<value>) -> (%1 : !pdl.range<type>)
- rewrite %2 with "transform.dialect"
- }
-
- transform.structured.canonicalized_sequence %arg0 failures(propagate) {
- ^bb1(%arg1: !pdl.operation):
- %0 = pdl_match @isa_linalg.matmul in %arg1
- %tiled_linalg_op, %loops:3 = transform.structured.tile %0 [6, 16, 32] {interchange = [1, 0, 2]}
- %1 = transform.loop.peel %loops#0
-
- %tiled_and_peeled_linalg_op = pdl_match @isa_linalg.matmul in %1
- // This test checks the proper handling of the scalarize dims attribute.
- // The first dimension does not divide but we can always scalarize a `?` into `1`
- // and enable vectorization of a lower-rank op this way.
- %tiled_and_peeled_linalg_op_0 = transform.structured.scalarize %tiled_and_peeled_linalg_op
- %parent = transform.get_closest_isolated_parent %tiled_and_peeled_linalg_op_0
- transform.structured.vectorize %parent {vectorize_padding = false}
- }
-}
diff --git a/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/single-tiling-full-script.mlir b/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/single-tiling-full-script.mlir
index 45f889a..7b93cbe 100644
--- a/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/single-tiling-full-script.mlir
+++ b/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/single-tiling-full-script.mlir
@@ -13,27 +13,13 @@
return %0 : tensor<128x128xf32>
}
-
-transform.with_pdl_patterns {
-^bb0(%arg0: !pdl.operation):
- pdl.pattern @pdl_target : benefit(1) {
- %args = operands
- %results = types
- %0 = pdl.operation "linalg.matmul"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
- %1 = pdl.attribute = @matmul_tensors
- apply_native_constraint "nestedInFunc"(%0, %1 : !pdl.operation, !pdl.attribute)
- // TODO: we don't want this, but it is the required terminator for pdl.pattern
- rewrite %0 with "transform.dialect"
- }
-
- transform.structured.canonicalized_sequence %arg0 failures(propagate) {
- ^bb1(%arg1: !pdl.operation):
- %0 = pdl_match @pdl_target in %arg1
- %1, %loops:3 = transform.structured.tile %0 [4, 4, 4]
- %2 = get_closest_isolated_parent %1
- transform.structured.vectorize %2 {vectorize_padding = true}
- bufferize
- lower_vectors { multireduction_lowering = "innerreduce"}
- lower_to_llvm
- }
+transform.structured.canonicalized_sequence failures(propagate) {
+^bb1(%module_op: !pdl.operation):
+ %0 = transform.structured.match ops{["linalg.matmul"]} in %module_op
+ %1, %loops:3 = transform.structured.tile %0 [4, 4, 4]
+ %2 = get_closest_isolated_parent %1
+ transform.structured.vectorize %2 { vectorize_padding }
+ bufferize
+ lower_vectors { multireduction_lowering = "innerreduce"}
+ lower_to_llvm
}
diff --git a/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/tile-and-peel.mlir b/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/tile-and-peel.mlir
deleted file mode 100644
index 68deb28..0000000
--- a/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/tile-and-peel.mlir
+++ /dev/null
@@ -1,55 +0,0 @@
-// RUN: iree-dialects-opt --transform-dialect-interpreter %s | FileCheck %s
-
-// CHECK-LABEL: func.func @matmul_tensors(
-func.func @matmul_tensors(
- %arg0: tensor<126x127xf32>, %arg1: tensor<127x128xf32>, %arg2: tensor<126x128xf32> { linalg.inplaceable = true})
- -> tensor<126x128xf32> {
- // CHECK-DAG: %[[c4:.*]] = arith.constant 4 : index
- // CHECK-DAG: %[[c124:.*]] = arith.constant 124 : index
- // CHECK-DAG: %[[c128:.*]] = arith.constant 128 : index
-
- // CHECK: scf.for {{.*}} to %[[c124]] step %[[c4]]
- // CHECK: scf.for {{.*}} to %[[c128]] step %[[c4]]
- // CHECK: scf.for {{.*}} to %[[c124]] step %[[c4]]
- // CHECK: linalg.matmul ins({{.*}} : tensor<4x4xf32>, tensor<4x4xf32>) outs({{.*}} : tensor<4x4xf32>) -> tensor<4x4xf32>
- // CHECK: linalg.matmul ins({{.*}} : tensor<4x3xf32>, tensor<3x4xf32>) outs({{.*}} : tensor<4x4xf32>) -> tensor<4x4xf32>
- // CHECK: scf.for {{.*}} to %[[c128]] step %[[c4]]
- // CHECK: scf.for {{.*}} to %[[c124]] step %[[c4]]
- // CHECK: linalg.matmul ins({{.*}} : tensor<2x4xf32>, tensor<4x4xf32>) outs({{.*}} : tensor<2x4xf32>) -> tensor<2x4xf32>
- // CHECK: linalg.matmul ins({{.*}} : tensor<2x3xf32>, tensor<3x4xf32>) outs({{.*}} : tensor<2x4xf32>) -> tensor<2x4xf32>
- %0 = linalg.matmul ins(%arg0, %arg1: tensor<126x127xf32>, tensor<127x128xf32>)
- outs(%arg2: tensor<126x128xf32>)
- -> tensor<126x128xf32>
-
- return %0 : tensor<126x128xf32>
-}
-
-
-transform.with_pdl_patterns {
-^bb0(%arg0: !pdl.operation):
- pdl.pattern @pdl_target : benefit(1) {
- %args = operands
- %results = types
- %0 = operation "linalg.matmul"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
- %1 = pdl.attribute = @matmul_tensors
- apply_native_constraint "nestedInFunc"(%0, %1 : !pdl.operation, !pdl.attribute)
- // TODO: we don't want this, but it is the required terminator for pdl.pattern
- rewrite %0 with "transform.dialect"
- }
-
- transform.structured.canonicalized_sequence %arg0 failures(propagate) {
- ^bb1(%arg1: !pdl.operation):
- %0 = pdl_match @pdl_target in %arg1
- %linalg_op, %loops:3 = transform.structured.tile %0 [4, 4, 4]
-
- // Note: The order in which the loops are peeled is important. If %loop#2 is
- // peeled first, the partial iteration of %loop#0 also contains a peeled
- // version of %loop#2.
- // Peeling #0 first is currently not possible as it will invalidate all the
- // nested handles.
- // TODO: extra arguments to specify parts of IR that should not be
- // invalidated when we know that the transform updates in-place.
- transform.loop.peel %loops#2
- transform.loop.peel %loops#0
- }
-}
diff --git a/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/tile-interchange.mlir b/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/tile-interchange.mlir
deleted file mode 100644
index 121dc1c..0000000
--- a/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/tile-interchange.mlir
+++ /dev/null
@@ -1,83 +0,0 @@
-// RUN: iree-dialects-opt --transform-dialect-interpreter --split-input-file %s | FileCheck %s
-
-#map0 = affine_map<(d0, d1, d2) -> (d0, d2)>
-#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
-#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
-
-// Check that vectorization applies after interchange+tiling.
-
-// CHECK-LABEL: @matmul_021
-// CHECK-NOT: linalg.generic
-// CHECK: vector.contract
-func.func public @matmul_021(%arg0: tensor<39x154xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, %arg1: tensor<154x5xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, %arg2: tensor<39x5xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = true}) -> tensor<39x5xf32> attributes {passthrough = ["noinline", ["target-cpu", "skylake-avx512"], ["prefer-vector-width", "512"]]} {
- %0 = linalg.generic {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<39x154xf32>, tensor<154x5xf32>) outs(%arg2 : tensor<39x5xf32>) {
- ^bb0(%arg3: f32, %arg4: f32, %arg5: f32):
- %1 = arith.mulf %arg3, %arg4 : f32
- %2 = arith.addf %arg5, %1 : f32
- linalg.yield %2 : f32
- } -> tensor<39x5xf32>
- return %0 : tensor<39x5xf32>
-}
-
-transform.with_pdl_patterns {
-^bb0(%arg0: !pdl.operation):
- pdl.pattern @target_pattern : benefit(1) {
- %0 = operands
- %1 = types
- %2 = operation "linalg.generic"(%0 : !pdl.range<value>) -> (%1 : !pdl.range<type>)
- %3 = pdl.attribute = @matmul_021
- apply_native_constraint "nestedInFunc"(%2, %3 : !pdl.operation, !pdl.attribute)
- rewrite %2 with "transform.dialect"
- }
-
- transform.structured.canonicalized_sequence %arg0 failures(propagate) {
- ^bb1(%arg1: !pdl.operation):
- %0 = pdl_match @target_pattern in %arg1
- %1, %loops1:3 = transform.structured.tile %0 [3, 5, 14] {interchange = [0, 2, 1]}
- %2, %loops2:3 = transform.structured.tile %1 [3, 5, 2]
- %3 = get_closest_isolated_parent %2
- transform.structured.vectorize %3 {vectorize_padding = true}
- }
-}
-
-// -----
-
-#map0 = affine_map<(d0, d1, d2) -> (d0, d2)>
-#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
-#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
-
-// Check that vectorization applies after interchange+tiling.
-
-// CHECK-LABEL: @matmul_210
-// CHECK-NOT: linalg.generic
-// CHECK: vector.contract
-func.func public @matmul_210(%arg0: tensor<39x154xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, %arg1: tensor<154x5xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, %arg2: tensor<39x5xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = true}) -> tensor<39x5xf32> attributes {passthrough = ["noinline", ["target-cpu", "skylake-avx512"], ["prefer-vector-width", "512"]]} {
- %0 = linalg.generic {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<39x154xf32>, tensor<154x5xf32>) outs(%arg2 : tensor<39x5xf32>) {
- ^bb0(%arg3: f32, %arg4: f32, %arg5: f32):
- %1 = arith.mulf %arg3, %arg4 : f32
- %2 = arith.addf %arg5, %1 : f32
- linalg.yield %2 : f32
- } -> tensor<39x5xf32>
- return %0 : tensor<39x5xf32>
-}
-
-transform.with_pdl_patterns {
-^bb0(%arg0: !pdl.operation):
- pdl.pattern @target_pattern : benefit(1) {
- %0 = operands
- %1 = types
- %2 = operation "linalg.generic"(%0 : !pdl.range<value>) -> (%1 : !pdl.range<type>)
- %3 = pdl.attribute = @matmul_210
- apply_native_constraint "nestedInFunc"(%2, %3 : !pdl.operation, !pdl.attribute)
- rewrite %2 with "transform.dialect"
- }
-
- transform.structured.canonicalized_sequence %arg0 failures(propagate) {
- ^bb1(%arg1: !pdl.operation):
- %0 = pdl_match @target_pattern in %arg1
- %1, %loops1:3 = transform.structured.tile %0 [3, 5, 14] {interchange = [2, 1, 0]}
- %2, %loops2:3 = transform.structured.tile %1 [3, 5, 2]
- %3 = get_closest_isolated_parent %2
- transform.structured.vectorize %3 {vectorize_padding = true}
- }
-}
diff --git a/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/tile.mlir b/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/tile.mlir
deleted file mode 100644
index 86b11ec..0000000
--- a/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/tile.mlir
+++ /dev/null
@@ -1,49 +0,0 @@
-// RUN: iree-dialects-opt --transform-dialect-interpreter %s | FileCheck %s
-
-// CHECK-LABEL: func.func @matmul_tensors(
-// CHECK-SAME: %[[TA:[0-9a-z]+]]: tensor<128x128xf32>
-// CHECK-SAME: %[[TB:[0-9a-z]+]]: tensor<128x128xf32>
-// CHECK-SAME: %[[TC:[0-9a-z]+]]: tensor<128x128xf32>
-// CHECK-SAME: -> tensor<128x128xf32> {
-func.func @matmul_tensors(
- %arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, %arg2: tensor<128x128xf32> { linalg.inplaceable = true})
- -> tensor<128x128xf32> {
-// CHECK: %[[TD0:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC0:.*]] = %[[TC]]) -> (tensor<128x128xf32>) {
-// CHECK: %[[TD1:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC1:.*]] = %[[TC0]]) -> (tensor<128x128xf32>) {
-// CHECK: %[[TD2:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC2:.*]] = %[[TC1]]) -> (tensor<128x128xf32>) {
-// CHECK: %[[sTA:.*]] = tensor.extract_slice %[[TA]][{{.*}}] : tensor<128x128xf32> to tensor<4x4xf32>
-// CHECK: %[[sTB:.*]] = tensor.extract_slice %[[TB]][{{.*}}] : tensor<128x128xf32> to tensor<4x4xf32>
-// CHECK: %[[sTC:.*]] = tensor.extract_slice %[[TC2]][{{.*}}] : tensor<128x128xf32> to tensor<4x4xf32>
-// CHECK: %[[sTD:.*]] = linalg.matmul ins(%[[sTA]], %[[sTB]] : tensor<4x4xf32>, tensor<4x4xf32>)
-// CHECK-SAME: outs(%[[sTC]] : tensor<4x4xf32>) -> tensor<4x4xf32>
-// CHECK: %[[TD:.*]] = tensor.insert_slice %[[sTD]] into %[[TC2]][{{.*}}] : tensor<4x4xf32> into tensor<128x128xf32>
-// CHECK: scf.yield %[[TD]] : tensor<128x128xf32>
-// CHECK: scf.yield %[[TD2]] : tensor<128x128xf32>
-// CHECK: scf.yield %[[TD1]] : tensor<128x128xf32>
- %0 = linalg.matmul ins(%arg0, %arg1: tensor<128x128xf32>, tensor<128x128xf32>)
- outs(%arg2: tensor<128x128xf32>)
- -> tensor<128x128xf32>
-
-// CHECK: return %[[TD0]] : tensor<128x128xf32>
- return %0 : tensor<128x128xf32>
-}
-
-transform.with_pdl_patterns {
-^bb0(%arg0: !pdl.operation):
- pdl.pattern @pdl_target : benefit(1) {
- %args = operands
- %results = types
- %0 = operation "linalg.matmul"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
- %1 = pdl.attribute = @matmul_tensors
- apply_native_constraint "nestedInFunc"(%0, %1 : !pdl.operation, !pdl.attribute)
- // TODO: we don't want this, but it is the required terminator for pdl.pattern
- rewrite %0 with "transform.dialect"
- }
-
- transform.structured.canonicalized_sequence %arg0 failures(propagate) {
- ^bb1(%arg1: !pdl.operation):
- %0 = pdl_match @pdl_target in %arg1
- %1, %loops:3 = transform.structured.tile %0 [4, 4, 4]
- print %1 {name = "Tiled"}
- }
-}
diff --git a/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/vectorize-transforms.mlir b/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/vectorize-transforms.mlir
deleted file mode 100644
index a025273..0000000
--- a/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/vectorize-transforms.mlir
+++ /dev/null
@@ -1,22 +0,0 @@
-// This test only checks the content of the file parses.
-// RUN: iree-dialects-opt %s
-
-transform.with_pdl_patterns {
-^bb0(%arg0: !pdl.operation):
- pdl.pattern @pdl_target : benefit(1) {
- %args = operands
- %results = types
- %0 = operation "linalg.matmul"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
- %1 = pdl.attribute = @matmul_tensors
- apply_native_constraint "nestedInFunc"(%0, %1 : !pdl.operation, !pdl.attribute)
- // TODO: we don't want this, but it is the required terminator for pdl.pattern
- rewrite %0 with "transform.dialect"
- }
-
- transform.structured.canonicalized_sequence %arg0 failures(propagate) {
- ^bb1(%arg1: !pdl.operation):
- %0 = pdl_match @pdl_target in %arg1
- %1 = get_closest_isolated_parent %0
- transform.structured.vectorize %1 {vectorize_padding = true}
- }
-}
diff --git a/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/vectorize.mlir b/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/vectorize.mlir
deleted file mode 100644
index c932cdf..0000000
--- a/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/vectorize.mlir
+++ /dev/null
@@ -1,43 +0,0 @@
-// RUN: iree-dialects-opt --transform-dialect-interpreter=transform-file-name=%p/vectorize-transforms.mlir %s | FileCheck %s
-
-// CHECK-LABEL: func.func @matmul_tensors(
-// CHECK-SAME: %[[TA:[0-9a-z]+]]: tensor<128x128xf32>
-// CHECK-SAME: %[[TB:[0-9a-z]+]]: tensor<128x128xf32>
-// CHECK-SAME: %[[TC:[0-9a-z]+]]: tensor<128x128xf32>
-// CHECK-SAME: -> tensor<128x128xf32> {
-func.func @matmul_tensors(
- %arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, %arg2: tensor<128x128xf32> { linalg.inplaceable = true})
- -> tensor<128x128xf32> {
- // CHECK: %[[VA:.*]] = vector.transfer_read %[[TA]]
- // CHECK: %[[VB:.*]] = vector.transfer_read %[[TB]]
- // CHECK: %[[VC:.*]] = vector.transfer_read %[[TC]]
- // CHECK: %[[VCU:.*]] = vector.contract {{.*}} %[[VA]], %[[VB]], %[[VC]]
- // CHECK: vector.transfer_write %[[VCU]], %[[TC]]
- %0 = linalg.matmul ins(%arg0, %arg1: tensor<128x128xf32>, tensor<128x128xf32>)
- outs(%arg2: tensor<128x128xf32>)
- -> tensor<128x128xf32>
-
- return %0 : tensor<128x128xf32>
-}
-
-// Some dummy functions to exercise TSAN under parallelism.
-func.func @foo1() -> index {
- %0 = arith.constant 1 : index
- return %0 : index
-}
-func.func @foo2() -> index {
- %0 = arith.constant 2 : index
- return %0 : index
-}
-func.func @foo3() -> index {
- %0 = arith.constant 3 : index
- return %0 : index
-}
-func.func @foo4() -> index {
- %0 = arith.constant 4 : index
- return %0 : index
-}
-func.func @foo5() -> index {
- %0 = arith.constant 5 : index
- return %0 : index
-}
diff --git a/runtime/src/iree/base/internal/BUILD b/runtime/src/iree/base/internal/BUILD
index 52c9153..4e3331e 100644
--- a/runtime/src/iree/base/internal/BUILD
+++ b/runtime/src/iree/base/internal/BUILD
@@ -146,6 +146,7 @@
iree_runtime_cc_test(
name = "file_io_test",
srcs = ["file_io_test.cc"],
+ tags = ["requires-filesystem"],
deps = [
":file_io",
"//runtime/src/iree/base:cc",
@@ -212,6 +213,7 @@
iree_runtime_cc_test(
name = "fpu_state_test",
srcs = ["fpu_state_test.cc"],
+ tags = ["requires-dtz"],
deps = [
":fpu_state",
"//runtime/src/iree/testing:gtest",
diff --git a/runtime/src/iree/base/internal/CMakeLists.txt b/runtime/src/iree/base/internal/CMakeLists.txt
index 5e10a79..72c4b7c 100644
--- a/runtime/src/iree/base/internal/CMakeLists.txt
+++ b/runtime/src/iree/base/internal/CMakeLists.txt
@@ -150,6 +150,8 @@
iree::base::core_headers
iree::testing::gtest
iree::testing::gtest_main
+ LABELS
+ "requires-filesystem"
)
iree_cc_library(
@@ -225,6 +227,8 @@
::fpu_state
iree::testing::gtest
iree::testing::gtest_main
+ LABELS
+ "requires-dtz"
)
iree_cc_library(
diff --git a/runtime/src/iree/base/testing/CMakeLists.txt b/runtime/src/iree/base/testing/CMakeLists.txt
index ea2b4c1..5f3f797 100644
--- a/runtime/src/iree/base/testing/CMakeLists.txt
+++ b/runtime/src/iree/base/testing/CMakeLists.txt
@@ -42,4 +42,6 @@
iree::base::internal::file_io
iree::testing::gtest
iree::testing::gtest_main
+ LABELS
+ "requires-filesystem"
)
diff --git a/runtime/src/iree/builtins/ukernel/BUILD b/runtime/src/iree/builtins/ukernel/BUILD
index 616dd2a..6b24498 100644
--- a/runtime/src/iree/builtins/ukernel/BUILD
+++ b/runtime/src/iree/builtins/ukernel/BUILD
@@ -12,31 +12,64 @@
licenses = ["notice"], # Apache 2.0
)
+# :types is the type declarations used by both the entry points and the
+# internal implementation functions.
iree_runtime_cc_library(
- name = "ukernel",
- srcs = [
- "elementwise_generic.c",
- "elementwise_impl.c.inc",
- "mmt4d.c",
- "mmt4d_arm_64.c",
- "mmt4d_generic.c",
- ],
+ name = "types",
hdrs = [
"common.h",
- "elementwise.h",
- "mmt4d.h",
- "mmt4d_arm_64.h",
- "mmt4d_generic.h",
- ],
- copts = [
- # Placeholder for a real flag.
- "-DIREE_UKERNEL_PLATFORM_EXAMPLE_FLAG=1",
- ],
- defines = [
- "IREE_HAVE_UKERNEL_BUILTINS=1",
+ "mmt4d_types.h",
],
deps = [
"//runtime/src/iree/base:core_headers",
- "//runtime/src/iree/schemas:cpu_data",
+ ],
+)
+
+# :generic contains non-architecture-specific implementations.
+iree_runtime_cc_library(
+ name = "generic",
+ srcs = [
+ "mmt4d_tile_generic.c",
+ ],
+ hdrs = [
+ "mmt4d_tile_generic.h",
+ ],
+ deps = [
+ ":types",
+ ],
+)
+
+# elementwise code is structured differently from other kernels. In fact it's
+# profoundly different: it carries its own custom shims. For now, we keep it
+# separate from the rest.
+iree_runtime_cc_library(
+ name = "elementwise",
+ srcs = [
+ "elementwise_generic.c",
+ "elementwise_impl.c.inc",
+ ],
+ hdrs = [
+ "elementwise.h",
+ ],
+ deps = [
+ ":types",
+ ],
+)
+
+# Entry points.
+iree_runtime_cc_library(
+ name = "ukernel",
+ srcs = [
+ "mmt4d.c",
+ ],
+ hdrs = [
+ "elementwise.h",
+ "mmt4d.h",
+ ],
+ deps = [
+ ":elementwise",
+ ":generic",
+ ":types",
+ "//runtime/src/iree/builtins/ukernel/arm_64:mmt4d_tile_arm_64",
],
)
diff --git a/runtime/src/iree/builtins/ukernel/CMakeLists.txt b/runtime/src/iree/builtins/ukernel/CMakeLists.txt
index f514b2f..1367b42 100644
--- a/runtime/src/iree/builtins/ukernel/CMakeLists.txt
+++ b/runtime/src/iree/builtins/ukernel/CMakeLists.txt
@@ -12,26 +12,53 @@
iree_cc_library(
NAME
- ukernel
- COPTS
- "-DIREE_UKERNEL_PLATFORM_EXAMPLE_FLAG=1"
+ types
HDRS
"common.h"
+ "mmt4d_types.h"
+ DEPS
+ iree::base::core_headers
+ PUBLIC
+)
+
+iree_cc_library(
+ NAME
+ generic
+ HDRS
+ "mmt4d_tile_generic.h"
+ SRCS
+ "mmt4d_tile_generic.c"
+ DEPS
+ ::types
+ PUBLIC
+)
+
+iree_cc_library(
+ NAME
+ elementwise
+ HDRS
"elementwise.h"
- "mmt4d.h"
- "mmt4d_arm_64.h"
- "mmt4d_generic.h"
SRCS
"elementwise_generic.c"
"elementwise_impl.c.inc"
- "mmt4d.c"
- "mmt4d_arm_64.c"
- "mmt4d_generic.c"
DEPS
- iree::base::core_headers
- iree::schemas::cpu_data
- DEFINES
- "IREE_HAVE_UKERNEL_BUILTINS=1"
+ ::types
+ PUBLIC
+)
+
+iree_cc_library(
+ NAME
+ ukernel
+ HDRS
+ "elementwise.h"
+ "mmt4d.h"
+ SRCS
+ "mmt4d.c"
+ DEPS
+ ::elementwise
+ ::generic
+ ::types
+ iree::builtins::ukernel::arm_64::mmt4d_tile_arm_64
PUBLIC
)
diff --git a/runtime/src/iree/builtins/ukernel/arm_64/BUILD b/runtime/src/iree/builtins/ukernel/arm_64/BUILD
new file mode 100644
index 0000000..03e1f87
--- /dev/null
+++ b/runtime/src/iree/builtins/ukernel/arm_64/BUILD
@@ -0,0 +1,20 @@
+# Copyright 2022 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+load("//build_tools/bazel:build_defs.oss.bzl", "iree_runtime_cc_library")
+
+package(
+ default_visibility = ["//visibility:public"],
+ features = ["layering_check"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+iree_runtime_cc_library(
+ name = "mmt4d_tile_arm_64",
+ hdrs = [
+ "mmt4d_tile_arm_64.h",
+ ],
+)
diff --git a/runtime/src/iree/builtins/ukernel/arm_64/CMakeLists.txt b/runtime/src/iree/builtins/ukernel/arm_64/CMakeLists.txt
new file mode 100644
index 0000000..d8aab27
--- /dev/null
+++ b/runtime/src/iree/builtins/ukernel/arm_64/CMakeLists.txt
@@ -0,0 +1,55 @@
+if ((CMAKE_SYSTEM_PROCESSOR STREQUAL aarch64) OR (CMAKE_SYSTEM_PROCESSOR STREQUAL arm64))
+
+check_cxx_compiler_flag("-march=armv8.2-a+dotprod" HAVE_FLAG_MARCH_ARMV8_2_A_DOTPROD)
+if(HAVE_FLAG_MARCH_ARMV8_2_A_DOTPROD)
+ iree_cc_library(
+ NAME
+ mmt4d_tile_arm_64_dotprod
+ SRCS
+ "mmt4d_tile_arm_64_dotprod.S"
+ COPTS
+ "-march=armv8.2-a+dotprod"
+ )
+ list(APPEND MMT4D_ARM_64_VARIANT_DEPS "iree::builtins::ukernel::arm_64::mmt4d_tile_arm_64_dotprod")
+endif()
+
+check_cxx_compiler_flag("-march=armv8.2-a+i8mm" HAVE_FLAG_MARCH_ARMV8_2_A_I8MM)
+if(HAVE_FLAG_MARCH_ARMV8_2_A_I8MM)
+ iree_cc_library(
+ NAME
+ mmt4d_tile_arm_64_i8mm
+ SRCS
+ "mmt4d_tile_arm_64_i8mm.S"
+ COPTS
+ "-march=armv8.2-a+i8mm"
+ )
+ list(APPEND MMT4D_ARM_64_VARIANT_DEPS "iree::builtins::ukernel::arm_64::mmt4d_tile_arm_64_i8mm")
+endif()
+
+configure_file(config.h.in config.h)
+
+iree_cc_library(
+ NAME
+ mmt4d_tile_arm_64
+ HDRS
+ "mmt4d_tile_arm_64.h"
+ SRCS
+ "mmt4d_tile_arm_64.c"
+ DEPS
+ iree::base::core_headers
+ iree::schemas::cpu_data
+ iree::builtins::ukernel::types
+ ${MMT4D_ARM_64_VARIANT_DEPS}
+ PUBLIC
+)
+
+else() # Not CMAKE_SYSTEM_PROCESSOR STREQUAL aarch64 / arm64
+
+iree_cc_library(
+ NAME
+ mmt4d_tile_arm_64
+ HDRS
+ "mmt4d_tile_arm_64.h"
+)
+
+endif()
diff --git a/runtime/src/iree/builtins/ukernel/arm_64/assembly.h b/runtime/src/iree/builtins/ukernel/arm_64/assembly.h
new file mode 100644
index 0000000..03537f8
--- /dev/null
+++ b/runtime/src/iree/builtins/ukernel/arm_64/assembly.h
@@ -0,0 +1,49 @@
+// Borrowed from XNNPACK's assembly.h (thanks!)
+// clang-format off
+#ifdef __wasm__
+ .macro BEGIN_FUNCTION name
+ .text
+ .section .text.\name,"",@
+ .hidden \name
+ .globl \name
+ .type \name,@function
+ \name:
+ .endm
+
+ .macro END_FUNCTION name
+ end_function
+ .endm
+#elif defined(__ELF__)
+ .macro BEGIN_FUNCTION name
+ .text
+ .p2align 4
+ .global \name
+ .hidden \name
+ .type \name, %function
+ \name:
+ .endm
+
+ .macro END_FUNCTION name
+ .size \name, .-\name
+ .endm
+#elif defined(__MACH__)
+ .macro BEGIN_FUNCTION name
+ .text
+ .p2align 4
+ .global _\name
+ .private_extern _\name
+ _\name:
+ .endm
+
+ .macro END_FUNCTION name
+ .endm
+#endif
+
+#ifdef __ELF__
+ .macro ALLOW_NON_EXECUTABLE_STACK
+ .section ".note.GNU-stack","",%progbits
+ .endm
+#else
+ .macro ALLOW_NON_EXECUTABLE_STACK
+ .endm
+#endif
diff --git a/runtime/src/iree/builtins/ukernel/arm_64/config.h.in b/runtime/src/iree/builtins/ukernel/arm_64/config.h.in
new file mode 100644
index 0000000..7dd0fd6
--- /dev/null
+++ b/runtime/src/iree/builtins/ukernel/arm_64/config.h.in
@@ -0,0 +1,2 @@
+#cmakedefine HAVE_FLAG_MARCH_ARMV8_2_A_DOTPROD
+#cmakedefine HAVE_FLAG_MARCH_ARMV8_2_A_I8MM
diff --git a/runtime/src/iree/builtins/ukernel/arm_64/mmt4d_tile_arm_64.c b/runtime/src/iree/builtins/ukernel/arm_64/mmt4d_tile_arm_64.c
new file mode 100644
index 0000000..e4f1a1d
--- /dev/null
+++ b/runtime/src/iree/builtins/ukernel/arm_64/mmt4d_tile_arm_64.c
@@ -0,0 +1,67 @@
+// Copyright 2022 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/builtins/ukernel/arm_64/mmt4d_tile_arm_64.h"
+
+#include "iree/builtins/ukernel/arm_64/config.h"
+#include "iree/schemas/cpu_data.h"
+
+#if defined(IREE_UKERNEL_ARCH_ARM_64)
+
+IREE_UKERNEL_MMT4D_TILE_FUNC_DECL(
+ iree_ukernel_mmt4d_i8i8i32_tile_8x8x4_arm_64_dotprod)
+IREE_UKERNEL_MMT4D_TILE_FUNC_DECL(
+ iree_ukernel_mmt4d_i8i8i32_tile_8x8x8_arm_64_i8mm)
+
+static iree_ukernel_mmt4d_tile_func_t
+iree_ukernel_mmt4d_select_tile_func_arm_64_i8i8i32_8x8x8(
+ const iree_ukernel_mmt4d_params_t* params) {
+#ifdef HAVE_FLAG_MARCH_ARMV8_2_A_I8MM
+ if (params->cpu_data_field_0 & IREE_CPU_DATA_FIELD_0_AARCH64_HAVE_I8MM) {
+ return iree_ukernel_mmt4d_i8i8i32_tile_8x8x8_arm_64_i8mm;
+ }
+#else
+ (void)params;
+#endif
+ return 0;
+}
+
+static iree_ukernel_mmt4d_tile_func_t
+iree_ukernel_mmt4d_select_tile_func_arm_64_i8i8i32_8x8x4(
+ const iree_ukernel_mmt4d_params_t* params) {
+#ifdef HAVE_FLAG_MARCH_ARMV8_2_A_DOTPROD
+ if (params->cpu_data_field_0 & IREE_CPU_DATA_FIELD_0_AARCH64_HAVE_DOTPROD) {
+ return iree_ukernel_mmt4d_i8i8i32_tile_8x8x4_arm_64_dotprod;
+ }
+#else
+ (void)params;
+#endif
+ return 0;
+}
+
+static iree_ukernel_mmt4d_tile_func_t
+iree_ukernel_mmt4d_select_tile_func_arm_64_i8i8i32(
+ const iree_ukernel_mmt4d_params_t* params) {
+ if (params->M0 == 8 && params->N0 == 8 && params->K0 == 8) {
+ return iree_ukernel_mmt4d_select_tile_func_arm_64_i8i8i32_8x8x8(params);
+ }
+ if (params->M0 == 8 && params->N0 == 8 && params->K0 == 4) {
+ return iree_ukernel_mmt4d_select_tile_func_arm_64_i8i8i32_8x8x4(params);
+ }
+ return 0;
+}
+
+iree_ukernel_mmt4d_tile_func_t iree_ukernel_mmt4d_select_tile_func_arm_64(
+ const iree_ukernel_mmt4d_params_t* params) {
+ switch (params->type) {
+ case iree_ukernel_mmt4d_type_i8i8i32:
+ return iree_ukernel_mmt4d_select_tile_func_arm_64_i8i8i32(params);
+ default:
+ return 0;
+ }
+}
+
+#endif // IREE_UKERNEL_ARCH_ARM_64
diff --git a/runtime/src/iree/builtins/ukernel/arm_64/mmt4d_tile_arm_64.h b/runtime/src/iree/builtins/ukernel/arm_64/mmt4d_tile_arm_64.h
new file mode 100644
index 0000000..7ad6813
--- /dev/null
+++ b/runtime/src/iree/builtins/ukernel/arm_64/mmt4d_tile_arm_64.h
@@ -0,0 +1,22 @@
+// Copyright 2022 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_BUILTINS_UKERNEL_ARM_64_MMT4D_ARM_64_H_
+#define IREE_BUILTINS_UKERNEL_ARM_64_MMT4D_ARM_64_H_
+
+#include "iree/builtins/ukernel/mmt4d_types.h"
+
+#if defined(IREE_UKERNEL_ARCH_ARM_64)
+
+// Returns the arm64 tile function to use for the mmt4d with given params, or
+// NULL if no suitable arm64 tile function exists for these params, in which
+// case the caller may fall back to a generic tile function.
+iree_ukernel_mmt4d_tile_func_t iree_ukernel_mmt4d_select_tile_func_arm_64(
+ const iree_ukernel_mmt4d_params_t* params);
+
+#endif // IREE_UKERNEL_ARCH_ARM_64
+
+#endif // IREE_BUILTINS_UKERNEL_ARM_64_MMT4D_ARM_64_H_
diff --git a/runtime/src/iree/builtins/ukernel/arm_64/mmt4d_tile_arm_64_dotprod.S b/runtime/src/iree/builtins/ukernel/arm_64/mmt4d_tile_arm_64_dotprod.S
new file mode 100644
index 0000000..2cc9bc8
--- /dev/null
+++ b/runtime/src/iree/builtins/ukernel/arm_64/mmt4d_tile_arm_64_dotprod.S
@@ -0,0 +1,101 @@
+#include "assembly.h"
+
+// TODO: share these bits with C/C++.
+.equ ACCUMULATE_FLAG_BIT_POS,0
+
+// Parameters:
+// x0: int32_t* out_tile
+// x1: const int8_t* lhs_panel
+// x2: const int8_t* rhs_panel
+// w3: int32_t K
+// w4: uint32_t flags
+// x5: (UNUSED) params - relevant params K and flags already passed above.
+
+BEGIN_FUNCTION iree_ukernel_mmt4d_i8i8i32_tile_8x8x4_arm_64_dotprod
+
+ // Do we accumulate into or clear the accumulator tile?
+ tbnz w4, ACCUMULATE_FLAG_BIT_POS, 1f
+
+ 0:
+ // No-accumulate case. Clear the 8x8 accumulator tile.
+ movi v16.16b, 0
+ movi v17.16b, 0
+ movi v18.16b, 0
+ movi v19.16b, 0
+ movi v20.16b, 0
+ movi v21.16b, 0
+ movi v22.16b, 0
+ movi v23.16b, 0
+ movi v24.16b, 0
+ movi v25.16b, 0
+ movi v26.16b, 0
+ movi v27.16b, 0
+ movi v28.16b, 0
+ movi v29.16b, 0
+ movi v30.16b, 0
+ movi v31.16b, 0
+ b 2f
+
+ 1:
+ // Accumulate case. Load the 8x8 accumulator tile from row-major
+ // out_tile, into temporary registers v16--v31.
+ ldp q16, q17, [x0, 0]
+ ldp q18, q19, [x0, 32]
+ ldp q20, q21, [x0, 64]
+ ldp q22, q23, [x0, 96]
+ ldp q24, q25, [x0, 128]
+ ldp q26, q27, [x0, 160]
+ ldp q28, q29, [x0, 192]
+ ldp q30, q31, [x0, 224]
+
+ 2:
+ // Loop body. Decrement the loop counter K.
+ subs w3, w3, 1
+ // Load 8x4 LHS tile
+ ldp q0, q1, [x1, 0]
+ add x1, x1, 32
+ // Load 8x4 RHS tile
+ ldp q4, q5, [x2, 0]
+ add x2, x2, 32
+ // Multiply-accumulate, row 0.
+ sdot v16.4s, v4.16b, v0.4b[0]
+ sdot v17.4s, v5.16b, v0.4b[0]
+ // Multiply-accumulate, row 1.
+ sdot v18.4s, v4.16b, v0.4b[1]
+ sdot v19.4s, v5.16b, v0.4b[1]
+ // Multiply-accumulate, row 2.
+ sdot v20.4s, v4.16b, v0.4b[2]
+ sdot v21.4s, v5.16b, v0.4b[2]
+ // Multiply-accumulate, row 3.
+ sdot v22.4s, v4.16b, v0.4b[3]
+ sdot v23.4s, v5.16b, v0.4b[3]
+ // Multiply-accumulate, row 4.
+ sdot v24.4s, v4.16b, v1.4b[0]
+ sdot v25.4s, v5.16b, v1.4b[0]
+ // Multiply-accumulate, row 5.
+ sdot v26.4s, v4.16b, v1.4b[1]
+ sdot v27.4s, v5.16b, v1.4b[1]
+ // Multiply-accumulate, row 6.
+ sdot v28.4s, v4.16b, v1.4b[2]
+ sdot v29.4s, v5.16b, v1.4b[2]
+ // Multiply-accumulate, row 7.
+ sdot v30.4s, v4.16b, v1.4b[3]
+ sdot v31.4s, v5.16b, v1.4b[3]
+ // Loop if K != 0.
+ b.ne 2b
+
+ 3:
+ // Store the accumulator tile to the destination.
+ stp q16, q17, [x0, 0]
+ stp q18, q19, [x0, 32]
+ stp q20, q21, [x0, 64]
+ stp q22, q23, [x0, 96]
+ stp q24, q25, [x0, 128]
+ stp q26, q27, [x0, 160]
+ stp q28, q29, [x0, 192]
+ stp q30, q31, [x0, 224]
+ ret
+
+END_FUNCTION iree_ukernel_mmt4d_i8i8i32_tile_8x8x4_arm_64_dotprod
+
+ALLOW_NON_EXECUTABLE_STACK
diff --git a/runtime/src/iree/builtins/ukernel/arm_64/mmt4d_tile_arm_64_i8mm.S b/runtime/src/iree/builtins/ukernel/arm_64/mmt4d_tile_arm_64_i8mm.S
new file mode 100644
index 0000000..108a61b
--- /dev/null
+++ b/runtime/src/iree/builtins/ukernel/arm_64/mmt4d_tile_arm_64_i8mm.S
@@ -0,0 +1,139 @@
+#include "assembly.h"
+
+// TODO: share these bits with C/C++.
+.equ ACCUMULATE_FLAG_BIT_POS,0
+
+// Parameters:
+// x0: int32_t* out_tile
+// x1: const int8_t* lhs_panel
+// x2: const int8_t* rhs_panel
+// w3: int32_t K
+// w4: uint32_t flags
+// x5: (UNUSED) params - relevant params K and flags already passed above.
+
+BEGIN_FUNCTION iree_ukernel_mmt4d_i8i8i32_tile_8x8x8_arm_64_i8mm
+
+ // Do we accumulate into or clear the accumulator tile?
+ tbnz w4, ACCUMULATE_FLAG_BIT_POS, 1f
+
+ 0:
+ // No-accumulate case. Clear the 8x8 accumulator tile.
+ movi v16.16b, 0
+ movi v17.16b, 0
+ movi v18.16b, 0
+ movi v19.16b, 0
+ movi v20.16b, 0
+ movi v21.16b, 0
+ movi v22.16b, 0
+ movi v23.16b, 0
+ movi v24.16b, 0
+ movi v25.16b, 0
+ movi v26.16b, 0
+ movi v27.16b, 0
+ movi v28.16b, 0
+ movi v29.16b, 0
+ movi v30.16b, 0
+ movi v31.16b, 0
+ b 2f
+
+ 1:
+ // Accumulate case. Load the 8x8 accumulator tile from row-major
+ // out_tile, into temporary registers v0--v15.
+ ldp q0, q1, [x0, 0]
+ ldp q2, q3, [x0, 32]
+ ldp q4, q5, [x0, 64]
+ ldp q6, q7, [x0, 96]
+ ldp q8, q9, [x0, 128]
+ ldp q10, q11, [x0, 160]
+ ldp q12, q13, [x0, 192]
+ ldp q14, q15, [x0, 224]
+ // Swizzle in 2x2 tiles for smmla, rows 0--1.
+ zip1 v16.2d, v0.2d, v2.2d
+ zip2 v17.2d, v0.2d, v2.2d
+ zip1 v18.2d, v1.2d, v3.2d
+ zip2 v19.2d, v1.2d, v3.2d
+ // Swizzle in 2x2 tiles for smmla, rows 2--3.
+ zip1 v20.2d, v4.2d, v6.2d
+ zip2 v21.2d, v4.2d, v6.2d
+ zip1 v22.2d, v5.2d, v7.2d
+ zip2 v23.2d, v5.2d, v7.2d
+ // Swizzle in 2x2 tiles for smmla, rows 4--5.
+ zip1 v24.2d, v8.2d, v10.2d
+ zip2 v25.2d, v8.2d, v10.2d
+ zip1 v26.2d, v9.2d, v11.2d
+ zip2 v27.2d, v9.2d, v11.2d
+ // Swizzle in 2x2 tiles for smmla, rows 6--7.
+ zip1 v28.2d, v12.2d, v14.2d
+ zip2 v29.2d, v12.2d, v14.2d
+ zip1 v30.2d, v13.2d, v15.2d
+ zip2 v31.2d, v13.2d, v15.2d
+
+ 2:
+ // Loop body. Decrement the loop counter K.
+ subs w3, w3, 1
+ // Load 8x8 LHS tile
+ ldp q0, q1, [x1, 0]
+ ldp q2, q3, [x1, 32]
+ add x1, x1, 64
+ // Load 8x8 RHS tile
+ ldp q4, q5, [x2, 0]
+ ldp q6, q7, [x2, 32]
+ add x2, x2, 64
+ // Multiply-accumulate, rows 0--1.
+ smmla v16.4s, v0.16b, v4.16b
+ smmla v17.4s, v0.16b, v5.16b
+ smmla v18.4s, v0.16b, v6.16b
+ smmla v19.4s, v0.16b, v7.16b
+ // Multiply-accumulate, rows 2--3.
+ smmla v20.4s, v1.16b, v4.16b
+ smmla v21.4s, v1.16b, v5.16b
+ smmla v22.4s, v1.16b, v6.16b
+ smmla v23.4s, v1.16b, v7.16b
+ // Multiply-accumulate, rows 4--5.
+ smmla v24.4s, v2.16b, v4.16b
+ smmla v25.4s, v2.16b, v5.16b
+ smmla v26.4s, v2.16b, v6.16b
+ smmla v27.4s, v2.16b, v7.16b
+ // Multiply-accumulate, rows 6--7.
+ smmla v28.4s, v3.16b, v4.16b
+ smmla v29.4s, v3.16b, v5.16b
+ smmla v30.4s, v3.16b, v6.16b
+ smmla v31.4s, v3.16b, v7.16b
+ // Loop if K != 0.
+ b.ne 2b
+
+ 3:
+ // Swizzle back to row-major, rows 0--1.
+ uzp1 v0.2d, v16.2d, v17.2d
+ uzp1 v1.2d, v18.2d, v19.2d
+ uzp2 v2.2d, v16.2d, v17.2d
+ uzp2 v3.2d, v18.2d, v19.2d
+ // Swizzle back to row-major, rows 2--3.
+ uzp1 v4.2d, v20.2d, v21.2d
+ uzp1 v5.2d, v22.2d, v23.2d
+ uzp2 v6.2d, v20.2d, v21.2d
+ uzp2 v7.2d, v22.2d, v23.2d
+ // Swizzle back to row-major, rows 4--5.
+ uzp1 v8.2d, v24.2d, v25.2d
+ uzp1 v9.2d, v26.2d, v27.2d
+ uzp2 v10.2d, v24.2d, v25.2d
+ uzp2 v11.2d, v26.2d, v27.2d
+ // Swizzle back to row-major, rows 6--7.
+ uzp1 v12.2d, v28.2d, v29.2d
+ uzp1 v13.2d, v30.2d, v31.2d
+ uzp2 v14.2d, v28.2d, v29.2d
+ uzp2 v15.2d, v30.2d, v31.2d
+ // Store the accumulator tile to the destination.
+ stp q0, q1, [x0, 0]
+ stp q2, q3, [x0, 32]
+ stp q4, q5, [x0, 64]
+ stp q6, q7, [x0, 96]
+ stp q8, q9, [x0, 128]
+ stp q10, q11, [x0, 160]
+ stp q12, q13, [x0, 192]
+ stp q14, q15, [x0, 224]
+ ret
+
+END_FUNCTION iree_ukernel_mmt4d_i8i8i32_tile_8x8x8_arm_64_i8mm
+
+ALLOW_NON_EXECUTABLE_STACK
diff --git a/runtime/src/iree/builtins/ukernel/common.h b/runtime/src/iree/builtins/ukernel/common.h
index 0ed9455..30096f9 100644
--- a/runtime/src/iree/builtins/ukernel/common.h
+++ b/runtime/src/iree/builtins/ukernel/common.h
@@ -38,7 +38,6 @@
// These two headers are clean and do not include any other headers:
#include "iree/base/attributes.h"
#include "iree/base/target_platform.h"
-#include "iree/schemas/cpu_data.h"
#ifdef __cplusplus
extern "C" {
@@ -57,7 +56,7 @@
#define IREE_UKERNEL_SIZE_TYPE int32_t
#elif defined(IREE_UKERNEL_ARCH_GENERIC_64)
#define IREE_UKERNEL_SIZE_TYPE int64_t
-#elif defined(IREE_ARCH_ARM_64)
+#elif defined(IREE_ARCH_ARM_64) && !defined(__APPLE__)
#define IREE_UKERNEL_ARCH_ARM_64 1
#define IREE_UKERNEL_SIZE_TYPE int64_t
#else
@@ -116,7 +115,7 @@
#endif // !INT8_MIN
-// Use iree_mmt4d_size_t for all sizes that may need pointer width.
+// Use iree_ukernel_size_t for all sizes that may need pointer width.
// For any argument that is known to fit in a specific size prefer that to
// ensure this code operates well on systems with small/weird widths (x32/ilp32,
// etc).
diff --git a/runtime/src/iree/builtins/ukernel/mmt4d.c b/runtime/src/iree/builtins/ukernel/mmt4d.c
index 7ea816d..7dbe3ac 100644
--- a/runtime/src/iree/builtins/ukernel/mmt4d.c
+++ b/runtime/src/iree/builtins/ukernel/mmt4d.c
@@ -7,56 +7,116 @@
#include "iree/builtins/ukernel/mmt4d.h"
#if defined(IREE_UKERNEL_ARCH_ARM_64)
-#include "iree/builtins/ukernel/mmt4d_arm_64.h"
+#include "iree/builtins/ukernel/arm_64/mmt4d_tile_arm_64.h"
#endif
-#if defined(IREE_UKERNEL_ARCH_GENERIC_32) || \
- defined(IREE_UKERNEL_ARCH_GENERIC_64)
-#include "iree/builtins/ukernel/mmt4d_generic.h"
-#endif
+#include "iree/builtins/ukernel/mmt4d_tile_generic.h"
-IREE_UKERNEL_EXPORT int iree_ukernel_mmt4d_f32f32f32(
- const iree_ukernel_mmt4d_f32f32f32_params_t* params) {
+#define OUTSIDE_UINT_RANGE(value, bits) (((value) < 0) || ((value) >> (bits)))
+
+static iree_ukernel_mmt4d_status_t iree_ukernel_mmt4d_validate(
+ const iree_ukernel_mmt4d_params_t* params) {
if (params->flags & ~IREE_VMVX_MATMUL_FLAG_ACCUMULATE) {
- return IREE_UKERNEL_MMT4D_ERROR_BAD_FLAGS;
+ return iree_ukernel_mmt4d_status_bad_flags;
}
-
-#if defined(IREE_UKERNEL_ARCH_ARM_64)
- return iree_ukernel_mmt4d_f32f32f32_arm_64(params);
-#endif
-
-#if defined(IREE_UKERNEL_ARCH_GENERIC_32) || \
- defined(IREE_UKERNEL_ARCH_GENERIC_64)
- return iree_ukernel_mmt4d_f32f32f32_generic(params);
-#endif
-
- return IREE_UKERNEL_MMT4D_ERROR_UNIMPLEMENTED;
+ switch (params->type) {
+ case iree_ukernel_mmt4d_type_f32f32f32:
+ case iree_ukernel_mmt4d_type_i8i8i32:
+ break;
+ default:
+ return iree_ukernel_mmt4d_status_bad_type;
+ }
+ // Some implementations may wish to avoid supporting absurdly wide types. For
+ // instance, K is the innermost (i.e. hottest) loop bound, so some 32bit
+ // targets may benefit from K being int32, not int64. We still let K be of
+ // type int64 to be future-proof, as types are hard to change later. But we
+ // enforce a narrower range here, as we can always relax that later as needed.
+ if (OUTSIDE_UINT_RANGE(params->M, 31) || OUTSIDE_UINT_RANGE(params->M, 31) ||
+ OUTSIDE_UINT_RANGE(params->K, 31) || OUTSIDE_UINT_RANGE(params->M0, 15) ||
+ OUTSIDE_UINT_RANGE(params->M0, 15) ||
+ OUTSIDE_UINT_RANGE(params->K0, 15)) {
+ return iree_ukernel_mmt4d_status_unsupported_huge_or_negative_dimension;
+ }
+ return iree_ukernel_mmt4d_status_ok;
}
-IREE_UKERNEL_EXPORT int iree_ukernel_mmt4d_i8i8i32(
- const iree_ukernel_mmt4d_i8i8i32_params_t* params) {
- if (params->flags & ~IREE_VMVX_MATMUL_FLAG_ACCUMULATE) {
- return IREE_UKERNEL_MMT4D_ERROR_BAD_FLAGS;
- }
-
+// On success, *out_tile_func is the tile function to use to perform the mmt4d
+// with the given *params.
+static iree_ukernel_mmt4d_status_t iree_ukernel_mmt4d_select_tile_func(
+ const iree_ukernel_mmt4d_params_t* params,
+ iree_ukernel_mmt4d_tile_func_t* out_tile_func) {
+ iree_ukernel_mmt4d_tile_func_t arch_tile_func = 0;
#if defined(IREE_UKERNEL_ARCH_ARM_64)
- return iree_ukernel_mmt4d_i8i8i32_arm_64(params);
+ arch_tile_func = iree_ukernel_mmt4d_select_tile_func_arm_64(params);
#endif
-
-#if defined(IREE_UKERNEL_ARCH_GENERIC_32) || \
- defined(IREE_UKERNEL_ARCH_GENERIC_64)
- return iree_ukernel_mmt4d_i8i8i32_generic(params);
-#endif
-
- return IREE_UKERNEL_MMT4D_ERROR_UNIMPLEMENTED;
+ if (arch_tile_func) {
+ *out_tile_func = arch_tile_func;
+ return iree_ukernel_mmt4d_status_ok;
+ }
+ return iree_ukernel_mmt4d_select_tile_func_generic(params, out_tile_func);
}
-const char* iree_ukernel_mmt4d_error_message(int retcode) {
- switch (retcode) {
- case IREE_UKERNEL_MMT4D_ERROR_UNIMPLEMENTED:
- return "hit unimplemented code path in mmt4d";
- case IREE_UKERNEL_MMT4D_ERROR_BAD_FLAGS:
+// General mmt4d implementation, shared among all cases. The idea is that the
+// only really performance-critical part is the inner-most loop, and that's
+// handled by the tile_func passed as argument here. Sharing the outer loops
+// across all cases is a roughly 2x code shrink compared to if we were
+// emitting the whole loop nest for each case.
+static void iree_ukernel_mmt4d_using_tile_func(
+ const iree_ukernel_mmt4d_params_t* params,
+ iree_ukernel_mmt4d_tile_func_t tile_func) {
+ const int32_t M = params->M;
+ const int32_t N = params->N;
+ const int32_t K = params->K;
+ const int16_t M0 = params->M0;
+ const int16_t N0 = params->N0;
+ const int16_t lhs_elem_size_log2 =
+ iree_ukernel_mmt4d_lhs_elem_size_log2(params->type);
+ const int16_t rhs_elem_size_log2 =
+ iree_ukernel_mmt4d_rhs_elem_size_log2(params->type);
+ const int16_t out_elem_size_log2 =
+ iree_ukernel_mmt4d_out_elem_size_log2(params->type);
+ char* out_tile_row = params->out_buffer;
+ const char* lhs_panel = params->lhs_buffer;
+ int32_t out_tile_size = (M0 * N0) << out_elem_size_log2;
+ iree_ukernel_size_t lhs_panel_stride = params->lhs_stride
+ << lhs_elem_size_log2;
+ iree_ukernel_size_t rhs_panel_stride = params->rhs_stride
+ << rhs_elem_size_log2;
+ iree_ukernel_size_t out_stride = params->out_stride << out_elem_size_log2;
+ for (int32_t i = 0; i < M; ++i) {
+ char* out_tile = out_tile_row;
+ const char* rhs_panel = params->rhs_buffer;
+ for (int32_t j = 0; j < N; ++j) {
+ tile_func(out_tile, lhs_panel, rhs_panel, K, params->flags, params);
+ out_tile += out_tile_size;
+ rhs_panel += rhs_panel_stride;
+ }
+ out_tile_row += out_stride;
+ lhs_panel += lhs_panel_stride;
+ }
+}
+
+IREE_UKERNEL_EXPORT iree_ukernel_mmt4d_status_t
+iree_ukernel_mmt4d(const iree_ukernel_mmt4d_params_t* params) {
+ IREE_UKERNEL_MMT4D_RETURN_IF_ERROR(iree_ukernel_mmt4d_validate(params));
+ iree_ukernel_mmt4d_tile_func_t tile_func;
+ IREE_UKERNEL_MMT4D_RETURN_IF_ERROR(
+ iree_ukernel_mmt4d_select_tile_func(params, &tile_func));
+ iree_ukernel_mmt4d_using_tile_func(params, tile_func);
+ return iree_ukernel_mmt4d_status_ok;
+}
+
+IREE_UKERNEL_EXPORT const char* iree_ukernel_mmt4d_status_message(
+ iree_ukernel_mmt4d_status_t status) {
+ switch (status) {
+ case iree_ukernel_mmt4d_status_bad_flags:
return "bad mmt4d flags";
+ case iree_ukernel_mmt4d_status_bad_type:
+ return "bad mmt4d type enum";
+ case iree_ukernel_mmt4d_status_unsupported_huge_or_negative_dimension:
+ return "unsupported huge or negative size in mmt4d";
+ case iree_ukernel_mmt4d_status_unsupported_generic_tile_size:
+ return "tile size too large for the generic tile implementation";
default:
return "unknown";
}
diff --git a/runtime/src/iree/builtins/ukernel/mmt4d.h b/runtime/src/iree/builtins/ukernel/mmt4d.h
index 407374a..51e512a 100644
--- a/runtime/src/iree/builtins/ukernel/mmt4d.h
+++ b/runtime/src/iree/builtins/ukernel/mmt4d.h
@@ -7,61 +7,19 @@
#ifndef IREE_BUILTINS_UKERNEL_MMT4D_H_
#define IREE_BUILTINS_UKERNEL_MMT4D_H_
-#include "iree/builtins/ukernel/common.h"
+#include "iree/builtins/ukernel/mmt4d_types.h"
#ifdef __cplusplus
extern "C" {
#endif // __cplusplus
-struct iree_ukernel_mmt4d_f32f32f32_params_t {
- const float* lhs_buffer;
- const float* rhs_buffer;
- float* out_buffer;
- iree_ukernel_size_t lhs_stride;
- iree_ukernel_size_t rhs_stride;
- iree_ukernel_size_t out_stride;
- iree_ukernel_size_t M;
- iree_ukernel_size_t N;
- iree_ukernel_size_t K;
- int32_t M0;
- int32_t N0;
- int32_t K0;
- uint32_t flags;
-};
+// Main entry point.
+IREE_UKERNEL_EXPORT iree_ukernel_mmt4d_status_t
+iree_ukernel_mmt4d(const iree_ukernel_mmt4d_params_t* params);
-struct iree_ukernel_mmt4d_i8i8i32_params_t {
- const int8_t* lhs_buffer;
- const int8_t* rhs_buffer;
- int32_t* out_buffer;
- iree_ukernel_size_t lhs_stride;
- iree_ukernel_size_t rhs_stride;
- iree_ukernel_size_t out_stride;
- iree_ukernel_size_t M;
- iree_ukernel_size_t N;
- iree_ukernel_size_t K;
- int32_t M0;
- int32_t N0;
- int32_t K0;
- uint32_t flags;
-};
-
-typedef struct iree_ukernel_mmt4d_f32f32f32_params_t
- iree_ukernel_mmt4d_f32f32f32_params_t;
-typedef struct iree_ukernel_mmt4d_i8i8i32_params_t
- iree_ukernel_mmt4d_i8i8i32_params_t;
-
-#define IREE_UKERNEL_MMT4D_ERROR_UNIMPLEMENTED 1
-#define IREE_UKERNEL_MMT4D_ERROR_BAD_FLAGS 2
-
-// TODO: move these flags to a header file shared with compiler/.
-#define IREE_VMVX_MATMUL_FLAG_ACCUMULATE 1
-
-IREE_UKERNEL_EXPORT int iree_ukernel_mmt4d_f32f32f32(
- const iree_ukernel_mmt4d_f32f32f32_params_t* params);
-IREE_UKERNEL_EXPORT int iree_ukernel_mmt4d_i8i8i32(
- const iree_ukernel_mmt4d_i8i8i32_params_t* params);
-
-IREE_UKERNEL_EXPORT const char* iree_ukernel_mmt4d_error_message(int retcode);
+// Convert a status code to a human-readable string.
+IREE_UKERNEL_EXPORT const char* iree_ukernel_mmt4d_status_message(
+ iree_ukernel_mmt4d_status_t status);
#ifdef __cplusplus
} // extern "C"
diff --git a/runtime/src/iree/builtins/ukernel/mmt4d_arm_64.c b/runtime/src/iree/builtins/ukernel/mmt4d_arm_64.c
deleted file mode 100644
index ef077b6..0000000
--- a/runtime/src/iree/builtins/ukernel/mmt4d_arm_64.c
+++ /dev/null
@@ -1,26 +0,0 @@
-// Copyright 2022 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#include "iree/builtins/ukernel/mmt4d_arm_64.h"
-
-// TODO: once actual ARM64 code is implemented, we shouldn't need this anymore
-#include "iree/builtins/ukernel/mmt4d_generic.h"
-
-#if defined(IREE_UKERNEL_ARCH_ARM_64)
-
-int iree_ukernel_mmt4d_f32f32f32_arm_64(
- const iree_ukernel_mmt4d_f32f32f32_params_t* params) {
- // TODO: implement actual arm assembly kernels instead of calling _generic.
- return iree_ukernel_mmt4d_f32f32f32_generic(params);
-}
-
-int iree_ukernel_mmt4d_i8i8i32_arm_64(
- const iree_ukernel_mmt4d_i8i8i32_params_t* params) {
- // TODO: implement actual arm assembly kernels instead of calling _generic.
- return iree_ukernel_mmt4d_i8i8i32_generic(params);
-}
-
-#endif // IREE_UKERNEL_ARCH_ARM_64
diff --git a/runtime/src/iree/builtins/ukernel/mmt4d_arm_64.h b/runtime/src/iree/builtins/ukernel/mmt4d_arm_64.h
deleted file mode 100644
index 1ea08fa..0000000
--- a/runtime/src/iree/builtins/ukernel/mmt4d_arm_64.h
+++ /dev/null
@@ -1,21 +0,0 @@
-// Copyright 2022 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#ifndef IREE_BUILTINS_UKERNEL_MMT4D_ARM_64_H_
-#define IREE_BUILTINS_UKERNEL_MMT4D_ARM_64_H_
-
-#include "iree/builtins/ukernel/mmt4d.h"
-
-#if defined(IREE_UKERNEL_ARCH_ARM_64)
-
-int iree_ukernel_mmt4d_f32f32f32_arm_64(
- const iree_ukernel_mmt4d_f32f32f32_params_t* params);
-int iree_ukernel_mmt4d_i8i8i32_arm_64(
- const iree_ukernel_mmt4d_i8i8i32_params_t* params);
-
-#endif // IREE_UKERNEL_ARCH_ARM_64
-
-#endif // IREE_BUILTINS_UKERNEL_MMT4D_ARM_64_H_
diff --git a/runtime/src/iree/builtins/ukernel/mmt4d_generic.c b/runtime/src/iree/builtins/ukernel/mmt4d_generic.c
deleted file mode 100644
index cc9eeb4..0000000
--- a/runtime/src/iree/builtins/ukernel/mmt4d_generic.c
+++ /dev/null
@@ -1,80 +0,0 @@
-// Copyright 2022 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#include "iree/builtins/ukernel/mmt4d_generic.h"
-
-#include <stdbool.h>
-
-int iree_ukernel_mmt4d_f32f32f32_generic(
- const iree_ukernel_mmt4d_f32f32f32_params_t* params) {
- bool accumulate = params->flags & IREE_VMVX_MATMUL_FLAG_ACCUMULATE;
- iree_ukernel_size_t lhs_tile_size = params->M0 * params->K0;
- iree_ukernel_size_t rhs_tile_size = params->N0 * params->K0;
- iree_ukernel_size_t out_tile_size = params->M0 * params->N0;
- for (iree_ukernel_size_t i = 0; i < params->M; ++i) {
- for (iree_ukernel_size_t j = 0; j < params->N; ++j) {
- float* out_tile_ptr =
- params->out_buffer + i * params->out_stride + j * out_tile_size;
- const float* lhs_panel_ptr = params->lhs_buffer + i * params->lhs_stride;
- const float* rhs_panel_ptr = params->rhs_buffer + j * params->rhs_stride;
- for (iree_ukernel_size_t i0 = 0; i0 < params->M0; ++i0) {
- for (iree_ukernel_size_t j0 = 0; j0 < params->N0; ++j0) {
- const float* lhs_tile_ptr = lhs_panel_ptr;
- const float* rhs_tile_ptr = rhs_panel_ptr;
- float* out_ptr = out_tile_ptr + i0 * params->N0 + j0;
- float acc = accumulate ? *out_ptr : 0.f;
- for (iree_ukernel_size_t k = 0; k < params->K; ++k) {
- for (iree_ukernel_size_t k0 = 0; k0 < params->K0; ++k0) {
- float lhs_val = lhs_tile_ptr[i0 * params->K0 + k0];
- float rhs_val = rhs_tile_ptr[j0 * params->K0 + k0];
- acc += lhs_val * rhs_val;
- }
- lhs_tile_ptr += lhs_tile_size;
- rhs_tile_ptr += rhs_tile_size;
- }
- *out_ptr = acc;
- }
- }
- }
- }
- return 0;
-}
-
-int iree_ukernel_mmt4d_i8i8i32_generic(
- const iree_ukernel_mmt4d_i8i8i32_params_t* params) {
- bool accumulate = params->flags & IREE_VMVX_MATMUL_FLAG_ACCUMULATE;
- iree_ukernel_size_t lhs_tile_size = params->M0 * params->K0;
- iree_ukernel_size_t rhs_tile_size = params->N0 * params->K0;
- iree_ukernel_size_t out_tile_size = params->M0 * params->N0;
- for (iree_ukernel_size_t i = 0; i < params->M; ++i) {
- for (iree_ukernel_size_t j = 0; j < params->N; ++j) {
- int32_t* out_tile_ptr =
- params->out_buffer + i * params->out_stride + j * out_tile_size;
- const int8_t* lhs_panel_ptr = params->lhs_buffer + i * params->lhs_stride;
- const int8_t* rhs_panel_ptr = params->rhs_buffer + j * params->rhs_stride;
- for (iree_ukernel_size_t i0 = 0; i0 < params->M0; ++i0) {
- for (iree_ukernel_size_t j0 = 0; j0 < params->N0; ++j0) {
- const int8_t* lhs_tile_ptr = lhs_panel_ptr;
- const int8_t* rhs_tile_ptr = rhs_panel_ptr;
- int32_t* out_ptr = out_tile_ptr + i0 * params->N0 + j0;
- int32_t acc = accumulate ? *out_ptr : 0;
- for (iree_ukernel_size_t k = 0; k < params->K; ++k) {
- for (iree_ukernel_size_t k0 = 0; k0 < params->K0; ++k0) {
- // C's implicit promotion to int saves skin, but let's be explicit
- int32_t lhs_val_int32 = lhs_tile_ptr[i0 * params->K0 + k0];
- int32_t rhs_val_int32 = rhs_tile_ptr[j0 * params->K0 + k0];
- acc += lhs_val_int32 * rhs_val_int32;
- }
- lhs_tile_ptr += lhs_tile_size;
- rhs_tile_ptr += rhs_tile_size;
- }
- *out_ptr = acc;
- }
- }
- }
- }
- return 0;
-}
diff --git a/runtime/src/iree/builtins/ukernel/mmt4d_generic.h b/runtime/src/iree/builtins/ukernel/mmt4d_generic.h
deleted file mode 100644
index 5dc0b5d..0000000
--- a/runtime/src/iree/builtins/ukernel/mmt4d_generic.h
+++ /dev/null
@@ -1,17 +0,0 @@
-// Copyright 2022 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#ifndef IREE_BUILTINS_UKERNEL_MMT4D_GENERIC_H_
-#define IREE_BUILTINS_UKERNEL_MMT4D_GENERIC_H_
-
-#include "iree/builtins/ukernel/mmt4d.h"
-
-int iree_ukernel_mmt4d_f32f32f32_generic(
- const iree_ukernel_mmt4d_f32f32f32_params_t* params);
-int iree_ukernel_mmt4d_i8i8i32_generic(
- const iree_ukernel_mmt4d_i8i8i32_params_t* params);
-
-#endif // IREE_BUILTINS_UKERNEL_MMT4D_GENERIC_H_
diff --git a/runtime/src/iree/builtins/ukernel/mmt4d_tile_generic.c b/runtime/src/iree/builtins/ukernel/mmt4d_tile_generic.c
new file mode 100644
index 0000000..b932ce2
--- /dev/null
+++ b/runtime/src/iree/builtins/ukernel/mmt4d_tile_generic.c
@@ -0,0 +1,120 @@
+// Copyright 2022 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/builtins/ukernel/mmt4d_tile_generic.h"
+
+// In order to be helpful as a reference for future architecture-specific
+// kernels, the generic kernels here are structured like an actual optimized
+// kernel, using an "accumulator tile" that in this case is a stack array
+// (which would become a group of SIMD registers in an actual optimized kernel).
+// The downside of this approach is that we have to set a fixed max size for
+// the accumulator tile, but for now all known cases are comfortably far below
+// where trouble would happen. For reference:
+// - On ARM NEON, the entire register space is 512 bytes, so the accumulator
+// tile is less than that.
+// - On ARM SME, we will be working with an accumulator tile as large as 2048
+// bytes (IIUC).
+// - The smallest stack frame size limit that we know we may have to deal with
+// on certain targets is 16 kilobytes.
+// The size or architecture-specific tiles is relevant here because this
+// generic code is what will be run as a fallback if the device is found not to
+// support the CPU feature that the tile sizes were picked to target.
+enum { iree_ukernel_mmt4d_tile_generic_max_bytes = 2048 };
+
+// Generic implementation of matmul tile, i8*i8->i32 case.
+static void iree_ukernel_mmt4d_tile_i8i8i32_generic(
+ void* out_tile_untyped, const void* lhs_panel_untyped,
+ const void* rhs_panel_untyped, int32_t K, uint32_t flags,
+ const iree_ukernel_mmt4d_params_t* params) {
+ int32_t* out_tile = out_tile_untyped;
+ const int8_t* lhs_panel = lhs_panel_untyped;
+ const int8_t* rhs_panel = rhs_panel_untyped;
+ int16_t M0 = params->M0;
+ int16_t N0 = params->N0;
+ int16_t K0 = params->K0;
+ // Initialize the local accumulator tile.
+ int32_t acc[iree_ukernel_mmt4d_tile_generic_max_bytes / sizeof(*out_tile)];
+ if (flags & IREE_VMVX_MATMUL_FLAG_ACCUMULATE) {
+ for (int i = 0; i < M0 * N0; ++i) acc[i] = out_tile[i];
+ } else {
+ for (int i = 0; i < M0 * N0; ++i) acc[i] = 0;
+ }
+ // Accumulation loop.
+ for (iree_ukernel_size_t k = 0; k < K; ++k) {
+ for (iree_ukernel_size_t i0 = 0; i0 < M0; ++i0) {
+ for (iree_ukernel_size_t j0 = 0; j0 < N0; ++j0) {
+ for (iree_ukernel_size_t k0 = 0; k0 < K0; ++k0) {
+ int32_t lhs_val_int32 = lhs_panel[i0 * K0 + k0];
+ int32_t rhs_val_int32 = rhs_panel[j0 * K0 + k0];
+ acc[i0 * N0 + j0] += lhs_val_int32 * rhs_val_int32;
+ }
+ }
+ }
+ lhs_panel += M0 * K0;
+ rhs_panel += N0 * K0;
+ }
+ // Store the local accumulator tile to the destination.
+ for (int i = 0; i < M0 * N0; ++i) out_tile[i] = acc[i];
+}
+
+// Generic implementation of matmul tile, f32*f32->f32 case.
+static void iree_ukernel_mmt4d_tile_f32f32f32_generic(
+ void* out_tile_untyped, const void* lhs_panel_untyped,
+ const void* rhs_panel_untyped, int32_t K, uint32_t flags,
+ const iree_ukernel_mmt4d_params_t* params) {
+ float* out_tile = out_tile_untyped;
+ const float* lhs_panel = lhs_panel_untyped;
+ const float* rhs_panel = rhs_panel_untyped;
+ int16_t M0 = params->M0;
+ int16_t N0 = params->N0;
+ int16_t K0 = params->K0;
+ // Initialize the local accumulator tile.
+ float acc[iree_ukernel_mmt4d_tile_generic_max_bytes / sizeof(*out_tile)];
+ if (flags & IREE_VMVX_MATMUL_FLAG_ACCUMULATE) {
+ for (int i = 0; i < M0 * N0; ++i) acc[i] = out_tile[i];
+ } else {
+ for (int i = 0; i < M0 * N0; ++i) acc[i] = 0;
+ }
+ // Accumulation loop.
+ for (iree_ukernel_size_t k = 0; k < K; ++k) {
+ for (iree_ukernel_size_t i0 = 0; i0 < M0; ++i0) {
+ for (iree_ukernel_size_t j0 = 0; j0 < N0; ++j0) {
+ for (iree_ukernel_size_t k0 = 0; k0 < K0; ++k0) {
+ float lhs_val = lhs_panel[i0 * K0 + k0];
+ float rhs_val = rhs_panel[j0 * K0 + k0];
+ acc[i0 * N0 + j0] += lhs_val * rhs_val;
+ }
+ }
+ }
+ lhs_panel += M0 * K0;
+ rhs_panel += N0 * K0;
+ }
+ // Store the local accumulator tile to the destination.
+ for (int i = 0; i < M0 * N0; ++i) out_tile[i] = acc[i];
+}
+
+// Generic implementation of matmul tile
+iree_ukernel_mmt4d_status_t iree_ukernel_mmt4d_select_tile_func_generic(
+ const iree_ukernel_mmt4d_params_t* params,
+ iree_ukernel_mmt4d_tile_func_t* out_tile_func) {
+ int tile_elems = params->M0 * params->N0;
+ int tile_bytes = tile_elems
+ << iree_ukernel_mmt4d_out_elem_size_log2(params->type);
+ if (tile_bytes > iree_ukernel_mmt4d_tile_generic_max_bytes) {
+ return iree_ukernel_mmt4d_status_unsupported_generic_tile_size;
+ }
+ switch (params->type) {
+ case iree_ukernel_mmt4d_type_f32f32f32:
+ *out_tile_func = iree_ukernel_mmt4d_tile_f32f32f32_generic;
+ return iree_ukernel_mmt4d_status_ok;
+ case iree_ukernel_mmt4d_type_i8i8i32:
+ *out_tile_func = iree_ukernel_mmt4d_tile_i8i8i32_generic;
+ return iree_ukernel_mmt4d_status_ok;
+ default:
+ // shouldn't happen, validated earlier.
+ return iree_ukernel_mmt4d_status_bad_type;
+ }
+}
diff --git a/runtime/src/iree/builtins/ukernel/mmt4d_tile_generic.h b/runtime/src/iree/builtins/ukernel/mmt4d_tile_generic.h
new file mode 100644
index 0000000..685d2b7
--- /dev/null
+++ b/runtime/src/iree/builtins/ukernel/mmt4d_tile_generic.h
@@ -0,0 +1,19 @@
+// Copyright 2022 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_BUILTINS_UKERNEL_MMT4D_GENERIC_H_
+#define IREE_BUILTINS_UKERNEL_MMT4D_GENERIC_H_
+
+#include "iree/builtins/ukernel/mmt4d_types.h"
+
+// On success, *out_tile_func is the generic tile function to use to perform the
+// mmt4d with the given *params. The caller may want to first try to get an
+// optimized architecture-specific tile function before falling back on this.
+iree_ukernel_mmt4d_status_t iree_ukernel_mmt4d_select_tile_func_generic(
+ const iree_ukernel_mmt4d_params_t* params,
+ iree_ukernel_mmt4d_tile_func_t* out_tile_func);
+
+#endif // IREE_BUILTINS_UKERNEL_MMT4D_GENERIC_H_
diff --git a/runtime/src/iree/builtins/ukernel/mmt4d_types.h b/runtime/src/iree/builtins/ukernel/mmt4d_types.h
new file mode 100644
index 0000000..209a318
--- /dev/null
+++ b/runtime/src/iree/builtins/ukernel/mmt4d_types.h
@@ -0,0 +1,131 @@
+// Copyright 2022 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_BUILTINS_UKERNEL_MMT4D_TYPES_H_
+#define IREE_BUILTINS_UKERNEL_MMT4D_TYPES_H_
+
+#include "iree/builtins/ukernel/common.h"
+
+// Supported combinations of data types (order: LHS, RHS, OUT).
+enum iree_ukernel_mmt4d_type_t {
+ iree_ukernel_mmt4d_type_none = 0,
+ iree_ukernel_mmt4d_type_f32f32f32,
+ iree_ukernel_mmt4d_type_i8i8i32,
+};
+
+typedef enum iree_ukernel_mmt4d_type_t iree_ukernel_mmt4d_type_t;
+
+// Parameters for a mmt4d operation.
+struct iree_ukernel_mmt4d_params_t {
+ iree_ukernel_mmt4d_type_t type;
+ uint32_t flags;
+ const void* lhs_buffer;
+ const void* rhs_buffer;
+ void* out_buffer;
+ iree_ukernel_size_t lhs_stride;
+ iree_ukernel_size_t rhs_stride;
+ iree_ukernel_size_t out_stride;
+ iree_ukernel_size_t M;
+ iree_ukernel_size_t N;
+ iree_ukernel_size_t K;
+ int32_t M0;
+ int32_t N0;
+ int32_t K0;
+ uint64_t cpu_data_field_0;
+};
+
+typedef struct iree_ukernel_mmt4d_params_t iree_ukernel_mmt4d_params_t;
+
+// Status codes returned by a mmt4d operation.
+enum iree_ukernel_mmt4d_status_t {
+ iree_ukernel_mmt4d_status_ok = 0,
+ iree_ukernel_mmt4d_status_bad_type,
+ iree_ukernel_mmt4d_status_bad_flags,
+ iree_ukernel_mmt4d_status_unsupported_huge_or_negative_dimension,
+ iree_ukernel_mmt4d_status_unsupported_generic_tile_size,
+};
+
+typedef enum iree_ukernel_mmt4d_status_t iree_ukernel_mmt4d_status_t;
+
+// TODO: move these flags to a header file shared with compiler/.
+#define IREE_VMVX_MATMUL_FLAG_ACCUMULATE 1
+
+#define IREE_UKERNEL_MMT4D_RETURN_IF_ERROR(X) \
+ do { \
+ iree_ukernel_mmt4d_status_t status = (X); \
+ if (status != iree_ukernel_mmt4d_status_ok) { \
+ return status; \
+ } \
+ } while (0)
+
+// Function pointer type for tile functions, i.e. typically architecture
+// specific functions computing one M0xN0 tile of the output matrix, i.e.
+// the inner-most loop of the matmul, i.e. the thing that we should actually
+// be calling "micro kernel" except that the name is already taken by the
+// higher-level builtin name.
+//
+// The 'params' argument is only used by generic kernels. Actual optimized
+// kernels are already specialized for a given tile shape (M0xN0xK0), so the
+// five first arguments here are the only information that they need. Not having
+// to address 'params' struct fields in the middle of assembly kernels is
+// good, because it's hard to get the struct field offsets right in assembly
+// and keep that in sync with future struct changes.
+typedef void (*iree_ukernel_mmt4d_tile_func_t)(
+ void* /*out_tile*/, const void* /*lhs_panel*/, const void* /*rhs_panel*/,
+ int32_t /*K*/, uint32_t /*flags*/,
+ const iree_ukernel_mmt4d_params_t* /*params*/);
+
+// Tile kernel declarations. Prototype matches iree_ukernel_mmt4d_tile_func_t.
+#define IREE_UKERNEL_MMT4D_TILE_FUNC_DECL(NAME) \
+ void NAME(void* out_tile, const void* lhs_panel, const void* rhs_panel, \
+ int32_t K, uint32_t flags, \
+ const iree_ukernel_mmt4d_params_t* params);
+
+// Log2 of size of LHS matrix element type, e.g. f32 --> size=4 --> log2=2
+static inline int iree_ukernel_mmt4d_lhs_elem_size_log2(
+ iree_ukernel_mmt4d_type_t type) {
+ switch (type) {
+ case iree_ukernel_mmt4d_type_f32f32f32:
+ return 2;
+ default:
+ return 0;
+ }
+}
+
+static inline int iree_ukernel_mmt4d_lhs_elem_size(
+ iree_ukernel_mmt4d_type_t type) {
+ return 1 << iree_ukernel_mmt4d_lhs_elem_size_log2(type);
+}
+
+// Log2 of size of RHS matrix element type, e.g. f32 --> size=4 --> log2=2
+static inline int iree_ukernel_mmt4d_rhs_elem_size_log2(
+ iree_ukernel_mmt4d_type_t type) {
+ return iree_ukernel_mmt4d_lhs_elem_size_log2(type); // for now it's the same
+}
+
+static inline int iree_ukernel_mmt4d_rhs_elem_size(
+ iree_ukernel_mmt4d_type_t type) {
+ return 1 << iree_ukernel_mmt4d_rhs_elem_size_log2(type);
+}
+
+// Log2 of size of OUT matrix element type, e.g. f32 --> size=4 --> log2=2
+static inline int iree_ukernel_mmt4d_out_elem_size_log2(
+ iree_ukernel_mmt4d_type_t type) {
+ switch (type) {
+ case iree_ukernel_mmt4d_type_f32f32f32:
+ case iree_ukernel_mmt4d_type_i8i8i32:
+ return 2;
+ default:
+ return 0;
+ }
+}
+
+static inline int iree_ukernel_mmt4d_out_elem_size(
+ iree_ukernel_mmt4d_type_t type) {
+ return 1 << iree_ukernel_mmt4d_out_elem_size_log2(type);
+}
+
+#endif // IREE_BUILTINS_UKERNEL_MMT4D_TYPES_H_
diff --git a/runtime/src/iree/builtins/ukernel/tools/BUILD b/runtime/src/iree/builtins/ukernel/tools/BUILD
index ac1940a..21c4c87 100644
--- a/runtime/src/iree/builtins/ukernel/tools/BUILD
+++ b/runtime/src/iree/builtins/ukernel/tools/BUILD
@@ -13,11 +13,24 @@
licenses = ["notice"], # Apache 2.0
)
+cc_library(
+ name = "mmt4d_test_utils",
+ srcs = ["mmt4d_test_utils.cc"],
+ hdrs = ["mmt4d_test_utils.h"],
+ deps = [
+ "//runtime/src/iree/base",
+ "//runtime/src/iree/builtins/ukernel:types",
+ "//runtime/src/iree/schemas:cpu_data",
+ ],
+)
+
cc_binary_benchmark(
name = "mmt4d_benchmark",
srcs = ["mmt4d_benchmark.c"],
deps = [
+ ":mmt4d_test_utils",
"//runtime/src/iree/base",
+ "//runtime/src/iree/base/internal:cpu",
"//runtime/src/iree/base/internal:flags",
"//runtime/src/iree/builtins/ukernel",
"//runtime/src/iree/testing:benchmark",
@@ -28,10 +41,11 @@
name = "mmt4d_test",
srcs = ["mmt4d_test.cc"],
deps = [
+ ":mmt4d_test_utils",
"//runtime/src/iree/base",
+ "//runtime/src/iree/base/internal:cpu",
"//runtime/src/iree/base/internal:flags",
"//runtime/src/iree/builtins/ukernel",
"//runtime/src/iree/testing:gtest",
- "//runtime/src/iree/testing:gtest_main",
],
)
diff --git a/runtime/src/iree/builtins/ukernel/tools/CMakeLists.txt b/runtime/src/iree/builtins/ukernel/tools/CMakeLists.txt
index 3e8f6d4..4b4e455 100644
--- a/runtime/src/iree/builtins/ukernel/tools/CMakeLists.txt
+++ b/runtime/src/iree/builtins/ukernel/tools/CMakeLists.txt
@@ -10,13 +10,29 @@
iree_add_all_subdirs()
+iree_cc_library(
+ NAME
+ mmt4d_test_utils
+ HDRS
+ "mmt4d_test_utils.h"
+ SRCS
+ "mmt4d_test_utils.cc"
+ DEPS
+ iree::base
+ iree::builtins::ukernel::types
+ iree::schemas::cpu_data
+ PUBLIC
+)
+
iree_cc_binary_benchmark(
NAME
mmt4d_benchmark
SRCS
"mmt4d_benchmark.c"
DEPS
+ ::mmt4d_test_utils
iree::base
+ iree::base::internal::cpu
iree::base::internal::flags
iree::builtins::ukernel
iree::testing::benchmark
@@ -29,11 +45,12 @@
SRCS
"mmt4d_test.cc"
DEPS
+ ::mmt4d_test_utils
iree::base
+ iree::base::internal::cpu
iree::base::internal::flags
iree::builtins::ukernel
iree::testing::gtest
- iree::testing::gtest_main
)
### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/runtime/src/iree/builtins/ukernel/tools/mmt4d_benchmark.c b/runtime/src/iree/builtins/ukernel/tools/mmt4d_benchmark.c
index ee16afd..60f4826 100644
--- a/runtime/src/iree/builtins/ukernel/tools/mmt4d_benchmark.c
+++ b/runtime/src/iree/builtins/ukernel/tools/mmt4d_benchmark.c
@@ -4,39 +4,160 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-// THIS IS STILL JUST A PLACEHOLDER - NOT AN ACTUAL TEST YET.
+// clang-format off
+#include <stdint.h> // include before ukernel/common.h to keep standard types
+// clang-format on
-#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
#include "iree/base/api.h"
+#include "iree/base/internal/cpu.h"
#include "iree/base/internal/flags.h"
#include "iree/builtins/ukernel/mmt4d.h"
+#include "iree/builtins/ukernel/tools/mmt4d_test_utils.h"
#include "iree/testing/benchmark.h"
-// Example flag; not really useful:
-IREE_FLAG(int32_t, batch_count, 64, "Ops to run per benchmark iteration.");
+IREE_FLAG(int32_t, batch_count, 1000, "Ops to run per benchmark iteration.");
+IREE_FLAG(int32_t, m_size, 1,
+ "M-dimension of mmt4d ops. The overall number of rows of the "
+ "accumulator is that times the M0 tile size.");
+IREE_FLAG(int32_t, n_size, 1,
+ "N-dimension of mmt4d ops. The overall number of columns of the "
+ "accumulator is that times the N0 tile size.");
+IREE_FLAG(
+ int32_t, k_size, 256,
+ "K-dimension of mmt4d ops. That's the number of iterations of the inner "
+ "loop. The overall accumulation depth is that times the K0 tile size.");
+IREE_FLAG(bool, accumulate, false,
+ "Whether the kernel should accumulate into the existing accumulator "
+ "tile values, or zero the accumulator tile.");
-static iree_status_t iree_mmt4d_example_matmul_f32_benchmark(
+struct iree_mmt4d_benchmark_user_data_t {
+ iree_ukernel_mmt4d_type_t type;
+ int M0;
+ int N0;
+ int K0;
+ uint64_t cpu_data_field_0;
+};
+
+typedef struct iree_mmt4d_benchmark_user_data_t
+ iree_mmt4d_benchmark_user_data_t;
+
+static iree_status_t iree_mmt4d_benchmark(
const iree_benchmark_def_t* benchmark_def,
iree_benchmark_state_t* benchmark_state) {
+ const iree_mmt4d_benchmark_user_data_t* user_data = benchmark_def->user_data;
+ iree_ukernel_mmt4d_params_t params;
+ memset(¶ms, 0, sizeof params);
+ params.type = user_data->type;
+ params.flags = FLAG_accumulate ? IREE_VMVX_MATMUL_FLAG_ACCUMULATE : 0;
+ params.M = FLAG_m_size;
+ params.N = FLAG_n_size;
+ params.K = FLAG_k_size;
+ params.M0 = user_data->M0;
+ params.N0 = user_data->N0;
+ params.K0 = user_data->K0;
+ params.cpu_data_field_0 = user_data->cpu_data_field_0;
+ params.lhs_stride = params.K * params.M0 * params.K0;
+ params.rhs_stride = params.K * params.N0 * params.K0;
+ params.out_stride = params.N * params.M0 * params.N0;
+ iree_ukernel_size_t lhs_buffer_size =
+ iree_ukernel_mmt4d_lhs_buffer_size(¶ms);
+ iree_ukernel_size_t rhs_buffer_size =
+ iree_ukernel_mmt4d_rhs_buffer_size(¶ms);
+ iree_ukernel_size_t out_buffer_size =
+ iree_ukernel_mmt4d_out_buffer_size(¶ms);
+ void* lhs_buffer = malloc(lhs_buffer_size);
+ void* rhs_buffer = malloc(lhs_buffer_size);
+ void* out_buffer = malloc(lhs_buffer_size);
+ iree_mmt4d_scalar_type_t lhs_type = iree_ukernel_mmt4d_lhs_type(¶ms);
+ iree_mmt4d_scalar_type_t rhs_type = iree_ukernel_mmt4d_rhs_type(¶ms);
+ iree_mmt4d_scalar_type_t out_type = iree_ukernel_mmt4d_out_type(¶ms);
+ iree_mmt4d_test_random_engine_t* engine =
+ iree_mmt4d_test_random_engine_create();
+ // It's just about plausible that on some platform, for some number type,
+ // performance might be different on zero buffers vs random buffers. But it
+ // shouldn't matter that we recreate the random engine every time, getting
+ // the same random values again.
+ write_random_buffer(lhs_buffer, lhs_buffer_size, lhs_type, engine);
+ write_random_buffer(rhs_buffer, rhs_buffer_size, rhs_type, engine);
+ write_random_buffer(out_buffer, out_buffer_size, out_type, engine);
+ iree_mmt4d_test_random_engine_destroy(engine);
+ params.lhs_buffer = lhs_buffer;
+ params.rhs_buffer = rhs_buffer;
+ params.out_buffer = out_buffer;
+ int64_t total_iterations = 0;
while (iree_benchmark_keep_running(benchmark_state,
/*batch_count=*/FLAG_batch_count)) {
for (int i = 0; i < FLAG_batch_count; ++i) {
- iree_ukernel_mmt4d_f32f32f32_params_t params;
- memset(¶ms, 0, sizeof params);
- int ukernel_retcode = iree_ukernel_mmt4d_f32f32f32(¶ms);
- if (0 != iree_ukernel_mmt4d_f32f32f32(¶ms)) {
- fprintf(stderr, "FATAL: iree_ukernel_mmt4d_f32f32f32 failed: %s\n",
- iree_ukernel_mmt4d_error_message(ukernel_retcode));
+ iree_ukernel_mmt4d_status_t status = iree_ukernel_mmt4d(¶ms);
+ if (status != iree_ukernel_mmt4d_status_ok) {
+ fprintf(stderr, "FATAL: iree_ukernel_mmt4d failed: %s\n",
+ iree_ukernel_mmt4d_status_message(status));
abort();
}
}
+ total_iterations += FLAG_batch_count;
}
+ iree_benchmark_set_items_processed(
+ benchmark_state, total_iterations * 2 * params.M * params.N * params.K *
+ params.M0 * params.N0 * params.K0);
+ free(lhs_buffer);
+ free(rhs_buffer);
+ free(out_buffer);
return iree_ok_status();
}
+static void iree_mmt4d_benchmark_register(
+ const iree_mmt4d_benchmark_user_data_t* user_data, const char* name) {
+ // Does this benchmark require an optional CPU feature?
+ if (user_data->cpu_data_field_0) {
+ if ((iree_cpu_data_field(0) & user_data->cpu_data_field_0) !=
+ user_data->cpu_data_field_0) {
+ // The CPU does not meet this benchmark's requirements. The builtin
+ // would fall back on generic code. We don't need more generic benchmark
+ // results.
+ return;
+ }
+ }
+
+ // benchmark_def does not need to be static, it will be cloned.
+ const iree_benchmark_def_t benchmark_def = {
+ .flags = IREE_BENCHMARK_FLAG_USE_REAL_TIME,
+ .time_unit = IREE_BENCHMARK_UNIT_MICROSECOND,
+ .minimum_duration_ns = 0,
+ .iteration_count = 0,
+ .run = iree_mmt4d_benchmark,
+ .user_data = user_data,
+ };
+ iree_benchmark_register(IREE_SV(name), &benchmark_def);
+}
+
+#define IREE_MMT4D_BENCHMARK_REGISTER(_type, _m0, _n0, _k0, _cpu_data_field_0, \
+ _label) \
+ do { \
+ static const iree_mmt4d_benchmark_user_data_t user_data = { \
+ .type = iree_ukernel_mmt4d_type_##_type, \
+ .M0 = _m0, \
+ .N0 = _n0, \
+ .K0 = _k0, \
+ .cpu_data_field_0 = _cpu_data_field_0, \
+ }; \
+ iree_mmt4d_benchmark_register(&user_data, \
+ "iree_ukernel_mmt4d_" #_type "_" #_m0 \
+ "x" #_n0 "x" #_k0 "_" #_label); \
+ } while (0)
+
+#define IREE_MMT4D_BENCHMARK_REGISTER_GENERIC(_type, _m0, _n0, _k0) \
+ IREE_MMT4D_BENCHMARK_REGISTER(_type, _m0, _n0, _k0, 0, GENERIC)
+
+#define IREE_MMT4D_BENCHMARK_REGISTER_ARM_64(_type, _m0, _n0, _k0, \
+ _cpu_feature) \
+ IREE_MMT4D_BENCHMARK_REGISTER( \
+ _type, _m0, _n0, _k0, IREE_CPU_DATA_FIELD_0_AARCH64_HAVE_##_cpu_feature, \
+ arm_64_##_cpu_feature)
+
int main(int argc, char** argv) {
iree_flags_set_usage(
"mmt4d_benchmark",
@@ -45,22 +166,21 @@
iree_flags_parse_checked(IREE_FLAGS_PARSE_MODE_UNDEFINED_OK, &argc, &argv);
iree_benchmark_initialize(&argc, argv);
+ iree_cpu_initialize(iree_allocator_system());
- // TODO: always add _generic variants to have a baseline vs reference?
+ // Generic code paths, not actually used, but interesting to get a sense
+ // of how slow generic code goes vs decent SIMD kernels. Interesting also to
+ // compare generic float vs int arithmetic.
+ IREE_MMT4D_BENCHMARK_REGISTER_GENERIC(f32f32f32, 4, 4, 1);
+ IREE_MMT4D_BENCHMARK_REGISTER_GENERIC(i8i8i32, 4, 4, 1);
- {
- static const iree_benchmark_def_t benchmark_def = {
- .flags = IREE_BENCHMARK_FLAG_MEASURE_PROCESS_CPU_TIME |
- IREE_BENCHMARK_FLAG_USE_REAL_TIME,
- .time_unit = IREE_BENCHMARK_UNIT_NANOSECOND,
- .minimum_duration_ns = 0,
- .iteration_count = 0,
- .run = iree_mmt4d_example_matmul_f32_benchmark,
- .user_data = NULL,
- };
- iree_benchmark_register(IREE_SV("iree_mmt4d_example_matmul_f32"),
- &benchmark_def);
- }
+// ARM_64 benchmarks.
+#if defined(IREE_UKERNEL_ARCH_ARM_64)
+
+ IREE_MMT4D_BENCHMARK_REGISTER_ARM_64(i8i8i32, 8, 8, 4, DOTPROD);
+ IREE_MMT4D_BENCHMARK_REGISTER_ARM_64(i8i8i32, 8, 8, 8, I8MM);
+
+#endif // defined(IREE_UKERNEL_ARCH_ARM_64)
iree_benchmark_run_specified();
return 0;
diff --git a/runtime/src/iree/builtins/ukernel/tools/mmt4d_test.cc b/runtime/src/iree/builtins/ukernel/tools/mmt4d_test.cc
index c3094e2..0dab85e 100644
--- a/runtime/src/iree/builtins/ukernel/tools/mmt4d_test.cc
+++ b/runtime/src/iree/builtins/ukernel/tools/mmt4d_test.cc
@@ -4,22 +4,301 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-// THIS IS STILL JUST A PLACEHOLDER - NOT AN ACTUAL TEST YET.
+// Design rationale and code creep warning!
+//
+// Summary:
+//
+// The goal of this test is to provide 100% coverage across all
+// internal kernel variants, which is not convenient to do in e2e tests.
+// Resist the temptation to reimplement here all the niceties of the e2e test.
+// Stick to guaranteeing that if the test succeeds, then the mmt4d builtin,
+// with all its asm code path variants, is correct. In case of failure, the
+// user is expected to be happy to jump into a debugger.
+//
+// Longer story:
+//
+// It is said by an ancient prophecy that all matrix multiplication tests grow
+// to be thousands of lines of code.
+//
+// In fact, we already have one, it's the end-to-end matmul test under
+// iree/tests/e2e/matmul. That one is needed anyway, and needs to be large
+// anyway, being end-to-end and applying to all target backends, including those
+// where device!=host. And so it makes sense for that one to have extra bells
+// and whistles such as fuzzy comparisons, pretty-printing of numerical errors
+// to aid debugging, and yet more special logic to make numerical errors easier
+// to debug.
+//
+// Let's not duplicate all that here! Note also that, tempting as it would
+// be to borrow the matrix-pretty-printing stuff from e2e/matmul, that applies
+// to plain row-major 2D matrices, while here we are dealing with 4D arrays /
+// tiled-layout matrices. Trying to bridge over that difference would bring yet
+// more complexity.
+//
+// Instead, let us keep a sharp focus on why we need this separate micro test.
+// The motivation is not the usual "because micro tests are easier to debug than
+// e2e" but rather because it would be difficult to have 100% code coverage in
+// e2e. There are many variants of mmt4d builtin ukernels for various CPU
+// features and tuned for various CPU models. We have to iterate over all these
+// variants. Trying to do so in e2e tests would require exposing knobs for
+// things that we would otherwise prefer to keep internal in the mmt4d builtin
+// implementation, and would make e2e/matmul tests even more expensive.
-#include <stdint.h>
+// clang-format off
+#include <stdint.h> // include before ukernel/common.h to keep standard types
+// clang-format on
-// Include in expected order with stdint and other system headers first.
-// See the note in mmt4d.h about stdint.h. This won't be an issue in most uses
-// but clang-format really likes to put the mmt4d.h above the system headers
-// due to this _test.cc file naming.
+#include "iree/builtins/ukernel/mmt4d.h"
+
+#include <vector>
#include "iree/base/api.h"
-#include "iree/builtins/ukernel/mmt4d.h"
+#include "iree/base/internal/cpu.h"
+#include "iree/builtins/ukernel/tools/mmt4d_test_utils.h"
#include "iree/testing/gtest.h"
#include "iree/testing/status_matchers.h"
-TEST(MMT4DTest, iree_mmt4d_example_matmul_f32) {
- iree_ukernel_mmt4d_f32f32f32_params_t params;
+template <typename lhs_t, typename rhs_t, typename out_t>
+static void iree_mmt4d_reference(const iree_ukernel_mmt4d_params_t& params) {
+ bool accumulate = params.flags & IREE_VMVX_MATMUL_FLAG_ACCUMULATE;
+ iree_ukernel_size_t lhs_tile_size = params.M0 * params.K0;
+ iree_ukernel_size_t rhs_tile_size = params.N0 * params.K0;
+ iree_ukernel_size_t out_tile_size = params.M0 * params.N0;
+ for (iree_ukernel_size_t i = 0; i < params.M; ++i) {
+ for (iree_ukernel_size_t j = 0; j < params.N; ++j) {
+ out_t* out_tile_ptr = ((out_t*)params.out_buffer) +
+ i * params.out_stride + j * out_tile_size;
+ const lhs_t* lhs_panel_ptr =
+ ((const lhs_t*)params.lhs_buffer) + i * params.lhs_stride;
+ const rhs_t* rhs_panel_ptr =
+ ((const rhs_t*)params.rhs_buffer) + j * params.rhs_stride;
+ for (iree_ukernel_size_t i0 = 0; i0 < params.M0; ++i0) {
+ for (iree_ukernel_size_t j0 = 0; j0 < params.N0; ++j0) {
+ const lhs_t* lhs_tile_ptr = lhs_panel_ptr;
+ const rhs_t* rhs_tile_ptr = rhs_panel_ptr;
+ out_t* out_ptr = out_tile_ptr + i0 * params.N0 + j0;
+ out_t acc = accumulate ? *out_ptr : 0.f;
+ for (iree_ukernel_size_t k = 0; k < params.K; ++k) {
+ for (iree_ukernel_size_t k0 = 0; k0 < params.K0; ++k0) {
+ out_t lhs_val = lhs_tile_ptr[i0 * params.K0 + k0];
+ out_t rhs_val = rhs_tile_ptr[j0 * params.K0 + k0];
+ acc += lhs_val * rhs_val;
+ }
+ lhs_tile_ptr += lhs_tile_size;
+ rhs_tile_ptr += rhs_tile_size;
+ }
+ *out_ptr = acc;
+ }
+ }
+ }
+ }
+}
+
+static void iree_mmt4d_reference(const iree_ukernel_mmt4d_params_t& params) {
+ switch (params.type) {
+ case iree_ukernel_mmt4d_type_f32f32f32:
+ iree_mmt4d_reference<float, float, float>(params);
+ break;
+ case iree_ukernel_mmt4d_type_i8i8i32:
+ iree_mmt4d_reference<int8_t, int8_t, int32_t>(params);
+ break;
+ default:
+ assert(false && "unknown type");
+ }
+}
+
+static void test_one_matmul_using_given_lhs_rhs(
+ const iree_ukernel_mmt4d_params_t& shared_params,
+ iree_mmt4d_test_random_engine_t* engine) {
+ assert(!shared_params.out_buffer);
+
+ iree_ukernel_mmt4d_params_t reference_params;
+ memcpy(&reference_params, &shared_params, sizeof shared_params);
+ iree_ukernel_size_t out_buffer_size =
+ iree_ukernel_mmt4d_out_buffer_size(&shared_params);
+ reference_params.out_buffer = malloc(out_buffer_size);
+ iree_mmt4d_scalar_type_t out_type =
+ iree_ukernel_mmt4d_out_type(&shared_params);
+ write_random_buffer(reference_params.out_buffer, out_buffer_size, out_type,
+ engine);
+
+ iree_ukernel_mmt4d_params_t actual_params;
+ memcpy(&actual_params, &shared_params, sizeof shared_params);
+ actual_params.out_buffer = malloc(out_buffer_size);
+ memcpy(actual_params.out_buffer, reference_params.out_buffer,
+ out_buffer_size);
+
+ iree_mmt4d_reference(reference_params);
+ iree_ukernel_mmt4d_status_t status = iree_ukernel_mmt4d(&actual_params);
+ if (status != iree_ukernel_mmt4d_status_ok) {
+ fprintf(stderr, "FATAL: iree_ukernel_mmt4d failed: %s\n",
+ iree_ukernel_mmt4d_status_message(status));
+ abort();
+ }
+
+ // For now we use exact comparisons, even for float, even though the reference
+ // code accumulates in a different order compared to the actual code. This
+ // relies on picking input test matrix elements so that all intermediate
+ // values are exactly representable - i.e. small integer numerators. This
+ // become problematic when we do float16. See the comment at the top of this
+ // file explaining how we refrain from letting this grow into a 1000-line-long
+ // fully-featured test.
+ if (memcmp(actual_params.out_buffer, reference_params.out_buffer,
+ out_buffer_size)) {
+ const auto& p = actual_params;
+ fprintf(stderr, "mmt4d test failure with the following params:\n");
+ fprintf(stderr, " type=%s\n", get_mmt4d_type_str(&p));
+ fprintf(stderr, " flags: accumulate=%d\n",
+ (int)(p.flags & IREE_VMVX_MATMUL_FLAG_ACCUMULATE));
+ fprintf(stderr, " M=%d, N=%d, K=%d\n", (int)p.M, (int)p.N, (int)p.K);
+ fprintf(stderr, " M0=%d, N0=%d, K0=%d\n", (int)p.M0, (int)p.N0, (int)p.K0);
+ fprintf(stderr, " lhs_stride=%zu, rhs_stride=%zu, out_stride=%zu\n",
+ (size_t)p.lhs_stride, (size_t)p.rhs_stride, (size_t)p.out_stride);
+ fprintf(stderr, " cpu features: %s\n", get_cpu_features_str(&p));
+ // Don't even try to pretty-print matrices. See the comment at the top of
+ // this file. Don't try to use GTest primitives to show expected vs actual
+ // since that would require dispatching to type-specific code paths.
+ // Also, at this point it's easy for the user to rerun this test
+ // in a debugger and manually inspect values.
+ //
+ // We want fatal here - that is what the user running this in a debugger
+ // wants us to do, so they can inspect values while they exist in memory.
+ // What's the GTest-sanctioned fatal error? GTEST_FAIL() has a comment that
+ // says that it's fatal, but that's a lie at least here on Android.
+ abort();
+ }
+
+ free(reference_params.out_buffer);
+ free(actual_params.out_buffer);
+}
+
+static void test_one_matmul_creating_lhs_rhs_for_given_shape(
+ const iree_ukernel_mmt4d_params_t& shared_params,
+ iree_mmt4d_test_random_engine_t* engine) {
+ iree_ukernel_mmt4d_params_t params;
+ memcpy(¶ms, &shared_params, sizeof params);
+ assert(!params.lhs_buffer);
+ assert(!params.rhs_buffer);
+ assert(!params.out_buffer);
+ assert(!params.lhs_stride);
+ assert(!params.rhs_stride);
+ assert(!params.out_stride);
+ // Populate strides first - they are read by the get_*_buffer_size helper.
+ // Randomly make strides either tight or not to exercise all cases.
+ params.lhs_stride = params.K * params.M0 * params.K0 +
+ iree_mmt4d_test_random_engine_get_0_or_1(engine);
+ params.rhs_stride = params.K * params.N0 * params.K0 +
+ iree_mmt4d_test_random_engine_get_0_or_1(engine);
+ params.out_stride = params.N * params.M0 * params.N0 +
+ iree_mmt4d_test_random_engine_get_0_or_1(engine);
+ iree_ukernel_size_t lhs_buffer_size =
+ iree_ukernel_mmt4d_lhs_buffer_size(¶ms);
+ iree_ukernel_size_t rhs_buffer_size =
+ iree_ukernel_mmt4d_rhs_buffer_size(¶ms);
+ iree_mmt4d_scalar_type_t lhs_type = iree_ukernel_mmt4d_lhs_type(¶ms);
+ iree_mmt4d_scalar_type_t rhs_type = iree_ukernel_mmt4d_rhs_type(¶ms);
+ void* lhs_buffer = malloc(lhs_buffer_size);
+ void* rhs_buffer = malloc(rhs_buffer_size);
+ write_random_buffer(lhs_buffer, lhs_buffer_size, lhs_type, engine);
+ write_random_buffer(rhs_buffer, rhs_buffer_size, rhs_type, engine);
+ params.lhs_buffer = lhs_buffer;
+ params.rhs_buffer = rhs_buffer;
+ test_one_matmul_using_given_lhs_rhs(params, engine);
+ free(lhs_buffer);
+ free(rhs_buffer);
+}
+
+static void test_matmuls_for_various_MNK_shapes_and_flags(
+ const iree_ukernel_mmt4d_params_t& shared_params,
+ iree_mmt4d_test_random_engine_t* engine) {
+ iree_ukernel_mmt4d_params_t params;
+ memcpy(¶ms, &shared_params, sizeof params);
+ assert(params.M == 0);
+ assert(params.N == 0);
+ assert(params.K == 0);
+ assert(params.flags == 0);
+ struct shape_mnk_t {
+ int m, n, k;
+ };
+ std::vector<shape_mnk_t> shapes{
+ {1, 1, 1}, {1, 1, 2}, {1, 1, 10}, {1, 1, 1000},
+ {2, 1, 1}, {1, 2, 1}, {2, 2, 2}, {5, 7, 13},
+ };
+ for (shape_mnk_t shape : shapes) {
+ params.M = shape.m;
+ params.N = shape.n;
+ params.K = shape.k;
+ for (bool accumulate : {false, true}) {
+ params.flags = accumulate ? IREE_VMVX_MATMUL_FLAG_ACCUMULATE : 0;
+ test_one_matmul_creating_lhs_rhs_for_given_shape(params, engine);
+ }
+ }
+}
+
+// Tests mmt4d with the specific data type and specific M0xN0xK0 tile format.
+// If cpu_data_field_0_bit is nonzero, it must then be a single bit (power of 2)
+// and if the CPU supports the corresponding feature, the mmt4d tests are run a
+// second time with that CPU feature enabled.
+static void mmt4d_test(iree_ukernel_mmt4d_type_t type, int M0, int N0, int K0,
+ uint64_t cpu_data_field_0_bit) {
+ // Letting each test create its own engine makes them independent: a testcase
+ // succeeds or fails the same way if we isolate it or reorder it. The
+ // potential downside of repeating the same pseudorandom sequence is OK
+ // because any pseudorandom sequence should be equally good at coverage, and
+ // different testcases tend to use different tile shapes anyway.
+ iree_mmt4d_test_random_engine_t* engine =
+ iree_mmt4d_test_random_engine_create();
+ iree_ukernel_mmt4d_params_t params;
memset(¶ms, 0, sizeof params);
- EXPECT_EQ(0, iree_ukernel_mmt4d_f32f32f32(¶ms));
+ params.type = type;
+ params.M0 = M0;
+ params.N0 = N0;
+ params.K0 = K0;
+ // First try without any optional CPU feature. This matters even when the
+ // feature is supported by the CPU because we want to test the fallback to
+ // architecture-default or generic code.
+ test_matmuls_for_various_MNK_shapes_and_flags(params, engine);
+ // If this is nonzero, we are asked to test again with this CPU feature.
+ if (cpu_data_field_0_bit) {
+ // Check if the CPU supports the feature (otherwise, we crash).
+ params.cpu_data_field_0 = cpu_data_field_0_bit;
+ bool supported = iree_cpu_data_field(0) & params.cpu_data_field_0;
+ if (supported) {
+ // Run with the optional CPU feature.
+ fprintf(stderr, "Device supports CPU feature: %s\n",
+ get_cpu_features_str(¶ms));
+ test_matmuls_for_various_MNK_shapes_and_flags(params, engine);
+ } else {
+ fprintf(stderr, "Skipped: device does not support CPU feature: %s\n",
+ get_cpu_features_str(¶ms));
+ }
+ }
+ iree_mmt4d_test_random_engine_destroy(engine);
+}
+
+#define MMT4D_TEST(type, M0, N0, K0, test_suffix, feature_bit) \
+ TEST(Mmt4dTest, type##_tile_##M0##x##N0##x##K0##_##test_suffix) { \
+ mmt4d_test(iree_ukernel_mmt4d_type_##type, M0, N0, K0, feature_bit); \
+ }
+
+// Generic tests, not matching any particular CPU feature. This is the place to
+// test weird M0, N0, K0 to ensure e.g. that we haven't unwittingly baked in a
+// power-of-two assumption
+MMT4D_TEST(f32f32f32, 3, 5, 7, generic, 0)
+MMT4D_TEST(i8i8i32, 9, 6, 3, generic, 0)
+
+// ARM_64 tests.
+#if defined(IREE_UKERNEL_ARCH_ARM_64)
+
+#define MMT4D_ARM_64_TEST(type, M0, N0, K0, FEATURE) \
+ MMT4D_TEST(type, M0, N0, K0, arm_64_##FEATURE, \
+ IREE_CPU_DATA_FIELD_0_AARCH64_HAVE_##FEATURE)
+
+MMT4D_ARM_64_TEST(i8i8i32, 8, 8, 4, DOTPROD)
+MMT4D_ARM_64_TEST(i8i8i32, 8, 8, 8, I8MM)
+#endif // defined(IREE_UKERNEL_ARCH_ARM_64)
+
+int main(int argc, char** argv) {
+ ::testing::InitGoogleTest(&argc, argv);
+ iree_cpu_initialize(iree_allocator_system());
+ return RUN_ALL_TESTS();
}
diff --git a/runtime/src/iree/builtins/ukernel/tools/mmt4d_test_utils.cc b/runtime/src/iree/builtins/ukernel/tools/mmt4d_test_utils.cc
new file mode 100644
index 0000000..0a9f970
--- /dev/null
+++ b/runtime/src/iree/builtins/ukernel/tools/mmt4d_test_utils.cc
@@ -0,0 +1,162 @@
+// Copyright 2022 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/builtins/ukernel/tools/mmt4d_test_utils.h"
+
+#include <cassert>
+#include <random>
+
+#include "iree/schemas/cpu_data.h"
+
+iree_mmt4d_scalar_type_t iree_ukernel_mmt4d_lhs_type(
+ const iree_ukernel_mmt4d_params_t* params) {
+ switch (params->type) {
+ case iree_ukernel_mmt4d_type_f32f32f32:
+ return iree_mmt4d_scalar_type_f32;
+ case iree_ukernel_mmt4d_type_i8i8i32:
+ return iree_mmt4d_scalar_type_i8;
+ default:
+ assert(false && "unknown type");
+ return iree_mmt4d_scalar_type_unknown;
+ }
+}
+
+iree_mmt4d_scalar_type_t iree_ukernel_mmt4d_rhs_type(
+ const iree_ukernel_mmt4d_params_t* params) {
+ // same for now
+ return iree_ukernel_mmt4d_lhs_type(params);
+}
+
+iree_mmt4d_scalar_type_t iree_ukernel_mmt4d_out_type(
+ const iree_ukernel_mmt4d_params_t* params) {
+ switch (params->type) {
+ case iree_ukernel_mmt4d_type_f32f32f32:
+ return iree_mmt4d_scalar_type_f32;
+ case iree_ukernel_mmt4d_type_i8i8i32:
+ return iree_mmt4d_scalar_type_i32;
+ default:
+ assert(false && "unknown type");
+ return iree_mmt4d_scalar_type_unknown;
+ }
+}
+
+iree_ukernel_size_t iree_ukernel_mmt4d_lhs_buffer_size(
+ const iree_ukernel_mmt4d_params_t* params) {
+ return params->M * params->lhs_stride *
+ iree_ukernel_mmt4d_lhs_elem_size(params->type);
+}
+
+iree_ukernel_size_t iree_ukernel_mmt4d_rhs_buffer_size(
+ const iree_ukernel_mmt4d_params_t* params) {
+ return params->N * params->rhs_stride *
+ iree_ukernel_mmt4d_rhs_elem_size(params->type);
+}
+
+iree_ukernel_size_t iree_ukernel_mmt4d_out_buffer_size(
+ const iree_ukernel_mmt4d_params_t* params) {
+ return params->M * params->out_stride *
+ iree_ukernel_mmt4d_out_elem_size(params->type);
+}
+
+struct iree_mmt4d_test_random_engine_t {
+ std::minstd_rand cpp_random_engine;
+};
+
+iree_mmt4d_test_random_engine_t* iree_mmt4d_test_random_engine_create() {
+ return new iree_mmt4d_test_random_engine_t;
+}
+
+void iree_mmt4d_test_random_engine_destroy(iree_mmt4d_test_random_engine_t* e) {
+ delete e;
+}
+
+static int iree_mmt4d_test_random_engine_get_in_uint16_range(
+ iree_mmt4d_test_random_engine_t* e) {
+ uint32_t v = e->cpp_random_engine();
+ // return the second-least-signicant out of the 4 bytes of state. It avoids
+ // some mild issues with the least-significant and most-significant bytes.
+ return (v >> 8) & 0xffff;
+}
+
+int iree_mmt4d_test_random_engine_get_0_or_1(
+ iree_mmt4d_test_random_engine_t* e) {
+ int v = iree_mmt4d_test_random_engine_get_in_uint16_range(e);
+ return v & 1;
+}
+
+int iree_mmt4d_test_random_engine_get_between_minus16_and_plus15(
+ iree_mmt4d_test_random_engine_t* e) {
+ int v = iree_mmt4d_test_random_engine_get_in_uint16_range(e);
+ return (v % 32) - 16;
+}
+
+template <typename T>
+static void write_random_buffer(T* buffer, iree_ukernel_size_t size_in_bytes,
+ iree_mmt4d_test_random_engine_t* engine) {
+ iree_ukernel_size_t size_in_elems = size_in_bytes / sizeof(T);
+ assert(size_in_elems * sizeof(T) == size_in_bytes && "bad size");
+ for (iree_ukernel_size_t i = 0; i < size_in_elems; ++i) {
+ // Small integers, should work for now for all the types we currently have
+ // and enable exact float arithmetic, allowing to keep tests simpler for
+ // now. Watch out for when we'll do float16!
+ T random_val =
+ iree_mmt4d_test_random_engine_get_between_minus16_and_plus15(engine);
+ buffer[i] = random_val;
+ }
+}
+
+void write_random_buffer(void* buffer, iree_ukernel_size_t size_in_bytes,
+ iree_mmt4d_scalar_type_t type,
+ iree_mmt4d_test_random_engine_t* engine) {
+ switch (type) {
+ case iree_mmt4d_scalar_type_f32:
+ write_random_buffer(static_cast<float*>(buffer), size_in_bytes, engine);
+ return;
+ case iree_mmt4d_scalar_type_i32:
+ write_random_buffer(static_cast<int32_t*>(buffer), size_in_bytes, engine);
+ return;
+ case iree_mmt4d_scalar_type_i8:
+ write_random_buffer(static_cast<int8_t*>(buffer), size_in_bytes, engine);
+ return;
+ default:
+ assert(false && "unknown type");
+ }
+}
+
+const char* get_mmt4d_type_str(const iree_ukernel_mmt4d_params_t* params) {
+ switch (params->type) {
+#define GET_MMT4D_TYPE_STR_CASE(x) \
+ case x: \
+ return #x;
+ GET_MMT4D_TYPE_STR_CASE(iree_ukernel_mmt4d_type_f32f32f32);
+ GET_MMT4D_TYPE_STR_CASE(iree_ukernel_mmt4d_type_i8i8i32);
+ default:
+ assert(false && "unknown type");
+ return "unknown type";
+ }
+}
+
+const char* get_cpu_features_str(const iree_ukernel_mmt4d_params_t* params) {
+ // We set only one feature bit at a time in this test --- not an actual
+ // detected cpu data field. This might have to change in the future if some
+ // code path relies on the combination of two features.
+ // For now, asserting only one bit set, and taking advantage of that to work
+ // with plain string literals.
+ assert(0 == (params->cpu_data_field_0 & (params->cpu_data_field_0 - 1)));
+ if (params->cpu_data_field_0 == 0) {
+ return "(none)";
+ }
+#if defined(IREE_UKERNEL_ARCH_ARM_64)
+ if (params->cpu_data_field_0 & IREE_CPU_DATA_FIELD_0_AARCH64_HAVE_I8MM) {
+ return "i8mm";
+ }
+ if (params->cpu_data_field_0 & IREE_CPU_DATA_FIELD_0_AARCH64_HAVE_DOTPROD) {
+ return "dotprod";
+ }
+#endif // defined(IREE_UKERNEL_ARCH_ARM_64)
+ assert(false && "unknown CPU feature");
+ return "unknown CPU feature";
+}
diff --git a/runtime/src/iree/builtins/ukernel/tools/mmt4d_test_utils.h b/runtime/src/iree/builtins/ukernel/tools/mmt4d_test_utils.h
new file mode 100644
index 0000000..0f45f0e
--- /dev/null
+++ b/runtime/src/iree/builtins/ukernel/tools/mmt4d_test_utils.h
@@ -0,0 +1,63 @@
+// Copyright 2022 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_BUILTINS_UKERNEL_TOOLS_MMT4D_TEST_UTILS_H_
+#define IREE_BUILTINS_UKERNEL_TOOLS_MMT4D_TEST_UTILS_H_
+
+// clang-format off
+#include <stdint.h> // include before ukernel/common.h to keep standard types
+// clang-format on
+
+#include "iree/builtins/ukernel/mmt4d_types.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+enum iree_mmt4d_scalar_type_t {
+ iree_mmt4d_scalar_type_unknown,
+ iree_mmt4d_scalar_type_i8,
+ iree_mmt4d_scalar_type_i32,
+ iree_mmt4d_scalar_type_f32,
+};
+
+typedef enum iree_mmt4d_scalar_type_t iree_mmt4d_scalar_type_t;
+
+iree_mmt4d_scalar_type_t iree_ukernel_mmt4d_lhs_type(
+ const iree_ukernel_mmt4d_params_t* params);
+iree_mmt4d_scalar_type_t iree_ukernel_mmt4d_rhs_type(
+ const iree_ukernel_mmt4d_params_t* params);
+iree_mmt4d_scalar_type_t iree_ukernel_mmt4d_out_type(
+ const iree_ukernel_mmt4d_params_t* params);
+
+iree_ukernel_size_t iree_ukernel_mmt4d_lhs_buffer_size(
+ const iree_ukernel_mmt4d_params_t* params);
+iree_ukernel_size_t iree_ukernel_mmt4d_rhs_buffer_size(
+ const iree_ukernel_mmt4d_params_t* params);
+iree_ukernel_size_t iree_ukernel_mmt4d_out_buffer_size(
+ const iree_ukernel_mmt4d_params_t* params);
+
+struct iree_mmt4d_test_random_engine_t;
+typedef struct iree_mmt4d_test_random_engine_t iree_mmt4d_test_random_engine_t;
+iree_mmt4d_test_random_engine_t* iree_mmt4d_test_random_engine_create();
+void iree_mmt4d_test_random_engine_destroy(iree_mmt4d_test_random_engine_t* e);
+int iree_mmt4d_test_random_engine_get_0_or_1(
+ iree_mmt4d_test_random_engine_t* e);
+int iree_mmt4d_test_random_engine_get_between_minus16_and_plus15(
+ iree_mmt4d_test_random_engine_t* e);
+
+void write_random_buffer(void* buffer, iree_ukernel_size_t size_in_bytes,
+ iree_mmt4d_scalar_type_t type,
+ iree_mmt4d_test_random_engine_t* engine);
+
+const char* get_mmt4d_type_str(const iree_ukernel_mmt4d_params_t* params);
+const char* get_cpu_features_str(const iree_ukernel_mmt4d_params_t* params);
+
+#ifdef __cplusplus
+} // extern "C"
+#endif
+
+#endif // IREE_BUILTINS_UKERNEL_TOOLS_MMT4D_TEST_UTILS_H_
diff --git a/runtime/src/iree/hal/drivers/local_task/task_semaphore.c b/runtime/src/iree/hal/drivers/local_task/task_semaphore.c
index ca2a298..9d809fa 100644
--- a/runtime/src/iree/hal/drivers/local_task/task_semaphore.c
+++ b/runtime/src/iree/hal/drivers/local_task/task_semaphore.c
@@ -238,6 +238,7 @@
&cmd->timepoint.base);
}
iree_event_pool_release(cmd->semaphore->event_pool, 1, &cmd->timepoint.event);
+ iree_hal_semaphore_release((iree_hal_semaphore_t*)cmd->semaphore);
}
iree_status_t iree_hal_task_semaphore_enqueue_timepoint(
@@ -271,6 +272,7 @@
iree_hal_task_semaphore_wait_cmd_cleanup);
iree_task_set_completion_task(&cmd->task.header, issue_task);
cmd->semaphore = semaphore;
+ iree_hal_semaphore_retain(base_semaphore);
iree_task_submission_enqueue(submission, &cmd->task.header);
}
}
diff --git a/runtime/src/iree/modules/hal/exports.inl b/runtime/src/iree/modules/hal/exports.inl
index 7aa0191..f52c94d 100644
--- a/runtime/src/iree/modules/hal/exports.inl
+++ b/runtime/src/iree/modules/hal/exports.inl
@@ -36,7 +36,7 @@
EXPORT_FN("buffer_view.assert", iree_hal_module_buffer_view_assert, rriiCID, v)
EXPORT_FN("buffer_view.buffer", iree_hal_module_buffer_view_buffer, r, r)
-EXPORT_FN("buffer_view.create", iree_hal_module_buffer_view_create, riiCID, r)
+EXPORT_FN("buffer_view.create", iree_hal_module_buffer_view_create, rIIiiCID, r)
EXPORT_FN("buffer_view.dim", iree_hal_module_buffer_view_dim, ri, I)
EXPORT_FN("buffer_view.element_type", iree_hal_module_buffer_view_element_type, r, i)
EXPORT_FN("buffer_view.encoding_type", iree_hal_module_buffer_view_encoding_type, r, i)
@@ -70,17 +70,12 @@
EXPORT_FN("executable.create", iree_hal_module_executable_create, rrrrCrD, r)
EXPORT_FN("fence.await", iree_hal_module_fence_await, iCrD, i)
-EXPORT_FN("fence.create", iree_hal_module_fence_create, CrID, r)
+EXPORT_FN("fence.create", iree_hal_module_fence_create, ri, r)
EXPORT_FN("fence.fail", iree_hal_module_fence_signal, ri, v)
EXPORT_FN("fence.join", iree_hal_module_fence_join, CrD, r)
+EXPORT_FN("fence.query", iree_hal_module_fence_query, r, i)
EXPORT_FN("fence.signal", iree_hal_module_fence_signal, r, v)
EXPORT_FN("pipeline_layout.create", iree_hal_module_pipeline_layout_create, riCrD, r)
-EXPORT_FN("semaphore.await", iree_hal_module_semaphore_await, rI, i)
-EXPORT_FN("semaphore.create", iree_hal_module_semaphore_create, rI, r)
-EXPORT_FN("semaphore.fail", iree_hal_module_semaphore_fail, r, i)
-EXPORT_FN("semaphore.query", iree_hal_module_semaphore_query, r, iI)
-EXPORT_FN("semaphore.signal", iree_hal_module_semaphore_signal, rI, v)
-
// clang-format on
diff --git a/runtime/src/iree/modules/hal/inline/exports.inl b/runtime/src/iree/modules/hal/inline/exports.inl
index 40f80ce..c45a5f8 100644
--- a/runtime/src/iree/modules/hal/inline/exports.inl
+++ b/runtime/src/iree/modules/hal/inline/exports.inl
@@ -33,7 +33,7 @@
EXPORT_FN("buffer_view.assert", iree_hal_inline_module_buffer_view_assert, rriiCID, v)
EXPORT_FN("buffer_view.buffer", iree_hal_inline_module_buffer_view_buffer, r, r)
-EXPORT_FN("buffer_view.create", iree_hal_inline_module_buffer_view_create, riiCID, r)
+EXPORT_FN("buffer_view.create", iree_hal_inline_module_buffer_view_create, rIIiiCID, r)
EXPORT_FN("buffer_view.dim", iree_hal_inline_module_buffer_view_dim, ri, I)
EXPORT_FN("buffer_view.element_type", iree_hal_inline_module_buffer_view_element_type, r, i)
EXPORT_FN("buffer_view.encoding_type", iree_hal_inline_module_buffer_view_encoding_type, r, i)
diff --git a/runtime/src/iree/modules/hal/inline/module.c b/runtime/src/iree/modules/hal/inline/module.c
index c5c1ba6..d643a52 100644
--- a/runtime/src/iree/modules/hal/inline/module.c
+++ b/runtime/src/iree/modules/hal/inline/module.c
@@ -390,20 +390,37 @@
IREE_VM_ABI_EXPORT(iree_hal_inline_module_buffer_view_create, //
iree_hal_inline_module_state_t, //
- riiCID, r) {
+ rIIiiCID, r) {
iree_hal_buffer_t* source_buffer = NULL;
IREE_RETURN_IF_ERROR(iree_hal_buffer_check_deref(args->r0, &source_buffer));
- iree_hal_element_type_t element_type = (iree_hal_element_type_t)args->i1;
- iree_hal_encoding_type_t encoding_type = (iree_hal_encoding_type_t)args->i2;
+ iree_device_size_t source_offset = iree_hal_cast_device_size(args->i1);
+ iree_device_size_t source_length = iree_hal_cast_device_size(args->i2);
+ iree_hal_element_type_t element_type = (iree_hal_element_type_t)args->i3;
+ iree_hal_encoding_type_t encoding_type = (iree_hal_encoding_type_t)args->i4;
iree_host_size_t shape_rank = 0;
iree_hal_dim_t* shape_dims = NULL;
// TODO(benvanik): avoid the cast/alloca if not required.
- IREE_VM_ABI_VLA_STACK_CAST(args, a3_count, a3, iree_hal_dim_t, 128,
+ IREE_VM_ABI_VLA_STACK_CAST(args, a5_count, a5, iree_hal_dim_t, 128,
&shape_rank, &shape_dims);
+
+ iree_hal_buffer_t* subspan_buffer = NULL;
+ if (source_offset != 0 ||
+ source_length != iree_hal_buffer_byte_length(source_buffer)) {
+ IREE_RETURN_IF_ERROR(
+ iree_hal_buffer_subspan(source_buffer, source_offset, source_length,
+ &subspan_buffer),
+ "invalid subspan of an existing buffer (source_offset=%" PRIdsz
+ ", length=%" PRIdsz ")",
+ source_offset, source_length);
+ }
+
iree_hal_buffer_view_t* buffer_view = NULL;
IREE_RETURN_IF_ERROR(iree_hal_buffer_view_create(
- source_buffer, shape_rank, shape_dims, element_type, encoding_type,
- state->host_allocator, &buffer_view));
+ subspan_buffer ? subspan_buffer : source_buffer, shape_rank, shape_dims,
+ element_type, encoding_type, state->host_allocator, &buffer_view));
+
+ iree_hal_buffer_release(subspan_buffer);
+
rets->r0 = iree_hal_buffer_view_move_ref(buffer_view);
return iree_ok_status();
}
diff --git a/runtime/src/iree/modules/hal/module.c b/runtime/src/iree/modules/hal/module.c
index d758312..1217b34 100644
--- a/runtime/src/iree/modules/hal/module.c
+++ b/runtime/src/iree/modules/hal/module.c
@@ -420,21 +420,37 @@
IREE_VM_ABI_EXPORT(iree_hal_module_buffer_view_create, //
iree_hal_module_state_t, //
- riiCID, r) {
+ rIIiiCID, r) {
iree_hal_buffer_t* source_buffer = NULL;
IREE_RETURN_IF_ERROR(iree_hal_buffer_check_deref(args->r0, &source_buffer));
- iree_hal_element_type_t element_type = (iree_hal_element_type_t)args->i1;
- iree_hal_encoding_type_t encoding_type = (iree_hal_encoding_type_t)args->i2;
+ iree_device_size_t source_offset = iree_hal_cast_device_size(args->i1);
+ iree_device_size_t source_length = iree_hal_cast_device_size(args->i2);
+ iree_hal_element_type_t element_type = (iree_hal_element_type_t)args->i3;
+ iree_hal_encoding_type_t encoding_type = (iree_hal_encoding_type_t)args->i4;
iree_host_size_t shape_rank = 0;
iree_hal_dim_t* shape_dims = NULL;
// TODO(benvanik): avoid the cast/alloca if not required.
- IREE_VM_ABI_VLA_STACK_CAST(args, a3_count, a3, iree_hal_dim_t, 128,
+ IREE_VM_ABI_VLA_STACK_CAST(args, a5_count, a5, iree_hal_dim_t, 128,
&shape_rank, &shape_dims);
+ iree_hal_buffer_t* subspan_buffer = NULL;
+ if (source_offset != 0 ||
+ source_length != iree_hal_buffer_byte_length(source_buffer)) {
+ IREE_RETURN_IF_ERROR(
+ iree_hal_buffer_subspan(source_buffer, source_offset, source_length,
+ &subspan_buffer),
+ "invalid subspan of an existing buffer (source_offset=%" PRIdsz
+ ", length=%" PRIdsz ")",
+ source_offset, source_length);
+ }
+
iree_hal_buffer_view_t* buffer_view = NULL;
IREE_RETURN_IF_ERROR(iree_hal_buffer_view_create(
- source_buffer, shape_rank, shape_dims, element_type, encoding_type,
- state->host_allocator, &buffer_view));
+ subspan_buffer ? subspan_buffer : source_buffer, shape_rank, shape_dims,
+ element_type, encoding_type, state->host_allocator, &buffer_view));
+
+ iree_hal_buffer_release(subspan_buffer);
+
rets->r0 = iree_hal_buffer_view_move_ref(buffer_view);
return iree_ok_status();
}
@@ -984,28 +1000,27 @@
IREE_VM_ABI_EXPORT(iree_hal_module_fence_create, //
iree_hal_module_state_t, //
- CrID, r) {
- // Create fence with enough capacity to store all the timepoints.
- // The count may end up lower if some are deduplicated.
- iree_hal_fence_t* fence = NULL;
- IREE_RETURN_IF_ERROR(
- iree_hal_fence_create(args->a0_count, state->host_allocator, &fence));
+ ri, r) {
+ iree_hal_device_t* device = NULL;
+ IREE_RETURN_IF_ERROR(iree_hal_device_check_deref(args->r0, &device));
+ uint32_t fence_flags = args->i1;
+ (void)fence_flags;
- // Insert each timepoint into the fence.
- // This will deduplicate the semaphores and max their values.
- // The compiler already does this but we want to ensure that invariant is met
- // and don't trust the user code - at the point we'd have to verify
- // correctness it's easier just to use the same code path as insertion.
- iree_status_t status = iree_ok_status();
- for (iree_host_size_t i = 0; i < args->a0_count; ++i) {
- iree_hal_semaphore_t* semaphore = NULL;
- status = iree_hal_semaphore_check_deref(args->a0[i].r0, &semaphore);
- if (!iree_status_is_ok(status)) break;
- uint64_t min_value = args->a0[i].i1;
- status = iree_hal_fence_insert(fence, semaphore, min_value);
- if (!iree_status_is_ok(status)) break;
+ // TODO(benvanik): hide semaphores from the API.
+ // This should be reworked to just create the fence.
+
+ iree_hal_semaphore_t* semaphore = NULL;
+ IREE_RETURN_IF_ERROR(iree_hal_semaphore_create(device, 0ull, &semaphore));
+
+ // Create fence with room for our single semaphore.
+ iree_hal_fence_t* fence = NULL;
+ iree_status_t status =
+ iree_hal_fence_create(1, state->host_allocator, &fence);
+ if (iree_status_is_ok(status)) {
+ status = iree_hal_fence_insert(fence, semaphore, 1ull);
}
+ iree_hal_semaphore_release(semaphore);
if (iree_status_is_ok(status)) {
rets->r0 = iree_hal_fence_move_ref(fence);
} else {
@@ -1030,6 +1045,19 @@
return iree_ok_status();
}
+IREE_VM_ABI_EXPORT(iree_hal_module_fence_query, //
+ iree_hal_module_state_t, //
+ r, i) {
+ iree_hal_fence_t* fence = NULL;
+ IREE_RETURN_IF_ERROR(iree_hal_fence_check_deref(args->r0, &fence));
+
+ iree_status_t query_status = iree_hal_fence_query(fence);
+ rets->i0 = iree_status_consume_code(query_status);
+ iree_status_ignore(query_status);
+
+ return iree_ok_status();
+}
+
IREE_VM_ABI_EXPORT(iree_hal_module_fence_signal, //
iree_hal_module_state_t, //
r, v) {
@@ -1272,152 +1300,6 @@
}
//===----------------------------------------------------------------------===//
-// iree_hal_semaphore_t
-//===----------------------------------------------------------------------===//
-
-IREE_VM_ABI_EXPORT(iree_hal_module_semaphore_create, //
- iree_hal_module_state_t, //
- rI, r) {
- iree_hal_device_t* device = NULL;
- IREE_RETURN_IF_ERROR(iree_hal_device_check_deref(args->r0, &device));
- uint64_t initial_value = (uint64_t)args->i1;
- iree_hal_semaphore_t* semaphore = NULL;
- IREE_RETURN_IF_ERROR(
- iree_hal_semaphore_create(device, initial_value, &semaphore));
- rets->r0 = iree_hal_semaphore_move_ref(semaphore);
- return iree_ok_status();
-}
-
-IREE_VM_ABI_EXPORT(iree_hal_module_semaphore_query, //
- iree_hal_module_state_t, //
- r, iI) {
- iree_hal_semaphore_t* semaphore = NULL;
- IREE_RETURN_IF_ERROR(iree_hal_semaphore_check_deref(args->r0, &semaphore));
- uint64_t current_value = 0;
- iree_status_t query_status =
- iree_hal_semaphore_query(semaphore, ¤t_value);
- rets->i0 = iree_status_consume_code(query_status);
- rets->i1 = current_value;
- iree_status_ignore(query_status);
- return iree_ok_status();
-}
-
-IREE_VM_ABI_EXPORT(iree_hal_module_semaphore_signal, //
- iree_hal_module_state_t, //
- rI, v) {
- iree_hal_semaphore_t* semaphore = NULL;
- IREE_RETURN_IF_ERROR(iree_hal_semaphore_check_deref(args->r0, &semaphore));
- uint64_t new_value = (uint64_t)args->i1;
- return iree_hal_semaphore_signal(semaphore, new_value);
-}
-
-IREE_VM_ABI_EXPORT(iree_hal_module_semaphore_fail, //
- iree_hal_module_state_t, //
- ri, v) {
- iree_hal_semaphore_t* semaphore = NULL;
- IREE_RETURN_IF_ERROR(iree_hal_semaphore_check_deref(args->r0, &semaphore));
- iree_status_code_t status_code =
- (iree_status_code_t)(args->i1 & IREE_STATUS_CODE_MASK);
- iree_hal_semaphore_fail(semaphore, iree_make_status(status_code));
- return iree_ok_status();
-}
-
-// PC for iree_hal_module_semaphore_await.
-enum iree_hal_module_semaphore_await_pc_e {
- // Initial entry point that will try to either wait inline or yield to the
- // scheduler with a wait-all operation.
- IREE_HAL_MODULE_SEMAPHORE_AWAIT_PC_BEGIN = 0,
- // Resume entry point after the scheduler wait has resolved (successfully or
- // otherwise).
- IREE_HAL_MODULE_SEMAPHORE_AWAIT_PC_RESUME,
-};
-
-IREE_VM_ABI_EXPORT(iree_hal_module_semaphore_await, //
- iree_hal_module_state_t, //
- rI, i) {
- // On entry we either perform the wait or begin a coroutine yield operation.
- // After resuming we check to see if the timepoint has been reached and
- // propagate the result.
- iree_vm_stack_frame_t* current_frame = iree_vm_stack_top(stack);
- iree_zone_id_t zone_id = 0;
- iree_status_t wait_status = iree_ok_status();
- if (current_frame->pc == IREE_HAL_MODULE_SEMAPHORE_AWAIT_PC_BEGIN) {
- iree_hal_semaphore_t* semaphore = NULL;
- IREE_RETURN_IF_ERROR(iree_hal_semaphore_check_deref(args->r0, &semaphore));
- uint64_t new_value = (uint64_t)args->i1;
-
- IREE_TRACE_ZONE_BEGIN(z0);
- zone_id = z0;
-
- // TODO(benvanik): take timeout as an argument.
- // Capture absolute timeout so that regardless of how long it takes us to
- // wait the user-perceived wait time remains the same.
- iree_timeout_t timeout = iree_infinite_timeout();
- iree_convert_timeout_to_absolute(&timeout);
-
- if (iree_all_bits_set(state->flags, IREE_HAL_MODULE_FLAG_SYNCHRONOUS)) {
- // Block the native thread until the fence is reached or the deadline is
- // exceeded.
- wait_status = iree_hal_semaphore_wait(semaphore, new_value, timeout);
- } else {
- // Quick check inline before yielding to the scheduler. This avoids a
- // round-trip through the scheduling stack for cases where we complete
- // synchronously.
- //
- // The query may fail to indicate that the semaphore is in a failure
- // state and we propagate the failure status to the waiter.
- //
- // It's possible to race here if we get back an older value and then
- // before we wait the target is reached but that's ok: the wait will
- // always be correctly ordered.
- uint64_t current_value = 0ull;
- wait_status = iree_hal_semaphore_query(semaphore, ¤t_value);
- if (iree_status_is_ok(wait_status) && current_value < new_value) {
- // Enter a wait frame and yield execution back to the scheduler.
- // When the wait handle resolves we'll resume at the RESUME PC.
- iree_vm_wait_frame_t* wait_frame = NULL;
- IREE_RETURN_AND_END_ZONE_IF_ERROR(
- zone_id, iree_vm_stack_wait_enter(stack, IREE_VM_WAIT_ALL, 1,
- timeout, zone_id, &wait_frame));
- wait_frame->wait_sources[0] =
- iree_hal_semaphore_await(semaphore, new_value);
- current_frame->pc = IREE_HAL_MODULE_SEMAPHORE_AWAIT_PC_RESUME;
- wait_status = iree_status_from_code(IREE_STATUS_DEFERRED);
- zone_id = 0; // ownership transferred to wait frame
- }
- }
- } else {
- // Resume by leaving the wait frame and storing the result.
- iree_vm_wait_result_t wait_result;
- IREE_RETURN_IF_ERROR(iree_vm_stack_wait_leave(stack, &wait_result));
- wait_status = wait_result.status;
- IREE_TRACE(zone_id = wait_result.trace_zone);
- }
-
- iree_status_t status = iree_ok_status();
- if (iree_status_is_ok(wait_status)) {
- // Successful wait.
- rets->i0 = 0;
- } else if (iree_status_is_deferred(wait_status)) {
- // Yielding; resume required.
- // NOTE: zone not ended as it's reserved on the stack.
- status = wait_status;
- } else if (iree_status_is_deadline_exceeded(wait_status)) {
- // Propagate deadline exceeded back to the VM.
- rets->i0 = (int32_t)iree_status_consume_code(wait_status);
- iree_status_ignore(wait_status);
- } else {
- // Fail the invocation.
- status = wait_status;
- }
-
- IREE_TRACE({
- if (zone_id) IREE_TRACE_ZONE_END(zone_id);
- });
- return status;
-}
-
-//===----------------------------------------------------------------------===//
// VM module interface implementation
//===----------------------------------------------------------------------===//
diff --git a/runtime/src/iree/modules/vmvx/BUILD b/runtime/src/iree/modules/vmvx/BUILD
index 05ba13a..c4f44c5 100644
--- a/runtime/src/iree/modules/vmvx/BUILD
+++ b/runtime/src/iree/modules/vmvx/BUILD
@@ -29,6 +29,7 @@
deps = [
"//runtime/src/iree/base",
"//runtime/src/iree/base:tracing",
+ "//runtime/src/iree/base/internal:cpu",
"//runtime/src/iree/builtins/ukernel",
"//runtime/src/iree/vm",
],
diff --git a/runtime/src/iree/modules/vmvx/CMakeLists.txt b/runtime/src/iree/modules/vmvx/CMakeLists.txt
index 97d5239..82cdb8e 100644
--- a/runtime/src/iree/modules/vmvx/CMakeLists.txt
+++ b/runtime/src/iree/modules/vmvx/CMakeLists.txt
@@ -24,6 +24,7 @@
iree::base
iree::base::tracing
iree::builtins::ukernel
+ iree::base::internal::cpu
iree::vm
${_VMVX_OPTIONAL_DEPS}
PUBLIC
diff --git a/runtime/src/iree/modules/vmvx/module.c b/runtime/src/iree/modules/vmvx/module.c
index 9ae4b6f..9310d34 100644
--- a/runtime/src/iree/modules/vmvx/module.c
+++ b/runtime/src/iree/modules/vmvx/module.c
@@ -17,6 +17,7 @@
// Include the ukernel support library so that we can use its implementations
// as fixed-function components of the runtime.
+#include "iree/base/internal/cpu.h"
#include "iree/builtins/ukernel/elementwise.h"
#include "iree/builtins/ukernel/mmt4d.h"
@@ -104,54 +105,53 @@
return (iree_host_size_t)value;
}
-#define BUFFER_2D_DECLS(name, dtype, offset, stride0, stride1, size0, size1) \
- uint64_t name##_overflow = 0; \
- iree_host_size_t name##_size0 = \
- iree_vmvx_cast_host_size(size0, &name##_overflow); \
- iree_host_size_t name##_size1 = \
- iree_vmvx_cast_host_size(size1, &name##_overflow); \
- iree_host_size_t name##_stride0 = \
- iree_vmvx_cast_host_size(stride0, &name##_overflow); \
- iree_host_size_t name##_stride1 = \
- iree_vmvx_cast_host_size(stride1, &name##_overflow); \
- iree_host_size_t name##_length_bound = iree_vmvx_2d_length_bound( \
- sizeof(dtype), name##_size0, name##_size1, name##_stride0, \
- name##_stride1, &name##_overflow); \
- iree_host_size_t name##_offset = \
- sizeof(dtype) * iree_vmvx_cast_host_size(offset, &name##_overflow); \
- if (name##_overflow) { \
- IREE_TRACE_ZONE_END(z0); \
- return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, \
- "buffer overflow for " #name); \
+#define BUFFER_2D_DECLS(name, dtype_size, offset, stride0, stride1, size0, \
+ size1) \
+ uint64_t name##_overflow = 0; \
+ iree_host_size_t name##_size0 = \
+ iree_vmvx_cast_host_size(size0, &name##_overflow); \
+ iree_host_size_t name##_size1 = \
+ iree_vmvx_cast_host_size(size1, &name##_overflow); \
+ iree_host_size_t name##_stride0 = \
+ iree_vmvx_cast_host_size(stride0, &name##_overflow); \
+ iree_host_size_t name##_stride1 = \
+ iree_vmvx_cast_host_size(stride1, &name##_overflow); \
+ iree_host_size_t name##_length_bound = iree_vmvx_2d_length_bound( \
+ dtype_size, name##_size0, name##_size1, name##_stride0, name##_stride1, \
+ &name##_overflow); \
+ iree_host_size_t name##_offset = \
+ dtype_size * iree_vmvx_cast_host_size(offset, &name##_overflow); \
+ if (name##_overflow) { \
+ IREE_TRACE_ZONE_END(z0); \
+ return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, \
+ "buffer overflow for " #name); \
}
-#define MAP_BUFFER_2D_RO(name, dtype, buffer_ref, offset, stride0, stride1, \
- size0, size1) \
- iree_vm_buffer_t* name##_buffer; \
- iree_const_byte_span_t name##_span; \
- BUFFER_2D_DECLS(name, dtype, offset, stride0, stride1, size0, size1); \
- IREE_RETURN_AND_END_ZONE_IF_ERROR( \
- z0, iree_vm_buffer_check_deref(buffer_ref, &name##_buffer)) \
- IREE_RETURN_AND_END_ZONE_IF_ERROR( \
- z0, iree_vm_buffer_map_ro(name##_buffer, /*offset=*/ \
- name##_offset, /*length=*/ \
- name##_length_bound, /*alignment=*/ \
- sizeof(dtype), &name##_span)); \
- const dtype* name = (dtype*)name##_span.data
+#define MAP_BUFFER_2D_IMPL(mode, ptr_type, span_type, name, dtype_size, \
+ buffer_ref, offset, stride0, stride1, size0, size1) \
+ iree_vm_buffer_t* name##_buffer; \
+ span_type name##_span; \
+ BUFFER_2D_DECLS(name, dtype_size, offset, stride0, stride1, size0, size1); \
+ IREE_RETURN_AND_END_ZONE_IF_ERROR( \
+ z0, iree_vm_buffer_check_deref(buffer_ref, &name##_buffer)) \
+ IREE_RETURN_AND_END_ZONE_IF_ERROR( \
+ z0, iree_vm_buffer_map_##mode(name##_buffer, /*offset=*/ \
+ name##_offset, /*length=*/ \
+ name##_length_bound, /*alignment=*/ \
+ dtype_size, &name##_span)); \
+ ptr_type name = (ptr_type)name##_span.data
-#define MAP_BUFFER_2D_RW(name, dtype, buffer_ref, offset, stride0, stride1, \
- size0, size1) \
- iree_vm_buffer_t* name##_buffer; \
- iree_byte_span_t name##_span; \
- BUFFER_2D_DECLS(name, dtype, offset, stride0, stride1, size0, size1); \
- IREE_RETURN_AND_END_ZONE_IF_ERROR( \
- z0, iree_vm_buffer_check_deref(buffer_ref, &name##_buffer)); \
- IREE_RETURN_AND_END_ZONE_IF_ERROR( \
- z0, iree_vm_buffer_map_rw(name##_buffer, /*offset=*/ \
- name##_offset, /*length=*/ \
- name##_length_bound, \
- /*alignment=*/sizeof(dtype), &name##_span)); \
- dtype* name = (dtype*)name##_span.data
+#define MAP_BUFFER_2D_UNTYPED_RO(name, dtype_size, ...) \
+ MAP_BUFFER_2D_IMPL(ro, const void*, iree_const_byte_span_t, name, \
+ dtype_size, __VA_ARGS__)
+#define MAP_BUFFER_2D_UNTYPED_RW(name, dtype_size, ...) \
+ MAP_BUFFER_2D_IMPL(rw, void*, iree_byte_span_t, name, dtype_size, __VA_ARGS__)
+#define MAP_BUFFER_2D_RO(name, dtype, ...) \
+ MAP_BUFFER_2D_IMPL(ro, const dtype*, iree_const_byte_span_t, name, \
+ sizeof(dtype), __VA_ARGS__)
+#define MAP_BUFFER_2D_RW(name, dtype, ...) \
+ MAP_BUFFER_2D_IMPL(rw, dtype*, iree_byte_span_t, name, sizeof(dtype), \
+ __VA_ARGS__)
//===----------------------------------------------------------------------===//
// Shared argument shims
@@ -636,7 +636,8 @@
});
IREE_VMVX_ABI_DEFINE_SHIM(mmt4d, v);
-IREE_VMVX_ABI_EXPORT(iree_vmvx_mmt4d_f32f32f32, mmt4d, v) {
+static iree_status_t iree_vmvx_mmt4d(iree_ukernel_mmt4d_type_t type,
+ const iree_vm_abi_mmt4d_t* args) {
IREE_TRACE_ZONE_BEGIN(z0);
iree_host_size_t M = (iree_host_size_t)args->m;
iree_host_size_t N = (iree_host_size_t)args->n;
@@ -647,33 +648,39 @@
iree_host_size_t lhs_tile_size = M0 * K0;
iree_host_size_t rhs_tile_size = N0 * K0;
iree_host_size_t out_tile_size = M0 * N0;
+ int lhs_elem_size = iree_ukernel_mmt4d_lhs_elem_size(type);
+ int rhs_elem_size = iree_ukernel_mmt4d_rhs_elem_size(type);
+ int out_elem_size = iree_ukernel_mmt4d_out_elem_size(type);
// Here are abusing the 2D-specific macros MAP_BUFFER_2D_* to query 4D arrays.
// Thanks to the requirement that all dimensions but the outer-most one are
// contiguous row-major, the outer-most stride is the only nontrivial stride,
// we can correctly coalesce the inner 3 dimensions without changing the
// mapped span.
- MAP_BUFFER_2D_RO(lhs, float,
- /*buffer_ref=*/args->lhs_ref,
- /*offset=*/args->lhs_offset,
- /*stride0=*/args->lhs_row_stride,
- /*stride1=*/1,
- /*size0=*/M,
- /*size1=*/K* lhs_tile_size);
- MAP_BUFFER_2D_RO(rhs, float,
- /*buffer_ref=*/args->rhs_ref,
- /*offset=*/args->rhs_offset,
- /*stride0=*/args->rhs_row_stride,
- /*stride1=*/1,
- /*size0=*/N,
- /*size1=*/K* rhs_tile_size);
- MAP_BUFFER_2D_RW(out, float,
- /*buffer_ref=*/args->out_ref,
- /*offset=*/args->out_offset,
- /*stride0=*/args->out_row_stride,
- /*stride1=*/1,
- /*size0=*/M,
- /*size1=*/N* out_tile_size);
- iree_ukernel_mmt4d_f32f32f32_params_t ukernel_params = {
+ MAP_BUFFER_2D_UNTYPED_RO(lhs,
+ /*dtype_size=*/lhs_elem_size,
+ /*buffer_ref=*/args->lhs_ref,
+ /*offset=*/args->lhs_offset,
+ /*stride0=*/args->lhs_row_stride,
+ /*stride1=*/1,
+ /*size0=*/M,
+ /*size1=*/K * lhs_tile_size);
+ MAP_BUFFER_2D_UNTYPED_RO(rhs, /*dtype_size=*/rhs_elem_size,
+ /*buffer_ref=*/args->rhs_ref,
+ /*offset=*/args->rhs_offset,
+ /*stride0=*/args->rhs_row_stride,
+ /*stride1=*/1,
+ /*size0=*/N,
+ /*size1=*/K * rhs_tile_size);
+ MAP_BUFFER_2D_UNTYPED_RW(out, /*dtype_size=*/out_elem_size,
+ /*buffer_ref=*/args->out_ref,
+ /*offset=*/args->out_offset,
+ /*stride0=*/args->out_row_stride,
+ /*stride1=*/1,
+ /*size0=*/M,
+ /*size1=*/N * out_tile_size);
+ iree_ukernel_mmt4d_params_t ukernel_params = {
+ .type = type,
+ .flags = args->flags,
.lhs_buffer = lhs,
.rhs_buffer = rhs,
.out_buffer = out,
@@ -686,76 +693,23 @@
.M0 = M0,
.N0 = N0,
.K0 = K0,
- .flags = args->flags,
+ .cpu_data_field_0 = iree_cpu_data_field(0),
};
- int ukernel_retcode = iree_ukernel_mmt4d_f32f32f32(&ukernel_params);
+ iree_ukernel_mmt4d_status_t status = iree_ukernel_mmt4d(&ukernel_params);
IREE_TRACE_ZONE_END(z0);
- if (ukernel_retcode) {
+ if (status != iree_ukernel_mmt4d_status_ok) {
return iree_make_status(IREE_STATUS_INTERNAL,
- iree_ukernel_mmt4d_error_message(ukernel_retcode));
+ iree_ukernel_mmt4d_status_message(status));
}
return iree_ok_status();
}
+IREE_VMVX_ABI_EXPORT(iree_vmvx_mmt4d_f32f32f32, mmt4d, v) {
+ return iree_vmvx_mmt4d(iree_ukernel_mmt4d_type_f32f32f32, args);
+}
+
IREE_VMVX_ABI_EXPORT(iree_vmvx_mmt4d_i8i8i32, mmt4d, v) {
- IREE_TRACE_ZONE_BEGIN(z0);
- iree_host_size_t M = (iree_host_size_t)args->m;
- iree_host_size_t N = (iree_host_size_t)args->n;
- iree_host_size_t K = (iree_host_size_t)args->k;
- iree_host_size_t M0 = (iree_host_size_t)args->m0;
- iree_host_size_t N0 = (iree_host_size_t)args->n0;
- iree_host_size_t K0 = (iree_host_size_t)args->k0;
- iree_host_size_t lhs_tile_size = M0 * K0;
- iree_host_size_t rhs_tile_size = N0 * K0;
- iree_host_size_t out_tile_size = M0 * N0;
- // Here are abusing the 2D-specific macros MAP_BUFFER_2D_* to query 4D arrays.
- // Thanks to the requirement that all dimensions but the outer-most one are
- // contiguous row-major, the outer-most stride is the only nontrivial stride,
- // we can correctly coalesce the inner 3 dimensions without changing the
- // mapped span.
- MAP_BUFFER_2D_RO(lhs, int8_t,
- /*buffer_ref=*/args->lhs_ref,
- /*offset=*/args->lhs_offset,
- /*stride0=*/args->lhs_row_stride,
- /*stride1=*/1,
- /*size0=*/M,
- /*size1=*/K * lhs_tile_size);
- MAP_BUFFER_2D_RO(rhs, int8_t,
- /*buffer_ref=*/args->rhs_ref,
- /*offset=*/args->rhs_offset,
- /*stride0=*/args->rhs_row_stride,
- /*stride1=*/1,
- /*size0=*/N,
- /*size1=*/K * rhs_tile_size);
- MAP_BUFFER_2D_RW(out, int32_t,
- /*buffer_ref=*/args->out_ref,
- /*offset=*/args->out_offset,
- /*stride0=*/args->out_row_stride,
- /*stride1=*/1,
- /*size0=*/M,
- /*size1=*/N * out_tile_size);
- iree_ukernel_mmt4d_i8i8i32_params_t ukernel_params = {
- .lhs_buffer = lhs,
- .rhs_buffer = rhs,
- .out_buffer = out,
- .lhs_stride = lhs_stride0,
- .rhs_stride = rhs_stride0,
- .out_stride = out_stride0,
- .M = M,
- .N = N,
- .K = K,
- .M0 = M0,
- .N0 = N0,
- .K0 = K0,
- .flags = args->flags,
- };
- int ukernel_retcode = iree_ukernel_mmt4d_i8i8i32(&ukernel_params);
- IREE_TRACE_ZONE_END(z0);
- if (ukernel_retcode) {
- return iree_make_status(IREE_STATUS_INTERNAL,
- iree_ukernel_mmt4d_error_message(ukernel_retcode));
- }
- return iree_ok_status();
+ return iree_vmvx_mmt4d(iree_ukernel_mmt4d_type_i8i8i32, args);
}
//===----------------------------------------------------------------------===//
diff --git a/runtime/src/iree/testing/benchmark_full.cc b/runtime/src/iree/testing/benchmark_full.cc
index c01abf0..69185ab 100644
--- a/runtime/src/iree/testing/benchmark_full.cc
+++ b/runtime/src/iree/testing/benchmark_full.cc
@@ -135,7 +135,7 @@
}
if (benchmark_def->minimum_duration_ns != 0) {
- instance->MinTime((double)benchmark_def->minimum_duration_ns / 1e-9);
+ instance->MinTime((double)benchmark_def->minimum_duration_ns * 1e-9);
} else if (benchmark_def->iteration_count != 0) {
instance->Iterations(benchmark_def->iteration_count);
}
diff --git a/runtime/src/iree/tooling/BUILD b/runtime/src/iree/tooling/BUILD
index 4f3196e..e793590 100644
--- a/runtime/src/iree/tooling/BUILD
+++ b/runtime/src/iree/tooling/BUILD
@@ -61,6 +61,7 @@
iree_runtime_cc_test(
name = "numpy_io_test",
srcs = ["numpy_io_test.cc"],
+ tags = ["requires-filesystem"],
deps = [
":device_util",
":numpy_io",
diff --git a/runtime/src/iree/tooling/CMakeLists.txt b/runtime/src/iree/tooling/CMakeLists.txt
index 5d7b4dd..22a901b 100644
--- a/runtime/src/iree/tooling/CMakeLists.txt
+++ b/runtime/src/iree/tooling/CMakeLists.txt
@@ -77,6 +77,8 @@
iree::testing::gtest
iree::testing::gtest_main
iree::tooling::testdata::npy
+ LABELS
+ "requires-filesystem"
)
iree_cc_library(
diff --git a/runtime/src/iree/vm/bytecode_disasm.c b/runtime/src/iree/vm/bytecode_disasm.c
index 4a331c5..0b03914 100644
--- a/runtime/src/iree/vm/bytecode_disasm.c
+++ b/runtime/src/iree/vm/bytecode_disasm.c
@@ -2009,7 +2009,7 @@
uint16_t result_reg = VM_ParseResultRegI32("result");
EMIT_I32_REG_NAME(result_reg);
IREE_RETURN_IF_ERROR(
- iree_string_builder_append_cstring(b, " = vm.bitcast.f32.if32 "));
+ iree_string_builder_append_cstring(b, " = vm.bitcast.f32.i32 "));
EMIT_F32_REG_NAME(operand_reg);
EMIT_OPTIONAL_VALUE_F32(regs->i32[operand_reg]);
break;
diff --git a/runtime/src/iree/vm/bytecode_module_impl.h b/runtime/src/iree/vm/bytecode_module_impl.h
index 1988181..9916bb9 100644
--- a/runtime/src/iree/vm/bytecode_module_impl.h
+++ b/runtime/src/iree/vm/bytecode_module_impl.h
@@ -33,7 +33,7 @@
// Major bytecode version; mismatches on this will fail in either direction.
// This allows coarse versioning of completely incompatible versions.
// Matches BytecodeEncoder::kVersionMajor in the compiler.
-#define IREE_VM_BYTECODE_VERSION_MAJOR 10
+#define IREE_VM_BYTECODE_VERSION_MAJOR 12
// Minor bytecode version; lower versions are allowed to enable newer runtimes
// to load older serialized files when there are backwards-compatible changes.
// Higher versions are disallowed as they occur when new ops are added that
diff --git a/runtime/src/iree/vm/shims.c b/runtime/src/iree/vm/shims.c
index d00e9ef..e7305cf 100644
--- a/runtime/src/iree/vm/shims.c
+++ b/runtime/src/iree/vm/shims.c
@@ -27,7 +27,7 @@
IREE_VM_ABI_DEFINE_SHIM(rI, r);
IREE_VM_ABI_DEFINE_SHIM(rI, v);
IREE_VM_ABI_DEFINE_SHIM(riCiD, r);
-IREE_VM_ABI_DEFINE_SHIM(riiCID, r);
+IREE_VM_ABI_DEFINE_SHIM(rIIiiCID, r);
IREE_VM_ABI_DEFINE_SHIM(riCiiiD, r);
IREE_VM_ABI_DEFINE_SHIM(riCrD, r);
IREE_VM_ABI_DEFINE_SHIM(rIi, i);
diff --git a/runtime/src/iree/vm/shims.h b/runtime/src/iree/vm/shims.h
index 85bb37c..ebeb5f0 100644
--- a/runtime/src/iree/vm/shims.h
+++ b/runtime/src/iree/vm/shims.h
@@ -426,12 +426,14 @@
iree_vm_abi_i_t a2[0];
});
-IREE_VM_ABI_VLA_STRUCT(riiCID, a3_count, a3, {
+IREE_VM_ABI_VLA_STRUCT(rIIiiCID, a5_count, a5, {
iree_vm_ref_t r0;
- int32_t i1;
- int32_t i2;
- iree_vm_size_t a3_count;
- iree_vm_abi_I_t a3[0];
+ int64_t i1;
+ int64_t i2;
+ int32_t i3;
+ int32_t i4;
+ iree_vm_size_t a5_count;
+ iree_vm_abi_I_t a5[0];
});
IREE_VM_ABI_VLA_STRUCT(rriiCID, a4_count, a4, {
@@ -559,7 +561,7 @@
IREE_VM_ABI_DECLARE_SHIM(rI, r);
IREE_VM_ABI_DECLARE_SHIM(rI, v);
IREE_VM_ABI_DECLARE_SHIM(riCiD, r);
-IREE_VM_ABI_DECLARE_SHIM(riiCID, r);
+IREE_VM_ABI_DECLARE_SHIM(rIIiiCID, r);
IREE_VM_ABI_DECLARE_SHIM(riCiiiD, r);
IREE_VM_ABI_DECLARE_SHIM(riCrD, r);
IREE_VM_ABI_DECLARE_SHIM(rIi, i);
diff --git a/samples/colab/edge_detection.ipynb b/samples/colab/edge_detection.ipynb
index 10fa776..6a0f9a2 100644
--- a/samples/colab/edge_detection.ipynb
+++ b/samples/colab/edge_detection.ipynb
@@ -330,7 +330,9 @@
"# application, we would probably want to freeze the version of IREE used and\n",
"# compile as completely as possible ahead of time, then use some other scheme\n",
"# to load the module into the application at runtime.\n",
- "compiler_module = tfc.compile_module(EdgeDetectionModule(), import_only=True)\n",
+ "compiler_module = tfc.compile_module(\n",
+ " EdgeDetectionModule(), import_only=True,\n",
+ " import_extra_args=[\"--output-format=mlir-ir\"])\n",
"print(\"Edge Detection MLIR: \", compiler_module.decode('utf-8'))\n",
"\n",
"edge_detection_mlir_path = os.path.join(ARTIFACTS_DIR, \"edge_detection.mlir\")\n",
diff --git a/samples/colab/tflite_text_classification.ipynb b/samples/colab/tflite_text_classification.ipynb
index b4a94d2..3d5dfb3 100644
--- a/samples/colab/tflite_text_classification.ipynb
+++ b/samples/colab/tflite_text_classification.ipynb
@@ -64,7 +64,7 @@
"import tflite_runtime.interpreter as tflite\n",
"\n",
"from iree import runtime as iree_rt\n",
- "from iree.compiler import compile_str\n",
+ "from iree.compiler import compile_file, compile_str\n",
"from iree.tools import tflite as iree_tflite\n",
"\n",
"ARTIFACTS_DIR = pathlib.Path(tempfile.gettempdir(), \"iree\", \"colab_artifacts\")\n",
@@ -320,14 +320,12 @@
},
"outputs": [],
"source": [
- "# Convert TFLite model to TOSA MLIR with IREE's import tool.\n",
+ "# Convert TFLite model to TOSA MLIR (bytecode) with IREE's import tool.\n",
"IREE_TFLITE_TOOL = iree_tflite.get_tool('iree-import-tflite')\n",
- "!{IREE_TFLITE_TOOL} {ARTIFACTS_DIR}/text_classification.tflite --o={ARTIFACTS_DIR}/text_classification.mlir\n",
+ "tosa_mlirbc_file = ARTIFACTS_DIR.joinpath(\"text_classification.mlirbc\")\n",
+ "!{IREE_TFLITE_TOOL} {ARTIFACTS_DIR}/text_classification.tflite --o={tosa_mlirbc_file}\n",
"\n",
- "with open(ARTIFACTS_DIR.joinpath(\"text_classification.mlir\")) as mlir_file:\n",
- " tosa_mlir = mlir_file.read()\n",
- "\n",
- "# The generated .mlir file could now be saved and used outside of Python, with\n",
+ "# The generated .mlirbc file could now be saved and used outside of Python, with\n",
"# IREE native tools or in apps, etc."
]
},
@@ -348,16 +346,16 @@
"text": [
"module {\n",
" func.func @main(%arg0: tensor<1x256xi32> {iree.identifier = \"input_5\"}) -> (tensor<1x2xf32> {iree.identifier = \"Identity\"}) {\n",
- " %0 = \"tosa.const\"() {value = dense<3.906250e-03> : tensor<1x1xf32>} : () -> tensor<1x1xf32>\n",
- " %1 = \"tosa.const\"() {value = opaque<\"elided_large_const\", \"0xDEADBEEF\"> : tensor<1x10003x16xf32>} : () -> tensor<1x10003x16xf32>\n",
- " %2 = \"tosa.const\"() {value = opaque<\"elided_large_const\", \"0xDEADBEEF\"> : tensor<16x16xf32>} : () -> tensor<16x16xf32>\n",
+ " %0 = \"tosa.const\"() {value = dense_resource<__elided__> : tensor<1x10003x16xf32>} : () -> tensor<1x10003x16xf32>\n",
+ " %1 = \"tosa.const\"() {value = dense<3.906250e-03> : tensor<1x1xf32>} : () -> tensor<1x1xf32>\n",
+ " %2 = \"tosa.const\"() {value = dense_resource<__elided__> : tensor<16x16xf32>} : () -> tensor<16x16xf32>\n",
" %3 = \"tosa.const\"() {value = dense<[-0.00698487554, 0.0294856895, 0.0699710473, 0.130019352, -0.0490558445, 0.0987673401, 0.0744077861, 0.0948959812, -0.010937131, 0.0931261852, 0.0711835548, -0.0385615043, 9.962780e-03, 0.00283221388, 0.112116851, 0.0134318024]> : tensor<16xf32>} : () -> tensor<16xf32>\n",
" %4 = \"tosa.const\"() {value = dense<[[0.091361463, -1.23269629, 1.33242488, 0.92142266, -0.445623249, 0.849273681, -1.27237022, 1.28574562, 0.436188251, -0.963210225, 0.745473146, -0.255745709, -1.4491415, -1.4687326, 0.900665163, -1.36293614], [-0.0968776941, 0.771379471, -1.36363328, -1.1110599, -0.304591209, -1.05579722, 0.795746565, -1.3122592, 0.352218777, 1.04682362, -1.18796027, -0.0409261398, 1.05883229, 1.48620188, -1.13325548, 1.03072512]]> : tensor<2x16xf32>} : () -> tensor<2x16xf32>\n",
" %5 = \"tosa.const\"() {value = dense<[0.043447677, -0.0434476472]> : tensor<2xf32>} : () -> tensor<2xf32>\n",
- " %6 = \"tosa.gather\"(%1, %arg0) : (tensor<1x10003x16xf32>, tensor<1x256xi32>) -> tensor<1x256x16xf32>\n",
+ " %6 = \"tosa.gather\"(%0, %arg0) : (tensor<1x10003x16xf32>, tensor<1x256xi32>) -> tensor<1x256x16xf32>\n",
" %7 = \"tosa.reduce_sum\"(%6) {axis = 1 : i64} : (tensor<1x256x16xf32>) -> tensor<1x1x16xf32>\n",
" %8 = \"tosa.reshape\"(%7) {new_shape = [1, 16]} : (tensor<1x1x16xf32>) -> tensor<1x16xf32>\n",
- " %9 = \"tosa.mul\"(%8, %0) {shift = 0 : i32} : (tensor<1x16xf32>, tensor<1x1xf32>) -> tensor<1x16xf32>\n",
+ " %9 = \"tosa.mul\"(%8, %1) {shift = 0 : i32} : (tensor<1x16xf32>, tensor<1x1xf32>) -> tensor<1x16xf32>\n",
" %10 = \"tosa.fully_connected\"(%9, %2, %3) : (tensor<1x16xf32>, tensor<16x16xf32>, tensor<16xf32>) -> tensor<1x16xf32>\n",
" %11 = \"tosa.clamp\"(%10) {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<1x16xf32>) -> tensor<1x16xf32>\n",
" %12 = \"tosa.fully_connected\"(%11, %4, %5) : (tensor<1x16xf32>, tensor<2x16xf32>, tensor<2xf32>) -> tensor<1x2xf32>\n",
@@ -374,7 +372,7 @@
],
"source": [
"# The model contains very large constants, so recompile a truncated version to print.\n",
- "!{IREE_TFLITE_TOOL} {ARTIFACTS_DIR}/text_classification.tflite --o={ARTIFACTS_DIR}/text_classification_truncated.mlir --mlir-elide-elementsattrs-if-larger=50\n",
+ "!{IREE_TFLITE_TOOL} {ARTIFACTS_DIR}/text_classification.tflite --o={ARTIFACTS_DIR}/text_classification_truncated.mlir --output-format=mlir-ir --mlir-elide-elementsattrs-if-larger=50\n",
"\n",
"with open(ARTIFACTS_DIR.joinpath(\"text_classification_truncated.mlir\")) as truncated_mlir_file:\n",
" truncated_tosa_mlir = truncated_mlir_file.read()\n",
@@ -390,7 +388,7 @@
"outputs": [],
"source": [
"# Compile the TOSA MLIR into a VM module.\n",
- "compiled_flatbuffer = compile_str(tosa_mlir, input_type=\"tosa\", target_backends=[\"vmvx\"])\n",
+ "compiled_flatbuffer = compile_file(tosa_mlirbc_file, input_type=\"tosa\", target_backends=[\"vmvx\"])\n",
"\n",
"# Register the module with a runtime context.\n",
"config = iree_rt.Config(\"local-task\")\n",
diff --git a/samples/dynamic_shapes/dynamic_shapes.ipynb b/samples/dynamic_shapes/dynamic_shapes.ipynb
index b2d04e4..257713e 100644
--- a/samples/dynamic_shapes/dynamic_shapes.ipynb
+++ b/samples/dynamic_shapes/dynamic_shapes.ipynb
@@ -164,7 +164,8 @@
"\n",
"compiler_module = tfc.compile_module(\n",
" DynamicShapesModule(), import_only=True, \n",
- " output_mlir_debuginfo=False)\n",
+ " output_mlir_debuginfo=False,\n",
+ " import_extra_args=[\"--output-format=mlir-ir\"])\n",
"clear_output() # Skip over TensorFlow's output.\n",
"\n",
"# Print the imported MLIR to see how the compiler views this program.\n",
@@ -184,45 +185,33 @@
"text": [
"Dynamic Shapes MLIR:\n",
"```\n",
- "\"builtin.module\"() ({\n",
- " \"func.func\"() ({\n",
- " ^bb0(%arg0: !iree_input.buffer_view):\n",
- " %0 = \"iree_input.cast.buffer_view_to_tensor\"(%arg0) : (!iree_input.buffer_view) -> tensor<?xi32>\n",
- " %1 = \"func.call\"(%0) {callee = @__inference_add_one_70} : (tensor<?xi32>) -> tensor<?xi32>\n",
- " %2 = \"iree_input.cast.tensor_to_buffer_view\"(%1) : (tensor<?xi32>) -> !iree_input.buffer_view\n",
- " \"func.return\"(%2) : (!iree_input.buffer_view) -> ()\n",
- " }) {function_type = (!iree_input.buffer_view) -> !iree_input.buffer_view, iree.abi = \"{\\22a\\22:[[\\22ndarray\\22,\\22i32\\22,1,null]],\\22r\\22:[[\\22ndarray\\22,\\22i32\\22,1,null]],\\22v\\22:1}\", sym_name = \"add_one\"} : () -> ()\n",
- " \"func.func\"() ({\n",
- " ^bb0(%arg0: tensor<?xi32>):\n",
- " %0 = \"mhlo.constant\"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>\n",
- " %1 = \"chlo.broadcast_add\"(%arg0, %0) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<?xi32>, tensor<i32>) -> tensor<?xi32>\n",
- " \"func.return\"(%1) : (tensor<?xi32>) -> ()\n",
- " }) {arg_attrs = [{tf._user_specified_name = \"values\"}], function_type = (tensor<?xi32>) -> tensor<?xi32>, sym_name = \"__inference_add_one_70\", sym_visibility = \"private\", tf._construction_context = \"kEagerRuntime\", tf._input_shapes = [#tf_type.shape<?>]} : () -> ()\n",
- " \"func.func\"() ({\n",
- " ^bb0(%arg0: !iree_input.buffer_view):\n",
- " %0 = \"mhlo.constant\"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>\n",
- " %1 = \"iree_input.cast.buffer_view_to_tensor\"(%arg0) : (!iree_input.buffer_view) -> tensor<?xi32>\n",
- " %2 = \"mhlo.reduce\"(%1, %0) ({\n",
- " ^bb0(%arg1: tensor<i32>, %arg2: tensor<i32>):\n",
- " %4 = \"mhlo.add\"(%arg1, %arg2) : (tensor<i32>, tensor<i32>) -> tensor<i32>\n",
- " \"mhlo.return\"(%4) : (tensor<i32>) -> ()\n",
- " }) {dimensions = dense<0> : tensor<1xi64>} : (tensor<?xi32>, tensor<i32>) -> tensor<i32>\n",
- " %3 = \"iree_input.cast.tensor_to_buffer_view\"(%2) : (tensor<i32>) -> !iree_input.buffer_view\n",
- " \"func.return\"(%3) : (!iree_input.buffer_view) -> ()\n",
- " }) {function_type = (!iree_input.buffer_view) -> !iree_input.buffer_view, iree.abi = \"{\\22a\\22:[[\\22ndarray\\22,\\22i32\\22,1,null]],\\22r\\22:[[\\22ndarray\\22,\\22i32\\22,0]],\\22v\\22:1}\", sym_name = \"reduce_sum_1d\"} : () -> ()\n",
- " \"func.func\"() ({\n",
- " ^bb0(%arg0: !iree_input.buffer_view):\n",
- " %0 = \"mhlo.constant\"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>\n",
- " %1 = \"iree_input.cast.buffer_view_to_tensor\"(%arg0) : (!iree_input.buffer_view) -> tensor<?x3xi32>\n",
- " %2 = \"mhlo.reduce\"(%1, %0) ({\n",
- " ^bb0(%arg1: tensor<i32>, %arg2: tensor<i32>):\n",
- " %4 = \"mhlo.add\"(%arg1, %arg2) : (tensor<i32>, tensor<i32>) -> tensor<i32>\n",
- " \"mhlo.return\"(%4) : (tensor<i32>) -> ()\n",
- " }) {dimensions = dense<0> : tensor<1xi64>} : (tensor<?x3xi32>, tensor<i32>) -> tensor<3xi32>\n",
- " %3 = \"iree_input.cast.tensor_to_buffer_view\"(%2) : (tensor<3xi32>) -> !iree_input.buffer_view\n",
- " \"func.return\"(%3) : (!iree_input.buffer_view) -> ()\n",
- " }) {function_type = (!iree_input.buffer_view) -> !iree_input.buffer_view, iree.abi = \"{\\22a\\22:[[\\22ndarray\\22,\\22i32\\22,2,null,3]],\\22r\\22:[[\\22ndarray\\22,\\22i32\\22,1,3]],\\22v\\22:1}\", sym_name = \"reduce_sum_2d\"} : () -> ()\n",
- "}) : () -> ()\n",
+ "module {\n",
+ " func.func @add_one(%arg0: !iree_input.buffer_view) -> !iree_input.buffer_view attributes {iree.abi = \"{\\22a\\22:[[\\22ndarray\\22,\\22i32\\22,1,null]],\\22r\\22:[[\\22ndarray\\22,\\22i32\\22,1,null]],\\22v\\22:1}\"} {\n",
+ " %0 = iree_input.cast.buffer_view_to_tensor %arg0 : !iree_input.buffer_view -> tensor<?xi32>\n",
+ " %1 = call @__inference_add_one_70(%0) : (tensor<?xi32>) -> tensor<?xi32>\n",
+ " %2 = iree_input.cast.tensor_to_buffer_view %1 : tensor<?xi32> -> !iree_input.buffer_view\n",
+ " return %2 : !iree_input.buffer_view\n",
+ " }\n",
+ " func.func private @__inference_add_one_70(%arg0: tensor<?xi32> {tf._user_specified_name = \"values\"}) -> tensor<?xi32> attributes {tf._construction_context = \"kEagerRuntime\", tf._input_shapes = [#tf_type.shape<?>]} {\n",
+ " %0 = mhlo.constant dense<1> : tensor<i32>\n",
+ " %1 = chlo.broadcast_add %arg0, %0 {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<?xi32>, tensor<i32>) -> tensor<?xi32>\n",
+ " return %1 : tensor<?xi32>\n",
+ " }\n",
+ " func.func @reduce_sum_1d(%arg0: !iree_input.buffer_view) -> !iree_input.buffer_view attributes {iree.abi = \"{\\22a\\22:[[\\22ndarray\\22,\\22i32\\22,1,null]],\\22r\\22:[[\\22ndarray\\22,\\22i32\\22,0]],\\22v\\22:1}\"} {\n",
+ " %0 = mhlo.constant dense<0> : tensor<i32>\n",
+ " %1 = iree_input.cast.buffer_view_to_tensor %arg0 : !iree_input.buffer_view -> tensor<?xi32>\n",
+ " %2 = mhlo.reduce(%1 init: %0) applies mhlo.add across dimensions = [0] : (tensor<?xi32>, tensor<i32>) -> tensor<i32>\n",
+ " %3 = iree_input.cast.tensor_to_buffer_view %2 : tensor<i32> -> !iree_input.buffer_view\n",
+ " return %3 : !iree_input.buffer_view\n",
+ " }\n",
+ " func.func @reduce_sum_2d(%arg0: !iree_input.buffer_view) -> !iree_input.buffer_view attributes {iree.abi = \"{\\22a\\22:[[\\22ndarray\\22,\\22i32\\22,2,null,3]],\\22r\\22:[[\\22ndarray\\22,\\22i32\\22,1,3]],\\22v\\22:1}\"} {\n",
+ " %0 = mhlo.constant dense<0> : tensor<i32>\n",
+ " %1 = iree_input.cast.buffer_view_to_tensor %arg0 : !iree_input.buffer_view -> tensor<?x3xi32>\n",
+ " %2 = mhlo.reduce(%1 init: %0) applies mhlo.add across dimensions = [0] : (tensor<?x3xi32>, tensor<i32>) -> tensor<3xi32>\n",
+ " %3 = iree_input.cast.tensor_to_buffer_view %2 : tensor<3xi32> -> !iree_input.buffer_view\n",
+ " return %3 : !iree_input.buffer_view\n",
+ " }\n",
+ "}\n",
"\n",
"```\n",
"\n",
diff --git a/samples/variables_and_state/variables_and_state.ipynb b/samples/variables_and_state/variables_and_state.ipynb
index 5c0ca86..ed4f70b 100644
--- a/samples/variables_and_state/variables_and_state.ipynb
+++ b/samples/variables_and_state/variables_and_state.ipynb
@@ -168,7 +168,9 @@
"from iree.compiler import tf as tfc\n",
"\n",
"compiler_module = tfc.compile_module(\n",
- " CounterModule(), import_only=True, output_mlir_debuginfo=False)\n",
+ " CounterModule(), import_only=True,\n",
+ " output_mlir_debuginfo=False,\n",
+ " import_extra_args=[\"--output-format=mlir-ir\"])\n",
"clear_output() # Skip over TensorFlow's output.\n",
"\n",
"# Print the imported MLIR to see how the compiler views this TensorFlow program.\n",
@@ -189,55 +191,47 @@
"text": [
"Counter MLIR:\n",
"```\n",
- "\"builtin.module\"() ({\n",
- " \"iree_input.global\"() {initial_value = dense<0> : tensor<i32>, is_mutable, sym_name = \"counter\", sym_visibility = \"private\", type = tensor<i32>} : () -> ()\n",
- " \"func.func\"() ({\n",
- " ^bb0(%arg0: !iree_input.buffer_view):\n",
- " %0 = \"iree_input.cast.buffer_view_to_tensor\"(%arg0) : (!iree_input.buffer_view) -> tensor<i32>\n",
- " \"func.call\"(%0) {callee = @__inference_add_to_value_100} : (tensor<i32>) -> ()\n",
- " \"func.return\"() : () -> ()\n",
- " }) {function_type = (!iree_input.buffer_view) -> (), iree.abi = \"{\\22a\\22:[[\\22ndarray\\22,\\22i32\\22,0]],\\22r\\22:[],\\22v\\22:1}\", sym_name = \"add_to_value\"} : () -> ()\n",
- " \"func.func\"() ({\n",
- " ^bb0(%arg0: tensor<i32>):\n",
- " %0 = \"iree_input.global.address\"() {global = @counter} : () -> !iree_input.ptr<tensor<i32>>\n",
- " %1 = \"iree_input.global.load.indirect\"(%0) : (!iree_input.ptr<tensor<i32>>) -> tensor<i32>\n",
- " %2 = \"chlo.broadcast_add\"(%1, %arg0) : (tensor<i32>, tensor<i32>) -> tensor<i32>\n",
- " \"iree_input.global.store.indirect\"(%2, %0) : (tensor<i32>, !iree_input.ptr<tensor<i32>>) -> ()\n",
- " \"func.return\"() : () -> ()\n",
- " }) {arg_attrs = [{tf._user_specified_name = \"x\"}], function_type = (tensor<i32>) -> (), sym_name = \"__inference_add_to_value_100\", sym_visibility = \"private\", tf._construction_context = \"kEagerRuntime\", tf._input_shapes = [#tf_type.shape<>, #tf_type.shape<>], tf.signature.is_stateful} : () -> ()\n",
- " \"func.func\"() ({\n",
- " %0 = \"func.call\"() {callee = @__inference_get_value_160} : () -> tensor<i32>\n",
- " %1 = \"iree_input.cast.tensor_to_buffer_view\"(%0) : (tensor<i32>) -> !iree_input.buffer_view\n",
- " \"func.return\"(%1) : (!iree_input.buffer_view) -> ()\n",
- " }) {function_type = () -> !iree_input.buffer_view, iree.abi = \"{\\22a\\22:[],\\22r\\22:[[\\22ndarray\\22,\\22i32\\22,0]],\\22v\\22:1}\", sym_name = \"get_value\"} : () -> ()\n",
- " \"func.func\"() ({\n",
- " %0 = \"iree_input.global.address\"() {global = @counter} : () -> !iree_input.ptr<tensor<i32>>\n",
- " %1 = \"iree_input.global.load.indirect\"(%0) : (!iree_input.ptr<tensor<i32>>) -> tensor<i32>\n",
- " \"func.return\"(%1) : (tensor<i32>) -> ()\n",
- " }) {function_type = () -> tensor<i32>, sym_name = \"__inference_get_value_160\", sym_visibility = \"private\", tf._construction_context = \"kEagerRuntime\", tf._input_shapes = [#tf_type.shape<>], tf.signature.is_stateful} : () -> ()\n",
- " \"func.func\"() ({\n",
- " \"func.call\"() {callee = @__inference_reset_value_270} : () -> ()\n",
- " \"func.return\"() : () -> ()\n",
- " }) {function_type = () -> (), iree.abi = \"{\\22a\\22:[],\\22r\\22:[],\\22v\\22:1}\", sym_name = \"reset_value\"} : () -> ()\n",
- " \"func.func\"() ({\n",
- " %0 = \"mhlo.constant\"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>\n",
- " %1 = \"iree_input.global.address\"() {global = @counter} : () -> !iree_input.ptr<tensor<i32>>\n",
- " \"iree_input.global.store.indirect\"(%0, %1) : (tensor<i32>, !iree_input.ptr<tensor<i32>>) -> ()\n",
- " \"func.return\"() : () -> ()\n",
- " }) {function_type = () -> (), sym_name = \"__inference_reset_value_270\", sym_visibility = \"private\", tf._construction_context = \"kEagerRuntime\", tf._input_shapes = [#tf_type.shape<>], tf.signature.is_stateful} : () -> ()\n",
- " \"func.func\"() ({\n",
- " ^bb0(%arg0: !iree_input.buffer_view):\n",
- " %0 = \"iree_input.cast.buffer_view_to_tensor\"(%arg0) : (!iree_input.buffer_view) -> tensor<i32>\n",
- " \"func.call\"(%0) {callee = @__sm_exported___inference_set_value_230} : (tensor<i32>) -> ()\n",
- " \"func.return\"() : () -> ()\n",
- " }) {function_type = (!iree_input.buffer_view) -> (), iree.abi = \"{\\22a\\22:[[\\22ndarray\\22,\\22i32\\22,0]],\\22r\\22:[],\\22v\\22:1}\", sym_name = \"set_value\"} : () -> ()\n",
- " \"func.func\"() ({\n",
- " ^bb0(%arg0: tensor<i32>):\n",
- " %0 = \"iree_input.global.address\"() {global = @counter} : () -> !iree_input.ptr<tensor<i32>>\n",
- " \"iree_input.global.store.indirect\"(%arg0, %0) : (tensor<i32>, !iree_input.ptr<tensor<i32>>) -> ()\n",
- " \"func.return\"() : () -> ()\n",
- " }) {arg_attrs = [{tf._user_specified_name = \"new_value\"}], function_type = (tensor<i32>) -> (), sym_name = \"__sm_exported___inference_set_value_230\", sym_visibility = \"private\", tf._construction_context = \"kEagerRuntime\", tf._input_shapes = [#tf_type.shape<>, #tf_type.shape<>], tf.signature.is_stateful} : () -> ()\n",
- "}) : () -> ()\n",
+ "module {\n",
+ " ml_program.global private mutable @counter(dense<0> : tensor<i32>) : tensor<i32>\n",
+ " func.func @add_to_value(%arg0: !iree_input.buffer_view) attributes {iree.abi = \"{\\22a\\22:[[\\22ndarray\\22,\\22i32\\22,0]],\\22r\\22:[],\\22v\\22:1}\"} {\n",
+ " %0 = iree_input.cast.buffer_view_to_tensor %arg0 : !iree_input.buffer_view -> tensor<i32>\n",
+ " call @__inference_add_to_value_100(%0) : (tensor<i32>) -> ()\n",
+ " return\n",
+ " }\n",
+ " func.func private @__inference_add_to_value_100(%arg0: tensor<i32> {tf._user_specified_name = \"x\"}) attributes {tf._construction_context = \"kEagerRuntime\", tf._input_shapes = [#tf_type.shape<>, #tf_type.shape<>], tf.signature.is_stateful} {\n",
+ " %0 = ml_program.global_load @counter : tensor<i32>\n",
+ " %1 = chlo.broadcast_add %0, %arg0 : (tensor<i32>, tensor<i32>) -> tensor<i32>\n",
+ " ml_program.global_store @counter = %1 : tensor<i32>\n",
+ " return\n",
+ " }\n",
+ " func.func @get_value() -> !iree_input.buffer_view attributes {iree.abi = \"{\\22a\\22:[],\\22r\\22:[[\\22ndarray\\22,\\22i32\\22,0]],\\22v\\22:1}\"} {\n",
+ " %0 = call @__inference_get_value_160() : () -> tensor<i32>\n",
+ " %1 = iree_input.cast.tensor_to_buffer_view %0 : tensor<i32> -> !iree_input.buffer_view\n",
+ " return %1 : !iree_input.buffer_view\n",
+ " }\n",
+ " func.func private @__inference_get_value_160() -> tensor<i32> attributes {tf._construction_context = \"kEagerRuntime\", tf._input_shapes = [#tf_type.shape<>], tf.signature.is_stateful} {\n",
+ " %0 = ml_program.global_load @counter : tensor<i32>\n",
+ " return %0 : tensor<i32>\n",
+ " }\n",
+ " func.func @reset_value() attributes {iree.abi = \"{\\22a\\22:[],\\22r\\22:[],\\22v\\22:1}\"} {\n",
+ " call @__inference_reset_value_270() : () -> ()\n",
+ " return\n",
+ " }\n",
+ " func.func private @__inference_reset_value_270() attributes {tf._construction_context = \"kEagerRuntime\", tf._input_shapes = [#tf_type.shape<>], tf.signature.is_stateful} {\n",
+ " %0 = mhlo.constant dense<0> : tensor<i32>\n",
+ " ml_program.global_store @counter = %0 : tensor<i32>\n",
+ " return\n",
+ " }\n",
+ " func.func @set_value(%arg0: !iree_input.buffer_view) attributes {iree.abi = \"{\\22a\\22:[[\\22ndarray\\22,\\22i32\\22,0]],\\22r\\22:[],\\22v\\22:1}\"} {\n",
+ " %0 = iree_input.cast.buffer_view_to_tensor %arg0 : !iree_input.buffer_view -> tensor<i32>\n",
+ " call @__sm_exported___inference_set_value_230(%0) : (tensor<i32>) -> ()\n",
+ " return\n",
+ " }\n",
+ " func.func private @__sm_exported___inference_set_value_230(%arg0: tensor<i32> {tf._user_specified_name = \"new_value\"}) attributes {tf._construction_context = \"kEagerRuntime\", tf._input_shapes = [#tf_type.shape<>, #tf_type.shape<>], tf.signature.is_stateful} {\n",
+ " ml_program.global_store @counter = %arg0 : tensor<i32>\n",
+ " return\n",
+ " }\n",
+ "}\n",
"\n",
"```\n",
"\n",
diff --git a/tests/e2e/linalg_transform/transform_dialect_codegen_spec.mlir b/tests/e2e/linalg_transform/transform_dialect_codegen_spec.mlir
index 24c98ce..ea6606b 100644
--- a/tests/e2e/linalg_transform/transform_dialect_codegen_spec.mlir
+++ b/tests/e2e/linalg_transform/transform_dialect_codegen_spec.mlir
@@ -1,7 +1,4 @@
-transform.with_pdl_patterns {
-^bb0(%arg0: !pdl.operation):
- transform.structured.canonicalized_sequence %arg0 failures(propagate) {
- ^bb1(%variant_op: !pdl.operation):
- transform.iree.bufferize %variant_op
- }
+transform.structured.canonicalized_sequence failures(propagate) {
+^bb1(%variant_op: !pdl.operation):
+ transform.iree.bufferize %variant_op
}
diff --git a/tests/e2e/linalg_transform/transform_dialect_dispatch_spec.mlir b/tests/e2e/linalg_transform/transform_dialect_dispatch_spec.mlir
index 633eeb5..f140aff 100644
--- a/tests/e2e/linalg_transform/transform_dialect_dispatch_spec.mlir
+++ b/tests/e2e/linalg_transform/transform_dialect_dispatch_spec.mlir
@@ -1,9 +1,6 @@
-transform.with_pdl_patterns {
-^bb0(%arg0: !pdl.operation):
- transform.structured.canonicalized_sequence %arg0 failures(propagate) {
- ^bb1(%arg1: !pdl.operation):
- %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1
- %foreach_op, %tiled_op = transform.structured.tile_to_foreach_thread_op %0 num_threads [13, 33]
- %dispatch_op = transform.iree.foreach_thread_to_flow %foreach_op
- }
+transform.structured.canonicalized_sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1
+ %foreach_op, %tiled_op = transform.structured.tile_to_foreach_thread_op %0 num_threads [13, 33]
+ %dispatch_op = transform.iree.foreach_thread_to_flow %foreach_op
}
diff --git a/tests/e2e/matmul/BUILD b/tests/e2e/matmul/BUILD
index 743a048..f2e8e73 100644
--- a/tests/e2e/matmul/BUILD
+++ b/tests/e2e/matmul/BUILD
@@ -114,12 +114,12 @@
"small",
]]
-# Test VMVX+ukernel
+# Test VMVX+ukernel, direct (not mmt4d)
[iree_generated_trace_runner_test(
- name = "e2e_matmul_%s_%s_small_ukernel" % (strategy, lhs_rhs_type),
+ name = "e2e_matmul_direct_%s_small_ukernel" % lhs_rhs_type,
compiler_flags = [
"--iree-vmvx-enable-microkernels",
- ] + (["--iree-flow-mmt4d-target-options=enable_generic_slow #pass_options_variant#"] if (strategy == "mmt4d") else []),
+ ],
generator = ":generate_e2e_matmul_tests",
generator_args = [
"--lhs_rhs_type=%s" % lhs_rhs_type,
@@ -129,10 +129,33 @@
("vmvx", "local-task"),
],
trace_runner = "//tools:iree-e2e-matmul-test",
-) for strategy in [
- "direct",
- "mmt4d",
-] for lhs_rhs_type in [
+) for lhs_rhs_type in [
+ "i8",
+ "f32",
+]]
+
+# Test VMVX+ukernel, mmt4d, with target CPU features variants relevant to each
+# lhs_rhs_type.
+[iree_generated_trace_runner_test(
+ name = "e2e_matmul_mmt4d_%s_small_ukernel" % lhs_rhs_type,
+ compiler_flags = [
+ "--iree-vmvx-enable-microkernels",
+ "--iree-flow-mmt4d-target-options=enable_generic_slow #pass_options_variant#",
+ ],
+ generator = ":generate_e2e_matmul_tests",
+ generator_args = [
+ "--lhs_rhs_type=%s" % lhs_rhs_type,
+ "--shapes=small",
+ ],
+ target_backends_and_drivers = [
+ ("vmvx", "local-task"),
+ ],
+ target_cpu_features_variants = ["default"] + ([
+ "aarch64:+dotprod",
+ "aarch64:+i8mm",
+ ] if lhs_rhs_type == "i8" else []),
+ trace_runner = "//tools:iree-e2e-matmul-test",
+) for lhs_rhs_type in [
"i8",
"f32",
]]
diff --git a/tests/e2e/matmul/CMakeLists.txt b/tests/e2e/matmul/CMakeLists.txt
index 52b6a39..e779924 100644
--- a/tests/e2e/matmul/CMakeLists.txt
+++ b/tests/e2e/matmul/CMakeLists.txt
@@ -223,6 +223,10 @@
COMPILER_FLAGS
"--iree-vmvx-enable-microkernels"
"--iree-flow-mmt4d-target-options=enable_generic_slow #pass_options_variant#"
+ TARGET_CPU_FEATURES_VARIANTS
+ "default"
+ "aarch64:+dotprod"
+ "aarch64:+i8mm"
)
iree_generated_trace_runner_test(
@@ -242,6 +246,8 @@
COMPILER_FLAGS
"--iree-vmvx-enable-microkernels"
"--iree-flow-mmt4d-target-options=enable_generic_slow #pass_options_variant#"
+ TARGET_CPU_FEATURES_VARIANTS
+ "default"
)
iree_generated_trace_runner_test(
diff --git a/tests/e2e/models/fullyconnected.mlir b/tests/e2e/models/fullyconnected.mlir
index d589c10..8977355 100644
--- a/tests/e2e/models/fullyconnected.mlir
+++ b/tests/e2e/models/fullyconnected.mlir
@@ -1,4 +1,5 @@
// RUN: iree-run-mlir --iree-input-type=mhlo --iree-hal-target-backends=llvm-cpu %s --function_input=1x5xf32=1,-2,-3,4,-5 --function_input=1x5x3x1xf32=15,14,13,12,11,10,9,8,7,6,5,4,3,2,1 | FileCheck %s
+// RUN: iree-run-mlir --iree-flow-dispatch-via-region-ops --iree-input-type=mhlo --iree-hal-target-backends=llvm-cpu %s --function_input=1x5xf32=1,-2,-3,4,-5 --function_input=1x5x3x1xf32=15,14,13,12,11,10,9,8,7,6,5,4,3,2,1 | FileCheck %s
// RUN: [[ $IREE_VULKAN_DISABLE == 1 ]] || (iree-run-mlir --iree-input-type=mhlo --iree-hal-target-backends=vulkan-spirv %s --function_input=1x5xf32=1,-2,-3,4,-5 --function_input=1x5x3x1xf32=15,14,13,12,11,10,9,8,7,6,5,4,3,2,1 | FileCheck %s)
// CHECK-LABEL: EXEC @main
diff --git a/tests/transform_dialect/cpu/matmul.mlir b/tests/transform_dialect/cpu/matmul.mlir
index 641c7f7..eed20be 100644
--- a/tests/transform_dialect/cpu/matmul.mlir
+++ b/tests/transform_dialect/cpu/matmul.mlir
@@ -21,7 +21,7 @@
// Atm the 3rd flow.dispatch.tensor.load shows as readonly instead of readwrite.
// DISPATCH: flow.executable private @matmul_static_dispatch_0 {
-// DISPATCH: flow.executable.export public @matmul_static_dispatch_0_matmul_3x3x5
+// DISPATCH: flow.executable.export public @matmul_static_dispatch_0_matmul_3x3x5
// DISPATCH: builtin.module {
// DISPATCH: func.func @matmul_static_dispatch_0_matmul_3x3x5
// DISPATCH: flow.dispatch.tensor.load {{.*}}, offsets = [0, 0], sizes = [3, 5], strides = [1, 1] : !flow.dispatch.tensor<readonly:3x5xf32> -> tensor<3x5xf32>
@@ -32,11 +32,17 @@
// DISPATCH: return
// RUN: iree-opt %s --iree-hal-target-backends=llvm-cpu \
-// RUN: --iree-abi-transformation-pipeline \
-// RUN: --iree-flow-transformation-pipeline \
-// RUN: --iree-flow-dispatch-use-transform-dialect=%p/matmul_dispatch_spec.mlir \
-// RUN: --iree-stream-transformation-pipeline \
-// RUN: --iree-hal-configuration-pipeline | \
+// RUN: --iree-abi-transformation-pipeline --iree-flow-transformation-pipeline --iree-flow-dispatch-use-transform-dialect=%p/matmul_dispatch_spec.mlir \
+// RUN: --iree-stream-transformation-pipeline --iree-hal-configuration-pipeline | \
+// RUN: iree-opt --pass-pipeline='hal.executable(hal.executable.variant(iree-llvmcpu-lower-executable-target))' \
+// RUN: --iree-codegen-llvmcpu-use-transform-dialect=%p/matmul_codegen_spec.mlir | \
+// RUN: FileCheck %s --check-prefixes=CODEGEN
+
+// Run with C++ dispatch region formation but transform dialect codegen
+// RUN: iree-opt %s --iree-hal-target-backends=llvm-cpu \
+// RUN: --iree-abi-transformation-pipeline --iree-flow-transformation-pipeline \
+// RUN: --iree-flow-dispatch-via-region-ops --iree-flow-dispatch-via-region-ops-generate-workload-region=false \
+// RUN: --iree-stream-transformation-pipeline --iree-hal-configuration-pipeline | \
// RUN: iree-opt --pass-pipeline='hal.executable(hal.executable.variant(iree-llvmcpu-lower-executable-target))' \
// RUN: --iree-codegen-llvmcpu-use-transform-dialect=%p/matmul_codegen_spec.mlir | \
// RUN: FileCheck %s --check-prefixes=CODEGEN
@@ -45,13 +51,13 @@
// CODEGEN: hal.executable.variant public @embedded_elf_x86_64, target = #executable_target_embedded_elf_x86_64_ {
//
// The signature of the hal.executable.export region is subject to conventions
-// at the flow level. These conventions are materialized in IR e.g. into
+// at the flow level. These conventions are materialized in IR e.g. into
// stream.cmd.dispatch before codegen gets invoked.
-// As a consequence, the tile_size/num_threads/workgroup_count passed to
+// As a consequence, the tile_size/num_threads/workgroup_count passed to
// transform.tile_to_foreach_thread needs to be aware of this convention.
// For now we use our own convention that sizes are static and no other bbArg
// than !hal.device is present.
-//
+//
// CODEGEN: hal.executable.export public @matmul_static_dispatch_0_matmul_3x3x5 ordinal(0) layout(#{{.*}}) attributes {translation_info = #translation} {
// CODEGEN: ^bb0(%{{.*}}: !hal.device):
// CODEGEN: arith.constant 2 : index
diff --git a/tests/transform_dialect/cpu/matmul_codegen_spec.mlir b/tests/transform_dialect/cpu/matmul_codegen_spec.mlir
index 1435bfb..db7b922 100644
--- a/tests/transform_dialect/cpu/matmul_codegen_spec.mlir
+++ b/tests/transform_dialect/cpu/matmul_codegen_spec.mlir
@@ -1,17 +1,14 @@
// RUN: iree-opt %s
-transform.with_pdl_patterns {
-^bb0(%arg0: !pdl.operation):
- transform.structured.canonicalized_sequence %arg0 failures(propagate) {
- ^bb1(%variant_op: !pdl.operation):
- %0 = transform.structured.match ops{["linalg.matmul"]} in %variant_op
+transform.structured.canonicalized_sequence failures(propagate) {
+^bb1(%variant_op: !pdl.operation):
+ %0 = transform.structured.match ops{["linalg.matmul"]} in %variant_op
- %foreach_thread, %tiled_generic =
- transform.structured.tile_to_foreach_thread_op %0 num_threads [2]
+ %foreach_thread, %tiled_generic =
+ transform.structured.tile_to_foreach_thread_op %0 num_threads [2]
- %1 = transform.iree.bufferize %variant_op
+ %1 = transform.iree.bufferize %variant_op
- %func = transform.structured.match ops{["func.func"]} in %1
- transform.iree.foreach_thread_to_workgroup %func
- }
+ %func = transform.structured.match ops{["func.func"]} in %1
+ transform.iree.foreach_thread_to_workgroup %func
}
diff --git a/tests/transform_dialect/cpu/matmul_tiled_dispatch_spec.mlir b/tests/transform_dialect/cpu/matmul_tiled_dispatch_spec.mlir
index 80b735d..41dc2dd 100644
--- a/tests/transform_dialect/cpu/matmul_tiled_dispatch_spec.mlir
+++ b/tests/transform_dialect/cpu/matmul_tiled_dispatch_spec.mlir
@@ -1,9 +1,6 @@
-transform.with_pdl_patterns {
-^bb0(%arg0: !pdl.operation):
- transform.structured.canonicalized_sequence %arg0 failures(propagate) {
- ^bb1(%arg1: !pdl.operation):
- %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1
- %foreach_op, %tiled_op = transform.structured.tile_to_foreach_thread_op %0 num_threads [10, 20]
- %dispatch_op = transform.iree.foreach_thread_to_flow %foreach_op
- }
+transform.structured.canonicalized_sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1
+ %foreach_op, %tiled_op = transform.structured.tile_to_foreach_thread_op %0 num_threads [10, 20]
+ %dispatch_op = transform.iree.foreach_thread_to_flow %foreach_op
}
diff --git a/tests/transform_dialect/cuda/BUILD b/tests/transform_dialect/cuda/BUILD
index 5c3321f..d2a71ac 100644
--- a/tests/transform_dialect/cuda/BUILD
+++ b/tests/transform_dialect/cuda/BUILD
@@ -26,7 +26,7 @@
iree_lit_test_suite(
name = "lit",
srcs = [
- "reduction.mlir",
+ # "reduction.mlir", // see #10398
"softmax.mlir",
],
cfg = "//tests:lit.cfg.py",
@@ -37,6 +37,7 @@
"reduction_dispatch_spec.mlir",
"softmax_codegen_spec.mlir",
"softmax_dispatch_spec.mlir",
+ "softmax_fused_codegen_spec.mlir",
],
tags = [
# CUDA cuInit fails with sanitizer on.
diff --git a/tests/transform_dialect/cuda/CMakeLists.txt b/tests/transform_dialect/cuda/CMakeLists.txt
index c947020..5905026 100644
--- a/tests/transform_dialect/cuda/CMakeLists.txt
+++ b/tests/transform_dialect/cuda/CMakeLists.txt
@@ -18,7 +18,6 @@
NAME
lit
SRCS
- "reduction.mlir"
"softmax.mlir"
TOOLS
FileCheck
@@ -30,6 +29,7 @@
reduction_dispatch_spec.mlir
softmax_codegen_spec.mlir
softmax_dispatch_spec.mlir
+ softmax_fused_codegen_spec.mlir
LABELS
"noasan"
"nomsan"
diff --git a/tests/transform_dialect/cuda/reduction_codegen_spec.mlir b/tests/transform_dialect/cuda/reduction_codegen_spec.mlir
index b2e834d..ca2ee11 100644
--- a/tests/transform_dialect/cuda/reduction_codegen_spec.mlir
+++ b/tests/transform_dialect/cuda/reduction_codegen_spec.mlir
@@ -1,43 +1,38 @@
// RUN: iree-opt %s
-transform.with_pdl_patterns {
-^bb0(%arg0: !pdl.operation):
- transform.structured.canonicalized_sequence %arg0 failures(propagate) {
- ^bb1(%variant_op: !pdl.operation):
- %0 = transform.structured.match ops{["linalg.generic"]} in %variant_op
- %fused_fill = transform.structured.match ops{["linalg.fill"]} in %variant_op
- // Note: split by 32 to vector-distribute the tail combiner_op, but
- // split by 2 to vector-distribute the meaty %more_parallel_op
- %init_or_alloc_op, %fill_op, %more_parallel_op, %combiner_op =
- transform.structured.split_reduction %0
- { split_factor = 2, insert_split_dimension = 1, use_alloc }
+transform.structured.canonicalized_sequence failures(propagate) {
+^bb1(%variant_op: !pdl.operation):
+ %0 = transform.structured.match ops{["linalg.generic"]} in %variant_op
+ %fused_fill = transform.structured.match ops{["linalg.fill"]} in %variant_op
+ // Note: split by 32 to vector-distribute the tail combiner_op, but
+ // split by 2 to vector-distribute the meaty %more_parallel_op
+ %init_or_alloc_op, %fill_op, %more_parallel_op, %combiner_op =
+ transform.structured.split_reduction %0
+ { split_factor = 2, insert_split_dimension = 1, use_alloc }
- %1 = transform.structured.match ops{["linalg.generic"]} in %variant_op
- %foreach_thread_1, %tiled_fill =
- transform.structured.tile_to_foreach_thread_op %fill_op num_threads [4, 2] (mapped to dims [2, 1, 0])
- %foreach_thread_2, %tiled_more_parallel_op =
- transform.structured.tile_to_foreach_thread_op %more_parallel_op num_threads [4, 2] (mapped to dims [2, 1, 0])
- %foreach_thread_3, %tiled_combiner_op =
- transform.structured.tile_to_foreach_thread_op %combiner_op num_threads [4] (mapped to dims [2, 1, 0])
- %foreach_thread_4, %tiled_fused_fill_op =
- transform.structured.tile_to_foreach_thread_op %fused_fill num_threads [4] (mapped to dims [2, 1, 0])
+ %1 = transform.structured.match ops{["linalg.generic"]} in %variant_op
+ %foreach_thread_1, %tiled_fill =
+ transform.structured.tile_to_foreach_thread_op %fill_op num_threads [4, 2] (mapped to dims [2, 1, 0])
+ %foreach_thread_2, %tiled_more_parallel_op =
+ transform.structured.tile_to_foreach_thread_op %more_parallel_op num_threads [4, 2] (mapped to dims [2, 1, 0])
+ %foreach_thread_3, %tiled_combiner_op =
+ transform.structured.tile_to_foreach_thread_op %combiner_op num_threads [4] (mapped to dims [2, 1, 0])
+ %foreach_thread_4, %tiled_fused_fill_op =
+ transform.structured.tile_to_foreach_thread_op %fused_fill num_threads [4] (mapped to dims [2, 1, 0])
- %isolated_handle_1 = transform.get_closest_isolated_parent %foreach_thread_2
- %isolated_handle_2 = transform.structured.vectorize %isolated_handle_1
- %isolated_handle_3 = transform.iree.apply_patterns %isolated_handle_2 { rank_reducing }
+ %isolated_handle_1 = transform.get_closest_isolated_parent %foreach_thread_2
+ %isolated_handle_2 = transform.structured.vectorize %isolated_handle_1
+ %isolated_handle_3 = transform.iree.apply_patterns %isolated_handle_2 { rank_reducing }
- %variant_op_2 = transform.iree.bufferize { target_gpu } %variant_op
+ %variant_op_2 = transform.iree.bufferize { target_gpu } %variant_op
- %funcop = transform.structured.match ops{["func.func"]} in %variant_op_2
- %isolated_handle_4 =
- transform.iree.foreach_thread_to_gpu_and_translation_info %funcop
- { workgroup_size = [32, 2, 4] }
+ %funcop = transform.structured.match ops{["func.func"]} in %variant_op_2
+ %isolated_handle_4 =
+ transform.iree.foreach_thread_to_gpu_and_translation_info %funcop
+ { workgroup_size = [32, 2, 4] }
- // Vector distribution needs to happen on buffers.
- %if_op = transform.structured.match ops{["scf.if"]} in %variant_op_2
- %warp = transform.iree.vector.to_warp_execute_on_lane_0 %if_op { warp_size = 32 }
- transform.iree.vector.warp_distribute %isolated_handle_4
-
- // transform.print { name = "after codegen"}
- }
+ // Vector distribution needs to happen on buffers.
+ %if_op = transform.structured.match ops{["scf.if"]} in %variant_op_2
+ %warp = transform.iree.vector.to_warp_execute_on_lane_0 %if_op { warp_size = 32 }
+ transform.iree.vector.warp_distribute %isolated_handle_4
}
diff --git a/tests/transform_dialect/cuda/reduction_dispatch_spec.mlir b/tests/transform_dialect/cuda/reduction_dispatch_spec.mlir
index 353342e..cfc4600 100644
--- a/tests/transform_dialect/cuda/reduction_dispatch_spec.mlir
+++ b/tests/transform_dialect/cuda/reduction_dispatch_spec.mlir
@@ -1,11 +1,8 @@
// RUN: iree-opt %s
-transform.with_pdl_patterns {
-^bb0(%arg0: !pdl.operation):
- transform.structured.canonicalized_sequence %arg0 failures(propagate) {
- ^bb1(%arg1: !pdl.operation):
- %0 = transform.structured.match ops{["linalg.generic"]} in %arg1
- %foreach_thread, %tiled_generic = transform.structured.tile_to_foreach_thread_op %0 num_threads [2]
- transform.iree.foreach_thread_to_flow %foreach_thread
- }
+transform.structured.canonicalized_sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["linalg.generic"]} in %arg1
+ %foreach_thread, %tiled_generic = transform.structured.tile_to_foreach_thread_op %0 num_threads [2]
+ transform.iree.foreach_thread_to_flow %foreach_thread
}
diff --git a/tests/transform_dialect/cuda/softmax.mlir b/tests/transform_dialect/cuda/softmax.mlir
index df1586d..b6e4fed 100644
--- a/tests/transform_dialect/cuda/softmax.mlir
+++ b/tests/transform_dialect/cuda/softmax.mlir
@@ -1,24 +1,55 @@
+
+// RUN: iree-opt %s --iree-hal-target-backends=cuda \
+// RUN: --iree-abi-transformation-pipeline \
+// RUN: --iree-flow-transformation-pipeline \
+// RUN: --iree-flow-dispatch-use-transform-dialect=%p/softmax_dispatch_spec.mlir \
+// RUN: --iree-stream-transformation-pipeline \
+// RUN: --iree-hal-configuration-pipeline | \
+// RUN: iree-opt --pass-pipeline='hal.executable(hal.executable.variant(iree-llvmgpu-lower-executable-target-pass))' \
+// RUN: --iree-codegen-llvmgpu-use-transform-dialect=%p/softmax_codegen_spec.mlir | \
+// RUN: FileCheck %s --check-prefix=CHECK-SHUFFLE
+
// RUN: iree-compile %s --iree-hal-target-backends=cuda \
// RUN: --iree-flow-dispatch-use-transform-dialect=%p/softmax_dispatch_spec.mlir \
// RUN: --iree-codegen-llvmgpu-use-transform-dialect=%p/softmax_codegen_spec.mlir | \
// RUN: iree-run-module --entry_function=max_sub_exp --device=cuda | \
// RUN: FileCheck %s
+// RUN: iree-opt %s --iree-hal-target-backends=cuda \
+// RUN: --iree-abi-transformation-pipeline \
+// RUN: --iree-flow-transformation-pipeline \
+// RUN: --iree-flow-dispatch-use-transform-dialect=%p/softmax_dispatch_spec.mlir \
+// RUN: --iree-stream-transformation-pipeline \
+// RUN: --iree-hal-configuration-pipeline | \
+// RUN: iree-opt --pass-pipeline='hal.executable(hal.executable.variant(iree-llvmgpu-lower-executable-target-pass))' \
+// RUN: --iree-codegen-llvmgpu-use-transform-dialect=%p/softmax_fused_codegen_spec.mlir | \
+// RUN: FileCheck %s --check-prefix=CHECK-SHUFFLE
+
+// RUN: iree-compile %s --iree-hal-target-backends=cuda \
+// RUN: --iree-flow-dispatch-use-transform-dialect=%p/softmax_dispatch_spec.mlir \
+// RUN: --iree-codegen-llvmgpu-use-transform-dialect=%p/softmax_fused_codegen_spec.mlir | \
+// RUN: iree-run-module --entry_function=max_sub_exp --device=cuda | \
+// RUN: FileCheck %s
+
// TODO: make this test drop transform dialect usage at the flow level and use:
// --iree-flow-transformation-pipeline --iree-flow-convert-region-to-workgroups
-!tmp_tensor_t = tensor<12x128xf32>
-!out_tensor_t = tensor<12x128x128xf32>
+!tmp_tensor_t = tensor<16x128xf32>
+!out_tensor_t = tensor<16x128x128xf32>
+
+// Compilation checks that shuffles are produced.
+// CHECK-SHUFFLE: gpu.shuffle xor
// Execution only checks that @max_sub_exp runs.
// CHECK: EXEC @max_sub_exp
+
func.func @max_sub_exp() {
%cst = arith.constant -3.40282347E+38 : f32
%cst_0 = arith.constant dense<1.000000e+00> : !out_tensor_t
%cst_1 = arith.constant dense<5.000000e+00> : !out_tensor_t
%0 = util.do_not_optimize(%cst_1) : !out_tensor_t
- %1 = linalg.init_tensor [12, 128] : !tmp_tensor_t
+ %1 = linalg.init_tensor [16, 128] : !tmp_tensor_t
%2 = linalg.fill ins(%cst : f32) outs(%1 : !tmp_tensor_t) -> !tmp_tensor_t
%3 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
affine_map<(d0, d1, d2) -> (d0, d1)>],
@@ -30,7 +61,7 @@
} -> !tmp_tensor_t
// This has been fused manually to avoid the fusion on tensors pass and reduce noise atm.
- %4 = linalg.init_tensor [12, 128, 128] : !out_tensor_t
+ %4 = linalg.init_tensor [16, 128, 128] : !out_tensor_t
%5 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
affine_map<(d0, d1, d2) -> (d0, d1)>,
affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
diff --git a/tests/transform_dialect/cuda/softmax_codegen_spec.mlir b/tests/transform_dialect/cuda/softmax_codegen_spec.mlir
index 6582475..8495114 100644
--- a/tests/transform_dialect/cuda/softmax_codegen_spec.mlir
+++ b/tests/transform_dialect/cuda/softmax_codegen_spec.mlir
@@ -16,7 +16,7 @@
%foreach_thread, %tiled_generic =
transform.structured.tile_to_foreach_thread_op %root tile_sizes [1, 4]
transform.structured.fuse_into_containing_op %not_root into %foreach_thread
-
+
// Second level of tiling + fusion parallelizes to threads.
// Leaving the reduction untiled on threadIdx.x makes it sequential on
// threadIdx.x. After distribution, predication by if (threadIdx.x == 0) is
@@ -54,7 +54,8 @@
// That is still not good enough because we need to predicate this in order
// to enable the parallel reduction on warps.
%func = transform.structured.match ops{["func.func"]} in %variant_op
- %func_2 = transform.structured.vectorize %func
+ %funcx = transform.iree.apply_patterns %func { rank_reducing }
+ transform.structured.vectorize %funcx
// Bufferization is necessary for:
// 1. lowering scf.foreach_thread to workgroup (block level parallelism)
@@ -63,9 +64,9 @@
// warp_execute_on_lane_0 and later vector distribution.
%variant_op_2 = transform.iree.bufferize { target_gpu } %variant_op
- %func_3 = transform.structured.match ops{["func.func"]} in %variant_op_2
- %func_4 = transform.iree.foreach_thread_to_workgroup %func_3
- transform.iree.foreach_thread_to_gpu_and_translation_info %func_4
+ %func_2 = transform.structured.match ops{["func.func"]} in %variant_op_2
+ %func_3 = transform.iree.foreach_thread_to_workgroup %func_2
+ transform.iree.foreach_thread_to_gpu_and_translation_info %func_3
{ workgroup_size = [32, 4, 1] }
%end_func = transform.structured.match ops{["func.func"]} in %variant_op_2
diff --git a/tests/transform_dialect/cuda/softmax_dispatch_spec.mlir b/tests/transform_dialect/cuda/softmax_dispatch_spec.mlir
index 8986ab1..e0ac46b 100644
--- a/tests/transform_dialect/cuda/softmax_dispatch_spec.mlir
+++ b/tests/transform_dialect/cuda/softmax_dispatch_spec.mlir
@@ -1,21 +1,18 @@
// RUN: iree-opt %s
// Dispatch softmax.
-transform.with_pdl_patterns {
-^bb0(%arg0: !pdl.operation):
- transform.structured.canonicalized_sequence %arg0 failures(propagate){
- ^bb1(%arg1: !pdl.operation):
- %root = transform.structured.match interface{LinalgOp}
- attributes{iterator_types = ["parallel", "parallel", "parallel"]} in %arg1
- %fill = transform.structured.match ops{["linalg.fill"]} in %arg1
- %red = transform.structured.match interface{LinalgOp}
- attributes{iterator_types = ["parallel", "parallel", "reduction"]} in %arg1
+transform.structured.canonicalized_sequence failures(propagate){
+^bb1(%arg1: !pdl.operation):
+ %root = transform.structured.match interface{LinalgOp}
+ attributes{iterator_types = ["parallel", "parallel", "parallel"]} in %arg1
+ %fill = transform.structured.match ops{["linalg.fill"]} in %arg1
+ %red = transform.structured.match interface{LinalgOp}
+ attributes{iterator_types = ["parallel", "parallel", "reduction"]} in %arg1
- // TODO: this could be replaced by a C++ only version.
- // Atm the IR produced is not the same so all pieces do not connect.
- %region_op = transform.iree.wrap_in_dispatch_region %root
- %region_op_2 = transform.iree.clone_preceding_op_into_dispatch_region %red into %region_op
- %region_op_3 = transform.iree.clone_preceding_op_into_dispatch_region %fill into %region_op_2
- transform.iree.region_to_workgroups %region_op_3
- }
+ // TODO: this could be replaced by a C++ only version.
+ // Atm the IR produced is not the same so all pieces do not connect.
+ %region_op = transform.iree.wrap_in_dispatch_region %root
+ %region_op_2 = transform.iree.move_preceding_op_into_dispatch_region %red into %region_op
+ %region_op_3 = transform.iree.move_preceding_op_into_dispatch_region %fill into %region_op_2
+ transform.iree.region_to_workgroups %region_op_3
}
diff --git a/tests/transform_dialect/cuda/softmax_fused_codegen_spec.mlir b/tests/transform_dialect/cuda/softmax_fused_codegen_spec.mlir
new file mode 100644
index 0000000..3a92244
--- /dev/null
+++ b/tests/transform_dialect/cuda/softmax_fused_codegen_spec.mlir
@@ -0,0 +1,57 @@
+// RUN: iree-opt %s
+
+// Codegen
+transform.structured.canonicalized_sequence failures(propagate) {
+// transform.sequence %arg0 failures(propagate) {
+^bb1(%variant_op: !pdl.operation):
+ // First level of tiling + fusion parallelizes to blocks.
+ // The mapping to block ids can only happen after bufferization atm
+ %root = transform.structured.match interface{LinalgOp}
+ attributes{iterator_types = ["parallel", "parallel", "parallel"]} in %variant_op
+ %fill = transform.structured.match ops{["linalg.fill"]} in %variant_op
+ %red = transform.structured.match interface{LinalgOp}
+ attributes{iterator_types = ["parallel", "parallel", "reduction"]} in %variant_op
+ %not_root = merge_handles %fill, %red
+ %foreach_thread, %tiled_generic =
+ transform.structured.tile_to_foreach_thread_op %root tile_sizes [1, 1]
+ (mapped to dims [0, 1, 2])
+ transform.structured.fuse_into_containing_op %not_root into %foreach_thread
+
+ // Second level of tiling + fusion parallelizes to threads.
+ // Leaving the reduction untiled on threadIdx.x makes it sequential on
+ // threadIdx.x. After distribution, predication by if (threadIdx.x == 0) is
+ // introduced and opportunities for distributing vector ops across warps
+ // appear.
+ %fill_linalg = transform.structured.match ops{["linalg.fill"]} in %variant_op
+ %reduction_linalg = transform.structured.match ops{["linalg.generic"]}
+ attributes{iterator_types = ["parallel", "parallel", "reduction"]} in %variant_op
+ %not_root_2 = merge_handles %fill_linalg, %reduction_linalg
+ %parallel_linalg = transform.structured.match ops{["linalg.generic"]}
+ attributes{iterator_types = ["parallel", "parallel", "parallel"]} in %variant_op
+ %foreach_thread_2, %parallel_linalg_2 =
+ transform.structured.tile_to_foreach_thread_op %parallel_linalg tile_sizes [1, 1, 0]
+ (mapped to dims [2, 1, 0])
+ transform.structured.fuse_into_containing_op %not_root_2 into %foreach_thread_2
+
+ // Rank-reduce and vectorize.
+ %funcx = transform.structured.match ops{["func.func"]} in %variant_op
+ %funcx_2 = transform.iree.apply_patterns %funcx { rank_reducing }
+ transform.structured.vectorize %funcx_2
+
+ // Bufferization is necessary for:
+ // 1. lowering scf.foreach_thread to workgroup (block level parallelism)
+ // 2. lowering scf.foreach_thread to gpu (thread level parallelism)
+ // 3. introducing predication (due to 1. + 2.) which enables rewriting to
+ // warp_execute_on_lane_0 and later vector distribution.
+ %variant_op_2 = transform.iree.bufferize { target_gpu } %variant_op
+ %func = transform.structured.match ops{["func.func"]} in %variant_op_2
+ %func_2 = transform.iree.foreach_thread_to_workgroup %func
+ transform.iree.foreach_thread_to_gpu_and_translation_info %func_2
+ { workgroup_size = [32, 1, 1] }
+
+ // Vector distribution needs to happen on buffers.
+ %end_func = transform.structured.match ops{["func.func"]} in %variant_op_2
+ %if_op = transform.structured.match ops{["scf.if"]} in %variant_op_2
+ %warp = transform.iree.vector.to_warp_execute_on_lane_0 %if_op { warp_size = 32 }
+ transform.iree.vector.warp_distribute %end_func
+}
diff --git a/third_party/llvm-project b/third_party/llvm-project
index b8a5ce6..80002f6 160000
--- a/third_party/llvm-project
+++ b/third_party/llvm-project
@@ -1 +1 @@
-Subproject commit b8a5ce609648a79b29a6aede5a717c718b28058b
+Subproject commit 80002f63e84c8290db9c1164761eb1ffaf492e81
diff --git a/third_party/mlir-hlo b/third_party/mlir-hlo
index 4fbf18d..03a48d8 160000
--- a/third_party/mlir-hlo
+++ b/third_party/mlir-hlo
@@ -1 +1 @@
-Subproject commit 4fbf18d6d70ac21941d15572618407a33affc93a
+Subproject commit 03a48d808e38dd0cb4f3421a166b66035a67f0c2