Merge pull request #7793 from NatashaKnk:main-to-google
PiperOrigin-RevId: 413527353
diff --git a/bindings/python/iree/runtime/function.py b/bindings/python/iree/runtime/function.py
index d641cc8..826fac9 100644
--- a/bindings/python/iree/runtime/function.py
+++ b/bindings/python/iree/runtime/function.py
@@ -110,15 +110,15 @@
args = list(args)
len_delta = self._max_named_arg_index - len(args) + 1
if len_delta > 0:
- args.extend([NotImplemented] * len_delta)
+ # Fill in MissingArgument placeholders before arranging kwarg input.
+ # Any remaining placeholders will fail arity checks later on.
+ args.extend([MissingArgument] * len_delta)
+
for kwarg_key, kwarg_value in kwargs.items():
try:
kwarg_index = self._named_arg_indices[kwarg_key]
except KeyError:
raise ArgumentError(f"specified kwarg '{kwarg_key}' is unknown")
- len_delta = kwarg_index - len(args) + 1
- if len_delta <= 0:
- args.extend([NotImplemented] * len_delta)
args[kwarg_index] = kwarg_value
arg_list = VmVariantList(len(args))
@@ -202,10 +202,6 @@
# desc: The ABI descriptor list (or None if in dynamic mode).
-def _missing_argument(inv: Invocation, t: VmVariantList, x, desc):
- _raise_argument_error(inv, f"a required argument was not specified")
-
-
def _bool_to_vm(inv: Invocation, t: VmVariantList, x, desc):
_int_to_vm(inv, t, int(x), desc)
@@ -307,8 +303,16 @@
return _ndarray_to_vm(inv, t, np.asarray(x), desc)
+class _MissingArgument:
+ """Placeholder for missing kwargs in the function input."""
+
+ def __repr__(self):
+ return "<mising argument>"
+
+
+MissingArgument = _MissingArgument()
+
PYTHON_TO_VM_CONVERTERS = {
- NotImplemented.__class__: _missing_argument,
bool: _bool_to_vm,
int: _int_to_vm,
float: _float_to_vm,
@@ -494,10 +498,14 @@
# For dynamic mode, just assume we have the right arity.
if descs is None:
descs = [None] * len(py_list)
- elif len(py_list) != len(descs):
- _raise_argument_error(
- inv, f"mismatched call arity: expected {len(descs)} arguments but got "
- f"{len(py_list)}. Expected signature=\n{descs}\nfor input=\n{py_list}")
+ else:
+ len_py_list = sum([1 for x in py_list if x is not MissingArgument])
+ if len(py_list) != len_py_list:
+ _raise_argument_error(
+ inv,
+ f"mismatched call arity: expected {len(descs)} arguments but got "
+ f"{len_py_list}. Expected signature=\n{descs}\nfor input=\n{py_list}")
+
for py_value, desc in zip(py_list, descs):
inv.current_arg = py_value
inv.current_desc = desc
diff --git a/bindings/python/iree/runtime/function_test.py b/bindings/python/iree/runtime/function_test.py
index 90ccda2..df5e222 100644
--- a/bindings/python/iree/runtime/function_test.py
+++ b/bindings/python/iree/runtime/function_test.py
@@ -128,7 +128,7 @@
})
invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None)
result = invoker()
- self.assertEqual((3, [{'bar': 100, 'foo': 200}, 6]), result)
+ self.assertEqual((3, [{"bar": 100, "foo": 200}, 6]), result)
def testMissingPositional(self):
@@ -149,9 +149,30 @@
})
})
invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None)
- with self.assertRaisesRegexp(ValueError,
- "a required argument was not specified"):
- result = invoker(a=1, b=2)
+ with self.assertRaisesRegex(ValueError, "mismatched call arity:"):
+ result = invoker(a=1, b=1)
+
+ def testMissingPositionalNdarray(self):
+
+ def invoke(arg_list, ret_list):
+ ret_list.push_int(3)
+
+ vm_context = MockVmContext(invoke)
+ vm_function = MockVmFunction(
+ reflection={
+ "iree.abi":
+ json.dumps({
+ "a": [
+ ["ndarray", "i32", 1, 1],
+ ["named", "a", ["ndarray", "i32", 1, 1]],
+ ["named", "b", ["ndarray", "i32", 1, 1]],
+ ],
+ "r": ["i32",],
+ })
+ })
+ invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None)
+ with self.assertRaisesRegex(ValueError, "mismatched call arity:"):
+ result = invoker(a=1, b=1)
def testMissingKeyword(self):
@@ -172,8 +193,29 @@
})
})
invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None)
- with self.assertRaisesRegexp(ValueError,
- "a required argument was not specified"):
+ with self.assertRaisesRegex(ValueError, "mismatched call arity:"):
+ result = invoker(-1, a=1)
+
+ def testMissingKeywordNdArray(self):
+
+ def invoke(arg_list, ret_list):
+ ret_list.push_int(3)
+
+ vm_context = MockVmContext(invoke)
+ vm_function = MockVmFunction(
+ reflection={
+ "iree.abi":
+ json.dumps({
+ "a": [
+ ["ndarray", "i32", 1, 1],
+ ["named", "a", ["ndarray", "i32", 1, 1]],
+ ["named", "b", ["ndarray", "i32", 1, 1]],
+ ],
+ "r": ["i32",],
+ })
+ })
+ invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None)
+ with self.assertRaisesRegex(ValueError, "mismatched call arity:"):
result = invoker(-1, a=1)
def testExtraKeyword(self):
@@ -195,7 +237,7 @@
})
})
invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None)
- with self.assertRaisesRegexp(ValueError, "specified kwarg 'c' is unknown"):
+ with self.assertRaisesRegex(ValueError, "specified kwarg 'c' is unknown"):
result = invoker(-1, a=1, b=2, c=3)
# TODO: Fill out all return types.
diff --git a/build_tools/bazel_to_cmake/bazel_to_cmake_targets.py b/build_tools/bazel_to_cmake/bazel_to_cmake_targets.py
index fe5df1b..1c23aa7 100644
--- a/build_tools/bazel_to_cmake/bazel_to_cmake_targets.py
+++ b/build_tools/bazel_to_cmake/bazel_to_cmake_targets.py
@@ -13,51 +13,67 @@
"//build_tools:dl": ["${CMAKE_DL_LIBS}"],
# IREE llvm-external-projects
- "//llvm-external-projects/iree-dialects:IREEInputDialect": [
- "IREEInputDialect"
- ],
- "//llvm-external-projects/iree-dialects:IREELinalgExtDialect": [
- "IREELinalgExtDialect"
- ],
- "//llvm-external-projects/iree-dialects:IREELinalgExtTransforms": [
- "IREELinalgExtPasses"
- ],
+ "//llvm-external-projects/iree-dialects:IREEInputDialect":
+ ["IREEInputDialect"],
+ "//llvm-external-projects/iree-dialects:IREELinalgExtDialect":
+ ["IREELinalgExtDialect"],
+ "//llvm-external-projects/iree-dialects:IREELinalgExtTransforms":
+ ["IREELinalgExtPasses"],
"//llvm-external-projects/iree-dialects:IREEPyDMDialect": [
"IREEPyDMDialect"
],
- "//llvm-external-projects/iree-dialects:IREEPyDMTransforms": [
- "IREEPyDMPasses"
- ],
+ "//llvm-external-projects/iree-dialects:IREEPyDMTransforms":
+ ["IREEPyDMPasses"],
# LLVM
"@llvm-project//llvm:IPO": ["LLVMipo"],
"@llvm-project//lld": ["lld"],
# MLIR
- "@llvm-project//mlir:AllPassesAndDialects": ["MLIRAllDialects"],
- "@llvm-project//mlir:AffineToStandardTransforms": ["MLIRAffineToStandard"],
+ "@llvm-project//mlir:AllPassesAndDialects": [
+ "MLIRAllDialects"
+ ],
+ "@llvm-project//mlir:AffineToStandardTransforms": [
+ "MLIRAffineToStandard"
+ ],
"@llvm-project//mlir:CFGTransforms": ["MLIRSCFToStandard"],
"@llvm-project//mlir:ComplexDialect": ["MLIRComplex"],
"@llvm-project//mlir:DialectUtils": [""],
- "@llvm-project//mlir:ExecutionEngineUtils": ["MLIRExecutionEngine"],
+ "@llvm-project//mlir:ExecutionEngineUtils": [
+ "MLIRExecutionEngine"
+ ],
"@llvm-project//mlir:GPUDialect": ["MLIRGPUOps"],
"@llvm-project//mlir:GPUTransforms": ["MLIRGPUTransforms"],
"@llvm-project//mlir:LinalgInterfaces": ["MLIRLinalg"],
"@llvm-project//mlir:LinalgOps": ["MLIRLinalg"],
"@llvm-project//mlir:LLVMDialect": ["MLIRLLVMIR"],
- "@llvm-project//mlir:LLVMTransforms": ["MLIRStandardToLLVM"],
+ "@llvm-project//mlir:LLVMTransforms": [
+ "MLIRStandardToLLVM"
+ ],
"@llvm-project//mlir:MathDialect": ["MLIRMath"],
- "@llvm-project//mlir:ArithmeticDialect": ["MLIRArithmetic"],
- "@llvm-project//mlir:BufferizationDialect": ["MLIRBufferization"],
+ "@llvm-project//mlir:ArithmeticDialect": [
+ "MLIRArithmetic"
+ ],
+ "@llvm-project//mlir:BufferizationDialect": [
+ "MLIRBufferization"
+ ],
"@llvm-project//mlir:MemRefDialect": ["MLIRMemRef"],
"@llvm-project//mlir:SCFToGPUPass": ["MLIRSCFToGPU"],
"@llvm-project//mlir:SCFDialect": ["MLIRSCF"],
"@llvm-project//mlir:StandardOps": ["MLIRStandard"],
- "@llvm-project//mlir:ShapeTransforms": ["MLIRShapeOpsTransforms"],
- "@llvm-project//mlir:SideEffects": ["MLIRSideEffectInterfaces"],
+ "@llvm-project//mlir:ShapeTransforms": [
+ "MLIRShapeOpsTransforms"
+ ],
+ "@llvm-project//mlir:SideEffects": [
+ "MLIRSideEffectInterfaces"
+ ],
"@llvm-project//mlir:SPIRVDialect": ["MLIRSPIRV"],
"@llvm-project//mlir:TosaDialect": ["MLIRTosa"],
- "@llvm-project//mlir:ToLLVMIRTranslation": ["MLIRTargetLLVMIRExport"],
- "@llvm-project//mlir:mlir_c_runner_utils": ["MLIRExecutionEngine"],
+ "@llvm-project//mlir:ToLLVMIRTranslation": [
+ "MLIRTargetLLVMIRExport"
+ ],
+ "@llvm-project//mlir:mlir_c_runner_utils": [
+ "MLIRExecutionEngine"
+ ],
"@llvm-project//mlir:mlir-translate": ["mlir-translate"],
"@llvm-project//mlir:MlirTableGenMain": ["MLIRTableGen"],
"@llvm-project//mlir:MlirOptLib": ["MLIROptLib"],
@@ -65,6 +81,70 @@
"@llvm-project//mlir:TensorDialect": ["MLIRTensor"],
"@llvm-project//mlir:NVVMDialect": ["MLIRNVVMIR"],
"@llvm-project//mlir:ROCDLDialect": ["MLIRROCDLIR"],
+ # MHLO.
+ # TODO: Rework this upstream so that Bazel and CMake rules match up
+ # better.
+ # All of these have to depend on tensorflow::external_mhlo_includes to
+ # ensure that include directories are inherited.
+ "@mlir-hlo//:chlo_legalize_to_hlo": [
+ "tensorflow::external_mhlo_includes",
+ "ChloPasses",
+ ],
+ "@mlir-hlo//:hlo": [
+ "tensorflow::external_mhlo_includes",
+ "ChloDialect",
+ "MhloDialect",
+ "MLIRMhloUtils",
+ ],
+ "@mlir-hlo//:legalize_control_flow": [
+ "tensorflow::external_mhlo_includes",
+ "MhloToStandard",
+ ],
+ "@mlir-hlo//:legalize_einsum_to_dot_general":
+ [
+ "tensorflow::external_mhlo_includes",
+ "MhloPasses",
+ ],
+ "@mlir-hlo//:legalize_gather_to_torch_index_select":
+ [
+ "tensorflow::external_mhlo_includes",
+ "MhloPasses",
+ ],
+ "@mlir-hlo//:legalize_to_linalg": [
+ "tensorflow::external_mhlo_includes",
+ "MhloLhloToLinalg",
+ ],
+ "@mlir-hlo//:legalize_to_standard": [
+ "tensorflow::external_mhlo_includes",
+ "MhloToStandard",
+ ],
+ "@mlir-hlo//:map_lmhlo_to_scalar_op": [
+ "tensorflow::external_mhlo_includes",
+ "LmhloDialect", # Unfortunate.
+ "MhloDialect",
+ ],
+ "@mlir-hlo//:map_mhlo_to_scalar_op": [
+ "tensorflow::external_mhlo_includes",
+ "MhloDialect",
+ ],
+ "@mlir-hlo//:materialize_broadcasts": [
+ "tensorflow::external_mhlo_includes",
+ "MhloPasses",
+ ],
+ "@mlir-hlo//:mhlo_control_flow_to_scf": [
+ "tensorflow::external_mhlo_includes",
+ "MhloToStandard",
+ ],
+ "@mlir-hlo//:mhlo_to_mhlo_lowering_patterns":
+ [
+ "tensorflow::external_mhlo_includes",
+ "MhloPasses",
+ ],
+ "@mlir-hlo//:unfuse_batch_norm": [
+ "tensorflow::external_mhlo_includes",
+ "MhloPasses",
+ ],
+
# Vulkan
"@vulkan_headers": ["Vulkan::Headers"],
# Cuda
@@ -80,7 +160,8 @@
"@com_google_googletest//:gtest": ["gmock", "gtest"],
"@spirv_cross//:spirv_cross_lib": ["spirv-cross-msl"],
"@cpuinfo": ["cpuinfo"],
- "@vulkan_memory_allocator//:impl_header_only": ["vulkan_memory_allocator"],
+ "@vulkan_memory_allocator//:impl_header_only":
+ ["vulkan_memory_allocator"],
}
@@ -131,8 +212,5 @@
return _convert_llvm_target(target)
if target.startswith("@llvm-project//mlir"):
return _convert_mlir_target(target)
- if target.startswith("@mlir-hlo//"):
- # All Bazel targets map to a single CMake target.
- return ["tensorflow::mlir_hlo"]
raise KeyError(f"No conversion found for target '{target}'")
diff --git a/build_tools/buildkite/cmake/build_configurations.yml b/build_tools/buildkite/cmake/build_configurations.yml
index 7b3acec..ddb1d9c 100644
--- a/build_tools/buildkite/cmake/build_configurations.yml
+++ b/build_tools/buildkite/cmake/build_configurations.yml
@@ -23,6 +23,15 @@
agents:
- "queue=build"
+ - label: ":pinching_hand: Build the size-optimized runtime only"
+ commands:
+ - "./scripts/git/submodule_versions.py init"
+ - "docker run --user=$(id -u):$(id -g) --volume=\\$PWD:\\$IREE_DOCKER_WORKDIR --workdir=\\$IREE_DOCKER_WORKDIR --rm gcr.io/iree-oss/base@sha256:b8d9863c6ac913f167c6fab319d7cd883ab099312488709ee30b29976d63eb22 ./build_tools/cmake/build_runtime_small.sh"
+ env:
+ IREE_DOCKER_WORKDIR: "/usr/src/github/iree"
+ agents:
+ - "queue=build"
+
- label: ":gnu: Build with GCC"
key: "build-gcc"
commands:
diff --git a/build_tools/cmake/build_runtime_small.sh b/build_tools/cmake/build_runtime_small.sh
new file mode 100755
index 0000000..287131c
--- /dev/null
+++ b/build_tools/cmake/build_runtime_small.sh
@@ -0,0 +1,35 @@
+#!/bin/bash
+# Copyright 2021 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+# Build IREE's runtime using CMake. Designed for CI, but can be run manually.
+# This uses previously cached build results and does not clear build
+# directories.
+
+set -e
+set -x
+
+ROOT_DIR=$(git rev-parse --show-toplevel)
+cd ${ROOT_DIR?}
+
+CMAKE_BIN=${CMAKE_BIN:-$(which cmake)}
+"${CMAKE_BIN?}" --version
+ninja --version
+
+if [ -d "build-runtime-small" ]
+then
+ echo "build-runtime-small directory already exists. Will use cached results there."
+else
+ echo "build-runtime-small directory does not already exist. Creating a new one."
+ mkdir build-runtime-small
+fi
+cd build-runtime-small
+
+"${CMAKE_BIN?}" -G Ninja .. \
+ -DCMAKE_BUILD_TYPE=MinSizeRel \
+ -DIREE_SIZE_OPTIMIZED=ON \
+ -DIREE_BUILD_COMPILER=OFF
+"${CMAKE_BIN?}" --build .
diff --git a/build_tools/cmake/iree_copts.cmake b/build_tools/cmake/iree_copts.cmake
index 5826210..15232e6 100644
--- a/build_tools/cmake/iree_copts.cmake
+++ b/build_tools/cmake/iree_copts.cmake
@@ -331,25 +331,15 @@
# Size-optimized build flags
#-------------------------------------------------------------------------------
- # TODO(#898): add a dedicated size-constrained configuration.
+# TODO(#898): add a dedicated size-constrained configuration.
if(${IREE_SIZE_OPTIMIZED})
iree_select_compiler_opts(IREE_SIZE_OPTIMIZED_DEFAULT_COPTS
- CLANG_OR_GCC
- "-DIREE_STATUS_MODE=0"
- "-DIREE_HAL_MODULE_STRING_UTIL_ENABLE=0"
- "-DIREE_VM_EXT_I64_ENABLE=0"
- "-DIREE_VM_EXT_F32_ENABLE=0"
MSVC_OR_CLANG_CL
"/GS-"
"/GL"
"/Gw"
"/Gy"
"/DNDEBUG"
- "/DIREE_STATUS_MODE=0"
- "/DIREE_FLAGS_ENABLE_CLI=0"
- "/DIREE_HAL_MODULE_STRING_UTIL_ENABLE=0"
- "/DIREE_VM_EXT_I64_ENABLE=0"
- "/DIREE_VM_EXT_F32_ENABLE=0"
"/Os"
"/Oy"
"/Zi"
@@ -362,12 +352,23 @@
"-opt:ref,icf"
)
# TODO(#898): make this only impact the runtime (IREE_RUNTIME_DEFAULT_...).
+ # These flags come from iree/base/config.h:
set(IREE_DEFAULT_COPTS
"${IREE_DEFAULT_COPTS}"
- "${IREE_SIZE_OPTIMIZED_DEFAULT_COPTS}")
+ "${IREE_SIZE_OPTIMIZED_DEFAULT_COPTS}"
+ "-DIREE_STATUS_MODE=0"
+ "-DIREE_STATISTICS_ENABLE=0"
+ "-DIREE_HAL_MODULE_STRING_UTIL_ENABLE=0"
+ "-DIREE_HAL_COMMAND_BUFFER_VALIDATION_ENABLE=0"
+ "-DIREE_VM_BACKTRACE_ENABLE=0"
+ "-DIREE_VM_EXT_I64_ENABLE=0"
+ "-DIREE_VM_EXT_F32_ENABLE=0"
+ "-DIREE_VM_EXT_F64_ENABLE=0"
+ )
set(IREE_DEFAULT_LINKOPTS
"${IREE_DEFAULT_LINKOPTS}"
- "${IREE_SIZE_OPTIMIZED_DEFAULT_LINKOPTS}")
+ "${IREE_SIZE_OPTIMIZED_DEFAULT_LINKOPTS}"
+ )
endif()
#-------------------------------------------------------------------------------
diff --git a/build_tools/third_party/mlir-hlo/CMakeLists.txt b/build_tools/third_party/mlir-hlo/CMakeLists.txt
index 0ee1640..62ed31b 100644
--- a/build_tools/third_party/mlir-hlo/CMakeLists.txt
+++ b/build_tools/third_party/mlir-hlo/CMakeLists.txt
@@ -15,18 +15,9 @@
PACKAGE
tensorflow
NAME
- mlir_hlo
+ external_mhlo_includes
ROOT
${TF_MLIR_HLO_SOURCE_DIR}
- DEPS
- MhloDialect
- MhloInferShapeEqualityOpInterface
- LmhloDialect
- ChloPasses
- MhloToStandard
- MhloPasses
- MhloLhloToLinalg
- MLIRMhloUtils
INCLUDES
"${TF_MLIR_HLO_SOURCE_DIR}/"
"${TF_MLIR_HLO_SOURCE_DIR}/include/"
diff --git a/docs/website/docs/deployment-configurations/bare-metal.md b/docs/website/docs/deployment-configurations/bare-metal.md
index 498e8ba..945a983 100644
--- a/docs/website/docs/deployment-configurations/bare-metal.md
+++ b/docs/website/docs/deployment-configurations/bare-metal.md
@@ -27,9 +27,10 @@
The model can be compiled with the following command from the IREE compiler
build directory
-``` shell hl_lines="3 4 5"
+```shell
iree/tools/iree-translate \
-iree-mlir-to-vm-bytecode-module \
+ -iree-stream-partitioning-favor=min-peak-memory \
-iree-hal-target-backends=dylib-llvm-aot \
-iree-llvm-target-triple=x86_64-pc-linux-elf \
-iree-llvm-debug-symbols=false \
@@ -40,12 +41,15 @@
In which
-* `iree-hal-target-backends=dylib-llvm-aot`: Build the model for the dynamic
-library CPU HAL driver
-* `iree-llvm-target-triple`: Use the `<arch>-pc-linux-elf` LLVM target triple so
-the artifact has a fixed ABI to be rendered by the
-[elf_module library](https://github.com/google/iree/tree/main/iree/hal/local/elf)
-* `iree-llvm-debug-symbols=false`: To reduce the artifact size
+* `-iree-stream-partitioning-favor=min-peak-memory`: Optimize for minimum peak
+ memory usage at the cost of concurrency - include when targeting
+ single-threaded execution to reduce memory consumption.
+* `iree-hal-target-backends=dylib-llvm-aot`: Build the model for the dynamic
+ library CPU HAL driver
+* `iree-llvm-target-triple`: Use the `<arch>-pc-linux-elf` LLVM target triple
+ so the artifact has a fixed ABI to be rendered by the
+ [elf_module library](https://github.com/google/iree/tree/main/iree/hal/local/elf)
+* `iree-llvm-debug-symbols=false`: To reduce the artifact size
See [generate.sh](https://github.com/google/iree/blob/main/iree/hal/local/elf/testdata/generate.sh)
for example command-line instructions of some common architectures
diff --git a/integrations/tensorflow/build_tools/overlay/mlir-hlo/BUILD.bazel b/integrations/tensorflow/build_tools/overlay/mlir-hlo/BUILD.bazel
index 2232b6f..26cefa3 100644
--- a/integrations/tensorflow/build_tools/overlay/mlir-hlo/BUILD.bazel
+++ b/integrations/tensorflow/build_tools/overlay/mlir-hlo/BUILD.bazel
@@ -27,10 +27,12 @@
"legalize_einsum_to_dot_general",
"legalize_gather_to_torch_index_select",
"legalize_to_linalg",
+ "legalize_to_standard",
"lhlo",
"map_lmhlo_to_scalar_op",
"map_mhlo_to_scalar_op",
"materialize_broadcasts",
+ "mhlo_control_flow_to_scf",
"mhlo_to_mhlo_lowering_patterns",
"unfuse_batch_norm",
]
diff --git a/integrations/tensorflow/iree_tf_compiler/MHLO/BUILD b/integrations/tensorflow/iree_tf_compiler/MHLO/BUILD
index e5143ef..a955af4 100644
--- a/integrations/tensorflow/iree_tf_compiler/MHLO/BUILD
+++ b/integrations/tensorflow/iree_tf_compiler/MHLO/BUILD
@@ -14,7 +14,6 @@
name = "MHLO",
srcs = [
"EmitDefaultIREEABI.cpp",
- "FlattenTuplesInCFG.cpp",
"Passes.cpp",
],
hdrs = [
@@ -27,7 +26,6 @@
"@iree//iree/compiler/Codegen:PassHeaders",
"@iree//iree/compiler/Dialect/Flow/IR",
"@iree//iree/compiler/Dialect/Flow/Transforms",
- "@iree//iree/compiler/Dialect/Shape/Transforms",
"@iree//iree/compiler/Dialect/Util/IR",
"@iree//iree/compiler/InputConversion/MHLO",
"@llvm-project//llvm:Support",
diff --git a/integrations/tensorflow/iree_tf_compiler/MHLO/Passes.cpp b/integrations/tensorflow/iree_tf_compiler/MHLO/Passes.cpp
index b945bae..440eafe 100644
--- a/integrations/tensorflow/iree_tf_compiler/MHLO/Passes.cpp
+++ b/integrations/tensorflow/iree_tf_compiler/MHLO/Passes.cpp
@@ -8,7 +8,6 @@
#include "iree/compiler/Codegen/Passes.h"
#include "iree/compiler/Dialect/Flow/Transforms/Passes.h"
-#include "iree/compiler/Dialect/Shape/Transforms/Passes.h"
#include "iree/compiler/InputConversion/MHLO/Passes.h"
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
#include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
@@ -23,26 +22,17 @@
namespace MHLO {
void buildMHLOImportPassPipeline(OpPassManager &pm) {
- //----------------------------------------------------------------------------
- // Convert control flow and flatten tuples (like tuple<tensor<...>, ...>)
- //----------------------------------------------------------------------------
- // NOTE: FlattenTuplesInCFGPass requires inlining to have run and has some
- // sensitivity to structured control flow ops.
- // SCF would be ideal as a target (as that matches our other IREE inputs) but
- // the current HLO to SCF pass is extremely basic and doesn't handle anything
- // but tf.while for less-than comparisons from 0. Since those are common we
- // still try to pull those out here but then fall back on the full conversion
- // to CFG form.
+ // We run the inliner for legacy reasons. It shouldn't be necessary anymore,
+ // but this entire pipeline will soon be deleted and it isn't worth
+ // removing now.
pm.addPass(mlir::createInlinerPass());
- pm.addNestedPass<FuncOp>(mhlo::createControlFlowToScfPass());
- pm.addNestedPass<FuncOp>(mhlo::createLegalizeControlFlowPass());
- pm.addNestedPass<FuncOp>(mlir::createLowerToCFGPass());
- pm.addPass(createFlattenTuplesInCFGPass());
- pm.addNestedPass<FuncOp>(mlir::createCanonicalizerPass());
+
+ // Drop to CFG and eliminate tuples.
+ mlir::iree_compiler::MHLO::buildXLACleanupPassPipeline(pm);
// Mostly delegate to the IREE side MHLO legalization pipeline, now that we
// have handled the weird that comes from legacy HLO clients.
- mlir::iree_compiler::buildMHLOInputConversionPassPipeline(pm);
+ mlir::iree_compiler::MHLO::buildMHLOInputConversionPassPipeline(pm);
// Import pipelines should end with canonicalization because they may have
// access to dialects and patterns that the core compiler does not.
diff --git a/integrations/tensorflow/iree_tf_compiler/MHLO/Passes.h b/integrations/tensorflow/iree_tf_compiler/MHLO/Passes.h
index 8572e50..670104a 100644
--- a/integrations/tensorflow/iree_tf_compiler/MHLO/Passes.h
+++ b/integrations/tensorflow/iree_tf_compiler/MHLO/Passes.h
@@ -29,9 +29,6 @@
// of MHLO and is suitable for such programs.
std::unique_ptr<OperationPass<FuncOp>> createEmitDefaultIREEABIPass();
-// Flattens tuple values in function signatures and blocks.
-std::unique_ptr<OperationPass<ModuleOp>> createFlattenTuplesInCFGPass();
-
//===----------------------------------------------------------------------===//
// Registration
//===----------------------------------------------------------------------===//
@@ -40,7 +37,6 @@
registerMHLOImportPassPipeline();
createEmitDefaultIREEABIPass();
- createFlattenTuplesInCFGPass();
}
} // namespace MHLO
diff --git a/integrations/tensorflow/iree_tf_compiler/TF/BUILD b/integrations/tensorflow/iree_tf_compiler/TF/BUILD
index e090ffa..4ac22f2 100644
--- a/integrations/tensorflow/iree_tf_compiler/TF/BUILD
+++ b/integrations/tensorflow/iree_tf_compiler/TF/BUILD
@@ -36,7 +36,6 @@
"@iree//iree/compiler/Dialect/Flow/Transforms",
"@iree//iree/compiler/Dialect/HAL/IR",
"@iree//iree/compiler/Dialect/HAL/IR:HALDialect",
- "@iree//iree/compiler/Dialect/Shape/Transforms",
"@iree//iree/compiler/Dialect/Util/IR",
"@iree//iree/compiler/Utils",
"@llvm-project//llvm:Support",
diff --git a/integrations/tensorflow/iree_tf_compiler/TF/Passes.cpp b/integrations/tensorflow/iree_tf_compiler/TF/Passes.cpp
index 8cb4d3d..a4b1ca9 100644
--- a/integrations/tensorflow/iree_tf_compiler/TF/Passes.cpp
+++ b/integrations/tensorflow/iree_tf_compiler/TF/Passes.cpp
@@ -7,7 +7,6 @@
#include "iree_tf_compiler/TF/Passes.h"
#include "iree/compiler/Dialect/Flow/Transforms/Passes.h"
-#include "iree/compiler/Dialect/Shape/Transforms/Passes.h"
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
#include "mlir/Dialect/Shape/Transforms/Passes.h"
diff --git a/integrations/tensorflow/iree_tf_compiler/TF/SavedModelToIreeABI.cpp b/integrations/tensorflow/iree_tf_compiler/TF/SavedModelToIreeABI.cpp
index 0bd0c61..f415839 100644
--- a/integrations/tensorflow/iree_tf_compiler/TF/SavedModelToIreeABI.cpp
+++ b/integrations/tensorflow/iree_tf_compiler/TF/SavedModelToIreeABI.cpp
@@ -210,8 +210,8 @@
assert(!callArgs[valueIndex] && "duplicate argument bindings");
auto value = thisValue;
if (value.getType().isa<IREE::HAL::BufferViewType>()) {
- value = builder.createOrFold<IREE::HAL::TensorCastOp>(loc, valueType,
- thisValue);
+ value = builder.createOrFold<IREE::HAL::TensorImportOp>(loc, valueType,
+ thisValue);
}
callArgs[valueIndex] = value;
return;
@@ -250,7 +250,7 @@
"mismatched number of call returns");
Value value = callReturns[valueIndex];
if (valueType.isa<TensorType>()) {
- value = builder.createOrFold<IREE::HAL::TensorCastOp>(
+ value = builder.createOrFold<IREE::HAL::TensorExportOp>(
loc, getIrType(builder), value);
}
return value;
@@ -305,8 +305,8 @@
// TODO: Null check, etc. How does that work if returning a tensor? Need
// to box somehow?
if (itemValue.getType().isa<IREE::HAL::BufferViewType>()) {
- itemValue = builder.createOrFold<IREE::HAL::TensorCastOp>(loc, valueType,
- itemValue);
+ itemValue = builder.createOrFold<IREE::HAL::TensorImportOp>(
+ loc, valueType, itemValue);
}
return itemValue;
}
diff --git a/integrations/tensorflow/iree_tf_compiler/TF/test/saved_model_to_iree_abi.mlir b/integrations/tensorflow/iree_tf_compiler/TF/test/saved_model_to_iree_abi.mlir
index 9195f59..cf7eab2 100644
--- a/integrations/tensorflow/iree_tf_compiler/TF/test/saved_model_to_iree_abi.mlir
+++ b/integrations/tensorflow/iree_tf_compiler/TF/test/saved_model_to_iree_abi.mlir
@@ -4,11 +4,11 @@
// Should just be a pass through.
// CHECK: func @binary_func
// CHECK-SAME{LITERAL}: iree.abi = "{\22a\22:[[\22ndarray\22,\22f32\22,1,16],[\22ndarray\22,\22f32\22,1,16]],\22r\22:[[\22stuple\22,[\22ndarray\22,\22f32\22,1,16],[\22ndarray\22,\22f32\22,1,16]]],\22v\22:1}"
-// CHECK: %[[ARG0_TENSOR:.*]] = hal.tensor.cast %arg0 : !hal.buffer_view -> tensor<16xf32>
-// CHECK: %[[ARG1_TENSOR:.*]] = hal.tensor.cast %arg1 : !hal.buffer_view -> tensor<16xf32>
+// CHECK: %[[ARG0_TENSOR:.*]] = hal.tensor.import %arg0 : !hal.buffer_view -> tensor<16xf32>
+// CHECK: %[[ARG1_TENSOR:.*]] = hal.tensor.import %arg1 : !hal.buffer_view -> tensor<16xf32>
// CHECK: %[[R:.*]]:2 = call @__inference_binary_func_70(%[[ARG0_TENSOR]], %[[ARG1_TENSOR]])
-// CHECK: %[[R0_BV:.*]] = hal.tensor.cast %[[R]]#0 : tensor<16xf32> -> !hal.buffer_view
-// CHECK: %[[R1_BV:.*]] = hal.tensor.cast %[[R]]#1 : tensor<16xf32> -> !hal.buffer_view
+// CHECK: %[[R0_BV:.*]] = hal.tensor.export %[[R]]#0 : tensor<16xf32> -> !hal.buffer_view
+// CHECK: %[[R1_BV:.*]] = hal.tensor.export %[[R]]#1 : tensor<16xf32> -> !hal.buffer_view
// CHECK: return %[[R0_BV]], %[[R1_BV]] : !hal.buffer_view, !hal.buffer_view
// CHECK: func private @__inference_binary_func_70
// CHECK-NOT: tf_saved_model
@@ -24,9 +24,9 @@
// CHECK-LABEL: module @unary_func
// CHECK: func @unary_func
// CHECK-SAME{LITERAL}: iree.abi = "{\22a\22:[[\22ndarray\22,\22f32\22,1,16]],\22r\22:[[\22ndarray\22,\22f32\22,1,16]],\22v\22:1}"
-// CHECK: %[[ARG0_TENSOR:.*]] = hal.tensor.cast %arg0 : !hal.buffer_view -> tensor<16xf32>
+// CHECK: %[[ARG0_TENSOR:.*]] = hal.tensor.import %arg0 : !hal.buffer_view -> tensor<16xf32>
// CHECK: %[[R:.*]] = call @__inference_unary_func_240(%[[ARG0_TENSOR]])
-// CHECK: %[[R0_BV:.*]] = hal.tensor.cast %[[R]] : tensor<16xf32> -> !hal.buffer_view
+// CHECK: %[[R0_BV:.*]] = hal.tensor.export %[[R]] : tensor<16xf32> -> !hal.buffer_view
// CHECK: return %[[R0_BV]] : !hal.buffer_view
// CHECK: func private @__inference_unary_func_240
// CHECK-NOT: tf_saved_model
@@ -41,11 +41,11 @@
// CHECK-LABEL: module @return_list
// CHECK: func @return_list
// CHECK-SAME{LITERAL}: iree.abi = "{\22a\22:[[\22ndarray\22,\22f32\22,1,16],[\22ndarray\22,\22f32\22,1,16]],\22r\22:[[\22stuple\22,[\22ndarray\22,\22f32\22,1,16],[\22ndarray\22,\22f32\22,1,16]]],\22v\22:1}"
-// CHECK: %[[ARG0_TENSOR:.*]] = hal.tensor.cast %arg0 : !hal.buffer_view -> tensor<16xf32>
-// CHECK: %[[ARG1_TENSOR:.*]] = hal.tensor.cast %arg1 : !hal.buffer_view -> tensor<16xf32>
+// CHECK: %[[ARG0_TENSOR:.*]] = hal.tensor.import %arg0 : !hal.buffer_view -> tensor<16xf32>
+// CHECK: %[[ARG1_TENSOR:.*]] = hal.tensor.import %arg1 : !hal.buffer_view -> tensor<16xf32>
// CHECK: %[[R:.+]]:2 = call @__inference_return_list_260(%[[ARG0_TENSOR]], %[[ARG1_TENSOR]])
-// CHECK: %[[R0_BV:.*]] = hal.tensor.cast %[[R]]#0 : tensor<16xf32> -> !hal.buffer_view
-// CHECK: %[[R1_BV:.*]] = hal.tensor.cast %[[R]]#1 : tensor<16xf32> -> !hal.buffer_view
+// CHECK: %[[R0_BV:.*]] = hal.tensor.export %[[R]]#0 : tensor<16xf32> -> !hal.buffer_view
+// CHECK: %[[R1_BV:.*]] = hal.tensor.export %[[R]]#1 : tensor<16xf32> -> !hal.buffer_view
// CHECK: return %[[R0_BV]], %[[R1_BV]] : !hal.buffer_view, !hal.buffer_view
// CHECK: func private @__inference_return_list_260
// CHECK-NOT: tf_saved_model
@@ -65,36 +65,36 @@
// CHECK: %[[L0:.+]] = util.list.get %arg0[%[[c0]]] : !util.list<?> -> !util.list<?>
// CHECK: %[[c0_0:.+]] = arith.constant 0 : index
// CHECK: %[[L1:.+]] = util.list.get %[[L0]][%[[c0_0]]] : !util.list<?> -> !hal.buffer_view
-// CHECK: %[[L1_TENSOR:.+]] = hal.tensor.cast %[[L1]] : !hal.buffer_view -> tensor<16xf32>
+// CHECK: %[[L1_TENSOR:.+]] = hal.tensor.import %[[L1]] : !hal.buffer_view -> tensor<16xf32>
// CHECK: %[[c1:.+]] = arith.constant 1 : index
// CHECK: %[[L2:.+]] = util.list.get %[[L0]][%[[c1]]] : !util.list<?> -> !hal.buffer_view
-// CHECK: %[[L2_TENSOR:.+]] = hal.tensor.cast %[[L2]] : !hal.buffer_view -> tensor<16xf32>
+// CHECK: %[[L2_TENSOR:.+]] = hal.tensor.import %[[L2]] : !hal.buffer_view -> tensor<16xf32>
// CHECK: %[[c1_1:.+]] = arith.constant 1 : index
// CHECK: %[[L3:.+]] = util.list.get %arg0[%[[c1_1]]] : !util.list<?> -> !util.list<?>
// CHECK: %[[c0_2:.+]] = arith.constant 0 : index
// CHECK: %[[L4:.+]] = util.list.get %[[L3]][%[[c0_2]]] : !util.list<?> -> !hal.buffer_view
-// CHECK: %[[L4_TENSOR:.+]] = hal.tensor.cast %[[L4]] : !hal.buffer_view -> tensor<16xf32>
+// CHECK: %[[L4_TENSOR:.+]] = hal.tensor.import %[[L4]] : !hal.buffer_view -> tensor<16xf32>
// CHECK: %[[c1_3:.+]] = arith.constant 1 : index
// CHECK: %[[L5:.+]] = util.list.get %[[L3]][%[[c1_3]]] : !util.list<?> -> !hal.buffer_view
-// CHECK: %[[L5_TENSOR:.+]] = hal.tensor.cast %[[L5]] : !hal.buffer_view -> tensor<16xf32>
-// CHECK: %[[ARG1_TENSOR:.+]] = hal.tensor.cast %arg1 : !hal.buffer_view -> tensor<f32>
+// CHECK: %[[L5_TENSOR:.+]] = hal.tensor.import %[[L5]] : !hal.buffer_view -> tensor<16xf32>
+// CHECK: %[[ARG1_TENSOR:.+]] = hal.tensor.import %arg1 : !hal.buffer_view -> tensor<f32>
// CHECK: %[[RESULT:.+]]:4 = call @__inference_dict_nest_190(%[[L1_TENSOR]], %[[L2_TENSOR]], %[[L4_TENSOR]], %[[L5_TENSOR]], %[[ARG1_TENSOR]]) : (tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<f32>) -> (tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>)
// CHECK: %[[c2:.+]] = arith.constant 2 : index
// CHECK: %[[R7:.+]] = util.list.create %[[c2]] : !util.list<?>
// CHECK: util.list.resize %[[R7]], %[[c2]]
-// CHECK: %[[R0_BV:.+]] = hal.tensor.cast %[[RESULT]]#0 : tensor<16xf32> -> !hal.buffer_view
+// CHECK: %[[R0_BV:.+]] = hal.tensor.export %[[RESULT]]#0 : tensor<16xf32> -> !hal.buffer_view
// CHECK: %[[c0_4:.+]] = arith.constant 0 : index
// CHECK: util.list.set %[[R7]][%[[c0_4]]], %[[R0_BV]] : !hal.buffer_view -> !util.list<?>
-// CHECK: %[[R1_BV:.+]] = hal.tensor.cast %[[RESULT]]#1 : tensor<16xf32> -> !hal.buffer_view
+// CHECK: %[[R1_BV:.+]] = hal.tensor.export %[[RESULT]]#1 : tensor<16xf32> -> !hal.buffer_view
// CHECK: %[[c1_5:.+]] = arith.constant 1 : index
// CHECK: util.list.set %[[R7]][%[[c1_5]]], %[[R1_BV]] : !hal.buffer_view -> !util.list<?>
// CHECK: %[[c2_8:.+]] = arith.constant 2 : index
// CHECK: %[[R9:.+]] = util.list.create %[[c2_8]] : !util.list<?>
// CHECK: util.list.resize %[[R9]], %[[c2_8]]
-// CHECK: %[[R2_BV:.+]] = hal.tensor.cast %[[RESULT]]#2 : tensor<16xf32> -> !hal.buffer_view
+// CHECK: %[[R2_BV:.+]] = hal.tensor.export %[[RESULT]]#2 : tensor<16xf32> -> !hal.buffer_view
// CHECK: %[[c0_9:.+]] = arith.constant 0 : index
// CHECK: util.list.set %[[R9]][%[[c0_9]]], %[[R2_BV]] : !hal.buffer_view -> !util.list<?>
-// CHECK: %[[R3_BV:.+]] = hal.tensor.cast %[[RESULT]]#3 : tensor<16xf32> -> !hal.buffer_view
+// CHECK: %[[R3_BV:.+]] = hal.tensor.export %[[RESULT]]#3 : tensor<16xf32> -> !hal.buffer_view
// CHECK: %[[c1_10:.+]] = arith.constant 1 : index
// CHECK: util.list.set %[[R9]][%[[c1_10]]], %[[R3_BV]] : !hal.buffer_view -> !util.list<?>
// return %[[R7]], %[[R8]] : !util.list<?>, !util.list<?>
diff --git a/integrations/tensorflow/iree_tf_compiler/TFL/BUILD b/integrations/tensorflow/iree_tf_compiler/TFL/BUILD
index bfff2f9..4d657ff 100644
--- a/integrations/tensorflow/iree_tf_compiler/TFL/BUILD
+++ b/integrations/tensorflow/iree_tf_compiler/TFL/BUILD
@@ -48,8 +48,6 @@
deps = [
":PassesIncGen",
"@iree//iree/compiler/Dialect/Flow/IR",
- "@iree//iree/compiler/Dialect/Shape/IR",
- "@iree//iree/compiler/Dialect/Shape/Transforms",
"@iree//iree/compiler/Dialect/Util/IR",
"@iree//iree/compiler/Utils",
"@llvm-project//llvm:Support",
diff --git a/integrations/tensorflow/iree_tf_compiler/TFL/Passes.cpp b/integrations/tensorflow/iree_tf_compiler/TFL/Passes.cpp
index 3d18060..ed9a78a 100644
--- a/integrations/tensorflow/iree_tf_compiler/TFL/Passes.cpp
+++ b/integrations/tensorflow/iree_tf_compiler/TFL/Passes.cpp
@@ -6,7 +6,6 @@
#include "iree_tf_compiler/TFL/Passes.h"
-#include "iree/compiler/Dialect/Shape/Transforms/Passes.h"
#include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h"
#include "mlir/Dialect/Shape/Transforms/Passes.h"
#include "mlir/Pass/PassManager.h"
diff --git a/integrations/tensorflow/iree_tf_compiler/iree-tf-opt-main.cpp b/integrations/tensorflow/iree_tf_compiler/iree-tf-opt-main.cpp
index a27579c..5b32646 100644
--- a/integrations/tensorflow/iree_tf_compiler/iree-tf-opt-main.cpp
+++ b/integrations/tensorflow/iree_tf_compiler/iree-tf-opt-main.cpp
@@ -41,8 +41,8 @@
// Select IREE input passes.
mlir::iree_compiler::registerCommonInputConversionPasses();
- mlir::iree_compiler::registerMHLOConversionPasses();
mlir::iree_compiler::registerTOSAConversionPasses();
+ mlir::iree_compiler::MHLO::registerMHLOConversionPasses();
// TensorFlow integration passes.
mlir::RegisterAllTensorFlowDialects(registry);
diff --git a/iree/base/attributes.h b/iree/base/attributes.h
index d419927..bd396a9 100644
--- a/iree/base/attributes.h
+++ b/iree/base/attributes.h
@@ -172,4 +172,23 @@
#define IREE_ATTRIBUTE_PACKED
#endif // IREE_HAVE_ATTRIBUTE(packed)
+//===----------------------------------------------------------------------===//
+// IREE_ATTRIBUTE_UNUSED
+//===----------------------------------------------------------------------===//
+
+// Hints that a variable is _maybe_ unused. This is primarily to quiet
+// diagnostic messages about unused variables that crop up around variables
+// passed to assert/logging/etc that gets stripped in certain configurations.
+//
+// Example:
+// int some_info IREE_ATTRIBUTE_UNUSED = compute_debug_info();
+// assert(some_info > 0); // stripped in NDEBUG
+#if IREE_HAVE_ATTRIBUTE(maybe_unused) && defined(__clang__)
+#define IREE_ATTRIBUTE_UNUSED __attribute__((maybe_unused))
+#elif IREE_HAVE_ATTRIBUTE(unused) || (defined(__GNUC__) && !defined(__clang__))
+#define IREE_ATTRIBUTE_UNUSED __attribute__((unused))
+#else
+#define IREE_ATTRIBUTE_UNUSED
+#endif // IREE_HAVE_ATTRIBUTE(maybe_unused / unused)
+
#endif // IREE_BASE_ATTRIBUTES_H_
diff --git a/iree/base/status.c b/iree/base/status.c
index a65a400..4942eb6 100644
--- a/iree/base/status.c
+++ b/iree/base/status.c
@@ -645,7 +645,8 @@
*out_buffer_length = 0;
// Grab storage which may have a message and zero or more payloads.
- iree_status_storage_t* storage = iree_status_storage(status);
+ iree_status_storage_t* storage IREE_ATTRIBUTE_UNUSED =
+ iree_status_storage(status);
// Prefix with source location and status code string (may be 'OK').
iree_host_size_t buffer_length = 0;
diff --git a/iree/compiler/Bindings/Native/Transforms/BUILD b/iree/compiler/Bindings/Native/Transforms/BUILD
index 3ebd30e..2f8f98c 100644
--- a/iree/compiler/Bindings/Native/Transforms/BUILD
+++ b/iree/compiler/Bindings/Native/Transforms/BUILD
@@ -22,8 +22,6 @@
deps = [
"//iree/compiler/Dialect/Flow/IR",
"//iree/compiler/Dialect/HAL/IR",
- "//iree/compiler/Dialect/Shape/IR",
- "//iree/compiler/Dialect/Shape/Transforms",
"//iree/compiler/Dialect/Util/IR",
"//iree/compiler/Utils",
"@llvm-project//llvm:Support",
diff --git a/iree/compiler/Bindings/Native/Transforms/CMakeLists.txt b/iree/compiler/Bindings/Native/Transforms/CMakeLists.txt
index 6441af1..8bb85ce 100644
--- a/iree/compiler/Bindings/Native/Transforms/CMakeLists.txt
+++ b/iree/compiler/Bindings/Native/Transforms/CMakeLists.txt
@@ -32,8 +32,6 @@
MLIRTransforms
iree::compiler::Dialect::Flow::IR
iree::compiler::Dialect::HAL::IR
- iree::compiler::Dialect::Shape::IR
- iree::compiler::Dialect::Shape::Transforms
iree::compiler::Dialect::Util::IR
iree::compiler::Utils
PUBLIC
diff --git a/iree/compiler/Bindings/Native/Transforms/Passes.cpp b/iree/compiler/Bindings/Native/Transforms/Passes.cpp
index cc86fb7..4ec5152 100644
--- a/iree/compiler/Bindings/Native/Transforms/Passes.cpp
+++ b/iree/compiler/Bindings/Native/Transforms/Passes.cpp
@@ -30,7 +30,7 @@
void registerTransformPassPipeline() {
PassPipelineRegistration<> transformPassPipeline(
- "iree-abi-transform-pipeline",
+ "iree-abi-transformation-pipeline",
"Runs the IREE native ABI bindings support pipeline",
[](OpPassManager &passManager) {
buildTransformPassPipeline(passManager);
diff --git a/iree/compiler/Bindings/Native/Transforms/WrapEntryPoints.cpp b/iree/compiler/Bindings/Native/Transforms/WrapEntryPoints.cpp
index 4e102d2..832829a 100644
--- a/iree/compiler/Bindings/Native/Transforms/WrapEntryPoints.cpp
+++ b/iree/compiler/Bindings/Native/Transforms/WrapEntryPoints.cpp
@@ -136,7 +136,7 @@
for (auto arg : llvm::enumerate(entryBlock->getArguments())) {
auto oldType = entryFuncType.getInput(arg.index());
if (oldType.isa<TensorType>()) {
- arguments.push_back(entryBuilder.create<IREE::HAL::TensorCastOp>(
+ arguments.push_back(entryBuilder.create<IREE::HAL::TensorImportOp>(
entryFuncOp.getLoc(), oldType, arg.value()));
} else {
arguments.push_back(arg.value());
@@ -153,7 +153,7 @@
auto oldType = entryFuncType.getResult(result.index());
auto newType = wrapperFuncType.getResult(result.index());
if (oldType.isa<TensorType>()) {
- results.push_back(entryBuilder.createOrFold<IREE::HAL::TensorCastOp>(
+ results.push_back(entryBuilder.createOrFold<IREE::HAL::TensorExportOp>(
entryFuncOp.getLoc(), newType, result.value()));
} else {
results.push_back(result.value());
diff --git a/iree/compiler/Bindings/Native/Transforms/test/wrap_entry_points.mlir b/iree/compiler/Bindings/Native/Transforms/test/wrap_entry_points.mlir
index 78e47bb..1762358 100644
--- a/iree/compiler/Bindings/Native/Transforms/test/wrap_entry_points.mlir
+++ b/iree/compiler/Bindings/Native/Transforms/test/wrap_entry_points.mlir
@@ -8,14 +8,14 @@
// CHECK-SAME: iree.abi.stub
// CHECK-SAME: } {
// CHECK-NEXT: %[[ARG0_DIM0:.+]] = hal.buffer_view.dim<%[[ARG0]] : !hal.buffer_view>[0] : index
-// CHECK-NEXT: %[[ARG0_TENSOR:.+]] = hal.tensor.cast %[[ARG0]] : !hal.buffer_view -> tensor<?x8x8x3xf32>{%[[ARG0_DIM0]]}
+// CHECK-NEXT: %[[ARG0_TENSOR:.+]] = hal.tensor.import %[[ARG0]] : !hal.buffer_view -> tensor<?x8x8x3xf32>{%[[ARG0_DIM0]]}
// CHECK-NEXT: %[[ARG1_DIM0:.+]] = hal.buffer_view.dim<%[[ARG1]] : !hal.buffer_view>[0] : index
-// CHECK-NEXT: %[[ARG1_TENSOR:.+]] = hal.tensor.cast %[[ARG1]] : !hal.buffer_view -> tensor<?x8x8x3xf32>{%[[ARG1_DIM0]]}
+// CHECK-NEXT: %[[ARG1_TENSOR:.+]] = hal.tensor.import %[[ARG1]] : !hal.buffer_view -> tensor<?x8x8x3xf32>{%[[ARG1_DIM0]]}
// CHECK-NEXT: %[[RET_TENSOR:.+]]:2 = call @_dynamicEntry(%[[ARG0_TENSOR]], %[[ARG1_TENSOR]])
// CHECK: %[[RET0_DIM0:.+]] = tensor.dim %[[RET_TENSOR]]#0, %c0{{.*}} : tensor<?x8x8x3xf32>
-// CHECK-NEXT: %[[RET0_VIEW:.+]] = hal.tensor.cast %[[RET_TENSOR]]#0 : tensor<?x8x8x3xf32>{%[[RET0_DIM0]]} -> !hal.buffer_view
+// CHECK-NEXT: %[[RET0_VIEW:.+]] = hal.tensor.export %[[RET_TENSOR]]#0 : tensor<?x8x8x3xf32>{%[[RET0_DIM0]]} -> !hal.buffer_view
// CHECK: %[[RET1_DIM0:.+]] = tensor.dim %[[RET_TENSOR]]#1, %c0{{.*}} : tensor<?x8x8x3xf32>
-// CHECK-NEXT: %[[RET1_VIEW:.+]] = hal.tensor.cast %[[RET_TENSOR]]#1 : tensor<?x8x8x3xf32>{%[[RET1_DIM0]]} -> !hal.buffer_view
+// CHECK-NEXT: %[[RET1_VIEW:.+]] = hal.tensor.export %[[RET_TENSOR]]#1 : tensor<?x8x8x3xf32>{%[[RET1_DIM0]]} -> !hal.buffer_view
// CHECK-NEXT: return %[[RET0_VIEW]], %[[RET1_VIEW]] : !hal.buffer_view, !hal.buffer_view
// CHECK-NEXT: }
diff --git a/iree/compiler/Bindings/TFLite/Transforms/BUILD b/iree/compiler/Bindings/TFLite/Transforms/BUILD
index 0fee820..8482307 100644
--- a/iree/compiler/Bindings/TFLite/Transforms/BUILD
+++ b/iree/compiler/Bindings/TFLite/Transforms/BUILD
@@ -22,6 +22,7 @@
deps = [
"//iree/compiler/Dialect/Flow/IR",
"//iree/compiler/Dialect/HAL/IR",
+ "//iree/compiler/Dialect/HAL/IR:HALDialect",
"//iree/compiler/Dialect/Util/IR",
"//iree/compiler/Utils",
"@llvm-project//llvm:Support",
diff --git a/iree/compiler/Bindings/TFLite/Transforms/CMakeLists.txt b/iree/compiler/Bindings/TFLite/Transforms/CMakeLists.txt
index 9ceda22..2e8ee00 100644
--- a/iree/compiler/Bindings/TFLite/Transforms/CMakeLists.txt
+++ b/iree/compiler/Bindings/TFLite/Transforms/CMakeLists.txt
@@ -31,6 +31,7 @@
MLIRTransforms
iree::compiler::Dialect::Flow::IR
iree::compiler::Dialect::HAL::IR
+ iree::compiler::Dialect::HAL::IR::HALDialect
iree::compiler::Dialect::Util::IR
iree::compiler::Utils
PUBLIC
diff --git a/iree/compiler/Bindings/TFLite/Transforms/WrapEntryPoints.cpp b/iree/compiler/Bindings/TFLite/Transforms/WrapEntryPoints.cpp
index ae54240..d17c0bb 100644
--- a/iree/compiler/Bindings/TFLite/Transforms/WrapEntryPoints.cpp
+++ b/iree/compiler/Bindings/TFLite/Transforms/WrapEntryPoints.cpp
@@ -225,7 +225,7 @@
auto inputPlaceholder =
recalculateBuilder.createOrFold<IREE::Util::NullOp>(loc, bufferType);
auto dynamicDims = inputDynamicDims.loadDynamicDims(recalculateBuilder);
- auto castOp = recalculateBuilder.create<IREE::HAL::TensorCastOp>(
+ auto castOp = recalculateBuilder.create<IREE::HAL::TensorImportOp>(
loc, inputValue.getType(), inputPlaceholder, dynamicDims);
inputValue.replaceAllUsesWith(castOp.target());
}
@@ -523,7 +523,7 @@
dynamicDims.push_back(entryBuilder.create<IREE::Util::GlobalLoadOp>(
arg.getLoc(), globalOp));
}
- callOperands.push_back(entryBuilder.create<IREE::HAL::TensorCastOp>(
+ callOperands.push_back(entryBuilder.create<IREE::HAL::TensorImportOp>(
arg.getLoc(), inputDynamicDims.tensorType, arg, dynamicDims));
}
auto callOp = entryBuilder.create<mlir::CallOp>(entryFuncOp.getLoc(),
@@ -539,7 +539,7 @@
entryBuilder.create<tensor::DimOp>(result.getLoc(), result, i));
}
}
- callResults.push_back(entryBuilder.create<IREE::HAL::TensorCastOp>(
+ callResults.push_back(entryBuilder.create<IREE::HAL::TensorExportOp>(
result.getLoc(), bufferType, result, dynamicDims));
for (auto it : llvm::zip(dynamicDims, outputDynamicDims.globalOps)) {
auto dynamicDim = std::get<0>(it);
diff --git a/iree/compiler/Bindings/TFLite/Transforms/test/wrap_entry_points.mlir b/iree/compiler/Bindings/TFLite/Transforms/test/wrap_entry_points.mlir
index 8f9fb2d..c630055 100644
--- a/iree/compiler/Bindings/TFLite/Transforms/test/wrap_entry_points.mlir
+++ b/iree/compiler/Bindings/TFLite/Transforms/test/wrap_entry_points.mlir
@@ -25,11 +25,11 @@
// Tie input0 shapes.
// CHECK-NEXT: %[[IN0_DIM0:.+]] = util.global.load @_tflite_dynamicEntry_input0_shape_dim0 : index
-// CHECK-NEXT: %[[IN0:.+]] = hal.tensor.cast %[[NULL]] : !hal.buffer -> tensor<?x8x8x3xf32>{%[[IN0_DIM0]]}
+// CHECK-NEXT: %[[IN0:.+]] = hal.tensor.import %[[NULL]] : !hal.buffer -> tensor<?x8x8x3xf32>{%[[IN0_DIM0]]}
// Tie input1 shapes.
// CHECK-NEXT: %[[IN1_DIM0:.+]] = util.global.load @_tflite_dynamicEntry_input1_shape_dim0 : index
-// CHECK-NEXT: %[[IN1:.+]] = hal.tensor.cast %[[NULL]] : !hal.buffer -> tensor<?x8x8x3xf32>{%[[IN1_DIM0]]}
+// CHECK-NEXT: %[[IN1:.+]] = hal.tensor.import %[[NULL]] : !hal.buffer -> tensor<?x8x8x3xf32>{%[[IN1_DIM0]]}
// The actual model code used to (eventually) compute shapes.
// CHECK-NEXT: %[[OUT0:.+]] = mhlo.add %[[IN0]], %[[IN1]]
@@ -165,23 +165,23 @@
// Cast input0 buffer to a shaped tensor.
// CHECK: %[[IN0_DIM0:.+]] = util.global.load @_tflite_dynamicEntry_input0_shape_dim0 : index
-// CHECK-NEXT: %[[IN0:.+]] = hal.tensor.cast %[[IN0_BUFFER]] : !hal.buffer -> tensor<?x8x8x3xf32>{%[[IN0_DIM0]]}
+// CHECK-NEXT: %[[IN0:.+]] = hal.tensor.import %[[IN0_BUFFER]] : !hal.buffer -> tensor<?x8x8x3xf32>{%[[IN0_DIM0]]}
// Cast input1 buffer to a shaped tensor.
// CHECK: %[[IN1_DIM0:.+]] = util.global.load @_tflite_dynamicEntry_input1_shape_dim0 : index
-// CHECK-NEXT: %[[IN1:.+]] = hal.tensor.cast %[[IN1_BUFFER]] : !hal.buffer -> tensor<?x8x8x3xf32>{%[[IN1_DIM0]]}
+// CHECK-NEXT: %[[IN1:.+]] = hal.tensor.import %[[IN1_BUFFER]] : !hal.buffer -> tensor<?x8x8x3xf32>{%[[IN1_DIM0]]}
// Call the original function with tensor arguments.
// CHECK: %[[OUT:.+]]:2 = call @dynamicEntry(%[[IN0]], %[[IN1]]) : (tensor<?x8x8x3xf32>, tensor<?x8x8x3xf32>) -> (tensor<?x8x8x3xf32>, tensor<?x8x8x3xf32>)
// Query output0 shape and get the HAL buffer to return.
// CHECK: %[[OUT0_DIM0:.+]] = tensor.dim %[[OUT]]#0, %c0 : tensor<?x8x8x3xf32>
-// CHECK-NEXT: %[[OUT0_BUFFER:.+]] = hal.tensor.cast %[[OUT]]#0 : tensor<?x8x8x3xf32>{%[[OUT0_DIM0]]} -> !hal.buffer
+// CHECK-NEXT: %[[OUT0_BUFFER:.+]] = hal.tensor.export %[[OUT]]#0 : tensor<?x8x8x3xf32>{%[[OUT0_DIM0]]} -> !hal.buffer
// CHECK-NEXT: util.global.store %[[OUT0_DIM0]], @_tflite_dynamicEntry_output0_shape_dim0 : index
// Query output1 shape and get the HAL buffer to return.
// CHECK: %[[OUT1_DIM0:.+]] = tensor.dim %[[OUT]]#1, %c0 : tensor<?x8x8x3xf32>
-// CHECK-NEXT: %[[OUT1_BUFFER:.+]] = hal.tensor.cast %[[OUT]]#1 : tensor<?x8x8x3xf32>{%[[OUT1_DIM0]]} -> !hal.buffer
+// CHECK-NEXT: %[[OUT1_BUFFER:.+]] = hal.tensor.export %[[OUT]]#1 : tensor<?x8x8x3xf32>{%[[OUT1_DIM0]]} -> !hal.buffer
// CHECK-NEXT: util.global.store %[[OUT1_DIM0]], @_tflite_dynamicEntry_output1_shape_dim0 : index
// Clear shape dirty bit as we've updated the shapes unconditionally.
diff --git a/iree/compiler/Codegen/Common/BUILD b/iree/compiler/Codegen/Common/BUILD
index c4efb22..758db7f 100644
--- a/iree/compiler/Codegen/Common/BUILD
+++ b/iree/compiler/Codegen/Common/BUILD
@@ -56,11 +56,11 @@
"//iree/compiler/Codegen:PassHeaders",
"//iree/compiler/Codegen/Common:FoldTensorExtractOpIncGen",
"//iree/compiler/Codegen/Dialect:IREECodegenDialect",
+ "//iree/compiler/Codegen/Interfaces:BufferizationInterfaces",
"//iree/compiler/Codegen/Transforms",
"//iree/compiler/Codegen/Utils",
"//iree/compiler/Dialect/Flow/IR",
"//iree/compiler/Dialect/HAL/IR",
- "//iree/compiler/Dialect/Shape/IR",
"//iree/compiler/Dialect/Util/IR",
"//llvm-external-projects/iree-dialects:IREELinalgExtDialect",
"//llvm-external-projects/iree-dialects:IREELinalgExtTransforms",
diff --git a/iree/compiler/Codegen/Common/BufferizationAnalysis.cpp b/iree/compiler/Codegen/Common/BufferizationAnalysis.cpp
index 3826d92..8cdce87 100644
--- a/iree/compiler/Codegen/Common/BufferizationAnalysis.cpp
+++ b/iree/compiler/Codegen/Common/BufferizationAnalysis.cpp
@@ -17,7 +17,6 @@
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
#include "iree/compiler/Dialect/Flow/IR/FlowTypes.h"
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
-#include "iree/compiler/Dialect/Shape/IR/ShapeOps.h"
#include "llvm/ADT/TypeSwitch.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
diff --git a/iree/compiler/Codegen/Common/CMakeLists.txt b/iree/compiler/Codegen/Common/CMakeLists.txt
index 4466194..27484a0 100644
--- a/iree/compiler/Codegen/Common/CMakeLists.txt
+++ b/iree/compiler/Codegen/Common/CMakeLists.txt
@@ -76,12 +76,12 @@
MLIRVectorBufferizableOpInterfaceImpl
iree::compiler::Codegen::Common::FoldTensorExtractOpIncGen
iree::compiler::Codegen::Dialect::IREECodegenDialect
+ iree::compiler::Codegen::Interfaces::BufferizationInterfaces
iree::compiler::Codegen::PassHeaders
iree::compiler::Codegen::Transforms
iree::compiler::Codegen::Utils
iree::compiler::Dialect::Flow::IR
iree::compiler::Dialect::HAL::IR
- iree::compiler::Dialect::Shape::IR
iree::compiler::Dialect::Util::IR
PUBLIC
)
diff --git a/iree/compiler/Codegen/Common/CleanupBufferAllocViewPass.cpp b/iree/compiler/Codegen/Common/CleanupBufferAllocViewPass.cpp
index 912c0f4..5484102 100644
--- a/iree/compiler/Codegen/Common/CleanupBufferAllocViewPass.cpp
+++ b/iree/compiler/Codegen/Common/CleanupBufferAllocViewPass.cpp
@@ -121,7 +121,10 @@
FoldReshapeIntoInterfaceTensorLoad<linalg::TensorCollapseShapeOp>,
FoldReshapeIntoInterfaceTensorLoad<linalg::TensorExpandShapeOp>,
RemoveDeadMemAllocs>(&getContext());
- (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+ if (failed(applyPatternsAndFoldGreedily(getOperation(),
+ std::move(patterns)))) {
+ return signalPassFailure();
+ }
}
};
diff --git a/iree/compiler/Codegen/Common/ConvertToDestinationPassingStylePass.cpp b/iree/compiler/Codegen/Common/ConvertToDestinationPassingStylePass.cpp
index a70a215..1ca48ce 100644
--- a/iree/compiler/Codegen/Common/ConvertToDestinationPassingStylePass.cpp
+++ b/iree/compiler/Codegen/Common/ConvertToDestinationPassingStylePass.cpp
@@ -21,7 +21,6 @@
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
#include "iree/compiler/Dialect/Flow/IR/FlowTypes.h"
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
-#include "iree/compiler/Dialect/Shape/IR/ShapeOps.h"
#include "iree/compiler/Dialect/Util/IR/UtilDialect.h"
#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
#include "llvm/ADT/DenseSet.h"
@@ -143,6 +142,15 @@
return result;
}
}
+ if (auto yieldOp = dyn_cast<scf::YieldOp>(operand.getOwner())) {
+ Operation *parentOp = yieldOp->getParentOp();
+ if (isa<scf::ForOp, scf::IfOp>(parentOp)) {
+ Value result = parentOp->getResult(operand.getOperandNumber());
+ if (plan.isEquivalent(result, operand.get())) {
+ return result;
+ }
+ }
+ }
return nullptr;
}
@@ -157,9 +165,6 @@
llvm::DenseSet<Value> &processed) {
Operation *user = nullptr;
while (value.hasOneUse()) {
- assert(!processed.count(value) &&
- "unexpected traversal through already traversed value during "
- "conversion to destination passing style");
processed.insert(value);
OpOperand &use = *value.use_begin();
user = use.getOwner();
@@ -214,15 +219,13 @@
"failed walk of uses to get to flow.dispatch.tensor.store op");
}
- // For now we need to change the usage to use destination passing style only
- // if the start of the use-def chain is an init_tensor op. Might adapt this as
- // we go along.
- if (!isa<linalg::InitTensorOp>(resultValueOp)) {
- return success();
- }
-
OpBuilder::InsertionGuard g(b);
- b.setInsertionPoint(resultValueOp);
+ b.setInsertionPointToStart(storeOp->getBlock());
+ if (auto sourceDefiningOp = storeOp.target().getDefiningOp()) {
+ if (sourceDefiningOp->getBlock() == storeOp->getBlock()) {
+ b.setInsertionPointAfter(sourceDefiningOp);
+ }
+ }
Value resultBuffer = getTensorLoadOpForTensorStoreOp(b, storeOp);
// Now replay the instructions that are essentially doing type-conversion, in
@@ -270,9 +273,9 @@
}
llvm::DenseSet<Value> processed;
- auto walkResult =
- funcOp.walk<WalkOrder::PreOrder>([&](Operation *op) -> WalkResult {
- for (auto result : op->getResults()) {
+ auto walkResult = funcOp.walk<WalkOrder::PreOrder>(
+ [&](linalg::InitTensorOp initTensorOp) -> WalkResult {
+ for (auto result : initTensorOp->getResults()) {
if (!result.getType().isa<RankedTensorType>()) continue;
if (plan.isInStoreSet(result) && !processed.count(result)) {
return modifyResultToUseStoreBuffer(b, result, plan, processed);
diff --git a/iree/compiler/Codegen/Common/FlattenMemRefSubspanPass.cpp b/iree/compiler/Codegen/Common/FlattenMemRefSubspanPass.cpp
index 62bfb04..f08c3ff 100644
--- a/iree/compiler/Codegen/Common/FlattenMemRefSubspanPass.cpp
+++ b/iree/compiler/Codegen/Common/FlattenMemRefSubspanPass.cpp
@@ -36,8 +36,6 @@
#include "iree/compiler/Codegen/PassDetail.h"
#include "iree/compiler/Codegen/Passes.h"
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
-#include "iree/compiler/Dialect/Shape/IR/ShapeDialect.h"
-#include "iree/compiler/Dialect/Shape/IR/ShapeOps.h"
#include "llvm/Support/Debug.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
@@ -266,12 +264,9 @@
Operation *sourceOp = sourceValue.getDefiningOp();
SmallVector<Value, 4> dims;
dims.reserve(rank);
- if (auto shapeCarryOp = dyn_cast<ShapeCarryingInterface>(sourceOp)) {
- Value shapeOp =
- shapeCarryOp.buildResultValueRankedShape(sourceValue, builder);
- for (int i = 0; i < rank; ++i) {
- dims.push_back(builder.create<Shape::RankedDimOp>(loc, shapeOp, i));
- }
+ if (auto shapeAwareOp =
+ dyn_cast<IREE::Util::ShapeAwareOpInterface>(sourceOp)) {
+ dims = shapeAwareOp.buildResultValueShape(sourceValue, builder);
} else {
auto getDimValues = [&](MemRefType type, ValueRange dynamicDims) {
auto shape = type.getShape();
@@ -559,7 +554,7 @@
FlattenMemRefSubspanPass(const FlattenMemRefSubspanPass &pass) {}
void getDependentDialects(DialectRegistry ®istry) const override {
- registry.insert<AffineDialect, memref::MemRefDialect, ShapeDialect>();
+ registry.insert<AffineDialect, memref::MemRefDialect>();
}
void runOnOperation() override {
@@ -622,7 +617,10 @@
foldPatterns.add<FoldSubspanOffsetIntoLoadStore<memref::LoadOp>,
FoldSubspanOffsetIntoLoadStore<memref::StoreOp>>(&context);
- (void)applyPatternsAndFoldGreedily(getOperation(), std::move(foldPatterns));
+ if (failed(applyPatternsAndFoldGreedily(getOperation(),
+ std::move(foldPatterns)))) {
+ return signalPassFailure();
+ }
}
};
diff --git a/iree/compiler/Codegen/Common/FoldAffineMinInDistributedLoops.cpp b/iree/compiler/Codegen/Common/FoldAffineMinInDistributedLoops.cpp
index bf6ded3..d455909 100644
--- a/iree/compiler/Codegen/Common/FoldAffineMinInDistributedLoops.cpp
+++ b/iree/compiler/Codegen/Common/FoldAffineMinInDistributedLoops.cpp
@@ -179,7 +179,13 @@
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
populateFoldAffineMinInDistributedLoopsPatterns(patterns);
- (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+ if (failed(applyPatternsAndFoldGreedily(getOperation(),
+ std::move(patterns)))) {
+ // TODO(#4759): Terrifyingly, this fails. Errors here were ignored for a
+ // long time and now tests for this pass actually fail if we propagate the
+ // failure correctly. Fix this.
+ // return signalPassFailure();
+ }
}
};
} // namespace
diff --git a/iree/compiler/Codegen/Common/ForOpCanonicalizationPass.cpp b/iree/compiler/Codegen/Common/ForOpCanonicalizationPass.cpp
index 4bbb3e2..11094ab 100644
--- a/iree/compiler/Codegen/Common/ForOpCanonicalizationPass.cpp
+++ b/iree/compiler/Codegen/Common/ForOpCanonicalizationPass.cpp
@@ -226,7 +226,9 @@
OwningRewritePatternList patterns(&getContext());
patterns.insert<CanonicalizeForOpInductionVarShape,
PackForOpInductionVarVector>(fn.getContext());
- (void)applyPatternsAndFoldGreedily(fn, std::move(patterns));
+ if (failed(applyPatternsAndFoldGreedily(fn, std::move(patterns)))) {
+ return signalPassFailure();
+ }
}
};
diff --git a/iree/compiler/Codegen/Common/IREEComprehensiveBufferizePass.cpp b/iree/compiler/Codegen/Common/IREEComprehensiveBufferizePass.cpp
index 7f814b3..8b2ac5f 100644
--- a/iree/compiler/Codegen/Common/IREEComprehensiveBufferizePass.cpp
+++ b/iree/compiler/Codegen/Common/IREEComprehensiveBufferizePass.cpp
@@ -11,6 +11,7 @@
//===----------------------------------------------------------------------===//
#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h"
#include "iree/compiler/Codegen/Common/BufferizationAnalysis.h"
+#include "iree/compiler/Codegen/Interfaces/BufferizationInterfaces.h"
#include "iree/compiler/Codegen/PassDetail.h"
#include "iree/compiler/Codegen/Passes.h"
#include "iree/compiler/Codegen/Transforms/Transforms.h"
@@ -19,7 +20,6 @@
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
#include "iree/compiler/Dialect/Flow/IR/FlowTypes.h"
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
-#include "iree/compiler/Dialect/Shape/IR/ShapeOps.h"
#include "iree/compiler/Dialect/Util/IR/UtilDialect.h"
#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
#include "llvm/ADT/EquivalenceClasses.h"
@@ -30,15 +30,7 @@
#include "llvm/Support/ErrorHandling.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
-#include "mlir/Dialect/Linalg/ComprehensiveBufferize/AffineInterfaceImpl.h"
-#include "mlir/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.h"
-#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h"
-#include "mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h"
-#include "mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h"
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h"
-#include "mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h"
-#include "mlir/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.h"
-#include "mlir/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
@@ -62,186 +54,32 @@
namespace mlir {
namespace iree_compiler {
-//===----------------------------------------------------------------------===//
-// Pass that interfaces with ComprehensiveBufferization in core.
-//===----------------------------------------------------------------------===//
-
-template <typename TensorType>
-static MemRefType getMemrefTypeForTensor(TensorType tensorType,
- MemRefLayoutAttrInterface layout = {},
- Attribute memorySpace = {}) {
- return MemRefType::get(tensorType.getShape(), tensorType.getElementType(),
- layout, memorySpace);
-}
-
using linalg::comprehensive_bufferize::BufferizableOpInterface;
using linalg::comprehensive_bufferize::BufferizationAliasInfo;
using linalg::comprehensive_bufferize::BufferizationState;
-Value getSubspanBuffer(Value tensor, OpBuilder &b, BufferizationState &state) {
- if (!state.isMapped(tensor)) {
- OpBuilder::InsertionGuard g(b);
- auto subspanOp =
- tensor.getDefiningOp<IREE::HAL::InterfaceBindingSubspanOp>();
- assert(subspanOp && "expected LoadOp/StoreOp source/target is SubspanOp");
-
- auto shapedType = subspanOp.getResult()
- .getType()
- .dyn_cast<IREE::Flow::DispatchTensorType>();
- assert(shapedType && shapedType.hasRank());
-
- b.setInsertionPoint(subspanOp);
- // Just change the result type of the InterfaceBindingSubspanOp.
- auto memRefType = getMemrefTypeForTensor(shapedType);
- auto baseBuffer = b.create<IREE::HAL::InterfaceBindingSubspanOp>(
- subspanOp->getLoc(), memRefType, subspanOp.binding(),
- subspanOp.byte_offset(), subspanOp.byte_length(),
- subspanOp.dynamic_dims(), subspanOp.alignmentAttr());
- state.mapValue(subspanOp, baseBuffer);
- state.aliasInfo.createAliasInfoEntry(subspanOp.result());
- }
-
- return state.lookupValue(tensor);
-}
-
namespace {
-struct DispatchTensorLoadOpInterface
- : public BufferizableOpInterface::ExternalModel<
- DispatchTensorLoadOpInterface, IREE::Flow::DispatchTensorLoadOp> {
- SmallVector<OpOperand *> getAliasingOpOperand(Operation *op,
- OpResult opResult) const {
- return {};
- }
-
- bool isWritable(Operation *op, Value value) const {
- auto loadOp = cast<IREE::Flow::DispatchTensorLoadOp>(op);
- auto shapedType =
- loadOp.source().getType().dyn_cast<IREE::Flow::DispatchTensorType>();
- assert(shapedType && "unexpected source type");
- return shapedType.getAccess() != IREE::Flow::TensorAccess::ReadOnly;
- }
-
- LogicalResult bufferize(Operation *op, OpBuilder &b,
- BufferizationState &state) const {
- OpBuilder::InsertionGuard g(b);
- b.setInsertionPoint(op);
- auto loadOp = cast<IREE::Flow::DispatchTensorLoadOp>(op);
- Value source = getSubspanBuffer(loadOp.source(), b, state);
-
- // Bufferize to subview.
- Value subView = b.create<memref::SubViewOp>(
- loadOp->getLoc(), source, loadOp.getMixedOffsets(),
- loadOp.getMixedSizes(), loadOp.getMixedStrides());
- state.mapBuffer(loadOp.result(), subView);
-
- return success();
- }
-};
-
-/// Return true if the value of a `storeOp` bufferizes to an equivalent
-/// DispatchTensorLoadOp result that bufferizes inplace.
-static bool isValueEquivalentToAnInplaceTensorLoadOp(
- const BufferizationAliasInfo &aliasInfo,
- IREE::Flow::DispatchTensorStoreOp storeOp) {
- bool foundOp = false;
- aliasInfo.applyOnEquivalenceClass(storeOp.value(), [&](Value value) {
- auto loadOp = value.getDefiningOp<IREE::Flow::DispatchTensorLoadOp>();
- // TODO: Assert that offsets, sizes and strides are the same.
- if (loadOp &&
- aliasInfo.areEquivalentBufferizedValues(loadOp.result(),
- storeOp.value()) &&
- loadOp.source() == storeOp.target()) {
- foundOp = true;
- }
- });
-
- return foundOp;
-}
-
-struct DispatchTensorStoreOpInterface
- : public BufferizableOpInterface::ExternalModel<
- DispatchTensorStoreOpInterface, IREE::Flow::DispatchTensorStoreOp> {
- bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const {
- return true;
- }
-
- bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const {
- return false;
- }
-
- OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const {
- return OpResult();
- }
-
- LogicalResult bufferize(Operation *op, OpBuilder &b,
- BufferizationState &state) const {
- OpBuilder::InsertionGuard g(b);
- b.setInsertionPoint(op);
- auto storeOp = cast<IREE::Flow::DispatchTensorStoreOp>(op);
-
- // If everything bufferized inplace, no copy is needed. We wrote to the
- // target buffer already.
- if (!isValueEquivalentToAnInplaceTensorLoadOp(state.aliasInfo, storeOp)) {
- Value target = getSubspanBuffer(storeOp.target(), b, state);
- Value subView = b.create<memref::SubViewOp>(
- storeOp->getLoc(), target, storeOp.getMixedOffsets(),
- storeOp.getMixedSizes(), storeOp.getMixedStrides());
- Value srcMemref = state.lookupBuffer(storeOp.value());
- state.allocationFns.memCpyFn(b, storeOp->getLoc(), srcMemref, subView);
- }
-
- state.markOpObsolete(storeOp);
- return success();
- }
-};
-
-using mlir::linalg::comprehensive_bufferize::linalg_ext::
- InitTensorEliminationStep;
-
-/// Try to eliminate InitTensorOps that are eventually fed into a
-/// DispatchTensorStoreOp. Such InitTensorOps are replaced with matching
-/// DispatchTensorLoadOps. Two conditions must be met:
-///
-/// * The target must be a "readwrite" tensor.
-/// * All ops along the reverse SSA use-def chain from the
-/// DispatchTensorStoreOp to the InitTensorOp must have bufferized in-place.
-struct StoreTensorOpAnchoredInitTensorEliminationStep
- : public InitTensorEliminationStep {
- LogicalResult run(FuncOp funcOp, BufferizationState &state,
- SmallVector<Operation *> &newOps) override {
- return eliminateInitTensors(
- funcOp, state,
- /*anchorMatchFunc=*/
- [&](OpOperand &operand) {
- return isa<IREE::Flow::DispatchTensorStoreOp>(operand.getOwner());
- },
- /*rewriteFunc=*/
- [](OpBuilder &b, Location loc, OpOperand &operand) {
- auto storeOp =
- cast<IREE::Flow::DispatchTensorStoreOp>(operand.getOwner());
- auto loadOp = b.create<IREE::Flow::DispatchTensorLoadOp>(
- loc, storeOp.value().getType().cast<RankedTensorType>(),
- storeOp.target(), storeOp.target_dims(),
- storeOp.getMixedOffsets(), storeOp.getMixedSizes(),
- storeOp.getMixedStrides());
- return loadOp.result();
- },
- newOps);
- }
-};
-
/// Pass to convert from tensor based ops to memref based ops.
class IREEComprehensiveBufferizePass
: public IREEComprehensiveBufferizeBase<IREEComprehensiveBufferizePass> {
public:
explicit IREEComprehensiveBufferizePass(
std::unique_ptr<linalg::comprehensive_bufferize::AllocationCallbacks>
- allocationFn)
- : allocationFn(std::move(allocationFn)) {}
+ allocationFn) {
+ options.allocationFns = std::move(allocationFn);
+ options.testAnalysisOnly = false;
+ addPostAnalysisTransformations(options);
+ }
IREEComprehensiveBufferizePass(const IREEComprehensiveBufferizePass &other) {
- llvm_unreachable("pass cannot be copied");
+ options.allocationFns =
+ std::make_unique<linalg::comprehensive_bufferize::AllocationCallbacks>(
+ other.options.allocationFns->allocationFn,
+ other.options.allocationFns->deallocationFn,
+ other.options.allocationFns->memCpyFn);
+ options.testAnalysisOnly = other.options.testAnalysisOnly;
+ addPostAnalysisTransformations(options);
}
void getDependentDialects(DialectRegistry ®istry) const override {
@@ -250,34 +88,12 @@
scf::SCFDialect, StandardOpsDialect, tensor::TensorDialect,
vector::VectorDialect, AffineDialect,
IREE::Flow::FlowDialect>();
-
- // TODO: Find a better place to register external models.
- // Registers operations of other dialects.
- linalg::comprehensive_bufferize::affine_ext::
- registerBufferizableOpInterfaceExternalModels(registry);
- linalg::comprehensive_bufferize::arith_ext::
- registerBufferizableOpInterfaceExternalModels(registry);
- linalg::comprehensive_bufferize::linalg_ext::
- registerBufferizableOpInterfaceExternalModels(registry);
- linalg::comprehensive_bufferize::scf_ext::
- registerBufferizableOpInterfaceExternalModels(registry);
- linalg::comprehensive_bufferize::tensor_ext::
- registerBufferizableOpInterfaceExternalModels(registry);
- linalg::comprehensive_bufferize::vector_ext::
- registerBufferizableOpInterfaceExternalModels(registry);
-
- // Register IREE operations.
- registry.addOpInterface<IREE::Flow::DispatchTensorLoadOp,
- DispatchTensorLoadOpInterface>();
- registry.addOpInterface<IREE::Flow::DispatchTensorStoreOp,
- DispatchTensorStoreOpInterface>();
}
void runOnOperation() override;
private:
- std::unique_ptr<linalg::comprehensive_bufferize::AllocationCallbacks>
- allocationFn;
+ linalg::comprehensive_bufferize::BufferizationOptions options;
};
} // namespace
@@ -286,30 +102,35 @@
/// Run comprehensive bufferize.
void IREEComprehensiveBufferizePass::runOnOperation() {
ModuleOp moduleOp = getOperation();
-
- linalg::comprehensive_bufferize::BufferizationOptions options;
- options.testAnalysisOnly = false;
- // Enable InitTensorOp elimination.
- options.addPostAnalysisStep<StoreTensorOpAnchoredInitTensorEliminationStep>();
- options.addPostAnalysisStep<linalg::comprehensive_bufferize::tensor_ext::
- InplaceInsertSliceOpAnalysis>();
- // TODO: Use allocationFn.
-
- if (failed(runComprehensiveBufferize(moduleOp, options))) signalPassFailure();
+ if (failed(runComprehensiveBufferize(moduleOp, options))) {
+ return signalPassFailure();
+ }
}
-// TODO: pass this to comprehensive bufferize.
-static Value defaultAllocationFn(OpBuilder &builder, Location loc,
- ArrayRef<int64_t> staticShape,
- Type elementType,
- ArrayRef<Value> dynamicSizes) {
- auto allocationType = MemRefType::get(staticShape, elementType);
- return builder.create<memref::AllocOp>(loc, allocationType, dynamicSizes);
+// Default allocation functions.
+static Optional<Value> defaultAllocationFn(OpBuilder &builder, Location loc,
+ MemRefType allocationType,
+ ArrayRef<Value> dynamicSizes) {
+ return builder.create<memref::AllocOp>(loc, allocationType, dynamicSizes)
+ .getResult();
+}
+static void defaultDeallocationFn(OpBuilder &builder, Location loc,
+ Value allocation) {
+ builder.create<memref::DeallocOp>(loc, allocation);
+}
+static void defaultMemCpyFn(OpBuilder &builder, Location loc, Value from,
+ Value to) {
+ builder.create<linalg::CopyOp>(loc, from, to);
}
std::unique_ptr<OperationPass<ModuleOp>> createIREEComprehensiveBufferizePass(
std::unique_ptr<linalg::comprehensive_bufferize::AllocationCallbacks>
allocationFns) {
+ if (!allocationFns) {
+ allocationFns =
+ std::make_unique<linalg::comprehensive_bufferize::AllocationCallbacks>(
+ defaultAllocationFn, defaultDeallocationFn, defaultMemCpyFn);
+ }
return std::make_unique<IREEComprehensiveBufferizePass>(
std::move(allocationFns));
}
@@ -318,6 +139,8 @@
OpPassManager &passManager,
std::unique_ptr<linalg::comprehensive_bufferize::AllocationCallbacks>
allocationFns) {
+ passManager.addNestedPass<FuncOp>(
+ createConvertToDestinationPassingStylePass());
passManager.addPass(
createIREEComprehensiveBufferizePass(std::move(allocationFns)));
passManager.addPass(memref::createResolveShapedTypeResultDimsPass());
diff --git a/iree/compiler/Codegen/Common/LinalgBufferizePass.cpp b/iree/compiler/Codegen/Common/LinalgBufferizePass.cpp
index 38d43ea..7171eb0 100644
--- a/iree/compiler/Codegen/Common/LinalgBufferizePass.cpp
+++ b/iree/compiler/Codegen/Common/LinalgBufferizePass.cpp
@@ -45,7 +45,6 @@
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
#include "iree/compiler/Dialect/Flow/IR/FlowTypes.h"
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
-#include "iree/compiler/Dialect/Shape/IR/ShapeOps.h"
#include "iree/compiler/Dialect/Util/IR/UtilDialect.h"
#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
#include "llvm/ADT/EquivalenceClasses.h"
diff --git a/iree/compiler/Codegen/Common/OptimizeVectorTransferPass.cpp b/iree/compiler/Codegen/Common/OptimizeVectorTransferPass.cpp
index 1de9c92..9e06edd 100644
--- a/iree/compiler/Codegen/Common/OptimizeVectorTransferPass.cpp
+++ b/iree/compiler/Codegen/Common/OptimizeVectorTransferPass.cpp
@@ -97,7 +97,9 @@
patterns.add<TransposeUnitDimToShapeCast>(&getContext());
mlir::vector::populateVectorTransferCollapseInnerMostContiguousDimsPatterns(
patterns);
- (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
+ if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
+ return signalPassFailure();
+ }
// Workaround, run loop invariant code motion before hoist redudant vector
// transfer to workaround a bug upstream.
// TODO(thomasraoux): Remove it once the fix is merged.
diff --git a/iree/compiler/Codegen/Common/RemoveTrivialLoops.cpp b/iree/compiler/Codegen/Common/RemoveTrivialLoops.cpp
index 3527287..4a2eb57 100644
--- a/iree/compiler/Codegen/Common/RemoveTrivialLoops.cpp
+++ b/iree/compiler/Codegen/Common/RemoveTrivialLoops.cpp
@@ -121,9 +121,9 @@
return numWorkgroups;
}
-static void removeOneTripTiledLoops(FuncOp funcOp,
- ArrayRef<int64_t> workgroupSize,
- ArrayRef<int64_t> numWorkgroups) {
+static LogicalResult removeOneTripTiledLoops(FuncOp funcOp,
+ ArrayRef<int64_t> workgroupSize,
+ ArrayRef<int64_t> numWorkgroups) {
auto getWorkgroupRangeFn = [numWorkgroups, workgroupSize](
Value processorValue,
SmallVectorImpl<Value> &dims,
@@ -133,7 +133,7 @@
};
OwningRewritePatternList patterns(funcOp.getContext());
populateRemoveSingleIterationLoopPattern(patterns, getWorkgroupRangeFn);
- (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
+ return applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
}
namespace {
@@ -148,7 +148,9 @@
SmallVector<int64_t> workgroupSize = getWorkgroupSize(entryPointOp);
SmallVector<int64_t> numWorkgroups = getNumWorkgroup(funcOp, entryPointOp);
- removeOneTripTiledLoops(funcOp, workgroupSize, numWorkgroups);
+ if (failed(removeOneTripTiledLoops(funcOp, workgroupSize, numWorkgroups))) {
+ return signalPassFailure();
+ }
}
};
} // namespace
diff --git a/iree/compiler/Codegen/Common/SetNumWorkgroupsPass.cpp b/iree/compiler/Codegen/Common/SetNumWorkgroupsPass.cpp
index 332c29a..bb4dc4b 100644
--- a/iree/compiler/Codegen/Common/SetNumWorkgroupsPass.cpp
+++ b/iree/compiler/Codegen/Common/SetNumWorkgroupsPass.cpp
@@ -150,7 +150,10 @@
populateAffineMinSCFCanonicalizationPattern(canonicalization);
IREE::Flow::populateFlowDispatchCanonicalizationPatterns(canonicalization,
context);
- (void)applyPatternsAndFoldGreedily(module, std::move(canonicalization));
+ if (failed(
+ applyPatternsAndFoldGreedily(module, std::move(canonicalization)))) {
+ return signalPassFailure();
+ }
}
std::unique_ptr<OperationPass<IREE::HAL::ExecutableVariantOp>>
diff --git a/iree/compiler/Codegen/Common/VectorizeConv.cpp b/iree/compiler/Codegen/Common/VectorizeConv.cpp
index 164662a..1d725f6 100644
--- a/iree/compiler/Codegen/Common/VectorizeConv.cpp
+++ b/iree/compiler/Codegen/Common/VectorizeConv.cpp
@@ -367,7 +367,10 @@
MLIRContext *context = &getContext();
OwningRewritePatternList patterns(&getContext());
patterns.insert<VectorizeLinalgConv, VectorizeLinalgDepthwiseConv>(context);
- (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+ if (failed(applyPatternsAndFoldGreedily(getOperation(),
+ std::move(patterns)))) {
+ return signalPassFailure();
+ }
}
};
diff --git a/iree/compiler/Codegen/Common/VectorizeMMT4d.cpp b/iree/compiler/Codegen/Common/VectorizeMMT4d.cpp
index 0686234..27e1952 100644
--- a/iree/compiler/Codegen/Common/VectorizeMMT4d.cpp
+++ b/iree/compiler/Codegen/Common/VectorizeMMT4d.cpp
@@ -148,7 +148,10 @@
MLIRContext *context = &getContext();
OwningRewritePatternList patterns(&getContext());
patterns.insert<VectorizeMMT4DOp>(context);
- (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+ if (failed(applyPatternsAndFoldGreedily(getOperation(),
+ std::move(patterns)))) {
+ return signalPassFailure();
+ }
}
};
diff --git a/iree/compiler/Codegen/Common/test/convert_to_destination_passing_style.mlir b/iree/compiler/Codegen/Common/test/convert_to_destination_passing_style.mlir
index f37345b..1712106 100644
--- a/iree/compiler/Codegen/Common/test/convert_to_destination_passing_style.mlir
+++ b/iree/compiler/Codegen/Common/test/convert_to_destination_passing_style.mlir
@@ -217,9 +217,9 @@
// CHECK: func @reshape_fused_source()
// CHECK-DAG: %[[ARG0:.+]] = hal.interface.binding.subspan @io::@arg0
// CHECK-DAG: %[[RET0:.+]] = hal.interface.binding.subspan @io::@ret0
+// CHECK: %[[TARGET:.+]] = flow.dispatch.tensor.load %[[RET0]]
// CHECK: %[[SOURCE:.+]] = flow.dispatch.tensor.load %[[ARG0]]
// CHECK: %[[RESHAPE:.+]] = linalg.tensor_expand_shape %[[SOURCE]]
-// CHECK: %[[TARGET:.+]] = flow.dispatch.tensor.load %[[RET0]]
// CHECK: %[[GENERIC:.+]] = linalg.generic
// CHECK-SAME: ins(%[[RESHAPE]]
// CHECK-SAME: outs(%[[TARGET]]
@@ -259,10 +259,10 @@
// CHECK: func @reshape_fused_source_and_copyout()
// CHECK-DAG: %[[ARG0:.+]] = hal.interface.binding.subspan @io::@arg0
// CHECK-DAG: %[[RET0:.+]] = hal.interface.binding.subspan @io::@ret0
+// CHECK-DAG: %[[TARGET:.+]] = flow.dispatch.tensor.load %[[RET0]]
// CHECK-DAG: %[[RET1:.+]] = hal.interface.binding.subspan @io::@ret1
// CHECK: %[[SOURCE:.+]] = flow.dispatch.tensor.load %[[ARG0]]
// CHECK: %[[RESHAPE:.+]] = linalg.tensor_expand_shape %[[SOURCE]]
-// CHECK: %[[TARGET:.+]] = flow.dispatch.tensor.load %[[RET0]]
// CHECK: %[[GENERIC:.+]] = linalg.generic
// CHECK-SAME: ins(%[[RESHAPE]]
// CHECK-SAME: outs(%[[TARGET]]
@@ -302,7 +302,7 @@
// CHECK-DAG: %[[RET0:.+]] = hal.interface.binding.subspan @io::@ret0
// CHECK-DAG: %[[SOURCE:.+]] = flow.dispatch.tensor.load %[[ARG0]]
// CHECK-DAG: %[[TARGET:.+]] = flow.dispatch.tensor.load %[[RET0]]
-// CHECK: %[[RESHAPE_EXPAND:.+]] = linalg.tensor_expand_shape %[[TARGET]] {{\[}}[0, 1]{{\]}}
+// CHECK-DAG: %[[RESHAPE_EXPAND:.+]] = linalg.tensor_expand_shape %[[TARGET]] {{\[}}[0, 1]{{\]}}
// CHECK: %[[GENERIC:.+]] = linalg.generic
// CHECK-SAME: ins(%[[SOURCE]]
// CHECK-SAME: outs(%[[RESHAPE_EXPAND]]
diff --git a/iree/compiler/Codegen/Common/test/iree_comprehensive_bufferize.mlir b/iree/compiler/Codegen/Common/test/iree_comprehensive_bufferize.mlir
index b90ac0e..b5bdf06 100644
--- a/iree/compiler/Codegen/Common/test/iree_comprehensive_bufferize.mlir
+++ b/iree/compiler/Codegen/Common/test/iree_comprehensive_bufferize.mlir
@@ -66,14 +66,13 @@
// CHECK-DAG: %[[LHS_TILE:.+]] = memref.subview %[[LHS]][%[[IV0]], 0] [%[[TILESIZE_Y]], %[[K]]]
// CHECK-DAG: %[[RHS_TILE:.+]] = memref.subview %[[RHS]][0, %[[IV1]]] [%[[K]], %[[TILESIZE_X]]]
// CHECK-DAG: %[[INIT_TILE:.+]] = memref.subview %[[INIT]][%[[IV0]], %[[IV1]]] [%[[TILESIZE_Y]], %[[TILESIZE_X]]]
-// CHECK-DAG: %[[ALLOC:.+]] = memref.alloc(%[[TILESIZE_Y]], %[[TILESIZE_X]]) {alignment = 128 : i64}
-// CHECK: %[[ALLOC_CASTED:.+]] = memref.cast %[[ALLOC]] : memref<?x?xf32> to memref<?x?xf32, #[[MAP2]]>
-// CHECK: memref.copy %[[INIT_TILE]], %[[ALLOC_CASTED]]
+// CHECK: %[[ALLOC:.+]] = memref.alloc(%[[TILESIZE_Y]], %[[TILESIZE_X]])
+// CHECK: linalg.copy(%[[INIT_TILE]], %[[ALLOC]])
// CHECK: linalg.matmul
// CHECK-SAME: ins(%[[LHS_TILE]], %[[RHS_TILE]]
// CHECK-SAME: outs(%[[ALLOC]]
// CHECK: %[[RESULT_TILE:.+]] = memref.subview %[[RESULT]][%[[IV0]], %[[IV1]]] [%[[TILESIZE_Y]], %[[TILESIZE_X]]]
-// CHECK: memref.copy %[[ALLOC_CASTED]], %[[RESULT_TILE]]
+// CHECK: linalg.copy(%[[ALLOC]], %[[RESULT_TILE]])
// CHECK: memref.dealloc %[[ALLOC]]
@@ -149,3 +148,4 @@
// CHECK: linalg.matmul
// CHECK-SAME: ins(%[[LHS_TILE]], %[[RHS_TILE]]
// CHECK-SAME: outs(%[[RESULT_TILE]]
+
diff --git a/iree/compiler/Codegen/Common/test/linalg_bufferize.mlir b/iree/compiler/Codegen/Common/test/linalg_bufferize.mlir
index 64e28b0..bbc80fe 100644
--- a/iree/compiler/Codegen/Common/test/linalg_bufferize.mlir
+++ b/iree/compiler/Codegen/Common/test/linalg_bufferize.mlir
@@ -992,12 +992,11 @@
}
// CHECK-LABEL: func @subtensor_insert()
// CHECK-DAG: %[[C0:.+]] = arith.constant 0
-// CHECK-DAG: %[[C1:.+]] = arith.constant 1
// CHECK-DAG: %[[ARG0:.+]] = hal.interface.binding.subspan @io::@arg0
// CHECK-DAG: %[[ARG1:.+]] = hal.interface.binding.subspan @io::@arg1
// CHECK-DAG: %[[RET0:.+]] = hal.interface.binding.subspan @io::@ret0
-// CHECK-DAG: %[[D0:.+]] = memref.dim %[[ARG0]], %[[C0]]
-// CHECK-DAG: %[[D1:.+]] = memref.dim %[[ARG0]], %[[C1]]
+// CHECK-DAG: %[[D0:.+]] = hal.interface.load.constant offset = 0 : index
+// CHECK-DAG: %[[D1:.+]] = hal.interface.load.constant offset = 1 : index
// CHECK: linalg.copy(%[[ARG1]], %[[RET0]])
// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[RET0]][3, 4] [%[[D0]], %[[D1]]] [1, 1]
// CHECK: linalg.copy(%[[ARG0]], %[[SUBVIEW]])
diff --git a/iree/compiler/Codegen/Dialect/BUILD b/iree/compiler/Codegen/Dialect/BUILD
index c10392f..9c6008b 100644
--- a/iree/compiler/Codegen/Dialect/BUILD
+++ b/iree/compiler/Codegen/Dialect/BUILD
@@ -17,7 +17,6 @@
"IREECodegenAttributes.td",
"IREECodegenDialect.td",
"LoweringConfig.td",
- "ProcessorOpInterfaces.td",
])
td_library(
@@ -27,7 +26,6 @@
"IREECodegenAttributes.td",
"IREECodegenDialect.td",
"LoweringConfig.td",
- "ProcessorOpInterfaces.td",
],
include = ["*.td"],
),
@@ -66,26 +64,6 @@
],
)
-cc_library(
- name = "ProcessorOpInterfaces",
- srcs = [
- "ProcessorOpInterfaces.cpp",
- ],
- hdrs = [
- "ProcessorOpInterfaces.h",
- ],
- textual_hdrs = [
- "ProcessorOpInterfaces.cpp.inc",
- "ProcessorOpInterfaces.h.inc",
- ],
- deps = [
- ":ProcessorOpInterfaceGen",
- "//iree/compiler/Dialect/HAL/IR",
- "@llvm-project//mlir:GPUDialect",
- "@llvm-project//mlir:IR",
- ],
-)
-
gentbl_cc_library(
name = "IREECodegenDialectGen",
tbl_outs = [
@@ -127,20 +105,3 @@
td_file = "LoweringConfig.td",
deps = [":td_files"],
)
-
-gentbl_cc_library(
- name = "ProcessorOpInterfaceGen",
- tbl_outs = [
- (
- ["-gen-op-interface-decls"],
- "ProcessorOpInterfaces.h.inc",
- ),
- (
- ["-gen-op-interface-defs"],
- "ProcessorOpInterfaces.cpp.inc",
- ),
- ],
- tblgen = "@llvm-project//mlir:mlir-tblgen",
- td_file = "ProcessorOpInterfaces.td",
- deps = [":td_files"],
-)
diff --git a/iree/compiler/Codegen/Dialect/CMakeLists.txt b/iree/compiler/Codegen/Dialect/CMakeLists.txt
index 63c4306..16f6826 100644
--- a/iree/compiler/Codegen/Dialect/CMakeLists.txt
+++ b/iree/compiler/Codegen/Dialect/CMakeLists.txt
@@ -37,24 +37,6 @@
PUBLIC
)
-iree_cc_library(
- NAME
- ProcessorOpInterfaces
- HDRS
- "ProcessorOpInterfaces.h"
- TEXTUAL_HDRS
- "ProcessorOpInterfaces.cpp.inc"
- "ProcessorOpInterfaces.h.inc"
- SRCS
- "ProcessorOpInterfaces.cpp"
- DEPS
- ::ProcessorOpInterfaceGen
- MLIRGPUOps
- MLIRIR
- iree::compiler::Dialect::HAL::IR
- PUBLIC
-)
-
iree_tablegen_library(
NAME
IREECodegenDialectGen
@@ -77,14 +59,4 @@
-gen-enum-defs LoweringConfigEnums.cpp.inc
)
-iree_tablegen_library(
- NAME
- ProcessorOpInterfaceGen
- TD_FILE
- "ProcessorOpInterfaces.td"
- OUTS
- -gen-op-interface-decls ProcessorOpInterfaces.h.inc
- -gen-op-interface-defs ProcessorOpInterfaces.cpp.inc
-)
-
### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/iree/compiler/Codegen/Interfaces/BUILD b/iree/compiler/Codegen/Interfaces/BUILD
new file mode 100644
index 0000000..2757e4d
--- /dev/null
+++ b/iree/compiler/Codegen/Interfaces/BUILD
@@ -0,0 +1,105 @@
+# Copyright 2019 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library")
+load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob")
+
+package(
+ default_visibility = ["//visibility:public"],
+ features = ["layering_check"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+exports_files([
+ "ProcessorOpInterfaces.td",
+])
+
+td_library(
+ name = "td_files",
+ srcs = enforce_glob(
+ [
+ "ProcessorOpInterfaces.td",
+ ],
+ include = ["*.td"],
+ ),
+ deps = [
+ "@llvm-project//mlir:OpBaseTdFiles",
+ ],
+)
+
+cc_library(
+ name = "Interfaces",
+ srcs = [
+ "Interfaces.cpp",
+ ],
+ hdrs = [
+ "Interfaces.h",
+ ],
+ deps = [
+ ":BufferizationInterfaces",
+ ":ProcessorOpInterfaces",
+ ],
+)
+
+cc_library(
+ name = "BufferizationInterfaces",
+ srcs = [
+ "BufferizationInterfaces.cpp",
+ ],
+ hdrs = [
+ "BufferizationInterfaces.h",
+ ],
+ deps = [
+ "//iree/compiler/Dialect/Flow/IR",
+ "//iree/compiler/Dialect/HAL/IR",
+ "@llvm-project//mlir:AffineBufferizableOpInterfaceImpl",
+ "@llvm-project//mlir:ArithBufferizableOpInterfaceImpl",
+ "@llvm-project//mlir:BufferizableOpInterface",
+ "@llvm-project//mlir:ComprehensiveBufferize",
+ "@llvm-project//mlir:LinalgBufferizableOpInterfaceImpl",
+ "@llvm-project//mlir:MemRefDialect",
+ "@llvm-project//mlir:SCFBufferizableOpInterfaceImpl",
+ "@llvm-project//mlir:TensorBufferizableOpInterfaceImpl",
+ "@llvm-project//mlir:VectorBufferizableOpInterfaceImpl",
+ ],
+)
+
+cc_library(
+ name = "ProcessorOpInterfaces",
+ srcs = [
+ "ProcessorOpInterfaces.cpp",
+ ],
+ hdrs = [
+ "ProcessorOpInterfaces.h",
+ ],
+ textual_hdrs = [
+ "ProcessorOpInterfaces.cpp.inc",
+ "ProcessorOpInterfaces.h.inc",
+ ],
+ deps = [
+ ":ProcessorOpInterfaceGen",
+ "//iree/compiler/Dialect/HAL/IR",
+ "@llvm-project//mlir:GPUDialect",
+ "@llvm-project//mlir:IR",
+ ],
+)
+
+gentbl_cc_library(
+ name = "ProcessorOpInterfaceGen",
+ tbl_outs = [
+ (
+ ["-gen-op-interface-decls"],
+ "ProcessorOpInterfaces.h.inc",
+ ),
+ (
+ ["-gen-op-interface-defs"],
+ "ProcessorOpInterfaces.cpp.inc",
+ ),
+ ],
+ tblgen = "@llvm-project//mlir:mlir-tblgen",
+ td_file = "ProcessorOpInterfaces.td",
+ deps = [":td_files"],
+)
diff --git a/iree/compiler/Codegen/Interfaces/BufferizationInterfaces.cpp b/iree/compiler/Codegen/Interfaces/BufferizationInterfaces.cpp
new file mode 100644
index 0000000..65618ef
--- /dev/null
+++ b/iree/compiler/Codegen/Interfaces/BufferizationInterfaces.cpp
@@ -0,0 +1,233 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Codegen/Interfaces/BufferizationInterfaces.h"
+
+#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
+#include "iree/compiler/Dialect/Flow/IR/FlowTypes.h"
+#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
+#include "mlir/Dialect/Linalg/ComprehensiveBufferize/AffineInterfaceImpl.h"
+#include "mlir/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.h"
+#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h"
+#include "mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h"
+#include "mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h"
+#include "mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h"
+#include "mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h"
+#include "mlir/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.h"
+#include "mlir/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+
+using mlir::linalg::comprehensive_bufferize::BufferizableOpInterface;
+using mlir::linalg::comprehensive_bufferize::BufferizationAliasInfo;
+using mlir::linalg::comprehensive_bufferize::BufferizationState;
+using mlir::linalg::comprehensive_bufferize::linalg_ext::
+ InitTensorEliminationStep;
+
+namespace mlir {
+namespace iree_compiler {
+
+//===----------------------------------------------------------------------===//
+// Utility functions.
+//===----------------------------------------------------------------------===//
+
+template <typename TensorType>
+static MemRefType getMemrefTypeForTensor(TensorType tensorType,
+ MemRefLayoutAttrInterface layout = {},
+ Attribute memorySpace = {}) {
+ return MemRefType::get(tensorType.getShape(), tensorType.getElementType(),
+ layout, memorySpace);
+}
+
+static Value getSubspanBuffer(Value tensor, OpBuilder &b,
+ BufferizationState &state) {
+ if (!state.isMapped(tensor)) {
+ OpBuilder::InsertionGuard g(b);
+ auto subspanOp =
+ tensor.getDefiningOp<IREE::HAL::InterfaceBindingSubspanOp>();
+ assert(subspanOp && "expected LoadOp/StoreOp source/target is SubspanOp");
+
+ auto shapedType = subspanOp.getResult()
+ .getType()
+ .dyn_cast<IREE::Flow::DispatchTensorType>();
+ assert(shapedType && shapedType.hasRank());
+
+ b.setInsertionPoint(subspanOp);
+ // Just change the result type of the InterfaceBindingSubspanOp.
+ auto memRefType = getMemrefTypeForTensor(shapedType);
+ auto baseBuffer = b.create<IREE::HAL::InterfaceBindingSubspanOp>(
+ subspanOp->getLoc(), memRefType, subspanOp.binding(),
+ subspanOp.byte_offset(), subspanOp.byte_length(),
+ subspanOp.dynamic_dims(), subspanOp.alignmentAttr());
+ state.mapValue(subspanOp, baseBuffer);
+ state.aliasInfo.createAliasInfoEntry(subspanOp.result());
+ }
+
+ return state.lookupValue(tensor);
+}
+
+namespace {
+
+//===----------------------------------------------------------------------===//
+// IREE specific External models for BufferizableOpInterface.
+//===----------------------------------------------------------------------===//
+
+struct DispatchTensorLoadOpInterface
+ : public BufferizableOpInterface::ExternalModel<
+ DispatchTensorLoadOpInterface, IREE::Flow::DispatchTensorLoadOp> {
+ SmallVector<OpOperand *> getAliasingOpOperand(Operation *op,
+ OpResult opResult) const {
+ return {};
+ }
+
+ bool isWritable(Operation *op, Value value) const {
+ auto loadOp = cast<IREE::Flow::DispatchTensorLoadOp>(op);
+ auto shapedType =
+ loadOp.source().getType().dyn_cast<IREE::Flow::DispatchTensorType>();
+ assert(shapedType && "unexpected source type");
+ return shapedType.getAccess() != IREE::Flow::TensorAccess::ReadOnly;
+ }
+
+ LogicalResult bufferize(Operation *op, OpBuilder &b,
+ BufferizationState &state) const {
+ OpBuilder::InsertionGuard g(b);
+ b.setInsertionPoint(op);
+ auto loadOp = cast<IREE::Flow::DispatchTensorLoadOp>(op);
+ Value source = getSubspanBuffer(loadOp.source(), b, state);
+
+ // Bufferize to subview.
+ Value subView = b.create<memref::SubViewOp>(
+ loadOp->getLoc(), source, loadOp.getMixedOffsets(),
+ loadOp.getMixedSizes(), loadOp.getMixedStrides());
+ state.mapBuffer(loadOp.result(), subView);
+
+ return success();
+ }
+};
+
+/// Returns true if the value of a `storeOp` bufferizes to an equivalent
+/// DispatchTensorLoadOp result that bufferizes inplace.
+static bool isValueEquivalentToAnInplaceTensorLoadOp(
+ const BufferizationAliasInfo &aliasInfo,
+ IREE::Flow::DispatchTensorStoreOp storeOp) {
+ bool foundOp = false;
+ aliasInfo.applyOnEquivalenceClass(storeOp.value(), [&](Value value) {
+ auto loadOp = value.getDefiningOp<IREE::Flow::DispatchTensorLoadOp>();
+ // TODO: Assert that offsets, sizes and strides are the same.
+ if (loadOp &&
+ aliasInfo.areEquivalentBufferizedValues(loadOp.result(),
+ storeOp.value()) &&
+ loadOp.source() == storeOp.target()) {
+ foundOp = true;
+ }
+ });
+
+ return foundOp;
+}
+
+struct DispatchTensorStoreOpInterface
+ : public BufferizableOpInterface::ExternalModel<
+ DispatchTensorStoreOpInterface, IREE::Flow::DispatchTensorStoreOp> {
+ bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const {
+ return true;
+ }
+
+ bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const {
+ return false;
+ }
+
+ OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const {
+ return OpResult();
+ }
+
+ LogicalResult bufferize(Operation *op, OpBuilder &b,
+ BufferizationState &state) const {
+ OpBuilder::InsertionGuard g(b);
+ b.setInsertionPoint(op);
+ auto storeOp = cast<IREE::Flow::DispatchTensorStoreOp>(op);
+
+ // If everything bufferized inplace, no copy is needed. We wrote to the
+ // target buffer already.
+ if (!isValueEquivalentToAnInplaceTensorLoadOp(state.aliasInfo, storeOp)) {
+ Value target = getSubspanBuffer(storeOp.target(), b, state);
+ Value subView = b.create<memref::SubViewOp>(
+ storeOp->getLoc(), target, storeOp.getMixedOffsets(),
+ storeOp.getMixedSizes(), storeOp.getMixedStrides());
+ Value srcMemref = state.lookupBuffer(storeOp.value());
+ state.allocationFns.memCpyFn(b, storeOp->getLoc(), srcMemref, subView);
+ }
+
+ state.markOpObsolete(storeOp);
+ return success();
+ }
+};
+
+//===----------------------------------------------------------------------===//
+// IREE specific post analysis transformations.
+//===----------------------------------------------------------------------===//
+
+/// Try to eliminate InitTensorOps that are eventually fed into a
+/// DispatchTensorStoreOp. Such InitTensorOps are replaced with matching
+/// DispatchTensorLoadOps. Two conditions must be met:
+///
+/// * The target must be a "readwrite" tensor.
+/// * All ops along the reverse SSA use-def chain from the
+/// DispatchTensorStoreOp to the InitTensorOp must have bufferized in-place.
+struct StoreTensorOpAnchoredInitTensorEliminationStep
+ : public InitTensorEliminationStep {
+ LogicalResult run(FuncOp funcOp, BufferizationState &state,
+ SmallVector<Operation *> &newOps) override {
+ return eliminateInitTensors(
+ funcOp, state,
+ /*anchorMatchFunc=*/
+ [&](OpOperand &operand) {
+ return isa<IREE::Flow::DispatchTensorStoreOp>(operand.getOwner());
+ },
+ /*rewriteFunc=*/
+ [](OpBuilder &b, Location loc, OpOperand &operand) {
+ auto storeOp =
+ cast<IREE::Flow::DispatchTensorStoreOp>(operand.getOwner());
+ auto loadOp = b.create<IREE::Flow::DispatchTensorLoadOp>(
+ loc, storeOp.value().getType().cast<RankedTensorType>(),
+ storeOp.target(), storeOp.target_dims(),
+ storeOp.getMixedOffsets(), storeOp.getMixedSizes(),
+ storeOp.getMixedStrides());
+ return loadOp.result();
+ },
+ newOps);
+ }
+};
+} // namespace
+
+void registerBufferizationInterfaces(DialectRegistry ®istry) {
+ linalg::comprehensive_bufferize::affine_ext::
+ registerBufferizableOpInterfaceExternalModels(registry);
+ linalg::comprehensive_bufferize::arith_ext::
+ registerBufferizableOpInterfaceExternalModels(registry);
+ linalg::comprehensive_bufferize::linalg_ext::
+ registerBufferizableOpInterfaceExternalModels(registry);
+ linalg::comprehensive_bufferize::scf_ext::
+ registerBufferizableOpInterfaceExternalModels(registry);
+ linalg::comprehensive_bufferize::tensor_ext::
+ registerBufferizableOpInterfaceExternalModels(registry);
+ linalg::comprehensive_bufferize::vector_ext::
+ registerBufferizableOpInterfaceExternalModels(registry);
+
+ // Register IREE operations.
+ registry.addOpInterface<IREE::Flow::DispatchTensorLoadOp,
+ DispatchTensorLoadOpInterface>();
+ registry.addOpInterface<IREE::Flow::DispatchTensorStoreOp,
+ DispatchTensorStoreOpInterface>();
+}
+
+void addPostAnalysisTransformations(
+ linalg::comprehensive_bufferize::BufferizationOptions &options) {
+ options.addPostAnalysisStep<StoreTensorOpAnchoredInitTensorEliminationStep>();
+ options.addPostAnalysisStep<linalg::comprehensive_bufferize::tensor_ext::
+ InplaceInsertSliceOpAnalysis>();
+}
+
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/iree/compiler/Codegen/Interfaces/BufferizationInterfaces.h b/iree/compiler/Codegen/Interfaces/BufferizationInterfaces.h
new file mode 100644
index 0000000..729a93d
--- /dev/null
+++ b/iree/compiler/Codegen/Interfaces/BufferizationInterfaces.h
@@ -0,0 +1,26 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_COMPILER_CODEGEN_INTERFACES_BUFFERIZATIONINTERFACES_H_
+#define IREE_COMPILER_CODEGEN_INTERFACES_BUFFERIZATIONINTERFACES_H_
+
+#include "mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h"
+#include "mlir/IR/Dialect.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+// Register all interfaces needed for bufferization.
+void registerBufferizationInterfaces(DialectRegistry ®istry);
+
+// Method to add all the analysis passes for bufferization.
+void addPostAnalysisTransformations(
+ linalg::comprehensive_bufferize::BufferizationOptions &options);
+
+} // namespace iree_compiler
+} // namespace mlir
+
+#endif // IREE_COMPILER_CODEGEN_INTERFACES_BUFFERIZATIONINTERFACES_H_
diff --git a/iree/compiler/Codegen/Interfaces/CMakeLists.txt b/iree/compiler/Codegen/Interfaces/CMakeLists.txt
new file mode 100644
index 0000000..5ac2788
--- /dev/null
+++ b/iree/compiler/Codegen/Interfaces/CMakeLists.txt
@@ -0,0 +1,76 @@
+################################################################################
+# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from #
+# iree/compiler/Codegen/Interfaces/BUILD #
+# #
+# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary #
+# CMake-only content. #
+# #
+# To disable autogeneration for this file entirely, delete this header. #
+################################################################################
+
+iree_add_all_subdirs()
+
+iree_cc_library(
+ NAME
+ Interfaces
+ HDRS
+ "Interfaces.h"
+ SRCS
+ "Interfaces.cpp"
+ DEPS
+ ::BufferizationInterfaces
+ ::ProcessorOpInterfaces
+ PUBLIC
+)
+
+iree_cc_library(
+ NAME
+ BufferizationInterfaces
+ HDRS
+ "BufferizationInterfaces.h"
+ SRCS
+ "BufferizationInterfaces.cpp"
+ DEPS
+ MLIRAffineBufferizableOpInterfaceImpl
+ MLIRArithBufferizableOpInterfaceImpl
+ MLIRBufferizableOpInterface
+ MLIRComprehensiveBufferize
+ MLIRLinalgBufferizableOpInterfaceImpl
+ MLIRMemRef
+ MLIRSCFBufferizableOpInterfaceImpl
+ MLIRTensorBufferizableOpInterfaceImpl
+ MLIRVectorBufferizableOpInterfaceImpl
+ iree::compiler::Dialect::Flow::IR
+ iree::compiler::Dialect::HAL::IR
+ PUBLIC
+)
+
+iree_cc_library(
+ NAME
+ ProcessorOpInterfaces
+ HDRS
+ "ProcessorOpInterfaces.h"
+ TEXTUAL_HDRS
+ "ProcessorOpInterfaces.cpp.inc"
+ "ProcessorOpInterfaces.h.inc"
+ SRCS
+ "ProcessorOpInterfaces.cpp"
+ DEPS
+ ::ProcessorOpInterfaceGen
+ MLIRGPUOps
+ MLIRIR
+ iree::compiler::Dialect::HAL::IR
+ PUBLIC
+)
+
+iree_tablegen_library(
+ NAME
+ ProcessorOpInterfaceGen
+ TD_FILE
+ "ProcessorOpInterfaces.td"
+ OUTS
+ -gen-op-interface-decls ProcessorOpInterfaces.h.inc
+ -gen-op-interface-defs ProcessorOpInterfaces.cpp.inc
+)
+
+### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/iree/compiler/Codegen/Interfaces/Interfaces.cpp b/iree/compiler/Codegen/Interfaces/Interfaces.cpp
new file mode 100644
index 0000000..64a1698
--- /dev/null
+++ b/iree/compiler/Codegen/Interfaces/Interfaces.cpp
@@ -0,0 +1,21 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Codegen/Interfaces/Interfaces.h"
+
+#include "iree/compiler/Codegen/Interfaces/BufferizationInterfaces.h"
+#include "iree/compiler/Codegen/Interfaces/ProcessorOpInterfaces.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+void registerCodegenInterfaces(DialectRegistry ®istry) {
+ registerProcessorOpInterfaceExternalModels(registry);
+ registerBufferizationInterfaces(registry);
+}
+
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/iree/compiler/Codegen/Interfaces/Interfaces.h b/iree/compiler/Codegen/Interfaces/Interfaces.h
new file mode 100644
index 0000000..9c777b6
--- /dev/null
+++ b/iree/compiler/Codegen/Interfaces/Interfaces.h
@@ -0,0 +1,21 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_COMPILER_CODEGEN_INTERFACES_INTERFACES_H_
+#define IREE_COMPILER_CODEGEN_INTERFACES_INTERFACES_H_
+
+#include "mlir/IR/Dialect.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+/// Register all codegen related interfaces.
+void registerCodegenInterfaces(DialectRegistry ®istry);
+
+} // namespace iree_compiler
+} // namespace mlir
+
+#endif // IREE_COMPILER_CODEGEN_INTERFACES_INTERFACES_H_
diff --git a/iree/compiler/Codegen/Dialect/ProcessorOpInterfaces.cpp b/iree/compiler/Codegen/Interfaces/ProcessorOpInterfaces.cpp
similarity index 94%
rename from iree/compiler/Codegen/Dialect/ProcessorOpInterfaces.cpp
rename to iree/compiler/Codegen/Interfaces/ProcessorOpInterfaces.cpp
index 47c92be..f1e64a7 100644
--- a/iree/compiler/Codegen/Dialect/ProcessorOpInterfaces.cpp
+++ b/iree/compiler/Codegen/Interfaces/ProcessorOpInterfaces.cpp
@@ -4,13 +4,13 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-#include "iree/compiler/Codegen/Dialect/ProcessorOpInterfaces.h"
+#include "iree/compiler/Codegen/Interfaces/ProcessorOpInterfaces.h"
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
#include "mlir/Dialect/GPU/GPUDialect.h"
/// Include the generated interface definitions.
-#include "iree/compiler/Codegen/Dialect/ProcessorOpInterfaces.cpp.inc"
+#include "iree/compiler/Codegen/Interfaces/ProcessorOpInterfaces.cpp.inc"
namespace mlir {
namespace iree_compiler {
diff --git a/iree/compiler/Codegen/Dialect/ProcessorOpInterfaces.h b/iree/compiler/Codegen/Interfaces/ProcessorOpInterfaces.h
similarity index 66%
rename from iree/compiler/Codegen/Dialect/ProcessorOpInterfaces.h
rename to iree/compiler/Codegen/Interfaces/ProcessorOpInterfaces.h
index ec3e0c4..ef16895 100644
--- a/iree/compiler/Codegen/Dialect/ProcessorOpInterfaces.h
+++ b/iree/compiler/Codegen/Interfaces/ProcessorOpInterfaces.h
@@ -4,14 +4,14 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-#ifndef IREE_COMPILER_CODEGEN_DIALECT_PROCESSOROPINTERFACES_H_
-#define IREE_COMPILER_CODEGEN_DIALECT_PROCESSOROPINTERFACES_H_
+#ifndef IREE_COMPILER_CODEGEN_INTERFACES_PROCESSOROPINTERFACES_H_
+#define IREE_COMPILER_CODEGEN_INTERFACES_PROCESSOROPINTERFACES_H_
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
/// Include the generated interface declarations.
-#include "iree/compiler/Codegen/Dialect/ProcessorOpInterfaces.h.inc" // IWYU pragma: export
+#include "iree/compiler/Codegen/Interfaces/ProcessorOpInterfaces.h.inc" // IWYU pragma: export
namespace mlir {
namespace iree_compiler {
@@ -22,4 +22,4 @@
} // namespace iree_compiler
} // namespace mlir
-#endif // IREE_COMPILER_CODEGEN_DIALECT_PROCESSOROPINTERFACES_H_
+#endif // IREE_COMPILER_CODEGEN_INTERFACES_PROCESSOROPINTERFACES_H_
diff --git a/iree/compiler/Codegen/Dialect/ProcessorOpInterfaces.td b/iree/compiler/Codegen/Interfaces/ProcessorOpInterfaces.td
similarity index 100%
rename from iree/compiler/Codegen/Dialect/ProcessorOpInterfaces.td
rename to iree/compiler/Codegen/Interfaces/ProcessorOpInterfaces.td
diff --git a/iree/compiler/Codegen/LLVMCPU/BUILD b/iree/compiler/Codegen/LLVMCPU/BUILD
index 953be7c..868f68e 100644
--- a/iree/compiler/Codegen/LLVMCPU/BUILD
+++ b/iree/compiler/Codegen/LLVMCPU/BUILD
@@ -35,8 +35,6 @@
"//iree/compiler/Dialect/Flow/IR",
"//iree/compiler/Dialect/HAL/IR",
"//iree/compiler/Dialect/HAL/IR:HALDialect",
- "//iree/compiler/Dialect/Shape/IR",
- "//iree/compiler/Dialect/Shape/Transforms",
"//iree/compiler/Dialect/Util/IR",
"//llvm-external-projects/iree-dialects:IREELinalgExtDialect",
"//llvm-external-projects/iree-dialects:IREELinalgExtTransforms",
diff --git a/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt b/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt
index 88a96d2..b86fc40 100644
--- a/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt
+++ b/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt
@@ -67,8 +67,6 @@
iree::compiler::Dialect::Flow::IR
iree::compiler::Dialect::HAL::IR
iree::compiler::Dialect::HAL::IR::HALDialect
- iree::compiler::Dialect::Shape::IR
- iree::compiler::Dialect::Shape::Transforms
iree::compiler::Dialect::Util::IR
PUBLIC
)
diff --git a/iree/compiler/Codegen/LLVMCPU/ConvertToLLVM.cpp b/iree/compiler/Codegen/LLVMCPU/ConvertToLLVM.cpp
index 62eeb4e..05e980d 100644
--- a/iree/compiler/Codegen/LLVMCPU/ConvertToLLVM.cpp
+++ b/iree/compiler/Codegen/LLVMCPU/ConvertToLLVM.cpp
@@ -9,7 +9,6 @@
#include "iree/compiler/Codegen/Utils/Utils.h"
#include "iree/compiler/Dialect/HAL/IR/HALDialect.h"
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
-#include "iree/compiler/Dialect/Shape/IR/ShapeDialect.h"
#include "iree/compiler/Dialect/Util/IR/UtilDialect.h"
#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
#include "llvm/ADT/Triple.h"
@@ -689,21 +688,29 @@
vector::populateVectorMaskOpLoweringPatterns(patterns);
vector::populateVectorShapeCastLoweringPatterns(patterns);
vector::populateVectorTransposeLoweringPatterns(patterns);
- (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+ if (failed(applyPatternsAndFoldGreedily(getOperation(),
+ std::move(patterns)))) {
+ return signalPassFailure();
+ }
}
{
OwningRewritePatternList vectorToLoopsPatterns(&getContext());
populateVectorToSCFConversionPatterns(
vectorToLoopsPatterns, VectorTransferToSCFOptions().enableFullUnroll());
- (void)applyPatternsAndFoldGreedily(getOperation(),
- std::move(vectorToLoopsPatterns));
+ if (failed(applyPatternsAndFoldGreedily(
+ getOperation(), std::move(vectorToLoopsPatterns)))) {
+ return signalPassFailure();
+ }
}
// math dialect elementry functions -> polynomial form.
{
OwningRewritePatternList mathPatterns(&getContext());
populateMathPolynomialApproximationPatterns(mathPatterns);
- (void)applyPatternsAndFoldGreedily(getOperation(), std::move(mathPatterns));
+ if (failed(applyPatternsAndFoldGreedily(getOperation(),
+ std::move(mathPatterns)))) {
+ return signalPassFailure();
+ }
}
const auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>();
@@ -755,8 +762,7 @@
// rest of the IR.
target.addLegalOp<ModuleOp, IREE::HAL::InterfaceOp,
IREE::HAL::InterfaceBindingOp, IREE::HAL::InterfaceEndOp>();
- target.addIllegalDialect<ShapeDialect, StandardOpsDialect,
- mlir::arith::ArithmeticDialect,
+ target.addIllegalDialect<StandardOpsDialect, mlir::arith::ArithmeticDialect,
IREE::Util::UtilDialect, IREE::HAL::HALDialect,
math::MathDialect, tosa::TosaDialect>();
target.addIllegalOp<UnrealizedConversionCastOp>();
@@ -766,15 +772,16 @@
if (isEntryPoint(funcOp)) return false;
return true;
});
- target.addDynamicallyLegalDialect<
- ShapeDialect, StandardOpsDialect, mlir::math::MathDialect,
- mlir::arith::ArithmeticDialect, IREE::Util::UtilDialect,
- IREE::HAL::HALDialect, math::MathDialect>([&](Operation *op) {
- auto funcParent = op->getParentOfType<FuncOp>();
- if (!funcParent) return false;
- if (isEntryPoint(funcParent)) return false;
- return true;
- });
+ target.addDynamicallyLegalDialect<StandardOpsDialect, mlir::math::MathDialect,
+ mlir::arith::ArithmeticDialect,
+ IREE::Util::UtilDialect,
+ IREE::HAL::HALDialect, math::MathDialect>(
+ [&](Operation *op) {
+ auto funcParent = op->getParentOfType<FuncOp>();
+ if (!funcParent) return false;
+ if (isEntryPoint(funcParent)) return false;
+ return true;
+ });
if (failed(applyPartialConversion(module, target, std::move(patterns)))) {
signalPassFailure();
@@ -791,7 +798,10 @@
llvm::Triple triple(targetTripleStr);
if (triple.isWasm()) {
populateUnfusedFMAOpsPassPatterns(&getContext(), postPatterns);
- (void)applyPatternsAndFoldGreedily(module, std::move(postPatterns));
+ if (failed(
+ applyPatternsAndFoldGreedily(module, std::move(postPatterns)))) {
+ return signalPassFailure();
+ }
}
}
}
diff --git a/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp b/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp
index 6f9a815..e4c49da 100644
--- a/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp
+++ b/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp
@@ -240,6 +240,99 @@
return success();
}
+/// Adjusts the workload per workgroup to be a multiple of vector size to ensure
+/// that the op vectorizes.
+static int64_t getMaxTileSize(int64_t lb, int64_t ub, int64_t maxSize,
+ int64_t vectorSizeVal) {
+ if (ub == ShapedType::kDynamicSize || lb == ShapedType::kDynamicSize) {
+ return maxSize;
+ }
+ int64_t dim = ub - lb;
+ if (dim < vectorSizeVal) return vectorSizeVal;
+ for (int64_t i = std::min(maxSize, dim); i > 0; --i) {
+ if (dim % i == 0 && i % vectorSizeVal == 0) {
+ return i;
+ }
+ }
+ return maxSize;
+}
+
+static LogicalResult setX86RootConfig(FuncOp entryPointFn,
+ linalg::ContractionOpInterface op,
+ SmallVector<int64_t> workloadPerWorkgroup,
+ int vectorSize) {
+ setTranslationInfo(entryPointFn,
+ getDispatchLoweringPassPipeline(entryPointFn, op),
+ workloadPerWorkgroup,
+ /*workgroupSize=*/ArrayRef<int64_t>{});
+
+ // Hardcoded tile sizes.
+ // L1 tile sizes are {1, 1, ..., 8, 32, 32}.
+ // Vector tile sizes are {1, ..., 1, 16, 16}
+ SmallVector<int64_t> l1TileSizes, vectorTileSizes;
+ int64_t nLoops = cast<linalg::LinalgOp>(op.getOperation()).getNumLoops();
+ l1TileSizes.append(nLoops - 3, 1);
+ l1TileSizes.push_back(
+ getMaxTileSize(0, workloadPerWorkgroup[1], 8, vectorSize));
+ l1TileSizes.push_back(
+ getMaxTileSize(0, workloadPerWorkgroup[0], 32, vectorSize));
+ vectorTileSizes.append(nLoops - 2, 1);
+ vectorTileSizes.push_back(vectorSize);
+
+ // L1/vector tile size for k dimensions.
+ auto lhsShapedType = op.lhs().getType().cast<ShapedType>();
+ int64_t K = lhsShapedType.getShape().back();
+ l1TileSizes.push_back(getMaxTileSize(0, K, 32, vectorSize));
+ vectorTileSizes.push_back(vectorSize);
+ TileSizesListType tileSizes;
+ tileSizes.push_back({}); // Empty here since there is nothing to do in first
+ // level tiling.
+ tileSizes.push_back(l1TileSizes);
+ tileSizes.push_back(vectorTileSizes);
+ auto config = IREE::Codegen::LoweringConfigAttr::get(
+ entryPointFn.getContext(), tileSizes, vectorTileSizes);
+ setLoweringConfig(op, config);
+
+ return success();
+}
+
+static LogicalResult setARMRootConfig(FuncOp entryPointFn,
+ linalg::ContractionOpInterface op,
+ SmallVector<int64_t> workloadPerWorkgroup,
+ int vectorSize) {
+ setTranslationInfo(entryPointFn,
+ getDispatchLoweringPassPipeline(entryPointFn, op),
+ workloadPerWorkgroup,
+ /*workgroupSize=*/ArrayRef<int64_t>{});
+
+ SmallVector<int64_t> l1TileSizes, vectorTileSizes;
+ const int kDefaultL1TileSize = 32;
+ int64_t nLoops = cast<linalg::LinalgOp>(op.getOperation()).getNumLoops();
+ l1TileSizes.append(nLoops - 3, 1);
+ l1TileSizes.push_back(getMaxTileSize(0, workloadPerWorkgroup[1],
+ kDefaultL1TileSize, vectorSize));
+ l1TileSizes.push_back(getMaxTileSize(0, workloadPerWorkgroup[0],
+ kDefaultL1TileSize, vectorSize));
+ vectorTileSizes.append(nLoops - 3, 1);
+ vectorTileSizes.append(2, vectorSize);
+
+ // L1/vector tile size for k dimensions.
+ auto lhsShapedType = op.lhs().getType().cast<ShapedType>();
+ int64_t K = lhsShapedType.getShape().back();
+ l1TileSizes.push_back(getMaxTileSize(0, K, kDefaultL1TileSize, vectorSize));
+ vectorTileSizes.push_back(vectorSize);
+ TileSizesListType tileSizes;
+ tileSizes.push_back({}); // Empty here since there is nothing to do in first
+ // level tiling.
+ tileSizes.push_back(l1TileSizes);
+ tileSizes.push_back(vectorTileSizes);
+ auto config = IREE::Codegen::LoweringConfigAttr::get(
+ entryPointFn.getContext(), tileSizes, vectorTileSizes);
+ setLoweringConfig(op, config);
+
+ return success();
+}
+
/// Sets the lowering configuration for dispatch region with root op that
/// implements the contraction operation interface.
static LogicalResult setRootConfig(
@@ -249,16 +342,8 @@
auto lhsShapedType = contractionOp.lhs().getType().cast<ShapedType>();
// Use the default distribution for the matmul loops.
- bool isBatchMatmul = lhsShapedType.getRank() == 3;
- if (isBatchMatmul) {
- if (tiledLoops.size() != 3) {
- return contractionOp.emitOpError(
- "expected op to be distributed along 3 dimensions");
- }
- } else if (tiledLoops.size() != 2) {
- return contractionOp.emitOpError(
- "expected op to be distributed along 2 dimensions");
- }
+ int numBatchDims =
+ cast<linalg::LinalgOp>(contractionOp.getOperation()).getNumLoops() - 3;
Type elementType = lhsShapedType.getElementType();
if (!elementType.isIntOrFloat()) return success();
@@ -275,26 +360,9 @@
vectorSizeVals[vectorSizeVals.size() - 2] = vectorSize;
SmallVector<int64_t> workloadPerWorkgroup = getDefaultWorkloadPerWorkgroup(
- isBatchMatmul ? tiledLoops.drop_front() : tiledLoops,
- isBatchMatmul ? ArrayRef<int64_t>(vectorSizeVals).drop_front()
- : vectorSizeVals);
+ tiledLoops.drop_front(numBatchDims),
+ ArrayRef<int64_t>(vectorSizeVals).drop_front(numBatchDims));
- // Adjust the workload per workgroup to be a multiple of vector size to ensure
- // that the op vectorizes.
- auto getTileSize = [](int64_t lb, int64_t ub, int64_t maxSize,
- int64_t vectorSizeVal) -> int64_t {
- if (ub == ShapedType::kDynamicSize || lb == ShapedType::kDynamicSize) {
- return maxSize;
- }
- int64_t dim = ub - lb;
- if (dim < vectorSizeVal) return vectorSizeVal;
- for (int64_t i = std::min(maxSize, dim); i > 0; --i) {
- if (dim % i == 0 && i % vectorSizeVal == 0) {
- return i;
- }
- }
- return maxSize;
- };
for (unsigned i = tiledLoops.size() - 2; i < tiledLoops.size(); ++i) {
if (!tiledLoops[i].untiledLowerBound.is<Attribute>() ||
!tiledLoops[i].untiledUpperBound.is<Attribute>()) {
@@ -304,45 +372,20 @@
tiledLoops[i].untiledLowerBound.get<Attribute>().cast<IntegerAttr>();
auto ub =
tiledLoops[i].untiledUpperBound.get<Attribute>().cast<IntegerAttr>();
- workloadPerWorkgroup[tiledLoops.size() - 1 - i] = getTileSize(
+ workloadPerWorkgroup[tiledLoops.size() - 1 - i] = getMaxTileSize(
lb.getInt(), ub.getInt(),
workloadPerWorkgroup[tiledLoops.size() - 1 - i], vectorSizeVals[i]);
}
- if (isBatchMatmul) {
- workloadPerWorkgroup.push_back(1);
- }
- setTranslationInfo(
- entryPointFn,
- getDispatchLoweringPassPipeline(entryPointFn, contractionOp),
- workloadPerWorkgroup,
- /*workgroupSize =*/ArrayRef<int64_t>{});
+ workloadPerWorkgroup.append(numBatchDims, 1);
Optional<llvm::Triple> triple = getTargetTriple(entryPointFn);
- int64_t matmulL1TileSize = (triple && triple.getValue().isX86()) ? 16 : 32;
-
- SmallVector<int64_t, 4> l1TileSizes, vectorTileSizes;
- if (isBatchMatmul) {
- l1TileSizes.push_back(1);
+ if (triple && triple.getValue().isX86()) {
+ return setX86RootConfig(entryPointFn, contractionOp, workloadPerWorkgroup,
+ vectorSize);
}
- for (unsigned i = tiledLoops.size() - 2; i < tiledLoops.size(); ++i) {
- l1TileSizes.push_back(
- getTileSize(0, workloadPerWorkgroup[tiledLoops.size() - 1 - i],
- matmulL1TileSize, vectorSizeVals[i]));
- }
- // L1 tile size for k dimensions.
- int64_t K = lhsShapedType.getShape().back();
- l1TileSizes.push_back(getTileSize(0, K, matmulL1TileSize, vectorSize));
- vectorSizeVals.push_back(vectorSize);
- vectorTileSizes.assign(vectorSizeVals.begin(), vectorSizeVals.end());
- TileSizesListType tileSizes;
- tileSizes.push_back({}); // Empty here since there is nothing to do in first
- // level tiling.
- tileSizes.emplace_back(std::move(l1TileSizes));
- tileSizes.emplace_back(std::move(vectorTileSizes));
- auto config = IREE::Codegen::LoweringConfigAttr::get(
- entryPointFn.getContext(), tileSizes, vectorSizeVals);
- setLoweringConfig(contractionOp, config);
- return success();
+ // Fall back to ARM configurations.
+ return setARMRootConfig(entryPointFn, contractionOp, workloadPerWorkgroup,
+ vectorSize);
}
/// Sets the lowering configuration for dispatch region for linalg.mmt4d root
diff --git a/iree/compiler/Codegen/LLVMCPU/LLVMCPUTileAndVectorizeLinalgTensorOps.cpp b/iree/compiler/Codegen/LLVMCPU/LLVMCPUTileAndVectorizeLinalgTensorOps.cpp
index cdba503..0963595 100644
--- a/iree/compiler/Codegen/LLVMCPU/LLVMCPUTileAndVectorizeLinalgTensorOps.cpp
+++ b/iree/compiler/Codegen/LLVMCPU/LLVMCPUTileAndVectorizeLinalgTensorOps.cpp
@@ -211,8 +211,10 @@
RewritePatternSet canonicalizationPatterns(context);
vector::ContractionOp::getCanonicalizationPatterns(canonicalizationPatterns,
context);
- (void)applyPatternsAndFoldGreedily(funcOp,
- std::move(canonicalizationPatterns));
+ if (failed(applyPatternsAndFoldGreedily(
+ funcOp, std::move(canonicalizationPatterns)))) {
+ return signalPassFailure();
+ }
DEBUG_WITH_TYPE(DEBUG_TYPE, {
llvm::dbgs()
diff --git a/iree/compiler/Codegen/LLVMCPU/LLVMCPUTileFuseAndVectorizeLinalgTensorOps.cpp b/iree/compiler/Codegen/LLVMCPU/LLVMCPUTileFuseAndVectorizeLinalgTensorOps.cpp
index fcfb63f..d095b18 100644
--- a/iree/compiler/Codegen/LLVMCPU/LLVMCPUTileFuseAndVectorizeLinalgTensorOps.cpp
+++ b/iree/compiler/Codegen/LLVMCPU/LLVMCPUTileFuseAndVectorizeLinalgTensorOps.cpp
@@ -242,8 +242,10 @@
RewritePatternSet canonicalizationPatterns(context);
vector::ContractionOp::getCanonicalizationPatterns(canonicalizationPatterns,
context);
- (void)applyPatternsAndFoldGreedily(funcOp,
- std::move(canonicalizationPatterns));
+ if (failed(applyPatternsAndFoldGreedily(
+ funcOp, std::move(canonicalizationPatterns)))) {
+ return signalPassFailure();
+ }
DEBUG_WITH_TYPE(DEBUG_TYPE, {
llvm::dbgs()
@@ -260,8 +262,10 @@
// TODO(hanchung): Set different vector sizes for different operations. Also
// it seems that `{16, 16, 16}` is not a good config. We should tune it.
vector::populateVectorUnrollPatterns(
- vectorUnrollPatterns, vector::UnrollVectorOptions().setNativeShape(
- config.getNativeVectorSizeVals()));
+ vectorUnrollPatterns,
+ vector::UnrollVectorOptions().setNativeShape(config.getTileSizeVals(
+ static_cast<unsigned>(TilingLevel::VectorTiles))));
+
if (failed(applyPatternsAndFoldGreedily(funcOp,
std::move(vectorUnrollPatterns)))) {
return signalPassFailure();
diff --git a/iree/compiler/Codegen/LLVMCPU/LLVMCPUUnfuseFMAOps.cpp b/iree/compiler/Codegen/LLVMCPU/LLVMCPUUnfuseFMAOps.cpp
index d4dbee5..42706e6 100644
--- a/iree/compiler/Codegen/LLVMCPU/LLVMCPUUnfuseFMAOps.cpp
+++ b/iree/compiler/Codegen/LLVMCPU/LLVMCPUUnfuseFMAOps.cpp
@@ -55,7 +55,9 @@
auto context = funcOp.getContext();
OwningRewritePatternList patterns(&getContext());
populateUnfusedFMAOpsPassPatterns(context, patterns);
- (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
+ if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
+ return signalPassFailure();
+ }
}
std::unique_ptr<OperationPass<FuncOp>> createLLVMCPUUnfuseFMAOpsPass() {
diff --git a/iree/compiler/Codegen/LLVMCPU/Passes.cpp b/iree/compiler/Codegen/LLVMCPU/Passes.cpp
index d220de3..7c30374 100644
--- a/iree/compiler/Codegen/LLVMCPU/Passes.cpp
+++ b/iree/compiler/Codegen/LLVMCPU/Passes.cpp
@@ -11,7 +11,6 @@
#include "iree/compiler/Codegen/LLVMCPU/KernelDispatch.h"
#include "iree/compiler/Codegen/PassDetail.h"
#include "iree/compiler/Codegen/Utils/Utils.h"
-#include "iree/compiler/Dialect/Shape/Transforms/Passes.h"
#include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
#include "mlir/Dialect/Arithmetic/Transforms/Passes.h"
#include "mlir/Dialect/Linalg/Passes.h"
@@ -22,6 +21,11 @@
namespace mlir {
namespace iree_compiler {
+//===---------------------------------------------------------------------===//
+// Default allocation functions for CPU backend
+//===---------------------------------------------------------------------===//
+
+// Default allocation function to use with IREEs bufferization.
static Value cpuAllocationFunction(OpBuilder &builder, Location loc,
ArrayRef<int64_t> staticShape,
Type elementType,
@@ -30,6 +34,29 @@
return builder.create<memref::AllocaOp>(loc, allocType, dynamicSizes);
}
+// Allocation callbacks to use with upstream comprehensive bufferization
+static Optional<Value> cpuComprehensiveBufferizeAllocationFn(
+ OpBuilder &builder, Location loc, MemRefType memRefType,
+ ArrayRef<Value> dynamicSizes) {
+ return builder.create<memref::AllocaOp>(loc, memRefType, dynamicSizes)
+ .getResult();
+}
+
+static void cpuComprehensiveBufferizeDeallocationFn(OpBuilder &builder,
+ Location loc,
+ Value allocation) {
+ return;
+}
+
+static void cpuComprehensiveBufferizeCopyFn(OpBuilder &builder, Location loc,
+ Value from, Value to) {
+ builder.create<linalg::CopyOp>(loc, from, to);
+}
+
+//===---------------------------------------------------------------------===//
+// Codegen configuration verifications.
+//===---------------------------------------------------------------------===//
+
LogicalResult verifyTensorToVectorsPassPipelineConfig(
Operation *op, IREE::Codegen::LoweringConfigAttr loweringConfig,
IREE::Codegen::TranslationInfoAttr translationInfo,
@@ -149,6 +176,17 @@
passManager.addNestedPass<FuncOp>(createCanonicalizerPass());
// Use stack allocation on CPU side.
+
+ // TODO(ravishankarm): This is commented cause this is WIP, to be enabled
+ // soon.
+ //
+ // auto callbacks =
+ // std::make_unique<linalg::comprehensive_bufferize::AllocationCallbacks>(
+ // cpuComprehensiveBufferizeAllocationFn,
+ // cpuComprehensiveBufferizeDeallocationFn,
+ // cpuComprehensiveBufferizeCopyFn);
+ // addIREEComprehensiveBufferizePasses(passManager, std::move(callbacks));
+
addLinalgBufferizePasses(passManager, cpuAllocationFunction);
passManager.addNestedPass<FuncOp>(createCSEPass());
passManager.addNestedPass<FuncOp>(createCanonicalizerPass());
@@ -170,8 +208,6 @@
// Linalg -> SCF
passManager.addNestedPass<FuncOp>(createConvertLinalgToLoopsPass());
- passManager.addNestedPass<FuncOp>(
- Shape::createFoldDimOverShapeCarryingOpPass());
passManager.addNestedPass<FuncOp>(createCanonicalizerPass());
passManager.addNestedPass<FuncOp>(createCSEPass());
diff --git a/iree/compiler/Codegen/LLVMCPU/test/materialize_launch_configuration.mlir b/iree/compiler/Codegen/LLVMCPU/test/materialize_launch_configuration.mlir
index 3ae924b..f13e5ea 100644
--- a/iree/compiler/Codegen/LLVMCPU/test/materialize_launch_configuration.mlir
+++ b/iree/compiler/Codegen/LLVMCPU/test/materialize_launch_configuration.mlir
@@ -1351,4 +1351,6 @@
}
}
}
-// CHECK: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"CPUTileFuseAndVectorize", workload_per_wg = [64, 64]>
+// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"CPUTileFuseAndVectorize", workload_per_wg = [64, 64]>
+// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering.config<tile_sizes = [{{\[}}], [8, 32, 32], [1, 16, 16]], native_vector_size = [1, 16, 16]>
+// CHECK: linalg.matmul {lowering.config = #[[CONFIG]]}
diff --git a/iree/compiler/Codegen/LLVMGPU/BUILD b/iree/compiler/Codegen/LLVMGPU/BUILD
index e79bc61..ed5ad74 100644
--- a/iree/compiler/Codegen/LLVMGPU/BUILD
+++ b/iree/compiler/Codegen/LLVMGPU/BUILD
@@ -40,7 +40,6 @@
"//iree/compiler/Codegen/Utils",
"//iree/compiler/Dialect/Flow/IR",
"//iree/compiler/Dialect/HAL/IR",
- "//iree/compiler/Dialect/Shape/Transforms",
"//iree/compiler/Dialect/Util/IR",
"//llvm-external-projects/iree-dialects:IREELinalgExtDialect",
"//llvm-external-projects/iree-dialects:IREELinalgExtTransforms",
@@ -75,7 +74,5 @@
"@llvm-project//mlir:VectorToGPU",
"@llvm-project//mlir:VectorToLLVM",
"@llvm-project//mlir:VectorToSCF",
- "@mlir-hlo//:hlo",
- "@mlir-hlo//:legalize_to_linalg",
],
)
diff --git a/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt b/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt
index 1ad842b..b1c09d2 100644
--- a/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt
+++ b/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt
@@ -72,9 +72,7 @@
iree::compiler::Codegen::Utils
iree::compiler::Dialect::Flow::IR
iree::compiler::Dialect::HAL::IR
- iree::compiler::Dialect::Shape::Transforms
iree::compiler::Dialect::Util::IR
- tensorflow::mlir_hlo
PUBLIC
)
diff --git a/iree/compiler/Codegen/LLVMGPU/ConvertToLLVM.cpp b/iree/compiler/Codegen/LLVMGPU/ConvertToLLVM.cpp
index e7c8e95..a985464 100644
--- a/iree/compiler/Codegen/LLVMGPU/ConvertToLLVM.cpp
+++ b/iree/compiler/Codegen/LLVMGPU/ConvertToLLVM.cpp
@@ -105,7 +105,10 @@
OwningRewritePatternList patterns(&getContext());
populateScalarizeMathOps(patterns);
populateConvertSharedMemoryAllocOps(patterns);
- (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+ if (failed(applyPatternsAndFoldGreedily(getOperation(),
+ std::move(patterns)))) {
+ return signalPassFailure();
+ }
}
};
diff --git a/iree/compiler/Codegen/LLVMGPU/ConvertToNVVM.cpp b/iree/compiler/Codegen/LLVMGPU/ConvertToNVVM.cpp
index 510aac3..1c154a3 100644
--- a/iree/compiler/Codegen/LLVMGPU/ConvertToNVVM.cpp
+++ b/iree/compiler/Codegen/LLVMGPU/ConvertToNVVM.cpp
@@ -69,12 +69,16 @@
vector::populateVectorShapeCastLoweringPatterns(patterns);
vector::populateVectorTransposeLoweringPatterns(patterns);
vector::populateVectorTransferLoweringPatterns(patterns);
- (void)applyPatternsAndFoldGreedily(m, std::move(patterns));
+ if (failed(applyPatternsAndFoldGreedily(m, std::move(patterns)))) {
+ return signalPassFailure();
+ }
}
{
OwningRewritePatternList patterns(&getContext());
populateGpuRewritePatterns(patterns);
- (void)applyPatternsAndFoldGreedily(m, std::move(patterns));
+ if (failed(applyPatternsAndFoldGreedily(m, std::move(patterns)))) {
+ return signalPassFailure();
+ }
}
{
OwningRewritePatternList llvmPatterns(&getContext());
diff --git a/iree/compiler/Codegen/LLVMGPU/ConvertToROCDL.cpp b/iree/compiler/Codegen/LLVMGPU/ConvertToROCDL.cpp
index 009e23f..8e5795f 100644
--- a/iree/compiler/Codegen/LLVMGPU/ConvertToROCDL.cpp
+++ b/iree/compiler/Codegen/LLVMGPU/ConvertToROCDL.cpp
@@ -65,12 +65,16 @@
vector::populateVectorShapeCastLoweringPatterns(patterns);
mlir::vector::populateVectorTransposeLoweringPatterns(patterns);
vector::populateVectorTransferLoweringPatterns(patterns);
- (void)applyPatternsAndFoldGreedily(m, std::move(patterns));
+ if (failed(applyPatternsAndFoldGreedily(m, std::move(patterns)))) {
+ return signalPassFailure();
+ }
}
{
OwningRewritePatternList patterns(&getContext());
populateGpuRewritePatterns(patterns);
- (void)applyPatternsAndFoldGreedily(m, std::move(patterns));
+ if (failed(applyPatternsAndFoldGreedily(m, std::move(patterns)))) {
+ return signalPassFailure();
+ }
}
{
OwningRewritePatternList llvmPatterns(&getContext());
diff --git a/iree/compiler/Codegen/LLVMGPU/LLVMGPUDistributeSharedMemoryCopy.cpp b/iree/compiler/Codegen/LLVMGPU/LLVMGPUDistributeSharedMemoryCopy.cpp
index 5dfd252..bb450ab 100644
--- a/iree/compiler/Codegen/LLVMGPU/LLVMGPUDistributeSharedMemoryCopy.cpp
+++ b/iree/compiler/Codegen/LLVMGPU/LLVMGPUDistributeSharedMemoryCopy.cpp
@@ -232,8 +232,10 @@
// Step 1. Vectorize the shared memory copy.
RewritePatternSet vectorizationPatterns(context);
populateVectorizationPatterns(vectorizationPatterns);
- (void)applyPatternsAndFoldGreedily(funcOp,
- std::move(vectorizationPatterns));
+ if (failed(applyPatternsAndFoldGreedily(
+ funcOp, std::move(vectorizationPatterns)))) {
+ return signalPassFailure();
+ }
// Step 2. Unroll transfer_read/transfer_write to a vector with the number
// of element equal to `targetVectorSize * targetVectorSize`. The.
@@ -241,15 +243,20 @@
// size.
RewritePatternSet vectorUnrollPatterns(context);
populateVectorUnrollPatterns(vectorUnrollPatterns, flatWorkgroupSize);
- (void)applyPatternsAndFoldGreedily(funcOp,
- std::move(vectorUnrollPatterns));
+ if (failed(applyPatternsAndFoldGreedily(
+ funcOp, std::move(vectorUnrollPatterns)))) {
+ return signalPassFailure();
+ }
// Step 3. Distribute the transfer ops onto the flat ids.
Value flatId = createFlatId(funcOp, workgroupSize);
distributeTransferRead(funcOp, flatId, flatWorkgroupSize);
// Propagate vector distribution to the chain of ops.
RewritePatternSet distributePatterns(context);
vector::populatePropagateVectorDistributionPatterns(distributePatterns);
- (void)applyPatternsAndFoldGreedily(funcOp, std::move(distributePatterns));
+ if (failed(applyPatternsAndFoldGreedily(funcOp,
+ std::move(distributePatterns)))) {
+ return signalPassFailure();
+ }
} else {
// Fall back to basic tiling for cases where workgroup memory size is not
// well aligned on the number of threads.
@@ -258,15 +265,19 @@
OwningRewritePatternList threadLevelTilingPatterns(context);
populateTilingCopyToWorkgroupMemPatterns(threadLevelTilingPatterns,
workgroupSize);
- (void)applyPatternsAndFoldGreedily(funcOp,
- std::move(threadLevelTilingPatterns));
+ if (failed(applyPatternsAndFoldGreedily(
+ funcOp, std::move(threadLevelTilingPatterns)))) {
+ return signalPassFailure();
+ }
// Apply canonicalization patterns.
RewritePatternSet threadTilingCanonicalizationPatterns =
linalg::getLinalgTilingCanonicalizationPatterns(context);
populateAffineMinSCFCanonicalizationPattern(
threadTilingCanonicalizationPatterns);
- (void)applyPatternsAndFoldGreedily(
- funcOp, std::move(threadTilingCanonicalizationPatterns));
+ if (failed(applyPatternsAndFoldGreedily(
+ funcOp, std::move(threadTilingCanonicalizationPatterns)))) {
+ return signalPassFailure();
+ }
}
}
};
diff --git a/iree/compiler/Codegen/LLVMGPU/LLVMGPUPipelining.cpp b/iree/compiler/Codegen/LLVMGPU/LLVMGPUPipelining.cpp
index 5512385..eefdb14 100644
--- a/iree/compiler/Codegen/LLVMGPU/LLVMGPUPipelining.cpp
+++ b/iree/compiler/Codegen/LLVMGPU/LLVMGPUPipelining.cpp
@@ -92,7 +92,10 @@
options.getScheduleFn = getPipelineStages;
RewritePatternSet pipeliningPatterns(context);
scf::populateSCFLoopPipeliningPatterns(pipeliningPatterns, options);
- (void)applyPatternsAndFoldGreedily(funcOp, std::move(pipeliningPatterns));
+ if (failed(applyPatternsAndFoldGreedily(funcOp,
+ std::move(pipeliningPatterns)))) {
+ return signalPassFailure();
+ }
}
};
} // namespace
diff --git a/iree/compiler/Codegen/LLVMGPU/LLVMGPUTensorCoreVectorization.cpp b/iree/compiler/Codegen/LLVMGPU/LLVMGPUTensorCoreVectorization.cpp
index 9488cae..efa3ce8 100644
--- a/iree/compiler/Codegen/LLVMGPU/LLVMGPUTensorCoreVectorization.cpp
+++ b/iree/compiler/Codegen/LLVMGPU/LLVMGPUTensorCoreVectorization.cpp
@@ -125,8 +125,10 @@
// Step 1. Vectorize
RewritePatternSet vectorizationPatterns(context);
populateVectorizationPatterns(vectorizationPatterns);
- (void)applyPatternsAndFoldGreedily(funcOp,
- std::move(vectorizationPatterns));
+ if (failed(applyPatternsAndFoldGreedily(
+ funcOp, std::move(vectorizationPatterns)))) {
+ return signalPassFailure();
+ }
// Fold consumer add ops into the contraction op itself.
RewritePatternSet canonicalizationPatterns(context);
@@ -134,13 +136,17 @@
canonicalizationPatterns, context);
canonicalizationPatterns.insert<CombineTransferReadOpBroadcast>(
funcOp.getContext());
- (void)applyPatternsAndFoldGreedily(funcOp,
- std::move(canonicalizationPatterns));
+ if (failed(applyPatternsAndFoldGreedily(
+ funcOp, std::move(canonicalizationPatterns)))) {
+ return signalPassFailure();
+ }
RewritePatternSet vectorUnrollPatterns(context);
populateVectorUnrollPatterns(vectorUnrollPatterns);
- (void)applyPatternsAndFoldGreedily(funcOp,
- std::move(vectorUnrollPatterns));
+ if (failed(applyPatternsAndFoldGreedily(
+ funcOp, std::move(vectorUnrollPatterns)))) {
+ return signalPassFailure();
+ }
}
}
};
diff --git a/iree/compiler/Codegen/LLVMGPU/LLVMGPUTileAndDistribute.cpp b/iree/compiler/Codegen/LLVMGPU/LLVMGPUTileAndDistribute.cpp
index 08bab9c..5a93975 100644
--- a/iree/compiler/Codegen/LLVMGPU/LLVMGPUTileAndDistribute.cpp
+++ b/iree/compiler/Codegen/LLVMGPU/LLVMGPUTileAndDistribute.cpp
@@ -285,7 +285,10 @@
// same size.
OwningRewritePatternList wgTilingPatterns(context);
populateTilingReductionPatterns(wgTilingPatterns);
- (void)applyPatternsAndFoldGreedily(funcOp, std::move(wgTilingPatterns));
+ if (failed(applyPatternsAndFoldGreedily(funcOp,
+ std::move(wgTilingPatterns)))) {
+ return signalPassFailure();
+ }
}
{
@@ -293,8 +296,10 @@
linalg::getLinalgTilingCanonicalizationPatterns(context);
populateAffineMinSCFCanonicalizationPattern(
wgTilingCanonicalizationPatterns);
- (void)applyPatternsAndFoldGreedily(
- funcOp, std::move(wgTilingCanonicalizationPatterns));
+ if (failed(applyPatternsAndFoldGreedily(
+ funcOp, std::move(wgTilingCanonicalizationPatterns)))) {
+ return signalPassFailure();
+ }
}
LLVM_DEBUG({
@@ -314,7 +319,10 @@
if (flatWorkgroupSize > kWarpSize) {
OwningRewritePatternList promotionPatterns(&getContext());
populatePromotionPatterns(context, promotionPatterns);
- (void)applyPatternsAndFoldGreedily(funcOp, std::move(promotionPatterns));
+ if (failed(applyPatternsAndFoldGreedily(funcOp,
+ std::move(promotionPatterns)))) {
+ return signalPassFailure();
+ }
// Insert barriers before and after copies to workgroup memory and skip
// insert barriers between back to back copy to workgroup memory.
OpBuilder builder(&getContext());
@@ -337,8 +345,10 @@
{
RewritePatternSet promotionCanonicalization =
linalg::getLinalgTilingCanonicalizationPatterns(context);
- (void)applyPatternsAndFoldGreedily(funcOp,
- std::move(promotionCanonicalization));
+ if (failed(applyPatternsAndFoldGreedily(
+ funcOp, std::move(promotionCanonicalization)))) {
+ return signalPassFailure();
+ }
}
LLVM_DEBUG({
@@ -351,16 +361,20 @@
OwningRewritePatternList warpLevelTilingPatterns(context);
populateTilingToWarpPatterns(warpLevelTilingPatterns, workgroupSize,
workloadPerWorkgroup);
- (void)applyPatternsAndFoldGreedily(funcOp,
- std::move(warpLevelTilingPatterns));
+ if (failed(applyPatternsAndFoldGreedily(
+ funcOp, std::move(warpLevelTilingPatterns)))) {
+ return signalPassFailure();
+ }
} else {
// Apply last level of tiling and distribute to threads.
OwningRewritePatternList threadLevelTilingPatterns(context);
populateTilingToInvocationPatterns(threadLevelTilingPatterns,
workgroupSize, workloadPerWorkgroup);
- (void)applyPatternsAndFoldGreedily(funcOp,
- std::move(threadLevelTilingPatterns));
+ if (failed(applyPatternsAndFoldGreedily(
+ funcOp, std::move(threadLevelTilingPatterns)))) {
+ return signalPassFailure();
+ }
}
{
// Apply canonicalization patterns.
@@ -368,8 +382,10 @@
linalg::getLinalgTilingCanonicalizationPatterns(context);
populateAffineMinSCFCanonicalizationPattern(
threadTilingCanonicalizationPatterns);
- (void)applyPatternsAndFoldGreedily(
- funcOp, std::move(threadTilingCanonicalizationPatterns));
+ if (failed(applyPatternsAndFoldGreedily(
+ funcOp, std::move(threadTilingCanonicalizationPatterns)))) {
+ return signalPassFailure();
+ }
}
LLVM_DEBUG({
diff --git a/iree/compiler/Codegen/LLVMGPU/LLVMGPUVectorLowering.cpp b/iree/compiler/Codegen/LLVMGPU/LLVMGPUVectorLowering.cpp
index 5301d92..55ef920 100644
--- a/iree/compiler/Codegen/LLVMGPU/LLVMGPUVectorLowering.cpp
+++ b/iree/compiler/Codegen/LLVMGPU/LLVMGPUVectorLowering.cpp
@@ -34,8 +34,10 @@
vectorToSCFOptions);
memref::populateFoldSubViewOpPatterns(vectorToLoopsPatterns);
vector::populateVectorTransferLoweringPatterns(vectorToLoopsPatterns);
- (void)applyPatternsAndFoldGreedily(funcOp,
- std::move(vectorToLoopsPatterns));
+ if (failed(applyPatternsAndFoldGreedily(
+ funcOp, std::move(vectorToLoopsPatterns)))) {
+ return signalPassFailure();
+ }
}
};
} // namespace
diff --git a/iree/compiler/Codegen/LLVMGPU/LLVMGPUVectorization.cpp b/iree/compiler/Codegen/LLVMGPU/LLVMGPUVectorization.cpp
index 78c94fe..d87e155 100644
--- a/iree/compiler/Codegen/LLVMGPU/LLVMGPUVectorization.cpp
+++ b/iree/compiler/Codegen/LLVMGPU/LLVMGPUVectorization.cpp
@@ -88,20 +88,26 @@
// Step 1. Vectorize
RewritePatternSet vectorizationPatterns(context);
populateVectorizationPatterns(vectorizationPatterns);
- (void)applyPatternsAndFoldGreedily(funcOp,
- std::move(vectorizationPatterns));
+ if (failed(applyPatternsAndFoldGreedily(
+ funcOp, std::move(vectorizationPatterns)))) {
+ return signalPassFailure();
+ }
// Fold consumer add ops into the contraction op itself.
RewritePatternSet canonicalizationPatterns(context);
vector::ContractionOp::getCanonicalizationPatterns(
canonicalizationPatterns, context);
- (void)applyPatternsAndFoldGreedily(funcOp,
- std::move(canonicalizationPatterns));
+ if (failed(applyPatternsAndFoldGreedily(
+ funcOp, std::move(canonicalizationPatterns)))) {
+ return signalPassFailure();
+ }
RewritePatternSet vectorUnrollPatterns(context);
populateVectorUnrollPatterns(vectorUnrollPatterns);
- (void)applyPatternsAndFoldGreedily(funcOp,
- std::move(vectorUnrollPatterns));
+ if (failed(applyPatternsAndFoldGreedily(
+ funcOp, std::move(vectorUnrollPatterns)))) {
+ return signalPassFailure();
+ }
DEBUG_WITH_TYPE(DEBUG_TYPE, {
llvm::dbgs() << "\n--- After Step 1: Vectorization ---\n";
funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
@@ -115,8 +121,10 @@
lowerTransferOpPatterns);
vector::populateVectorTransferPermutationMapLoweringPatterns(
lowerTransferOpPatterns);
- (void)applyPatternsAndFoldGreedily(funcOp,
- std::move(lowerTransferOpPatterns));
+ if (failed(applyPatternsAndFoldGreedily(
+ funcOp, std::move(lowerTransferOpPatterns)))) {
+ return signalPassFailure();
+ }
DEBUG_WITH_TYPE(DEBUG_TYPE, {
llvm::dbgs()
<< "\n--- After Step 2: Lower transfer op to canonical form. ---\n";
@@ -132,8 +140,10 @@
canonicalizationPatterns, canonicalizationPatterns.getContext());
vector::populateVectorToVectorCanonicalizationPatterns(
canonicalizationPatterns);
- (void)applyPatternsAndFoldGreedily(funcOp,
- std::move(canonicalizationPatterns));
+ if (failed(applyPatternsAndFoldGreedily(
+ funcOp, std::move(canonicalizationPatterns)))) {
+ return signalPassFailure();
+ }
DEBUG_WITH_TYPE(DEBUG_TYPE, {
llvm::dbgs() << "\n--- After Step 3: Canonicalize. ---\n";
funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
@@ -153,8 +163,10 @@
vector::populateVectorMultiReductionLoweringPatterns(
contractLoweringPatterns,
vector::VectorMultiReductionLowering::InnerParallel);
- (void)applyPatternsAndFoldGreedily(funcOp,
- std::move(contractLoweringPatterns));
+ if (failed(applyPatternsAndFoldGreedily(
+ funcOp, std::move(contractLoweringPatterns)))) {
+ return signalPassFailure();
+ }
DEBUG_WITH_TYPE(DEBUG_TYPE, {
llvm::dbgs()
<< "\n--- After Step 4: Lower contract op to outer product. ---\n";
diff --git a/iree/compiler/Codegen/LLVMGPU/Passes.cpp b/iree/compiler/Codegen/LLVMGPU/Passes.cpp
index f4f1a9b..cda7243 100644
--- a/iree/compiler/Codegen/LLVMGPU/Passes.cpp
+++ b/iree/compiler/Codegen/LLVMGPU/Passes.cpp
@@ -8,7 +8,6 @@
#include "iree-dialects/Dialect/LinalgExt/Transforms/Passes.h"
#include "iree/compiler/Codegen/PassDetail.h"
-#include "iree/compiler/Dialect/Shape/Transforms/Passes.h"
#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
#include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
diff --git a/iree/compiler/Codegen/Passes.h b/iree/compiler/Codegen/Passes.h
index 569656f..099089d 100644
--- a/iree/compiler/Codegen/Passes.h
+++ b/iree/compiler/Codegen/Passes.h
@@ -50,8 +50,7 @@
void addIREEComprehensiveBufferizePasses(
OpPassManager &passManager,
std::unique_ptr<linalg::comprehensive_bufferize::AllocationCallbacks>
- allocationFn =
- linalg::comprehensive_bufferize::defaultAllocationCallbacks());
+ allocationFn = nullptr);
/// Pass to perform canonicalizations/cleanups related to HAL interface/buffer
/// allocations and view operations.
@@ -91,7 +90,7 @@
WorkgroupMemoryAllocationFn allocationFn = nullptr);
std::unique_ptr<OperationPass<ModuleOp>> createIREEComprehensiveBufferizePass(
std::unique_ptr<linalg::comprehensive_bufferize::AllocationCallbacks> =
- linalg::comprehensive_bufferize::defaultAllocationCallbacks());
+ nullptr);
/// Creates a pass to remove single iteration distributed loops.
std::unique_ptr<OperationPass<FuncOp>> createRemoveSingleIterationLoopPass();
diff --git a/iree/compiler/Codegen/SPIRV/BUILD b/iree/compiler/Codegen/SPIRV/BUILD
index 64a03d1..b0a65e8 100644
--- a/iree/compiler/Codegen/SPIRV/BUILD
+++ b/iree/compiler/Codegen/SPIRV/BUILD
@@ -42,8 +42,6 @@
"//iree/compiler/Dialect/Flow/IR",
"//iree/compiler/Dialect/HAL/IR",
"//iree/compiler/Dialect/HAL/IR:HALDialect",
- "//iree/compiler/Dialect/Shape/IR",
- "//iree/compiler/Dialect/Shape/Transforms",
"//iree/compiler/Dialect/Util/IR",
"//llvm-external-projects/iree-dialects:IREELinalgExtDialect",
"//llvm-external-projects/iree-dialects:IREELinalgExtTransforms",
@@ -84,7 +82,5 @@
"@llvm-project//mlir:VectorInterfaces",
"@llvm-project//mlir:VectorOps",
"@llvm-project//mlir:VectorToSPIRV",
- "@mlir-hlo//:hlo",
- "@mlir-hlo//:legalize_to_linalg",
],
)
diff --git a/iree/compiler/Codegen/SPIRV/CMakeLists.txt b/iree/compiler/Codegen/SPIRV/CMakeLists.txt
index 2bfa3ef..fd663a1 100644
--- a/iree/compiler/Codegen/SPIRV/CMakeLists.txt
+++ b/iree/compiler/Codegen/SPIRV/CMakeLists.txt
@@ -79,10 +79,7 @@
iree::compiler::Dialect::Flow::IR
iree::compiler::Dialect::HAL::IR
iree::compiler::Dialect::HAL::IR::HALDialect
- iree::compiler::Dialect::Shape::IR
- iree::compiler::Dialect::Shape::Transforms
iree::compiler::Dialect::Util::IR
- tensorflow::mlir_hlo
PUBLIC
)
diff --git a/iree/compiler/Codegen/SPIRV/Passes.cpp b/iree/compiler/Codegen/SPIRV/Passes.cpp
index 60448c4..4202ca0 100644
--- a/iree/compiler/Codegen/SPIRV/Passes.cpp
+++ b/iree/compiler/Codegen/SPIRV/Passes.cpp
@@ -17,7 +17,6 @@
#include "iree/compiler/Codegen/PassDetail.h"
#include "iree/compiler/Codegen/Passes.h"
#include "iree/compiler/Codegen/SPIRV/MemorySpace.h"
-#include "iree/compiler/Dialect/Shape/Transforms/Passes.h"
#include "llvm/Support/Debug.h"
#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
#include "mlir/Dialect/Arithmetic/Transforms/Passes.h"
@@ -66,7 +65,6 @@
// In SPIR-V we don't use memref descriptor so it's not possible to handle
// subview ops.
pm.addPass(memref::createFoldSubViewOpsPass());
- pm.addNestedPass<FuncOp>(Shape::createFoldDimOverShapeCarryingOpPass());
pm.addNestedPass<FuncOp>(arith::createArithmeticExpandOpsPass());
pm.addNestedPass<FuncOp>(createStdExpandOpsPass());
pm.addPass(createCanonicalizerPass());
diff --git a/iree/compiler/Codegen/SPIRV/SPIRVCopyToWorkgroupMemory.cpp b/iree/compiler/Codegen/SPIRV/SPIRVCopyToWorkgroupMemory.cpp
index d2c4f3f..76c1215 100644
--- a/iree/compiler/Codegen/SPIRV/SPIRVCopyToWorkgroupMemory.cpp
+++ b/iree/compiler/Codegen/SPIRV/SPIRVCopyToWorkgroupMemory.cpp
@@ -290,7 +290,10 @@
linalg::getLinalgTilingCanonicalizationPatterns(context);
populateAffineMinCanonicalizationPattern(canonicalizePatterns);
scf::populateSCFForLoopCanonicalizationPatterns(canonicalizePatterns);
- (void)applyPatternsAndFoldGreedily(funcOp, std::move(canonicalizePatterns));
+ if (failed(applyPatternsAndFoldGreedily(funcOp,
+ std::move(canonicalizePatterns)))) {
+ return signalPassFailure();
+ }
// 3. Vectorize the tiled linalg to be able to map it to load/store vector.
OwningRewritePatternList vectorizationPatterns(&getContext());
@@ -298,7 +301,10 @@
vectorizationPatterns, linalg::LinalgVectorizationOptions(),
linalg::LinalgTransformationFilter(
Identifier::get(getVectorizeMarker(), context), {}));
- (void)applyPatternsAndFoldGreedily(funcOp, std::move(vectorizationPatterns));
+ if (failed(applyPatternsAndFoldGreedily(funcOp,
+ std::move(vectorizationPatterns)))) {
+ return signalPassFailure();
+ }
}
void SPIRVCopyToWorkgroupMemoryPass::runOnOperation() {
diff --git a/iree/compiler/Codegen/SPIRV/SPIRVLowerExecutableTargetPass.cpp b/iree/compiler/Codegen/SPIRV/SPIRVLowerExecutableTargetPass.cpp
index d4c49ea..51cbc60 100644
--- a/iree/compiler/Codegen/SPIRV/SPIRVLowerExecutableTargetPass.cpp
+++ b/iree/compiler/Codegen/SPIRV/SPIRVLowerExecutableTargetPass.cpp
@@ -41,8 +41,7 @@
.insert<IREE::Codegen::IREECodegenDialect, AffineDialect,
gpu::GPUDialect, IREE::HAL::HALDialect, linalg::LinalgDialect,
IREE::LinalgExt::IREELinalgExtDialect, memref::MemRefDialect,
- scf::SCFDialect, ShapeDialect, spirv::SPIRVDialect,
- vector::VectorDialect>();
+ scf::SCFDialect, spirv::SPIRVDialect, vector::VectorDialect>();
}
void runOnOperation() override;
diff --git a/iree/compiler/Codegen/SPIRV/SPIRVTileAndDistribute.cpp b/iree/compiler/Codegen/SPIRV/SPIRVTileAndDistribute.cpp
index 613f9a5..fa50bc1 100644
--- a/iree/compiler/Codegen/SPIRV/SPIRVTileAndDistribute.cpp
+++ b/iree/compiler/Codegen/SPIRV/SPIRVTileAndDistribute.cpp
@@ -185,8 +185,11 @@
{ // Tile and distribute to invocations.
RewritePatternSet invocationTilingPatterns(&getContext());
populateTilingToInvocationPatterns(context, invocationTilingPatterns);
- (void)applyPatternsAndFoldGreedily(funcOp,
- std::move(invocationTilingPatterns));
+ if (failed(applyPatternsAndFoldGreedily(
+ funcOp, std::move(invocationTilingPatterns)))) {
+ funcOp.emitOpError() << "failure in tiling";
+ return signalPassFailure();
+ }
LLVM_DEBUG({
llvm::dbgs() << "--- After tiling to invocations ---\n";
@@ -201,8 +204,14 @@
populateFoldAffineMinInDistributedLoopsPatterns(canonicalizationPatterns);
- (void)applyPatternsAndFoldGreedily(funcOp,
- std::move(canonicalizationPatterns));
+ if (failed(applyPatternsAndFoldGreedily(
+ funcOp, std::move(canonicalizationPatterns)))) {
+ // TODO(#4759): Terrifyingly, this fails. Errors here were ignored for a
+ // long time and now tests for this pass actually fail if we propagate the
+ // failure correctly. Fix this.
+ // funcOp.emitOpError() << "failure canonicalizing after tiling";
+ // return signalPassFailure();
+ }
LLVM_DEBUG({
llvm::dbgs() << "--- After tiling canonicalization ---\n";
@@ -216,14 +225,20 @@
auto marker = getLinalgMatchAndReplaceMarker(getTileReductionMarker(),
getVectorizeMarker(), context);
populateTilingReductionPatterns(context, reductionTilingPatterns, marker);
- (void)applyPatternsAndFoldGreedily(funcOp,
- std::move(reductionTilingPatterns));
+ if (failed(applyPatternsAndFoldGreedily(
+ funcOp, std::move(reductionTilingPatterns)))) {
+ funcOp.emitOpError() << "failing in tile reduction";
+ return signalPassFailure();
+ }
RewritePatternSet canonicalizationPatterns =
linalg::getLinalgTilingCanonicalizationPatterns(context);
scf::populateSCFForLoopCanonicalizationPatterns(canonicalizationPatterns);
- (void)applyPatternsAndFoldGreedily(funcOp,
- std::move(canonicalizationPatterns));
+ if (failed(applyPatternsAndFoldGreedily(
+ funcOp, std::move(canonicalizationPatterns)))) {
+ funcOp.emitOpError() << "failing canonicalizing after tile reduction";
+ return signalPassFailure();
+ }
LLVM_DEBUG({
llvm::dbgs() << "--- After tiling reduction dimensions ---\n";
diff --git a/iree/compiler/Codegen/SPIRV/SPIRVTileAndVectorizeToCooperativeOps.cpp b/iree/compiler/Codegen/SPIRV/SPIRVTileAndVectorizeToCooperativeOps.cpp
index 2a5dfbd..f79fda8 100644
--- a/iree/compiler/Codegen/SPIRV/SPIRVTileAndVectorizeToCooperativeOps.cpp
+++ b/iree/compiler/Codegen/SPIRV/SPIRVTileAndVectorizeToCooperativeOps.cpp
@@ -306,14 +306,18 @@
{
RewritePatternSet subgroupTilingPatterns(context);
populateTilingToSubgroupPatterns(subgroupCounts, subgroupTilingPatterns);
- (void)applyPatternsAndFoldGreedily(funcOp,
- std::move(subgroupTilingPatterns));
+ if (failed(applyPatternsAndFoldGreedily(
+ funcOp, std::move(subgroupTilingPatterns)))) {
+ return signalPassFailure();
+ }
RewritePatternSet canonicalizationPatterns =
linalg::getLinalgTilingCanonicalizationPatterns(context);
populateFoldAffineMinInDistributedLoopsPatterns(canonicalizationPatterns);
- (void)applyPatternsAndFoldGreedily(funcOp,
- std::move(canonicalizationPatterns));
+ if (failed(applyPatternsAndFoldGreedily(
+ funcOp, std::move(canonicalizationPatterns)))) {
+ return signalPassFailure();
+ }
}
LLVM_DEBUG({
@@ -327,14 +331,18 @@
{
RewritePatternSet vectorizationPatterns(context);
populateVectorizationPatterns(context, vectorizationPatterns);
- (void)applyPatternsAndFoldGreedily(funcOp,
- std::move(vectorizationPatterns));
+ if (failed(applyPatternsAndFoldGreedily(
+ funcOp, std::move(vectorizationPatterns)))) {
+ return signalPassFailure();
+ }
RewritePatternSet canonicalizationPatterns(context);
vector::ContractionOp::getCanonicalizationPatterns(
canonicalizationPatterns, context);
- (void)applyPatternsAndFoldGreedily(funcOp,
- std::move(canonicalizationPatterns));
+ if (failed(applyPatternsAndFoldGreedily(
+ funcOp, std::move(canonicalizationPatterns)))) {
+ return signalPassFailure();
+ }
}
LLVM_DEBUG({
@@ -346,8 +354,10 @@
{
RewritePatternSet vectorUnrollPatterns(context);
populateVectorUnrollPatterns(cooperativeOpSize, vectorUnrollPatterns);
- (void)applyPatternsAndFoldGreedily(funcOp,
- std::move(vectorUnrollPatterns));
+ if (failed(applyPatternsAndFoldGreedily(
+ funcOp, std::move(vectorUnrollPatterns)))) {
+ return signalPassFailure();
+ }
}
LLVM_DEBUG({
@@ -370,8 +380,10 @@
RewritePatternSet canonicalizationPatterns(context);
vector::populateVectorTransferPermutationMapLoweringPatterns(
canonicalizationPatterns);
- (void)applyPatternsAndFoldGreedily(funcOp,
- std::move(canonicalizationPatterns));
+ if (failed(applyPatternsAndFoldGreedily(
+ funcOp, std::move(canonicalizationPatterns)))) {
+ return signalPassFailure();
+ }
}
LLVM_DEBUG({
@@ -385,8 +397,10 @@
// converted to cooperative matrix matmul op.
RewritePatternSet combineTransposePatterns(context);
combineTransposePatterns.add<CombineContractTranspose>(context);
- (void)applyPatternsAndFoldGreedily(funcOp,
- std::move(combineTransposePatterns));
+ if (failed(applyPatternsAndFoldGreedily(
+ funcOp, std::move(combineTransposePatterns)))) {
+ return signalPassFailure();
+ }
LLVM_DEBUG({
llvm::dbgs() << "--- After handling transposes ---\n";
diff --git a/iree/compiler/Codegen/SPIRV/SPIRVVectorize.cpp b/iree/compiler/Codegen/SPIRV/SPIRVVectorize.cpp
index cec5d1a..0707c7b 100644
--- a/iree/compiler/Codegen/SPIRV/SPIRVVectorize.cpp
+++ b/iree/compiler/Codegen/SPIRV/SPIRVVectorize.cpp
@@ -33,21 +33,27 @@
namespace iree_compiler {
namespace {
-Optional<SmallVector<int64_t, 4>> getSPIRVNativeVectorSize(Operation *op) {
+int getNativeVectorSize(int64_t size) {
+ // Try to use 4 first, and then 2, and then 1.
+ return size % 4 == 0 ? 4 : (size % 2 == 0 ? 2 : 1);
+}
+
+Optional<SmallVector<int64_t, 4>> getNativeVectorShape(Operation *op) {
if (OpTrait::hasElementwiseMappableTraits(op) && op->getNumResults() == 1) {
if (auto vecType = op->getResultTypes()[0].dyn_cast<VectorType>()) {
- // Use 4-element vectors for elementwise ops.
SmallVector<int64_t, 4> nativeSize(vecType.getRank(), 1);
- nativeSize.back() = 4;
+ nativeSize.back() = getNativeVectorSize(vecType.getShape().back());
return nativeSize;
}
} else if (auto vtOp = dyn_cast<VectorTransferOpInterface>(op)) {
- auto rank = vtOp.getVectorType().getRank();
- SmallVector<int64_t, 4> nativeSize(rank, 1);
+ auto vecType = vtOp.getVectorType();
+ SmallVector<int64_t, 4> nativeSize(vecType.getRank(), 1);
for (auto dim : llvm::enumerate(vtOp.permutation_map().getResults())) {
if (auto dimExpr = dim.value().dyn_cast<AffineDimExpr>()) {
- if (dimExpr.getPosition() == vtOp.permutation_map().getNumDims() - 1)
- nativeSize[dim.index()] = 4;
+ if (dimExpr.getPosition() == vtOp.permutation_map().getNumDims() - 1) {
+ nativeSize[dim.index()] =
+ getNativeVectorSize(vecType.getShape()[dim.index()]);
+ }
}
}
return nativeSize;
@@ -57,8 +63,7 @@
if (isParallelIterator(it.value())) lastParalleldim = it.index();
}
SmallVector<int64_t, 4> nativeSize(contractOp.iterator_types().size(), 1);
- nativeSize[lastParalleldim] = 4;
- // Map to vec4 fma operations.
+ nativeSize[lastParalleldim] = 4; // Map to vec4 fma operations.
return nativeSize;
}
return llvm::None;
@@ -81,7 +86,7 @@
RewritePatternSet &patterns) {
vector::populateVectorUnrollPatterns(
patterns,
- vector::UnrollVectorOptions().setNativeShapeFn(getSPIRVNativeVectorSize));
+ vector::UnrollVectorOptions().setNativeShapeFn(getNativeVectorShape));
}
/// Vectorizes Linalg ops on buffer semantics.
@@ -103,15 +108,19 @@
populateVectorizationPatterns(context, vectorizationPatterns);
populateLinalgToVectorVectorizeConvPatterns(context,
vectorizationPatterns);
- (void)applyPatternsAndFoldGreedily(funcOp,
- std::move(vectorizationPatterns));
+ if (failed(applyPatternsAndFoldGreedily(
+ funcOp, std::move(vectorizationPatterns)))) {
+ return signalPassFailure();
+ }
// Fold consumer add ops into the contraction op itself.
RewritePatternSet canonicalizationPatterns(context);
vector::ContractionOp::getCanonicalizationPatterns(
canonicalizationPatterns, context);
- (void)applyPatternsAndFoldGreedily(funcOp,
- std::move(canonicalizationPatterns));
+ if (failed(applyPatternsAndFoldGreedily(
+ funcOp, std::move(canonicalizationPatterns)))) {
+ return signalPassFailure();
+ }
}
LLVM_DEBUG({
@@ -123,8 +132,10 @@
{
RewritePatternSet vectorUnrollPatterns(funcOp.getContext());
populateVectorUnrollPatterns(funcOp.getContext(), vectorUnrollPatterns);
- (void)applyPatternsAndFoldGreedily(funcOp,
- std::move(vectorUnrollPatterns));
+ if (failed(applyPatternsAndFoldGreedily(
+ funcOp, std::move(vectorUnrollPatterns)))) {
+ return signalPassFailure();
+ }
}
LLVM_DEBUG({
@@ -147,8 +158,10 @@
canonicalizationPatterns, context);
vector::populateVectorTransferPermutationMapLoweringPatterns(
canonicalizationPatterns);
- (void)applyPatternsAndFoldGreedily(funcOp,
- std::move(canonicalizationPatterns));
+ if (failed(applyPatternsAndFoldGreedily(
+ funcOp, std::move(canonicalizationPatterns)))) {
+ return signalPassFailure();
+ }
}
LLVM_DEBUG({
@@ -164,8 +177,10 @@
contractLoweringPatterns,
vector::VectorTransformsOptions().setVectorTransformsOptions(
vector::VectorContractLowering::OuterProduct));
- (void)applyPatternsAndFoldGreedily(funcOp,
- std::move(contractLoweringPatterns));
+ if (failed(applyPatternsAndFoldGreedily(
+ funcOp, std::move(contractLoweringPatterns)))) {
+ return signalPassFailure();
+ }
}
LLVM_DEBUG({
diff --git a/iree/compiler/Codegen/SPIRV/SPIRVVectorizeLoadStore.cpp b/iree/compiler/Codegen/SPIRV/SPIRVVectorizeLoadStore.cpp
index b7b87ce..a8c36be 100644
--- a/iree/compiler/Codegen/SPIRV/SPIRVVectorizeLoadStore.cpp
+++ b/iree/compiler/Codegen/SPIRV/SPIRVVectorizeLoadStore.cpp
@@ -119,8 +119,8 @@
public:
explicit MemRefUsageAnalysis(mlir::Operation *);
- // Returns true if the memref should be converted to a vector of memref.
- bool vectorizeMemRef(Value value) const {
+ // Returns true if the memref should be converted to a memref of vectors.
+ bool shouldVectorizeMemRef(Value value) const {
return valueToVectorBitsMap.count(value);
}
@@ -130,12 +130,17 @@
}
// Returns true if the transfer operation needs to be updated during memref
// vectorization.
- bool transferConvert(Operation *op) const { return transferOps.count(op); }
+ bool shouldConvertTransfer(Operation *op) const {
+ return transferOps.count(op);
+ }
private:
void analyzeMemRefValue(Value value);
+ // The mapping from a MemRef value to the number of bits of the vector this
+ // MemRef value should be vectorized into.
llvm::DenseMap<Value, unsigned> valueToVectorBitsMap;
+ // A list of transfer ops that should be adjusted for memref vectorization.
llvm::DenseSet<Operation *> transferOps;
};
@@ -174,9 +179,10 @@
const MemRefUsageAnalysis &memrefUsageAnalysis;
};
-class ProcessFuncArg final : public MemRefConversionPattern<FuncOp> {
+class ProcessFunctionArgument final : public MemRefConversionPattern<FuncOp> {
public:
- using MemRefConversionPattern<FuncOp>::MemRefConversionPattern;
+ using MemRefConversionPattern::MemRefConversionPattern;
+
LogicalResult matchAndRewrite(
FuncOp funcOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
@@ -185,12 +191,12 @@
class ProcessTransferRead final
: public MemRefConversionPattern<vector::TransferReadOp> {
public:
- using MemRefConversionPattern<
- vector::TransferReadOp>::MemRefConversionPattern;
+ using MemRefConversionPattern::MemRefConversionPattern;
+
LogicalResult matchAndRewrite(
vector::TransferReadOp read, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- if (!memrefUsageAnalysis.transferConvert(read)) {
+ if (!memrefUsageAnalysis.shouldConvertTransfer(read)) {
return rewriter.notifyMatchFailure(
read, "cannot be vectorized per memref usage analysis");
}
@@ -211,8 +217,7 @@
return failure();
unsigned ratio = *vectorMemrefElemSize / *scalarMemrefElemSize;
- SmallVector<Value, 4> indices(adaptor.indices().begin(),
- adaptor.indices().end());
+ auto indices = llvm::to_vector<4>(adaptor.indices());
indices.back() = rewriter.create<arith::DivSIOp>(
loc, indices.back(),
rewriter.create<arith::ConstantIndexOp>(loc, ratio));
@@ -241,12 +246,12 @@
class ProcessTransferWrite final
: public MemRefConversionPattern<vector::TransferWriteOp> {
public:
- using MemRefConversionPattern<
- vector::TransferWriteOp>::MemRefConversionPattern;
+ using MemRefConversionPattern::MemRefConversionPattern;
+
LogicalResult matchAndRewrite(
vector::TransferWriteOp write, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- if (!memrefUsageAnalysis.transferConvert(write)) {
+ if (!memrefUsageAnalysis.shouldConvertTransfer(write)) {
return rewriter.notifyMatchFailure(
write, "cannot be vectorized per memref usage analysis");
}
@@ -302,32 +307,36 @@
template <typename OpTy>
Optional<MemRefType> MemRefConversionPattern<OpTy>::getVectorizedMemRefType(
ConversionPatternRewriter &rewriter, Value memRefValue) const {
- unsigned vecSizeInBits =
- memrefUsageAnalysis.getMemRefVectorSizeInBits(memRefValue);
MemRefType type = memRefValue.getType().cast<MemRefType>();
- unsigned elemSize = type.getElementTypeBitWidth();
- unsigned numElements = vecSizeInBits / elemSize;
- Type elemType = type.getElementType();
+ unsigned vectorNumBits =
+ memrefUsageAnalysis.getMemRefVectorSizeInBits(memRefValue);
+
+ Type scalarType = type.getElementType();
+ unsigned scalarNumBits = type.getElementTypeBitWidth();
+ unsigned vectorNumElements = vectorNumBits / scalarNumBits;
// If the vector we need to generate is bigger than the the max vector size
// allowed for loads use a larger element type.
- if (numElements > kMaxVectorNumElements) {
- elemType = elemType.isa<IntegerType>() ? rewriter.getI32Type().cast<Type>()
- : rewriter.getF32Type().cast<Type>();
- elemSize = elemType.getIntOrFloatBitWidth();
- numElements = vecSizeInBits / elemSize;
+ if (vectorNumElements > kMaxVectorNumElements) {
+ scalarType = scalarType.isa<IntegerType>()
+ ? rewriter.getI32Type().cast<Type>()
+ : rewriter.getF32Type().cast<Type>();
+ scalarNumBits = scalarType.getIntOrFloatBitWidth();
+ vectorNumElements = vectorNumBits / scalarNumBits;
}
- Type vecType = VectorType::get(numElements, elemType);
- SmallVector<int64_t, 2> newShape(type.getShape().begin(),
- type.getShape().end());
- unsigned ratio = vecSizeInBits / type.getElementTypeBitWidth();
+
+ Type vectorType = VectorType::get(vectorNumElements, scalarType);
+ auto newShape = llvm::to_vector<2>(type.getShape());
+ unsigned ratio = vectorNumBits / type.getElementTypeBitWidth();
if (newShape.back() % ratio != 0) return {};
newShape.back() = newShape.back() / ratio;
- return MemRefType::get(newShape, vecType, {}, type.getMemorySpaceAsInt());
+
+ return MemRefType::get(newShape, vectorType, {}, type.getMemorySpaceAsInt());
}
class ProcessAlloc final : public MemRefConversionPattern<memref::AllocOp> {
public:
- using MemRefConversionPattern<memref::AllocOp>::MemRefConversionPattern;
+ using MemRefConversionPattern::MemRefConversionPattern;
+
LogicalResult matchAndRewrite(
memref::AllocOp alloc, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
@@ -342,8 +351,7 @@
class ProcessInterfaceBinding final
: public MemRefConversionPattern<IREE::HAL::InterfaceBindingSubspanOp> {
public:
- using MemRefConversionPattern<
- IREE::HAL::InterfaceBindingSubspanOp>::MemRefConversionPattern;
+ using MemRefConversionPattern::MemRefConversionPattern;
LogicalResult matchAndRewrite(
IREE::HAL::InterfaceBindingSubspanOp bindingOp, OpAdaptor adaptor,
@@ -355,7 +363,7 @@
assert(memrefType.getRank() > 0 &&
!ShapedType::isDynamic(memrefType.getShape().back()));
- auto vecMemRef = getVectorizedMemRefType(rewriter, bindingOp.getResult());
+ auto vecMemRef = getVectorizedMemRefType(rewriter, bindingOp.result());
if (!vecMemRef) {
return rewriter.notifyMatchFailure(bindingOp,
"cannot get vectorized memref type");
@@ -390,6 +398,7 @@
struct ScalarizeVectorTransferRead final
: public OpRewritePattern<vector::TransferReadOp> {
using OpRewritePattern::OpRewritePattern;
+
LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
PatternRewriter &rewriter) const override {
VectorType vectorType = readOp.getType();
@@ -402,8 +411,7 @@
Value newVector = rewriter.create<arith::ConstantOp>(
loc, vectorType, rewriter.getZeroAttr(vectorType));
for (int i = 0; i < vectorType.getDimSize(0); ++i) {
- SmallVector<Value, 4> indices(readOp.indices().begin(),
- readOp.indices().end());
+ auto indices = llvm::to_vector<4>(readOp.indices());
indices.back() = rewriter.createOrFold<arith::AddIOp>(
loc, indices.back(),
rewriter.createOrFold<arith::ConstantIndexOp>(loc, i));
@@ -419,6 +427,7 @@
struct ScalarizeVectorTransferWrite final
: public OpRewritePattern<vector::TransferWriteOp> {
using OpRewritePattern::OpRewritePattern;
+
LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
PatternRewriter &rewriter) const override {
VectorType vectorType = writeOp.getVectorType();
@@ -429,8 +438,7 @@
Location loc = writeOp.getLoc();
for (int i = 0; i < vectorType.getDimSize(0); ++i) {
- SmallVector<Value, 4> indices(writeOp.indices().begin(),
- writeOp.indices().end());
+ auto indices = llvm::to_vector<4>(writeOp.indices());
indices.back() = rewriter.createOrFold<arith::AddIOp>(
loc, indices.back(),
rewriter.createOrFold<arith::ConstantIndexOp>(loc, i));
@@ -456,13 +464,13 @@
};
} // namespace
-LogicalResult ProcessFuncArg::matchAndRewrite(
+LogicalResult ProcessFunctionArgument::matchAndRewrite(
FuncOp funcOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
TypeConverter::SignatureConversion signatureConverter(
funcOp.getType().getNumInputs());
for (const auto &arg : llvm::enumerate(funcOp.getArguments())) {
- if (memrefUsageAnalysis.vectorizeMemRef(arg.value())) {
+ if (memrefUsageAnalysis.shouldVectorizeMemRef(arg.value())) {
if (auto memrefType = getVectorizedMemRefType(rewriter, arg.value())) {
signatureConverter.addInputs(arg.index(), *memrefType);
continue;
@@ -491,7 +499,7 @@
RewritePatternSet conversionPatterns(context);
conversionPatterns
- .add<ProcessFuncArg, ProcessTransferRead, ProcessTransferWrite,
+ .add<ProcessFunctionArgument, ProcessTransferRead, ProcessTransferWrite,
ProcessAlloc, ProcessInterfaceBinding>(context,
*memrefUsageAnalysis);
conversionPatterns.add<PassThroughConversion<memref::DeallocOp>>(context);
@@ -499,21 +507,21 @@
ConversionTarget target(*context);
target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
return llvm::all_of(op.getArguments(), [&](Value arg) {
- return !memrefUsageAnalysis->vectorizeMemRef(arg);
+ return !memrefUsageAnalysis->shouldVectorizeMemRef(arg);
});
});
target.addDynamicallyLegalOp<memref::AllocOp>([&](memref::AllocOp alloc) {
- return !memrefUsageAnalysis->vectorizeMemRef(alloc);
+ return !memrefUsageAnalysis->shouldVectorizeMemRef(alloc);
});
target.addDynamicallyLegalOp<IREE::HAL::InterfaceBindingSubspanOp>(
[&](IREE::HAL::InterfaceBindingSubspanOp bindingOp) {
- return !memrefUsageAnalysis->vectorizeMemRef(bindingOp);
+ return !memrefUsageAnalysis->shouldVectorizeMemRef(bindingOp);
});
target.markUnknownOpDynamicallyLegal([&](Operation *op) {
if (isa<vector::TransferWriteOp, vector::TransferReadOp>(op))
- return !memrefUsageAnalysis->transferConvert(op);
+ return !memrefUsageAnalysis->shouldConvertTransfer(op);
if (auto dealloc = dyn_cast<memref::DeallocOp>(op))
- return !memrefUsageAnalysis->vectorizeMemRef(dealloc.memref());
+ return !memrefUsageAnalysis->shouldVectorizeMemRef(dealloc.memref());
return true;
});
if (failed(applyPartialConversion(module, target,
@@ -526,7 +534,10 @@
.add<ScalarizeVectorTransferRead, ScalarizeVectorTransferWrite>(
context);
- (void)applyPatternsAndFoldGreedily(func, std::move(rewritingPatterns));
+ if (failed(
+ applyPatternsAndFoldGreedily(func, std::move(rewritingPatterns)))) {
+ return signalPassFailure();
+ }
}
}
diff --git a/iree/compiler/Codegen/SPIRV/test/vectorize_load_store.mlir b/iree/compiler/Codegen/SPIRV/test/vectorize_load_store.mlir
index 36e1236..2645075 100644
--- a/iree/compiler/Codegen/SPIRV/test/vectorize_load_store.mlir
+++ b/iree/compiler/Codegen/SPIRV/test/vectorize_load_store.mlir
@@ -1,6 +1,6 @@
// RUN: iree-opt -split-input-file -iree-spirv-vectorize-load-store -canonicalize %s | IreeFileCheck %s
-// CHECK-LABEL: func @copy
+// CHECK-LABEL: func @alloc_copy
// CHECK-SAME: (%[[ARG0:.+]]: memref<4096x1024xvector<4xf32>>
// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<128x8xvector<4xf32>, 3>
// CHECK: %[[V:.+]] = memref.load %[[ARG0]][%{{.*}}, %{{.*}}] : memref<4096x1024xvector<4xf32>>
@@ -8,7 +8,7 @@
// CHECK: %[[MAT:.+]] = vector.transfer_read %[[ARG0]][%{{.*}}, %{{.*}}], %{{.*}} : memref<4096x1024xvector<4xf32>>, vector<32x8xf32>
// CHECK: vector.transfer_write %[[MAT]], %[[ALLOC]][%{{.*}}, %{{.*}}] : vector<32x8xf32>, memref<128x8xvector<4xf32>, 3>
// CHECK: memref.dealloc %[[ALLOC]] : memref<128x8xvector<4xf32>, 3>
-func @copy(%arg0: memref<4096x4096xf32>, %x: index, %y: index) {
+func @alloc_copy(%arg0: memref<4096x4096xf32>, %x: index, %y: index) {
%cst = arith.constant 0.000000e+00 : f32
%0 = memref.alloc() : memref<128x32xf32, 3>
%v = vector.transfer_read %arg0[%x, %y], %cst : memref<4096x4096xf32>, vector<1x4xf32>
@@ -22,9 +22,10 @@
// -----
// Test that the memref is not vectorized if used by scalar load or store.
-// CHECK-LABEL: func @copy
+
+// CHECK-LABEL: func @alloc_copy
// CHECK-SAME: %[[ARG0:.+]]: memref<4096x4096xf32>
-func @copy(%arg0: memref<4096x4096xf32>, %x: index, %y: index) {
+func @alloc_copy(%arg0: memref<4096x4096xf32>, %x: index, %y: index) {
%cst = arith.constant 0.000000e+00 : f32
%0 = memref.alloc() : memref<128x32xf32, 3>
%s = memref.load %arg0[%x, %y] : memref<4096x4096xf32>
diff --git a/iree/compiler/Codegen/SPIRV/test/vectorize_matmul.mlir b/iree/compiler/Codegen/SPIRV/test/vectorize_matmul.mlir
index d94eb39..07c490e 100644
--- a/iree/compiler/Codegen/SPIRV/test/vectorize_matmul.mlir
+++ b/iree/compiler/Codegen/SPIRV/test/vectorize_matmul.mlir
@@ -81,3 +81,37 @@
// CHECK: %[[FMA_3:.+]] = vector.fma %[[LHS_3_VECTOR]], %[[RHS_3_VECTOR]], %[[FMA_2]] : vector<4xf32>
// CHECK: %[[INSERT:.+]] = vector.insert %[[FMA_3]], %[[ZERO]] [0]
// CHECK: vector.transfer_write %[[INSERT]], %[[ACC_TILE]][%[[C0]], %[[C0]]]
+
+// -----
+
+// Check that we can vectorize shape dimensions not divisible by 4 but divisible by 2.
+
+func @matmul_8x8x2(%lhs: memref<8x2xf32>, %rhs: memref<2x8xf32>, %output: memref<8x8xf32>) {
+ linalg.matmul {__internal_linalg_transform__ = "vectorize"} ins(%lhs, %rhs: memref<8x2xf32>, memref<2x8xf32>) outs(%output: memref<8x8xf32>)
+ return
+}
+
+// CHECK-LABEL: func @matmul_8x8x2
+
+// CHECK-COUNT-8: vector.transfer_read {{.*}} : memref<8x2xf32>, vector<1x2xf32>
+// CHECK-COUNT-4: vector.transfer_read {{.*}} : memref<2x8xf32>, vector<1x4xf32>
+// CHECK-COUNT-16: vector.transfer_read {{.*}} : memref<8x8xf32>, vector<1x4xf32>
+// CHECK-COUNT-16: vector.fma
+// CHECK-COUNT-16: vector.transfer_write {{.*}} : vector<1x4xf32>, memref<8x8xf32>
+
+// -----
+
+// Check that we can vectorize shape dimensions not divisible by 4/2 but divisible by 1.
+
+func @matmul_8x8x1(%lhs: memref<8x1xf32>, %rhs: memref<1x8xf32>, %output: memref<8x8xf32>) {
+ linalg.matmul {__internal_linalg_transform__ = "vectorize"} ins(%lhs, %rhs: memref<8x1xf32>, memref<1x8xf32>) outs(%output: memref<8x8xf32>)
+ return
+}
+
+// CHECK-LABEL: func @matmul_8x8x1
+
+// CHECK-COUNT-8: vector.transfer_read {{.*}} : memref<8x1xf32>, vector<1x1xf32>
+// CHECK-COUNT-2: vector.transfer_read {{.*}} : memref<1x8xf32>, vector<1x4xf32>
+// CHECK-COUNT-16: vector.transfer_read {{.*}} : memref<8x8xf32>, vector<1x4xf32>
+// CHECK-COUNT-16: vector.fma
+// CHECK-COUNT-16: vector.transfer_write {{.*}} : vector<1x4xf32>, memref<8x8xf32>
diff --git a/iree/compiler/Codegen/Utils/BUILD b/iree/compiler/Codegen/Utils/BUILD
index ab7ec98..9a58e65 100644
--- a/iree/compiler/Codegen/Utils/BUILD
+++ b/iree/compiler/Codegen/Utils/BUILD
@@ -23,7 +23,7 @@
"Utils.h",
],
deps = [
- "//iree/compiler/Codegen/Dialect:ProcessorOpInterfaces",
+ "//iree/compiler/Codegen/Interfaces:ProcessorOpInterfaces",
"//iree/compiler/Dialect/Flow/IR",
"//iree/compiler/Dialect/HAL/IR",
"//llvm-external-projects/iree-dialects:IREELinalgExtDialect",
diff --git a/iree/compiler/Codegen/Utils/CMakeLists.txt b/iree/compiler/Codegen/Utils/CMakeLists.txt
index 5ff2816..1211b11 100644
--- a/iree/compiler/Codegen/Utils/CMakeLists.txt
+++ b/iree/compiler/Codegen/Utils/CMakeLists.txt
@@ -26,7 +26,7 @@
MLIRLinalg
MLIRLinalgTransforms
MLIRSupport
- iree::compiler::Codegen::Dialect::ProcessorOpInterfaces
+ iree::compiler::Codegen::Interfaces::ProcessorOpInterfaces
iree::compiler::Dialect::Flow::IR
iree::compiler::Dialect::HAL::IR
PUBLIC
diff --git a/iree/compiler/Codegen/Utils/Utils.cpp b/iree/compiler/Codegen/Utils/Utils.cpp
index 263d581..e6c36ed 100644
--- a/iree/compiler/Codegen/Utils/Utils.cpp
+++ b/iree/compiler/Codegen/Utils/Utils.cpp
@@ -7,7 +7,7 @@
#include "iree/compiler/Codegen/Utils/Utils.h"
#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h"
-#include "iree/compiler/Codegen/Dialect/ProcessorOpInterfaces.h"
+#include "iree/compiler/Codegen/Interfaces/ProcessorOpInterfaces.h"
#include "iree/compiler/Codegen/Utils/MarkerUtils.h"
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
diff --git a/iree/compiler/Dialect/Flow/IR/BUILD b/iree/compiler/Dialect/Flow/IR/BUILD
index 5466c08..a63e27d 100644
--- a/iree/compiler/Dialect/Flow/IR/BUILD
+++ b/iree/compiler/Dialect/Flow/IR/BUILD
@@ -25,7 +25,6 @@
include = ["*.td"],
),
deps = [
- "//iree/compiler/Dialect/Shape/IR:td_files",
"//iree/compiler/Dialect/Util/IR:td_files",
"@llvm-project//mlir:ControlFlowInterfacesTdFiles",
"@llvm-project//mlir:InferTypeOpInterfaceTdFiles",
@@ -64,7 +63,6 @@
":FlowInterfacesGen",
":FlowOpsGen",
":FlowTypesGen",
- "//iree/compiler/Dialect/Shape/IR",
"//iree/compiler/Dialect/Util/IR",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:ControlFlowInterfaces",
diff --git a/iree/compiler/Dialect/Flow/IR/CMakeLists.txt b/iree/compiler/Dialect/Flow/IR/CMakeLists.txt
index 5a8860a..bad9bd3 100644
--- a/iree/compiler/Dialect/Flow/IR/CMakeLists.txt
+++ b/iree/compiler/Dialect/Flow/IR/CMakeLists.txt
@@ -48,7 +48,6 @@
MLIRSupport
MLIRTensor
MLIRTransformUtils
- iree::compiler::Dialect::Shape::IR
iree::compiler::Dialect::Util::IR
PUBLIC
)
diff --git a/iree/compiler/Dialect/Flow/IR/FlowBase.td b/iree/compiler/Dialect/Flow/IR/FlowBase.td
index 59dd818..5370f6e 100644
--- a/iree/compiler/Dialect/Flow/IR/FlowBase.td
+++ b/iree/compiler/Dialect/Flow/IR/FlowBase.td
@@ -9,7 +9,6 @@
include "iree/compiler/Dialect/Flow/IR/FlowInterfaces.td"
include "iree/compiler/Dialect/Util/IR/UtilBase.td"
-include "iree/compiler/Dialect/Shape/IR/ShapeBase.td"
//===----------------------------------------------------------------------===//
// IREE execution flow dialect
diff --git a/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp b/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp
index efad188..b3080da 100644
--- a/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp
+++ b/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp
@@ -9,7 +9,6 @@
#include "iree/compiler/Dialect/Flow/IR/FlowDialect.h"
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
-#include "iree/compiler/Dialect/Shape/IR/ShapeOps.h"
#include "iree/compiler/Dialect/Util/IR/ClosureOpUtils.h"
#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
#include "iree/compiler/Dialect/Util/IR/UtilTypes.h"
diff --git a/iree/compiler/Dialect/Flow/IR/FlowOps.cpp b/iree/compiler/Dialect/Flow/IR/FlowOps.cpp
index 34b269f..85456cc 100644
--- a/iree/compiler/Dialect/Flow/IR/FlowOps.cpp
+++ b/iree/compiler/Dialect/Flow/IR/FlowOps.cpp
@@ -6,7 +6,6 @@
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
-#include "iree/compiler/Dialect/Shape/IR/Builders.h"
#include "iree/compiler/Dialect/Util/IR/ClosureOpUtils.h"
#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
#include "iree/compiler/Dialect/Util/IR/UtilTypes.h"
@@ -341,18 +340,6 @@
return success();
}
-Value DispatchWorkgroupsOp::buildOperandRankedShape(unsigned idx,
- OpBuilder &builder) {
- return Shape::buildRankedShapeForValueInList(getLoc(), idx, getOperands(),
- operand_dims(), builder);
-}
-
-Value DispatchWorkgroupsOp::buildResultRankedShape(unsigned idx,
- OpBuilder &builder) {
- return Shape::buildRankedShapeForValueInList(getLoc(), idx, getResults(),
- result_dims(), builder);
-}
-
Operation::operand_range DispatchWorkgroupsOp::getClosureOperands() {
return operands();
}
@@ -693,16 +680,6 @@
return success();
}
-Value DispatchOp::buildOperandRankedShape(unsigned idx, OpBuilder &builder) {
- return Shape::buildRankedShapeForValueInList(getLoc(), idx, getOperands(),
- operand_dims(), builder);
-}
-
-Value DispatchOp::buildResultRankedShape(unsigned idx, OpBuilder &builder) {
- return Shape::buildRankedShapeForValueInList(getLoc(), idx, getResults(),
- result_dims(), builder);
-}
-
std::pair<unsigned, unsigned> DispatchOp::getTiedOperandsIndexAndLength() {
return getODSOperandIndexAndLength(1); // $operands
}
@@ -711,18 +688,6 @@
// flow.tensor.reshape
//===----------------------------------------------------------------------===//
-Value TensorReshapeOp::buildOperandRankedShape(unsigned idx,
- OpBuilder &builder) {
- return Shape::buildRankedShapeForValue(getLoc(), source(), source_dims(),
- builder);
-}
-
-Value TensorReshapeOp::buildResultRankedShape(unsigned idx,
- OpBuilder &builder) {
- return Shape::buildRankedShapeForValue(getLoc(), result(), result_dims(),
- builder);
-}
-
Value TensorReshapeOp::getTiedResult(unsigned resultIndex) {
return IREE::Util::TiedOpInterface::findTiedBaseValue(source());
}
@@ -737,58 +702,6 @@
}
//===----------------------------------------------------------------------===//
-// flow.tensor.*
-//===----------------------------------------------------------------------===//
-
-Value TensorLoadOp::buildOperandRankedShape(unsigned idx, OpBuilder &builder) {
- return Shape::buildRankedShapeForValue(getLoc(), source(), source_dims(),
- builder);
-}
-
-Value TensorLoadOp::buildResultRankedShape(unsigned idx, OpBuilder &builder) {
- return {};
-}
-
-Value TensorStoreOp::buildOperandRankedShape(unsigned idx, OpBuilder &builder) {
- return Shape::buildRankedShapeForValue(getLoc(), target(), target_dims(),
- builder);
-}
-
-Value TensorStoreOp::buildResultRankedShape(unsigned idx, OpBuilder &builder) {
- return Shape::buildRankedShapeForValue(getLoc(), result(), target_dims(),
- builder);
-}
-
-Value TensorSplatOp::buildOperandRankedShape(unsigned idx, OpBuilder &builder) {
- return {};
-}
-
-Value TensorSplatOp::buildResultRankedShape(unsigned idx, OpBuilder &builder) {
- return Shape::buildRankedShapeForValue(getLoc(), result(), result_dims(),
- builder);
-}
-
-Value TensorCloneOp::buildOperandRankedShape(unsigned idx, OpBuilder &builder) {
- return Shape::buildRankedShapeForValue(getLoc(), operand(), operand_dims(),
- builder);
-}
-
-Value TensorCloneOp::buildResultRankedShape(unsigned idx, OpBuilder &builder) {
- return Shape::buildRankedShapeForValue(getLoc(), result(), operand_dims(),
- builder);
-}
-
-Value TensorSliceOp::buildOperandRankedShape(unsigned idx, OpBuilder &builder) {
- return Shape::buildRankedShapeForValue(getLoc(), source(), source_dims(),
- builder);
-}
-
-Value TensorSliceOp::buildResultRankedShape(unsigned idx, OpBuilder &builder) {
- return Shape::buildRankedShapeForValue(getLoc(), result(), result_dims(),
- builder);
-}
-
-//===----------------------------------------------------------------------===//
// flow.tensor.update
//===----------------------------------------------------------------------===//
@@ -796,9 +709,9 @@
Value target, ValueRange startIndices,
Value update) {
auto targetDims =
- Shape::buildOrFindDynamicDimsForValue(state.location, target, builder);
+ IREE::Util::buildDynamicDimsForValue(state.location, target, builder);
auto updateDims =
- Shape::buildOrFindDynamicDimsForValue(state.location, update, builder);
+ IREE::Util::buildDynamicDimsForValue(state.location, update, builder);
build(builder, state, target.getType(), target, targetDims, startIndices,
update, updateDims, builder.getIndexArrayAttr({0}));
}
@@ -811,22 +724,6 @@
return success();
}
-Value TensorUpdateOp::buildOperandRankedShape(unsigned idx,
- OpBuilder &builder) {
- if (idx == 0) {
- return Shape::buildRankedShapeForValueInList(getLoc(), idx, getOperands(),
- target_dims(), builder);
- } else {
- return Shape::buildRankedShapeForValueInList(getLoc(), idx, getOperands(),
- update_dims(), builder);
- }
-}
-
-Value TensorUpdateOp::buildResultRankedShape(unsigned idx, OpBuilder &builder) {
- return Shape::buildRankedShapeForValue(getLoc(), target(), target_dims(),
- builder);
-}
-
Value TensorUpdateOp::getTiedResult(unsigned resultIndex) {
return IREE::Util::TiedOpInterface::findTiedBaseValue(target());
}
diff --git a/iree/compiler/Dialect/Flow/IR/FlowOps.h b/iree/compiler/Dialect/Flow/IR/FlowOps.h
index a455cad..5a964c1 100644
--- a/iree/compiler/Dialect/Flow/IR/FlowOps.h
+++ b/iree/compiler/Dialect/Flow/IR/FlowOps.h
@@ -11,8 +11,6 @@
#include "iree/compiler/Dialect/Flow/IR/FlowDialect.h"
#include "iree/compiler/Dialect/Flow/IR/FlowTypes.h"
-#include "iree/compiler/Dialect/Shape/IR/ShapeDialect.h"
-#include "iree/compiler/Dialect/Shape/IR/ShapeTypes.h"
#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
#include "iree/compiler/Dialect/Util/IR/UtilTraits.h"
#include "iree/compiler/Dialect/Util/IR/UtilTypes.h"
diff --git a/iree/compiler/Dialect/Flow/IR/FlowOps.td b/iree/compiler/Dialect/Flow/IR/FlowOps.td
index 1763666..af19305 100644
--- a/iree/compiler/Dialect/Flow/IR/FlowOps.td
+++ b/iree/compiler/Dialect/Flow/IR/FlowOps.td
@@ -10,7 +10,6 @@
include "iree/compiler/Dialect/Flow/IR/FlowBase.td"
include "iree/compiler/Dialect/Flow/IR/FlowInterfaces.td"
include "iree/compiler/Dialect/Util/IR/UtilInterfaces.td"
-include "iree/compiler/Dialect/Shape/IR/ShapeInterfaces.td"
include "mlir/IR/OpAsmInterface.td"
include "mlir/IR/SymbolInterfaces.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
@@ -33,7 +32,6 @@
DeclareOpInterfaceMethods<Util_TiedOpInterface, [
"getTiedOperandsIndexAndLength",
]>,
- DeclareOpInterfaceMethods<Shape_ShapeCarryingOpInterface>,
Util_ShapeAwareOp,
]> {
let summary = [{a dispatch of workgroups across an n-dimension grid}];
@@ -324,7 +322,7 @@
];
let extraClassDeclaration = [{
- /// Return the expected rank of each of the`static_offsets`, `static_sizes`
+ /// Return the expected rank of each of the `static_offsets`, `static_sizes`
/// and `static_strides` attributes.
std::array<unsigned, 3> getArrayAttrMaxRanks() {
unsigned sourceRank = source().getType().cast<DispatchTensorType>().asTensorType().getRank();
@@ -398,7 +396,7 @@
];
let extraClassDeclaration = [{
- /// Return the expected rank of each of the`static_offsets`, `static_sizes`
+ /// Return the expected rank of each of the `static_offsets`, `static_sizes`
/// and `static_strides` attributes.
std::array<unsigned, 3> getArrayAttrMaxRanks() {
unsigned resultRank = target().getType().cast<DispatchTensorType>().asTensorType().getRank();
@@ -530,7 +528,6 @@
DeclareOpInterfaceMethods<Util_TiedOpInterface, [
"getTiedOperandsIndexAndLength",
]>,
- DeclareOpInterfaceMethods<Shape_ShapeCarryingOpInterface>,
Util_ShapeAwareOp,
]> {
let summary = [{a dispatch of workgroups across an n-dimension grid}];
@@ -638,7 +635,6 @@
"getTiedResultOperandIndex",
"getTiedResultOperandIndices",
]>,
- DeclareOpInterfaceMethods<Shape_ShapeCarryingOpInterface>,
Util_ShapeAwareOp,
]> {
let summary = [{reshapes a tensor}];
@@ -669,7 +665,7 @@
build($_builder, $_state,
result_type,
source,
- Shape::buildOrFindDynamicDimsForValue($_state.location, source, $_builder),
+ IREE::Util::buildDynamicDimsForValue($_state.location, source, $_builder),
target_dims);
}]>,
];
@@ -691,7 +687,6 @@
"source", "result",
"$_self.cast<ShapedType>().getElementType()">,
AttrSizedOperandSegments,
- DeclareOpInterfaceMethods<Shape_ShapeCarryingOpInterface>,
Util_ShapeAwareOp,
]> {
let summary = [{loads a value from a tensor element}];
@@ -721,7 +716,7 @@
build($_builder, $_state,
result_type,
source,
- Shape::buildOrFindDynamicDimsForValue($_state.location, source, $_builder),
+ IREE::Util::buildDynamicDimsForValue($_state.location, source, $_builder),
indices);
}]>,
];
@@ -742,7 +737,6 @@
"target", "value",
"$_self.cast<ShapedType>().getElementType()">,
AttrSizedOperandSegments,
- DeclareOpInterfaceMethods<Shape_ShapeCarryingOpInterface>,
Util_ShapeAwareOp,
]> {
let summary = [{stores a value into a tensor element}];
@@ -774,7 +768,7 @@
target.getType(),
value,
target,
- Shape::buildOrFindDynamicDimsForValue($_state.location, target, $_builder),
+ IREE::Util::buildDynamicDimsForValue($_state.location, target, $_builder),
indices);
}]>,
];
@@ -792,7 +786,6 @@
TypesMatchWith<"value type matches element type of result",
"result", "value",
"$_self.cast<ShapedType>().getElementType()">,
- DeclareOpInterfaceMethods<Shape_ShapeCarryingOpInterface>,
Util_ShapeAwareOp,
]> {
let summary = [{splats a value into a shaped tensor}];
@@ -827,7 +820,6 @@
def FLOW_TensorCloneOp : FLOW_PureOp<"tensor.clone", [
FLOW_StreamableOp,
AllTypesMatch<["operand", "result"]>,
- DeclareOpInterfaceMethods<Shape_ShapeCarryingOpInterface>,
Util_ShapeAwareOp,
]> {
let summary = [{performs a full tensor clone operation}];
@@ -854,7 +846,7 @@
build($_builder, $_state,
operand.getType(),
operand,
- Shape::buildOrFindDynamicDimsForValue($_state.location, operand, $_builder));
+ IREE::Util::buildDynamicDimsForValue($_state.location, operand, $_builder));
}]>,
];
@@ -875,7 +867,6 @@
AllRanksMatch<["source", "result"]>,
AllElementTypesMatch<["source", "result"]>,
AttrSizedOperandSegments,
- DeclareOpInterfaceMethods<Shape_ShapeCarryingOpInterface>,
Util_ShapeAwareOp,
]> {
let summary = [{slices out a subregion of a tensor}];
@@ -924,7 +915,6 @@
"getTiedResultOperandIndex",
"getTiedResultOperandIndices",
]>,
- DeclareOpInterfaceMethods<Shape_ShapeCarryingOpInterface>,
Util_ShapeAwareOp,
]> {
let summary = [{updates a tensor with the contents of another tensor}];
diff --git a/iree/compiler/Dialect/Flow/IR/FlowTypes.h b/iree/compiler/Dialect/Flow/IR/FlowTypes.h
index af95021..69d731f 100644
--- a/iree/compiler/Dialect/Flow/IR/FlowTypes.h
+++ b/iree/compiler/Dialect/Flow/IR/FlowTypes.h
@@ -8,7 +8,6 @@
#define IREE_COMPILER_DIALECT_FLOW_IR_FLOWTYPES_H_
#include "iree/compiler/Dialect/Flow/IR/FlowDialect.h"
-#include "iree/compiler/Dialect/Shape/IR/ShapeTypes.h"
#include "iree/compiler/Dialect/Util/IR/UtilTypes.h"
#include "llvm/ADT/DenseMapInfo.h"
#include "llvm/ADT/SmallVector.h"
@@ -135,10 +134,6 @@
TensorType asTensorType() const {
return RankedTensorType::get(getShape(), getElementType());
}
-
- Shape::RankedShapeType asRankedShapeType() const {
- return Shape::RankedShapeType::get(getShape(), getContext());
- }
};
void printType(DispatchTensorType &type, DialectAsmPrinter &p);
diff --git a/iree/compiler/Dialect/Flow/Transforms/BUILD b/iree/compiler/Dialect/Flow/Transforms/BUILD
index 0a4dbe8..e52a14d 100644
--- a/iree/compiler/Dialect/Flow/Transforms/BUILD
+++ b/iree/compiler/Dialect/Flow/Transforms/BUILD
@@ -64,8 +64,6 @@
"//iree/compiler/Dialect/Flow/Conversion/TensorToFlow",
"//iree/compiler/Dialect/Flow/IR",
"//iree/compiler/Dialect/HAL/IR",
- "//iree/compiler/Dialect/Shape/IR",
- "//iree/compiler/Dialect/Shape/Transforms",
"//iree/compiler/Dialect/Util/IR",
"//iree/compiler/Dialect/Util/Transforms",
"//iree/compiler/Utils",
diff --git a/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt b/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt
index 0c57300..e381f09 100644
--- a/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt
+++ b/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt
@@ -71,8 +71,6 @@
iree::compiler::Dialect::Flow::Conversion::TensorToFlow
iree::compiler::Dialect::Flow::IR
iree::compiler::Dialect::HAL::IR
- iree::compiler::Dialect::Shape::IR
- iree::compiler::Dialect::Shape::Transforms
iree::compiler::Dialect::Util::IR
iree::compiler::Dialect::Util::Transforms
iree::compiler::Utils
diff --git a/iree/compiler/Dialect/Flow/Transforms/ConvertConv2D1x1ToMatmulPass.cpp b/iree/compiler/Dialect/Flow/Transforms/ConvertConv2D1x1ToMatmulPass.cpp
index 0305b4a..ccf1568 100644
--- a/iree/compiler/Dialect/Flow/Transforms/ConvertConv2D1x1ToMatmulPass.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/ConvertConv2D1x1ToMatmulPass.cpp
@@ -102,7 +102,10 @@
MLIRContext *context = &getContext();
OwningRewritePatternList patterns(&getContext());
patterns.insert<Convert1x1ConvolutionMatmulOp>(context);
- (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+ if (failed(applyPatternsAndFoldGreedily(getOperation(),
+ std::move(patterns)))) {
+ return signalPassFailure();
+ }
}
};
} // namespace
diff --git a/iree/compiler/Dialect/Flow/Transforms/ConvertConv2DToImg2ColPass.cpp b/iree/compiler/Dialect/Flow/Transforms/ConvertConv2DToImg2ColPass.cpp
index 4fca00e..353bca7 100644
--- a/iree/compiler/Dialect/Flow/Transforms/ConvertConv2DToImg2ColPass.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/ConvertConv2DToImg2ColPass.cpp
@@ -350,7 +350,10 @@
OwningRewritePatternList patterns(&getContext());
patterns.insert<Conv2DImg2ColMatmulConversion,
DepthwiseConv2DNHWCHWCImg2ColMatmulConversion>(context);
- (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+ if (failed(applyPatternsAndFoldGreedily(getOperation(),
+ std::move(patterns)))) {
+ return signalPassFailure();
+ }
}
};
diff --git a/iree/compiler/Dialect/Flow/Transforms/ConvertLinalgMatmulToMmt4D.cpp b/iree/compiler/Dialect/Flow/Transforms/ConvertLinalgMatmulToMmt4D.cpp
index b490187..bf1e315 100644
--- a/iree/compiler/Dialect/Flow/Transforms/ConvertLinalgMatmulToMmt4D.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/ConvertLinalgMatmulToMmt4D.cpp
@@ -360,7 +360,10 @@
OwningRewritePatternList patterns(&getContext());
patterns.insert<LinalgMatmulOpToLinalgMmt4DOpPattern>(context, M0, K0,
N0);
- (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+ if (failed(applyPatternsAndFoldGreedily(getOperation(),
+ std::move(patterns)))) {
+ return signalPassFailure();
+ }
}
// Canonicalization.
{
@@ -368,7 +371,10 @@
linalg::TensorExpandShapeOp::getCanonicalizationPatterns(patterns,
context);
patterns.insert<FoldFillGenericOpPattern>(context);
- (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+ if (failed(applyPatternsAndFoldGreedily(getOperation(),
+ std::move(patterns)))) {
+ return signalPassFailure();
+ }
}
}
};
diff --git a/iree/compiler/Dialect/Flow/Transforms/DestructiveUpdateUtils.cpp b/iree/compiler/Dialect/Flow/Transforms/DestructiveUpdateUtils.cpp
index 441c732..bc54d30 100644
--- a/iree/compiler/Dialect/Flow/Transforms/DestructiveUpdateUtils.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/DestructiveUpdateUtils.cpp
@@ -368,9 +368,8 @@
// affineminscf and others as needed.
OwningRewritePatternList canonicalizationPatterns(context);
scf::ForOp::getCanonicalizationPatterns(canonicalizationPatterns, context);
- (void)applyPatternsAndFoldGreedily(dispatchOp,
- std::move(canonicalizationPatterns));
- return success();
+ return applyPatternsAndFoldGreedily(dispatchOp,
+ std::move(canonicalizationPatterns));
}
} // namespace Flow
diff --git a/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp b/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp
index a6d932f..b6a0dee 100644
--- a/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp
@@ -985,7 +985,7 @@
void getDependentDialects(DialectRegistry ®istry) const override {
registry
.insert<AffineDialect, IREE::Flow::FlowDialect, linalg::LinalgDialect,
- scf::SCFDialect, ShapeDialect, tensor::TensorDialect>();
+ scf::SCFDialect, tensor::TensorDialect>();
}
DispatchLinalgOnTensorsPass() = default;
DispatchLinalgOnTensorsPass(const DispatchLinalgOnTensorsPass &pass) {}
@@ -1109,7 +1109,9 @@
// update subtensor_insert ops will be turned into flow dispatch output
// store ops.
tensor::InsertSliceOp::getCanonicalizationPatterns(patterns, context);
- (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
+ if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
+ return failure();
+ }
}
// After outlining in dispatch region we can rewrite the dispatch ops with
diff --git a/iree/compiler/Dialect/Flow/Transforms/FusionOfTensorOps.cpp b/iree/compiler/Dialect/Flow/Transforms/FusionOfTensorOps.cpp
index a969b19..f6f902b 100644
--- a/iree/compiler/Dialect/Flow/Transforms/FusionOfTensorOps.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/FusionOfTensorOps.cpp
@@ -127,8 +127,10 @@
.setControlFoldingReshapes(foldReshapeBetweenLinalgFn)
.setControlElementwiseOpsFusionFn(controlFn));
- (void)applyPatternsAndFoldGreedily(op->getRegions(),
- std::move(fusionPatterns));
+ if (failed(applyPatternsAndFoldGreedily(op->getRegions(),
+ std::move(fusionPatterns)))) {
+ return signalPassFailure();
+ }
OwningRewritePatternList reshapeCanonicalizations(&getContext());
linalg::populateFoldUnitDimsReshapeOpsByLinearizationPatterns(
@@ -137,8 +139,10 @@
reshapeCanonicalizations, context);
linalg::TensorExpandShapeOp::getCanonicalizationPatterns(
reshapeCanonicalizations, context);
- (void)applyPatternsAndFoldGreedily(op->getRegions(),
- std::move(reshapeCanonicalizations));
+ if (failed(applyPatternsAndFoldGreedily(
+ op->getRegions(), std::move(reshapeCanonicalizations)))) {
+ return signalPassFailure();
+ }
// Push the remaining reshapes down the graphs.
OwningRewritePatternList pushReshapePatterns(&getContext());
@@ -147,8 +151,10 @@
pushReshapePatterns, context);
linalg::TensorExpandShapeOp::getCanonicalizationPatterns(
pushReshapePatterns, context);
- (void)applyPatternsAndFoldGreedily(op->getRegions(),
- std::move(pushReshapePatterns));
+ if (failed(applyPatternsAndFoldGreedily(op->getRegions(),
+ std::move(pushReshapePatterns)))) {
+ return signalPassFailure();
+ }
}
};
diff --git a/iree/compiler/Dialect/Flow/Transforms/InterchangeGenericOps.cpp b/iree/compiler/Dialect/Flow/Transforms/InterchangeGenericOps.cpp
index b091d01..caead1a 100644
--- a/iree/compiler/Dialect/Flow/Transforms/InterchangeGenericOps.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/InterchangeGenericOps.cpp
@@ -62,7 +62,10 @@
void runOnOperation() override {
OwningRewritePatternList patterns(&getContext());
patterns.add<GenericOpInterchangePattern>(&getContext());
- (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+ if (failed(applyPatternsAndFoldGreedily(getOperation(),
+ std::move(patterns)))) {
+ return signalPassFailure();
+ }
}
};
diff --git a/iree/compiler/Dialect/Flow/Transforms/OutlineDispatchRegions.cpp b/iree/compiler/Dialect/Flow/Transforms/OutlineDispatchRegions.cpp
index 8c63a11..8117c55 100644
--- a/iree/compiler/Dialect/Flow/Transforms/OutlineDispatchRegions.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/OutlineDispatchRegions.cpp
@@ -9,9 +9,6 @@
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
#include "iree/compiler/Dialect/Flow/Transforms/PassDetail.h"
#include "iree/compiler/Dialect/Flow/Transforms/Passes.h"
-#include "iree/compiler/Dialect/Shape/IR/Builders.h"
-#include "iree/compiler/Dialect/Shape/IR/ShapeOps.h"
-#include "iree/compiler/Dialect/Shape/IR/ShapeTypes.h"
#include "llvm/Support/Debug.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/BlockAndValueMapping.h"
diff --git a/iree/compiler/Dialect/Flow/Transforms/PadLinalgOps.cpp b/iree/compiler/Dialect/Flow/Transforms/PadLinalgOps.cpp
index 356e268..e9b5514 100644
--- a/iree/compiler/Dialect/Flow/Transforms/PadLinalgOps.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/PadLinalgOps.cpp
@@ -139,7 +139,10 @@
MLIRContext *context = &getContext();
OwningRewritePatternList patterns(context);
patterns.insert<PadMatmulOp>(context, paddingSize);
- (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+ if (failed(applyPatternsAndFoldGreedily(getOperation(),
+ std::move(patterns)))) {
+ return signalPassFailure();
+ }
}
private:
diff --git a/iree/compiler/Dialect/Flow/Transforms/Passes.cpp b/iree/compiler/Dialect/Flow/Transforms/Passes.cpp
index f0b42c6..e59f394 100644
--- a/iree/compiler/Dialect/Flow/Transforms/Passes.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/Passes.cpp
@@ -8,7 +8,6 @@
#include <memory>
-#include "iree/compiler/Dialect/Shape/Transforms/Passes.h"
#include "iree/compiler/Dialect/Util/Transforms/Passes.h"
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
diff --git a/iree/compiler/Dialect/Flow/Transforms/PromoteI1ToI8Pass.cpp b/iree/compiler/Dialect/Flow/Transforms/PromoteI1ToI8Pass.cpp
index 37f8d60..7919e68 100644
--- a/iree/compiler/Dialect/Flow/Transforms/PromoteI1ToI8Pass.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/PromoteI1ToI8Pass.cpp
@@ -103,7 +103,10 @@
void runOnOperation() override {
OwningRewritePatternList patterns(&getContext());
patterns.insert<ConvertBoolConstantPattern>(&getContext());
- (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+ if (failed(applyPatternsAndFoldGreedily(getOperation(),
+ std::move(patterns)))) {
+ return signalPassFailure();
+ }
}
};
diff --git a/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir b/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir
index d986c48..fa7e415 100644
--- a/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir
+++ b/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir
@@ -855,7 +855,7 @@
%6 = linalg.matmul ins(%0, %1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%5 : tensor<?x?xf32>) -> tensor<?x?xf32>
%7 = tensor.dim %6, %c0 : tensor<?x?xf32>
%8 = tensor.dim %6, %c1 : tensor<?x?xf32>
- %9 = hal.tensor.cast %6 : tensor<?x?xf32>{%7, %8} -> !hal.buffer_view
+ %9 = hal.tensor.export %6 : tensor<?x?xf32>{%7, %8} -> !hal.buffer_view
return %9 : !hal.buffer_view
}
// CHECK-LABEL: func @dynamic_dot()
diff --git a/iree/compiler/Dialect/HAL/Conversion/StandardToHAL/BUILD b/iree/compiler/Dialect/HAL/Conversion/StandardToHAL/BUILD
index fcd85f5..27740f7 100644
--- a/iree/compiler/Dialect/HAL/Conversion/StandardToHAL/BUILD
+++ b/iree/compiler/Dialect/HAL/Conversion/StandardToHAL/BUILD
@@ -25,7 +25,6 @@
"//iree/compiler/Dialect/HAL/IR",
"//iree/compiler/Dialect/HAL/IR:HALDialect",
"//iree/compiler/Dialect/HAL/Target",
- "//iree/compiler/Dialect/Shape/IR",
"//iree/compiler/Dialect/Util/Conversion",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
diff --git a/iree/compiler/Dialect/HAL/Conversion/StandardToHAL/CMakeLists.txt b/iree/compiler/Dialect/HAL/Conversion/StandardToHAL/CMakeLists.txt
index 80028c5..d777f0b 100644
--- a/iree/compiler/Dialect/HAL/Conversion/StandardToHAL/CMakeLists.txt
+++ b/iree/compiler/Dialect/HAL/Conversion/StandardToHAL/CMakeLists.txt
@@ -32,7 +32,6 @@
iree::compiler::Dialect::HAL::IR
iree::compiler::Dialect::HAL::IR::HALDialect
iree::compiler::Dialect::HAL::Target
- iree::compiler::Dialect::Shape::IR
iree::compiler::Dialect::Util::Conversion
PUBLIC
)
diff --git a/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/ConvertStreamToHAL.cpp b/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/ConvertStreamToHAL.cpp
index 13169bb..033bed2 100644
--- a/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/ConvertStreamToHAL.cpp
+++ b/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/ConvertStreamToHAL.cpp
@@ -1100,6 +1100,62 @@
}
};
+struct TimepointImportOpPattern
+ : public StreamConversionPattern<IREE::Stream::TimepointImportOp> {
+ using StreamConversionPattern::StreamConversionPattern;
+ LogicalResult matchAndRewrite(
+ IREE::Stream::TimepointImportOp importOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ // Only handle imports from HAL semaphores.
+ auto operands = adaptor.operands();
+ if (operands.size() != 2 ||
+ !operands[0].getType().isa<IREE::HAL::SemaphoreType>() ||
+ !operands[1].getType().isIntOrIndex()) {
+ return rewriter.notifyMatchFailure(importOp,
+ "only imports from HAL semaphore + "
+ "sequence value tuples are supported");
+ }
+
+ // TODO(benvanik): model timepoints as semaphores.
+ // For now we just block on the semaphore.
+ auto awaitOp = rewriter.create<IREE::HAL::SemaphoreAwaitOp>(
+ importOp.getLoc(), rewriter.getI32Type(), operands[0], operands[1]);
+ rewriter.create<IREE::Util::StatusCheckOkOp>(
+ importOp.getLoc(), awaitOp.status(),
+ "failed to wait on imported semaphore");
+ rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(importOp, 0);
+ return success();
+ }
+};
+
+struct TimepointExportOpPattern
+ : public StreamConversionPattern<IREE::Stream::TimepointExportOp> {
+ using StreamConversionPattern::StreamConversionPattern;
+ LogicalResult matchAndRewrite(
+ IREE::Stream::TimepointExportOp exportOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ // Only handle exports into HAL semaphores.
+ if (exportOp.getNumResults() != 2 ||
+ !exportOp.getResult(0).getType().isa<IREE::HAL::SemaphoreType>() ||
+ !exportOp.getResult(1).getType().isIntOrIndex()) {
+ return rewriter.notifyMatchFailure(exportOp,
+ "only exports to HAL semaphore + "
+ "sequence value tuples are supported");
+ }
+
+ auto loc = exportOp.getLoc();
+ auto device = lookupDeviceFor(exportOp, rewriter);
+
+ // TODO(benvanik): model timepoints as semaphores.
+ // For now we just create a signaled semaphore.
+ auto exportValue = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ auto exportSemaphore = rewriter.create<IREE::HAL::SemaphoreCreateOp>(
+ loc, rewriter.getType<IREE::HAL::SemaphoreType>(), device, exportValue);
+ rewriter.replaceOp(exportOp, {exportSemaphore, exportValue});
+ return success();
+ }
+};
+
struct TimepointJoinOpPattern
: public StreamConversionPattern<IREE::Stream::TimepointJoinOp> {
using StreamConversionPattern::StreamConversionPattern;
@@ -1197,7 +1253,8 @@
CmdFillOpPattern, CmdCopyOpPattern, CmdDispatchOpPattern,
CmdExecuteOpPattern, CmdSerialOpPattern, CmdConcurrentOpPattern>(
mapping, typeConverter, context);
- patterns.insert<TimepointImmediateOpPattern, TimepointJoinOpPattern,
+ patterns.insert<TimepointImmediateOpPattern, TimepointImportOpPattern,
+ TimepointExportOpPattern, TimepointJoinOpPattern,
TimepointAwaitOpPattern>(mapping, typeConverter, context);
patterns.insert<ElideYieldOpPattern>(mapping, typeConverter, context);
}
diff --git a/iree/compiler/Dialect/HAL/Conversion/TypeConverter.cpp b/iree/compiler/Dialect/HAL/Conversion/TypeConverter.cpp
index e7796b4..950d91d 100644
--- a/iree/compiler/Dialect/HAL/Conversion/TypeConverter.cpp
+++ b/iree/compiler/Dialect/HAL/Conversion/TypeConverter.cpp
@@ -42,7 +42,7 @@
ValueRange inputs, Location loc) -> Value {
assert(inputs.size() == 1);
if (inputs[0].getType().isa<TensorType>()) {
- return builder.create<IREE::HAL::TensorCastOp>(loc, type, inputs[0]);
+ return builder.create<IREE::HAL::TensorExportOp>(loc, type, inputs[0]);
} else if (inputs[0].getType().isa<IREE::HAL::BufferViewType>()) {
return builder.create<IREE::HAL::BufferViewBufferOp>(loc, type,
inputs[0]);
@@ -59,7 +59,7 @@
auto inputValue = inputs[0];
auto inputType = inputValue.getType();
if (inputType.isa<TensorType>()) {
- return builder.create<IREE::HAL::TensorCastOp>(loc, type, inputValue);
+ return builder.create<IREE::HAL::TensorExportOp>(loc, type, inputValue);
} else if (inputType.isa<IREE::HAL::BufferType>()) {
// Look for the buffer view this buffer came from, if any.
// If we don't have the origin buffer view then we can't know the shape
diff --git a/iree/compiler/Dialect/HAL/Conversion/TypeConverter.h b/iree/compiler/Dialect/HAL/Conversion/TypeConverter.h
index e6c555a..4c57d9f 100644
--- a/iree/compiler/Dialect/HAL/Conversion/TypeConverter.h
+++ b/iree/compiler/Dialect/HAL/Conversion/TypeConverter.h
@@ -24,7 +24,7 @@
// TODO(benvanik): signature conversion for output buffers.
// Returns true if the given |type| maps to !hal.buffer_view by default.
- // hal.tensor.cast can be used by frontends to map to other types.
+ // hal.tensor.import/export can be used by frontends to map to other types.
static bool shouldConvertToBufferView(Type type) {
if (auto tensorType = type.template dyn_cast<TensorType>()) {
return tensorType.getElementType().isIntOrFloat();
diff --git a/iree/compiler/Dialect/HAL/IR/BUILD b/iree/compiler/Dialect/HAL/IR/BUILD
index 56c1ffe..5bd260c 100644
--- a/iree/compiler/Dialect/HAL/IR/BUILD
+++ b/iree/compiler/Dialect/HAL/IR/BUILD
@@ -31,7 +31,6 @@
include = ["*.td"],
),
deps = [
- "//iree/compiler/Dialect/Shape/IR:td_files",
"//iree/compiler/Dialect/Util/IR:td_files",
"@llvm-project//mlir:OpBaseTdFiles",
"@llvm-project//mlir:StdOpsTdFiles",
@@ -73,7 +72,6 @@
":HALOpsGen",
":HALStructsGen",
":HALTypesGen",
- "//iree/compiler/Dialect/Shape/IR",
"//iree/compiler/Dialect/Util/IR",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
@@ -95,7 +93,6 @@
":IR",
"//iree/compiler/Dialect/HAL:hal_imports",
"//iree/compiler/Dialect/HAL/Conversion/HALToVM",
- "//iree/compiler/Dialect/Shape/IR",
"//iree/compiler/Dialect/Util/IR",
"//iree/compiler/Dialect/VM/Conversion",
"@llvm-project//llvm:Support",
diff --git a/iree/compiler/Dialect/HAL/IR/CMakeLists.txt b/iree/compiler/Dialect/HAL/IR/CMakeLists.txt
index 75d123a..be77535 100644
--- a/iree/compiler/Dialect/HAL/IR/CMakeLists.txt
+++ b/iree/compiler/Dialect/HAL/IR/CMakeLists.txt
@@ -51,7 +51,6 @@
MLIRSupport
MLIRTransformUtils
MLIRViewLikeInterface
- iree::compiler::Dialect::Shape::IR
iree::compiler::Dialect::Util::IR
PUBLIC
)
@@ -73,7 +72,6 @@
MLIRTransformUtils
iree::compiler::Dialect::HAL::Conversion::HALToVM
iree::compiler::Dialect::HAL::hal_imports
- iree::compiler::Dialect::Shape::IR
iree::compiler::Dialect::Util::IR
iree::compiler::Dialect::VM::Conversion
PUBLIC
diff --git a/iree/compiler/Dialect/HAL/IR/HALOpFolders.cpp b/iree/compiler/Dialect/HAL/IR/HALOpFolders.cpp
index f1833e6..cd3eda8 100644
--- a/iree/compiler/Dialect/HAL/IR/HALOpFolders.cpp
+++ b/iree/compiler/Dialect/HAL/IR/HALOpFolders.cpp
@@ -25,23 +25,24 @@
namespace HAL {
//===----------------------------------------------------------------------===//
-// hal.tensor.cast
+// hal.tensor.import/export
//===----------------------------------------------------------------------===//
-OpFoldResult TensorCastOp::fold(ArrayRef<Attribute> operands) {
- if (source().getType() == target().getType()) {
- return source();
+OpFoldResult TensorImportOp::fold(ArrayRef<Attribute> operands) {
+ if (auto exportOp = source().getDefiningOp<TensorExportOp>()) {
+ if (exportOp.source().getType() == target().getType()) {
+ return exportOp.source();
+ }
}
+ return {};
+}
- // Cast of a cast can use the defining op's source.
- // This can apply recursively and may bottom out at source == target type.
- if (auto castOp = source().getDefiningOp<TensorCastOp>()) {
- auto mutableSource = sourceMutable();
- mutableSource.clear();
- mutableSource.append(castOp.source());
- return getResult();
+OpFoldResult TensorExportOp::fold(ArrayRef<Attribute> operands) {
+ if (auto importOp = source().getDefiningOp<TensorImportOp>()) {
+ if (importOp.source().getType() == target().getType()) {
+ return importOp.source();
+ }
}
-
return {};
}
diff --git a/iree/compiler/Dialect/HAL/IR/HALOps.cpp b/iree/compiler/Dialect/HAL/IR/HALOps.cpp
index 2257845..db46ee3 100644
--- a/iree/compiler/Dialect/HAL/IR/HALOps.cpp
+++ b/iree/compiler/Dialect/HAL/IR/HALOps.cpp
@@ -7,7 +7,6 @@
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
#include "iree/compiler/Dialect/HAL/IR/HALTypes.h"
-#include "iree/compiler/Dialect/Shape/IR/Builders.h"
#include "iree/compiler/Dialect/Util/IR/UtilTypes.h"
#include "llvm/ADT/Hashing.h"
#include "llvm/ADT/STLExtras.h"
@@ -180,74 +179,56 @@
}
//===----------------------------------------------------------------------===//
-// hal.tensor.cast
+// hal.tensor.import/export
//===----------------------------------------------------------------------===//
-void TensorCastOp::build(OpBuilder &builder, OperationState &result,
- Type resultType, Value source,
- ArrayRef<NamedAttribute> attrs) {
+void TensorImportOp::build(OpBuilder &builder, OperationState &result,
+ Type resultType, Value source) {
+ auto shapedType = resultType.cast<ShapedType>();
+ assert((source.getType().isa<IREE::HAL::BufferViewType>() ||
+ shapedType.hasStaticShape()) &&
+ "can only use this constructor for buffer views when shape "
+ "information is required");
SmallVector<Value> dynamicDims;
- if (source.getType().isa<IREE::HAL::BufferViewType>()) {
- auto shapedType = resultType.cast<ShapedType>();
- for (int64_t i = 0; i < shapedType.getRank(); ++i) {
- if (!shapedType.isDynamicDim(i)) continue;
- dynamicDims.push_back(builder.createOrFold<IREE::HAL::BufferViewDimOp>(
- result.location, builder.getIndexType(), source,
- builder.getIndexAttr(i)));
- }
- } else {
- dynamicDims =
- Shape::buildOrFindDynamicDimsForValue(result.location, source, builder);
+ for (int64_t i = 0; i < shapedType.getRank(); ++i) {
+ if (!shapedType.isDynamicDim(i)) continue;
+ dynamicDims.push_back(builder.createOrFold<IREE::HAL::BufferViewDimOp>(
+ result.location, builder.getIndexType(), source,
+ builder.getIndexAttr(i)));
}
- build(builder, result, resultType, source, dynamicDims, attrs);
+ build(builder, result, resultType, source, dynamicDims);
}
-void TensorCastOp::build(OpBuilder &builder, OperationState &result,
- Type resultType, Value source, ValueRange dynamicDims,
- ArrayRef<NamedAttribute> attrs) {
- result.addTypes({resultType});
- result.addOperands({source});
- result.addOperands({dynamicDims});
- result.addAttributes(attrs);
- result.addAttribute(
- "operand_segment_sizes",
- builder.getI32VectorAttr({
- static_cast<int32_t>(1),
- static_cast<int32_t>(
- source.getType().isa<TensorType>() ? dynamicDims.size() : 0),
- static_cast<int32_t>(resultType.isa<TensorType>() ? dynamicDims.size()
- : 0),
- }));
-}
-
-Value TensorCastOp::buildOperandRankedShape(unsigned idx, OpBuilder &builder) {
- if (source().getType().isa<TensorType>()) {
- return Shape::buildRankedShapeForValue(getLoc(), source(), source_dims(),
- builder);
- } else {
- return buildResultRankedShape(idx, builder);
- }
-}
-
-Value TensorCastOp::buildResultRankedShape(unsigned idx, OpBuilder &builder) {
- if (target().getType().isa<TensorType>()) {
- return Shape::buildRankedShapeForValue(getLoc(), target(), target_dims(),
- builder);
- } else {
- return buildOperandRankedShape(idx, builder);
- }
-}
-
-Value TensorCastOp::getTiedResult(unsigned resultIndex) {
+Value TensorImportOp::getTiedResult(unsigned resultIndex) {
return IREE::Util::TiedOpInterface::findTiedBaseValue(source());
}
-::llvm::Optional<unsigned> TensorCastOp::getTiedResultOperandIndex(
+::llvm::Optional<unsigned> TensorImportOp::getTiedResultOperandIndex(
unsigned resultIndex) {
return {0}; // source
}
-SmallVector<int64_t, 4> TensorCastOp::getTiedResultOperandIndices() {
+SmallVector<int64_t, 4> TensorImportOp::getTiedResultOperandIndices() {
+ return {0}; // source
+}
+
+void TensorExportOp::build(OpBuilder &builder, OperationState &result,
+ Type resultType, Value source) {
+ auto dynamicDims =
+ IREE::Util::buildDynamicDimsForValue(result.location, source, builder);
+ build(builder, result, resultType, source, dynamicDims);
+}
+
+Value TensorExportOp::getTiedResult(unsigned resultIndex) {
+ return IREE::Util::TiedOpInterface::findTiedBaseValue(source());
+}
+
+::llvm::Optional<unsigned> TensorExportOp::getTiedResultOperandIndex(
+ unsigned resultIndex) {
+ return {0}; // source
+}
+
+SmallVector<int64_t, 4> TensorExportOp::getTiedResultOperandIndices() {
return {0}; // source
}
@@ -926,17 +907,6 @@
SymbolTable::lookupNearestSymbolFrom(getOperation(), binding()));
}
-Value InterfaceBindingSubspanOp::buildOperandRankedShape(unsigned idx,
- OpBuilder &builder) {
- return {};
-}
-
-Value InterfaceBindingSubspanOp::buildResultRankedShape(unsigned idx,
- OpBuilder &builder) {
- return Shape::buildRankedShapeForValue(getLoc(), result(), dynamic_dims(),
- builder);
-}
-
// TODO(benvanik): share with align op folder and analysis.
// May need an interface for querying the alignment from ops that can carry it.
diff --git a/iree/compiler/Dialect/HAL/IR/HALOps.h b/iree/compiler/Dialect/HAL/IR/HALOps.h
index c16c8a7..b417181 100644
--- a/iree/compiler/Dialect/HAL/IR/HALOps.h
+++ b/iree/compiler/Dialect/HAL/IR/HALOps.h
@@ -11,8 +11,6 @@
#include "iree/compiler/Dialect/HAL/IR/HALTraits.h"
#include "iree/compiler/Dialect/HAL/IR/HALTypes.h"
-#include "iree/compiler/Dialect/Shape/IR/ShapeDialect.h"
-#include "iree/compiler/Dialect/Shape/IR/ShapeTypes.h"
#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
#include "iree/compiler/Dialect/Util/IR/UtilTraits.h"
#include "llvm/Support/Alignment.h"
diff --git a/iree/compiler/Dialect/HAL/IR/HALOps.td b/iree/compiler/Dialect/HAL/IR/HALOps.td
index 51f37cd..7a5e182 100644
--- a/iree/compiler/Dialect/HAL/IR/HALOps.td
+++ b/iree/compiler/Dialect/HAL/IR/HALOps.td
@@ -11,7 +11,6 @@
include "iree/compiler/Dialect/HAL/IR/HALInterfaces.td"
include "iree/compiler/Dialect/Util/IR/UtilAttrs.td"
include "iree/compiler/Dialect/Util/IR/UtilInterfaces.td"
-include "iree/compiler/Dialect/Shape/IR/ShapeInterfaces.td"
include "mlir/IR/OpAsmInterface.td"
include "mlir/IR/SymbolInterfaces.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
@@ -59,56 +58,94 @@
// Pseudo ops for conversion support
//===----------------------------------------------------------------------===//
-def HAL_TensorCastOp : HAL_PureOp<"tensor.cast", [
- AttrSizedOperandSegments,
- DeclareOpInterfaceMethods<Util_TiedOpInterface, [
- "getTiedResult",
- "getTiedResultOperandIndex",
- "getTiedResultOperandIndices",
- ]>,
- DeclareOpInterfaceMethods<Shape_ShapeCarryingOpInterface>,
- Util_ShapeAwareOp,
- ]> {
- let summary = [{conversion placeholder for HAL<->tensor type conversion}];
+def HAL_TensorImportOp : HAL_PureOp<"tensor.import", [
+ DeclareOpInterfaceMethods<Util_TiedOpInterface, [
+ "getTiedResult",
+ "getTiedResultOperandIndex",
+ "getTiedResultOperandIndices",
+ ]>,
+ Util_ShapeAwareOp,
+]> {
+ let summary = [{imports a tensor from a HAL buffer view}];
let description = [{
- Defines a conversion from a higher-level dialect type such as `tensor` that
- is resolved during lowering into the HAL. This can be used to interoperate
- between levels of the stack that require specifying HAL types and those that
- prior to lowering do not handle them.
+ Defines an import of an external HAL buffer view into a SSA-form tensor.
+ An optional semaphore timepoint can be specified indicating when the
+ buffer view is available for use. If no semaphore timepoint is provided it
+ is assumed the buffer view is immediately available.
}];
let arguments = (ins
- AnyType:$source,
- HAL_ShapeDynamicDims:$source_dims,
+ AnyTypeOf<[HAL_Buffer, HAL_BufferView]>:$source,
HAL_ShapeDynamicDims:$target_dims
);
let results = (outs
- AnyType:$target
+ AnyTensor:$target
);
let assemblyFormat = [{
- $source `:`
- type($source) (`{` $source_dims^ `}`)? `->`
+ $source `:` type($source)
+ `->`
type($target) (`{` $target_dims^ `}`)?
attr-dict-with-keyword
}];
- let skipDefaultBuilders = 1;
let builders = [
OpBuilder<(ins
"Type":$resultType,
- "Value":$source,
- CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>,
+ "Value":$source
+ )>,
+ ];
+
+ let extraClassDeclaration = [{
+ ValueRange getOperandDynamicDims(unsigned idx) { return {}; }
+ ValueRange getResultDynamicDims(unsigned idx) { return target_dims(); }
+ }];
+
+ let hasFolder = 1;
+}
+
+def HAL_TensorExportOp : HAL_PureOp<"tensor.export", [
+ DeclareOpInterfaceMethods<Util_TiedOpInterface, [
+ "getTiedResult",
+ "getTiedResultOperandIndex",
+ "getTiedResultOperandIndices",
+ ]>,
+ Util_ShapeAwareOp,
+]> {
+ let summary = [{exports a tensor to a HAL buffer view}];
+ let description = [{
+ Defines an export of an SSA-form tensor to an external HAL buffer view.
+ An optional semaphore timepoint can be specified indicating when the
+ buffer view is available for use. If no semaphore timepoint is requested it
+ is assumed execution blocks until the buffer view is available.
+ }];
+
+ let arguments = (ins
+ AnyTensor:$source,
+ HAL_ShapeDynamicDims:$source_dims
+ );
+ let results = (outs
+ AnyTypeOf<[HAL_Buffer, HAL_BufferView]>:$target
+ );
+
+ let assemblyFormat = [{
+ $source `:`
+ type($source) (`{` $source_dims^ `}`)?
+ `->`
+ type($target)
+ attr-dict-with-keyword
+ }];
+
+ let builders = [
OpBuilder<(ins
"Type":$resultType,
- "Value":$source,
- "ValueRange":$dynamicDims,
- CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>,
+ "Value":$source
+ )>,
];
let extraClassDeclaration = [{
ValueRange getOperandDynamicDims(unsigned idx) { return source_dims(); }
- ValueRange getResultDynamicDims(unsigned idx) { return target_dims(); }
+ ValueRange getResultDynamicDims(unsigned idx) { return {}; }
}];
let hasFolder = 1;
@@ -1919,7 +1956,6 @@
def HAL_InterfaceBindingSubspanOp : HAL_Op<"interface.binding.subspan", [
AttrSizedOperandSegments, MemoryEffects<[MemAlloc]>,
- DeclareOpInterfaceMethods<Shape_ShapeCarryingOpInterface>,
Util_ShapeAwareOp,
]> {
let summary = [{returns an alias to a subspan of interface binding data}];
@@ -2125,8 +2161,6 @@
}];
}
-// TODO(benvanik): rework this to make it a terminator with branch targets
-// for ^success and ^fail(status).
def HAL_SemaphoreAwaitOp : HAL_Op<"semaphore.await", [YieldPoint]> {
let summary = [{asynchronous semaphore wait operation}];
let description = [{
diff --git a/iree/compiler/Dialect/HAL/IR/test/tensor_op_folding.mlir b/iree/compiler/Dialect/HAL/IR/test/tensor_op_folding.mlir
index e68d76a..8b2dd07 100644
--- a/iree/compiler/Dialect/HAL/IR/test/tensor_op_folding.mlir
+++ b/iree/compiler/Dialect/HAL/IR/test/tensor_op_folding.mlir
@@ -1,41 +1,39 @@
// RUN: iree-opt -split-input-file -canonicalize -cse %s | iree-opt -split-input-file | IreeFileCheck %s
-// CHECK-LABEL: @tensorCastMatchingTypeFolds
-func @tensorCastMatchingTypeFolds(%arg0: !hal.buffer_view) -> !hal.buffer_view {
- // CHECK-NOT: hal.tensor.cast
+// CHECK-LABEL: @foldTensorImportExport
+func @foldTensorImportExport(%arg0: !hal.buffer_view) -> !hal.buffer_view {
+ // CHECK-NOT: hal.tensor.import
+ %0 = hal.tensor.import %arg0 : !hal.buffer_view -> tensor<5xi32>
+ // CHECK-NOT: hal.tensor.export
+ %1 = hal.tensor.export %0 : tensor<5xi32> -> !hal.buffer_view
// CHECK: return %arg0 : !hal.buffer_view
- %0 = hal.tensor.cast %arg0 : !hal.buffer_view -> !hal.buffer_view
- return %0 : !hal.buffer_view
-}
-
-// -----
-
-// CHECK-LABEL: @tensorCastPassthroughFolds
-func @tensorCastPassthroughFolds(%arg0: !hal.buffer_view) -> !hal.buffer_view {
- // CHECK-NOT: hal.tensor.cast
- // CHECK: return %arg0 : !hal.buffer_view
- %0 = hal.tensor.cast %arg0 : !hal.buffer_view -> tensor<5xi32>
- %1 = hal.tensor.cast %0 : tensor<5xi32> -> !hal.buffer_view
return %1 : !hal.buffer_view
}
// -----
-// CHECK-LABEL: @tensorCastThroughDifferentTypesFolds
-func @tensorCastThroughDifferentTypesFolds(%arg0: !hal.buffer_view) -> !hal.buffer {
- // CHECK: %0 = hal.tensor.cast %arg0 : !hal.buffer_view -> !hal.buffer
- // CHECK: return %0 : !hal.buffer
- %0 = hal.tensor.cast %arg0 : !hal.buffer_view -> tensor<5xi32>
- %1 = hal.tensor.cast %0 : tensor<5xi32> -> !hal.buffer
+// TODO(benvanik): add a canonicalizer to take buffer_view -> buffer and turn
+// it into a hal.buffer_view.buffer op and buffer -> buffer_view into a
+// hal.buffer_view.create.
+// For now we just don't fold.
+
+// CHECK-LABEL: @foldTensorImportExportTypeMismatch
+func @foldTensorImportExportTypeMismatch(%arg0: !hal.buffer_view) -> !hal.buffer {
+ // CHECK: hal.tensor.import
+ %0 = hal.tensor.import %arg0 : !hal.buffer_view -> tensor<5xi32>
+ // CHECK: hal.tensor.export
+ %1 = hal.tensor.export %0 : tensor<5xi32> -> !hal.buffer
return %1 : !hal.buffer
}
// -----
-// CHECK-LABEL: @tensorCastFoldingPreservesDims
-func @tensorCastFoldingPreservesDims(%arg0: !hal.buffer_view, %arg1 : index) -> tensor<?x3xi32> {
- // CHECK: hal.tensor.cast %arg0 : !hal.buffer_view -> tensor<?x3xi32>{%arg1}
- %0 = hal.tensor.cast %arg0 : !hal.buffer_view -> !hal.buffer
- %1 = hal.tensor.cast %0 : !hal.buffer -> tensor<?x3xi32>{%arg1}
- return %1 : tensor<?x3xi32>
+// CHECK-LABEL: @foldTensorExportImport
+func @foldTensorExportImport(%arg0: tensor<5xi32>) -> tensor<5xi32> {
+ // CHECK-NOT: hal.tensor.export
+ %0 = hal.tensor.export %arg0 : tensor<5xi32> -> !hal.buffer_view
+ // CHECK-NOT: hal.tensor.import
+ %1 = hal.tensor.import %0 : !hal.buffer_view -> tensor<5xi32>
+ // CHECK: return %arg0 : tensor<5xi32>
+ return %1 : tensor<5xi32>
}
diff --git a/iree/compiler/Dialect/HAL/IR/test/tensor_ops.mlir b/iree/compiler/Dialect/HAL/IR/test/tensor_ops.mlir
index 8d396db..70dc2cb 100644
--- a/iree/compiler/Dialect/HAL/IR/test/tensor_ops.mlir
+++ b/iree/compiler/Dialect/HAL/IR/test/tensor_ops.mlir
@@ -1,22 +1,26 @@
// RUN: iree-opt -split-input-file %s | iree-opt -split-input-file | IreeFileCheck %s
-// CHECK-LABEL: @tensorCastStatic
-func @tensorCastStatic(%arg0: !hal.buffer_view) -> tensor<5xi32> {
- // CHECK: hal.tensor.cast %arg0 : !hal.buffer_view -> tensor<5xi32>
- %0 = hal.tensor.cast %arg0 : !hal.buffer_view -> tensor<5xi32>
+// CHECK-LABEL: @tensorImportStatic
+func @tensorImportStatic(%arg0: !hal.buffer_view) -> tensor<5xi32> {
+ // CHECK: hal.tensor.import %arg0 : !hal.buffer_view -> tensor<5xi32>
+ %0 = hal.tensor.import %arg0 : !hal.buffer_view -> tensor<5xi32>
return %0 : tensor<5xi32>
}
-// CHECK-LABEL: @tensorCastDynamicInput
-func @tensorCastDynamicInput(%arg0: tensor<?x3xi32>, %arg1 : index) -> !hal.buffer_view {
- // CHECK: hal.tensor.cast %arg0 : tensor<?x3xi32>{%arg1} -> !hal.buffer_view
- %0 = hal.tensor.cast %arg0 : tensor<?x3xi32>{%arg1} -> !hal.buffer_view
- return %0 : !hal.buffer_view
+// -----
+
+// CHECK-LABEL: @tensorImportDynamic
+func @tensorImportDynamic(%arg0: !hal.buffer_view, %arg1 : index) -> tensor<?x3xi32> {
+ // CHECK: hal.tensor.import %arg0 : !hal.buffer_view -> tensor<?x3xi32>{%arg1}
+ %0 = hal.tensor.import %arg0 : !hal.buffer_view -> tensor<?x3xi32>{%arg1}
+ return %0 : tensor<?x3xi32>
}
-// CHECK-LABEL: @tensorCastDynamicOutput
-func @tensorCastDynamicOutput(%arg0: !hal.buffer_view, %arg1 : index) -> tensor<?x3xi32> {
- // CHECK: hal.tensor.cast %arg0 : !hal.buffer_view -> tensor<?x3xi32>{%arg1}
- %0 = hal.tensor.cast %arg0 : !hal.buffer_view -> tensor<?x3xi32>{%arg1}
- return %0 : tensor<?x3xi32>
+// -----
+
+// CHECK-LABEL: @tensorExportDynamic
+func @tensorExportDynamic(%arg0: tensor<?x3xi32>, %arg1 : index) -> !hal.buffer_view {
+ // CHECK: hal.tensor.export %arg0 : tensor<?x3xi32>{%arg1} -> !hal.buffer_view
+ %0 = hal.tensor.export %arg0 : tensor<?x3xi32>{%arg1} -> !hal.buffer_view
+ return %0 : !hal.buffer_view
}
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/LLVMAOTTarget.cpp b/iree/compiler/Dialect/HAL/Target/LLVM/LLVMAOTTarget.cpp
index 2838be9..6e5bb5f 100644
--- a/iree/compiler/Dialect/HAL/Target/LLVM/LLVMAOTTarget.cpp
+++ b/iree/compiler/Dialect/HAL/Target/LLVM/LLVMAOTTarget.cpp
@@ -103,10 +103,11 @@
// Verifies builtin bitcode is loaded correctly and appends it to |linker|.
//
// Example:
-// if (failed(linkBuiltinLibrary(loc, linker, "libfoo", loadLibFoo(...))))
+// if (failed(linkBuiltinLibrary(loc, linker, linkerFlag, targetMachine,
+// "libfoo", loadLibFoo(...))))
static LogicalResult linkBuiltinLibrary(
- Location loc, llvm::Linker &linker, llvm::TargetMachine *targetMachine,
- StringRef name,
+ Location loc, llvm::Linker &linker, llvm::Linker::Flags linkerFlag,
+ llvm::TargetMachine *targetMachine, StringRef name,
llvm::Expected<std::unique_ptr<llvm::Module>> bitcodeModuleValue) {
// Ensure the bitcode loaded correctly. It may fail if the LLVM version is
// incompatible.
@@ -122,9 +123,7 @@
// Link the bitcode into the base module. This will merge in any required
// symbols and override declarations that may exist.
- if (linker.linkInModule(
- std::move(bitcodeModule),
- llvm::Linker::OverrideFromSrc /*| llvm::Linker::LinkOnlyNeeded*/)) {
+ if (linker.linkInModule(std::move(bitcodeModule), linkerFlag)) {
return mlir::emitError(loc) << "failed to link " << name << " bitcode";
}
@@ -340,16 +339,20 @@
// Note that if producing a static library then the symbols we add must be
// weak such that we don't trigger ODR issues.
llvm::Linker moduleLinker(*llvmModule);
+
+ llvm::Linker::Flags linkerFlag = llvm::Linker::OverrideFromSrc;
+ if (options_.linkStatic) linkerFlag = llvm::Linker::LinkOnlyNeeded;
+
if (failed(linkBuiltinLibrary(
- variantOp.getLoc(), moduleLinker, targetMachine.get(), "libdevice",
- loadDeviceBitcode(targetMachine.get(), context)))) {
+ variantOp.getLoc(), moduleLinker, linkerFlag, targetMachine.get(),
+ "libdevice", loadDeviceBitcode(targetMachine.get(), context)))) {
return mlir::emitError(variantOp.getLoc())
<< "failed linking in builtin library for target triple '"
<< options_.targetTriple << "'";
}
if (failed(linkBuiltinLibrary(
- variantOp.getLoc(), moduleLinker, targetMachine.get(), "libmusl",
- loadMuslBitcode(targetMachine.get(), context)))) {
+ variantOp.getLoc(), moduleLinker, linkerFlag, targetMachine.get(),
+ "libmusl", loadMuslBitcode(targetMachine.get(), context)))) {
return mlir::emitError(variantOp.getLoc())
<< "failed linking in builtin library for target triple '"
<< options_.targetTriple << "'";
diff --git a/iree/compiler/Dialect/HAL/Transforms/ResolveEntryPointOrdinals.cpp b/iree/compiler/Dialect/HAL/Transforms/ResolveEntryPointOrdinals.cpp
index 50849dc..f69c1dd 100644
--- a/iree/compiler/Dialect/HAL/Transforms/ResolveEntryPointOrdinals.cpp
+++ b/iree/compiler/Dialect/HAL/Transforms/ResolveEntryPointOrdinals.cpp
@@ -87,7 +87,10 @@
OwningRewritePatternList patterns(&getContext());
patterns.insert<ResolveCommandBufferDispatchOrdinals>(context);
patterns.insert<ResolveCommandBufferDispatchIndirectOrdinals>(context);
- (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+ if (failed(applyPatternsAndFoldGreedily(getOperation(),
+ std::move(patterns)))) {
+ return signalPassFailure();
+ }
}
};
diff --git a/iree/compiler/Dialect/Modules/VMVX/Conversion/StandardToVMVX/BUILD b/iree/compiler/Dialect/Modules/VMVX/Conversion/StandardToVMVX/BUILD
index 07089d3..680fb46 100644
--- a/iree/compiler/Dialect/Modules/VMVX/Conversion/StandardToVMVX/BUILD
+++ b/iree/compiler/Dialect/Modules/VMVX/Conversion/StandardToVMVX/BUILD
@@ -21,7 +21,6 @@
deps = [
"//iree/compiler/Dialect/Modules/VMVX/IR",
"//iree/compiler/Dialect/Modules/VMVX/IR:VMVXDialect",
- "//iree/compiler/Dialect/Shape/IR",
"//iree/compiler/Dialect/Util/IR",
"@llvm-project//mlir:Affine",
"@llvm-project//mlir:IR",
diff --git a/iree/compiler/Dialect/Modules/VMVX/Conversion/StandardToVMVX/CMakeLists.txt b/iree/compiler/Dialect/Modules/VMVX/Conversion/StandardToVMVX/CMakeLists.txt
index b945209..0b0833b 100644
--- a/iree/compiler/Dialect/Modules/VMVX/Conversion/StandardToVMVX/CMakeLists.txt
+++ b/iree/compiler/Dialect/Modules/VMVX/Conversion/StandardToVMVX/CMakeLists.txt
@@ -29,7 +29,6 @@
MLIRTransforms
iree::compiler::Dialect::Modules::VMVX::IR
iree::compiler::Dialect::Modules::VMVX::IR::VMVXDialect
- iree::compiler::Dialect::Shape::IR
iree::compiler::Dialect::Util::IR
PUBLIC
)
diff --git a/iree/compiler/Dialect/Modules/VMVX/Conversion/StandardToVMVX/ConvertStandardToVMVX.cpp b/iree/compiler/Dialect/Modules/VMVX/Conversion/StandardToVMVX/ConvertStandardToVMVX.cpp
index e15e4b8..b7a96fb 100644
--- a/iree/compiler/Dialect/Modules/VMVX/Conversion/StandardToVMVX/ConvertStandardToVMVX.cpp
+++ b/iree/compiler/Dialect/Modules/VMVX/Conversion/StandardToVMVX/ConvertStandardToVMVX.cpp
@@ -9,9 +9,6 @@
#include "iree/compiler/Dialect/Modules/VMVX/IR/VMVXDialect.h"
#include "iree/compiler/Dialect/Modules/VMVX/IR/VMVXOps.h"
#include "iree/compiler/Dialect/Modules/VMVX/IR/VMVXTypes.h"
-#include "iree/compiler/Dialect/Shape/IR/Builders.h"
-#include "iree/compiler/Dialect/Shape/IR/ShapeOps.h"
-#include "iree/compiler/Dialect/Shape/IR/ShapeTypes.h"
#include "iree/compiler/Dialect/Util/IR/UtilTypes.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
diff --git a/iree/compiler/Dialect/Modules/VMVX/IR/BUILD b/iree/compiler/Dialect/Modules/VMVX/IR/BUILD
index 7a4bac4..47d1dd5 100644
--- a/iree/compiler/Dialect/Modules/VMVX/IR/BUILD
+++ b/iree/compiler/Dialect/Modules/VMVX/IR/BUILD
@@ -26,7 +26,6 @@
include = ["*.td"],
),
deps = [
- "//iree/compiler/Dialect/Shape/IR:td_files",
"//iree/compiler/Dialect/Util/IR:td_files",
"@llvm-project//mlir:OpBaseTdFiles",
"@llvm-project//mlir:StdOpsTdFiles",
@@ -55,7 +54,6 @@
":VMVXEnumsGen",
":VMVXOpInterfaceGen",
":VMVXOpsGen",
- "//iree/compiler/Dialect/Shape/IR",
"//iree/compiler/Dialect/Util/IR",
"//iree/compiler/Dialect/VM/IR",
"@llvm-project//llvm:Support",
diff --git a/iree/compiler/Dialect/Modules/VMVX/IR/CMakeLists.txt b/iree/compiler/Dialect/Modules/VMVX/IR/CMakeLists.txt
index 69f9df7..d48bdf9 100644
--- a/iree/compiler/Dialect/Modules/VMVX/IR/CMakeLists.txt
+++ b/iree/compiler/Dialect/Modules/VMVX/IR/CMakeLists.txt
@@ -37,7 +37,6 @@
MLIRSupport
MLIRTransformUtils
MLIRTranslation
- iree::compiler::Dialect::Shape::IR
iree::compiler::Dialect::Util::IR
iree::compiler::Dialect::VM::IR
PUBLIC
diff --git a/iree/compiler/Dialect/Modules/VMVX/Transforms/BUILD b/iree/compiler/Dialect/Modules/VMVX/Transforms/BUILD
index 0656d11..00b113b 100644
--- a/iree/compiler/Dialect/Modules/VMVX/Transforms/BUILD
+++ b/iree/compiler/Dialect/Modules/VMVX/Transforms/BUILD
@@ -29,8 +29,6 @@
"//iree/compiler/Dialect/Modules/VMVX/Conversion/StandardToVMVX",
"//iree/compiler/Dialect/Modules/VMVX/IR",
"//iree/compiler/Dialect/Modules/VMVX/IR:VMVXDialect",
- "//iree/compiler/Dialect/Shape/IR",
- "//iree/compiler/Dialect/Shape/Transforms",
"//iree/compiler/Dialect/Util/IR",
"//iree/compiler/Dialect/Util/Transforms",
"//iree/compiler/Dialect/VM/IR",
diff --git a/iree/compiler/Dialect/Modules/VMVX/Transforms/CMakeLists.txt b/iree/compiler/Dialect/Modules/VMVX/Transforms/CMakeLists.txt
index 9c100fe..226c34f 100644
--- a/iree/compiler/Dialect/Modules/VMVX/Transforms/CMakeLists.txt
+++ b/iree/compiler/Dialect/Modules/VMVX/Transforms/CMakeLists.txt
@@ -52,8 +52,6 @@
iree::compiler::Dialect::Modules::VMVX::Conversion::StandardToVMVX
iree::compiler::Dialect::Modules::VMVX::IR
iree::compiler::Dialect::Modules::VMVX::IR::VMVXDialect
- iree::compiler::Dialect::Shape::IR
- iree::compiler::Dialect::Shape::Transforms
iree::compiler::Dialect::Util::IR
iree::compiler::Dialect::Util::Transforms
iree::compiler::Dialect::VM::IR
diff --git a/iree/compiler/Dialect/Modules/VMVX/Transforms/Passes.cpp b/iree/compiler/Dialect/Modules/VMVX/Transforms/Passes.cpp
index 5456721..7b17fc9 100644
--- a/iree/compiler/Dialect/Modules/VMVX/Transforms/Passes.cpp
+++ b/iree/compiler/Dialect/Modules/VMVX/Transforms/Passes.cpp
@@ -11,7 +11,6 @@
#include "iree-dialects/Dialect/LinalgExt/Transforms/Passes.h"
#include "iree/compiler/Codegen/Passes.h"
#include "iree/compiler/Dialect/HAL/Transforms/Passes.h"
-#include "iree/compiler/Dialect/Shape/Transforms/Passes.h"
#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
#include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
#include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
@@ -64,8 +63,6 @@
// Flatten and cleanup memrefs.
nestedModulePM.addNestedPass<FuncOp>(memref::createFoldSubViewOpsPass());
- nestedModulePM.addNestedPass<FuncOp>(
- Shape::createFoldDimOverShapeCarryingOpPass());
nestedModulePM.addPass(createCanonicalizerPass());
nestedModulePM.addPass(createCSEPass());
nestedModulePM.addPass(createFlattenMemRefSubspanPass());
diff --git a/iree/compiler/Dialect/Shape/BUILD b/iree/compiler/Dialect/Shape/BUILD
deleted file mode 100644
index 236a474..0000000
--- a/iree/compiler/Dialect/Shape/BUILD
+++ /dev/null
@@ -1,11 +0,0 @@
-# Copyright 2020 The IREE Authors
-#
-# Licensed under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-package(
- default_visibility = ["//visibility:public"],
- features = ["layering_check"],
- licenses = ["notice"], # Apache 2.0
-)
diff --git a/iree/compiler/Dialect/Shape/CMakeLists.txt b/iree/compiler/Dialect/Shape/CMakeLists.txt
deleted file mode 100644
index 4bd4fb4..0000000
--- a/iree/compiler/Dialect/Shape/CMakeLists.txt
+++ /dev/null
@@ -1,13 +0,0 @@
-################################################################################
-# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from #
-# iree/compiler/Dialect/Shape/BUILD #
-# #
-# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary #
-# CMake-only content. #
-# #
-# To disable autogeneration for this file entirely, delete this header. #
-################################################################################
-
-iree_add_all_subdirs()
-
-### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/iree/compiler/Dialect/Shape/IR/BUILD b/iree/compiler/Dialect/Shape/IR/BUILD
deleted file mode 100644
index ad79a52..0000000
--- a/iree/compiler/Dialect/Shape/IR/BUILD
+++ /dev/null
@@ -1,123 +0,0 @@
-# Copyright 2019 The IREE Authors
-#
-# Licensed under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-load("//build_tools/bazel:iree_tablegen_doc.bzl", "iree_tablegen_doc")
-load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library")
-load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob")
-
-package(
- default_visibility = ["//visibility:public"],
- features = ["layering_check"],
- licenses = ["notice"], # Apache 2.0
-)
-
-exports_files(["ShapeBase.td"])
-
-td_library(
- name = "td_files",
- srcs = enforce_glob(
- [
- "ShapeBase.td",
- "ShapeInterfaces.td",
- "ShapeOps.td",
- ],
- include = ["*.td"],
- ),
- deps = [
- "//iree/compiler/Dialect/Util/IR:td_files",
- "@llvm-project//mlir:InferTypeOpInterfaceTdFiles",
- "@llvm-project//mlir:OpBaseTdFiles",
- "@llvm-project//mlir:SideEffectTdFiles",
- "@llvm-project//mlir:StdOpsTdFiles",
- "@llvm-project//mlir:ViewLikeInterfaceTdFiles",
- ],
-)
-
-cc_library(
- name = "IR",
- srcs = [
- "Builders.cpp",
- "Folders.cpp",
- "ShapeDialect.cpp",
- "ShapeInterfaces.cpp.inc",
- "ShapeOps.cpp",
- "ShapeOps.cpp.inc",
- "ShapeTypes.cpp",
- ],
- hdrs = [
- "Builders.h",
- "ShapeDialect.h",
- "ShapeInterfaces.h.inc",
- "ShapeOps.h",
- "ShapeOps.h.inc",
- "ShapeTypes.h",
- ],
- deps = [
- ":ShapeInterfacesGen",
- ":ShapeOpsGen",
- "//iree/compiler/Dialect/Util/IR",
- "//iree/compiler/Utils",
- "@llvm-project//llvm:Support",
- "@llvm-project//mlir:ControlFlowInterfaces",
- "@llvm-project//mlir:IR",
- "@llvm-project//mlir:InferTypeOpInterface",
- "@llvm-project//mlir:MemRefDialect",
- "@llvm-project//mlir:Parser",
- "@llvm-project//mlir:SideEffects",
- "@llvm-project//mlir:StandardOps",
- "@llvm-project//mlir:Support",
- "@llvm-project//mlir:TensorDialect",
- "@llvm-project//mlir:Transforms",
- "@llvm-project//mlir:ViewLikeInterface",
- ],
-)
-
-gentbl_cc_library(
- name = "ShapeInterfacesGen",
- tbl_outs = [
- (
- ["-gen-op-interface-decls"],
- "ShapeInterfaces.h.inc",
- ),
- (
- ["-gen-op-interface-defs"],
- "ShapeInterfaces.cpp.inc",
- ),
- ],
- tblgen = "@llvm-project//mlir:mlir-tblgen",
- td_file = "ShapeInterfaces.td",
- deps = [":td_files"],
-)
-
-gentbl_cc_library(
- name = "ShapeOpsGen",
- tbl_outs = [
- (
- ["-gen-op-decls"],
- "ShapeOps.h.inc",
- ),
- (
- ["-gen-op-defs"],
- "ShapeOps.cpp.inc",
- ),
- ],
- tblgen = "@llvm-project//mlir:mlir-tblgen",
- td_file = "ShapeOps.td",
- deps = [":td_files"],
-)
-
-iree_tablegen_doc(
- name = "ShapeDialectDocGen",
- tbl_outs = [
- (
- ["-gen-dialect-doc"],
- "ShapeDialect.md",
- ),
- ],
- tblgen = "@llvm-project//mlir:mlir-tblgen",
- td_file = "ShapeOps.td",
- deps = [":td_files"],
-)
diff --git a/iree/compiler/Dialect/Shape/IR/Builders.cpp b/iree/compiler/Dialect/Shape/IR/Builders.cpp
deleted file mode 100644
index ca4c6bf..0000000
--- a/iree/compiler/Dialect/Shape/IR/Builders.cpp
+++ /dev/null
@@ -1,138 +0,0 @@
-// Copyright 2020 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#include "iree/compiler/Dialect/Shape/IR/Builders.h"
-
-#include "iree/compiler/Dialect/Shape/IR/ShapeDialect.h"
-#include "iree/compiler/Dialect/Shape/IR/ShapeOps.h"
-#include "iree/compiler/Dialect/Shape/IR/ShapeTypes.h"
-#include "mlir/Dialect/StandardOps/IR/Ops.h"
-#include "mlir/Dialect/Tensor/IR/Tensor.h"
-#include "mlir/IR/Diagnostics.h"
-
-namespace mlir {
-namespace iree_compiler {
-namespace Shape {
-
-static Value getRankedShapeFromOpResult(Operation *op, Value resultValue,
- OpBuilder &builder) {
- if (!op) return nullptr;
- if (auto carryingOp = dyn_cast<ShapeCarryingInterface>(op)) {
- return carryingOp.buildResultValueRankedShape(resultValue, builder);
- } else {
- return nullptr;
- }
-}
-
-static Value getRankedShapeFromOpOperand(Operation *op, unsigned idx,
- OpBuilder &builder) {
- auto carryingOp = dyn_cast_or_null<ShapeCarryingInterface>(op);
- if (!carryingOp) {
- auto value = op->getOperand(idx);
- auto definingOp = value.getDefiningOp();
- if (!definingOp) return nullptr;
- return getRankedShapeFromOpResult(definingOp, value, builder);
- }
- return carryingOp.buildOperandRankedShape(idx, builder);
-}
-
-static Value findRankedShapeFromUse(Value value, OpBuilder &builder) {
- Value rs = getRankedShapeFromOpResult(value.getDefiningOp(), value, builder);
- if (rs) return rs;
- for (auto &use : value.getUses()) {
- rs = getRankedShapeFromOpOperand(use.getOwner(), use.getOperandNumber(),
- builder);
- if (rs) return rs;
- }
- return nullptr;
-}
-
-Value buildRankedShapeForValue(Location loc, Value shapedValue,
- ValueRange dynamicDims, OpBuilder &builder) {
- auto shapedType = shapedValue.getType().dyn_cast<ShapedType>();
- assert(shapedType && "only valid to call on shaped types");
- return builder.createOrFold<Shape::MakeRankedShapeOp>(
- loc, Shape::RankedShapeType::get(shapedType), dynamicDims);
-}
-
-// Slices out a range of |dynamicDims| corresponding to the value at |index|.
-static ValueRange sliceDynamicDims(unsigned index, ValueRange values,
- ValueRange dynamicDims) {
- auto valueType = values[index].getType().dyn_cast<ShapedType>();
- assert(valueType && "must be a shaped type to get dims");
- unsigned dimsIndex = 0;
- for (unsigned i = 0; i < index; ++i) {
- if (auto shapedType = values[i].getType().dyn_cast<ShapedType>()) {
- dimsIndex += shapedType.getNumDynamicDims();
- }
- }
- return dynamicDims.slice(dimsIndex, valueType.getNumDynamicDims());
-}
-
-Value buildRankedShapeForValueInList(Location loc, unsigned index,
- ValueRange flatValues,
- ValueRange flatDynamicDims,
- OpBuilder &builder) {
- auto dynamicDims = sliceDynamicDims(index, flatValues, flatDynamicDims);
- return buildRankedShapeForValue(loc, flatValues[index], dynamicDims, builder);
-}
-
-SmallVector<Value, 4> buildOrFindDynamicDimsForValue(Location loc, Value value,
- OpBuilder &builder) {
- auto valueSt = value.getType().dyn_cast<ShapedType>();
- if (!valueSt) {
- builder.getContext()->getDiagEngine().emit(loc, DiagnosticSeverity::Error)
- << "cannot construct shape for non shaped value: " << value.getType();
- return {};
- }
-
- // Bail if all dimensions are static.
- if (valueSt.hasStaticShape()) {
- return {};
- }
-
- // TODO(benvanik): delete this entire dialect.
- // This is the first step on the path: we are going to gradually start
- // removing the implementation of the ShapeCarryingInterface on ops and use
- // the new ShapeAwareOpInterface.
- auto dynamicDims = IREE::Util::findDynamicDims(value, builder.getBlock(),
- builder.getInsertionPoint());
- if (dynamicDims.hasValue()) {
- return llvm::to_vector<4>(dynamicDims.getValue());
- }
-
- // Dynamic - walk the uses to find a tie_shape op (either this op or an
- // immediate use).
- SmallVector<Value, 4> result;
- Value rs = findRankedShapeFromUse(value, builder);
- if (rs) {
- auto rsType = rs.getType().dyn_cast<RankedShapeType>();
- if (!rsType) {
- builder.getContext()->getDiagEngine().emit(loc, DiagnosticSeverity::Error)
- << "dynamically shaped value is not ranked (which is not yet "
- << "supported)";
- return {};
- }
- for (unsigned i = 0; i < rsType.getRank(); ++i) {
- if (rsType.isDimDynamic(i)) {
- result.push_back(builder.createOrFold<Shape::RankedDimOp>(loc, rs, i));
- }
- }
- } else {
- // No tie information - insert std.dim ops that may later be used and
- // hopefully converted to ranked shape types.
- for (unsigned i = 0; i < valueSt.getRank(); ++i) {
- if (valueSt.isDynamicDim(i)) {
- result.push_back(builder.createOrFold<tensor::DimOp>(loc, value, i));
- }
- }
- }
- return result;
-}
-
-} // namespace Shape
-} // namespace iree_compiler
-} // namespace mlir
diff --git a/iree/compiler/Dialect/Shape/IR/Builders.h b/iree/compiler/Dialect/Shape/IR/Builders.h
deleted file mode 100644
index 1421c5a..0000000
--- a/iree/compiler/Dialect/Shape/IR/Builders.h
+++ /dev/null
@@ -1,42 +0,0 @@
-// Copyright 2020 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#ifndef IREE_COMPILER_DIALECT_SHAPE_IR_BUILDERS_H_
-#define IREE_COMPILER_DIALECT_SHAPE_IR_BUILDERS_H_
-
-#include "iree/compiler/Dialect/Shape/IR/ShapeTypes.h"
-#include "mlir/Dialect/MemRef/IR/MemRef.h"
-#include "mlir/IR/Builders.h"
-#include "mlir/IR/Location.h"
-#include "mlir/IR/Operation.h"
-
-namespace mlir {
-namespace iree_compiler {
-namespace Shape {
-
-// Builds a ranked_shape for the given |shapedValue| with zero or more dynamic
-// dims with the values taken from |dynamicDims|.
-Value buildRankedShapeForValue(Location loc, Value shapedValue,
- ValueRange dynamicDims, OpBuilder &builder);
-
-// As with buildRankedShapeForValue but by selecting out the appropriate dims
-// from a flattened set of values and dynamic dims.
-Value buildRankedShapeForValueInList(Location loc, unsigned index,
- ValueRange flatValues,
- ValueRange flatDynamicDims,
- OpBuilder &builder);
-
-// Returns dimension values for each dynamic dimension of the given |value|.
-// |value| must be a ShapedType and may optionally have a ranked_shape tied.
-// The returned value range will be empty if the shape is fully static.
-SmallVector<Value, 4> buildOrFindDynamicDimsForValue(Location loc, Value value,
- OpBuilder &builder);
-
-} // namespace Shape
-} // namespace iree_compiler
-} // namespace mlir
-
-#endif // IREE_COMPILER_DIALECT_SHAPE_IR_BUILDERS_H_
diff --git a/iree/compiler/Dialect/Shape/IR/CMakeLists.txt b/iree/compiler/Dialect/Shape/IR/CMakeLists.txt
deleted file mode 100644
index 7707eb5..0000000
--- a/iree/compiler/Dialect/Shape/IR/CMakeLists.txt
+++ /dev/null
@@ -1,80 +0,0 @@
-################################################################################
-# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from #
-# iree/compiler/Dialect/Shape/IR/BUILD #
-# #
-# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary #
-# CMake-only content. #
-# #
-# To disable autogeneration for this file entirely, delete this header. #
-################################################################################
-
-iree_add_all_subdirs()
-
-iree_cc_library(
- NAME
- IR
- HDRS
- "Builders.h"
- "ShapeDialect.h"
- "ShapeInterfaces.h.inc"
- "ShapeOps.h"
- "ShapeOps.h.inc"
- "ShapeTypes.h"
- SRCS
- "Builders.cpp"
- "Folders.cpp"
- "ShapeDialect.cpp"
- "ShapeInterfaces.cpp.inc"
- "ShapeOps.cpp"
- "ShapeOps.cpp.inc"
- "ShapeTypes.cpp"
- DEPS
- ::ShapeInterfacesGen
- ::ShapeOpsGen
- LLVMSupport
- MLIRControlFlowInterfaces
- MLIRIR
- MLIRInferTypeOpInterface
- MLIRMemRef
- MLIRParser
- MLIRSideEffectInterfaces
- MLIRStandard
- MLIRSupport
- MLIRTensor
- MLIRTransforms
- MLIRViewLikeInterface
- iree::compiler::Dialect::Util::IR
- iree::compiler::Utils
- PUBLIC
-)
-
-iree_tablegen_library(
- NAME
- ShapeInterfacesGen
- TD_FILE
- "ShapeInterfaces.td"
- OUTS
- -gen-op-interface-decls ShapeInterfaces.h.inc
- -gen-op-interface-defs ShapeInterfaces.cpp.inc
-)
-
-iree_tablegen_library(
- NAME
- ShapeOpsGen
- TD_FILE
- "ShapeOps.td"
- OUTS
- -gen-op-decls ShapeOps.h.inc
- -gen-op-defs ShapeOps.cpp.inc
-)
-
-iree_tablegen_doc(
- NAME
- ShapeDialectDocGen
- TD_FILE
- "ShapeOps.td"
- OUTS
- -gen-dialect-doc ShapeDialect.md
-)
-
-### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/iree/compiler/Dialect/Shape/IR/Folders.cpp b/iree/compiler/Dialect/Shape/IR/Folders.cpp
deleted file mode 100644
index 70dddb9..0000000
--- a/iree/compiler/Dialect/Shape/IR/Folders.cpp
+++ /dev/null
@@ -1,164 +0,0 @@
-// Copyright 2020 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#include "iree/compiler/Dialect/Shape/IR/ShapeOps.h"
-#include "iree/compiler/Dialect/Shape/IR/ShapeTypes.h"
-#include "iree/compiler/Utils/PatternUtils.h"
-#include "llvm/Support/Debug.h"
-#include "mlir/Dialect/MemRef/IR/MemRef.h"
-#include "mlir/Dialect/StandardOps/IR/Ops.h"
-#include "mlir/Dialect/Tensor/IR/Tensor.h"
-#include "mlir/IR/Matchers.h"
-#include "mlir/IR/PatternMatch.h"
-#include "mlir/Support/LogicalResult.h"
-#include "mlir/Transforms/DialectConversion.h"
-
-namespace mlir {
-namespace iree_compiler {
-namespace Shape {
-
-//===----------------------------------------------------------------------===//
-// Canonicalization
-//===----------------------------------------------------------------------===//
-
-static LogicalResult identityMakeRankedShapePattern(
- MakeRankedShapeOp op, MakeRankedShapeOp::Adaptor operands,
- PatternRewriter &rewriter) {
- if (operands.dynamic_dimensions().empty()) {
- // Do not match static shapes.
- return failure();
- }
-
- // Detects make_ranked_shape ops whose dynamic dimensions are provided by
- // ranked_dim ops that extract dimensions from an identical ranked_shape.
- auto rankedShape = op.getRankedShapeType();
- RankedDimOp commonRankedDimOp;
- unsigned previousProvidingIndex = 0;
- for (auto providingDim : operands.dynamic_dimensions()) {
- auto rankedDimOp =
- llvm::dyn_cast_or_null<RankedDimOp>(providingDim.getDefiningOp());
- if (!rankedDimOp) return failure();
-
- // Shapes must match and refer to a dynamic index.
- unsigned providingIndex = rankedDimOp.getIndex();
- if (rankedDimOp.getRankedShapeType() != rankedShape ||
- !rankedShape.isDimDynamic(providingIndex)) {
- return failure();
- }
-
- if (commonRankedDimOp) {
- // Not first dim: verify same providing shape and indexes into next
- // dynamic dim.
- if (rankedDimOp.shape() != commonRankedDimOp.shape() ||
- providingIndex <= previousProvidingIndex) {
- return failure();
- }
- }
-
- commonRankedDimOp = rankedDimOp;
- previousProvidingIndex = rankedDimOp.getIndex();
- }
-
- // Fall-through: this op produces an identical shape as
- // commonRankedDimOp.
- assert(commonRankedDimOp &&
- "dynamic ranked_shape did not find a common provider");
-
- rewriter.replaceOp(op, commonRankedDimOp.shape());
- return success();
-}
-
-// TODO(silvasean): Better handling of "erase unused ops for legality".
-// Currently, the way that we legalize !shapex.ranked_shape into individual SSA
-// values per dimension is to iteratively reduce other ops to
-// shapex.ranked_dim/shapex.ranked_dims and shapex.make_ranked_shape and then
-// have patterns that know how to resolve the
-// shapex.ranked_dim/shapex.ranked_dims to scalar values by looking through the
-// shapex.make_ranked_shape ops, with the eventual goal of not having any uses
-// of the shapex.make_ranked_shape op itself, instead the main computation flow
-// using the individual SSA values. This naturally produces a lot of unused
-// shapex.make_ranked_shape ops which we need to delete for legality reasons.
-// This pattern allows conversions to erase those ops.
-static LogicalResult eraseUnusedMakeRankedShapeOp(
- MakeRankedShapeOp op, MakeRankedShapeOp::Adaptor operands,
- PatternRewriter &rewriter) {
- if (!op.getResult().use_empty())
- return rewriter.notifyMatchFailure(op, "op has uses");
- rewriter.eraseOp(op);
- return success();
-}
-
-static LogicalResult dynamicMakeRankedShapeDimPattern(
- RankedDimOp op, RankedDimOp::Adaptor operands, PatternRewriter &rewriter) {
- // If the immediate predecessor is a MakeRankedShapeOp, then this op can be
- // erased in favor of the corresponding input to that op.
- auto shapeInput = operands.shape();
- auto makeRsOp =
- dyn_cast_or_null<MakeRankedShapeOp>(shapeInput.getDefiningOp());
- if (!makeRsOp) return failure();
-
- RankedShapeType rsType = shapeInput.getType().cast<RankedShapeType>();
- unsigned index = op.getIndex();
- auto allDims = rsType.getAllDims();
- assert(index < allDims.size());
- if (allDims[index] >= 0) {
- // Not dynamic.
- return failure();
- }
-
- // Map the overall index to the dynamic dim index.
- int dynamicDimIndex = 0;
- for (unsigned i = 0; i < index; ++i) {
- if (allDims[i] < 0) dynamicDimIndex++;
- }
-
- assert(dynamicDimIndex < makeRsOp.dynamic_dimensions().size());
- rewriter.replaceOp(op, makeRsOp.dynamic_dimensions()[dynamicDimIndex]);
- return success();
-}
-
-//===----------------------------------------------------------------------===//
-// shapex.make_ranked_shape
-//===----------------------------------------------------------------------===//
-
-void MakeRankedShapeOp::getCanonicalizationPatterns(
- OwningRewritePatternList &patterns, MLIRContext *context) {
- insertGreedyPattern(patterns, context, identityMakeRankedShapePattern);
-}
-
-//===----------------------------------------------------------------------===//
-// shapex.ranked_dim
-//===----------------------------------------------------------------------===//
-
-OpFoldResult RankedDimOp::fold(ArrayRef<Attribute> operand) {
- auto rsType = shape().getType().cast<RankedShapeType>();
- int index = getIndex();
- if (!rsType.isDimDynamic(index)) {
- auto dimSize = rsType.getStaticDim(index);
- return IntegerAttr::get(getType(), dimSize);
- }
- return {};
-}
-
-void RankedDimOp::getCanonicalizationPatterns(
- OwningRewritePatternList &patterns, MLIRContext *context) {
- insertGreedyPattern(patterns, context, dynamicMakeRankedShapeDimPattern);
-}
-
-//===----------------------------------------------------------------------===//
-// Standard folding and canonicalization conversion patterns.
-//===----------------------------------------------------------------------===//
-
-void populateFoldConversionPatterns(MLIRContext *context,
- OwningRewritePatternList &patterns) {
- insertConversionPattern(patterns, context, eraseUnusedMakeRankedShapeOp);
- insertConversionPattern(patterns, context, dynamicMakeRankedShapeDimPattern);
- insertConversionPattern(patterns, context, identityMakeRankedShapePattern);
-}
-
-} // namespace Shape
-} // namespace iree_compiler
-} // namespace mlir
diff --git a/iree/compiler/Dialect/Shape/IR/ShapeBase.td b/iree/compiler/Dialect/Shape/IR/ShapeBase.td
deleted file mode 100644
index 0455a99..0000000
--- a/iree/compiler/Dialect/Shape/IR/ShapeBase.td
+++ /dev/null
@@ -1,47 +0,0 @@
-// Copyright 2019 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#ifndef IREE_DIALECT_SHAPE_BASE
-#define IREE_DIALECT_SHAPE_BASE
-
-include "mlir/IR/OpBase.td"
-
-//===----------------------------------------------------------------------===//
-// Shape dialect
-//===----------------------------------------------------------------------===//
-
-// TODO(b/143787186): rename when old dialects are removed.
-def Shape_Dialect : Dialect {
- let name = "shapex";
- let cppNamespace = "::mlir::iree_compiler::Shape";
-
- let summary = [{
- A dialect of helper ops for shapifying computations.
- }];
-}
-
-//===----------------------------------------------------------------------===//
-// General types and helpers
-//===----------------------------------------------------------------------===//
-
-def Shape_RankedShape :
- Type<CPred<"$_self.isa<::mlir::iree_compiler::Shape::RankedShapeType>()">,
- "Ranked shape type">;
-
-// TODO(silvasean): Investigate the layering aspects of allowing non-index types
-// here. There seem to be two primary motivators right now, both of which are
-// not obviously ideal long-term:
-//
-// 1. mhlo dialect uses i64 in many places that index should be used.
-// This is understood to be a bug.
-// 2. VMLA gets these values as i32 directly.
-//
-// Both cases can be resolved by inserting casts if we decide to take a firmer
-// stance on only allowing index type. But retaining the flexibility might
-// be a useful feature.
-def Shape_DimType : AnyTypeOf<[Index, AnySignlessInteger]>;
-
-#endif // IREE_DIALECT_SHAPE_BASE
diff --git a/iree/compiler/Dialect/Shape/IR/ShapeDialect.cpp b/iree/compiler/Dialect/Shape/IR/ShapeDialect.cpp
deleted file mode 100644
index 0a0f34d..0000000
--- a/iree/compiler/Dialect/Shape/IR/ShapeDialect.cpp
+++ /dev/null
@@ -1,177 +0,0 @@
-// Copyright 2019 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#include "iree/compiler/Dialect/Shape/IR/ShapeDialect.h"
-
-#include "iree/compiler/Dialect/Shape/IR/ShapeOps.h"
-#include "iree/compiler/Dialect/Shape/IR/ShapeTypes.h"
-#include "llvm/ADT/STLExtras.h"
-#include "llvm/ADT/SmallVector.h"
-#include "llvm/Support/SourceMgr.h"
-#include "mlir/Dialect/StandardOps/IR/Ops.h"
-#include "mlir/IR/Attributes.h"
-#include "mlir/IR/Builders.h"
-#include "mlir/IR/DialectImplementation.h"
-#include "mlir/IR/OpDefinition.h"
-#include "mlir/IR/OpImplementation.h"
-#include "mlir/Interfaces/FoldInterfaces.h"
-#include "mlir/Parser.h"
-#include "mlir/Transforms/InliningUtils.h"
-
-namespace mlir {
-namespace iree_compiler {
-
-#include "iree/compiler/Dialect/Shape/IR/ShapeInterfaces.cpp.inc"
-
-// Used to control inlining behavior.
-struct ShapeInlinerInterface : public DialectInlinerInterface {
- using DialectInlinerInterface::DialectInlinerInterface;
-
- // Allow all call operations to be inlined.
- bool isLegalToInline(Operation* call, Operation* callable,
- bool wouldBeCloned) const final {
- return true;
- }
- bool isLegalToInline(Operation* op, Region* dest, bool wouldBeCloned,
- BlockAndValueMapping& valueMapping) const final {
- return true;
- }
-};
-
-// Used to control constant folding behavior as a fallback on the dialect when
-// individual op folder does not match.
-struct ShapeConstantFoldInterface : public DialectFoldInterface {
- using DialectFoldInterface::DialectFoldInterface;
-
- LogicalResult fold(Operation* op, ArrayRef<Attribute> operands,
- SmallVectorImpl<OpFoldResult>& results) const final {
- bool foundConstantRankedShape = false;
- for (Value result : op->getResults()) {
- auto rankedShape = result.getType().dyn_cast<Shape::RankedShapeType>();
- if (rankedShape && rankedShape.isFullyStatic()) {
- foundConstantRankedShape = true;
- results.push_back(TypeAttr::get(rankedShape));
- } else {
- results.push_back(nullptr);
- }
- }
- return success(foundConstantRankedShape);
- }
-};
-
-ShapeDialect::ShapeDialect(MLIRContext* context)
- : Dialect(getDialectNamespace(), context, TypeID::get<ShapeDialect>()) {
- registerTypes();
- addInterfaces<ShapeConstantFoldInterface, ShapeInlinerInterface>();
-#define GET_OP_LIST
- addOperations<
-#include "iree/compiler/Dialect/Shape/IR/ShapeOps.cpp.inc"
- >();
-}
-
-Operation* ShapeDialect::materializeConstant(OpBuilder& builder,
- Attribute value, Type type,
- Location loc) {
- if (auto typeAttr = value.dyn_cast<TypeAttr>()) {
- auto rankedShape = typeAttr.getValue().cast<Shape::RankedShapeType>();
- return builder.create<Shape::ConstRankedShapeOp>(loc, rankedShape);
- }
- if (arith::ConstantOp::isBuildableWith(value, type))
- return builder.create<arith::ConstantOp>(loc, type, value);
- if (mlir::ConstantOp::isBuildableWith(value, type))
- return builder.create<mlir::ConstantOp>(loc, type, value);
- return nullptr;
-}
-
-//===----------------------------------------------------------------------===//
-// Type parsing and printing
-//===----------------------------------------------------------------------===//
-
-static Type parseRankedShape(DialectAsmParser& parser) {
- llvm::SmallVector<int64_t, 7> dims;
- Type dimType;
- // Parse: ranked_shape<[
- if (failed(parser.parseLess()) || failed(parser.parseLSquare()))
- return nullptr;
-
- // Parse list of comma-separated dims, where each dim is an integer >= 0
- // or ?.
- for (bool first = true;; first = false) {
- if (!first) {
- if (failed(parser.parseOptionalComma())) break;
- }
-
- int64_t dim;
- OptionalParseResult optionalInteger = parser.parseOptionalInteger(dim);
- if (optionalInteger.hasValue()) {
- if (dim < 0) {
- parser.emitError(parser.getNameLoc(), "expected dim >= 0 or '?'");
- return nullptr;
- }
- } else if (succeeded(parser.parseOptionalQuestion())) {
- dim = -1;
- } else if (first) {
- // It is fine to not have a first dim.
- break;
- } else {
- parser.emitError(parser.getNameLoc(), "expected shape dim");
- return nullptr;
- }
- dims.push_back(dim);
- }
- if (failed(parser.parseRSquare())) return nullptr;
-
- // Parse optional: , type
- if (succeeded(parser.parseOptionalComma())) {
- if (failed(parser.parseType(dimType))) {
- return nullptr;
- }
- } else {
- dimType = parser.getBuilder().getIndexType();
- }
- if (failed(parser.parseGreater())) {
- parser.emitError(parser.getNameLoc(), "expected terminating '>'");
- return nullptr;
- }
-
- return Shape::RankedShapeType::getChecked(
- dims, parser.getEncodedSourceLoc(parser.getNameLoc()));
-}
-
-static void printRankedShape(Shape::RankedShapeType type,
- DialectAsmPrinter& printer) {
- auto dims = type.getAllDims();
- printer << "ranked_shape<[";
- interleave(
- dims, printer,
- [&](int64_t dim) {
- if (dim < 0)
- printer << "?";
- else
- printer << dim;
- },
- ",");
- printer << "]>";
-}
-
-Type ShapeDialect::parseType(DialectAsmParser& parser) const {
- Location loc = parser.getEncodedSourceLoc(parser.getNameLoc());
- llvm::StringRef spec = parser.getFullSymbolSpec();
- if (succeeded(parser.parseOptionalKeyword("ranked_shape"))) {
- return parseRankedShape(parser);
- }
- emitError(loc, "unknown Shape type: ") << spec;
- return Type();
-}
-
-void ShapeDialect::printType(Type type, DialectAsmPrinter& os) const {
- if (auto rankedShapeTy = type.dyn_cast<Shape::RankedShapeType>())
- return printRankedShape(type.cast<Shape::RankedShapeType>(), os);
- llvm_unreachable("unhandled Shape type");
-}
-
-} // namespace iree_compiler
-} // namespace mlir
diff --git a/iree/compiler/Dialect/Shape/IR/ShapeDialect.h b/iree/compiler/Dialect/Shape/IR/ShapeDialect.h
deleted file mode 100644
index 1c96300..0000000
--- a/iree/compiler/Dialect/Shape/IR/ShapeDialect.h
+++ /dev/null
@@ -1,38 +0,0 @@
-// Copyright 2019 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#ifndef IREE_COMPILER_DIALECT_SHAPE_IR_IREEDIALECT_H_
-#define IREE_COMPILER_DIALECT_SHAPE_IR_IREEDIALECT_H_
-
-#include "mlir/IR/Dialect.h"
-#include "mlir/IR/OpDefinition.h"
-
-namespace mlir {
-namespace iree_compiler {
-
-#include "iree/compiler/Dialect/Shape/IR/ShapeInterfaces.h.inc"
-
-class ShapeDialect : public Dialect {
- public:
- explicit ShapeDialect(MLIRContext* context);
- // TODO(b/143787186): rename to iree.
- static StringRef getDialectNamespace() { return "shapex"; }
-
- Type parseType(DialectAsmParser& parser) const override;
- void printType(Type type, DialectAsmPrinter& os) const override;
-
- Operation* materializeConstant(OpBuilder& builder, Attribute value, Type type,
- Location loc) override;
-
- private:
- /// Register the types of this dialect.
- void registerTypes();
-};
-
-} // namespace iree_compiler
-} // namespace mlir
-
-#endif // IREE_COMPILER_DIALECT_SHAPE_IR_IREEDIALECT_H_
diff --git a/iree/compiler/Dialect/Shape/IR/ShapeInterfaces.td b/iree/compiler/Dialect/Shape/IR/ShapeInterfaces.td
deleted file mode 100644
index af2971d..0000000
--- a/iree/compiler/Dialect/Shape/IR/ShapeInterfaces.td
+++ /dev/null
@@ -1,50 +0,0 @@
-// Copyright 2021 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#ifndef IREE_DIALECT_SHAPE_INTERFACES
-#define IREE_DIALECT_SHAPE_INTERFACES
-
-include "mlir/IR/OpBase.td"
-
-//===----------------------------------------------------------------------===//
-// Op interfaces
-//===----------------------------------------------------------------------===//
-
-def Shape_ShapeCarryingOpInterface : OpInterface<"ShapeCarryingInterface"> {
- let description = [{
- Interface for ops that interact with dynamically shaped inputs and outputs.
- Such ops are able to materialize RankedShapes on demand for any operand or
- result that derives from ShapedType.
- }];
-
- let methods = [
- StaticInterfaceMethod<
- /*desc=*/[{Returns a RankedShape for the given shaped result value.}],
- /*retTy=*/"Value",
- /*methodName=*/"buildResultValueRankedShape",
- /*args=*/(ins "Value":$result, "OpBuilder &":$builder),
- /*methodBody=*/[{
- auto carryingOp = dyn_cast<ShapeCarryingInterface>(result.getDefiningOp());
- return carryingOp.buildResultRankedShape(
- result.cast<mlir::OpResult>().getResultNumber(), builder);
- }]
- >,
- InterfaceMethod<
- /*desc=*/[{Returns a RankedShape for the given shaped operand index.}],
- /*retTy=*/"Value",
- /*methodName=*/"buildOperandRankedShape",
- /*args=*/(ins "unsigned":$idx, "OpBuilder &":$builder)
- >,
- InterfaceMethod<
- /*desc=*/[{Returns a RankedShape for the given shaped result index.}],
- /*retTy=*/"Value",
- /*methodName=*/"buildResultRankedShape",
- /*args=*/(ins "unsigned":$idx, "OpBuilder &":$builder)
- >,
- ];
-}
-
-#endif // IREE_DIALECT_SHAPE_INTERFACES
diff --git a/iree/compiler/Dialect/Shape/IR/ShapeOps.cpp b/iree/compiler/Dialect/Shape/IR/ShapeOps.cpp
deleted file mode 100644
index b058e49..0000000
--- a/iree/compiler/Dialect/Shape/IR/ShapeOps.cpp
+++ /dev/null
@@ -1,137 +0,0 @@
-// Copyright 2019 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#include "iree/compiler/Dialect/Shape/IR/ShapeOps.h"
-
-#include "iree/compiler/Dialect/Shape/IR/ShapeTypes.h"
-#include "llvm/ADT/STLExtras.h"
-#include "llvm/ADT/SmallString.h"
-#include "llvm/ADT/SmallVector.h"
-#include "llvm/Support/Casting.h"
-#include "llvm/Support/SMLoc.h"
-#include "mlir/Dialect/StandardOps/IR/Ops.h"
-#include "mlir/IR/Attributes.h"
-#include "mlir/IR/Builders.h"
-#include "mlir/IR/BuiltinTypes.h"
-#include "mlir/IR/Diagnostics.h"
-#include "mlir/IR/Matchers.h"
-#include "mlir/IR/OpImplementation.h"
-#include "mlir/IR/OperationSupport.h"
-#include "mlir/IR/PatternMatch.h"
-#include "mlir/IR/TypeUtilities.h"
-#include "mlir/IR/Value.h"
-#include "mlir/Support/LLVM.h"
-#include "mlir/Support/LogicalResult.h"
-
-namespace mlir {
-namespace iree_compiler {
-namespace Shape {
-
-//===----------------------------------------------------------------------===//
-// shapex.const_ranked_shape
-//===----------------------------------------------------------------------===//
-
-void ConstRankedShapeOp::build(OpBuilder &builder, OperationState &result,
- Type type) {
- assert(type.cast<RankedShapeType>().isFullyStatic());
- result.types.push_back(type);
-}
-
-static LogicalResult verifyConstRankedShapeOp(ConstRankedShapeOp op) {
- auto rsType = op.result().getType().dyn_cast<RankedShapeType>();
- if (!rsType || !rsType.isFullyStatic()) {
- return op.emitOpError("must be a fully static ranked_shape");
- }
- return success();
-}
-
-void ConstRankedShapeOp::getAsmResultNames(
- function_ref<void(Value, StringRef)> setNameFn) {
- auto rankedShape = result().getType().cast<RankedShapeType>();
- SmallString<32> buffer;
- llvm::raw_svector_ostream os(buffer);
- os << "rs";
- interleave(
- rankedShape.getAllDims(), os, [&](int64_t dim) { os << dim; }, "_");
- setNameFn(getResult(), os.str());
-}
-
-//===----------------------------------------------------------------------===//
-// shapex.make_ranked_shape
-//===----------------------------------------------------------------------===//
-
-static LogicalResult verifyMakeRankedShapeOp(MakeRankedShapeOp op) {
- if (op.getRankedShapeType().getNumDynamicDims() != op.getNumOperands()) {
- return op.emitError()
- << "number of dynamic dims doesn't match number of operands";
- }
- return success();
-}
-
-//===----------------------------------------------------------------------===//
-// shapex.ranked_dim
-//===----------------------------------------------------------------------===//
-
-void RankedDimOp::build(OpBuilder &builder, OperationState &result,
- Type dimType, Value shape, int index) {
- result.addOperands(shape);
- result.addAttribute("index",
- builder.getIntegerAttr(builder.getIndexType(), index));
- result.addTypes(dimType);
-}
-
-void RankedDimOp::build(OpBuilder &builder, OperationState &result, Value shape,
- int index) {
- RankedDimOp::build(builder, result, builder.getIndexType(), shape, index);
-}
-
-ParseResult parseRankedDimOp(OpAsmParser &parser, OperationState &state) {
- OpAsmParser::OperandType operand;
- Type operandType;
- IntegerAttr indexAttr;
- Type indexType = parser.getBuilder().getIndexType();
- SmallVector<Type, 1> resultTypes;
- if (parser.parseOperand(operand) || parser.parseLSquare() ||
- parser.parseAttribute(indexAttr, indexType, "index", state.attributes) ||
- parser.parseRSquare() || parser.parseColonType(operandType) ||
- parser.parseArrowTypeList(resultTypes) || resultTypes.empty() ||
- parser.resolveOperand(operand, operandType, state.operands)) {
- return failure();
- }
-
- auto rsType = operandType.dyn_cast<RankedShapeType>();
- if (!rsType) {
- return parser.emitError(parser.getNameLoc());
- }
- state.types.push_back(resultTypes[0]);
- return success();
-}
-
-static void printRankedDimOp(OpAsmPrinter &p, RankedDimOp op) {
- p << " ";
- p.printOperand(op.shape());
- p << "[" << op.getIndex() << "]";
- p << " : ";
- p.printType(op.shape().getType());
- p << " -> ";
- p.printType(op.getType());
-}
-
-static LogicalResult verifyRankedDimOp(RankedDimOp op) {
- auto rsType = op.shape().getType().dyn_cast<RankedShapeType>();
- auto index = static_cast<int64_t>(op.getIndex());
- if (index < 0 || index >= rsType.getRank()) {
- return op.emitOpError() << "index out of bounds of shape";
- }
- return success();
-}
-
-} // namespace Shape
-} // namespace iree_compiler
-} // namespace mlir
-
-#define GET_OP_CLASSES
-#include "iree/compiler/Dialect/Shape/IR/ShapeOps.cpp.inc"
diff --git a/iree/compiler/Dialect/Shape/IR/ShapeOps.h b/iree/compiler/Dialect/Shape/IR/ShapeOps.h
deleted file mode 100644
index 47acfcd..0000000
--- a/iree/compiler/Dialect/Shape/IR/ShapeOps.h
+++ /dev/null
@@ -1,38 +0,0 @@
-// Copyright 2019 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#ifndef IREE_COMPILER_DIALECT_SHAPE_IR_SHAPEOPS_H_
-#define IREE_COMPILER_DIALECT_SHAPE_IR_SHAPEOPS_H_
-
-#include "iree/compiler/Dialect/Shape/IR/ShapeDialect.h"
-#include "iree/compiler/Dialect/Shape/IR/ShapeTypes.h"
-#include "mlir/IR/Attributes.h"
-#include "mlir/IR/BuiltinTypes.h"
-#include "mlir/IR/Dialect.h"
-#include "mlir/IR/OpDefinition.h"
-#include "mlir/IR/OpImplementation.h"
-#include "mlir/Interfaces/InferTypeOpInterface.h"
-#include "mlir/Interfaces/SideEffectInterfaces.h"
-#include "mlir/Interfaces/ViewLikeInterface.h"
-
-namespace mlir {
-namespace iree_compiler {
-namespace Shape {
-
-// Populates conversion patterns that perform folding and canonicalization of
-// shape ops. These patterns are intended to be used with the dialect conversion
-// framework.
-void populateFoldConversionPatterns(MLIRContext *context,
- OwningRewritePatternList &patterns);
-
-} // namespace Shape
-} // namespace iree_compiler
-} // namespace mlir
-
-#define GET_OP_CLASSES
-#include "iree/compiler/Dialect/Shape/IR/ShapeOps.h.inc"
-
-#endif // IREE_COMPILER_DIALECT_SHAPE_IR_SHAPEOPS_H_
diff --git a/iree/compiler/Dialect/Shape/IR/ShapeOps.td b/iree/compiler/Dialect/Shape/IR/ShapeOps.td
deleted file mode 100644
index 73d0432..0000000
--- a/iree/compiler/Dialect/Shape/IR/ShapeOps.td
+++ /dev/null
@@ -1,119 +0,0 @@
-// Copyright 2019 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#ifndef IREE_DIALECT_SHAPE_OPS
-#define IREE_DIALECT_SHAPE_OPS
-
-include "iree/compiler/Dialect/Shape/IR/ShapeBase.td"
-include "iree/compiler/Dialect/Shape/IR/ShapeInterfaces.td"
-include "mlir/Interfaces/InferTypeOpInterface.td"
-include "mlir/Interfaces/SideEffectInterfaces.td"
-include "mlir/Interfaces/ViewLikeInterface.td"
-include "mlir/IR/OpAsmInterface.td"
-
-//===----------------------------------------------------------------------===//
-// Op types
-//===----------------------------------------------------------------------===//
-
-class Shape_Op<string mnemonic, list<OpTrait> traits = []> :
- Op<Shape_Dialect, mnemonic, traits> {
- let parser = [{ return parse$cppClass(parser, result); }];
- let printer = [{ print$cppClass(p, *this); }];
-}
-
-class Shape_PureOp<string mnemonic, list<OpTrait> traits = []> :
- Shape_Op<mnemonic, !listconcat(traits, [NoSideEffect])>;
-
-//===----------------------------------------------------------------------===//
-// RankedShapeType manipulation
-//===----------------------------------------------------------------------===//
-
-def Shape_ConstRankedShapeOp : Shape_PureOp<"const_ranked_shape",
- [ConstantLike, DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>]> {
- let summary = "A constant ranked_shape.";
- let description = [{
- Holds a RankedShape value. Note that it is only legal to store a constant
- RankedShape that is fully static, as anything more specific should be
- in the type, not have dims represented as const SSA values.
-
- Usage:
- %0 = shapex.const_ranked_shape : !shapex.ranked_shape<[1,2]>
- }];
-
- let arguments = (ins);
- let results = (outs Shape_RankedShape:$result);
-
- let assemblyFormat = "attr-dict `:` type($result)";
-
- let skipDefaultBuilders = 1;
- let builders = [
- OpBuilder<(ins "Type":$type)>,
- ];
- let verifier = [{ return verify$cppClass(*this); }];
-}
-
-def Shape_MakeRankedShapeOp : Shape_PureOp<"make_ranked_shape"> {
- let summary = "Makes a ranked_shape from individual dims.";
- let description = [{
- Given a list of SSA values holding compatible dims, makes a corresponding
- ranked_shape.
-
- Usage:
- %0 = shapex.make_ranked_shape %dim0, %dim1 : (i32, i32) ->
- !shapex.ranked_shape<[?,?,128]>
-
- Note that the type of the dims is is implied by the dim type of the result.
- }];
-
- let arguments = (ins Variadic<Shape_DimType>:$dynamic_dimensions);
- let results = (outs Shape_RankedShape:$shape);
-
- let assemblyFormat = "$dynamic_dimensions `:` functional-type($dynamic_dimensions, $shape) attr-dict";
-
- let extraClassDeclaration = [{
- RankedShapeType getRankedShapeType() {
- return shape().getType().cast<RankedShapeType>();
- }
- }];
- let verifier = [{ return verify$cppClass(*this); }];
- let hasCanonicalizer = 1;
-}
-
-def Shape_RankedDimOp : Shape_PureOp<"ranked_dim"> {
- let summary = "Gets a dimension value from a ranked_shape.";
- let description = [{
- Static dimensions will fold to constants.
-
- Usage:
- %0 = shapex.const ranked_shape : !shapex.ranked_shape<[1,2]>
- %1 = shapex.ranked_dim %0[0] : !shapex.ranked_shape<[1,2]> -> i32
- }];
-
- let arguments = (ins Shape_RankedShape:$shape,
- APIntAttr:$index);
- let results = (outs Shape_DimType:$result);
- let verifier = [{ return verify$cppClass(*this); }];
-
- let builders = [
- OpBuilder<(ins "Type":$dimType, "Value":$shape, "int":$index)>,
- // dimType is defaulted to IndexType.
- OpBuilder<(ins "Value":$shape, "int":$index)>,
- ];
-
- let extraClassDeclaration = [{
- RankedShapeType getRankedShapeType() {
- return shape().getType().cast<RankedShapeType>();
- }
- unsigned getIndex() {
- auto index = getOperation()->getAttrOfType<IntegerAttr>("index");
- return index.getValue().getZExtValue();
- }
- }];
- let hasFolder = 1;
- let hasCanonicalizer = 1;
-}
-
-#endif // IREE_DIALECT_SHAPE_OPS
diff --git a/iree/compiler/Dialect/Shape/IR/ShapeTypes.cpp b/iree/compiler/Dialect/Shape/IR/ShapeTypes.cpp
deleted file mode 100644
index d1b784c..0000000
--- a/iree/compiler/Dialect/Shape/IR/ShapeTypes.cpp
+++ /dev/null
@@ -1,128 +0,0 @@
-// Copyright 2020 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#include "iree/compiler/Dialect/Shape/IR/ShapeTypes.h"
-
-#include "iree/compiler/Dialect/Shape/IR/ShapeDialect.h"
-#include "llvm/ADT/Twine.h"
-#include "mlir/IR/BuiltinTypes.h"
-#include "mlir/IR/Diagnostics.h"
-#include "mlir/IR/MLIRContext.h"
-#include "mlir/IR/TypeSupport.h"
-
-namespace mlir {
-namespace iree_compiler {
-namespace Shape {
-namespace detail {
-
-struct RankedShapeTypeStorage : public TypeStorage {
- struct KeyTy {
- KeyTy(ArrayRef<int64_t> dims) : dims(dims) {}
- bool operator==(const KeyTy &other) const {
- return dims.equals(other.dims);
- }
- unsigned getHashValue() const {
- return llvm::hash_combine_range(dims.begin(), dims.end());
- }
- ArrayRef<int64_t> dims;
- };
-
- RankedShapeTypeStorage(const KeyTy &key) : key(key) {}
- static RankedShapeTypeStorage *construct(TypeStorageAllocator &allocator,
- KeyTy key) {
- key.dims = allocator.copyInto(key.dims);
- return new (allocator.allocate<RankedShapeTypeStorage>())
- RankedShapeTypeStorage(key);
- }
-
- bool operator==(const KeyTy &otherKey) const { return key == otherKey; }
- static unsigned hashKey(const KeyTy &key) { return key.getHashValue(); }
-
- KeyTy key;
-};
-
-} // namespace detail
-} // namespace Shape
-} // namespace iree_compiler
-} // namespace mlir
-
-using namespace mlir;
-using namespace mlir::iree_compiler::Shape;
-
-//===----------------------------------------------------------------------===//
-// RankedShapeType
-//===----------------------------------------------------------------------===//
-
-RankedShapeType RankedShapeType::get(ArrayRef<int64_t> dims,
- MLIRContext *context) {
- return Base::get(context, dims);
-}
-
-RankedShapeType RankedShapeType::getChecked(ArrayRef<int64_t> dims,
- Location loc) {
- return Base::getChecked(loc, loc.getContext(), dims);
-}
-
-RankedShapeType RankedShapeType::getChecked(
- function_ref<InFlightDiagnostic()> emitError, MLIRContext *context,
- ArrayRef<int64_t> dims) {
- return Base::getChecked(emitError, context, dims);
-}
-
-RankedShapeType RankedShapeType::get(ShapedType shapedType) {
- return Base::get(shapedType.getContext(), shapedType.getShape());
-}
-
-LogicalResult RankedShapeType::verify(
- function_ref<InFlightDiagnostic()> emitError, ArrayRef<int64_t> dims) {
- for (auto dim : dims) {
- if (dim < 0 && dim != -1) {
- return emitError() << "dims must be -1 for dynamic";
- }
- }
- return success();
-}
-
-int64_t RankedShapeType::getRank() const { return getImpl()->key.dims.size(); }
-
-bool RankedShapeType::isFullyStatic() const {
- for (auto dim : getImpl()->key.dims) {
- if (dim < 0) return false;
- }
- return true;
-}
-
-ArrayRef<int64_t> RankedShapeType::getAllDims() const {
- return getImpl()->key.dims;
-}
-
-unsigned RankedShapeType::getNumDynamicDims() const {
- auto allDims = getAllDims();
- return std::count_if(allDims.begin(), allDims.end(),
- [](int64_t dim) { return dim < 0; });
-}
-
-bool RankedShapeType::isDimDynamic(int allDimsIndex) const {
- assert(allDimsIndex >= 0 && allDimsIndex < getImpl()->key.dims.size());
- return getImpl()->key.dims[allDimsIndex] < 0;
-}
-
-int64_t RankedShapeType::getStaticDim(int allDimsIndex) const {
- assert(allDimsIndex >= 0 && allDimsIndex < getRank());
- auto dim = getAllDims()[allDimsIndex];
- assert(dim >= 0 && "getStaticDim() called on dynamic dimension");
- return dim;
-}
-
-//===----------------------------------------------------------------------===//
-// ShapeDialect
-//===----------------------------------------------------------------------===//
-
-namespace mlir {
-namespace iree_compiler {
-void ShapeDialect::registerTypes() { addTypes<Shape::RankedShapeType>(); }
-} // namespace iree_compiler
-} // namespace mlir
diff --git a/iree/compiler/Dialect/Shape/IR/ShapeTypes.h b/iree/compiler/Dialect/Shape/IR/ShapeTypes.h
deleted file mode 100644
index 3a7ca07..0000000
--- a/iree/compiler/Dialect/Shape/IR/ShapeTypes.h
+++ /dev/null
@@ -1,75 +0,0 @@
-// Copyright 2019 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#ifndef IREE_COMPILER_DIALECT_SHAPE_IR_IREETYPES_H_
-#define IREE_COMPILER_DIALECT_SHAPE_IR_IREETYPES_H_
-
-#include "iree/compiler/Dialect/Util/IR/UtilTypes.h"
-#include "mlir/IR/BuiltinTypes.h"
-#include "mlir/IR/Diagnostics.h"
-#include "mlir/IR/Location.h"
-#include "mlir/IR/TypeSupport.h"
-#include "mlir/IR/Types.h"
-
-namespace mlir {
-namespace iree_compiler {
-namespace Shape {
-
-namespace detail {
-struct RankedShapeTypeStorage;
-} // namespace detail
-
-// A shape with a fixed ranked and a mixture of static and dynamic dimensions
-// which can express partially shaped values in the tensor domain and be
-// easily lowered to the memref domain (only retaining the dynamic dims upon
-// conversion).
-class RankedShapeType : public Type::TypeBase<RankedShapeType, Type,
- detail::RankedShapeTypeStorage> {
- public:
- using Base::Base;
-
- // Gets an instance of a RankedShapeType given an array of dimensions.
- // Any dynamic dim should be -1.
- static RankedShapeType get(ArrayRef<int64_t> dims, MLIRContext *context);
- static RankedShapeType getChecked(ArrayRef<int64_t> dims, Location loc);
- static RankedShapeType getChecked(
- function_ref<InFlightDiagnostic()> emitError, MLIRContext *context,
- ArrayRef<int64_t> dims);
-
- // Derives a RankedShapeType from a ShapedType.
- static RankedShapeType get(ShapedType shapedType);
-
- // Verifies construction invariants and issues errors/warnings.
- static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
- ArrayRef<int64_t> dims);
-
- // Gets the rank (counting all dims, static and dynamic).
- int64_t getRank() const;
-
- // Whether the shape is fully static.
- bool isFullyStatic() const;
-
- // Gets all dims of this shape, where dynamic dims are represented by -1.
- // The size of the dims vector will be the same as reported by getRank().
- ArrayRef<int64_t> getAllDims() const;
-
- // Gets the number of dynamic dims.
- unsigned getNumDynamicDims() const;
-
- // Returns whether the indexed dimension is dynamic.
- bool isDimDynamic(int allDimsIndex) const;
-
- // Returns the static dimension at the overall shape index.
- // It is an error to request a static index for which isDimDynamic() is
- // true.
- int64_t getStaticDim(int allDimsIndex) const;
-};
-
-} // namespace Shape
-} // namespace iree_compiler
-} // namespace mlir
-
-#endif // IREE_COMPILER_DIALECT_SHAPE_IR_IREETYPES_H_
diff --git a/iree/compiler/Dialect/Shape/IR/test/BUILD b/iree/compiler/Dialect/Shape/IR/test/BUILD
deleted file mode 100644
index a37206f..0000000
--- a/iree/compiler/Dialect/Shape/IR/test/BUILD
+++ /dev/null
@@ -1,31 +0,0 @@
-# Copyright 2019 The IREE Authors
-#
-# Licensed under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-load("//iree:lit_test.bzl", "iree_lit_test_suite")
-load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob")
-
-package(
- default_visibility = ["//visibility:public"],
- features = ["layering_check"],
- licenses = ["notice"], # Apache 2.0
-)
-
-iree_lit_test_suite(
- name = "lit",
- srcs = enforce_glob(
- [
- "canonicalize.mlir",
- "op_verification.mlir",
- "parse_print.mlir",
- "ranked_shape_type.mlir",
- ],
- include = ["*.mlir"],
- ),
- data = [
- "//iree/tools:IreeFileCheck",
- "//iree/tools:iree-opt",
- ],
-)
diff --git a/iree/compiler/Dialect/Shape/IR/test/CMakeLists.txt b/iree/compiler/Dialect/Shape/IR/test/CMakeLists.txt
deleted file mode 100644
index fa8db2b..0000000
--- a/iree/compiler/Dialect/Shape/IR/test/CMakeLists.txt
+++ /dev/null
@@ -1,26 +0,0 @@
-################################################################################
-# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from #
-# iree/compiler/Dialect/Shape/IR/test/BUILD #
-# #
-# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary #
-# CMake-only content. #
-# #
-# To disable autogeneration for this file entirely, delete this header. #
-################################################################################
-
-iree_add_all_subdirs()
-
-iree_lit_test_suite(
- NAME
- lit
- SRCS
- "canonicalize.mlir"
- "op_verification.mlir"
- "parse_print.mlir"
- "ranked_shape_type.mlir"
- DATA
- iree::tools::IreeFileCheck
- iree::tools::iree-opt
-)
-
-### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/iree/compiler/Dialect/Shape/IR/test/canonicalize.mlir b/iree/compiler/Dialect/Shape/IR/test/canonicalize.mlir
deleted file mode 100644
index af00bb8..0000000
--- a/iree/compiler/Dialect/Shape/IR/test/canonicalize.mlir
+++ /dev/null
@@ -1,94 +0,0 @@
-// RUN: iree-opt -split-input-file -verify-diagnostics -canonicalize %s | IreeFileCheck %s
-
-// -----
-// CHECK-LABEL: @foldStaticRankedDim
-// CHECK-SAME: %[[SHAPE:[^:[:space:]]+]]: !shapex.ranked_shape<[1,?,2,?]>
-func @foldStaticRankedDim(%arg0: !shapex.ranked_shape<[1,?,2,?]>) -> (i32, i32) {
- // CHECK-DAG: %[[D2:.+]] = arith.constant 2 : i32
- %0 = shapex.ranked_dim %arg0[2] : !shapex.ranked_shape<[1,?,2,?]> -> i32
- // CHECK-DAG: %[[D1:.+]] = shapex.ranked_dim %[[SHAPE]][1]
- %1 = shapex.ranked_dim %arg0[1] : !shapex.ranked_shape<[1,?,2,?]> -> i32
- // CHECK: return %[[D2]], %[[D1]]
- return %0, %1 : i32, i32
-}
-
-// -----
-// CHECK-LABEL: @dynamicMakeRankedShapeDim
-// CHECK-SAME: %[[DD0:[^:[:space:]]+]]: index
-// CHECK-SAME: %[[DD1:[^:[:space:]]+]]: index
-func @dynamicMakeRankedShapeDim(%arg0: index, %arg1 : index) -> (index, index, index, index) {
- // CHECK-NOT: make_ranked_shape
- // CHECK-NOT: ranked_dim
- %rs = shapex.make_ranked_shape %arg0, %arg1 : (index, index) -> !shapex.ranked_shape<[?,8,?,16]>
- %d0 = shapex.ranked_dim %rs[0] : !shapex.ranked_shape<[?,8,?,16]> -> index
- %d1 = shapex.ranked_dim %rs[1] : !shapex.ranked_shape<[?,8,?,16]> -> index
- %d2 = shapex.ranked_dim %rs[2] : !shapex.ranked_shape<[?,8,?,16]> -> index
- %d3 = shapex.ranked_dim %rs[3] : !shapex.ranked_shape<[?,8,?,16]> -> index
- // CHECK-DAG: %[[C8:.+]] = arith.constant 8 : index
- // CHECK-DAG: %[[C16:.+]] = arith.constant 16 : index
- // CHECK-DAG: return %[[DD0]], %[[C8]], %[[DD1]], %[[C16]]
- return %d0, %d1, %d2, %d3 : index, index, index, index
-}
-
-//===----------------------------------------------------------------------===//
-// IdentityMakeRankedShapePattern tests
-//===----------------------------------------------------------------------===//
-
-// -----
-// CHECK-LABEL: @identityMakeRankedShape_match_1dim
-// CHECK-SAME: %[[ARGRS:[^:[:space:]]+]]: !shapex.ranked_shape
-func @identityMakeRankedShape_match_1dim(%arg0 : !shapex.ranked_shape<[?,16]>) -> !shapex.ranked_shape<[?,16]> {
- // CHECK-NOT: shapex.make_ranked_shape
- %0 = shapex.ranked_dim %arg0[0] : !shapex.ranked_shape<[?,16]> -> index
- %1 = shapex.make_ranked_shape %0 : (index) -> !shapex.ranked_shape<[?,16]>
- // CHECK: return %[[ARGRS]]
- return %1 : !shapex.ranked_shape<[?,16]>
-}
-
-// -----
-// CHECK-LABEL: @identityMakeRankedShape_match_2dim
-// CHECK-SAME: %[[ARGRS:[^:[:space:]]+]]: !shapex.ranked_shape
-func @identityMakeRankedShape_match_2dim(%arg0 : !shapex.ranked_shape<[?,16,?]>) -> !shapex.ranked_shape<[?,16,?]> {
- // CHECK-NOT: shapex.make_ranked_shape
- %0 = shapex.ranked_dim %arg0[0] : !shapex.ranked_shape<[?,16,?]> -> index
- %1 = shapex.ranked_dim %arg0[2] : !shapex.ranked_shape<[?,16,?]> -> index
- %2 = shapex.make_ranked_shape %0, %1 : (index, index) -> !shapex.ranked_shape<[?,16,?]>
- // CHECK: return %[[ARGRS]]
- return %2 : !shapex.ranked_shape<[?,16,?]>
-}
-
-// -----
-// CHECK-LABEL: @identityMakeRankedShape_nomatch_swap_dims
-// CHECK-SAME: %[[ARGRS:[^:[:space:]]+]]: !shapex.ranked_shape
-func @identityMakeRankedShape_nomatch_swap_dims(%arg0 : !shapex.ranked_shape<[?,16,?]>) -> !shapex.ranked_shape<[?,16,?]> {
- %0 = shapex.ranked_dim %arg0[2] : !shapex.ranked_shape<[?,16,?]> -> index
- %1 = shapex.ranked_dim %arg0[0] : !shapex.ranked_shape<[?,16,?]> -> index
- %2 = shapex.make_ranked_shape %0, %1 : (index, index) -> !shapex.ranked_shape<[?,16,?]>
- // CHECK: %[[RS:.+]] = shapex.make_ranked_shape
- // CHECK: return %[[RS]]
- return %2 : !shapex.ranked_shape<[?,16,?]>
-}
-
-// -----
-// CHECK-LABEL: @identityMakeRankedShape_nomatch_static_dim
-// CHECK-SAME: %[[ARGRS:[^:[:space:]]+]]: !shapex.ranked_shape
-func @identityMakeRankedShape_nomatch_static_dim(%arg0 : !shapex.ranked_shape<[?,16,?]>) -> !shapex.ranked_shape<[?,16,?]> {
- %0 = shapex.ranked_dim %arg0[1] : !shapex.ranked_shape<[?,16,?]> -> index
- %1 = shapex.ranked_dim %arg0[2] : !shapex.ranked_shape<[?,16,?]> -> index
- %2 = shapex.make_ranked_shape %0, %1 : (index, index) -> !shapex.ranked_shape<[?,16,?]>
- // CHECK: %[[RS:.+]] = shapex.make_ranked_shape
- // CHECK: return %[[RS]]
- return %2 : !shapex.ranked_shape<[?,16,?]>
-}
-
-// CHECK-LABEL: @identityMakeRankedShape_nomatch_different_shape
-// CHECK-SAME: %[[ARGRS1:[^:[:space:]]+]]: !shapex.ranked_shape
-// CHECK-SAME: %[[ARGRS2:[^:[:space:]]+]]: !shapex.ranked_shape
-func @identityMakeRankedShape_nomatch_different_shape(%arg0 : !shapex.ranked_shape<[?,16,?]>, %arg1 : !shapex.ranked_shape<[?,16,?]>) -> !shapex.ranked_shape<[?,16,?]> {
- %0 = shapex.ranked_dim %arg0[0] : !shapex.ranked_shape<[?,16,?]> -> index
- %1 = shapex.ranked_dim %arg1[2] : !shapex.ranked_shape<[?,16,?]> -> index
- %2 = shapex.make_ranked_shape %0, %1 : (index, index) -> !shapex.ranked_shape<[?,16,?]>
- // CHECK: %[[RS:.+]] = shapex.make_ranked_shape
- // CHECK: return %[[RS]]
- return %2 : !shapex.ranked_shape<[?,16,?]>
-}
diff --git a/iree/compiler/Dialect/Shape/IR/test/op_verification.mlir b/iree/compiler/Dialect/Shape/IR/test/op_verification.mlir
deleted file mode 100644
index 580cdf3..0000000
--- a/iree/compiler/Dialect/Shape/IR/test/op_verification.mlir
+++ /dev/null
@@ -1,22 +0,0 @@
-// RUN: iree-opt -split-input-file -verify-diagnostics %s
-
-// -----
-func @const_ranked_shape_wrong_type() {
- // expected-error @+1 {{result #0 must be Ranked shape type, but got 'i32'}}
- %0 = shapex.const_ranked_shape : i32
- return
-}
-
-// -----
-func @const_ranked_shape_not_static() {
- // expected-error @+1 {{must be a fully static ranked_shape}}
- %0 = shapex.const_ranked_shape : !shapex.ranked_shape<[2,?,4]>
- return
-}
-
-// -----
-func @ranked_dim_out_of_range(%arg0 : !shapex.ranked_shape<[2,4]>) {
- // expected-error @+1 {{index out of bounds of shape}}
- %0 = shapex.ranked_dim %arg0[2] : !shapex.ranked_shape<[2,4]> -> index
- return
-}
diff --git a/iree/compiler/Dialect/Shape/IR/test/parse_print.mlir b/iree/compiler/Dialect/Shape/IR/test/parse_print.mlir
deleted file mode 100644
index 61af782..0000000
--- a/iree/compiler/Dialect/Shape/IR/test/parse_print.mlir
+++ /dev/null
@@ -1,29 +0,0 @@
-// RUN: iree-opt -split-input-file %s | iree-opt -split-input-file | IreeFileCheck %s
-
-// -----
-// CHECK-LABEL: @const_ranked_shape
-func @const_ranked_shape() -> !shapex.ranked_shape<[2,4]> {
- // CHECK: %rs2_4 = shapex.const_ranked_shape : !shapex.ranked_shape<[2,4]>
- %0 = shapex.const_ranked_shape : !shapex.ranked_shape<[2,4]>
- // CHECK: %rs = shapex.const_ranked_shape : !shapex.ranked_shape<[]>
- %1 = shapex.const_ranked_shape : !shapex.ranked_shape<[]>
- // CHECK: %rs5_6 = shapex.const_ranked_shape : !shapex.ranked_shape<[5,6]>
- %2 = shapex.const_ranked_shape : !shapex.ranked_shape<[5,6]>
- return %0 : !shapex.ranked_shape<[2,4]>
-}
-
-// -----
-// CHECK-LABEL: @ranked_dim
-func @ranked_dim(%arg0 : !shapex.ranked_shape<[2,4]>) {
- // CHECK: shapex.ranked_dim %arg0[1] : !shapex.ranked_shape<[2,4]> -> index
- %0 = shapex.ranked_dim %arg0[1] : !shapex.ranked_shape<[2,4]> -> index
- return
-}
-
-// -----
-// CHECK-LABEL: @make_ranked_shape
-func @make_ranked_shape(%arg0 : index, %arg1 : index) -> (!shapex.ranked_shape<[1,?,?,16]>) {
- // CHECK: shapex.make_ranked_shape %arg0, %arg1 : (index, index) -> !shapex.ranked_shape<[1,?,?,16]>
- %0 = shapex.make_ranked_shape %arg0, %arg1 : (index, index) -> !shapex.ranked_shape<[1,?,?,16]>
- return %0 : !shapex.ranked_shape<[1,?,?,16]>
-}
diff --git a/iree/compiler/Dialect/Shape/IR/test/ranked_shape_type.mlir b/iree/compiler/Dialect/Shape/IR/test/ranked_shape_type.mlir
deleted file mode 100644
index 887a28b..0000000
--- a/iree/compiler/Dialect/Shape/IR/test/ranked_shape_type.mlir
+++ /dev/null
@@ -1,34 +0,0 @@
-// RUN: iree-opt -split-input-file -verify-diagnostics %s | IreeFileCheck %s
-
-// CHECK-LABEL: @parseScalarShapeIndex
-// CHECK: !shapex.ranked_shape<[]>
-func @parseScalarShapeIndex(%arg0 : !shapex.ranked_shape<[]>) {
- return
-}
-
-// -----
-// CHECK-LABEL: @parseStaticShapeIndex
-// CHECK: !shapex.ranked_shape<[1,2]>
-func @parseStaticShapeIndex(%arg0 : !shapex.ranked_shape<[1, 2]>) {
- return
-}
-
-// CHECK-LABEL: @parseScalarShape
-// CHECK: !shapex.ranked_shape<[]>
-func @parseScalarShape(%arg0 : !shapex.ranked_shape<[]>) {
- return
-}
-
-// -----
-// CHECK-LABEL: @parseStaticShape
-// CHECK: !shapex.ranked_shape<[1,2]>
-func @parseStaticShape(%arg0 : !shapex.ranked_shape<[1, 2]>) {
- return
-}
-
-// -----
-// CHECK-LABEL: @parseDynamicShape
-// CHECK: !shapex.ranked_shape<[1,?,2,?]>
-func @parseDynamicShape(%arg0 : !shapex.ranked_shape<[1,?,2,?]>) {
- return
-}
diff --git a/iree/compiler/Dialect/Shape/Transforms/BUILD b/iree/compiler/Dialect/Shape/Transforms/BUILD
deleted file mode 100644
index b2824e2..0000000
--- a/iree/compiler/Dialect/Shape/Transforms/BUILD
+++ /dev/null
@@ -1,34 +0,0 @@
-# Copyright 2019 The IREE Authors
-#
-# Licensed under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-package(
- default_visibility = ["//visibility:public"],
- features = ["layering_check"],
- licenses = ["notice"], # Apache 2.0
-)
-
-cc_library(
- name = "Transforms",
- srcs = [
- "FoldDimOverShapeCarryingOpPass.cpp",
- ],
- hdrs = [
- "Passes.h",
- ],
- deps = [
- "//iree/compiler/Dialect/Shape/IR",
- "//iree/compiler/Utils",
- "@llvm-project//llvm:Support",
- "@llvm-project//mlir:Analysis",
- "@llvm-project//mlir:IR",
- "@llvm-project//mlir:MemRefDialect",
- "@llvm-project//mlir:Pass",
- "@llvm-project//mlir:StandardOps",
- "@llvm-project//mlir:Support",
- "@llvm-project//mlir:TensorDialect",
- "@llvm-project//mlir:Transforms",
- ],
-)
diff --git a/iree/compiler/Dialect/Shape/Transforms/FoldDimOverShapeCarryingOpPass.cpp b/iree/compiler/Dialect/Shape/Transforms/FoldDimOverShapeCarryingOpPass.cpp
deleted file mode 100644
index 667c01e..0000000
--- a/iree/compiler/Dialect/Shape/Transforms/FoldDimOverShapeCarryingOpPass.cpp
+++ /dev/null
@@ -1,80 +0,0 @@
-// Copyright 2021 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#include "iree/compiler/Dialect/Shape/IR/ShapeDialect.h"
-#include "iree/compiler/Dialect/Shape/IR/ShapeOps.h"
-#include "iree/compiler/Dialect/Shape/Transforms/Passes.h"
-#include "mlir/Dialect/MemRef/IR/MemRef.h"
-#include "mlir/Dialect/Tensor/IR/Tensor.h"
-#include "mlir/IR/BuiltinAttributes.h"
-#include "mlir/IR/MLIRContext.h"
-#include "mlir/IR/Matchers.h"
-#include "mlir/IR/PatternMatch.h"
-#include "mlir/Pass/PassRegistry.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-
-namespace mlir {
-namespace iree_compiler {
-namespace Shape {
-
-namespace {
-
-template <typename DimOp>
-class FoldDimOp : public OpRewritePattern<DimOp> {
- using OpRewritePattern<DimOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(DimOp op,
- PatternRewriter &rewriter) const override {
- auto shapeCarryingOp =
- dyn_cast<ShapeCarryingInterface>(op.source().getDefiningOp());
- if (!shapeCarryingOp) return failure();
-
- IntegerAttr index;
- if (!matchPattern(op.index(), m_Constant(&index))) return failure();
-
- auto shapeOp =
- shapeCarryingOp.buildResultValueRankedShape(op.source(), rewriter);
- rewriter.replaceOpWithNewOp<RankedDimOp>(op, op.getType(), shapeOp, index);
- return success();
- }
-};
-
-class FoldDimOverShapeCarryingOpPass
- : public PassWrapper<FoldDimOverShapeCarryingOpPass, FunctionPass> {
- StringRef getArgument() const override {
- return "iree-fold-dim-over-shape-carrying-op";
- }
-
- StringRef getDescription() const override {
- return "Fold tensor.dim/memref.dim ops taking shape carrying ops as "
- "operands";
- }
-
- void getDependentDialects(DialectRegistry ®istry) const override {
- registry.insert<ShapeDialect>();
- }
-
- void runOnFunction() override {
- OwningRewritePatternList patterns(&getContext());
- patterns.insert<FoldDimOp<memref::DimOp>, FoldDimOp<tensor::DimOp>>(
- &getContext());
- (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
- }
-};
-
-} // namespace
-
-// For any function which contains dynamic dims in its inputs or results,
-// rewrites it so that the dynamic dims are passed in/out.
-std::unique_ptr<OperationPass<FuncOp>> createFoldDimOverShapeCarryingOpPass() {
- return std::make_unique<Shape::FoldDimOverShapeCarryingOpPass>();
-}
-
-static PassRegistration<Shape::FoldDimOverShapeCarryingOpPass> pass;
-
-} // namespace Shape
-} // namespace iree_compiler
-} // namespace mlir
diff --git a/iree/compiler/Dialect/Shape/Transforms/Passes.h b/iree/compiler/Dialect/Shape/Transforms/Passes.h
deleted file mode 100644
index 8c6199e..0000000
--- a/iree/compiler/Dialect/Shape/Transforms/Passes.h
+++ /dev/null
@@ -1,29 +0,0 @@
-// Copyright 2020 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#ifndef IREE_COMPILER_DIALECT_SHAPE_TRANSFORMS_PASSES_H_
-#define IREE_COMPILER_DIALECT_SHAPE_TRANSFORMS_PASSES_H_
-
-#include <memory>
-
-#include "mlir/Pass/Pass.h"
-#include "mlir/Pass/PassManager.h"
-
-namespace mlir {
-namespace iree_compiler {
-namespace Shape {
-
-// Folds tensor.dim/memref.dim ops taking shape carrying ops as operands.
-std::unique_ptr<OperationPass<FuncOp>> createFoldDimOverShapeCarryingOpPass();
-
-// Register all Passes
-inline void registerShapePasses() { createFoldDimOverShapeCarryingOpPass(); }
-
-} // namespace Shape
-} // namespace iree_compiler
-} // namespace mlir
-
-#endif // IREE_COMPILER_DIALECT_SHAPE_TRANSFORMS_PASSES_H_
diff --git a/iree/compiler/Dialect/Shape/Transforms/test/fold_dim_over_shape_carrying_op.mlir b/iree/compiler/Dialect/Shape/Transforms/test/fold_dim_over_shape_carrying_op.mlir
deleted file mode 100644
index 6aa8a66..0000000
--- a/iree/compiler/Dialect/Shape/Transforms/test/fold_dim_over_shape_carrying_op.mlir
+++ /dev/null
@@ -1,37 +0,0 @@
-// RUN: iree-opt -split-input-file -iree-fold-dim-over-shape-carrying-op -canonicalize %s | IreeFileCheck %s
-
-// CHECK: func @memrefDim
-// CHECK-SAME: (%[[DIM0:.+]]: index, %[[DIM1:.+]]: index)
-func @memrefDim(%d0: index, %d1: index) -> (index, index, index) {
- %c0 = arith.constant 0 : index
- %c1 = arith.constant 1 : index
- %c2 = arith.constant 2 : index
- %subspan = hal.interface.binding.subspan @io::@s0b0_ro_constant[%c0] : memref<?x7x?xf32>{%d0, %d1}
- %dim0 = memref.dim %subspan, %c0 : memref<?x7x?xf32>
- %dim1 = memref.dim %subspan, %c1 : memref<?x7x?xf32>
- %dim2 = memref.dim %subspan, %c2 : memref<?x7x?xf32>
- // CHECK: %[[C7:.+]] = arith.constant 7 : index
- // CHECK: return %[[DIM0]], %[[C7]], %[[DIM1]]
- return %dim0, %dim1, %dim2 : index, index, index
-}
-
-hal.interface @io attributes {sym_visibility = "private"} {
- hal.interface.binding @s0b0_ro_constant, set=0, binding=0, type="StorageBuffer"
-}
-
-// -----
-
-// CHECK: func @tensorDim
-// CHECK-SAME: (%{{.+}}: f32, %[[DIM0:.+]]: index, %[[DIM1:.+]]: index)
-func @tensorDim(%value: f32, %d0: index, %d1: index) -> (index, index, index) {
- %c0 = arith.constant 0 : index
- %c1 = arith.constant 1 : index
- %c2 = arith.constant 2 : index
- %splat = flow.tensor.splat %value : tensor<?x8x?xf32>{%d0, %d1}
- %dim0 = tensor.dim %splat, %c0 : tensor<?x8x?xf32>
- %dim1 = tensor.dim %splat, %c1 : tensor<?x8x?xf32>
- %dim2 = tensor.dim %splat, %c2 : tensor<?x8x?xf32>
- // CHECK: %[[C8:.+]] = arith.constant 8 : index
- // CHECK: return %[[DIM0]], %[[C8]], %[[DIM1]]
- return %dim0, %dim1, %dim2 : index, index, index
-}
diff --git a/iree/compiler/Dialect/Stream/Analysis/Partitioning/ReferencePartitioning.cpp b/iree/compiler/Dialect/Stream/Analysis/Partitioning/ReferencePartitioning.cpp
index 7b5f1bf..6e43954 100644
--- a/iree/compiler/Dialect/Stream/Analysis/Partitioning/ReferencePartitioning.cpp
+++ b/iree/compiler/Dialect/Stream/Analysis/Partitioning/ReferencePartitioning.cpp
@@ -218,8 +218,7 @@
IREE::Stream::PartitioningConfigAttr config, Block *block) {
PartitionSet waveSet;
- auto favor = config ? config.getFavor().getValue()
- : IREE::Stream::Favor::MinPeakMemory;
+ auto favor = config.getFavor().getValue();
if (favor == IREE::Stream::Favor::Debug) {
// Disable partitioning when favoring debugability.
return waveSet;
@@ -298,7 +297,7 @@
opInfo.membership.resize(builders.size(), /*t=*/false);
// No consumers - if there's any candidate then we'll go into that.
- int firstCandidateOrdinal = favor == IREE::Stream::Favor::MinPeakMemory
+ int firstCandidateOrdinal = favor == IREE::Stream::Favor::MaxConcurrency
? candidates.find_first()
: candidates.find_last();
if (firstCandidateOrdinal != -1) {
diff --git a/iree/compiler/Dialect/Stream/Conversion/FlowToStream/BUILD b/iree/compiler/Dialect/Stream/Conversion/FlowToStream/BUILD
index d317206..025e0d1 100644
--- a/iree/compiler/Dialect/Stream/Conversion/FlowToStream/BUILD
+++ b/iree/compiler/Dialect/Stream/Conversion/FlowToStream/BUILD
@@ -20,8 +20,6 @@
],
deps = [
"//iree/compiler/Dialect/Flow/IR",
- "//iree/compiler/Dialect/HAL/IR",
- "//iree/compiler/Dialect/Shape/IR",
"//iree/compiler/Dialect/Stream/Conversion",
"//iree/compiler/Dialect/Stream/IR",
"@llvm-project//mlir:IR",
diff --git a/iree/compiler/Dialect/Stream/Conversion/FlowToStream/CMakeLists.txt b/iree/compiler/Dialect/Stream/Conversion/FlowToStream/CMakeLists.txt
index d3b2d39..c5651f7 100644
--- a/iree/compiler/Dialect/Stream/Conversion/FlowToStream/CMakeLists.txt
+++ b/iree/compiler/Dialect/Stream/Conversion/FlowToStream/CMakeLists.txt
@@ -22,8 +22,6 @@
MLIRStandard
MLIRTensor
iree::compiler::Dialect::Flow::IR
- iree::compiler::Dialect::HAL::IR
- iree::compiler::Dialect::Shape::IR
iree::compiler::Dialect::Stream::Conversion
iree::compiler::Dialect::Stream::IR
PUBLIC
diff --git a/iree/compiler/Dialect/Stream/Conversion/FlowToStream/ConvertFlowToStream.cpp b/iree/compiler/Dialect/Stream/Conversion/FlowToStream/ConvertFlowToStream.cpp
index 57e2b96..bbd01c6 100644
--- a/iree/compiler/Dialect/Stream/Conversion/FlowToStream/ConvertFlowToStream.cpp
+++ b/iree/compiler/Dialect/Stream/Conversion/FlowToStream/ConvertFlowToStream.cpp
@@ -8,9 +8,6 @@
#include "iree/compiler/Dialect/Flow/IR/FlowDialect.h"
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
-#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
-#include "iree/compiler/Dialect/Shape/IR/Builders.h"
-#include "iree/compiler/Dialect/Shape/IR/ShapeOps.h"
#include "iree/compiler/Dialect/Stream/Conversion/PatternUtils.h"
#include "iree/compiler/Dialect/Stream/IR/StreamDialect.h"
#include "iree/compiler/Dialect/Stream/IR/StreamOps.h"
@@ -35,66 +32,6 @@
dynamicDims, /*affinity=*/nullptr);
}
-// hal.tensor.cast is inserted by frontends to ensure that ABI types are HAL
-// buffer views. We need to map those to the stream import/export equivalents as
-// the cast has special meaning when we are dealing with asynchronous values.
-//
-// %1 = hal.tensor.cast %0 : !hal.buffer_view -> tensor<4xf32>
-// ->
-// %1 = stream.tensor.import %0 : !hal.buffer_view ->
-// tensor<4xf32> in !stream.resource<*>
-struct ConvertHALTensorCastOp
- : public OpConversionPattern<IREE::HAL::TensorCastOp> {
- using OpConversionPattern::OpConversionPattern;
- LogicalResult matchAndRewrite(
- IREE::HAL::TensorCastOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- auto sourceType = op.source().getType();
- auto targetType = op.target().getType();
- if (sourceType.isa<IREE::HAL::BufferType>() ||
- sourceType.isa<IREE::HAL::BufferViewType>()) {
- // Import (buffer view to stream resource).
- auto resultType = rewriter.getType<IREE::Stream::ResourceType>(
- IREE::Stream::Lifetime::External);
- auto resultSize = buildResultSizeOf(op.getLoc(), op.target(),
- adaptor.target_dims(), rewriter);
- auto newOp = rewriter.create<IREE::Stream::TensorImportOp>(
- op.getLoc(), resultType, adaptor.source(), TypeAttr::get(targetType),
- adaptor.target_dims(), resultSize,
- /*affinity=*/nullptr);
-
- auto unknownType = rewriter.getType<IREE::Stream::ResourceType>();
- rewriter.replaceOpWithNewOp<IREE::Stream::AsyncTransferOp>(
- op, unknownType, newOp.result(), resultSize, resultSize,
- /*source_affinity=*/nullptr,
- /*result_affinity=*/nullptr);
- } else if (targetType.isa<IREE::HAL::BufferType>() ||
- targetType.isa<IREE::HAL::BufferViewType>()) {
- auto source =
- consumeTensorOperand(op.getLoc(), adaptor.source(), rewriter);
- auto externalType = rewriter.getType<IREE::Stream::ResourceType>(
- IREE::Stream::Lifetime::External);
- auto exportSource = adaptor.source();
- if (source.resource.getType() != externalType) {
- exportSource = rewriter.create<IREE::Stream::AsyncTransferOp>(
- op.getLoc(), externalType, source.resource, source.resourceSize,
- source.resourceSize,
- /*source_affinity=*/nullptr,
- /*result_affinity=*/nullptr);
- }
-
- // Export (stream resource to buffer view).
- rewriter.replaceOpWithNewOp<IREE::Stream::TensorExportOp>(
- op, targetType, exportSource, TypeAttr::get(op.source().getType()),
- adaptor.source_dims(), source.resourceSize,
- /*affinity=*/nullptr);
- } else {
- return rewriter.notifyMatchFailure(op, "unsupported HAL cast conversion");
- }
- return success();
- }
-};
-
// Reshapes become clones here to preserve shape information (which may become
// actual transfers depending on source/target shape) - they'll be elided if not
// needed.
@@ -277,7 +214,7 @@
/*source_affinity=*/nullptr,
/*result_affinity=*/nullptr);
}
- auto dynamicDims = Shape::buildOrFindDynamicDimsForValue(
+ auto dynamicDims = IREE::Util::buildDynamicDimsForValue(
op.getLoc(), tensorOperand, rewriter);
exportedTensors.push_back(rewriter.create<IREE::Stream::TensorExportOp>(
op.getLoc(), tensorOperand.getType(), exportSource,
@@ -332,7 +269,7 @@
resultSizes.push_back(operandSizes[operandIndex]);
resultTypes.push_back(dispatchOperands[operandIndex].getType());
} else {
- auto resultDynamicDims = Shape::buildOrFindDynamicDimsForValue(
+ auto resultDynamicDims = IREE::Util::buildDynamicDimsForValue(
op.getLoc(), result.value(), rewriter);
resultSizes.push_back(buildResultSizeOf(op.getLoc(), result.value(),
resultDynamicDims, rewriter));
@@ -434,10 +371,6 @@
void populateFlowToStreamConversionPatterns(
MLIRContext *context, TypeConverter &typeConverter,
OwningRewritePatternList &patterns) {
- typeConverter.addConversion(
- [](IREE::HAL::BufferViewType type) { return type; });
- patterns.insert<ConvertHALTensorCastOp>(typeConverter, context);
-
patterns
.insert<ConvertTensorReshapeOp, ConvertTensorSplatOp,
ConvertTensorCloneOp, ConvertTensorSliceOp, ConvertTensorUpdateOp,
@@ -459,12 +392,6 @@
conversionTarget.addLegalOp<IREE::Stream::ExecutableOp>();
conversionTarget.markOpRecursivelyLegal<IREE::Stream::ExecutableOp>();
- conversionTarget.addDynamicallyLegalOp<IREE::HAL::TensorCastOp>(
- [&](IREE::HAL::TensorCastOp op) {
- return typeConverter.isLegal(op.source().getType()) &&
- typeConverter.isLegal(op.target().getType());
- });
-
populateFlowToStreamConversionPatterns(context, typeConverter, patterns);
}
diff --git a/iree/compiler/Dialect/Stream/Conversion/FlowToStream/test/BUILD b/iree/compiler/Dialect/Stream/Conversion/FlowToStream/test/BUILD
index 1f3d342..e73559d 100644
--- a/iree/compiler/Dialect/Stream/Conversion/FlowToStream/test/BUILD
+++ b/iree/compiler/Dialect/Stream/Conversion/FlowToStream/test/BUILD
@@ -17,7 +17,6 @@
name = "lit",
srcs = enforce_glob(
[
- "cast_ops.mlir",
"dispatch_ops.mlir",
"tensor_ops.mlir",
],
diff --git a/iree/compiler/Dialect/Stream/Conversion/FlowToStream/test/CMakeLists.txt b/iree/compiler/Dialect/Stream/Conversion/FlowToStream/test/CMakeLists.txt
index 1974566..cb5a3aa 100644
--- a/iree/compiler/Dialect/Stream/Conversion/FlowToStream/test/CMakeLists.txt
+++ b/iree/compiler/Dialect/Stream/Conversion/FlowToStream/test/CMakeLists.txt
@@ -14,7 +14,6 @@
NAME
lit
SRCS
- "cast_ops.mlir"
"dispatch_ops.mlir"
"tensor_ops.mlir"
DATA
diff --git a/iree/compiler/Dialect/Stream/Conversion/HALToStream/BUILD b/iree/compiler/Dialect/Stream/Conversion/HALToStream/BUILD
new file mode 100644
index 0000000..b00816d
--- /dev/null
+++ b/iree/compiler/Dialect/Stream/Conversion/HALToStream/BUILD
@@ -0,0 +1,28 @@
+# Copyright 2021 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+package(
+ default_visibility = ["//visibility:public"],
+ features = ["layering_check"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+cc_library(
+ name = "HALToStream",
+ srcs = [
+ "ConvertHALToStream.cpp",
+ ],
+ hdrs = [
+ "ConvertHALToStream.h",
+ ],
+ deps = [
+ "//iree/compiler/Dialect/HAL/IR",
+ "//iree/compiler/Dialect/Stream/Conversion",
+ "//iree/compiler/Dialect/Stream/IR",
+ "@llvm-project//mlir:IR",
+ "@llvm-project//mlir:StandardOps",
+ ],
+)
diff --git a/iree/compiler/Dialect/Shape/Transforms/CMakeLists.txt b/iree/compiler/Dialect/Stream/Conversion/HALToStream/CMakeLists.txt
similarity index 72%
rename from iree/compiler/Dialect/Shape/Transforms/CMakeLists.txt
rename to iree/compiler/Dialect/Stream/Conversion/HALToStream/CMakeLists.txt
index 27f4e8f..51c682f 100644
--- a/iree/compiler/Dialect/Shape/Transforms/CMakeLists.txt
+++ b/iree/compiler/Dialect/Stream/Conversion/HALToStream/CMakeLists.txt
@@ -1,6 +1,6 @@
################################################################################
# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from #
-# iree/compiler/Dialect/Shape/Transforms/BUILD #
+# iree/compiler/Dialect/Stream/Conversion/HALToStream/BUILD #
# #
# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary #
# CMake-only content. #
@@ -12,23 +12,17 @@
iree_cc_library(
NAME
- Transforms
+ HALToStream
HDRS
- "Passes.h"
+ "ConvertHALToStream.h"
SRCS
- "FoldDimOverShapeCarryingOpPass.cpp"
+ "ConvertHALToStream.cpp"
DEPS
- LLVMSupport
- MLIRAnalysis
MLIRIR
- MLIRMemRef
- MLIRPass
MLIRStandard
- MLIRSupport
- MLIRTensor
- MLIRTransforms
- iree::compiler::Dialect::Shape::IR
- iree::compiler::Utils
+ iree::compiler::Dialect::HAL::IR
+ iree::compiler::Dialect::Stream::Conversion
+ iree::compiler::Dialect::Stream::IR
PUBLIC
)
diff --git a/iree/compiler/Dialect/Stream/Conversion/HALToStream/ConvertHALToStream.cpp b/iree/compiler/Dialect/Stream/Conversion/HALToStream/ConvertHALToStream.cpp
new file mode 100644
index 0000000..792805f
--- /dev/null
+++ b/iree/compiler/Dialect/Stream/Conversion/HALToStream/ConvertHALToStream.cpp
@@ -0,0 +1,126 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Dialect/Stream/Conversion/HALToStream/ConvertHALToStream.h"
+
+#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
+#include "iree/compiler/Dialect/Stream/Conversion/PatternUtils.h"
+#include "iree/compiler/Dialect/Stream/IR/StreamDialect.h"
+#include "iree/compiler/Dialect/Stream/IR/StreamOps.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+namespace {
+
+// %1 = hal.tensor.import %0 : !hal.buffer_view -> tensor<4xf32>
+// ->
+// %1 = stream.tensor.import %0 : !hal.buffer_view ->
+// tensor<4xf32> in !stream.resource<*>
+struct ConvertTensorImportOp
+ : public OpConversionPattern<IREE::HAL::TensorImportOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult matchAndRewrite(
+ IREE::HAL::TensorImportOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto sourceType = op.source().getType();
+ auto targetType = op.target().getType();
+ if (!sourceType.isa<IREE::HAL::BufferType>() &&
+ !sourceType.isa<IREE::HAL::BufferViewType>()) {
+ return rewriter.notifyMatchFailure(op, "unsupported HAL cast conversion");
+ }
+
+ // Import (buffer view to stream resource).
+ auto resultType = rewriter.getType<IREE::Stream::ResourceType>(
+ IREE::Stream::Lifetime::External);
+ auto resultSize = rewriter.createOrFold<IREE::Stream::TensorSizeOfOp>(
+ op.getLoc(), rewriter.getIndexType(),
+ TypeAttr::get(op.target().getType()), adaptor.target_dims(),
+ /*affinity=*/nullptr);
+ auto newOp = rewriter.create<IREE::Stream::TensorImportOp>(
+ op.getLoc(), resultType, adaptor.source(), TypeAttr::get(targetType),
+ adaptor.target_dims(), resultSize,
+ /*affinity=*/nullptr);
+
+ auto unknownType = rewriter.getType<IREE::Stream::ResourceType>();
+ rewriter.replaceOpWithNewOp<IREE::Stream::AsyncTransferOp>(
+ op, unknownType, newOp.result(), resultSize, resultSize,
+ /*source_affinity=*/nullptr,
+ /*result_affinity=*/nullptr);
+ return success();
+ }
+};
+
+// %1 = hal.tensor.export %0 : tensor<4xf32> -> !hal.buffer_view
+// ->
+// %1 = stream.tensor.export %0 : tensor<4xf32> in !stream.resource<*> ->
+// !hal.buffer_view
+struct ConvertTensorExportOp
+ : public OpConversionPattern<IREE::HAL::TensorExportOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult matchAndRewrite(
+ IREE::HAL::TensorExportOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto sourceType = op.source().getType();
+ auto targetType = op.target().getType();
+ if (!targetType.isa<IREE::HAL::BufferType>() &&
+ !targetType.isa<IREE::HAL::BufferViewType>()) {
+ return rewriter.notifyMatchFailure(op, "unsupported HAL cast conversion");
+ }
+
+ auto source = consumeTensorOperand(op.getLoc(), adaptor.source(), rewriter);
+ auto externalType = rewriter.getType<IREE::Stream::ResourceType>(
+ IREE::Stream::Lifetime::External);
+ auto exportSource = adaptor.source();
+ if (source.resource.getType() != externalType) {
+ exportSource = rewriter.create<IREE::Stream::AsyncTransferOp>(
+ op.getLoc(), externalType, source.resource, source.resourceSize,
+ source.resourceSize,
+ /*source_affinity=*/nullptr,
+ /*result_affinity=*/nullptr);
+ }
+
+ // Export (stream resource to buffer view).
+ rewriter.replaceOpWithNewOp<IREE::Stream::TensorExportOp>(
+ op, targetType, exportSource, TypeAttr::get(sourceType),
+ adaptor.source_dims(), source.resourceSize,
+ /*affinity=*/nullptr);
+ return success();
+ }
+};
+
+} // namespace
+
+void populateHALToStreamConversionPatterns(MLIRContext *context,
+ TypeConverter &typeConverter,
+ OwningRewritePatternList &patterns) {
+ typeConverter.addConversion(
+ [](IREE::HAL::BufferViewType type) { return type; });
+ patterns.insert<ConvertTensorImportOp>(typeConverter, context);
+ patterns.insert<ConvertTensorExportOp>(typeConverter, context);
+}
+
+void populateHALToStreamConversionPatterns(MLIRContext *context,
+ ConversionTarget &conversionTarget,
+ TypeConverter &typeConverter,
+ OwningRewritePatternList &patterns) {
+ conversionTarget.addDynamicallyLegalOp<IREE::HAL::TensorImportOp>(
+ [&](IREE::HAL::TensorImportOp op) {
+ return typeConverter.isLegal(op.source().getType()) &&
+ typeConverter.isLegal(op.target().getType());
+ });
+ conversionTarget.addDynamicallyLegalOp<IREE::HAL::TensorExportOp>(
+ [&](IREE::HAL::TensorExportOp op) {
+ return typeConverter.isLegal(op.source().getType()) &&
+ typeConverter.isLegal(op.target().getType());
+ });
+
+ populateHALToStreamConversionPatterns(context, typeConverter, patterns);
+}
+
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/iree/compiler/Dialect/Stream/Conversion/HALToStream/ConvertHALToStream.h b/iree/compiler/Dialect/Stream/Conversion/HALToStream/ConvertHALToStream.h
new file mode 100644
index 0000000..71c2bff
--- /dev/null
+++ b/iree/compiler/Dialect/Stream/Conversion/HALToStream/ConvertHALToStream.h
@@ -0,0 +1,31 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_COMPILER_DIALECT_STREAM_CONVERSION_HALTOSTREAM_CONVERTHALTOSTREAM_H_
+#define IREE_COMPILER_DIALECT_STREAM_CONVERSION_HALTOSTREAM_CONVERTHALTOSTREAM_H_
+
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/OperationSupport.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+// Populates conversion patterns that perform hal->stream conversion.
+// These patterns ensure that nested types are run through the provided
+// |typeConverter|.
+void populateHALToStreamConversionPatterns(MLIRContext *context,
+ TypeConverter &typeConverter,
+ OwningRewritePatternList &patterns);
+void populateHALToStreamConversionPatterns(MLIRContext *context,
+ ConversionTarget &conversionTarget,
+ TypeConverter &typeConverter,
+ OwningRewritePatternList &patterns);
+
+} // namespace iree_compiler
+} // namespace mlir
+
+#endif // IREE_COMPILER_DIALECT_STREAM_CONVERSION_HALTOSTREAM_CONVERTHALTOSTREAM_H_
diff --git a/iree/compiler/Dialect/Shape/Transforms/test/BUILD b/iree/compiler/Dialect/Stream/Conversion/HALToStream/test/BUILD
similarity index 88%
rename from iree/compiler/Dialect/Shape/Transforms/test/BUILD
rename to iree/compiler/Dialect/Stream/Conversion/HALToStream/test/BUILD
index 415ce3d..5fc9bdb 100644
--- a/iree/compiler/Dialect/Shape/Transforms/test/BUILD
+++ b/iree/compiler/Dialect/Stream/Conversion/HALToStream/test/BUILD
@@ -1,4 +1,4 @@
-# Copyright 2020 The IREE Authors
+# Copyright 2021 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
@@ -17,7 +17,7 @@
name = "lit",
srcs = enforce_glob(
[
- "fold_dim_over_shape_carrying_op.mlir",
+ "abi_ops.mlir",
],
include = ["*.mlir"],
),
diff --git a/iree/compiler/Dialect/Shape/Transforms/test/CMakeLists.txt b/iree/compiler/Dialect/Stream/Conversion/HALToStream/test/CMakeLists.txt
similarity index 88%
rename from iree/compiler/Dialect/Shape/Transforms/test/CMakeLists.txt
rename to iree/compiler/Dialect/Stream/Conversion/HALToStream/test/CMakeLists.txt
index 28c3aaa..a02b6cb 100644
--- a/iree/compiler/Dialect/Shape/Transforms/test/CMakeLists.txt
+++ b/iree/compiler/Dialect/Stream/Conversion/HALToStream/test/CMakeLists.txt
@@ -1,6 +1,6 @@
################################################################################
# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from #
-# iree/compiler/Dialect/Shape/Transforms/test/BUILD #
+# iree/compiler/Dialect/Stream/Conversion/HALToStream/test/BUILD #
# #
# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary #
# CMake-only content. #
@@ -14,7 +14,7 @@
NAME
lit
SRCS
- "fold_dim_over_shape_carrying_op.mlir"
+ "abi_ops.mlir"
DATA
iree::tools::IreeFileCheck
iree::tools::iree-opt
diff --git a/iree/compiler/Dialect/Stream/Conversion/FlowToStream/test/cast_ops.mlir b/iree/compiler/Dialect/Stream/Conversion/HALToStream/test/abi_ops.mlir
similarity index 90%
rename from iree/compiler/Dialect/Stream/Conversion/FlowToStream/test/cast_ops.mlir
rename to iree/compiler/Dialect/Stream/Conversion/HALToStream/test/abi_ops.mlir
index 17f9a6a..5b7a98f 100644
--- a/iree/compiler/Dialect/Stream/Conversion/FlowToStream/test/cast_ops.mlir
+++ b/iree/compiler/Dialect/Stream/Conversion/HALToStream/test/abi_ops.mlir
@@ -13,7 +13,7 @@
// CHECK-SAME: tensor<?x?x4xf32>{%[[DIM0]], %[[DIM1]]} in !stream.resource<external>{%[[SIZE]]}
// CHECK-NEXT: %[[RESULT:.+]] = stream.async.transfer %[[RESOURCE]] :
// CHECK-SAME: !stream.resource<external>{%[[SIZE]]} -> !stream.resource<*>{%[[SIZE]]}
- %0 = hal.tensor.cast %view : !hal.buffer_view -> tensor<?x?x4xf32>{%dim0, %dim1}
+ %0 = hal.tensor.import %view : !hal.buffer_view -> tensor<?x?x4xf32>{%dim0, %dim1}
// CHECK: return %[[RESULT]], %[[SIZE]] : !stream.resource<*>, index
return %0 : tensor<?x?x4xf32>
}
@@ -28,7 +28,7 @@
// CHECK-NEXT: %[[RESULT:.+]] = stream.tensor.export %[[VIEW]] :
// CHECK-SAME: tensor<?x?x4xf32>{%[[DIM0]], %[[DIM1]]} in !stream.resource<external>{%[[SIZE]]}
// CHECK-SAME: -> !hal.buffer_view
- %0 = hal.tensor.cast %tensor : tensor<?x?x4xf32>{%dim0, %dim1} -> !hal.buffer_view
+ %0 = hal.tensor.export %tensor : tensor<?x?x4xf32>{%dim0, %dim1} -> !hal.buffer_view
// CHECK: return %[[RESULT]]
return %0 : !hal.buffer_view
}
diff --git a/iree/compiler/Dialect/Stream/Conversion/UtilToStream/test/global_ops.mlir b/iree/compiler/Dialect/Stream/Conversion/UtilToStream/test/global_ops.mlir
index cc8ee40..04d7198 100644
--- a/iree/compiler/Dialect/Stream/Conversion/UtilToStream/test/global_ops.mlir
+++ b/iree/compiler/Dialect/Stream/Conversion/UtilToStream/test/global_ops.mlir
@@ -63,7 +63,7 @@
// CHECK: %[[SIZE:.+]] = stream.tensor.sizeof tensor<?x4xf32>{%[[DIM0]]} : index
// CHECK: %[[IMPORT:.+]] = stream.tensor.import %arg0 : !hal.buffer_view -> tensor<?x4xf32>{%[[DIM0]]} in !stream.resource<external>{%[[SIZE]]}
// CHECK: %[[T:.+]] = stream.async.transfer %[[IMPORT]] : !stream.resource<external>{%[[SIZE]]} -> !stream.resource<*>{%[[SIZE]]}
- %0 = hal.tensor.cast %arg0 : !hal.buffer_view -> tensor<?x4xf32>{%dim0}
+ %0 = hal.tensor.import %arg0 : !hal.buffer_view -> tensor<?x4xf32>{%dim0}
// CHECK: %[[VAR:.+]] = stream.async.transfer %[[T]] : !stream.resource<*>{%[[SIZE]]} -> !stream.resource<variable>{%[[SIZE]]}
// CHECK: util.global.store %[[VAR]], @var_with_buffer_view_store : !stream.resource<variable>
// CHECK: util.global.store %[[SIZE]], @var_with_buffer_view_store__size : index
@@ -80,7 +80,7 @@
// util.global public mutable @var_indirect_with_buffer_view_store : tensor<i32>
// func @globalStoreFromExternalIndirect(%arg0: !hal.buffer_view) {
// %0 = util.global.address @var_indirect_with_buffer_view_store : !util.ptr<tensor<i32>>
-// %1 = hal.tensor.cast %arg0 : !hal.buffer_view -> tensor<i32>
+// %1 = hal.tensor.import %arg0 : !hal.buffer_view -> tensor<i32>
// util.global.store.indirect %1, %0 : tensor<i32> -> !util.ptr<tensor<i32>>
// return
// }
diff --git a/iree/compiler/Dialect/Stream/IR/BUILD b/iree/compiler/Dialect/Stream/IR/BUILD
index 39655c5..47c6758 100644
--- a/iree/compiler/Dialect/Stream/IR/BUILD
+++ b/iree/compiler/Dialect/Stream/IR/BUILD
@@ -25,7 +25,6 @@
include = ["*.td"],
),
deps = [
- "//iree/compiler/Dialect/Shape/IR:td_files",
"//iree/compiler/Dialect/Util/IR:td_files",
"@llvm-project//mlir:InferTypeOpInterfaceTdFiles",
"@llvm-project//mlir:OpBaseTdFiles",
@@ -65,7 +64,6 @@
":StreamInterfacesGen",
":StreamOpsGen",
":StreamTypesGen",
- "//iree/compiler/Dialect/Shape/IR",
"//iree/compiler/Dialect/Util/IR",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:ArithmeticDialect",
diff --git a/iree/compiler/Dialect/Stream/IR/CMakeLists.txt b/iree/compiler/Dialect/Stream/IR/CMakeLists.txt
index 6b00ba4..fe0e7ca 100644
--- a/iree/compiler/Dialect/Stream/IR/CMakeLists.txt
+++ b/iree/compiler/Dialect/Stream/IR/CMakeLists.txt
@@ -49,7 +49,6 @@
MLIRSupport
MLIRTensor
MLIRTransformUtils
- iree::compiler::Dialect::Shape::IR
iree::compiler::Dialect::Util::IR
PUBLIC
)
diff --git a/iree/compiler/Dialect/Stream/IR/StreamBase.td b/iree/compiler/Dialect/Stream/IR/StreamBase.td
index 4be92fd..c694de0 100644
--- a/iree/compiler/Dialect/Stream/IR/StreamBase.td
+++ b/iree/compiler/Dialect/Stream/IR/StreamBase.td
@@ -10,7 +10,6 @@
include "iree/compiler/Dialect/Stream/IR/StreamInterfaces.td"
include "iree/compiler/Dialect/Util/IR/UtilBase.td"
include "iree/compiler/Dialect/Util/IR/UtilInterfaces.td"
-include "iree/compiler/Dialect/Shape/IR/ShapeBase.td"
include "mlir/IR/SubElementInterfaces.td"
//===----------------------------------------------------------------------===//
@@ -173,13 +172,13 @@
}
def Stream_Favor_Debug : I32EnumAttrCase<"Debug", 0, "debug">;
-def Stream_Favor_Concurrency : I32EnumAttrCase<"Concurrency", 1, "concurrency">;
-def Stream_Favor_MinPeakMemory : I32EnumAttrCase<"MinPeakMemory", 2, "min-peak-memory">;
+def Stream_Favor_MinPeakMemory : I32EnumAttrCase<"MinPeakMemory", 1, "min-peak-memory">;
+def Stream_Favor_MaxConcurrency : I32EnumAttrCase<"MaxConcurrency", 2, "max-concurrency">;
def Stream_FavorAttr :
I32EnumAttr<"Favor", "IREE partitioning bias", [
Stream_Favor_Debug,
- Stream_Favor_Concurrency,
Stream_Favor_MinPeakMemory,
+ Stream_Favor_MaxConcurrency,
]> {
let cppNamespace = "::mlir::iree_compiler::IREE::Stream";
}
diff --git a/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp b/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp
index 0599c5c..0e2edfb 100644
--- a/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp
+++ b/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp
@@ -7,7 +7,6 @@
#include <algorithm>
#include <numeric>
-#include "iree/compiler/Dialect/Shape/IR/ShapeOps.h"
#include "iree/compiler/Dialect/Stream/IR/StreamDialect.h"
#include "iree/compiler/Dialect/Stream/IR/StreamOps.h"
#include "iree/compiler/Dialect/Util/IR/ClosureOpUtils.h"
@@ -1977,6 +1976,24 @@
}
//===----------------------------------------------------------------------===//
+// stream.timepoint.export
+//===----------------------------------------------------------------------===//
+
+LogicalResult TimepointExportOp::fold(ArrayRef<Attribute> operands,
+ SmallVectorImpl<OpFoldResult> &results) {
+ // If the source timepoint comes from an import op we can fold - but only if
+ // the types match.
+ if (auto importOp = dyn_cast_or_null<TimepointImportOp>(
+ await_timepoint().getDefiningOp())) {
+ if (llvm::equal(importOp.getOperandTypes(), getResultTypes())) {
+ llvm::append_range(results, importOp.operands());
+ return success();
+ }
+ }
+ return failure();
+}
+
+//===----------------------------------------------------------------------===//
// stream.timepoint.join
//===----------------------------------------------------------------------===//
diff --git a/iree/compiler/Dialect/Stream/IR/StreamOps.cpp b/iree/compiler/Dialect/Stream/IR/StreamOps.cpp
index df5cf54..9db3886 100644
--- a/iree/compiler/Dialect/Stream/IR/StreamOps.cpp
+++ b/iree/compiler/Dialect/Stream/IR/StreamOps.cpp
@@ -6,7 +6,6 @@
#include "iree/compiler/Dialect/Stream/IR/StreamOps.h"
-#include "iree/compiler/Dialect/Shape/IR/Builders.h"
#include "iree/compiler/Dialect/Util/IR/ClosureOpUtils.h"
#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
#include "iree/compiler/Dialect/Util/IR/UtilTypes.h"
@@ -1823,17 +1822,6 @@
return success();
}
-Value BindingSubspanOp::buildOperandRankedShape(unsigned idx,
- OpBuilder &builder) {
- return {};
-}
-
-Value BindingSubspanOp::buildResultRankedShape(unsigned idx,
- OpBuilder &builder) {
- return Shape::buildRankedShapeForValue(getLoc(), result(), dynamic_dims(),
- builder);
-}
-
//===----------------------------------------------------------------------===//
// stream.yield
//===----------------------------------------------------------------------===//
diff --git a/iree/compiler/Dialect/Stream/IR/StreamOps.h b/iree/compiler/Dialect/Stream/IR/StreamOps.h
index a57fea1..632b8a0 100644
--- a/iree/compiler/Dialect/Stream/IR/StreamOps.h
+++ b/iree/compiler/Dialect/Stream/IR/StreamOps.h
@@ -9,8 +9,6 @@
#include <cstdint>
-#include "iree/compiler/Dialect/Shape/IR/ShapeDialect.h"
-#include "iree/compiler/Dialect/Shape/IR/ShapeTypes.h"
#include "iree/compiler/Dialect/Stream/IR/StreamDialect.h"
#include "iree/compiler/Dialect/Stream/IR/StreamTraits.h"
#include "iree/compiler/Dialect/Stream/IR/StreamTypes.h"
diff --git a/iree/compiler/Dialect/Stream/IR/StreamOps.td b/iree/compiler/Dialect/Stream/IR/StreamOps.td
index 2cdb2fe..f5310ea 100644
--- a/iree/compiler/Dialect/Stream/IR/StreamOps.td
+++ b/iree/compiler/Dialect/Stream/IR/StreamOps.td
@@ -7,7 +7,6 @@
#ifndef IREE_DIALECT_STREAM_OPS
#define IREE_DIALECT_STREAM_OPS
-include "iree/compiler/Dialect/Shape/IR/ShapeInterfaces.td"
include "iree/compiler/Dialect/Stream/IR/StreamBase.td"
include "iree/compiler/Dialect/Stream/IR/StreamInterfaces.td"
include "iree/compiler/Dialect/Util/IR/UtilInterfaces.td"
@@ -2427,6 +2426,65 @@
let hasFolder = 1;
}
+
+def Stream_TimepointImportOp : Stream_PureOp<"timepoint.import", [
+ Stream_AffinityOp,
+]> {
+ let summary = [{imports a timepoint from an external dialect type}];
+ let description = [{
+ Defines a conversion from an external dialect type such as `hal.semaphore`
+ that is resolved during lowering into the stream dialect. This can be used
+ to interoperate between levels of the stack that require specifying stream
+ types and those that prior to lowering do not handle them.
+ }];
+
+ let arguments = (ins
+ Variadic<AnyType>:$operands,
+ OptionalAttr<Stream_AffinityAttr>:$affinity
+ );
+ let results = (outs
+ Stream_Timepoint:$result_timepoint
+ );
+
+ let assemblyFormat = [{
+ (`on` `(` $affinity^ `)`)?
+ $operands `:` `(` type($operands) `)`
+ `=` `` `>`
+ type($result_timepoint)
+ attr-dict-with-keyword
+ }];
+}
+
+def Stream_TimepointExportOp : Stream_PureOp<"timepoint.export", [
+ Stream_AffinityOp,
+]> {
+ let summary = [{exports a timepoint to an external dialect type}];
+ let description = [{
+ Defines a conversion to an external dialect type such as `hal.semaphore`
+ that is resolved during lowering into the stream dialect. This can be used
+ to interoperate between levels of the stack that require specifying stream
+ types and those that prior to lowering do not handle them.
+ }];
+
+ let arguments = (ins
+ Stream_Timepoint:$await_timepoint,
+ OptionalAttr<Stream_AffinityAttr>:$affinity
+ );
+ let results = (outs
+ Variadic<AnyType>:$results
+ );
+
+ let assemblyFormat = [{
+ (`on` `(` $affinity^ `)`)?
+ $await_timepoint
+ `=` `` `>`
+ `(` type($results) `)`
+ attr-dict-with-keyword
+ }];
+
+ let hasFolder = 1;
+}
+
def Stream_TimepointJoinOp : Stream_PureOp<"timepoint.join", [
Stream_TimelineOp,
]> {
@@ -2607,7 +2665,6 @@
}
def Stream_BindingSubspanOp : Stream_PureOp<"binding.subspan", [
- DeclareOpInterfaceMethods<Shape_ShapeCarryingOpInterface>,
Util_ShapeAwareOp,
]> {
let summary = [{returns an alias to a subspan of interface binding data}];
diff --git a/iree/compiler/Dialect/Stream/IR/StreamTypes.cpp b/iree/compiler/Dialect/Stream/IR/StreamTypes.cpp
index 8b22c16..38135cf 100644
--- a/iree/compiler/Dialect/Stream/IR/StreamTypes.cpp
+++ b/iree/compiler/Dialect/Stream/IR/StreamTypes.cpp
@@ -8,6 +8,7 @@
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/CommandLine.h"
#include "mlir/IR/DialectImplementation.h"
// clang-format off: must be included after all LLVM/MLIR headers.
@@ -23,6 +24,20 @@
namespace IREE {
namespace Stream {
+static llvm::cl::opt<Favor> partitioningFavor(
+ "iree-stream-partitioning-favor",
+ llvm::cl::desc("Default stream partitioning favor configuration."),
+ llvm::cl::init(Favor::MaxConcurrency),
+ llvm::cl::values(
+ clEnumValN(Favor::Debug, "debug",
+ "Force debug partitioning (no concurrency or pipelining)."),
+ clEnumValN(Favor::MinPeakMemory, "min-peak-memory",
+ "Favor minimizing memory consumption at the cost of "
+ "additional concurrency."),
+ clEnumValN(Favor::MaxConcurrency, "max-concurrency",
+ "Favor maximizing concurrency at the cost of additional "
+ "memory consumption.")));
+
//===----------------------------------------------------------------------===//
// #stream.resource_config<...>
//===----------------------------------------------------------------------===//
@@ -219,7 +234,9 @@
if (attr) return attr;
op = op->getParentOp();
}
- return {}; // No config found; let caller decide what to do.
+ // No config found; use defaults.
+ auto favorAttr = FavorAttr::get(attrId.getContext(), partitioningFavor);
+ return PartitioningConfigAttr::get(favorAttr);
}
//===----------------------------------------------------------------------===//
@@ -292,6 +309,9 @@
#include "iree/compiler/Dialect/Stream/IR/StreamTypeInterfaces.cpp.inc" // IWYU pragma: keep
void StreamDialect::registerAttributes() {
+ // Register command line flags:
+ (void)partitioningFavor;
+
addAttributes<
#define GET_ATTRDEF_LIST
#include "iree/compiler/Dialect/Stream/IR/StreamAttrs.cpp.inc" // IWYU pragma: keep
diff --git a/iree/compiler/Dialect/Stream/IR/StreamTypes.h b/iree/compiler/Dialect/Stream/IR/StreamTypes.h
index 50e272c..82dea8d 100644
--- a/iree/compiler/Dialect/Stream/IR/StreamTypes.h
+++ b/iree/compiler/Dialect/Stream/IR/StreamTypes.h
@@ -7,7 +7,6 @@
#ifndef IREE_COMPILER_DIALECT_STREAM_IR_STREAMTYPES_H_
#define IREE_COMPILER_DIALECT_STREAM_IR_STREAMTYPES_H_
-#include "iree/compiler/Dialect/Shape/IR/ShapeTypes.h"
#include "iree/compiler/Dialect/Stream/IR/StreamDialect.h"
#include "iree/compiler/Dialect/Util/IR/UtilTypes.h"
#include "llvm/ADT/DenseMapInfo.h"
diff --git a/iree/compiler/Dialect/Stream/IR/test/timepoint_folding.mlir b/iree/compiler/Dialect/Stream/IR/test/timepoint_folding.mlir
index 4d0e50f..4e2fdbc 100644
--- a/iree/compiler/Dialect/Stream/IR/test/timepoint_folding.mlir
+++ b/iree/compiler/Dialect/Stream/IR/test/timepoint_folding.mlir
@@ -1,5 +1,28 @@
// RUN: iree-opt -split-input-file -canonicalize %s | IreeFileCheck %s
+// CHECK-LABEL: @FoldTimepointExport
+func @FoldTimepointExport(%arg0: !hal.semaphore, %arg1: index) -> (!hal.semaphore, index) {
+ // CHECK-NOT: stream.timepoint.import
+ %0 = stream.timepoint.import %arg0, %arg1 : (!hal.semaphore, index) => !stream.timepoint
+ // CHECK-NOT: stream.timepoint.export
+ %1:2 = stream.timepoint.export %0 => (!hal.semaphore, index)
+ // CHECK: return %arg0, %arg1
+ return %1#0, %1#1 : !hal.semaphore, index
+}
+
+// -----
+
+// CHECK-LABEL: @DontFoldTimepointExportMismatch
+func @DontFoldTimepointExportMismatch(%arg0: !hal.semaphore, %arg1: index) -> (!hal.semaphore, i32) {
+ // CHECK: stream.timepoint.import
+ %0 = stream.timepoint.import %arg0, %arg1 : (!hal.semaphore, index) => !stream.timepoint
+ // CHECK-NEXT: stream.timepoint.export
+ %1:2 = stream.timepoint.export %0 => (!hal.semaphore, i32)
+ return %1#0, %1#1 : !hal.semaphore, i32
+}
+
+// -----
+
// CHECK-LABEL: @FoldTimepointJoinOp
func @FoldTimepointJoinOp(%arg0: !stream.timepoint) -> !stream.timepoint {
// CHECK-NOT: stream.timepoint.join
diff --git a/iree/compiler/Dialect/Stream/IR/test/timepoint_ops.mlir b/iree/compiler/Dialect/Stream/IR/test/timepoint_ops.mlir
index fcc96d8..d49ecac 100644
--- a/iree/compiler/Dialect/Stream/IR/test/timepoint_ops.mlir
+++ b/iree/compiler/Dialect/Stream/IR/test/timepoint_ops.mlir
@@ -9,6 +9,24 @@
// -----
+// CHECK-LABEL: @timepointImport
+func @timepointImport(%arg0: !hal.semaphore, %arg1: index) -> !stream.timepoint {
+ // CHECK: = stream.timepoint.import %arg0, %arg1 : (!hal.semaphore, index) => !stream.timepoint
+ %0 = stream.timepoint.import %arg0, %arg1 : (!hal.semaphore, index) => !stream.timepoint
+ return %0 : !stream.timepoint
+}
+
+// -----
+
+// CHECK-LABEL: @timepointExport
+func @timepointExport(%arg0: !stream.timepoint) -> (!hal.semaphore, index) {
+ // CHECK: = stream.timepoint.export %arg0 => (!hal.semaphore, index)
+ %0:2 = stream.timepoint.export %arg0 => (!hal.semaphore, index)
+ return %0#0, %0#1 : !hal.semaphore, index
+}
+
+// -----
+
// CHECK-LABEL: @timepointJoin
func @timepointJoin(%arg0: !stream.timepoint, %arg1: !stream.timepoint) -> !stream.timepoint {
// CHECK: = stream.timepoint.join max(%arg0, %arg1) => !stream.timepoint
diff --git a/iree/compiler/Dialect/Stream/Transforms/BUILD b/iree/compiler/Dialect/Stream/Transforms/BUILD
index 18afde8..d1adeef 100644
--- a/iree/compiler/Dialect/Stream/Transforms/BUILD
+++ b/iree/compiler/Dialect/Stream/Transforms/BUILD
@@ -44,11 +44,10 @@
deps = [
":PassesIncGen",
"//iree/compiler/Dialect/Flow/IR",
- "//iree/compiler/Dialect/Shape/IR",
- "//iree/compiler/Dialect/Shape/Transforms",
"//iree/compiler/Dialect/Stream/Analysis",
"//iree/compiler/Dialect/Stream/Conversion",
"//iree/compiler/Dialect/Stream/Conversion/FlowToStream",
+ "//iree/compiler/Dialect/Stream/Conversion/HALToStream",
"//iree/compiler/Dialect/Stream/Conversion/StandardToStream",
"//iree/compiler/Dialect/Stream/Conversion/UtilToStream",
"//iree/compiler/Dialect/Stream/IR",
diff --git a/iree/compiler/Dialect/Stream/Transforms/CMakeLists.txt b/iree/compiler/Dialect/Stream/Transforms/CMakeLists.txt
index 9751f78..045281b 100644
--- a/iree/compiler/Dialect/Stream/Transforms/CMakeLists.txt
+++ b/iree/compiler/Dialect/Stream/Transforms/CMakeLists.txt
@@ -53,11 +53,10 @@
MLIRTransformUtils
MLIRTransforms
iree::compiler::Dialect::Flow::IR
- iree::compiler::Dialect::Shape::IR
- iree::compiler::Dialect::Shape::Transforms
iree::compiler::Dialect::Stream::Analysis
iree::compiler::Dialect::Stream::Conversion
iree::compiler::Dialect::Stream::Conversion::FlowToStream
+ iree::compiler::Dialect::Stream::Conversion::HALToStream
iree::compiler::Dialect::Stream::Conversion::StandardToStream
iree::compiler::Dialect::Stream::Conversion::UtilToStream
iree::compiler::Dialect::Stream::IR
diff --git a/iree/compiler/Dialect/Stream/Transforms/ConvertToStream.cpp b/iree/compiler/Dialect/Stream/Transforms/ConvertToStream.cpp
index fa9ac65..c8821d7 100644
--- a/iree/compiler/Dialect/Stream/Transforms/ConvertToStream.cpp
+++ b/iree/compiler/Dialect/Stream/Transforms/ConvertToStream.cpp
@@ -5,10 +5,8 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree/compiler/Dialect/Flow/IR/FlowDialect.h"
-#include "iree/compiler/Dialect/Shape/IR/Builders.h"
-#include "iree/compiler/Dialect/Shape/IR/ShapeDialect.h"
-#include "iree/compiler/Dialect/Shape/IR/ShapeOps.h"
#include "iree/compiler/Dialect/Stream/Conversion/FlowToStream/ConvertFlowToStream.h"
+#include "iree/compiler/Dialect/Stream/Conversion/HALToStream/ConvertHALToStream.h"
#include "iree/compiler/Dialect/Stream/Conversion/PatternUtils.h"
#include "iree/compiler/Dialect/Stream/Conversion/StandardToStream/ConvertStandardToStream.h"
#include "iree/compiler/Dialect/Stream/Conversion/UtilToStream/ConvertUtilToStream.h"
@@ -49,7 +47,7 @@
OpBuilder &builder) {
// Gather dynamic dimensions from the input value.
auto dynamicDims =
- Shape::buildOrFindDynamicDimsForValue(loc, sourceTensor, builder);
+ IREE::Util::buildDynamicDimsForValue(loc, sourceTensor, builder);
// Compute the size of the tensor once in the stream resource.
// This may differ from the external encoding of the tensor as imports are
@@ -154,7 +152,7 @@
auto tensorType = oldOperand.getType().dyn_cast<TensorType>();
assert(tensorType && "must have a tensor type to map to a resource");
- auto dynamicDims = Shape::buildOrFindDynamicDimsForValue(
+ auto dynamicDims = IREE::Util::buildDynamicDimsForValue(
op->getLoc(), oldOperand, rewriter);
newOperands.push_back(buildTensorExportOp(
op->getLoc(), newOperand, tensorType, dynamicDims, rewriter));
@@ -168,7 +166,7 @@
if (!tensorType) continue;
auto dynamicDims =
- Shape::buildOrFindDynamicDimsForValue(op->getLoc(), result, rewriter);
+ IREE::Util::buildDynamicDimsForValue(op->getLoc(), result, rewriter);
SmallPtrSet<Operation *, 4> consumingOps;
auto importedValue = buildTensorImportOp(
op->getLoc(), result, rewriter.getType<IREE::Stream::ResourceType>(),
@@ -183,9 +181,9 @@
class ConvertToStreamPass : public ConvertToStreamBase<ConvertToStreamPass> {
public:
void getDependentDialects(DialectRegistry ®istry) const override {
- registry.insert<ShapeDialect>();
registry.insert<mlir::StandardOpsDialect>();
registry.insert<mlir::arith::ArithmeticDialect>();
+ registry.insert<mlir::tensor::TensorDialect>();
registry.insert<IREE::Stream::StreamDialect>();
registry.insert<IREE::Util::UtilDialect>();
}
@@ -244,6 +242,8 @@
typeConverter, patterns);
populateFlowToStreamConversionPatterns(context, conversionTarget,
typeConverter, patterns);
+ populateHALToStreamConversionPatterns(context, conversionTarget,
+ typeConverter, patterns);
conversionTarget.markUnknownOpDynamicallyLegal(
[&](Operation *op) -> bool { return !doesOperationNeedWrapping(op); });
diff --git a/iree/compiler/Dialect/Stream/Transforms/Passes.cpp b/iree/compiler/Dialect/Stream/Transforms/Passes.cpp
index fa4f63c..7380f2b 100644
--- a/iree/compiler/Dialect/Stream/Transforms/Passes.cpp
+++ b/iree/compiler/Dialect/Stream/Transforms/Passes.cpp
@@ -8,7 +8,6 @@
#include <memory>
-#include "iree/compiler/Dialect/Shape/Transforms/Passes.h"
#include "iree/compiler/Dialect/Util/Transforms/Passes.h"
#include "mlir/Pass/PassOptions.h"
#include "mlir/Pass/PassRegistry.h"
diff --git a/iree/compiler/Dialect/Stream/Transforms/ScheduleConcurrency.cpp b/iree/compiler/Dialect/Stream/Transforms/ScheduleConcurrency.cpp
index 577ed25..8907012 100644
--- a/iree/compiler/Dialect/Stream/Transforms/ScheduleConcurrency.cpp
+++ b/iree/compiler/Dialect/Stream/Transforms/ScheduleConcurrency.cpp
@@ -13,6 +13,7 @@
#include "iree/compiler/Dialect/Util/IR/UtilDialect.h"
#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
#include "iree/compiler/Dialect/Util/IR/UtilTypes.h"
+#include "iree/compiler/Utils/GraphUtils.h"
#include "llvm/ADT/BitVector.h"
#include "llvm/Support/Debug.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
@@ -241,26 +242,16 @@
deadOps.insert(oldResult.getDefiningOp());
}
partitionBuilder.finish();
-
- // Extremely shady reordering of ops we know (should) be safe to move
- // after the partition - otherwise, we shouldn't have moved the source
- // ops into the partition.
- auto concurrentOp = partitionBuilder.concurrentOp;
- for (auto user : concurrentOp->getUsers()) {
- if (user->getBlock() == concurrentOp->getBlock() &&
- user->isBeforeInBlock(partitionBuilder.concurrentOp)) {
- LLVM_DEBUG({
- llvm::dbgs() << "Shady move of op to after partition: ";
- user->dump();
- });
- user->moveAfter(concurrentOp);
- }
- }
}
for (auto *deadOp : llvm::reverse(deadOps)) {
deadOp->erase();
}
+ // Sort the ops in the execution region as they may have gotten out of order
+ // during partitioning. This is safe because we are still unaliased and SSA
+ // values imply ordering.
+ sortBlockTopologically(block);
+
LLVM_DEBUG({
llvm::dbgs() << "\nWaves constructed:\n";
block->dump();
diff --git a/iree/compiler/Dialect/Stream/Transforms/ScheduleExecution.cpp b/iree/compiler/Dialect/Stream/Transforms/ScheduleExecution.cpp
index 42747ac..42739bd 100644
--- a/iree/compiler/Dialect/Stream/Transforms/ScheduleExecution.cpp
+++ b/iree/compiler/Dialect/Stream/Transforms/ScheduleExecution.cpp
@@ -13,6 +13,7 @@
#include "iree/compiler/Dialect/Util/IR/UtilDialect.h"
#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
#include "iree/compiler/Dialect/Util/IR/UtilTypes.h"
+#include "iree/compiler/Utils/GraphUtils.h"
#include "llvm/ADT/EquivalenceClasses.h"
#include "llvm/Support/Debug.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
@@ -284,28 +285,9 @@
deadOps.insert(oldResult.getDefiningOp());
}
- // Extremely shady reordering of ops we know (should) be safe to move
- // after the partition - otherwise, we shouldn't have moved the source
- // ops into the partition.
- SetVector<Operation *> worklist;
- for (auto user : executeOp->getUsers()) {
- worklist.insert(user);
- }
- while (!worklist.empty()) {
- auto *user = worklist.pop_back_val();
- if (user->getBlock() == executeOp->getBlock() &&
- user->isBeforeInBlock(executeOp)) {
- LLVM_DEBUG({
- llvm::dbgs() << "Shady move of op to after partition: ";
- user->dump();
- });
- user->moveAfter(builder.getInsertionBlock(),
- builder.getInsertionPoint());
- }
- for (auto subUser : user->getUsers()) {
- worklist.insert(subUser);
- }
- }
+ // Sort the ops in the execution region. This is safe because we are
+ // still unaliased and SSA values imply ordering.
+ sortBlockTopologically(block);
}
for (auto *deadOp : llvm::reverse(deadOps)) {
deadOp->erase();
diff --git a/iree/compiler/Dialect/Stream/Transforms/test/convert_to_stream.mlir b/iree/compiler/Dialect/Stream/Transforms/test/convert_to_stream.mlir
index ba5ef42..0ebf4b9 100644
--- a/iree/compiler/Dialect/Stream/Transforms/test/convert_to_stream.mlir
+++ b/iree/compiler/Dialect/Stream/Transforms/test/convert_to_stream.mlir
@@ -29,7 +29,7 @@
// CHECK: %[[ARG0_SIZE:.+]] = stream.tensor.sizeof tensor<?x4xf32>{%[[DIM0]]} : index
// CHECK: %[[ARG0_IMPORT:.+]] = stream.tensor.import %arg0 : !hal.buffer_view -> tensor<?x4xf32>{%[[DIM0]]} in !stream.resource<external>{%[[ARG0_SIZE]]}
// CHECK: %[[ARG0_T:.+]] = stream.async.transfer %[[ARG0_IMPORT]] : !stream.resource<external>{%[[ARG0_SIZE]]} -> !stream.resource<*>{%[[ARG0_SIZE]]}
- %0 = hal.tensor.cast %arg0 : !hal.buffer_view -> tensor<?x4xf32>{%dim0}
+ %0 = hal.tensor.import %arg0 : !hal.buffer_view -> tensor<?x4xf32>{%dim0}
%c1 = arith.constant 1 : index
%c4 = arith.constant 4 : index
@@ -39,7 +39,7 @@
// CHECK: %[[RET0_T:.+]] = stream.async.transfer %[[RET0]] : !stream.resource<*>{%[[RET0_SIZE]]} -> !stream.resource<external>{%[[RET0_SIZE]]}
// CHECK: %[[RET0_EXPORT:.+]] = stream.tensor.export %[[RET0_T]] : tensor<?xf32>{%[[DIM0]]} in !stream.resource<external>{%[[RET0_SIZE]]} -> !hal.buffer_view
- %2 = hal.tensor.cast %1 : tensor<?xf32>{%dim0} -> !hal.buffer_view
+ %2 = hal.tensor.export %1 : tensor<?xf32>{%dim0} -> !hal.buffer_view
// CHECK: return %[[RET0_EXPORT]] : !hal.buffer_view
return %2 : !hal.buffer_view
}
diff --git a/iree/compiler/Dialect/Stream/Transforms/test/schedule_concurrency.mlir b/iree/compiler/Dialect/Stream/Transforms/test/schedule_concurrency.mlir
index 2cd287a..f31e4f5 100644
--- a/iree/compiler/Dialect/Stream/Transforms/test/schedule_concurrency.mlir
+++ b/iree/compiler/Dialect/Stream/Transforms/test/schedule_concurrency.mlir
@@ -1,8 +1,61 @@
// RUN: iree-opt -split-input-file -pass-pipeline="builtin.func(iree-stream-schedule-concurrency)" %s | IreeFileCheck %s
-// CHECK-LABEL: @partitioning
+// Tests that when favor=min-peak-memory we assume ops are in an order that
+// reduces live memory ranges and only optimistically put them in concurrency
+// regions when it wouldn't increase the ranges.
+
+// CHECK-LABEL: @partitioningForMinPeakMemory
// CHECK-SAME: (%[[ARG0:.+]]: !stream.resource<external>, %[[ARG1:.+]]: !stream.resource<external>)
-func @partitioning(%arg0: !stream.resource<external>, %arg1: !stream.resource<external>) -> !stream.resource<external> {
+func @partitioningForMinPeakMemory(%arg0: !stream.resource<external>, %arg1: !stream.resource<external>) -> !stream.resource<external>
+ attributes {stream.partitioning = #stream.partitioning_config<"min-peak-memory">} {
+ %c1 = arith.constant 1 : index
+ %c20 = arith.constant 20 : index
+ %c80 = arith.constant 80 : index
+ %c1280 = arith.constant 1280 : index
+ %cst = arith.constant 0x7F800000 : f32
+ // CHECK: stream.async.execute
+ %results, %result_timepoint = stream.async.execute
+ // CHECK-SAME: with(%[[ARG1]] as %[[ARG1_CAPTURE:.+]]: !stream.resource<external>{%c80},
+ // CHECK-SAME: %[[ARG0]] as %[[ARG0_CAPTURE:.+]]: !stream.resource<external>{%c20})
+ with(%arg1 as %arg2: !stream.resource<external>{%c80},
+ %arg0 as %arg3: !stream.resource<external>{%c20})
+ -> !stream.resource<external>{%c20} {
+
+ // CHECK: %[[SPLAT0:.+]] = stream.async.splat %cst : f32 -> !stream.resource<transient>{%c1280}
+ %1 = stream.async.splat %cst : f32 -> !stream.resource<transient>{%c1280}
+
+ // CHECK: %[[CON0:.+]]:2 = stream.async.concurrent
+ // CHECK-SAME: with(%[[SPLAT0]] as %[[SPLAT0_CAPTURE:.+]]: !stream.resource<transient>{%c1280},
+ // CHECK-SAME: %[[ARG1_CAPTURE]] as %[[ARG1_CON0_CAPTURE:.+]]: !stream.resource<external>{%c80})
+ // CHECK-SAME: -> (!stream.resource<transient>{%c1280}, !stream.resource<transient>{%c20}) {
+ // CHECK-NEXT: %[[DISPATCH0:.+]] = stream.async.dispatch @ex::@dispatch_0[%c1, %c1, %c1](%[[SPLAT0_CAPTURE]], %[[ARG1_CON0_CAPTURE]])
+ %2 = stream.async.dispatch @ex::@dispatch_0[%c1, %c1, %c1](%1, %arg2) : (!stream.resource<transient>{%c1280}, !stream.resource<external>{%c80}) -> %1{%c1280}
+ // CHECK-NEXT: %[[SPLAT1:.+]] = stream.async.splat %cst : f32 -> !stream.resource<transient>{%c20}
+ %3 = stream.async.splat %cst : f32 -> !stream.resource<transient>{%c20}
+ // CHECK-NEXT: stream.yield %[[DISPATCH0]], %[[SPLAT1]] : !stream.resource<transient>{%c1280}, !stream.resource<transient>{%c20}
+
+ // CHECK: %[[DISPATCH1:.+]] = stream.async.dispatch @ex::@dispatch_1[%c1, %c1, %c1](%[[ARG0_CAPTURE]], %[[CON0]]#1)
+ %4 = stream.async.dispatch @ex::@dispatch_1[%c1, %c1, %c1](%arg3, %3) : (!stream.resource<external>{%c20}, !stream.resource<transient>{%c20}) -> %3{%c20}
+
+ // CHECK: %[[DISPATCH2:.+]] = stream.async.dispatch @ex::@dispatch_2[%c1, %c1, %c1](%[[CON0]]#0, %[[DISPATCH1]])
+ %5 = stream.async.dispatch @ex::@dispatch_2[%c1, %c1, %c1](%2, %4) : (!stream.resource<transient>{%c1280}, !stream.resource<transient>{%c20}) -> !stream.resource<external>{%c20}
+
+ // CHECK-NEXT: stream.yield %[[DISPATCH2]]
+ stream.yield %5 : !stream.resource<external>{%c20}
+ } => !stream.timepoint
+ %0 = stream.timepoint.await %result_timepoint => %results : !stream.resource<external>{%c20}
+ return %0 : !stream.resource<external>
+}
+
+// -----
+
+// Tests that when favor=max-concurrency we reorder ops aggressively to maximize
+// the amount of work scheduled concurrently.
+
+// CHECK-LABEL: @partitioningForMaxConcurrency
+// CHECK-SAME: (%[[ARG0:.+]]: !stream.resource<external>, %[[ARG1:.+]]: !stream.resource<external>)
+func @partitioningForMaxConcurrency(%arg0: !stream.resource<external>, %arg1: !stream.resource<external>) -> !stream.resource<external>
+ attributes {stream.partitioning = #stream.partitioning_config<"max-concurrency">} {
%c1 = arith.constant 1 : index
%c20 = arith.constant 20 : index
%c80 = arith.constant 80 : index
diff --git a/iree/compiler/Dialect/Util/IR/BUILD b/iree/compiler/Dialect/Util/IR/BUILD
index e68c1c9..d41d8e9 100644
--- a/iree/compiler/Dialect/Util/IR/BUILD
+++ b/iree/compiler/Dialect/Util/IR/BUILD
@@ -74,11 +74,13 @@
"@llvm-project//mlir:ArithmeticDialect",
"@llvm-project//mlir:ControlFlowInterfaces",
"@llvm-project//mlir:IR",
+ "@llvm-project//mlir:MemRefDialect",
"@llvm-project//mlir:Parser",
"@llvm-project//mlir:SideEffectInterfaces",
"@llvm-project//mlir:SideEffects",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
+ "@llvm-project//mlir:TensorDialect",
"@llvm-project//mlir:Transforms",
],
)
@@ -151,11 +153,11 @@
)
iree_tablegen_doc(
- name = "UtilUtilDialectDocGen",
+ name = "UtilDialectDocGen",
tbl_outs = [
(
["-gen-dialect-doc"],
- "UtilUtilDialect.md",
+ "UtilDialect.md",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
diff --git a/iree/compiler/Dialect/Util/IR/CMakeLists.txt b/iree/compiler/Dialect/Util/IR/CMakeLists.txt
index cb1cc88..7d6350f 100644
--- a/iree/compiler/Dialect/Util/IR/CMakeLists.txt
+++ b/iree/compiler/Dialect/Util/IR/CMakeLists.txt
@@ -45,10 +45,12 @@
MLIRArithmetic
MLIRControlFlowInterfaces
MLIRIR
+ MLIRMemRef
MLIRParser
MLIRSideEffectInterfaces
MLIRStandard
MLIRSupport
+ MLIRTensor
MLIRTransforms
PUBLIC
)
@@ -89,11 +91,11 @@
iree_tablegen_doc(
NAME
- UtilUtilDialectDocGen
+ UtilDialectDocGen
TD_FILE
"UtilOps.td"
OUTS
- -gen-dialect-doc UtilUtilDialect.md
+ -gen-dialect-doc UtilDialect.md
)
### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/iree/compiler/Dialect/Util/IR/UtilDialect.cpp b/iree/compiler/Dialect/Util/IR/UtilDialect.cpp
index 3f43ccd..1b3e2a4 100644
--- a/iree/compiler/Dialect/Util/IR/UtilDialect.cpp
+++ b/iree/compiler/Dialect/Util/IR/UtilDialect.cpp
@@ -11,9 +11,12 @@
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/SourceMgr.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/DialectImplementation.h"
+#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/Parser.h"
@@ -97,6 +100,49 @@
return nullptr;
}
+template <typename DimOp>
+struct FoldDimOp : public OpRewritePattern<DimOp> {
+ using OpRewritePattern<DimOp>::OpRewritePattern;
+ LogicalResult matchAndRewrite(DimOp op,
+ PatternRewriter &rewriter) const override {
+ auto shapeAwareOp =
+ dyn_cast_or_null<ShapeAwareOpInterface>(op.source().getDefiningOp());
+ if (!shapeAwareOp) return failure();
+
+ // We only support static dimension indices today (as in general we only
+ // support ranked shapes). If we find dynamic indices sneaking in we will
+ // need to do something much more complex - or prevent them from sneaking
+ // in.
+ APInt index;
+ if (!matchPattern(op.index(), m_ConstantInt(&index))) {
+ return rewriter.notifyMatchFailure(op,
+ "non-constant dim index unsupported");
+ }
+
+ // If it's a static dim then just fold to that.
+ auto type = op.source().getType().template cast<ShapedType>();
+ int64_t staticDim = type.getDimSize(index.getZExtValue());
+ if (staticDim != ShapedType::kDynamicSize) {
+ rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op, staticDim);
+ return success();
+ }
+
+ // Otherwise try to get the dynamic dimension cheaply without the need to
+ // insert new IR.
+ unsigned dynamicIdx = type.getDynamicDimIndex(index.getZExtValue());
+ auto dynamicDims = shapeAwareOp.getResultDynamicDimsFromValue(op.source());
+ rewriter.replaceOp(op, dynamicDims[dynamicIdx]);
+
+ return success();
+ }
+};
+
+void UtilDialect::getCanonicalizationPatterns(
+ RewritePatternSet &results) const {
+ results.insert<FoldDimOp<memref::DimOp>>(getContext());
+ results.insert<FoldDimOp<tensor::DimOp>>(getContext());
+}
+
} // namespace Util
} // namespace IREE
} // namespace iree_compiler
diff --git a/iree/compiler/Dialect/Util/IR/UtilDialect.h b/iree/compiler/Dialect/Util/IR/UtilDialect.h
index 468d3ed..8a07a6e 100644
--- a/iree/compiler/Dialect/Util/IR/UtilDialect.h
+++ b/iree/compiler/Dialect/Util/IR/UtilDialect.h
@@ -29,6 +29,8 @@
Operation* materializeConstant(OpBuilder& builder, Attribute value, Type type,
Location loc) override;
+ void getCanonicalizationPatterns(RewritePatternSet& results) const override;
+
private:
void registerAttributes();
void registerTypes();
diff --git a/iree/compiler/Dialect/Util/IR/UtilInterfaces.td b/iree/compiler/Dialect/Util/IR/UtilInterfaces.td
index fec6cd4..cbd57c1 100644
--- a/iree/compiler/Dialect/Util/IR/UtilInterfaces.td
+++ b/iree/compiler/Dialect/Util/IR/UtilInterfaces.td
@@ -309,6 +309,23 @@
let description = [{
An operation that is able to provide dynamic shape dimensions for all
shaped operands and results.
+
+ This is a more fine-grained version of ReifyRankedShapedTypeOpInterface that
+ allows for querying of individual operand and result dimensions without
+ unconditionally inserting IR. The dynamic dimension queries allow us to find
+ the dynamic shape dimension SSA values in the IR in read-only mode (such as
+ during analysis), and having the queries specific to a particular operand
+ and result allows us to walk through ops along use-def edges. When combined
+ with tied operands the split queries allow for operands and results to have
+ differing dimensions (such as after reshaping/casting).
+
+ `getOperandDynamicDims` and `getResultDynamicDims` are the load-bearing
+ methods and if there was an equivalent set in an upstream dialect we could
+ switch to using that instead as the rest are just utilities built around
+ them.
+
+ `tensor.dim` and `memref.dim` will both resolve shape dimensions through
+ this interface when the dimension index is constant.
}];
let methods = [
@@ -332,6 +349,35 @@
return {};
}]
>,
+ InterfaceMethod<
+ /*desc=*/[{Returns a full shape for the given shaped operand index.}],
+ /*retTy=*/"SmallVector<Value>",
+ /*methodName=*/"buildOperandShape",
+ /*args=*/(ins "unsigned":$idx, "OpBuilder &":$builder),
+ /*defaultImplementation=*/[{
+ return IREE::Util::buildOperandShape($_self, idx, builder);
+ }]
+ >,
+ InterfaceMethod<
+ /*desc=*/[{Returns a full shape for the given shaped result index.}],
+ /*retTy=*/"SmallVector<Value>",
+ /*methodName=*/"buildResultShape",
+ /*args=*/(ins "unsigned":$idx, "OpBuilder &":$builder),
+ /*defaultImplementation=*/[{
+ return IREE::Util::buildResultShape($_self, idx, builder);
+ }]
+ >,
+ StaticInterfaceMethod<
+ /*desc=*/[{Builds a full shape for the given shaped result value.}],
+ /*retTy=*/"SmallVector<Value>",
+ /*methodName=*/"buildResultValueShape",
+ /*args=*/(ins "Value":$result, "OpBuilder &":$builder),
+ /*methodBody=*/[{
+ auto shapeAwareOp = dyn_cast<IREE::Util::ShapeAwareOpInterface>(result.getDefiningOp());
+ return shapeAwareOp.buildResultShape(
+ result.cast<mlir::OpResult>().getResultNumber(), builder);
+ }]
+ >,
];
}
diff --git a/iree/compiler/Dialect/Util/IR/UtilTypes.cpp b/iree/compiler/Dialect/Util/IR/UtilTypes.cpp
index f01cca4..e5fc087 100644
--- a/iree/compiler/Dialect/Util/IR/UtilTypes.cpp
+++ b/iree/compiler/Dialect/Util/IR/UtilTypes.cpp
@@ -9,6 +9,7 @@
#include "iree/compiler/Dialect/Util/IR/UtilDialect.h"
#include "llvm/ADT/BitVector.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
@@ -397,27 +398,6 @@
// IREE::Util::ShapeAware*
//===----------------------------------------------------------------------===//
-ValueRange findVariadicDynamicDims(unsigned idx, ValueRange values,
- ValueRange dynamicDims) {
- auto value = values[idx];
- auto shapedType = value.getType().dyn_cast<ShapedType>();
- if (!shapedType) return ValueRange{};
-
- // Bail immediately if the shape is static.
- if (shapedType.hasStaticShape()) return ValueRange{};
-
- // Find where the dynamic dims start in the flattened list.
- unsigned offset = 0;
- for (unsigned i = 0; i < idx; ++i) {
- if (auto type = values[i].getType().dyn_cast<ShapedType>()) {
- offset += type.getNumDynamicDims();
- }
- }
-
- // Return the subrange of dynamic dims for the value being queried.
- return dynamicDims.slice(offset, shapedType.getNumDynamicDims());
-}
-
Optional<ValueRange> findDynamicDims(Value shapedValue) {
// Look up the use-def chain: always safe, as any value we reach dominates
// {|block|, |insertionPoint|} implicitly.
@@ -463,6 +443,107 @@
return None;
}
+ValueRange findVariadicDynamicDims(unsigned idx, ValueRange values,
+ ValueRange dynamicDims) {
+ auto value = values[idx];
+ auto shapedType = value.getType().dyn_cast<ShapedType>();
+ if (!shapedType) return ValueRange{};
+
+ // Bail immediately if the shape is static.
+ if (shapedType.hasStaticShape()) return ValueRange{};
+
+ // Find where the dynamic dims start in the flattened list.
+ unsigned offset = 0;
+ for (unsigned i = 0; i < idx; ++i) {
+ if (auto type = values[i].getType().dyn_cast<ShapedType>()) {
+ offset += type.getNumDynamicDims();
+ }
+ }
+
+ // Return the subrange of dynamic dims for the value being queried.
+ return dynamicDims.slice(offset, shapedType.getNumDynamicDims());
+}
+
+SmallVector<Value> buildDynamicDimsForValue(Location loc, Value value,
+ OpBuilder &builder) {
+ auto valueType = value.getType().dyn_cast<ShapedType>();
+ if (!valueType) {
+ mlir::emitError(loc) << "cannot construct shape for non shaped value: "
+ << value.getType();
+ return {};
+ }
+
+ // Early-exit if all dimensions are static.
+ if (valueType.hasStaticShape()) {
+ return {};
+ }
+
+ // Try the fast-path of scanning for the dynamic dims that exist in the IR
+ // already. For shape-aware ops this is free as the dynamic dim SSA values are
+ // always available.
+ auto foundDynamicDims = IREE::Util::findDynamicDims(
+ value, builder.getBlock(), builder.getInsertionPoint());
+ if (foundDynamicDims.hasValue()) {
+ return llvm::to_vector<4>(foundDynamicDims.getValue());
+ }
+
+ // Slower path that materializes the entire shape for a result. Some
+ // implementations may only support this (vs the fast find above).
+ if (auto shapeAwareOp = dyn_cast_or_null<IREE::Util::ShapeAwareOpInterface>(
+ value.getDefiningOp())) {
+ return shapeAwareOp.buildResultValueShape(value, builder);
+ }
+
+ // TODO(benvanik): add support for ReifyRankedShapedTypeOpInterface;
+ // unfortunately it is for all results and all dimensions so a lot of unneeded
+ // IR will be inserted.
+
+ // Fallback to inserting dim ops that can be resolved via normal upstream
+ // mechanisms. Depending on where this is called from within the parent
+ // pipeline these ops may not be desirable, but that's what the
+ // ShapeAwareOpInterface is for.
+ SmallVector<Value> dynamicDims;
+ for (unsigned i = 0; i < valueType.getRank(); ++i) {
+ if (valueType.isDynamicDim(i)) {
+ dynamicDims.push_back(builder.createOrFold<tensor::DimOp>(loc, value, i));
+ }
+ }
+ return dynamicDims;
+}
+
+static SmallVector<Value> buildShape(Location loc, ShapedType type,
+ ValueRange dynamicDims,
+ OpBuilder &builder) {
+ SmallVector<Value> dims;
+ dims.reserve(type.getRank());
+ unsigned dynamicIdx = 0;
+ for (unsigned i = 0; i < type.getRank(); ++i) {
+ int64_t dim = type.getDimSize(i);
+ if (dim == ShapedType::kDynamicSize) {
+ dims.push_back(dynamicDims[dynamicIdx++]);
+ } else {
+ dims.push_back(builder.create<arith::ConstantIndexOp>(loc, dim));
+ }
+ }
+ return dims;
+}
+
+SmallVector<Value> buildOperandShape(ShapeAwareOpInterface op,
+ unsigned operandIdx, OpBuilder &builder) {
+ auto operand = op->getOperand(operandIdx);
+ auto type = operand.getType().cast<ShapedType>();
+ auto dynamicDims = op.getOperandDynamicDims(operandIdx);
+ return buildShape(op.getLoc(), type, dynamicDims, builder);
+}
+
+SmallVector<Value> buildResultShape(ShapeAwareOpInterface op,
+ unsigned resultIdx, OpBuilder &builder) {
+ auto result = op->getResult(resultIdx);
+ auto type = result.getType().cast<ShapedType>();
+ auto dynamicDims = op.getResultDynamicDims(resultIdx);
+ return buildShape(op.getLoc(), type, dynamicDims, builder);
+}
+
//===----------------------------------------------------------------------===//
// IREE::Util::UtilDialect
//===----------------------------------------------------------------------===//
diff --git a/iree/compiler/Dialect/Util/IR/UtilTypes.h b/iree/compiler/Dialect/Util/IR/UtilTypes.h
index e7643fd..2d47d4f 100644
--- a/iree/compiler/Dialect/Util/IR/UtilTypes.h
+++ b/iree/compiler/Dialect/Util/IR/UtilTypes.h
@@ -26,13 +26,17 @@
namespace IREE {
namespace Util {
+class ShapeAwareOpInterface;
class TiedOpInterface;
+//===----------------------------------------------------------------------===//
+// Common types
+//===----------------------------------------------------------------------===//
+
namespace detail {
struct ListTypeStorage;
struct PtrTypeStorage;
-struct RankedShapeTypeStorage;
} // namespace detail
@@ -149,6 +153,10 @@
using Base::Base;
};
+//===----------------------------------------------------------------------===//
+// Tied operand interface utilities
+//===----------------------------------------------------------------------===//
+
namespace detail {
llvm::Optional<unsigned> getTiedResultOperandIndex(Operation *op,
@@ -169,6 +177,10 @@
ArrayRef<unsigned> excludedResultIndices,
SmallVector<int64_t, 4> &tiedOperandIndices);
+//===----------------------------------------------------------------------===//
+// Shape-aware interface utilities
+//===----------------------------------------------------------------------===//
+
// Walks the SSA use-def chain upwards to find the dynamic dimensions of the
// value. Returns None if the shape cannot be found.
Optional<ValueRange> findDynamicDims(Value shapedValue);
@@ -183,6 +195,24 @@
ValueRange findVariadicDynamicDims(unsigned idx, ValueRange values,
ValueRange dynamicDims);
+// Returns dimension values for each dynamic dimension of the given |value|.
+// |value| must be a ShapedType. The returned value range will be empty if the
+// shape is fully static.
+SmallVector<Value> buildDynamicDimsForValue(Location loc, Value value,
+ OpBuilder &builder);
+
+// Builds a ranked shape with all dimension values for the given operand.
+SmallVector<Value> buildOperandShape(ShapeAwareOpInterface op,
+ unsigned operandIdx, OpBuilder &builder);
+
+// Builds a ranked shape with all dimension values for the given result.
+SmallVector<Value> buildResultShape(ShapeAwareOpInterface op,
+ unsigned resultIdx, OpBuilder &builder);
+
+//===----------------------------------------------------------------------===//
+// Alignment and byte offset/length manipulation
+//===----------------------------------------------------------------------===//
+
// Aligns |value| to |alignment|, rounding up if needed.
static inline uint64_t align(uint64_t value, uint64_t alignment) {
return (value + (alignment - 1)) & ~(alignment - 1);
diff --git a/iree/compiler/Dialect/VM/Conversion/BUILD b/iree/compiler/Dialect/VM/Conversion/BUILD
index a065da6..4b90b33 100644
--- a/iree/compiler/Dialect/VM/Conversion/BUILD
+++ b/iree/compiler/Dialect/VM/Conversion/BUILD
@@ -26,7 +26,6 @@
"TypeConverter.h",
],
deps = [
- "//iree/compiler/Dialect/Shape/IR",
"//iree/compiler/Dialect/Util/IR",
"//iree/compiler/Dialect/VM/IR",
"@llvm-project//llvm:Support",
diff --git a/iree/compiler/Dialect/VM/Conversion/CMakeLists.txt b/iree/compiler/Dialect/VM/Conversion/CMakeLists.txt
index fd8fc39..e0f872a 100644
--- a/iree/compiler/Dialect/VM/Conversion/CMakeLists.txt
+++ b/iree/compiler/Dialect/VM/Conversion/CMakeLists.txt
@@ -30,7 +30,6 @@
MLIRParser
MLIRStandard
MLIRTransforms
- iree::compiler::Dialect::Shape::IR
iree::compiler::Dialect::Util::IR
iree::compiler::Dialect::VM::IR
PUBLIC
diff --git a/iree/compiler/Dialect/VM/Conversion/ImportUtils.h b/iree/compiler/Dialect/VM/Conversion/ImportUtils.h
index 21992c3..4235814 100644
--- a/iree/compiler/Dialect/VM/Conversion/ImportUtils.h
+++ b/iree/compiler/Dialect/VM/Conversion/ImportUtils.h
@@ -7,8 +7,6 @@
#ifndef IREE_COMPILER_DIALECT_VM_CONVERSION_IMPORTUTILS_H_
#define IREE_COMPILER_DIALECT_VM_CONVERSION_IMPORTUTILS_H_
-#include "iree/compiler/Dialect/Shape/IR/ShapeOps.h"
-#include "iree/compiler/Dialect/Shape/IR/ShapeTypes.h"
#include "iree/compiler/Dialect/Util/IR/UtilTypes.h"
#include "iree/compiler/Dialect/VM/IR/VMOps.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
@@ -89,27 +87,12 @@
auto newOperands =
llvm::to_vector<4>(adaptor.getODSOperands(inputSetIndex));
++inputSetIndex;
- if (oldOperands.size() == 1 &&
- oldOperands[0].getType().template isa<Shape::RankedShapeType>()) {
- // Expand a ranked_shape into its dimensions.
- // We need to rematerialize the static dimensions and then pass through
- // the new dynamic dimensions that we have the SSA values for.
- auto rankedShapeType = oldOperands[0]
- .getType()
- .template dyn_cast<Shape::RankedShapeType>();
- for (int i = 0; i < rankedShapeType.getRank(); ++i) {
- auto dimOp = rewriter.createOrFold<Shape::RankedDimOp>(
- op.getLoc(), oldOperands[0], i);
- state.addOperands(dimOp);
- }
- segmentSizes.push_back(rankedShapeType.getRank());
+
+ state.addOperands(newOperands);
+ if (importOp.isFuncArgumentVariadic(input.index())) {
+ segmentSizes.push_back(newOperands.size());
} else {
- state.addOperands(newOperands);
- if (importOp.isFuncArgumentVariadic(input.index())) {
- segmentSizes.push_back(newOperands.size());
- } else {
- segmentSizes.push_back(kFixedSingleValue);
- }
+ segmentSizes.push_back(kFixedSingleValue);
}
}
}
diff --git a/iree/compiler/Dialect/VM/Conversion/TypeConverter.cpp b/iree/compiler/Dialect/VM/Conversion/TypeConverter.cpp
index b119353..5a00c9c 100644
--- a/iree/compiler/Dialect/VM/Conversion/TypeConverter.cpp
+++ b/iree/compiler/Dialect/VM/Conversion/TypeConverter.cpp
@@ -6,8 +6,6 @@
#include "iree/compiler/Dialect/VM/Conversion/TypeConverter.h"
-#include "iree/compiler/Dialect/Shape/IR/ShapeOps.h"
-#include "iree/compiler/Dialect/Shape/IR/ShapeTypes.h"
#include "iree/compiler/Dialect/Util/IR/UtilTypes.h"
#include "iree/compiler/Dialect/VM/IR/VMOps.h"
#include "llvm/Support/Debug.h"
@@ -113,28 +111,6 @@
IREE::VM::BufferType::get(vectorType.getContext()));
});
- // Convert ranked shape types (expanding all dims).
- addConversion([this](Shape::RankedShapeType rankedShape,
- SmallVectorImpl<Type> &results) {
- auto indexType =
- IntegerType::get(rankedShape.getContext(), targetOptions_.indexBits);
- for (int i = 0; i < rankedShape.getRank(); ++i) {
- if (rankedShape.isDimDynamic(i)) {
- results.push_back(indexType);
- }
- }
- return success();
- });
-
- // TODO(b/145876978): materialize conversion for other types
- addArgumentMaterialization([](OpBuilder &builder,
- Shape::RankedShapeType resultType,
- ValueRange inputs, Location loc) -> Value {
- LLVM_DEBUG(llvm::dbgs()
- << "MATERIALIZE CONVERSION: " << resultType << "\n");
- return builder.create<Shape::MakeRankedShapeOp>(loc, resultType, inputs);
- });
-
addSourceMaterialization([](OpBuilder &builder, IndexType type,
ValueRange inputs, Location loc) -> Value {
if (inputs.size() != 1 || !inputs.front().getType().isa<IntegerType>()) {
diff --git a/iree/compiler/Dialect/VM/Transforms/BUILD b/iree/compiler/Dialect/VM/Transforms/BUILD
index fd4d30a..426edfa 100644
--- a/iree/compiler/Dialect/VM/Transforms/BUILD
+++ b/iree/compiler/Dialect/VM/Transforms/BUILD
@@ -25,7 +25,6 @@
"Passes.h",
],
deps = [
- "//iree/compiler/Dialect/Shape/IR",
"//iree/compiler/Dialect/Util/Conversion",
"//iree/compiler/Dialect/Util/IR",
"//iree/compiler/Dialect/VM/Conversion",
diff --git a/iree/compiler/Dialect/VM/Transforms/CMakeLists.txt b/iree/compiler/Dialect/VM/Transforms/CMakeLists.txt
index 5558889..cc69f45 100644
--- a/iree/compiler/Dialect/VM/Transforms/CMakeLists.txt
+++ b/iree/compiler/Dialect/VM/Transforms/CMakeLists.txt
@@ -36,7 +36,6 @@
MLIRSupport
MLIRTransformUtils
MLIRTransforms
- iree::compiler::Dialect::Shape::IR
iree::compiler::Dialect::Util::Conversion
iree::compiler::Dialect::Util::IR
iree::compiler::Dialect::VM::Conversion
diff --git a/iree/compiler/Dialect/VM/Transforms/Conversion.cpp b/iree/compiler/Dialect/VM/Transforms/Conversion.cpp
index 0d0d28a..c190deb 100644
--- a/iree/compiler/Dialect/VM/Transforms/Conversion.cpp
+++ b/iree/compiler/Dialect/VM/Transforms/Conversion.cpp
@@ -7,7 +7,6 @@
#include <memory>
#include <tuple>
-#include "iree/compiler/Dialect/Shape/IR/ShapeOps.h"
#include "iree/compiler/Dialect/Util/Conversion/ConversionPatterns.h"
#include "iree/compiler/Dialect/Util/IR/UtilDialect.h"
#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
@@ -139,7 +138,6 @@
dialectInterface->populateVMConversionPatterns(
importSymbols, conversionPatterns, typeConverter);
}
- Shape::populateFoldConversionPatterns(context, conversionPatterns);
if (failed(applyPartialConversion(outerModuleOp, conversionTarget,
std::move(conversionPatterns)))) {
diff --git a/iree/compiler/InputConversion/Common/IREEImportPublic.cpp b/iree/compiler/InputConversion/Common/IREEImportPublic.cpp
index b018efb..a990e5e 100644
--- a/iree/compiler/InputConversion/Common/IREEImportPublic.cpp
+++ b/iree/compiler/InputConversion/Common/IREEImportPublic.cpp
@@ -92,7 +92,7 @@
ConversionPatternRewriter &rewriter) const override {
Type resultType = typeConverter->convertType(srcOp.target().getType());
if (!resultType) return failure();
- rewriter.replaceOpWithNewOp<IREE::HAL::TensorCastOp>(
+ rewriter.replaceOpWithNewOp<IREE::HAL::TensorImportOp>(
srcOp, resultType, adaptor.source(), adaptor.target_dims());
return success();
}
@@ -107,7 +107,7 @@
ConversionPatternRewriter &rewriter) const override {
Type resultType = typeConverter->convertType(srcOp.target().getType());
if (!resultType) return failure();
- rewriter.replaceOpWithNewOp<IREE::HAL::TensorCastOp>(
+ rewriter.replaceOpWithNewOp<IREE::HAL::TensorExportOp>(
srcOp, resultType, adaptor.source(), adaptor.source_dims());
return success();
}
diff --git a/iree/compiler/InputConversion/Common/test/iree_import_public.mlir b/iree/compiler/InputConversion/Common/test/iree_import_public.mlir
index 9aaf540..f55f620 100644
--- a/iree/compiler/InputConversion/Common/test/iree_import_public.mlir
+++ b/iree/compiler/InputConversion/Common/test/iree_import_public.mlir
@@ -39,17 +39,17 @@
// -----
// CHECK-LABEL: func @tensor_to_buffer_view
-// CHECK: hal.tensor.cast %arg0 : tensor<?x?x3xf32>{%arg1, %arg2} -> !hal.buffer_view
+// CHECK: hal.tensor.export %arg0 : tensor<?x?x3xf32>{%arg1, %arg2} -> !hal.buffer_view
builtin.func @tensor_to_buffer_view(%arg0 : tensor<?x?x3xf32>, %arg1 : index, %arg2 : index) -> !iree_input.buffer_view {
- %0 = iree_input.cast.tensor_to_buffer_view %arg0 : tensor<?x?x3xf32> {%arg1, %arg2} -> !iree_input.buffer_view
+ %0 = iree_input.cast.tensor_to_buffer_view %arg0 : tensor<?x?x3xf32>{%arg1, %arg2} -> !iree_input.buffer_view
return %0 : !iree_input.buffer_view
}
// -----
// CHECK-LABEL: func @buffer_view_to_tensor
-// CHECK: hal.tensor.cast %arg0 : !hal.buffer_view -> tensor<?x?x3xf32>{%arg1, %arg2}
+// CHECK: hal.tensor.import %arg0 : !hal.buffer_view -> tensor<?x?x3xf32>{%arg1, %arg2}
builtin.func @buffer_view_to_tensor(%arg0 : !iree_input.buffer_view, %arg1 : index, %arg2 : index) -> tensor<?x?x3xf32> {
- %0 = iree_input.cast.buffer_view_to_tensor %arg0 : !iree_input.buffer_view -> tensor<?x?x3xf32> {%arg1, %arg2}
+ %0 = iree_input.cast.buffer_view_to_tensor %arg0 : !iree_input.buffer_view -> tensor<?x?x3xf32>{%arg1, %arg2}
return %0 : tensor<?x?x3xf32>
}
diff --git a/iree/compiler/InputConversion/MHLO/BUILD b/iree/compiler/InputConversion/MHLO/BUILD
index 4f4b6d8..1f3adc5 100644
--- a/iree/compiler/InputConversion/MHLO/BUILD
+++ b/iree/compiler/InputConversion/MHLO/BUILD
@@ -50,6 +50,7 @@
"ConvertMHLOToFlow.cpp",
"ConvertMHLOToFlow.h",
"ConvertMHLOToLinalgExt.cpp",
+ "FlattenTuplesInCFG.cpp",
"LegalizeInputTypes.cpp",
"MHLOToLinalgOnTensors.cpp",
"MHLOToMHLOPreprocessing.cpp",
@@ -64,7 +65,6 @@
":PassesIncGen",
"//iree/compiler/Dialect/Flow/IR",
"//iree/compiler/Dialect/Flow/Transforms",
- "//iree/compiler/Dialect/Shape/IR",
"//iree/compiler/Dialect/Util/IR",
"//iree/compiler/InputConversion/Common",
"//llvm-external-projects/iree-dialects:IREELinalgExtDialect",
@@ -81,6 +81,7 @@
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:ReconcileUnrealizedCasts",
"@llvm-project//mlir:SCFToStandard",
+ "@llvm-project//mlir:SCFTransforms",
"@llvm-project//mlir:Shape",
"@llvm-project//mlir:ShapeToStandard",
"@llvm-project//mlir:ShapeTransforms",
@@ -90,11 +91,15 @@
"@llvm-project//mlir:Transforms",
"@mlir-hlo//:chlo_legalize_to_hlo",
"@mlir-hlo//:hlo",
+ "@mlir-hlo//:legalize_control_flow",
"@mlir-hlo//:legalize_einsum_to_dot_general",
"@mlir-hlo//:legalize_gather_to_torch_index_select",
"@mlir-hlo//:legalize_to_linalg",
+ "@mlir-hlo//:legalize_to_standard",
+ "@mlir-hlo//:map_lmhlo_to_scalar_op",
"@mlir-hlo//:map_mhlo_to_scalar_op",
"@mlir-hlo//:materialize_broadcasts",
+ "@mlir-hlo//:mhlo_control_flow_to_scf",
"@mlir-hlo//:mhlo_to_mhlo_lowering_patterns",
"@mlir-hlo//:unfuse_batch_norm",
],
diff --git a/iree/compiler/InputConversion/MHLO/BroadcastingToLinalgPatterns.cpp b/iree/compiler/InputConversion/MHLO/BroadcastingToLinalgPatterns.cpp
index 5e98d4d..794be24 100644
--- a/iree/compiler/InputConversion/MHLO/BroadcastingToLinalgPatterns.cpp
+++ b/iree/compiler/InputConversion/MHLO/BroadcastingToLinalgPatterns.cpp
@@ -21,6 +21,7 @@
namespace mlir {
namespace iree_compiler {
+namespace MHLO {
namespace {
@@ -740,10 +741,11 @@
} // namespace
+} // namespace MHLO
} // namespace iree_compiler
} // namespace mlir
-void mlir::iree_compiler::populateMHLOBroadcastingToLinalgPatterns(
+void mlir::iree_compiler::MHLO::populateMHLOBroadcastingToLinalgPatterns(
MLIRContext *context, TypeConverter &typeConverter,
OwningRewritePatternList &patterns) {
#define POPULATE_SIMPLE_BCAST(ChloOp, HloOp) \
diff --git a/iree/compiler/InputConversion/MHLO/CMakeLists.txt b/iree/compiler/InputConversion/MHLO/CMakeLists.txt
index 3d6fc1e..bf40b06 100644
--- a/iree/compiler/InputConversion/MHLO/CMakeLists.txt
+++ b/iree/compiler/InputConversion/MHLO/CMakeLists.txt
@@ -45,6 +45,7 @@
"ConvertMHLOToFlow.cpp"
"ConvertMHLOToFlow.h"
"ConvertMHLOToLinalgExt.cpp"
+ "FlattenTuplesInCFG.cpp"
"LegalizeInputTypes.cpp"
"MHLOToLinalgOnTensors.cpp"
"MHLOToMHLOPreprocessing.cpp"
@@ -53,9 +54,12 @@
DEPS
::PassHeaders
::PassesIncGen
+ ChloDialect
+ ChloPasses
IREELinalgExtDialect
IREELinalgExtPasses
LLVMSupport
+ LmhloDialect
MLIRAffine
MLIRComplex
MLIRIR
@@ -63,9 +67,11 @@
MLIRLinalgTransforms
MLIRMath
MLIRMemRef
+ MLIRMhloUtils
MLIRPass
MLIRReconcileUnrealizedCasts
MLIRSCFToStandard
+ MLIRSCFTransforms
MLIRShape
MLIRShapeOpsTransforms
MLIRShapeToStandard
@@ -73,12 +79,15 @@
MLIRSupport
MLIRTensor
MLIRTransforms
+ MhloDialect
+ MhloLhloToLinalg
+ MhloPasses
+ MhloToStandard
iree::compiler::Dialect::Flow::IR
iree::compiler::Dialect::Flow::Transforms
- iree::compiler::Dialect::Shape::IR
iree::compiler::Dialect::Util::IR
iree::compiler::InputConversion::Common
- tensorflow::mlir_hlo
+ tensorflow::external_mhlo_includes
PUBLIC
)
diff --git a/iree/compiler/InputConversion/MHLO/ConvertComplexToReal.cpp b/iree/compiler/InputConversion/MHLO/ConvertComplexToReal.cpp
index c33f230..c6c4d1c 100644
--- a/iree/compiler/InputConversion/MHLO/ConvertComplexToReal.cpp
+++ b/iree/compiler/InputConversion/MHLO/ConvertComplexToReal.cpp
@@ -14,6 +14,7 @@
namespace mlir {
namespace iree_compiler {
+namespace MHLO {
namespace {
@@ -428,5 +429,6 @@
return std::make_unique<TestMHLOConvertComplexToRealPass>();
}
+} // namespace MHLO
} // namespace iree_compiler
} // namespace mlir
diff --git a/iree/compiler/InputConversion/MHLO/ConvertMHLOToFlow.cpp b/iree/compiler/InputConversion/MHLO/ConvertMHLOToFlow.cpp
index b562208..8aec9b1 100644
--- a/iree/compiler/InputConversion/MHLO/ConvertMHLOToFlow.cpp
+++ b/iree/compiler/InputConversion/MHLO/ConvertMHLOToFlow.cpp
@@ -19,6 +19,7 @@
namespace mlir {
namespace iree_compiler {
+namespace MHLO {
namespace {
@@ -43,5 +44,6 @@
patterns.insert<ConstOpLowering>(context);
}
+} // namespace MHLO
} // namespace iree_compiler
} // namespace mlir
diff --git a/iree/compiler/InputConversion/MHLO/ConvertMHLOToFlow.h b/iree/compiler/InputConversion/MHLO/ConvertMHLOToFlow.h
index 482e0e8..17f914b 100644
--- a/iree/compiler/InputConversion/MHLO/ConvertMHLOToFlow.h
+++ b/iree/compiler/InputConversion/MHLO/ConvertMHLOToFlow.h
@@ -12,6 +12,7 @@
namespace mlir {
namespace iree_compiler {
+namespace MHLO {
// Setup the |conversionTarget| op legality for early-phase direct-to-flow
// conversion from the MHLO dialect. This will make certain ops illegal that we
@@ -24,6 +25,7 @@
void populateMHLOToFlowPatterns(MLIRContext *context,
OwningRewritePatternList &patterns);
+} // namespace MHLO
} // namespace iree_compiler
} // namespace mlir
diff --git a/iree/compiler/InputConversion/MHLO/ConvertMHLOToLinalgExt.cpp b/iree/compiler/InputConversion/MHLO/ConvertMHLOToLinalgExt.cpp
index 4e33d60..df0ff46 100644
--- a/iree/compiler/InputConversion/MHLO/ConvertMHLOToLinalgExt.cpp
+++ b/iree/compiler/InputConversion/MHLO/ConvertMHLOToLinalgExt.cpp
@@ -29,6 +29,7 @@
namespace mlir {
namespace iree_compiler {
+namespace MHLO {
namespace {
@@ -571,5 +572,6 @@
return std::make_unique<ConvertMHLOToLinalgExtPass>();
}
+} // namespace MHLO
} // namespace iree_compiler
} // namespace mlir
diff --git a/integrations/tensorflow/iree_tf_compiler/MHLO/FlattenTuplesInCFG.cpp b/iree/compiler/InputConversion/MHLO/FlattenTuplesInCFG.cpp
similarity index 95%
rename from integrations/tensorflow/iree_tf_compiler/MHLO/FlattenTuplesInCFG.cpp
rename to iree/compiler/InputConversion/MHLO/FlattenTuplesInCFG.cpp
index 0fcf381..3a3e978 100644
--- a/integrations/tensorflow/iree_tf_compiler/MHLO/FlattenTuplesInCFG.cpp
+++ b/iree/compiler/InputConversion/MHLO/FlattenTuplesInCFG.cpp
@@ -4,7 +4,8 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-#include "iree_tf_compiler/MHLO/Passes.h"
+#include "iree/compiler/InputConversion/MHLO/PassDetail.h"
+#include "iree/compiler/InputConversion/MHLO/Passes.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/iterator_range.h"
@@ -18,7 +19,7 @@
#include "mlir/Transforms/Utils.h"
namespace mlir {
-namespace iree_integrations {
+namespace iree_compiler {
namespace MHLO {
namespace {
@@ -274,17 +275,8 @@
}
class FlattenTuplesInCFGPass
- : public PassWrapper<FlattenTuplesInCFGPass, OperationPass<ModuleOp>> {
+ : public FlattenTuplesInCFGBase<FlattenTuplesInCFGPass> {
public:
- StringRef getArgument() const override {
- return "iree-mhlo-flatten-tuples-in-cfg";
- }
-
- StringRef getDescription() const override {
- return "Convert functions to remove tuples from method signatures and "
- "blocks";
- }
-
void runOnOperation() override {
auto module = getOperation();
Builder builder(module.getContext());
@@ -332,5 +324,5 @@
static PassRegistration<FlattenTuplesInCFGPass> pass;
} // namespace MHLO
-} // namespace iree_integrations
+} // namespace iree_compiler
} // namespace mlir
diff --git a/iree/compiler/InputConversion/MHLO/LegalizeInputTypes.cpp b/iree/compiler/InputConversion/MHLO/LegalizeInputTypes.cpp
index 5d99e15..b6bec24 100644
--- a/iree/compiler/InputConversion/MHLO/LegalizeInputTypes.cpp
+++ b/iree/compiler/InputConversion/MHLO/LegalizeInputTypes.cpp
@@ -26,6 +26,7 @@
namespace mlir {
namespace iree_compiler {
+namespace MHLO {
static Attribute convertAttribute(Location loc, Attribute value,
FlowTypeConverter &typeConverter) {
@@ -253,5 +254,6 @@
return std::make_unique<LegalizeInputTypesPass>();
}
+} // namespace MHLO
} // namespace iree_compiler
} // namespace mlir
diff --git a/iree/compiler/InputConversion/MHLO/MHLOToLinalgOnTensors.cpp b/iree/compiler/InputConversion/MHLO/MHLOToLinalgOnTensors.cpp
index 6a61233..cabeb37 100644
--- a/iree/compiler/InputConversion/MHLO/MHLOToLinalgOnTensors.cpp
+++ b/iree/compiler/InputConversion/MHLO/MHLOToLinalgOnTensors.cpp
@@ -14,8 +14,6 @@
#include <memory>
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
-#include "iree/compiler/Dialect/Shape/IR/ShapeDialect.h"
-#include "iree/compiler/Dialect/Shape/IR/ShapeOps.h"
#include "iree/compiler/InputConversion/MHLO/ConvertMHLOToFlow.h"
#include "iree/compiler/InputConversion/MHLO/PassDetail.h"
#include "iree/compiler/InputConversion/MHLO/Passes.h"
@@ -45,7 +43,7 @@
namespace mlir {
namespace iree_compiler {
-namespace {
+namespace MHLO {
//===----------------------------------------------------------------------===//
// mhlo.concatenate conversion patterns.
@@ -102,14 +100,11 @@
return success();
}
};
-} // namespace
//===----------------------------------------------------------------------===//
// mhlo.fft conversion patterns.
//===----------------------------------------------------------------------===//
-namespace {
-
/// Creats coefficients based on DFT definition, see
/// https://en.wikipedia.org/wiki/Discrete_Fourier_transform
Value getDFTMatmulCoeff(OpBuilder b, Location loc, RankedTensorType matrixType,
@@ -207,16 +202,94 @@
return success();
}
};
-} // namespace
+
+// We need to convert func ops in order to convert types.
+class BuiltinFuncOpPattern : public OpConversionPattern<FuncOp> {
+ using OpConversionPattern<FuncOp>::OpConversionPattern;
+ LogicalResult matchAndRewrite(
+ FuncOp srcOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ FunctionType srcFuncType = srcOp.getType();
+ TypeConverter::SignatureConversion signatureConversion(
+ srcOp.getNumArguments());
+
+ // Convert function arguments.
+ for (unsigned i = 0, e = srcFuncType.getNumInputs(); i < e; ++i) {
+ if (failed(getTypeConverter()->convertSignatureArg(
+ i, srcFuncType.getInput(i), signatureConversion))) {
+ return rewriter.notifyMatchFailure(srcOp, "argument failed to convert");
+ }
+ }
+
+ // Convert function results.
+ SmallVector<Type> convertedResultTypes;
+ if (failed(getTypeConverter()->convertTypes(srcFuncType.getResults(),
+ convertedResultTypes))) {
+ return rewriter.notifyMatchFailure(srcOp, "results failed to convert");
+ }
+
+ // Create new function with converted argument and result types.
+ auto newFuncType = mlir::FunctionType::get(
+ srcOp.getContext(), signatureConversion.getConvertedTypes(),
+ convertedResultTypes);
+
+ // Update the function in place.
+ rewriter.startRootUpdate(srcOp);
+ srcOp.setType(newFuncType);
+
+ // Tell the rewriter to convert the region signature.
+ TypeConverter &typeConverter = *getTypeConverter();
+ if (failed(rewriter.convertRegionTypes(&srcOp.getBody(), typeConverter,
+ &signatureConversion))) {
+ return failure();
+ }
+
+ rewriter.finalizeRootUpdate(srcOp);
+ return success();
+ }
+};
+
+class GenericTypeConvert : public ConversionPattern {
+ public:
+ GenericTypeConvert(StringRef rootName, TypeConverter &converter,
+ MLIRContext *context, PatternBenefit benefit = 0)
+ : ConversionPattern(converter, rootName, benefit, context) {}
+ LogicalResult matchAndRewrite(
+ Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ llvm::SmallVector<NamedAttribute, 4> newAttr;
+ llvm::append_range(newAttr, op->getAttrs());
+ llvm::SmallVector<Type, 4> newResults;
+ if (failed(getTypeConverter()->convertTypes(op->getResultTypes(),
+ newResults))) {
+ return rewriter.notifyMatchFailure(op, "result type conversion failed");
+ }
+ OperationState state(op->getLoc(), op->getName().getStringRef(), operands,
+ newResults, newAttr, op->getSuccessors());
+ for (Region &r : op->getRegions()) {
+ Region *newRegion = state.addRegion();
+ rewriter.inlineRegionBefore(r, *newRegion, newRegion->begin());
+ TypeConverter::SignatureConversion result(newRegion->getNumArguments());
+ if (failed(getTypeConverter()->convertSignatureArgs(
+ newRegion->getArgumentTypes(), result))) {
+ return rewriter.notifyMatchFailure(op,
+ "argument type conversion failed");
+ }
+ rewriter.applySignatureConversion(newRegion, result);
+ }
+ Operation *newOp = rewriter.createOperation(state);
+ rewriter.replaceOp(op, newOp->getResults());
+ return success();
+ }
+};
struct ConvertMHLOToLinalgOnTensorsPass
: public ConvertMHLOToLinalgOnTensorsBase<
ConvertMHLOToLinalgOnTensorsPass> {
void getDependentDialects(DialectRegistry ®istry) const override {
registry.insert<IREE::Flow::FlowDialect, linalg::LinalgDialect,
- mhlo::MhloDialect, shape::ShapeDialect, ShapeDialect,
- math::MathDialect, memref::MemRefDialect,
- complex::ComplexDialect>();
+ mhlo::MhloDialect, shape::ShapeDialect, math::MathDialect,
+ memref::MemRefDialect, complex::ComplexDialect>();
}
void runOnOperation() override {
@@ -235,12 +308,51 @@
patterns);
populateMHLOComplexToRealPatterns(context, *typeConverter, patterns);
+ // Structural patterns (functions, cfg, terminators).
+ patterns.insert<BuiltinFuncOpPattern>(*typeConverter, context);
+ patterns.insert<GenericTypeConvert>(ReturnOp::getOperationName(),
+ *typeConverter, context);
+ patterns.insert<GenericTypeConvert>(CallOp::getOperationName(),
+ *typeConverter, context);
+ patterns.insert<GenericTypeConvert>(CondBranchOp::getOperationName(),
+ *typeConverter, context);
+ patterns.insert<GenericTypeConvert>(BranchOp::getOperationName(),
+ *typeConverter, context);
+
ConversionTarget target(getContext());
+ auto isIllegalType = [&](Type t) { return !typeConverter->isLegal(t); };
+ auto isLegallyTypedOp = [&](Operation *op) -> bool {
+ for (Type type : op->getResultTypes()) {
+ if (isIllegalType(type)) return false;
+ }
+ for (Type type : op->getOperandTypes()) {
+ if (isIllegalType(type)) return false;
+ }
+ return true;
+ };
+
target.addIllegalDialect<chlo::HloClientDialect>();
target.addIllegalDialect<mhlo::MhloDialect>();
+ // Functions must have legal types.
+ target.addDynamicallyLegalOp<FuncOp>([&](FuncOp funcOp) {
+ for (Type type : funcOp.getType().getInputs()) {
+ if (isIllegalType(type)) return false;
+ }
+ for (Type type : funcOp.getType().getResults()) {
+ if (isIllegalType(type)) return false;
+ }
+ for (Block &block : funcOp.body()) {
+ for (Type type : block.getArgumentTypes()) {
+ if (isIllegalType(type)) return false;
+ }
+ }
+ return true;
+ });
+
// Let the rest fall through.
- target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
+ target.addLegalDialect<BuiltinDialect>();
+ target.markUnknownOpDynamicallyLegal(isLegallyTypedOp);
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns)))) {
@@ -249,26 +361,6 @@
}
};
-/// Convert mhlo.constant op into std.const.
-struct ConstOpConversion : public OpConversionPattern<mhlo::ConstOp> {
- using OpConversionPattern::OpConversionPattern;
- LogicalResult matchAndRewrite(
- mhlo::ConstOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- auto valueAttr = op.value();
- Type oldElType = valueAttr.getType().getElementType();
- Type newElType = this->typeConverter->convertType(oldElType);
- ElementsAttr newValueAttr = valueAttr;
- if (newElType != oldElType) {
- // Values don't change, just their reported type.
- newValueAttr = valueAttr.cast<DenseIntOrFPElementsAttr>().mapValues(
- newElType, [](const APInt &oldEl) { return oldEl; });
- }
- rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, newValueAttr);
- return success();
- }
-};
-
} // namespace
void populateMHLOToLinalgOnTensorsConversionPatterns(
@@ -277,7 +369,7 @@
mhlo::populateHLOToLinalgConversionPattern(context, typeConverter, &patterns);
// TODO(#5809): Drop ConcatenateOp lowering in favor of the upstream version
// then remove the PatternBenefit here
- patterns.insert<ConstOpConversion, ConcatenateOpConversion, FftOpConversion>(
+ patterns.insert<ConcatenateOpConversion, FftOpConversion>(
typeConverter, context, PatternBenefit(1000));
}
@@ -285,5 +377,6 @@
return std::make_unique<ConvertMHLOToLinalgOnTensorsPass>();
}
+} // namespace MHLO
} // namespace iree_compiler
} // namespace mlir
diff --git a/iree/compiler/InputConversion/MHLO/MHLOToMHLOPreprocessing.cpp b/iree/compiler/InputConversion/MHLO/MHLOToMHLOPreprocessing.cpp
index bb1d5b9..b3e2eaf 100644
--- a/iree/compiler/InputConversion/MHLO/MHLOToMHLOPreprocessing.cpp
+++ b/iree/compiler/InputConversion/MHLO/MHLOToMHLOPreprocessing.cpp
@@ -28,6 +28,7 @@
namespace mlir {
namespace iree_compiler {
+namespace MHLO {
namespace {
@@ -876,7 +877,10 @@
patterns.insert<ReorderConvOpKernelDimensions>(context);
patterns.insert<ReorderConvOpOutputDimensions>(context);
}
- (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+ if (failed(applyPatternsAndFoldGreedily(getOperation(),
+ std::move(patterns)))) {
+ return signalPassFailure();
+ }
}
};
@@ -886,5 +890,6 @@
return std::make_unique<MHLOToMHLOPreprocessingPass>();
}
+} // namespace MHLO
} // namespace iree_compiler
} // namespace mlir
diff --git a/iree/compiler/InputConversion/MHLO/PassDetail.h b/iree/compiler/InputConversion/MHLO/PassDetail.h
index 4d2be6c..f7098ac 100644
--- a/iree/compiler/InputConversion/MHLO/PassDetail.h
+++ b/iree/compiler/InputConversion/MHLO/PassDetail.h
@@ -11,10 +11,12 @@
namespace mlir {
namespace iree_compiler {
+namespace MHLO {
#define GEN_PASS_CLASSES
#include "iree/compiler/InputConversion/MHLO/Passes.h.inc"
+} // namespace MHLO
} // namespace iree_compiler
} // namespace mlir
diff --git a/iree/compiler/InputConversion/MHLO/Passes.cpp b/iree/compiler/InputConversion/MHLO/Passes.cpp
index d0b27a7..6f452ad 100644
--- a/iree/compiler/InputConversion/MHLO/Passes.cpp
+++ b/iree/compiler/InputConversion/MHLO/Passes.cpp
@@ -7,8 +7,11 @@
#include "iree/compiler/InputConversion/MHLO/Passes.h"
#include "iree/compiler/InputConversion/Common/Passes.h"
+#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
#include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h"
+#include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
#include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h"
+#include "mlir/Dialect/SCF/Passes.h"
#include "mlir/Dialect/Shape/Transforms/Passes.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Pass/PassOptions.h"
@@ -17,6 +20,7 @@
namespace mlir {
namespace iree_compiler {
+namespace MHLO {
void registerMHLOConversionPassPipeline() {
PassPipelineRegistration<> mhlo(
@@ -25,6 +29,11 @@
[](OpPassManager &passManager) {
buildMHLOInputConversionPassPipeline(passManager);
});
+ PassPipelineRegistration<> xla("iree-mhlo-xla-cleanup-pipeline",
+ "Runs the post-XLA import cleanup pipeline",
+ [](OpPassManager &passManager) {
+ buildXLACleanupPassPipeline(passManager);
+ });
}
// Prepare HLO for use as an input to the Flow dialect.
@@ -53,7 +62,7 @@
// TODO(nicolasvasilache): createLegalizeInputTypesPass is old and does not
// handle region conversion properly (parent cloned before children). Revisit
// when using ops with regions such as scf.for and linalg.generic.
- passManager.addPass(mlir::iree_compiler::createLegalizeInputTypesPass());
+ passManager.addPass(createLegalizeInputTypesPass());
// Perform initial cleanup. createLegalizeInputTypes could rewrite types. In
// this context, some operations could be folded away.
@@ -61,10 +70,8 @@
passManager.addNestedPass<FuncOp>(mlir::createCSEPass());
// Convert to Linalg. After this point, MHLO will be eliminated.
- passManager.addNestedPass<FuncOp>(
- mlir::iree_compiler::createConvertMHLOToLinalgExtPass());
- passManager.addNestedPass<FuncOp>(
- mlir::iree_compiler::createMHLOToLinalgOnTensorsPass());
+ passManager.addNestedPass<FuncOp>(createConvertMHLOToLinalgExtPass());
+ passManager.addNestedPass<FuncOp>(createMHLOToLinalgOnTensorsPass());
// Ensure conversion completed.
passManager.addPass(createReconcileUnrealizedCastsPass());
@@ -78,6 +85,12 @@
passManager.addPass(createVerifyCompilerMHLOInputLegality());
}
+void buildXLACleanupPassPipeline(OpPassManager &passManager) {
+ passManager.addNestedPass<FuncOp>(mhlo::createLegalizeControlFlowPass());
+ passManager.addPass(createFlattenTuplesInCFGPass());
+ passManager.addNestedPass<FuncOp>(mlir::createCanonicalizerPass());
+}
+
namespace {
#define GEN_PASS_REGISTRATION
#include "iree/compiler/InputConversion/MHLO/Passes.h.inc" // IWYU pragma: export
@@ -91,5 +104,6 @@
registerMHLOConversionPassPipeline();
}
+} // namespace MHLO
} // namespace iree_compiler
} // namespace mlir
diff --git a/iree/compiler/InputConversion/MHLO/Passes.h b/iree/compiler/InputConversion/MHLO/Passes.h
index 8bc4222..a027fc3 100644
--- a/iree/compiler/InputConversion/MHLO/Passes.h
+++ b/iree/compiler/InputConversion/MHLO/Passes.h
@@ -11,6 +11,7 @@
namespace mlir {
namespace iree_compiler {
+namespace MHLO {
//===----------------------------------------------------------------------===//
// Pipelines
@@ -19,9 +20,28 @@
// Performs input legalization for specific combination of input dialects.
void buildMHLOInputConversionPassPipeline(OpPassManager &passManager);
+// Performs some cleanup activities on programs that may have originated from
+// an XLA import (or made to interop with it). This involves:
+// - Convert XLA control flow to SCF
+// - Convert SCF control flow to CFG
+// - Flatten tuples in CFG
+// - Canonicalize
+// It is unfortunate to lose SCF so early in the process but CFG provides a
+// large simplification to tuple heavy programs, and this compromise is taken
+// in the name of compatibility.
+void buildXLACleanupPassPipeline(OpPassManager &passManager);
+
void registerMHLOConversionPassPipelines();
//------------------------------------------------------------------------------
+// Cleanup passes
+//------------------------------------------------------------------------------
+
+// Flattens tuples in functions and CFG control flow. This is a common
+// form of MHLO as produced by XLA based systems.
+std::unique_ptr<OperationPass<ModuleOp>> createFlattenTuplesInCFGPass();
+
+//------------------------------------------------------------------------------
// Conversions into Linalg
//------------------------------------------------------------------------------
@@ -59,6 +79,7 @@
void registerMHLOConversionPasses();
+} // namespace MHLO
} // namespace iree_compiler
} // namespace mlir
diff --git a/iree/compiler/InputConversion/MHLO/Passes.td b/iree/compiler/InputConversion/MHLO/Passes.td
index b0082b2..aa633a8 100644
--- a/iree/compiler/InputConversion/MHLO/Passes.td
+++ b/iree/compiler/InputConversion/MHLO/Passes.td
@@ -12,7 +12,7 @@
def ConvertMHLOToLinalgOnTensors :
Pass<"iree-mhlo-to-linalg-on-tensors", "FuncOp"> {
let summary = "Convert from XLA-HLO ops to Linalg ops on tensors";
- let constructor = "mlir::iree_compiler::createMHLOToLinalgOnTensorsPass()";
+ let constructor = "mlir::iree_compiler::MHLO::createMHLOToLinalgOnTensorsPass()";
}
def ConvertMHLOToLinalgExt
@@ -20,19 +20,25 @@
let summary =
"Convert from XLA-HLO ops to LinalgExt ops and distribute to Flow ops";
let constructor =
- "mlir::iree_compiler::createConvertMHLOToLinalgExtPass()";
+ "mlir::iree_compiler::MHLO::createConvertMHLOToLinalgExtPass()";
}
def LegalizeInputTypes :
Pass<"iree-mhlo-legalize-input-types", "ModuleOp"> {
let summary = "Legalizes input types to ones supported by the IREE flow dialect";
- let constructor = "mlir::iree_compiler::createLegalizeInputTypesPass()";
+ let constructor = "mlir::iree_compiler::MHLO::createLegalizeInputTypesPass()";
+}
+
+def FlattenTuplesInCFG :
+ Pass<"iree-mhlo-flatten-tuples-in-cfg", "ModuleOp"> {
+ let summary = "Flattens tuples in a CFG form of MHLO";
+ let constructor = "mlir::iree_compiler::MHLO::createFlattenTuplesInCFGPass()";
}
def MHLOToMHLOPreprocessing :
Pass<"iree-mhlo-to-mhlo-preprocessing", "FuncOp"> {
let summary = "Apply mhlo to mhlo transformations for some mhlo ops";
- let constructor = "mlir::iree_compiler::createMHLOToMHLOPreprocessingPass()";
+ let constructor = "mlir::iree_compiler::MHLO::createMHLOToMHLOPreprocessingPass()";
let options = [
Option<"extractPadFromConv", "extract-pad-from-conv", "bool", /*default=*/"true",
"Extract padding attributes from conv op">,
@@ -44,7 +50,7 @@
def VerifyCompilerMHLOInputLegality :
Pass<"iree-mhlo-verify-compiler-input-legality", "ModuleOp"> {
let summary = "Verifies that only supported IR constructs are passed to the compiler.";
- let constructor = "mlir::iree_compiler::createVerifyCompilerMHLOInputLegality()";
+ let constructor = "mlir::iree_compiler::MHLO::createVerifyCompilerMHLOInputLegality()";
}
//------------------------------------------------------------------------------
@@ -54,7 +60,7 @@
def TestMHLOConvertComplexToReal :
Pass<"iree-test-mhlo-convert-complex-to-real", "FuncOp"> {
let summary = "Test pass that does an MHLO->MHLO conversion of just complex arithmetic ops.";
- let constructor = "mlir::iree_compiler::createTestMHLOConvertComplexToRealPass()";
+ let constructor = "mlir::iree_compiler::MHLO::createTestMHLOConvertComplexToRealPass()";
}
#endif // IREE_COMPILER_INPUTCONVERSION_MHLO_PASSES
diff --git a/iree/compiler/InputConversion/MHLO/Rewriters.h b/iree/compiler/InputConversion/MHLO/Rewriters.h
index 8b9c3a5..f7cfb11 100644
--- a/iree/compiler/InputConversion/MHLO/Rewriters.h
+++ b/iree/compiler/InputConversion/MHLO/Rewriters.h
@@ -11,6 +11,7 @@
namespace mlir {
namespace iree_compiler {
+namespace MHLO {
/// Populates the patterns that convert from MHLO to Linalg on tensors. Imports
/// patterns from XLA, as well as some IREE specific modifications.
@@ -32,6 +33,7 @@
TypeConverter &typeConverter,
OwningRewritePatternList &patterns);
+} // namespace MHLO
} // namespace iree_compiler
} // namespace mlir
diff --git a/iree/compiler/InputConversion/MHLO/VerifyCompilerMHLOInputLegality.cpp b/iree/compiler/InputConversion/MHLO/VerifyCompilerMHLOInputLegality.cpp
index 0ddfa30..3c0ec3a 100644
--- a/iree/compiler/InputConversion/MHLO/VerifyCompilerMHLOInputLegality.cpp
+++ b/iree/compiler/InputConversion/MHLO/VerifyCompilerMHLOInputLegality.cpp
@@ -15,6 +15,7 @@
namespace mlir {
namespace iree_compiler {
+namespace MHLO {
struct VerifyCompilerMHLOInputLegalityPass
: public VerifyCompilerMHLOInputLegalityBase<
@@ -69,5 +70,6 @@
return std::make_unique<VerifyCompilerMHLOInputLegalityPass>();
}
+} // namespace MHLO
} // namespace iree_compiler
} // namespace mlir
diff --git a/iree/compiler/InputConversion/MHLO/test/BUILD b/iree/compiler/InputConversion/MHLO/test/BUILD
index cef5ec5..f17b8d7 100644
--- a/iree/compiler/InputConversion/MHLO/test/BUILD
+++ b/iree/compiler/InputConversion/MHLO/test/BUILD
@@ -22,8 +22,10 @@
"broadcasting.mlir",
"convert_mhlo_to_linalg_ext.mlir",
"convert_complex_to_real.mlir",
+ "convert_structural_types.mlir",
"dynamic_shape.mlir",
"fft.mlir",
+ "flatten_tuples_in_cfg.mlir",
"legalize_input_types.mlir",
"mhlo_to_mhlo_preprocessing.mlir",
"mhlo_to_mhlo_preprocessing_canoncalize_dot_general.mlir",
diff --git a/iree/compiler/InputConversion/MHLO/test/CMakeLists.txt b/iree/compiler/InputConversion/MHLO/test/CMakeLists.txt
index d7405b8..82d8ea0 100644
--- a/iree/compiler/InputConversion/MHLO/test/CMakeLists.txt
+++ b/iree/compiler/InputConversion/MHLO/test/CMakeLists.txt
@@ -17,8 +17,10 @@
"broadcasting.mlir"
"convert_complex_to_real.mlir"
"convert_mhlo_to_linalg_ext.mlir"
+ "convert_structural_types.mlir"
"dynamic_shape.mlir"
"fft.mlir"
+ "flatten_tuples_in_cfg.mlir"
"legalize_input_types.mlir"
"mhlo_to_mhlo_preprocessing.mlir"
"mhlo_to_mhlo_preprocessing_canoncalize_dot_general.mlir"
diff --git a/iree/compiler/InputConversion/MHLO/test/broadcasting.mlir b/iree/compiler/InputConversion/MHLO/test/broadcasting.mlir
index d657341..d572c25 100644
--- a/iree/compiler/InputConversion/MHLO/test/broadcasting.mlir
+++ b/iree/compiler/InputConversion/MHLO/test/broadcasting.mlir
@@ -411,7 +411,6 @@
// -----
// CHECK-LABEL: @fallbackDynamicReshape
func @fallbackDynamicReshape(%arg0 : tensor<4x?x3x?xui32>, %arg1 : tensor<5xindex>) -> tensor<12x?x?x1x?xui32> {
- // CHECK: %[[INPUT:.*]] = builtin.unrealized_conversion_cast %arg0 : tensor<4x?x3x?xui32> to tensor<4x?x3x?xi32>
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[RESULT_D1:.*]] = tensor.extract %arg1[%[[C1]]] : tensor<5xindex>
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
@@ -419,12 +418,11 @@
// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
// CHECK-DAG: %[[RESULT_D4:.*]] = tensor.extract %arg1[%[[C4]]] : tensor<5xindex>
// CHECK-DAG: %[[INDEX1:.*]] = arith.constant 1 : index
- // CHECK-DAG: %[[ARG_D1:.*]] = tensor.dim %[[INPUT]], %[[INDEX1]] : tensor<4x?x3x?xi32>
+ // CHECK-DAG: %[[ARG_D1:.*]] = tensor.dim %arg0, %[[INDEX1]] : tensor<4x?x3x?xi32>
// CHECK-DAG: %[[INDEX3:.*]] = arith.constant 3 : index
- // CHECK-DAG: %[[ARG_D3:.*]] = tensor.dim %[[INPUT]], %[[INDEX3]] : tensor<4x?x3x?xi32>
- // CHECK-DAG: %[[RESULT:.*]] = flow.tensor.reshape %[[INPUT]] : tensor<4x?x3x?xi32>{%[[ARG_D1]], %[[ARG_D3]]} -> tensor<12x?x?x1x?xi32>{%[[RESULT_D1]], %[[RESULT_D2]], %[[RESULT_D4]]}
+ // CHECK-DAG: %[[ARG_D3:.*]] = tensor.dim %arg0, %[[INDEX3]] : tensor<4x?x3x?xi32>
+ // CHECK-DAG: %[[RESULT:.*]] = flow.tensor.reshape %arg0 : tensor<4x?x3x?xi32>{%[[ARG_D1]], %[[ARG_D3]]} -> tensor<12x?x?x1x?xi32>{%[[RESULT_D1]], %[[RESULT_D2]], %[[RESULT_D4]]}
%0 = "mhlo.dynamic_reshape"(%arg0, %arg1) : (tensor<4x?x3x?xui32>, tensor<5xindex>) -> tensor<12x?x?x1x?xui32>
- // CHECK: %[[UNCONVERTED_RESULT:.*]] = builtin.unrealized_conversion_cast %[[RESULT]] : tensor<12x?x?x1x?xi32> to tensor<12x?x?x1x?xui32>
- // CHECK: return %[[UNCONVERTED_RESULT]]
+ // CHECK: return %[[RESULT]]
return %0 : tensor<12x?x?x1x?xui32>
}
diff --git a/iree/compiler/InputConversion/MHLO/test/convert_structural_types.mlir b/iree/compiler/InputConversion/MHLO/test/convert_structural_types.mlir
new file mode 100644
index 0000000..16e2d05
--- /dev/null
+++ b/iree/compiler/InputConversion/MHLO/test/convert_structural_types.mlir
@@ -0,0 +1,25 @@
+// RUN: iree-opt -split-input-file --iree-mhlo-to-linalg-on-tensors %s | IreeFileCheck %s
+
+// CHECK-LABEL: @func_cfg_conversion
+module @func_cfg_conversion {
+ // CHECK: func @caller(%arg0: tensor<2xi32>, %arg1: i1) -> tensor<2xi32>
+ func @caller(%arg0: tensor<2xui32>, %arg1 : i1) -> tensor<2xui32> {
+ // CHECK: %[[RESULT:.*]] = call @callee(%arg0, %arg1) : (tensor<2xi32>, i1) -> tensor<2xi32>
+ %1 = call @callee(%arg0, %arg1) : (tensor<2xui32>, i1) -> tensor<2xui32>
+ // CHECK: return %[[RESULT]] : tensor<2xi32>
+ return %1 : tensor<2xui32>
+ }
+
+ // CHECK: func @callee(%arg0: tensor<2xi32>, %arg1: i1) -> tensor<2xi32>
+ func @callee(%arg0: tensor<2xui32>, %arg1: i1) -> tensor<2xui32> {
+ // CHECK: cond_br %arg1, ^bb1(%arg0 : tensor<2xi32>), ^bb2(%arg0 : tensor<2xi32>)
+ cond_br %arg1, ^bb1(%arg0 : tensor<2xui32>), ^bb2(%arg0 : tensor<2xui32>)
+ // CHECK: ^bb1(%[[BB1_PHI:.*]]: tensor<2xi32>)
+ ^bb1(%phi0 : tensor<2xui32>) :
+ // CHECK: br ^bb2(%[[BB1_PHI]] : tensor<2xi32>)
+ br ^bb2(%phi0 : tensor<2xui32>)
+ // CHECK: ^bb2(%[[BB2_PHI:.*]]: tensor<2xi32>)
+ ^bb2(%phi1 : tensor<2xui32>):
+ return %phi1 : tensor<2xui32>
+ }
+}
diff --git a/iree/compiler/InputConversion/MHLO/test/flatten_tuples_in_cfg.mlir b/iree/compiler/InputConversion/MHLO/test/flatten_tuples_in_cfg.mlir
new file mode 100644
index 0000000..6fe968e
--- /dev/null
+++ b/iree/compiler/InputConversion/MHLO/test/flatten_tuples_in_cfg.mlir
@@ -0,0 +1,34 @@
+// RUN: iree-opt -split-input-file --iree-mhlo-flatten-tuples-in-cfg -canonicalize %s | IreeFileCheck %s
+// We rely on canonicalization to cancel out tuple/get_element operations, so
+// we test this followed by the canonicalizer rather than just the pass in
+// isolation.
+// TODO: It would be better if the pass was standalone.
+
+// CHECK-LABEL: @flatten_func
+module @flatten_func {
+ // CHECK: func @caller(%arg0: i1, %arg1: tensor<f32>) -> tensor<f32>
+ func @caller(%arg0 : i1, %arg1: tensor<f32>) -> tensor<f32> {
+ // CHECK: %[[RESULT:.*]]:2 = call @callee(%arg0, %arg1, %arg1, %arg1) : (i1, tensor<f32>, tensor<f32>, tensor<f32>) -> (tensor<f32>, tensor<f32>)
+ %0 = "mhlo.tuple"(%arg1, %arg1) : (tensor<f32>, tensor<f32>) -> tuple<tensor<f32>, tensor<f32>>
+ %1 = "mhlo.tuple"(%arg1, %0) : (tensor<f32>, tuple<tensor<f32>, tensor<f32>>) -> tuple<tensor<f32>, tuple<tensor<f32>, tensor<f32>>>
+ %2 = call @callee(%arg0, %1) : (i1, tuple<tensor<f32>, tuple<tensor<f32>, tensor<f32>>>) -> tuple<tensor<f32>, tensor<f32>>
+ %3 = "mhlo.get_tuple_element"(%2) {index = 0 : i32} : (tuple<tensor<f32>, tensor<f32>>) -> tensor<f32>
+ // CHECK: return %[[RESULT]]#0 : tensor<f32>
+ return %3 : tensor<f32>
+ }
+
+ // CHECK: func private @callee(%arg0: i1, %arg1: tensor<f32>, %arg2: tensor<f32>, %arg3: tensor<f32>) -> (tensor<f32>, tensor<f32>)
+ func private @callee(%arg0: i1, %arg1: tuple<tensor<f32>, tuple<tensor<f32>, tensor<f32>>>) -> tuple<tensor<f32>, tensor<f32>> {
+ // CHECK-DAG: %[[RESULT0:.*]] = select %arg0, %arg2, %arg1 : tensor<f32>
+ // CHECK-DAG: %[[RESULT1:.*]] = select %arg0, %arg3, %arg1 : tensor<f32>
+ // CHECK: return %[[RESULT0]], %[[RESULT1]] : tensor<f32>, tensor<f32>
+ %0 = "mhlo.get_tuple_element"(%arg1) {index = 0 : i32} : (tuple<tensor<f32>, tuple<tensor<f32>, tensor<f32>>>) -> tensor<f32>
+ %1 = "mhlo.get_tuple_element"(%arg1) {index = 1 : i32} : (tuple<tensor<f32>, tuple<tensor<f32>, tensor<f32>>>) -> tuple<tensor<f32>, tensor<f32>>
+ cond_br %arg0, ^bb1(%1 : tuple<tensor<f32>, tensor<f32>>), ^bb2(%0 : tensor<f32>)
+ ^bb1(%phi0 : tuple<tensor<f32>, tensor<f32>>):
+ return %phi0 : tuple<tensor<f32>, tensor<f32>>
+ ^bb2(%phi1 : tensor<f32>):
+ %2 = "mhlo.tuple"(%phi1, %phi1) : (tensor<f32>, tensor<f32>) -> tuple<tensor<f32>, tensor<f32>>
+ br ^bb1(%2 : tuple<tensor<f32>, tensor<f32>>)
+ }
+}
diff --git a/iree/compiler/Translation/IREEVM.cpp b/iree/compiler/Translation/IREEVM.cpp
index 91031b2..283cb18 100644
--- a/iree/compiler/Translation/IREEVM.cpp
+++ b/iree/compiler/Translation/IREEVM.cpp
@@ -59,12 +59,16 @@
static llvm::cl::opt<InputDialectOptions::Type> *typeFlag =
new llvm::cl::opt<InputDialectOptions::Type>{
"iree-input-type", llvm::cl::desc("IREE input type"),
- llvm::cl::values(clEnumValN(InputDialectOptions::Type::none, "none",
- "No input dialect transformation"),
- clEnumValN(InputDialectOptions::Type::tosa, "tosa",
- "Legalize from TOSA ops"),
- clEnumValN(InputDialectOptions::Type::mhlo, "mhlo",
- "Legalize from MHLO ops")),
+ llvm::cl::values(
+ clEnumValN(InputDialectOptions::Type::none, "none",
+ "No input dialect transformation"),
+ clEnumValN(InputDialectOptions::Type::tosa, "tosa",
+ "Legalize from TOSA ops"),
+ clEnumValN(InputDialectOptions::Type::mhlo, "mhlo",
+ "Legalize from MHLO ops"),
+ clEnumValN(
+ InputDialectOptions::Type::xla, "xla",
+ "Legalize from MHLO ops (with XLA cleanup preprocessing)")),
llvm::cl::init(InputDialectOptions::Type::none),
llvm::cl::cat(inputDialectOptions)};
@@ -77,13 +81,10 @@
BindingOptions bindingOptions, InputDialectOptions inputOptions,
IREE::HAL::TargetOptions executableOptions,
IREE::VM::TargetOptions targetOptions, OpPassManager &passManager) {
- if (bindingOptions.native) {
- IREE::ABI::buildTransformPassPipeline(passManager);
- }
- if (bindingOptions.tflite) {
- IREE::TFLite::buildTransformPassPipeline(passManager);
- }
-
+ // Input pipelines can result in changes to the exported functions and types
+ // and must run before generating bindings.
+ // After input processing, there should only be IREE legal types in
+ // signatures.
switch (inputOptions.type) {
case InputDialectOptions::Type::none:
break;
@@ -91,11 +92,23 @@
buildTOSAInputConversionPassPipeline(passManager);
break;
case InputDialectOptions::Type::mhlo:
- buildMHLOInputConversionPassPipeline(passManager);
+ MHLO::buildMHLOInputConversionPassPipeline(passManager);
+ break;
+ case InputDialectOptions::Type::xla:
+ MHLO::buildXLACleanupPassPipeline(passManager);
+ MHLO::buildMHLOInputConversionPassPipeline(passManager);
break;
}
-
buildCommonInputConversionPassPipeline(passManager);
+
+ // Now that inputs are legalized, generate wrapper for entry functions.
+ if (bindingOptions.native) {
+ IREE::ABI::buildTransformPassPipeline(passManager);
+ }
+ if (bindingOptions.tflite) {
+ IREE::TFLite::buildTransformPassPipeline(passManager);
+ }
+
IREE::Flow::TransformOptions flowOptions;
IREE::Flow::buildFlowTransformPassPipeline(passManager, flowOptions);
IREE::Stream::TransformOptions streamOptions;
diff --git a/iree/compiler/Translation/IREEVM.h b/iree/compiler/Translation/IREEVM.h
index 51edc6b..f1555a2 100644
--- a/iree/compiler/Translation/IREEVM.h
+++ b/iree/compiler/Translation/IREEVM.h
@@ -48,6 +48,10 @@
// Legalizes input defined over MHLO ops.
mhlo,
+
+ // Special case of 'mhlo' legalization which also performs some XLA
+ // cleanup activities.
+ xla,
};
Type type = Type::none;
};
diff --git a/iree/compiler/Utils/GraphUtils.cpp b/iree/compiler/Utils/GraphUtils.cpp
index 9a79f01..08342bd 100644
--- a/iree/compiler/Utils/GraphUtils.cpp
+++ b/iree/compiler/Utils/GraphUtils.cpp
@@ -48,5 +48,14 @@
return sortedOps;
}
+void sortBlockTopologically(Block *block) {
+ SetVector<Operation *> unsortedOps;
+ for (auto &op : *block) unsortedOps.insert(&op);
+ auto sortedOps = sortOpsTopologically(unsortedOps);
+ for (auto *op : llvm::reverse(sortedOps)) {
+ op->moveBefore(block, block->begin());
+ }
+}
+
} // namespace iree_compiler
} // namespace mlir
diff --git a/iree/compiler/Utils/GraphUtils.h b/iree/compiler/Utils/GraphUtils.h
index 98a9cbb..d2cca5b 100644
--- a/iree/compiler/Utils/GraphUtils.h
+++ b/iree/compiler/Utils/GraphUtils.h
@@ -32,6 +32,9 @@
return SmallVector<Operation *, N>(result.begin(), result.end());
}
+// Sorts all of the ops within |block| into an arbitrary topological order.
+void sortBlockTopologically(Block *block);
+
} // namespace iree_compiler
} // namespace mlir
diff --git a/iree/hal/allocator.h b/iree/hal/allocator.h
index c2e4096..03b87fc 100644
--- a/iree/hal/allocator.h
+++ b/iree/hal/allocator.h
@@ -174,9 +174,13 @@
// used in file IO or tests). Buffers allocated with this will not be compatible
// with real device allocators and will likely incur a copy (or failure) if
// used.
+//
+// The buffers created from the allocator will use |host_allocator| for their
+// metadata and |data_allocator| for their device storage allocations. If the
+// two are the same the buffers will be allocated in a single flat slab.
IREE_API_EXPORT iree_status_t iree_hal_allocator_create_heap(
- iree_string_view_t identifier, iree_allocator_t host_allocator,
- iree_hal_allocator_t** out_allocator);
+ iree_string_view_t identifier, iree_allocator_t data_allocator,
+ iree_allocator_t host_allocator, iree_hal_allocator_t** out_allocator);
//===----------------------------------------------------------------------===//
// iree_hal_allocator_t implementation details
diff --git a/iree/hal/allocator_heap.c b/iree/hal/allocator_heap.c
index e59cd17..dc5d4cc 100644
--- a/iree/hal/allocator_heap.c
+++ b/iree/hal/allocator_heap.c
@@ -16,6 +16,7 @@
typedef struct iree_hal_heap_allocator_t {
iree_hal_resource_t resource;
iree_allocator_t host_allocator;
+ iree_allocator_t data_allocator;
iree_string_view_t identifier;
IREE_STATISTICS(iree_hal_heap_allocator_statistics_t statistics;)
} iree_hal_heap_allocator_t;
@@ -28,8 +29,8 @@
}
IREE_API_EXPORT iree_status_t iree_hal_allocator_create_heap(
- iree_string_view_t identifier, iree_allocator_t host_allocator,
- iree_hal_allocator_t** out_allocator) {
+ iree_string_view_t identifier, iree_allocator_t data_allocator,
+ iree_allocator_t host_allocator, iree_hal_allocator_t** out_allocator) {
IREE_ASSERT_ARGUMENT(out_allocator);
IREE_TRACE_ZONE_BEGIN(z0);
@@ -42,6 +43,7 @@
iree_hal_resource_initialize(&iree_hal_heap_allocator_vtable,
&allocator->resource);
allocator->host_allocator = host_allocator;
+ allocator->data_allocator = data_allocator;
iree_string_view_append_to_buffer(
identifier, &allocator->identifier,
(char*)allocator + iree_sizeof_struct(*allocator));
@@ -82,9 +84,9 @@
static void iree_hal_heap_allocator_query_statistics(
iree_hal_allocator_t* base_allocator,
iree_hal_allocator_statistics_t* out_statistics) {
- iree_hal_heap_allocator_t* allocator =
- iree_hal_heap_allocator_cast(base_allocator);
IREE_STATISTICS({
+ iree_hal_heap_allocator_t* allocator =
+ iree_hal_heap_allocator_cast(base_allocator);
iree_slim_mutex_lock(&allocator->statistics.mutex);
memcpy(out_statistics, &allocator->statistics.base,
sizeof(*out_statistics));
@@ -158,9 +160,10 @@
// Allocate the buffer (both the wrapper and the contents).
iree_hal_heap_allocator_statistics_t* statistics = NULL;
IREE_STATISTICS(statistics = &allocator->statistics);
- return iree_hal_heap_buffer_create(
- base_allocator, statistics, memory_type, allowed_access, allowed_usage,
- allocation_size, allocator->host_allocator, out_buffer);
+ return iree_hal_heap_buffer_create(base_allocator, statistics, memory_type,
+ allowed_access, allowed_usage,
+ allocation_size, allocator->data_allocator,
+ allocator->host_allocator, out_buffer);
}
static iree_status_t iree_hal_heap_allocator_wrap_buffer(
diff --git a/iree/hal/buffer.c b/iree/hal/buffer.c
index a47b234..ca15374 100644
--- a/iree/hal/buffer.c
+++ b/iree/hal/buffer.c
@@ -170,9 +170,9 @@
!iree_all_bits_set(actual_memory_type, expected_memory_type))) {
// Missing one or more bits.
iree_bitfield_string_temp_t temp0, temp1;
- iree_string_view_t actual_memory_type_str =
+ iree_string_view_t actual_memory_type_str IREE_ATTRIBUTE_UNUSED =
iree_hal_memory_type_format(actual_memory_type, &temp0);
- iree_string_view_t expected_memory_type_str =
+ iree_string_view_t expected_memory_type_str IREE_ATTRIBUTE_UNUSED =
iree_hal_memory_type_format(expected_memory_type, &temp1);
return iree_make_status(
IREE_STATUS_PERMISSION_DENIED,
@@ -201,9 +201,9 @@
required_memory_access))) {
// Bits must match exactly.
iree_bitfield_string_temp_t temp0, temp1;
- iree_string_view_t allowed_memory_access_str =
+ iree_string_view_t allowed_memory_access_str IREE_ATTRIBUTE_UNUSED =
iree_hal_memory_access_format(allowed_memory_access, &temp0);
- iree_string_view_t required_memory_access_str =
+ iree_string_view_t required_memory_access_str IREE_ATTRIBUTE_UNUSED =
iree_hal_memory_access_format(required_memory_access, &temp1);
return iree_make_status(
IREE_STATUS_PERMISSION_DENIED,
@@ -221,9 +221,9 @@
if (IREE_UNLIKELY(!iree_all_bits_set(allowed_usage, required_usage))) {
// Missing one or more bits.
iree_bitfield_string_temp_t temp0, temp1;
- iree_string_view_t allowed_usage_str =
+ iree_string_view_t allowed_usage_str IREE_ATTRIBUTE_UNUSED =
iree_hal_buffer_usage_format(allowed_usage, &temp0);
- iree_string_view_t required_usage_str =
+ iree_string_view_t required_usage_str IREE_ATTRIBUTE_UNUSED =
iree_hal_buffer_usage_format(required_usage, &temp1);
return iree_make_status(
IREE_STATUS_PERMISSION_DENIED,
diff --git a/iree/hal/buffer_heap.c b/iree/hal/buffer_heap.c
index f7699bb..db5a229 100644
--- a/iree/hal/buffer_heap.c
+++ b/iree/hal/buffer_heap.c
@@ -27,23 +27,73 @@
static const iree_hal_buffer_vtable_t iree_hal_heap_buffer_vtable;
-iree_status_t iree_hal_heap_buffer_create(
- iree_hal_allocator_t* allocator,
- iree_hal_heap_allocator_statistics_t* statistics,
- iree_hal_memory_type_t memory_type, iree_hal_memory_access_t allowed_access,
- iree_hal_buffer_usage_t allowed_usage, iree_device_size_t allocation_size,
- iree_allocator_t host_allocator, iree_hal_buffer_t** out_buffer) {
- IREE_ASSERT_ARGUMENT(allocator);
- IREE_ASSERT_ARGUMENT(out_buffer);
- IREE_TRACE_ZONE_BEGIN(z0);
+// Allocates a buffer with the metadata and storage split.
+// This results in an additional host allocation but allows for user-overridden
+// data storage allocations.
+static iree_status_t iree_hal_heap_buffer_allocate_split(
+ iree_device_size_t allocation_size, iree_allocator_t data_allocator,
+ iree_allocator_t host_allocator, iree_hal_heap_buffer_t** out_buffer,
+ iree_byte_span_t* out_data) {
+ // Try allocating the storage first as it's the most likely to fail if OOM.
+ out_data->data_length = allocation_size;
+ IREE_RETURN_IF_ERROR(iree_allocator_malloc(data_allocator, allocation_size,
+ (void**)&out_data->data));
+ // Allocate the host metadata wrapper.
+ iree_status_t status = iree_allocator_malloc(
+ host_allocator, sizeof(**out_buffer), (void**)out_buffer);
+ if (!iree_status_is_ok(status)) {
+ // Need to free the storage we just allocated.
+ iree_allocator_free(data_allocator, out_data->data);
+ }
+ return status;
+}
+
+// Allocates a buffer with the metadata as a prefix to the storage.
+// This results in a single allocation per buffer but requires that both the
+// metadata and storage live together.
+static iree_status_t iree_hal_heap_buffer_allocate_slab(
+ iree_device_size_t allocation_size, iree_allocator_t host_allocator,
+ iree_hal_heap_buffer_t** out_buffer, iree_byte_span_t* out_data) {
// NOTE: we want the buffer data to always be 16-byte aligned.
iree_hal_heap_buffer_t* buffer = NULL;
iree_host_size_t header_size =
iree_host_align(iree_sizeof_struct(*buffer), 16);
iree_host_size_t total_size = header_size + allocation_size;
+ IREE_RETURN_IF_ERROR(
+ iree_allocator_malloc(host_allocator, total_size, (void**)&buffer));
+ *out_buffer = buffer;
+ *out_data =
+ iree_make_byte_span((uint8_t*)buffer + header_size, allocation_size);
+ return iree_ok_status();
+}
+
+iree_status_t iree_hal_heap_buffer_create(
+ iree_hal_allocator_t* allocator,
+ iree_hal_heap_allocator_statistics_t* statistics,
+ iree_hal_memory_type_t memory_type, iree_hal_memory_access_t allowed_access,
+ iree_hal_buffer_usage_t allowed_usage, iree_device_size_t allocation_size,
+ iree_allocator_t data_allocator, iree_allocator_t host_allocator,
+ iree_hal_buffer_t** out_buffer) {
+ IREE_ASSERT_ARGUMENT(allocator);
+ IREE_ASSERT_ARGUMENT(out_buffer);
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ // If the data and host allocators are the same we can allocate more
+ // efficiently as a large slab. Otherwise we need to allocate both the
+ // metadata and the storage independently.
+ bool same_allocator =
+ memcmp(&data_allocator, &host_allocator, sizeof(data_allocator)) == 0;
+
+ iree_hal_heap_buffer_t* buffer = NULL;
+ iree_byte_span_t data = iree_make_byte_span(NULL, 0);
iree_status_t status =
- iree_allocator_malloc(host_allocator, total_size, (void**)&buffer);
+ same_allocator
+ ? iree_hal_heap_buffer_allocate_slab(allocation_size, host_allocator,
+ &buffer, &data)
+ : iree_hal_heap_buffer_allocate_split(allocation_size, data_allocator,
+ host_allocator, &buffer, &data);
+
if (iree_status_is_ok(status)) {
iree_hal_resource_initialize(&iree_hal_heap_buffer_vtable,
&buffer->base.resource);
@@ -55,9 +105,9 @@
buffer->base.memory_type = memory_type;
buffer->base.allowed_access = allowed_access;
buffer->base.allowed_usage = allowed_usage;
- buffer->data =
- iree_make_byte_span((uint8_t*)buffer + header_size, allocation_size);
- buffer->data_allocator = iree_allocator_null(); // freed with the buffer
+ buffer->data = data;
+ buffer->data_allocator =
+ same_allocator ? iree_allocator_null() : data_allocator;
IREE_STATISTICS({
if (statistics != NULL) {
diff --git a/iree/hal/buffer_heap_impl.h b/iree/hal/buffer_heap_impl.h
index 5c01aff..5a2c5ac 100644
--- a/iree/hal/buffer_heap_impl.h
+++ b/iree/hal/buffer_heap_impl.h
@@ -26,14 +26,17 @@
iree_hal_allocator_statistics_t base;
} iree_hal_heap_allocator_statistics_t;
-// Allocates a new heap buffer from the specified |host_allocator|.
-// |out_buffer| must be released by the caller.
+// Allocates a new heap buffer from the specified |data_allocator|.
+// |host_allocator| is used for the iree_hal_buffer_t metadata. If both
+// |data_allocator| and |host_allocator| are the same the buffer will be created
+// as a flat slab. |out_buffer| must be released by the caller.
iree_status_t iree_hal_heap_buffer_create(
iree_hal_allocator_t* allocator,
iree_hal_heap_allocator_statistics_t* statistics,
iree_hal_memory_type_t memory_type, iree_hal_memory_access_t allowed_access,
iree_hal_buffer_usage_t allowed_usage, iree_device_size_t allocation_size,
- iree_allocator_t host_allocator, iree_hal_buffer_t** out_buffer);
+ iree_allocator_t data_allocator, iree_allocator_t host_allocator,
+ iree_hal_buffer_t** out_buffer);
#ifdef __cplusplus
} // extern "C"
diff --git a/iree/hal/buffer_view.c b/iree/hal/buffer_view.c
index bdc55bc..4de4a38 100644
--- a/iree/hal/buffer_view.c
+++ b/iree/hal/buffer_view.c
@@ -363,7 +363,7 @@
iree_hal_element_type_t element_type,
iree_hal_encoding_type_t encoding_type,
iree_device_size_t* out_allocation_size) {
- IREE_ASSERT_ARGUMENT(shape);
+ IREE_ASSERT_ARGUMENT(!shape_rank || shape);
IREE_ASSERT_ARGUMENT(out_allocation_size);
*out_allocation_size = 0;
diff --git a/iree/hal/command_buffer_validation.c b/iree/hal/command_buffer_validation.c
index 1f49a91..f8dffa9 100644
--- a/iree/hal/command_buffer_validation.c
+++ b/iree/hal/command_buffer_validation.c
@@ -57,9 +57,9 @@
if (!iree_all_bits_set(command_buffer->allowed_categories,
required_categories)) {
iree_bitfield_string_temp_t temp0, temp1;
- iree_string_view_t required_categories_str =
+ iree_string_view_t required_categories_str IREE_ATTRIBUTE_UNUSED =
iree_hal_command_category_format(required_categories, &temp0);
- iree_string_view_t allowed_categories_str =
+ iree_string_view_t allowed_categories_str IREE_ATTRIBUTE_UNUSED =
iree_hal_command_category_format(command_buffer->allowed_categories,
&temp1);
return iree_make_status(
@@ -87,9 +87,10 @@
if (!iree_all_bits_set(allowed_compatibility, required_compatibility)) {
// Buffer cannot be used on the queue for the given usage.
iree_bitfield_string_temp_t temp0, temp1;
- iree_string_view_t allowed_usage_str = iree_hal_buffer_usage_format(
- iree_hal_buffer_allowed_usage(buffer), &temp0);
- iree_string_view_t intended_usage_str =
+ iree_string_view_t allowed_usage_str IREE_ATTRIBUTE_UNUSED =
+ iree_hal_buffer_usage_format(iree_hal_buffer_allowed_usage(buffer),
+ &temp0);
+ iree_string_view_t intended_usage_str IREE_ATTRIBUTE_UNUSED =
iree_hal_buffer_usage_format(intended_usage, &temp1);
return iree_make_status(
IREE_STATUS_PERMISSION_DENIED,
@@ -454,10 +455,12 @@
!iree_any_bit_set(iree_hal_buffer_memory_type(target_buffer),
IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE)) {
iree_bitfield_string_temp_t temp0, temp1;
- iree_string_view_t source_memory_type_str = iree_hal_memory_type_format(
- iree_hal_buffer_memory_type(source_buffer), &temp0);
- iree_string_view_t target_memory_type_str = iree_hal_memory_type_format(
- iree_hal_buffer_memory_type(target_buffer), &temp1);
+ iree_string_view_t source_memory_type_str IREE_ATTRIBUTE_UNUSED =
+ iree_hal_memory_type_format(iree_hal_buffer_memory_type(source_buffer),
+ &temp0);
+ iree_string_view_t target_memory_type_str IREE_ATTRIBUTE_UNUSED =
+ iree_hal_memory_type_format(iree_hal_buffer_memory_type(target_buffer),
+ &temp1);
return iree_make_status(
IREE_STATUS_PERMISSION_DENIED,
"at least one buffer must be device-visible for a copy; "
diff --git a/iree/hal/cuda/cuda_allocator.c b/iree/hal/cuda/cuda_allocator.c
index de23b2c..573ddb9 100644
--- a/iree/hal/cuda/cuda_allocator.c
+++ b/iree/hal/cuda/cuda_allocator.c
@@ -18,6 +18,7 @@
iree_hal_resource_t resource;
iree_hal_cuda_context_wrapper_t* context;
CUdevice device;
+ bool supports_concurrent_managed_access;
CUstream stream;
IREE_STATISTICS(iree_hal_allocator_statistics_t statistics;)
@@ -36,6 +37,28 @@
iree_hal_allocator_t** out_allocator) {
IREE_ASSERT_ARGUMENT(context);
IREE_TRACE_ZONE_BEGIN(z0);
+
+ // To support device-local + host-visible memory we need concurrent managed
+ // access indicating that the host and devices can concurrently access the
+ // device memory. If we don't have this feature then we fall back to forcing
+ // all device-local + host-visible memory into host-local + device-visible
+ // page-locked memory. The compiler tries to avoid this for high-traffic
+ // buffers except for readback staging buffers.
+ int supports_concurrent_managed_access = 0;
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, CU_RESULT_TO_STATUS(
+ context->syms,
+ cuDeviceGetAttribute(
+ &supports_concurrent_managed_access,
+ CU_DEVICE_ATTRIBUTE_CONCURRENT_MANAGED_ACCESS, device),
+ "cuDeviceGetAttribute"));
+
+ IREE_TRACE_ZONE_APPEND_TEXT(
+ z0, supports_concurrent_managed_access
+ ? "has CONCURRENT_MANAGED_ACCESS"
+ : "no CONCURRENT_MANAGED_ACCESS (expect slow accesses on "
+ "device-local + host-visible memory)");
+
iree_hal_cuda_allocator_t* allocator = NULL;
iree_status_t status = iree_allocator_malloc(
context->host_allocator, sizeof(*allocator), (void**)&allocator);
@@ -45,6 +68,8 @@
allocator->context = context;
allocator->device = device;
allocator->stream = stream;
+ allocator->supports_concurrent_managed_access =
+ supports_concurrent_managed_access != 0;
*out_allocator = (iree_hal_allocator_t*)allocator;
}
@@ -121,6 +146,21 @@
// application is unlikely to do anything when requesting a 0-byte buffer; but
// it can happen in real world use cases. So we should at least not crash.
if (allocation_size == 0) allocation_size = 4;
+
+ // If concurrent managed access is not supported then make device-local +
+ // host-visible allocations fall back to host-local + device-visible
+ // page-locked memory. This will be significantly slower for the device to
+ // access but the compiler only uses this type for readback staging buffers
+ // and it's better to function than function fast.
+ if (!allocator->supports_concurrent_managed_access &&
+ iree_all_bits_set(memory_type, IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL |
+ IREE_HAL_MEMORY_TYPE_HOST_VISIBLE)) {
+ memory_type &= ~(IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL |
+ IREE_HAL_MEMORY_TYPE_HOST_VISIBLE);
+ memory_type |=
+ IREE_HAL_MEMORY_TYPE_HOST_LOCAL | IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE;
+ }
+
iree_status_t status;
void* host_ptr = NULL;
CUdeviceptr device_ptr = 0;
diff --git a/iree/hal/cuda/cuda_driver.c b/iree/hal/cuda/cuda_driver.c
index 82b241b..0d16334 100644
--- a/iree/hal/cuda/cuda_driver.c
+++ b/iree/hal/cuda/cuda_driver.c
@@ -124,16 +124,6 @@
// Return true if the device support all the extension required.
static bool iree_hal_cuda_is_valid_device(iree_hal_cuda_driver_t* driver,
CUdevice device) {
- int support_concurrent_managed_access = 0;
- iree_status_t status = CU_RESULT_TO_STATUS(
- &driver->syms,
- cuDeviceGetAttribute(&support_concurrent_managed_access,
- CU_DEVICE_ATTRIBUTE_CONCURRENT_MANAGED_ACCESS,
- device),
- "cuDeviceGetAttribute");
- if (!iree_status_is_ok(status) || !support_concurrent_managed_access) {
- return false;
- }
return true;
}
diff --git a/iree/hal/cuda/executable_layout.c b/iree/hal/cuda/executable_layout.c
index 21858ba..78ea17a 100644
--- a/iree/hal/cuda/executable_layout.c
+++ b/iree/hal/cuda/executable_layout.c
@@ -16,6 +16,7 @@
iree_hal_resource_t resource;
iree_hal_cuda_context_wrapper_t* context;
iree_host_size_t push_constant_base_index;
+ iree_host_size_t push_constant_count;
iree_host_size_t set_layout_count;
iree_hal_descriptor_set_layout_t* set_layouts[];
} iree_hal_cuda_executable_layout_t;
@@ -54,6 +55,14 @@
IREE_ASSERT_ARGUMENT(out_executable_layout);
*out_executable_layout = NULL;
IREE_TRACE_ZONE_BEGIN(z0);
+
+ if (push_constant_count > IREE_HAL_CUDA_MAX_PUSH_CONSTANT_COUNT) {
+ return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+ "push constant count %zu over the limit of %d",
+ push_constant_count,
+ IREE_HAL_CUDA_MAX_PUSH_CONSTANT_COUNT);
+ }
+
// Currently the executable layout doesn't do anything.
// TODO: Handle creating the argument layout at that time hadling both push
// constant and buffers.
@@ -76,6 +85,7 @@
iree_hal_cuda_descriptor_set_layout_binding_count(set_layouts[i]);
}
executable_layout->push_constant_base_index = binding_number;
+ executable_layout->push_constant_count = push_constant_count;
*out_executable_layout = (iree_hal_executable_layout_t*)executable_layout;
}
IREE_TRACE_ZONE_END(z0);
@@ -103,6 +113,13 @@
return executable_layout->push_constant_base_index;
}
+iree_host_size_t iree_hal_cuda_executable_layout_num_constants(
+ iree_hal_executable_layout_t* base_executable_layout) {
+ iree_hal_cuda_executable_layout_t* executable_layout =
+ iree_hal_cuda_executable_layout_cast(base_executable_layout);
+ return executable_layout->push_constant_count;
+}
+
const iree_hal_executable_layout_vtable_t
iree_hal_cuda_executable_layout_vtable = {
.destroy = iree_hal_cuda_executable_layout_destroy,
diff --git a/iree/hal/cuda/executable_layout.h b/iree/hal/cuda/executable_layout.h
index adf29ea..b7810e0 100644
--- a/iree/hal/cuda/executable_layout.h
+++ b/iree/hal/cuda/executable_layout.h
@@ -15,6 +15,8 @@
extern "C" {
#endif // __cplusplus
+#define IREE_HAL_CUDA_MAX_PUSH_CONSTANT_COUNT 64
+
// Creates the kernel arguments.
iree_status_t iree_hal_cuda_executable_layout_create(
iree_hal_cuda_context_wrapper_t* context, iree_host_size_t set_layout_count,
@@ -30,6 +32,10 @@
iree_host_size_t iree_hal_cuda_push_constant_index(
iree_hal_executable_layout_t* base_executable_layout);
+// Return the number of constants in the executable layout.
+iree_host_size_t iree_hal_cuda_executable_layout_num_constants(
+ iree_hal_executable_layout_t* base_executable_layout);
+
#ifdef __cplusplus
} // extern "C"
#endif // __cplusplus
diff --git a/iree/hal/cuda/graph_command_buffer.c b/iree/hal/cuda/graph_command_buffer.c
index b42dfef..f3ac78c 100644
--- a/iree/hal/cuda/graph_command_buffer.c
+++ b/iree/hal/cuda/graph_command_buffer.c
@@ -18,6 +18,10 @@
#include "iree/hal/cuda/native_executable.h"
#include "iree/hal/cuda/status_util.h"
+#define IREE_HAL_CUDA_MAX_BINDING_COUNT 64
+// Kernel arguments contains binding and push constants.
+#define IREE_HAL_CUDA_MAX_KERNEL_ARG 128
+
// Command buffer implementation that directly maps to cuda graph.
// This records the commands on the calling thread without additional threading
// indirection.
@@ -32,14 +36,11 @@
// Keep track of the last node added to the command buffer as we are currently
// serializing all the nodes (each node depends on the previous one).
CUgraphNode last_node;
+ int32_t push_constant[IREE_HAL_CUDA_MAX_PUSH_CONSTANT_COUNT];
// Keep track of the current set of kernel arguments.
void* current_descriptor[];
} iree_hal_cuda_graph_command_buffer_t;
-#define IREE_HAL_CUDA_MAX_BINDING_COUNT 64
-// Kernel arguments contains binding and push constants.
-#define IREE_HAL_CUDA_MAX_KERNEL_ARG 128
-
extern const iree_hal_command_buffer_vtable_t
iree_hal_cuda_graph_command_buffer_vtable;
@@ -349,11 +350,9 @@
const void* values, iree_host_size_t values_length) {
iree_hal_cuda_graph_command_buffer_t* command_buffer =
iree_hal_cuda_graph_command_buffer_cast(base_command_buffer);
- iree_host_size_t constant_base_index =
- iree_hal_cuda_push_constant_index(executable_layout) +
- offset / sizeof(int32_t);
+ iree_host_size_t constant_base_index = offset / sizeof(int32_t);
for (iree_host_size_t i = 0; i < values_length / sizeof(int32_t); i++) {
- *((uint32_t*)command_buffer->current_descriptor[i + constant_base_index]) =
+ command_buffer->push_constant[i + constant_base_index] =
((uint32_t*)values)[i];
}
return iree_ok_status();
@@ -425,7 +424,17 @@
uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z) {
iree_hal_cuda_graph_command_buffer_t* command_buffer =
iree_hal_cuda_graph_command_buffer_cast(base_command_buffer);
-
+ iree_hal_executable_layout_t* layout =
+ iree_hal_cuda_executable_get_layout(executable, entry_point);
+ iree_host_size_t num_constants =
+ iree_hal_cuda_executable_layout_num_constants(layout);
+ iree_host_size_t constant_base_index =
+ iree_hal_cuda_push_constant_index(layout);
+ // Patch the push constants in the kernel arguments.
+ for (iree_host_size_t i = 0; i < num_constants; i++) {
+ *((uint32_t*)command_buffer->current_descriptor[i + constant_base_index]) =
+ command_buffer->push_constant[i];
+ }
int32_t block_size_x, block_size_y, block_size_z;
IREE_RETURN_IF_ERROR(iree_hal_cuda_native_executable_block_size(
executable, entry_point, &block_size_x, &block_size_y, &block_size_z));
diff --git a/iree/hal/cuda/native_executable.c b/iree/hal/cuda/native_executable.c
index eeed4bb..fb8d498 100644
--- a/iree/hal/cuda/native_executable.c
+++ b/iree/hal/cuda/native_executable.c
@@ -11,6 +11,7 @@
#include "iree/base/api.h"
#include "iree/base/tracing.h"
#include "iree/hal/cuda/dynamic_symbols.h"
+#include "iree/hal/cuda/executable_layout.h"
#include "iree/hal/cuda/status_util.h"
// flatcc schemas:
@@ -28,6 +29,7 @@
typedef struct iree_hal_cuda_native_executable_t {
iree_hal_resource_t resource;
iree_hal_cuda_context_wrapper_t* context;
+ iree_hal_executable_layout_t** executable_layouts;
iree_host_size_t entry_count;
CUmodule module;
iree_hal_cuda_native_executable_function_t entry_functions[];
@@ -68,9 +70,13 @@
iree_host_size_t entry_count = flatbuffers_string_vec_len(entry_points_vec);
iree_host_size_t total_size =
sizeof(*executable) +
- entry_count * sizeof(iree_hal_cuda_native_executable_function_t);
+ entry_count * sizeof(iree_hal_cuda_native_executable_function_t) +
+ entry_count * sizeof(iree_hal_executable_layout_t*);
iree_status_t status = iree_allocator_malloc(context->host_allocator,
total_size, (void**)&executable);
+ executable->executable_layouts =
+ (void*)((char*)executable + sizeof(*executable) +
+ entry_count * sizeof(iree_hal_cuda_native_executable_function_t));
CUmodule module = NULL;
CUDA_RETURN_IF_ERROR(context->syms,
cuModuleLoadDataEx(&module, ptx_image, 0, NULL, NULL),
@@ -86,6 +92,8 @@
executable->entry_functions[i].block_size_x = block_sizes_vec[i].x;
executable->entry_functions[i].block_size_y = block_sizes_vec[i].y;
executable->entry_functions[i].block_size_z = block_sizes_vec[i].z;
+ executable->executable_layouts[i] = executable_spec->executable_layouts[i];
+ iree_hal_executable_layout_retain(executable_spec->executable_layouts[i]);
}
iree_hal_resource_initialize(&iree_hal_cuda_native_executable_vtable,
@@ -115,6 +123,13 @@
return iree_ok_status();
}
+iree_hal_executable_layout_t* iree_hal_cuda_executable_get_layout(
+ iree_hal_executable_t* base_executable, int32_t entry_point) {
+ iree_hal_cuda_native_executable_t* executable =
+ iree_hal_cuda_native_executable_cast(base_executable);
+ return executable->executable_layouts[entry_point];
+}
+
static void iree_hal_cuda_native_executable_destroy(
iree_hal_executable_t* base_executable) {
iree_hal_cuda_native_executable_t* executable =
@@ -122,6 +137,9 @@
iree_allocator_t host_allocator = executable->context->host_allocator;
IREE_TRACE_ZONE_BEGIN(z0);
+ for (iree_host_size_t i = 0; i < executable->entry_count; ++i) {
+ iree_hal_executable_layout_release(executable->executable_layouts[i]);
+ }
iree_allocator_free(host_allocator, executable);
IREE_TRACE_ZONE_END(z0);
diff --git a/iree/hal/cuda/native_executable.h b/iree/hal/cuda/native_executable.h
index bfecb1a..007331b 100644
--- a/iree/hal/cuda/native_executable.h
+++ b/iree/hal/cuda/native_executable.h
@@ -33,6 +33,10 @@
iree_hal_executable_t* executable, int32_t entry_point, uint32_t* x,
uint32_t* y, uint32_t* z);
+/// Return the layout associated with the entry point.
+iree_hal_executable_layout_t* iree_hal_cuda_executable_get_layout(
+ iree_hal_executable_t* executable, int32_t entry_point);
+
#ifdef __cplusplus
} // extern "C"
#endif // __cplusplus
diff --git a/iree/hal/dylib/registration/driver_module.c b/iree/hal/dylib/registration/driver_module.c
index 16fe55f..836db33 100644
--- a/iree/hal/dylib/registration/driver_module.c
+++ b/iree/hal/dylib/registration/driver_module.c
@@ -44,7 +44,7 @@
}
static iree_status_t iree_hal_dylib_driver_factory_try_create(
- void* self, iree_hal_driver_id_t driver_id, iree_allocator_t allocator,
+ void* self, iree_hal_driver_id_t driver_id, iree_allocator_t host_allocator,
iree_hal_driver_t** out_driver) {
if (driver_id != IREE_HAL_DYLIB_DRIVER_ID) {
return iree_make_status(IREE_STATUS_UNAVAILABLE,
@@ -62,26 +62,34 @@
iree_host_size_t loader_count = 0;
if (iree_status_is_ok(status)) {
status = iree_hal_embedded_library_loader_create(
- iree_hal_executable_import_provider_null(), allocator,
+ iree_hal_executable_import_provider_null(), host_allocator,
&loaders[loader_count++]);
}
if (iree_status_is_ok(status)) {
status = iree_hal_system_library_loader_create(
- iree_hal_executable_import_provider_null(), allocator,
+ iree_hal_executable_import_provider_null(), host_allocator,
&loaders[loader_count++]);
}
iree_task_executor_t* executor = NULL;
if (iree_status_is_ok(status)) {
- status = iree_task_executor_create_from_flags(allocator, &executor);
+ status = iree_task_executor_create_from_flags(host_allocator, &executor);
+ }
+
+ iree_hal_allocator_t* device_allocator = NULL;
+ if (iree_status_is_ok(status)) {
+ status = iree_hal_allocator_create_heap(iree_make_cstring_view("cpu"),
+ host_allocator, host_allocator,
+ &device_allocator);
}
if (iree_status_is_ok(status)) {
status = iree_hal_task_driver_create(
iree_make_cstring_view("cpu"), &default_params, executor, loader_count,
- loaders, allocator, out_driver);
+ loaders, device_allocator, host_allocator, out_driver);
}
+ iree_hal_allocator_release(device_allocator);
iree_task_executor_release(executor);
for (iree_host_size_t i = 0; i < loader_count; ++i) {
iree_hal_executable_loader_release(loaders[i]);
diff --git a/iree/hal/dylib/registration/driver_module_sync.c b/iree/hal/dylib/registration/driver_module_sync.c
index 3407a61..29f0a69 100644
--- a/iree/hal/dylib/registration/driver_module_sync.c
+++ b/iree/hal/dylib/registration/driver_module_sync.c
@@ -37,7 +37,7 @@
}
static iree_status_t iree_hal_dylib_sync_driver_factory_try_create(
- void* self, iree_hal_driver_id_t driver_id, iree_allocator_t allocator,
+ void* self, iree_hal_driver_id_t driver_id, iree_allocator_t host_allocator,
iree_hal_driver_t** out_driver) {
if (driver_id != IREE_HAL_DYLIB_SYNC_DRIVER_ID) {
return iree_make_status(IREE_STATUS_UNAVAILABLE,
@@ -53,15 +53,24 @@
iree_hal_executable_loader_t* loaders[1] = {NULL};
if (iree_status_is_ok(status)) {
status = iree_hal_embedded_library_loader_create(
- iree_hal_executable_import_provider_null(), allocator, &loaders[0]);
+ iree_hal_executable_import_provider_null(), host_allocator,
+ &loaders[0]);
+ }
+
+ iree_hal_allocator_t* device_allocator = NULL;
+ if (iree_status_is_ok(status)) {
+ status = iree_hal_allocator_create_heap(iree_make_cstring_view("cpu"),
+ host_allocator, host_allocator,
+ &device_allocator);
}
if (iree_status_is_ok(status)) {
status = iree_hal_sync_driver_create(
iree_make_cstring_view("cpu"), &default_params, IREE_ARRAYSIZE(loaders),
- loaders, allocator, out_driver);
+ loaders, device_allocator, host_allocator, out_driver);
}
+ iree_hal_allocator_release(device_allocator);
iree_hal_executable_loader_release(loaders[0]);
return status;
}
diff --git a/iree/hal/local/elf/CMakeLists.txt b/iree/hal/local/elf/CMakeLists.txt
index debba51..ecda505 100644
--- a/iree/hal/local/elf/CMakeLists.txt
+++ b/iree/hal/local/elf/CMakeLists.txt
@@ -88,9 +88,19 @@
# TODO(*): figure out how to make this work on Bazel+Windows.
if(${MSVC})
if(CMAKE_SYSTEM_PROCESSOR MATCHES "amd64.*|x86_64.*|AMD64.*")
- target_sources(iree_hal_local_elf_arch PRIVATE "arch/x86_64_msvc.asm")
set_source_files_properties(
- arch/x86_64_msvc.asm
- PROPERTIES LANGUAGE ASM_MASM)
+ arch/x86_64_msvc.asm
+ PROPERTIES
+ LANGUAGE ASM_MASM
+ )
+ # CMake + MASM does not work well and CMake ends up passing all our C/C++
+ # flags confusing MASM. We invoke MASM directly (ml64.exe) to keep it quiet.
+ target_sources(iree_hal_local_elf_arch PRIVATE "arch/x86_64_msvc.obj")
+ add_custom_command(
+ OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/arch/x86_64_msvc.obj
+ DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/arch/x86_64_msvc.asm
+ COMMAND ml64 /nologo /Zi /c /Fo ${CMAKE_CURRENT_BINARY_DIR}/arch/x86_64_msvc.obj ${CMAKE_CURRENT_SOURCE_DIR}/arch/x86_64_msvc.asm
+ VERBATIM
+ )
endif()
endif()
diff --git a/iree/hal/local/elf/elf_module.c b/iree/hal/local/elf/elf_module.c
index 3cffeb7..61f68e9 100644
--- a/iree/hal/local/elf/elf_module.c
+++ b/iree/hal/local/elf/elf_module.c
@@ -474,7 +474,8 @@
for (iree_host_size_t i = 1; i < module->dynsym_count; ++i) {
const iree_elf_sym_t* sym = &module->dynsym[i];
if (sym->st_shndx == IREE_ELF_SHN_UNDEF) {
- const char* symname = sym->st_name ? module->dynstr + sym->st_name : NULL;
+ const char* symname IREE_ATTRIBUTE_UNUSED =
+ sym->st_name ? module->dynstr + sym->st_name : NULL;
return iree_make_status(IREE_STATUS_UNAVAILABLE,
"ELF imports one or more symbols (trying "
"'%s'); imports are not supported in the "
diff --git a/iree/hal/local/executable_library_benchmark.c b/iree/hal/local/executable_library_benchmark.c
index 1f466f4..2e3588e 100644
--- a/iree/hal/local/executable_library_benchmark.c
+++ b/iree/hal/local/executable_library_benchmark.c
@@ -263,7 +263,8 @@
// memory accessed by the invocation will come from here.
iree_hal_allocator_t* heap_allocator = NULL;
IREE_RETURN_IF_ERROR(iree_hal_allocator_create_heap(
- iree_make_cstring_view("benchmark"), host_allocator, &heap_allocator));
+ iree_make_cstring_view("benchmark"), host_allocator, host_allocator,
+ &heap_allocator));
iree_hal_buffer_view_t* buffer_views[IREE_HAL_LOCAL_MAX_TOTAL_BINDING_COUNT];
void* binding_ptrs[IREE_HAL_LOCAL_MAX_TOTAL_BINDING_COUNT];
size_t binding_lengths[IREE_HAL_LOCAL_MAX_TOTAL_BINDING_COUNT];
diff --git a/iree/hal/local/sync_device.c b/iree/hal/local/sync_device.c
index 924da83..778a802 100644
--- a/iree/hal/local/sync_device.c
+++ b/iree/hal/local/sync_device.c
@@ -53,9 +53,11 @@
iree_status_t iree_hal_sync_device_create(
iree_string_view_t identifier, const iree_hal_sync_device_params_t* params,
iree_host_size_t loader_count, iree_hal_executable_loader_t** loaders,
- iree_allocator_t host_allocator, iree_hal_device_t** out_device) {
+ iree_hal_allocator_t* device_allocator, iree_allocator_t host_allocator,
+ iree_hal_device_t** out_device) {
IREE_ASSERT_ARGUMENT(params);
IREE_ASSERT_ARGUMENT(!loader_count || loaders);
+ IREE_ASSERT_ARGUMENT(device_allocator);
IREE_ASSERT_ARGUMENT(out_device);
*out_device = NULL;
IREE_TRACE_ZONE_BEGIN(z0);
@@ -76,6 +78,8 @@
iree_string_view_append_to_buffer(identifier, &device->identifier,
(char*)device + struct_size);
device->host_allocator = host_allocator;
+ device->device_allocator = device_allocator;
+ iree_hal_allocator_retain(device_allocator);
device->loader_count = loader_count;
for (iree_host_size_t i = 0; i < device->loader_count; ++i) {
@@ -87,11 +91,6 @@
}
if (iree_status_is_ok(status)) {
- status = iree_hal_allocator_create_heap(identifier, host_allocator,
- &device->device_allocator);
- }
-
- if (iree_status_is_ok(status)) {
*out_device = (iree_hal_device_t*)device;
} else {
iree_hal_device_release((iree_hal_device_t*)device);
diff --git a/iree/hal/local/sync_device.h b/iree/hal/local/sync_device.h
index 8990d35..de990b7 100644
--- a/iree/hal/local/sync_device.h
+++ b/iree/hal/local/sync_device.h
@@ -31,7 +31,8 @@
iree_status_t iree_hal_sync_device_create(
iree_string_view_t identifier, const iree_hal_sync_device_params_t* params,
iree_host_size_t loader_count, iree_hal_executable_loader_t** loaders,
- iree_allocator_t host_allocator, iree_hal_device_t** out_device);
+ iree_hal_allocator_t* device_allocator, iree_allocator_t host_allocator,
+ iree_hal_device_t** out_device);
#ifdef __cplusplus
} // extern "C"
diff --git a/iree/hal/local/sync_driver.c b/iree/hal/local/sync_driver.c
index d9349cd..c8291f2 100644
--- a/iree/hal/local/sync_driver.c
+++ b/iree/hal/local/sync_driver.c
@@ -16,6 +16,7 @@
typedef struct iree_hal_sync_driver_t {
iree_hal_resource_t resource;
iree_allocator_t host_allocator;
+ iree_hal_allocator_t* device_allocator;
iree_string_view_t identifier;
iree_hal_sync_device_params_t default_params;
@@ -36,9 +37,11 @@
iree_string_view_t identifier,
const iree_hal_sync_device_params_t* default_params,
iree_host_size_t loader_count, iree_hal_executable_loader_t** loaders,
- iree_allocator_t host_allocator, iree_hal_driver_t** out_driver) {
+ iree_hal_allocator_t* device_allocator, iree_allocator_t host_allocator,
+ iree_hal_driver_t** out_driver) {
IREE_ASSERT_ARGUMENT(default_params);
IREE_ASSERT_ARGUMENT(!loader_count || loaders);
+ IREE_ASSERT_ARGUMENT(device_allocator);
IREE_ASSERT_ARGUMENT(out_driver);
*out_driver = NULL;
IREE_TRACE_ZONE_BEGIN(z0);
@@ -53,6 +56,8 @@
iree_hal_resource_initialize(&iree_hal_sync_driver_vtable,
&driver->resource);
driver->host_allocator = host_allocator;
+ driver->device_allocator = device_allocator;
+ iree_hal_allocator_retain(device_allocator);
iree_string_view_append_to_buffer(
identifier, &driver->identifier,
@@ -81,6 +86,7 @@
iree_allocator_t host_allocator = driver->host_allocator;
IREE_TRACE_ZONE_BEGIN(z0);
+ iree_hal_allocator_release(driver->device_allocator);
for (iree_host_size_t i = 0; i < driver->loader_count; ++i) {
iree_hal_executable_loader_release(driver->loaders[i]);
}
@@ -107,11 +113,11 @@
static iree_status_t iree_hal_sync_driver_create_device(
iree_hal_driver_t* base_driver, iree_hal_device_id_t device_id,
- iree_allocator_t allocator, iree_hal_device_t** out_device) {
+ iree_allocator_t host_allocator, iree_hal_device_t** out_device) {
iree_hal_sync_driver_t* driver = iree_hal_sync_driver_cast(base_driver);
return iree_hal_sync_device_create(
driver->identifier, &driver->default_params, driver->loader_count,
- driver->loaders, allocator, out_device);
+ driver->loaders, driver->device_allocator, host_allocator, out_device);
}
static const iree_hal_driver_vtable_t iree_hal_sync_driver_vtable = {
diff --git a/iree/hal/local/sync_driver.h b/iree/hal/local/sync_driver.h
index 04de5ac..f4ff241 100644
--- a/iree/hal/local/sync_driver.h
+++ b/iree/hal/local/sync_driver.h
@@ -23,7 +23,8 @@
iree_string_view_t identifier,
const iree_hal_sync_device_params_t* default_params,
iree_host_size_t loader_count, iree_hal_executable_loader_t** loaders,
- iree_allocator_t host_allocator, iree_hal_driver_t** out_driver);
+ iree_hal_allocator_t* device_allocator, iree_allocator_t host_allocator,
+ iree_hal_driver_t** out_driver);
#ifdef __cplusplus
} // extern "C"
diff --git a/iree/hal/local/sync_semaphore.c b/iree/hal/local/sync_semaphore.c
index f982f22..c990ec1 100644
--- a/iree/hal/local/sync_semaphore.c
+++ b/iree/hal/local/sync_semaphore.c
@@ -131,7 +131,7 @@
static iree_status_t iree_hal_sync_semaphore_signal_unsafe(
iree_hal_sync_semaphore_t* semaphore, uint64_t new_value) {
if (new_value <= semaphore->current_value) {
- uint64_t current_value = semaphore->current_value;
+ uint64_t current_value IREE_ATTRIBUTE_UNUSED = semaphore->current_value;
iree_slim_mutex_unlock(&semaphore->mutex);
return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
"semaphore values must be monotonically "
diff --git a/iree/hal/local/task_device.c b/iree/hal/local/task_device.c
index 236c83e..a800457 100644
--- a/iree/hal/local/task_device.c
+++ b/iree/hal/local/task_device.c
@@ -80,10 +80,12 @@
iree_status_t iree_hal_task_device_create(
iree_string_view_t identifier, const iree_hal_task_device_params_t* params,
iree_task_executor_t* executor, iree_host_size_t loader_count,
- iree_hal_executable_loader_t** loaders, iree_allocator_t host_allocator,
+ iree_hal_executable_loader_t** loaders,
+ iree_hal_allocator_t* device_allocator, iree_allocator_t host_allocator,
iree_hal_device_t** out_device) {
IREE_ASSERT_ARGUMENT(params);
IREE_ASSERT_ARGUMENT(!loader_count || loaders);
+ IREE_ASSERT_ARGUMENT(device_allocator);
IREE_ASSERT_ARGUMENT(out_device);
*out_device = NULL;
IREE_TRACE_ZONE_BEGIN(z0);
@@ -105,6 +107,9 @@
iree_string_view_append_to_buffer(identifier, &device->identifier,
(char*)device + struct_size);
device->host_allocator = host_allocator;
+ device->device_allocator = device_allocator;
+ iree_hal_allocator_retain(device_allocator);
+
iree_arena_block_pool_initialize(4096, host_allocator,
&device->small_block_pool);
iree_arena_block_pool_initialize(params->arena_block_size, host_allocator,
@@ -134,11 +139,6 @@
}
if (iree_status_is_ok(status)) {
- status = iree_hal_allocator_create_heap(identifier, host_allocator,
- &device->device_allocator);
- }
-
- if (iree_status_is_ok(status)) {
status = iree_hal_local_event_pool_allocate(
IREE_HAL_LOCAL_TASK_EVENT_POOL_CAPACITY, host_allocator,
&device->event_pool);
diff --git a/iree/hal/local/task_device.h b/iree/hal/local/task_device.h
index 6f4c9b3..d43c1cf 100644
--- a/iree/hal/local/task_device.h
+++ b/iree/hal/local/task_device.h
@@ -40,7 +40,8 @@
iree_status_t iree_hal_task_device_create(
iree_string_view_t identifier, const iree_hal_task_device_params_t* params,
iree_task_executor_t* executor, iree_host_size_t loader_count,
- iree_hal_executable_loader_t** loaders, iree_allocator_t host_allocator,
+ iree_hal_executable_loader_t** loaders,
+ iree_hal_allocator_t* device_allocator, iree_allocator_t host_allocator,
iree_hal_device_t** out_device);
#ifdef __cplusplus
diff --git a/iree/hal/local/task_driver.c b/iree/hal/local/task_driver.c
index 7711520..49218c4 100644
--- a/iree/hal/local/task_driver.c
+++ b/iree/hal/local/task_driver.c
@@ -16,6 +16,7 @@
typedef struct iree_hal_task_driver_t {
iree_hal_resource_t resource;
iree_allocator_t host_allocator;
+ iree_hal_allocator_t* device_allocator;
iree_string_view_t identifier;
iree_hal_task_device_params_t default_params;
@@ -38,10 +39,12 @@
iree_string_view_t identifier,
const iree_hal_task_device_params_t* default_params,
iree_task_executor_t* executor, iree_host_size_t loader_count,
- iree_hal_executable_loader_t** loaders, iree_allocator_t host_allocator,
+ iree_hal_executable_loader_t** loaders,
+ iree_hal_allocator_t* device_allocator, iree_allocator_t host_allocator,
iree_hal_driver_t** out_driver) {
IREE_ASSERT_ARGUMENT(default_params);
IREE_ASSERT_ARGUMENT(!loader_count || loaders);
+ IREE_ASSERT_ARGUMENT(device_allocator);
IREE_ASSERT_ARGUMENT(out_driver);
*out_driver = NULL;
IREE_TRACE_ZONE_BEGIN(z0);
@@ -56,6 +59,8 @@
iree_hal_resource_initialize(&iree_hal_task_driver_vtable,
&driver->resource);
driver->host_allocator = host_allocator;
+ driver->device_allocator = device_allocator;
+ iree_hal_allocator_retain(device_allocator);
iree_string_view_append_to_buffer(identifier, &driver->identifier,
(char*)driver + struct_size);
@@ -86,6 +91,7 @@
iree_allocator_t host_allocator = driver->host_allocator;
IREE_TRACE_ZONE_BEGIN(z0);
+ iree_hal_allocator_release(driver->device_allocator);
for (iree_host_size_t i = 0; i < driver->loader_count; ++i) {
iree_hal_executable_loader_release(driver->loaders[i]);
}
@@ -113,11 +119,12 @@
static iree_status_t iree_hal_task_driver_create_device(
iree_hal_driver_t* base_driver, iree_hal_device_id_t device_id,
- iree_allocator_t allocator, iree_hal_device_t** out_device) {
+ iree_allocator_t host_allocator, iree_hal_device_t** out_device) {
iree_hal_task_driver_t* driver = iree_hal_task_driver_cast(base_driver);
return iree_hal_task_device_create(
driver->identifier, &driver->default_params, driver->executor,
- driver->loader_count, driver->loaders, allocator, out_device);
+ driver->loader_count, driver->loaders, driver->device_allocator,
+ host_allocator, out_device);
}
static const iree_hal_driver_vtable_t iree_hal_task_driver_vtable = {
diff --git a/iree/hal/local/task_driver.h b/iree/hal/local/task_driver.h
index 415a832..4c36d2a 100644
--- a/iree/hal/local/task_driver.h
+++ b/iree/hal/local/task_driver.h
@@ -24,7 +24,8 @@
iree_string_view_t identifier,
const iree_hal_task_device_params_t* default_params,
iree_task_executor_t* executor, iree_host_size_t loader_count,
- iree_hal_executable_loader_t** loaders, iree_allocator_t host_allocator,
+ iree_hal_executable_loader_t** loaders,
+ iree_hal_allocator_t* device_allocator, iree_allocator_t host_allocator,
iree_hal_driver_t** out_driver);
#ifdef __cplusplus
diff --git a/iree/hal/local/task_semaphore.c b/iree/hal/local/task_semaphore.c
index b4c6289..6aad400 100644
--- a/iree/hal/local/task_semaphore.c
+++ b/iree/hal/local/task_semaphore.c
@@ -238,7 +238,7 @@
iree_slim_mutex_lock(&semaphore->mutex);
if (new_value <= semaphore->current_value) {
- uint64_t current_value = semaphore->current_value;
+ uint64_t current_value IREE_ATTRIBUTE_UNUSED = semaphore->current_value;
iree_slim_mutex_unlock(&semaphore->mutex);
return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
"semaphore values must be monotonically "
diff --git a/iree/hal/string_util_test.cc b/iree/hal/string_util_test.cc
index c661520..da3a51e 100644
--- a/iree/hal/string_util_test.cc
+++ b/iree/hal/string_util_test.cc
@@ -386,9 +386,9 @@
// used.
static StatusOr<Allocator> CreateHostLocal() {
Allocator allocator;
- iree_status_t status =
- iree_hal_allocator_create_heap(iree_make_cstring_view("host_local"),
- iree_allocator_system(), &allocator);
+ iree_status_t status = iree_hal_allocator_create_heap(
+ iree_make_cstring_view("host_local"), iree_allocator_system(),
+ iree_allocator_system(), &allocator);
IREE_RETURN_IF_ERROR(std::move(status));
return std::move(allocator);
}
diff --git a/iree/hal/vmvx/registration/driver_module.c b/iree/hal/vmvx/registration/driver_module.c
index b58836a..a1a0228 100644
--- a/iree/hal/vmvx/registration/driver_module.c
+++ b/iree/hal/vmvx/registration/driver_module.c
@@ -41,7 +41,7 @@
}
static iree_status_t iree_hal_vmvx_driver_factory_try_create(
- void* self, iree_hal_driver_id_t driver_id, iree_allocator_t allocator,
+ void* self, iree_hal_driver_id_t driver_id, iree_allocator_t host_allocator,
iree_hal_driver_t** out_driver) {
if (driver_id != IREE_HAL_VMVX_DRIVER_ID) {
return iree_make_status(IREE_STATUS_UNAVAILABLE,
@@ -51,27 +51,36 @@
}
iree_vm_instance_t* instance = NULL;
- IREE_RETURN_IF_ERROR(iree_vm_instance_create(allocator, &instance));
+ IREE_RETURN_IF_ERROR(iree_vm_instance_create(host_allocator, &instance));
iree_hal_task_device_params_t default_params;
iree_hal_task_device_params_initialize(&default_params);
iree_hal_executable_loader_t* vmvx_loader = NULL;
- iree_status_t status =
- iree_hal_vmvx_module_loader_create(instance, allocator, &vmvx_loader);
+ iree_status_t status = iree_hal_vmvx_module_loader_create(
+ instance, host_allocator, &vmvx_loader);
iree_hal_executable_loader_t* loaders[1] = {vmvx_loader};
iree_task_executor_t* executor = NULL;
if (iree_status_is_ok(status)) {
- status = iree_task_executor_create_from_flags(allocator, &executor);
+ status = iree_task_executor_create_from_flags(host_allocator, &executor);
+ }
+
+ iree_hal_allocator_t* device_allocator = NULL;
+ if (iree_status_is_ok(status)) {
+ status = iree_hal_allocator_create_heap(iree_make_cstring_view("vmvx"),
+ host_allocator, host_allocator,
+ &device_allocator);
}
if (iree_status_is_ok(status)) {
status = iree_hal_task_driver_create(
iree_make_cstring_view("vmvx"), &default_params, executor,
- IREE_ARRAYSIZE(loaders), loaders, allocator, out_driver);
+ IREE_ARRAYSIZE(loaders), loaders, device_allocator, host_allocator,
+ out_driver);
}
+ iree_hal_allocator_release(device_allocator);
iree_task_executor_release(executor);
iree_hal_executable_loader_release(vmvx_loader);
iree_vm_instance_release(instance);
diff --git a/iree/hal/vmvx/registration/driver_module_sync.c b/iree/hal/vmvx/registration/driver_module_sync.c
index 5ab7e0b..6a5fc70 100644
--- a/iree/hal/vmvx/registration/driver_module_sync.c
+++ b/iree/hal/vmvx/registration/driver_module_sync.c
@@ -41,7 +41,7 @@
}
static iree_status_t iree_hal_vmvx_sync_driver_factory_try_create(
- void* self, iree_hal_driver_id_t driver_id, iree_allocator_t allocator,
+ void* self, iree_hal_driver_id_t driver_id, iree_allocator_t host_allocator,
iree_hal_driver_t** out_driver) {
if (driver_id != IREE_HAL_VMVX_SYNC_DRIVER_ID) {
return iree_make_status(IREE_STATUS_UNAVAILABLE,
@@ -51,21 +51,31 @@
}
iree_vm_instance_t* instance = NULL;
- IREE_RETURN_IF_ERROR(iree_vm_instance_create(allocator, &instance));
+ IREE_RETURN_IF_ERROR(iree_vm_instance_create(host_allocator, &instance));
iree_hal_executable_loader_t* vmvx_loader = NULL;
- iree_status_t status =
- iree_hal_vmvx_module_loader_create(instance, allocator, &vmvx_loader);
+ iree_status_t status = iree_hal_vmvx_module_loader_create(
+ instance, host_allocator, &vmvx_loader);
iree_hal_executable_loader_t* loaders[1] = {vmvx_loader};
+ iree_hal_allocator_t* device_allocator = NULL;
+ if (iree_status_is_ok(status)) {
+ status = iree_hal_allocator_create_heap(iree_make_cstring_view("vmvx"),
+ host_allocator, host_allocator,
+ &device_allocator);
+ }
+
// Set parameters for the device created in the next step.
iree_hal_sync_device_params_t default_params;
iree_hal_sync_device_params_initialize(&default_params);
if (iree_status_is_ok(status)) {
status = iree_hal_sync_driver_create(
iree_make_cstring_view("vmvx"), &default_params,
- IREE_ARRAYSIZE(loaders), loaders, allocator, out_driver);
+ IREE_ARRAYSIZE(loaders), loaders, device_allocator, host_allocator,
+ out_driver);
}
+
+ iree_hal_allocator_release(device_allocator);
iree_hal_executable_loader_release(vmvx_loader);
iree_vm_instance_release(instance);
return status;
diff --git a/iree/modules/check/test/success.mlir b/iree/modules/check/test/success.mlir
index 236c83f..9c9650f 100644
--- a/iree/modules/check/test/success.mlir
+++ b/iree/modules/check/test/success.mlir
@@ -15,7 +15,7 @@
func @expect_all_true() {
%all_true = util.unfoldable_constant dense<1> : tensor<2x2xi32>
- %all_true_view = hal.tensor.cast %all_true : tensor<2x2xi32> -> !hal.buffer_view
+ %all_true_view = hal.tensor.export %all_true : tensor<2x2xi32> -> !hal.buffer_view
check.expect_all_true(%all_true_view) : !hal.buffer_view
return
}
diff --git a/iree/modules/hal/module.c b/iree/modules/hal/module.c
index 1faa691..72a0bae 100644
--- a/iree/modules/hal/module.c
+++ b/iree/modules/hal/module.c
@@ -414,7 +414,8 @@
IREE_RETURN_IF_ERROR(iree_hal_buffer_check_deref(args->r0, &buffer));
iree_vm_buffer_t* message = NULL;
IREE_RETURN_IF_ERROR(iree_vm_buffer_check_deref(args->r1, &message));
- iree_string_view_t message_str = iree_vm_buffer_as_string(message);
+ iree_string_view_t message_str IREE_ATTRIBUTE_UNUSED =
+ iree_vm_buffer_as_string(message);
iree_hal_allocator_t* allocator = NULL;
IREE_RETURN_IF_ERROR(iree_hal_allocator_check_deref(args->r2, &allocator));
iree_vm_size_t minimum_length = (iree_vm_size_t)args->i3;
@@ -621,7 +622,8 @@
iree_hal_buffer_view_check_deref(args->r0, &buffer_view));
iree_vm_buffer_t* message = NULL;
IREE_RETURN_IF_ERROR(iree_vm_buffer_check_deref(args->r1, &message));
- iree_string_view_t message_str = iree_vm_buffer_as_string(message);
+ iree_string_view_t message_str IREE_ATTRIBUTE_UNUSED =
+ iree_vm_buffer_as_string(message);
iree_hal_element_type_t expected_element_type =
(iree_hal_element_type_t)args->i2;
iree_hal_encoding_type_t expected_encoding_type =
diff --git a/iree/samples/dynamic_shapes/dynamic_shapes.ipynb b/iree/samples/dynamic_shapes/dynamic_shapes.ipynb
index 29a85af..8b1b6da 100644
--- a/iree/samples/dynamic_shapes/dynamic_shapes.ipynb
+++ b/iree/samples/dynamic_shapes/dynamic_shapes.ipynb
@@ -191,10 +191,10 @@
" func @add_one(%arg0: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi = \"{\\22a\\22:[[\\22ndarray\\22,\\22i32\\22,1,null]],\\22r\\22:[[\\22ndarray\\22,\\22i32\\22,1,null]],\\22v\\22:1}\"} {\n",
" %c0 = arith.constant 0 : index\n",
" %0 = hal.buffer_view.dim %arg0, 0 : index\n",
- " %1 = hal.tensor.cast %arg0 : !hal.buffer_view -> tensor<?xi32>{%0}\n",
+ " %1 = hal.tensor.import %arg0 : !hal.buffer_view -> tensor<?xi32>{%0}\n",
" %2 = call @__inference_add_one_70(%1) : (tensor<?xi32>) -> tensor<?xi32>\n",
" %3 = memref.dim %2, %c0 : tensor<?xi32>\n",
- " %4 = hal.tensor.cast %2 : tensor<?xi32>{%3} -> !hal.buffer_view\n",
+ " %4 = hal.tensor.export %2 : tensor<?xi32>{%3} -> !hal.buffer_view\n",
" return %4 : !hal.buffer_view\n",
" }\n",
" func private @__inference_add_one_70(%arg0: tensor<?xi32> {tf._user_specified_name = \"values\"}) -> tensor<?xi32> attributes {tf._construction_context = \"kEagerRuntime\", tf._input_shapes = [#tf.shape<?>]} {\n",
@@ -218,7 +218,7 @@
" func @reduce_sum_1d(%arg0: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi = \"{\\22a\\22:[[\\22ndarray\\22,\\22i32\\22,1,null]],\\22r\\22:[[\\22ndarray\\22,\\22i32\\22,0]],\\22v\\22:1}\"} {\n",
" %c0_i32 = arith.constant 0 : i32\n",
" %0 = hal.buffer_view.dim %arg0, 0 : index\n",
- " %1 = hal.tensor.cast %arg0 : !hal.buffer_view -> tensor<?xi32>{%0}\n",
+ " %1 = hal.tensor.import %arg0 : !hal.buffer_view -> tensor<?xi32>{%0}\n",
" %2 = linalg.init_tensor [] : tensor<i32>\n",
" %3 = linalg.fill(%2, %c0_i32) : tensor<i32>, i32 -> tensor<i32> \n",
" %4 = linalg.generic {indexing_maps = [#map1, #map0], iterator_types = [\"reduction\"]} ins(%1 : tensor<?xi32>) outs(%3 : tensor<i32>) {\n",
@@ -226,13 +226,13 @@
" %6 = arith.addi %arg1, %arg2 : i32\n",
" linalg.yield %6 : i32\n",
" } -> tensor<i32>\n",
- " %5 = hal.tensor.cast %4 : tensor<i32> -> !hal.buffer_view\n",
+ " %5 = hal.tensor.export %4 : tensor<i32> -> !hal.buffer_view\n",
" return %5 : !hal.buffer_view\n",
" }\n",
" func @reduce_sum_2d(%arg0: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi = \"{\\22a\\22:[[\\22ndarray\\22,\\22i32\\22,2,null,3]],\\22r\\22:[[\\22ndarray\\22,\\22i32\\22,1,3]],\\22v\\22:1}\"} {\n",
" %c0_i32 = arith.constant 0 : i32\n",
" %0 = hal.buffer_view.dim %arg0, 0 : index\n",
- " %1 = hal.tensor.cast %arg0 : !hal.buffer_view -> tensor<?x3xi32>{%0}\n",
+ " %1 = hal.tensor.import %arg0 : !hal.buffer_view -> tensor<?x3xi32>{%0}\n",
" %2 = linalg.init_tensor [3] : tensor<3xi32>\n",
" %3 = linalg.fill(%2, %c0_i32) : tensor<3xi32>, i32 -> tensor<3xi32> \n",
" %4 = linalg.generic {indexing_maps = [#map2, #map3], iterator_types = [\"parallel\", \"reduction\"]} ins(%1 : tensor<?x3xi32>) outs(%3 : tensor<3xi32>) {\n",
@@ -240,7 +240,7 @@
" %6 = arith.addi %arg1, %arg2 : i32\n",
" linalg.yield %6 : i32\n",
" } -> tensor<3xi32>\n",
- " %5 = hal.tensor.cast %4 : tensor<3xi32> -> !hal.buffer_view\n",
+ " %5 = hal.tensor.export %4 : tensor<3xi32> -> !hal.buffer_view\n",
" return %5 : !hal.buffer_view\n",
" }\n",
"}\n",
@@ -414,4 +414,4 @@
"outputs": []
}
]
-}
\ No newline at end of file
+}
diff --git a/iree/samples/simple_embedding/device_cuda.c b/iree/samples/simple_embedding/device_cuda.c
index e8cf2f4..06f439d 100644
--- a/iree/samples/simple_embedding/device_cuda.c
+++ b/iree/samples/simple_embedding/device_cuda.c
@@ -15,18 +15,24 @@
// Compiled module embedded here to avoid file IO:
#include "iree/samples/simple_embedding/simple_embedding_test_bytecode_module_cuda_c.h"
-iree_status_t create_sample_device(iree_hal_device_t** device) {
- // Only register the cuda HAL driver.
+iree_status_t create_sample_device(iree_allocator_t host_allocator,
+ iree_hal_device_t** out_device) {
+ // Only register the CUDA HAL driver.
IREE_RETURN_IF_ERROR(
iree_hal_cuda_driver_module_register(iree_hal_driver_registry_default()));
- // Create the hal driver from the name.
+
+ // Create the HAL driver from the name.
iree_hal_driver_t* driver = NULL;
iree_string_view_t identifier = iree_make_cstring_view("cuda");
- IREE_RETURN_IF_ERROR(iree_hal_driver_registry_try_create_by_name(
- iree_hal_driver_registry_default(), identifier, iree_allocator_system(),
- &driver));
- IREE_RETURN_IF_ERROR(iree_hal_driver_create_default_device(
- driver, iree_allocator_system(), device));
+ iree_status_t status = iree_hal_driver_registry_try_create_by_name(
+ iree_hal_driver_registry_default(), identifier, host_allocator, &driver);
+
+ // Create the default device (primary GPU).
+ if (iree_status_is_ok(status)) {
+ status = iree_hal_driver_create_default_device(driver, host_allocator,
+ out_device);
+ }
+
iree_hal_driver_release(driver);
return iree_ok_status();
}
diff --git a/iree/samples/simple_embedding/device_dylib.c b/iree/samples/simple_embedding/device_dylib.c
index ffdabcf..e8ab727 100644
--- a/iree/samples/simple_embedding/device_dylib.c
+++ b/iree/samples/simple_embedding/device_dylib.c
@@ -20,27 +20,37 @@
#include "iree/samples/simple_embedding/simple_embedding_test_bytecode_module_dylib_riscv_64_c.h"
#include "iree/samples/simple_embedding/simple_embedding_test_bytecode_module_dylib_x86_64_c.h"
-iree_status_t create_sample_device(iree_hal_device_t** device) {
+iree_status_t create_sample_device(iree_allocator_t host_allocator,
+ iree_hal_device_t** out_device) {
// Set paramters for the device created in the next step.
iree_hal_task_device_params_t params;
iree_hal_task_device_params_initialize(¶ms);
iree_hal_executable_loader_t* loader = NULL;
IREE_RETURN_IF_ERROR(iree_hal_embedded_library_loader_create(
- iree_hal_executable_import_provider_null(), iree_allocator_system(),
- &loader));
+ iree_hal_executable_import_provider_null(), host_allocator, &loader));
iree_task_executor_t* executor = NULL;
iree_status_t status =
- iree_task_executor_create_from_flags(iree_allocator_system(), &executor);
+ iree_task_executor_create_from_flags(host_allocator, &executor);
+ // Use the default host allocator for buffer allocations.
iree_string_view_t identifier = iree_make_cstring_view("dylib");
+ iree_hal_allocator_t* device_allocator = NULL;
if (iree_status_is_ok(status)) {
- // Create the device.
+ status = iree_hal_allocator_create_heap(identifier, host_allocator,
+ host_allocator, &device_allocator);
+ }
+
+ // Create the device.
+ if (iree_status_is_ok(status)) {
status = iree_hal_task_device_create(identifier, ¶ms, executor,
/*loader_count=*/1, &loader,
- iree_allocator_system(), device);
+ device_allocator, host_allocator,
+ out_device);
}
+
+ iree_hal_allocator_release(device_allocator);
iree_task_executor_release(executor);
iree_hal_executable_loader_release(loader);
return status;
diff --git a/iree/samples/simple_embedding/device_embedded_sync.c b/iree/samples/simple_embedding/device_embedded_sync.c
index 02cc6db..edd6565 100644
--- a/iree/samples/simple_embedding/device_embedded_sync.c
+++ b/iree/samples/simple_embedding/device_embedded_sync.c
@@ -27,22 +27,30 @@
#include "iree/samples/simple_embedding/simple_embedding_test_bytecode_module_dylib_x86_64_c.h"
#endif
-iree_status_t create_sample_device(iree_hal_device_t** device) {
+iree_status_t create_sample_device(iree_allocator_t host_allocator,
+ iree_hal_device_t** out_device) {
// Set parameters for the device created in the next step.
iree_hal_sync_device_params_t params;
iree_hal_sync_device_params_initialize(¶ms);
iree_hal_executable_loader_t* loader = NULL;
IREE_RETURN_IF_ERROR(iree_hal_embedded_library_loader_create(
- iree_hal_executable_import_provider_null(), iree_allocator_system(),
- &loader));
+ iree_hal_executable_import_provider_null(), host_allocator, &loader));
+ // Use the default host allocator for buffer allocations.
iree_string_view_t identifier = iree_make_cstring_view("dylib");
+ iree_hal_allocator_t* device_allocator = NULL;
+ iree_status_t status = iree_hal_allocator_create_heap(
+ identifier, host_allocator, host_allocator, &device_allocator);
// Create the synchronous device and release the loader afterwards.
- iree_status_t status =
- iree_hal_sync_device_create(identifier, ¶ms, /*loader_count=*/1,
- &loader, iree_allocator_system(), device);
+ if (iree_status_is_ok(status)) {
+ status = iree_hal_sync_device_create(
+ identifier, ¶ms, /*loader_count=*/1, &loader, device_allocator,
+ host_allocator, out_device);
+ }
+
+ iree_hal_allocator_release(device_allocator);
iree_hal_executable_loader_release(loader);
return status;
}
diff --git a/iree/samples/simple_embedding/device_vmvx_sync.c b/iree/samples/simple_embedding/device_vmvx_sync.c
index a41fb9d..44f2a72 100644
--- a/iree/samples/simple_embedding/device_vmvx_sync.c
+++ b/iree/samples/simple_embedding/device_vmvx_sync.c
@@ -17,27 +17,36 @@
// Compiled module embedded here to avoid file IO:
#include "iree/samples/simple_embedding/simple_embedding_test_bytecode_module_vmvx_c.h"
-iree_status_t create_sample_device(iree_hal_device_t** device) {
+iree_status_t create_sample_device(iree_allocator_t host_allocator,
+ iree_hal_device_t** out_device) {
// Set parameters for the device created in the next step.
iree_hal_sync_device_params_t params;
iree_hal_sync_device_params_initialize(¶ms);
iree_vm_instance_t* instance = NULL;
- IREE_RETURN_IF_ERROR(
- iree_vm_instance_create(iree_allocator_system(), &instance));
+ IREE_RETURN_IF_ERROR(iree_vm_instance_create(host_allocator, &instance));
iree_hal_executable_loader_t* loader = NULL;
- iree_status_t status = iree_hal_vmvx_module_loader_create(
- instance, iree_allocator_system(), &loader);
+ iree_status_t status =
+ iree_hal_vmvx_module_loader_create(instance, host_allocator, &loader);
iree_vm_instance_release(instance);
+ // Use the default host allocator for buffer allocations.
iree_string_view_t identifier = iree_make_cstring_view("vmvx");
+ iree_hal_allocator_t* device_allocator = NULL;
+ if (iree_status_is_ok(status)) {
+ status = iree_hal_allocator_create_heap(identifier, host_allocator,
+ host_allocator, &device_allocator);
+ }
+
if (iree_status_is_ok(status)) {
// Create the synchronous device.
- status =
- iree_hal_sync_device_create(identifier, ¶ms, /*loader_count=*/1,
- &loader, iree_allocator_system(), device);
+ status = iree_hal_sync_device_create(
+ identifier, ¶ms, /*loader_count=*/1, &loader, device_allocator,
+ host_allocator, out_device);
}
+
+ iree_hal_allocator_release(device_allocator);
iree_hal_executable_loader_release(loader);
return status;
}
diff --git a/iree/samples/simple_embedding/device_vulkan.c b/iree/samples/simple_embedding/device_vulkan.c
index cc73a44..c77e164 100644
--- a/iree/samples/simple_embedding/device_vulkan.c
+++ b/iree/samples/simple_embedding/device_vulkan.c
@@ -15,18 +15,24 @@
// Compiled module embedded here to avoid file IO:
#include "iree/samples/simple_embedding/simple_embedding_test_bytecode_module_vulkan_c.h"
-iree_status_t create_sample_device(iree_hal_device_t** device) {
- // Only register the vulkan HAL driver.
+iree_status_t create_sample_device(iree_allocator_t host_allocator,
+ iree_hal_device_t** out_device) {
+ // Only register the Vulkan HAL driver.
IREE_RETURN_IF_ERROR(iree_hal_vulkan_driver_module_register(
iree_hal_driver_registry_default()));
- // Create the hal driver from the name.
+
+ // Create the HAL driver from the name.
iree_hal_driver_t* driver = NULL;
iree_string_view_t identifier = iree_make_cstring_view("vulkan");
- IREE_RETURN_IF_ERROR(iree_hal_driver_registry_try_create_by_name(
- iree_hal_driver_registry_default(), identifier, iree_allocator_system(),
- &driver));
- IREE_RETURN_IF_ERROR(iree_hal_driver_create_default_device(
- driver, iree_allocator_system(), device));
+ iree_status_t status = iree_hal_driver_registry_try_create_by_name(
+ iree_hal_driver_registry_default(), identifier, host_allocator, &driver);
+
+ // Create the default device (primary GPU).
+ if (iree_status_is_ok(status)) {
+ status = iree_hal_driver_create_default_device(driver, host_allocator,
+ out_device);
+ }
+
iree_hal_driver_release(driver);
return iree_ok_status();
}
diff --git a/iree/samples/simple_embedding/simple_embedding.c b/iree/samples/simple_embedding/simple_embedding.c
index b9ce44d..93b7a1f 100644
--- a/iree/samples/simple_embedding/simple_embedding.c
+++ b/iree/samples/simple_embedding/simple_embedding.c
@@ -19,7 +19,8 @@
// A function to create the HAL device from the different backend targets.
// The HAL device is returned based on the implementation, and it must be
// released by the caller.
-extern iree_status_t create_sample_device(iree_hal_device_t** device);
+extern iree_status_t create_sample_device(iree_allocator_t host_allocator,
+ iree_hal_device_t** out_device);
// A function to load the vm bytecode module from the different backend targets.
// The bytecode module is generated for the specific backend and platform.
@@ -34,7 +35,8 @@
iree_vm_instance_create(iree_allocator_system(), &instance));
iree_hal_device_t* device = NULL;
- IREE_RETURN_IF_ERROR(create_sample_device(&device), "create device");
+ IREE_RETURN_IF_ERROR(create_sample_device(iree_allocator_system(), &device),
+ "create device");
iree_vm_module_t* hal_module = NULL;
IREE_RETURN_IF_ERROR(
iree_hal_module_create(device, iree_allocator_system(), &hal_module));
diff --git a/iree/samples/static_library/static_library_demo.c b/iree/samples/static_library/static_library_demo.c
index 6b9e36e..637d01b 100644
--- a/iree/samples/static_library/static_library_demo.c
+++ b/iree/samples/static_library/static_library_demo.c
@@ -23,7 +23,8 @@
// A function to create the HAL device from the different backend targets.
// The HAL device is returned based on the implementation, and it must be
// released by the caller.
-iree_status_t create_device_with_static_loader(iree_hal_device_t** device) {
+iree_status_t create_device_with_static_loader(iree_allocator_t host_allocator,
+ iree_hal_device_t** out_device) {
iree_status_t status = iree_ok_status();
// Set paramters for the device created in the next step.
@@ -40,18 +41,27 @@
if (iree_status_is_ok(status)) {
status = iree_hal_static_library_loader_create(
IREE_ARRAYSIZE(libraries), libraries,
- iree_hal_executable_import_provider_null(), iree_allocator_system(),
+ iree_hal_executable_import_provider_null(), host_allocator,
&library_loader);
}
+ // Use the default host allocator for buffer allocations.
+ iree_string_view_t identifier = iree_make_cstring_view("sync");
+ iree_hal_allocator_t* device_allocator = NULL;
+ if (iree_status_is_ok(status)) {
+ status = iree_hal_allocator_create_heap(identifier, host_allocator,
+ host_allocator, &device_allocator);
+ }
+
// Create the device and release the executor and loader afterwards.
if (iree_status_is_ok(status)) {
status = iree_hal_sync_device_create(
- iree_make_cstring_view("dylib"), ¶ms, /*loader_count=*/1,
- &library_loader, iree_allocator_system(), device);
+ identifier, ¶ms, /*loader_count=*/1, &library_loader,
+ device_allocator, host_allocator, out_device);
}
- iree_hal_executable_loader_release(library_loader);
+ iree_hal_allocator_release(device_allocator);
+ iree_hal_executable_loader_release(library_loader);
return status;
}
@@ -73,7 +83,7 @@
// Create dylib device with static loader.
iree_hal_device_t* device = NULL;
if (iree_status_is_ok(status)) {
- status = create_device_with_static_loader(&device);
+ status = create_device_with_static_loader(iree_allocator_system(), &device);
}
// Session configuration (one per loaded module to hold module state).
diff --git a/iree/samples/variables_and_state/main.c b/iree/samples/variables_and_state/main.c
index 5abb8a0..302ff4e 100644
--- a/iree/samples/variables_and_state/main.c
+++ b/iree/samples/variables_and_state/main.c
@@ -46,7 +46,6 @@
session, iree_make_cstring_view("module.set_value"), &call));
iree_hal_buffer_view_t* arg0 = NULL;
- static const iree_hal_dim_t arg0_shape[1] = {1};
int arg0_data[1] = {new_value};
// TODO(scotttodd): use iree_hal_buffer_view_wrap_or_clone_heap_buffer
@@ -54,8 +53,8 @@
iree_status_t status = iree_ok_status();
if (iree_status_is_ok(status)) {
status = iree_hal_buffer_view_clone_heap_buffer(
- iree_runtime_session_device_allocator(session), arg0_shape,
- IREE_ARRAYSIZE(arg0_shape), IREE_HAL_ELEMENT_TYPE_SINT_32,
+ iree_runtime_session_device_allocator(session), /*shape=*/NULL,
+ /*shape_rank=*/0, IREE_HAL_ELEMENT_TYPE_SINT_32,
IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR,
IREE_HAL_MEMORY_TYPE_HOST_LOCAL | IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE,
IREE_HAL_BUFFER_USAGE_ALL,
@@ -79,7 +78,6 @@
session, iree_make_cstring_view("module.add_to_value"), &call));
iree_hal_buffer_view_t* arg0 = NULL;
- static const iree_hal_dim_t arg0_shape[1] = {1};
int arg0_data[1] = {x};
// TODO(scotttodd): use iree_hal_buffer_view_wrap_or_clone_heap_buffer
@@ -87,8 +85,8 @@
iree_status_t status = iree_ok_status();
if (iree_status_is_ok(status)) {
status = iree_hal_buffer_view_clone_heap_buffer(
- iree_runtime_session_device_allocator(session), arg0_shape,
- IREE_ARRAYSIZE(arg0_shape), IREE_HAL_ELEMENT_TYPE_SINT_32,
+ iree_runtime_session_device_allocator(session), /*shape=*/NULL,
+ /*shape_rank=*/0, IREE_HAL_ELEMENT_TYPE_SINT_32,
IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR,
IREE_HAL_MEMORY_TYPE_HOST_LOCAL | IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE,
IREE_HAL_BUFFER_USAGE_ALL,
diff --git a/iree/samples/variables_and_state/variables_and_state.ipynb b/iree/samples/variables_and_state/variables_and_state.ipynb
index bacfc57..733f9f6 100644
--- a/iree/samples/variables_and_state/variables_and_state.ipynb
+++ b/iree/samples/variables_and_state/variables_and_state.ipynb
@@ -192,7 +192,7 @@
"module {\n",
" util.global private mutable @counter = dense<0> : tensor<i32>\n",
" func @add_to_value(%arg0: !hal.buffer_view) attributes {iree.abi = \"{\\22a\\22:[[\\22ndarray\\22,\\22i32\\22,0]],\\22r\\22:[],\\22v\\22:1}\"} {\n",
- " %0 = hal.tensor.cast %arg0 : !hal.buffer_view -> tensor<i32>\n",
+ " %0 = hal.tensor.import %arg0 : !hal.buffer_view -> tensor<i32>\n",
" call @__inference_add_to_value_100(%0) : (tensor<i32>) -> ()\n",
" return\n",
" }\n",
@@ -209,7 +209,7 @@
" }\n",
" func @get_value() -> !hal.buffer_view attributes {iree.abi = \"{\\22a\\22:[],\\22r\\22:[[\\22ndarray\\22,\\22i32\\22,0]],\\22v\\22:1}\"} {\n",
" %0 = util.global.load @counter : tensor<i32>\n",
- " %1 = hal.tensor.cast %0 : tensor<i32> -> !hal.buffer_view\n",
+ " %1 = hal.tensor.export %0 : tensor<i32> -> !hal.buffer_view\n",
" return %1 : !hal.buffer_view\n",
" }\n",
" func @reset_value() attributes {iree.abi = \"{\\22a\\22:[],\\22r\\22:[],\\22v\\22:1}\"} {\n",
@@ -218,7 +218,7 @@
" return\n",
" }\n",
" func @set_value(%arg0: !hal.buffer_view) attributes {iree.abi = \"{\\22a\\22:[[\\22ndarray\\22,\\22i32\\22,0]],\\22r\\22:[],\\22v\\22:1}\"} {\n",
- " %0 = hal.tensor.cast %arg0 : !hal.buffer_view -> tensor<i32>\n",
+ " %0 = hal.tensor.import %arg0 : !hal.buffer_view -> tensor<i32>\n",
" util.global.store %0, @counter : tensor<i32>\n",
" return\n",
" }\n",
diff --git a/iree/tools/BUILD b/iree/tools/BUILD
index 039fc0c..4be1159 100644
--- a/iree/tools/BUILD
+++ b/iree/tools/BUILD
@@ -101,15 +101,13 @@
"//iree/compiler/Bindings/Native/Transforms",
"//iree/compiler/Bindings/TFLite/Transforms",
"//iree/compiler/Codegen/Dialect:IREECodegenDialect",
- "//iree/compiler/Codegen/Dialect:ProcessorOpInterfaces",
+ "//iree/compiler/Codegen/Interfaces",
"//iree/compiler/Dialect/Flow/IR",
"//iree/compiler/Dialect/Flow/Transforms",
"//iree/compiler/Dialect/HAL/IR:HALDialect",
"//iree/compiler/Dialect/HAL/Transforms",
"//iree/compiler/Dialect/Modules/VMVX/IR:VMVXDialect",
"//iree/compiler/Dialect/Modules/VMVX/Transforms",
- "//iree/compiler/Dialect/Shape/IR",
- "//iree/compiler/Dialect/Shape/Transforms",
"//iree/compiler/Dialect/Stream/IR",
"//iree/compiler/Dialect/Stream/Transforms",
"//iree/compiler/Dialect/Util/IR",
diff --git a/iree/tools/CMakeLists.txt b/iree/tools/CMakeLists.txt
index 0dbda94..66be56d 100644
--- a/iree/tools/CMakeLists.txt
+++ b/iree/tools/CMakeLists.txt
@@ -217,15 +217,13 @@
iree::compiler::Bindings::Native::Transforms
iree::compiler::Bindings::TFLite::Transforms
iree::compiler::Codegen::Dialect::IREECodegenDialect
- iree::compiler::Codegen::Dialect::ProcessorOpInterfaces
+ iree::compiler::Codegen::Interfaces::Interfaces
iree::compiler::Dialect::Flow::IR
iree::compiler::Dialect::Flow::Transforms
iree::compiler::Dialect::HAL::IR::HALDialect
iree::compiler::Dialect::HAL::Transforms
iree::compiler::Dialect::Modules::VMVX::IR::VMVXDialect
iree::compiler::Dialect::Modules::VMVX::Transforms
- iree::compiler::Dialect::Shape::IR
- iree::compiler::Dialect::Shape::Transforms
iree::compiler::Dialect::Stream::IR
iree::compiler::Dialect::Stream::Transforms
iree::compiler::Dialect::Util::IR
@@ -319,7 +317,9 @@
HDRS
"init_xla_dialects.h"
DEPS
- tensorflow::mlir_hlo
+ tensorflow::external_mhlo_includes
+ ChloDialect
+ MhloDialect
PUBLIC
)
diff --git a/iree/tools/init_iree_dialects.h b/iree/tools/init_iree_dialects.h
index 07c0221..e4c8277 100644
--- a/iree/tools/init_iree_dialects.h
+++ b/iree/tools/init_iree_dialects.h
@@ -17,11 +17,10 @@
#include "iree-dialects/Dialect/LinalgExt/IR/TiledOpInterface.h"
#include "iree-dialects/Dialect/PyDM/IR/PyDMDialect.h"
#include "iree/compiler/Codegen/Dialect/IREECodegenDialect.h"
-#include "iree/compiler/Codegen/Dialect/ProcessorOpInterfaces.h"
+#include "iree/compiler/Codegen/Interfaces/Interfaces.h"
#include "iree/compiler/Dialect/Flow/IR/FlowDialect.h"
#include "iree/compiler/Dialect/HAL/IR/HALDialect.h"
#include "iree/compiler/Dialect/Modules/VMVX/IR/VMVXDialect.h"
-#include "iree/compiler/Dialect/Shape/IR/ShapeDialect.h"
#include "iree/compiler/Dialect/Stream/IR/StreamDialect.h"
#include "iree/compiler/Dialect/Util/IR/UtilDialect.h"
#include "iree/compiler/Dialect/VM/IR/VMDialect.h"
@@ -38,7 +37,6 @@
IREE::Flow::FlowDialect,
IREE::HAL::HALDialect,
IREE::LinalgExt::IREELinalgExtDialect,
- ShapeDialect,
IREE::Stream::StreamDialect,
IREE::Util::UtilDialect,
IREE::VM::VMDialect,
@@ -49,7 +47,7 @@
// clang-format on
IREE::LinalgExt::registerTiledOpInterfaceExternalModels(registry);
- registerProcessorOpInterfaceExternalModels(registry);
+ registerCodegenInterfaces(registry);
}
} // namespace iree_compiler
diff --git a/iree/tools/init_iree_passes.h b/iree/tools/init_iree_passes.h
index fc1e15e..fd26b3e 100644
--- a/iree/tools/init_iree_passes.h
+++ b/iree/tools/init_iree_passes.h
@@ -20,7 +20,6 @@
#include "iree/compiler/Dialect/Flow/Transforms/Passes.h"
#include "iree/compiler/Dialect/HAL/Transforms/Passes.h"
#include "iree/compiler/Dialect/Modules/VMVX/Transforms/Passes.h"
-#include "iree/compiler/Dialect/Shape/Transforms/Passes.h"
#include "iree/compiler/Dialect/Stream/Transforms/Passes.h"
#include "iree/compiler/Dialect/Util/Transforms/Passes.h"
#include "iree/compiler/Dialect/VM/Analysis/TestPasses.h"
@@ -42,10 +41,9 @@
IREE::TFLite::registerTransformPassPipeline();
registerCommonInputConversionPasses();
- registerMHLOConversionPasses();
+ MHLO::registerMHLOConversionPasses();
registerTOSAConversionPasses();
- Shape::registerShapePasses();
IREE::Flow::registerFlowPasses();
IREE::HAL::registerHALPasses();
IREE::LinalgExt::registerPasses();
diff --git a/iree/tools/utils/yaml_util.c b/iree/tools/utils/yaml_util.c
index f721d00..a3092a5 100644
--- a/iree/tools/utils/yaml_util.c
+++ b/iree/tools/utils/yaml_util.c
@@ -99,7 +99,7 @@
size_t decoded_length = 0;
size_t i = 0;
while (i < source.size) {
- uint8_t c = iree_yaml_base64_decode_table[source.data[i++]];
+ uint8_t c = iree_yaml_base64_decode_table[(uint8_t)source.data[i++]];
if (c == IREE_YAML_BASE64_WHITESPACE) {
// Skip whitespace.
continue;
@@ -138,7 +138,7 @@
size_t i = 0;
uint8_t* p = target.data;
while (i < source.size) {
- uint8_t c = iree_yaml_base64_decode_table[source.data[i++]];
+ uint8_t c = iree_yaml_base64_decode_table[(uint8_t)source.data[i++]];
if (c == IREE_YAML_BASE64_WHITESPACE) {
// Skip whitespace.
continue;
diff --git a/iree/vm/native_module.c b/iree/vm/native_module.c
index 7e90813..e941089 100644
--- a/iree/vm/native_module.c
+++ b/iree/vm/native_module.c
@@ -304,14 +304,20 @@
iree_status_t status = function_ptr->shim(stack, call, function_ptr->target,
module, module_state, out_result);
if (IREE_UNLIKELY(!iree_status_is_ok(status))) {
- iree_string_view_t module_name = iree_vm_native_module_name(module);
- iree_string_view_t function_name = iree_string_view_empty();
+#if IREE_STATUS_FEATURES & IREE_STATUS_FEATURE_ANNOTATIONS
+ iree_string_view_t module_name IREE_ATTRIBUTE_UNUSED =
+ iree_vm_native_module_name(module);
+ iree_string_view_t function_name IREE_ATTRIBUTE_UNUSED =
+ iree_string_view_empty();
iree_status_ignore(iree_vm_native_module_get_export_function(
module, call->function.ordinal, NULL, &function_name, NULL));
return iree_status_annotate_f(status,
"while invoking native function %.*s.%.*s",
(int)module_name.size, module_name.data,
(int)function_name.size, function_name.data);
+#else
+ return status;
+#endif // IREE_STATUS_FEATURES & IREE_STATUS_FEATURE_ANNOTATIONS
}
return iree_vm_stack_function_leave(stack);
diff --git a/llvm-external-projects/iree-compiler-api/BUILD.bazel b/llvm-external-projects/iree-compiler-api/BUILD.bazel
index 85b9a95..94afcee 100644
--- a/llvm-external-projects/iree-compiler-api/BUILD.bazel
+++ b/llvm-external-projects/iree-compiler-api/BUILD.bazel
@@ -66,6 +66,8 @@
deps = [
"//iree/compiler/Dialect/VM/IR",
"//iree/compiler/Dialect/VM/Target/Bytecode",
+ "//iree/compiler/InputConversion/MHLO",
+ "//iree/compiler/InputConversion/TOSA",
"//iree/compiler/Translation:IREEVM",
"//iree/tools:init_targets",
"//iree/tools:iree_translate_lib",
diff --git a/llvm-external-projects/iree-compiler-api/include/iree-compiler-c/Compiler.h b/llvm-external-projects/iree-compiler-api/include/iree-compiler-c/Compiler.h
index 2be688c..c1fd471 100644
--- a/llvm-external-projects/iree-compiler-api/include/iree-compiler-c/Compiler.h
+++ b/llvm-external-projects/iree-compiler-api/include/iree-compiler-c/Compiler.h
@@ -41,6 +41,8 @@
IreeCompilerOptions options);
MLIR_CAPI_EXPORTED void ireeCompilerOptionsSetInputDialectTOSA(
IreeCompilerOptions options);
+MLIR_CAPI_EXPORTED void ireeCompilerOptionsSetInputDialectXLA(
+ IreeCompilerOptions options);
MLIR_CAPI_EXPORTED void ireeCompilerOptionsAddTargetBackend(
IreeCompilerOptions options, const char *targetBackend);
@@ -48,6 +50,26 @@
// Compiler stages.
//===----------------------------------------------------------------------===//
+// Builds a pass pipeline to cleanup MHLO dialect input derived from XLA.
+// A pass pipeline of this plus ireeCompilerBuildMHLOImportPassPipeline is
+// equivalent to ireeCompilerOptionsSetInputDialectXLA as a one-shot.
+MLIR_CAPI_EXPORTED void ireeCompilerBuildXLACleanupPassPipeline(
+ MlirOpPassManager passManager);
+
+// Builds a pass pipeline to lower IREE-compatible MHLO functions and ops to
+// be a legal input to IREE. This performs the standalone work that
+// ireeCompilerOptionsSetInputDialectMHLO will do as a one-shot. Notably, this
+// requires that XLA control flow has been legalized to SCF or CFG and that
+// no tuples are in the input program.
+MLIR_CAPI_EXPORTED void ireeCompilerBuildMHLOImportPassPipeline(
+ MlirOpPassManager passManager);
+
+// Builds a pass pipeline to lower IREE-compatible TOSA function and ops to
+// be a legal input to IREE. This performs that standalone work that
+// ireeCompilerOptionsSetInputDialectTOSA will do as a one-shot.
+MLIR_CAPI_EXPORTED void ireeCompilerBuildTOSAImportPassPipeline(
+ MlirOpPassManager passManager);
+
// Builds a pass manager for transforming from an input module op to the IREE VM
// dialect. This represents the primary compilation stage with serialization to
// specific formats following.
diff --git a/llvm-external-projects/iree-compiler-api/lib/CAPI/CMakeLists.txt b/llvm-external-projects/iree-compiler-api/lib/CAPI/CMakeLists.txt
index 13d041f..b24e995 100644
--- a/llvm-external-projects/iree-compiler-api/lib/CAPI/CMakeLists.txt
+++ b/llvm-external-projects/iree-compiler-api/lib/CAPI/CMakeLists.txt
@@ -8,6 +8,8 @@
Support
LINK_LIBS PUBLIC
MLIRIR
+ iree::compiler::InputConversion::MHLO::MHLO
+ iree::compiler::InputConversion::TOSA::TOSA
iree::compiler::Dialect::VM::IR::IR
iree::compiler::Dialect::VM::Target::Bytecode::Bytecode
iree::compiler::Translation::IREEVM
diff --git a/llvm-external-projects/iree-compiler-api/lib/CAPI/Compiler.cpp b/llvm-external-projects/iree-compiler-api/lib/CAPI/Compiler.cpp
index d13028d..0f996f8 100644
--- a/llvm-external-projects/iree-compiler-api/lib/CAPI/Compiler.cpp
+++ b/llvm-external-projects/iree-compiler-api/lib/CAPI/Compiler.cpp
@@ -8,6 +8,8 @@
#include "iree/compiler/Dialect/VM/IR/VMOps.h"
#include "iree/compiler/Dialect/VM/Target/Bytecode/BytecodeModuleTarget.h"
+#include "iree/compiler/InputConversion/MHLO/Passes.h"
+#include "iree/compiler/InputConversion/TOSA/Passes.h"
#include "iree/compiler/Translation/IREEVM.h"
#include "iree/tools/init_targets.h"
#include "mlir/CAPI/IR.h"
@@ -68,6 +70,25 @@
unwrap(options)->inputDialectOptions.type = InputDialectOptions::Type::tosa;
}
+void ireeCompilerOptionsSetInputDialectXLA(IreeCompilerOptions options) {
+ unwrap(options)->inputDialectOptions.type = InputDialectOptions::Type::xla;
+}
+
+void ireeCompilerBuildXLACleanupPassPipeline(MlirOpPassManager passManager) {
+ auto *passManagerCpp = unwrap(passManager);
+ MHLO::buildXLACleanupPassPipeline(*passManagerCpp);
+}
+
+void ireeCompilerBuildMHLOImportPassPipeline(MlirOpPassManager passManager) {
+ auto *passManagerCpp = unwrap(passManager);
+ MHLO::buildMHLOInputConversionPassPipeline(*passManagerCpp);
+}
+
+void ireeCompilerBuildTOSAImportPassPipeline(MlirOpPassManager passManager) {
+ auto *passManagerCpp = unwrap(passManager);
+ buildTOSAInputConversionPassPipeline(*passManagerCpp);
+}
+
void ireeCompilerBuildIREEVMPassPipeline(IreeCompilerOptions options,
MlirOpPassManager passManager) {
auto *optionsCpp = unwrap(options);
diff --git a/llvm-external-projects/iree-compiler-api/python/CompilerModule.cpp b/llvm-external-projects/iree-compiler-api/python/CompilerModule.cpp
index b4c6038..37256ce 100644
--- a/llvm-external-projects/iree-compiler-api/python/CompilerModule.cpp
+++ b/llvm-external-projects/iree-compiler-api/python/CompilerModule.cpp
@@ -57,6 +57,26 @@
} // namespace
+static const char BUILD_MHLO_IMPORT_PASS_PIPELINE_DOCSTRING[] =
+ R"(Populates MHLO import passes on a PassManager.
+
+This enables standalone access to the import pipeline that can be run as part
+of main compilation with the `set_input_dialect_mhlo()` option. It is provided
+seprately to facilitate integration with frontend workflows.
+
+This pipeline requires IREE-compatible MHLO input: MHLO control flow must have
+been legalized to SCF or CFG and tuples must not exist. See the
+`build_xla_cleanup_pass_pipeline` for assistance if interoping with such IR.
+)";
+
+static const char BUILD_TOSA_IMPORT_PASS_PIPELINE_DOCSTRING[] =
+ R"(Populates TOSA import passes on a PassManager.
+
+This enables standalone access to the import pipeline that can be run as part
+of main compilation with the `set_input_dialect_tosa()` option. It is provided
+seprately to facilitate integration with frontend workflows.
+)";
+
static const char BUILD_IREE_VM_PASS_PIPELINE_DOCSTRING[] =
R"(Populates VM compilation pass on a PassManager.
@@ -65,6 +85,15 @@
IREE's lowest level representation.
)";
+static const char BUILD_XLA_CLEANUP_PASS_PIPELINE_DOCSTRING[] =
+ R"(Populates passes to cleanup XLA-imported MHLO to comply with IREE.
+
+Combining this pipeline with `build_mhlo_import_pass_pipeline()` provides
+standalone access to the import pipeline that can be run as part of main
+compilation with the `set_input_dialect_xla()` option. It is provided
+separately to facilitate integration with frontend workflows.
+)";
+
static const char TRANSLATE_MODULE_TO_VM_BYTECODE_DOCSTRING[] =
R"(Given a `vm.module` translate it to VM bytecode.
@@ -92,6 +121,13 @@
},
"Sets the input type to the 'tosa' dialect")
.def(
+ "set_input_dialect_xla",
+ [](PyCompilerOptions &self) {
+ ireeCompilerOptionsSetInputDialectTOSA(self.options);
+ },
+ "Sets the input type to the 'mhlo' dialect with XLA compatibility "
+ "cleanups")
+ .def(
"add_target_backend",
[](PyCompilerOptions &self, const std::string &targetBackend) {
ireeCompilerOptionsAddTargetBackend(self.options,
@@ -99,7 +135,30 @@
},
py::arg("target_backend"),
"Adds a target backend (i.e. 'cpu', 'vulkan-spirv', etc)");
-
+ m.def(
+ "build_mhlo_import_pass_pipeline",
+ [](MlirPassManager passManager) {
+ MlirOpPassManager opPassManager =
+ mlirPassManagerGetAsOpPassManager(passManager);
+ ireeCompilerBuildMHLOImportPassPipeline(opPassManager);
+ },
+ py::arg("pass_manager"), BUILD_MHLO_IMPORT_PASS_PIPELINE_DOCSTRING);
+ m.def(
+ "build_tosa_import_pass_pipeline",
+ [](MlirPassManager passManager) {
+ MlirOpPassManager opPassManager =
+ mlirPassManagerGetAsOpPassManager(passManager);
+ ireeCompilerBuildTOSAImportPassPipeline(opPassManager);
+ },
+ py::arg("pass_manager"), BUILD_TOSA_IMPORT_PASS_PIPELINE_DOCSTRING);
+ m.def(
+ "build_xla_cleanup_pass_pipeline",
+ [](MlirPassManager passManager) {
+ MlirOpPassManager opPassManager =
+ mlirPassManagerGetAsOpPassManager(passManager);
+ ireeCompilerBuildXLACleanupPassPipeline(opPassManager);
+ },
+ py::arg("pass_manager"), BUILD_XLA_CLEANUP_PASS_PIPELINE_DOCSTRING);
m.def(
"build_iree_vm_pass_pipeline",
[](PyCompilerOptions &options, MlirPassManager passManager) {