Merge google -> main (#5574)
* 97a7c8e18 Synchronize submodules with LLVM at llvm/llvm-project@2f69975683f5
* 3c09f24d1 Integrate LLVM at llvm/llvm-project@2f69975683f5
* 44d106317 Merge pull request #5551 from GMNGeoffrey:main-to-google
* 124307091 Integrate LLVM at llvm/llvm-project@1a3f88658a02
* 361e47bc4 Integrate LLVM at llvm/llvm-project@4bb60c285cb3
* eaeb09d38 Integrate LLVM at llvm/llvm-project@9430efa18b02
* dc0af0cfc Synchronize submodules with LLVM at llvm/llvm-project@e0adf7e06a9e
* 678b1e104 Integrate LLVM at llvm/llvm-project@e0adf7e06a9e
* 98cf08799 Integrate LLVM at llvm/llvm-project@01ace074fcb6
* 307cef033 Integrate LLVM at llvm/llvm-project@aa80ea8a617b
* 297f730be Integrate LLVM at llvm/llvm-project@fb69b92c7b33
* 5f1f4fe13 Integrate LLVM at llvm/llvm-project@f549176ad976
* bda62f0e6 Merge main -> google
* 0adee0f42 Merge pull request #5520 from GMNGeoffrey:main-to-google
* e302cb36a Synchronize submodules with LLVM at llvm/llvm-project@80e166f81abd
* 6c9457038 Integrate LLVM at llvm/llvm-project@80e166f81abd
* ce364ae7d Integrate LLVM at llvm/llvm-project@517c3aee4de5
* 32792f7a9 Merge pull request #5500 from rsuderman:main-to-google
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 9958c07..92f7dfe 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -218,6 +218,7 @@
include(iree_tablegen_library)
include(iree_tablegen_doc)
include(iree_cc_embed_data)
+include(iree_c_embed_data)
include(iree_bytecode_module)
include(iree_c_module)
include(iree_python)
diff --git a/build_tools/bazel_to_cmake/bazel_to_cmake_converter.py b/build_tools/bazel_to_cmake/bazel_to_cmake_converter.py
index 72ce425..5d86fae 100644
--- a/build_tools/bazel_to_cmake/bazel_to_cmake_converter.py
+++ b/build_tools/bazel_to_cmake/bazel_to_cmake_converter.py
@@ -447,6 +447,37 @@
f"{flatten_block}"
f" PUBLIC\n)\n\n")
+ def c_embed_data(self,
+ name,
+ srcs,
+ c_file_output,
+ h_file_output,
+ testonly=None,
+ strip_prefix=None,
+ flatten=None,
+ identifier=None,
+ **kwargs):
+ if identifier:
+ self._convert_unimplemented_function("c_embed_data",
+ name + " has identifier")
+ name_block = _convert_string_arg_block("NAME", name, quote=False)
+ srcs_block = _convert_srcs_block(srcs)
+ c_file_output_block = _convert_string_arg_block("C_FILE_OUTPUT",
+ c_file_output)
+ h_file_output_block = _convert_string_arg_block("H_FILE_OUTPUT",
+ h_file_output)
+ testonly_block = _convert_option_block("TESTONLY", testonly)
+ flatten_block = _convert_option_block("FLATTEN", flatten)
+
+ self.converter.body += (f"iree_c_embed_data(\n"
+ f"{name_block}"
+ f"{srcs_block}"
+ f"{c_file_output_block}"
+ f"{h_file_output_block}"
+ f"{testonly_block}"
+ f"{flatten_block}"
+ f" PUBLIC\n)\n\n")
+
def spirv_kernel_cc_library(self, name, srcs):
name_block = _convert_string_arg_block("NAME", name, quote=False)
srcs_block = _convert_srcs_block(srcs)
@@ -462,10 +493,12 @@
flags=None,
translate_tool=None,
cc_namespace=None,
+ c_output=None,
testonly=None):
name_block = _convert_string_arg_block("NAME", name, quote=False)
src_block = _convert_string_arg_block("SRC", src)
namespace_block = _convert_string_arg_block("CC_NAMESPACE", cc_namespace)
+ c_output_block = _convert_option_block("C_OUTPUT", c_output)
translate_tool_block = _convert_translate_tool_block(translate_tool)
flags_block = _convert_string_list_block("FLAGS", flags)
testonly_block = _convert_option_block("TESTONLY", testonly)
@@ -474,6 +507,7 @@
f"{name_block}"
f"{src_block}"
f"{namespace_block}"
+ f"{c_output_block}"
f"{translate_tool_block}"
f"{flags_block}"
f"{testonly_block}"
diff --git a/build_tools/buildkite/cmake/android/arm64-v8a/benchmark.yml b/build_tools/buildkite/cmake/android/arm64-v8a/benchmark.yml
index bb2e1b1..ea55ba6 100644
--- a/build_tools/buildkite/cmake/android/arm64-v8a/benchmark.yml
+++ b/build_tools/buildkite/cmake/android/arm64-v8a/benchmark.yml
@@ -89,3 +89,5 @@
- "mako-uploader=true"
- "test-desktop-gpu=false"
branches: "main"
+ soft_fail:
+ - exit_status: 1
diff --git a/build_tools/cmake/iree_bytecode_module.cmake b/build_tools/cmake/iree_bytecode_module.cmake
index 0af567a..e297083 100644
--- a/build_tools/cmake/iree_bytecode_module.cmake
+++ b/build_tools/cmake/iree_bytecode_module.cmake
@@ -26,6 +26,7 @@
# TRANSLATE_TOOL: Translation tool to invoke (CMake target). The default
# tool is "iree-translate".
# CC_NAMESPACE: Wraps everything in a C++ namespace.
+# C_OUTPUT: Control flag to generate c embed code instead.
# PUBLIC: Add this so that this library will be exported under ${PACKAGE}::
# Also in IDE, target will appear in ${PACKAGE} folder while non PUBLIC
# will be in ${PACKAGE}/internal.
@@ -39,7 +40,7 @@
function(iree_bytecode_module)
cmake_parse_arguments(
_RULE
- "PUBLIC;TESTONLY"
+ "PUBLIC;TESTONLY;C_OUTPUT"
"NAME;SRC;TRANSLATE_TOOL;CC_NAMESPACE"
"FLAGS"
${ARGN}
@@ -104,4 +105,22 @@
"${_TESTONLY_ARG}"
)
endif()
+
+ if(_RULE_C_OUTPUT)
+ iree_c_embed_data(
+ NAME
+ "${_RULE_NAME}_c"
+ IDENTIFIER
+ "${_RULE_NAME}_c"
+ GENERATED_SRCS
+ "${_RULE_NAME}.vmfb"
+ C_FILE_OUTPUT
+ "${_RULE_NAME}_c.c"
+ H_FILE_OUTPUT
+ "${_RULE_NAME}_c.h"
+ FLATTEN
+ "${_PUBLIC_ARG}"
+ "${_TESTONLY_ARG}"
+ )
+ endif()
endfunction()
diff --git a/build_tools/cmake/iree_c_embed_data.cmake b/build_tools/cmake/iree_c_embed_data.cmake
new file mode 100644
index 0000000..0037daf
--- /dev/null
+++ b/build_tools/cmake/iree_c_embed_data.cmake
@@ -0,0 +1,97 @@
+# Copyright 2021 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+include(CMakeParseArguments)
+
+# iree_c_embed_data()
+#
+# CMake function to imitate Bazel's c_embed_data rule.
+#
+# Parameters:
+# NAME: Name of target (see Note).
+# SRCS: List of source files to embed.
+# GENERATED_SRCS: List of generated source files to embed.
+# C_FILE_OUTPUT: The C implementation file to output.
+# H_FILE_OUTPUT: The H header file to output.
+# STRIP_PREFIX: Strips this verbatim prefix from filenames (in the TOC).
+# FLATTEN: Removes all directory components from filenames (in the TOC).
+# IDENTIFIER: The identifier to use in generated names (defaults to name).
+# PUBLIC: Add this so that this library will be exported under ${PACKAGE}::
+# Also in IDE, target will appear in ${PACKAGE} folder while non PUBLIC will be
+# in ${PACKAGE}/internal.
+# TESTONLY: When added, this target will only be built if user passes
+# -DIREE_BUILD_TESTS=ON to CMake.
+# TODO(scotttodd): Support passing KWARGS down into iree_cc_library?
+#
+function(iree_c_embed_data)
+ cmake_parse_arguments(
+ _RULE
+ "PUBLIC;TESTONLY;FLATTEN"
+ "NAME;IDENTIFIER;STRIP_PREFIX;C_FILE_OUTPUT;H_FILE_OUTPUT"
+ "SRCS;GENERATED_SRCS"
+ ${ARGN}
+ )
+
+ if(_RULE_TESTONLY AND NOT IREE_BUILD_TESTS)
+ return()
+ endif()
+
+ if(DEFINED _RULE_IDENTIFIER)
+ set(_IDENTIFIER ${_RULE_IDENTIFIER})
+ else()
+ set(_IDENTIFIER ${_RULE_NAME})
+ endif()
+
+ set(_ARGS)
+ list(APPEND _ARGS "--output_header=${_RULE_H_FILE_OUTPUT}")
+ list(APPEND _ARGS "--output_impl=${_RULE_C_FILE_OUTPUT}")
+ list(APPEND _ARGS "--identifier=${_IDENTIFIER}")
+ list(APPEND _ARGS "--c_output=true")
+ if(DEFINED _RULE_STRIP_PREFIX})
+ list(APPEND _ARGS "--strip_prefix=${_RULE_STRIP_PREFIX}")
+ endif()
+ if(DEFINED _RULE_FLATTEN})
+ list(APPEND _ARGS "--flatten")
+ endif()
+
+ foreach(SRC ${_RULE_SRCS})
+ list(APPEND _ARGS "${CMAKE_CURRENT_SOURCE_DIR}/${SRC}")
+ endforeach(SRC)
+ foreach(SRC ${_RULE_GENERATED_SRCS})
+ list(APPEND _ARGS "${SRC}")
+ endforeach(SRC)
+
+ iree_get_executable_path(_EXE_PATH generate_embed_data)
+
+ add_custom_command(
+ OUTPUT "${_RULE_H_FILE_OUTPUT}" "${_RULE_C_FILE_OUTPUT}"
+ COMMAND ${_EXE_PATH} ${_ARGS}
+ DEPENDS ${_EXE_PATH} ${_RULE_SRCS} ${_RULE_GENERATED_SRCS}
+ )
+
+ if(_RULE_TESTONLY)
+ set(_TESTONLY_ARG "TESTONLY")
+ endif()
+ if(_RULE_PUBLIC)
+ set(_PUBLIC_ARG "PUBLIC")
+ endif()
+
+ iree_cc_library(
+ NAME ${_RULE_NAME}
+ HDRS "${_RULE_H_FILE_OUTPUT}"
+ SRCS "${_RULE_C_FILE_OUTPUT}"
+ "${_PUBLIC_ARG}"
+ "${_TESTONLY_ARG}"
+ )
+endfunction()
diff --git a/build_tools/cmake/iree_cc_embed_data.cmake b/build_tools/cmake/iree_cc_embed_data.cmake
index 7eeac23..893c2e9 100644
--- a/build_tools/cmake/iree_cc_embed_data.cmake
+++ b/build_tools/cmake/iree_cc_embed_data.cmake
@@ -62,6 +62,7 @@
list(APPEND _ARGS "--output_header=${_RULE_H_FILE_OUTPUT}")
list(APPEND _ARGS "--output_impl=${_RULE_CC_FILE_OUTPUT}")
list(APPEND _ARGS "--identifier=${_IDENTIFIER}")
+ list(APPEND _ARGS "--c_output=false")
if(DEFINED _RULE_CPP_NAMESPACE)
list(APPEND _ARGS "--cpp_namespace=${_RULE_CPP_NAMESPACE}")
endif()
@@ -79,7 +80,7 @@
list(APPEND _ARGS "${SRC}")
endforeach(SRC)
- iree_get_executable_path(_EXE_PATH generate_cc_embed_data)
+ iree_get_executable_path(_EXE_PATH generate_embed_data)
add_custom_command(
OUTPUT "${_RULE_H_FILE_OUTPUT}" "${_RULE_CC_FILE_OUTPUT}"
diff --git a/build_tools/embed_data/BUILD b/build_tools/embed_data/BUILD
index 37097f6..0ade681 100644
--- a/build_tools/embed_data/BUILD
+++ b/build_tools/embed_data/BUILD
@@ -14,7 +14,7 @@
# Generates source files with embedded file contents.
-load(":build_defs.bzl", "cc_embed_data")
+load(":build_defs.bzl", "c_embed_data", "cc_embed_data")
package(
default_visibility = ["//visibility:public"],
@@ -23,8 +23,8 @@
)
cc_binary(
- name = "generate_cc_embed_data",
- srcs = ["generate_cc_embed_data_main.cc"],
+ name = "generate_embed_data",
+ srcs = ["generate_embed_data_main.cc"],
deps = [
"@com_google_absl//absl/flags:flag",
"@com_google_absl//absl/flags:parse",
@@ -66,3 +66,25 @@
"//iree/testing:gtest_main",
],
)
+
+c_embed_data(
+ name = "testembed1_c",
+ # do not sort
+ srcs = [
+ "file1.txt",
+ "data/file2.txt",
+ ],
+ c_file_output = "testembed1_c.c",
+ flatten = True,
+ h_file_output = "testembed1_c.h",
+)
+
+cc_test(
+ name = "c_embed_data_test",
+ srcs = ["c_embed_data_test.cc"],
+ deps = [
+ ":testembed1_c",
+ "//iree/testing:gtest",
+ "//iree/testing:gtest_main",
+ ],
+)
diff --git a/build_tools/embed_data/CMakeLists.txt b/build_tools/embed_data/CMakeLists.txt
index 7f28f75..92ac014 100644
--- a/build_tools/embed_data/CMakeLists.txt
+++ b/build_tools/embed_data/CMakeLists.txt
@@ -13,16 +13,17 @@
# limitations under the License.
if(NOT CMAKE_CROSSCOMPILING)
- add_executable(generate_cc_embed_data)
- target_sources(generate_cc_embed_data PRIVATE generate_cc_embed_data_main.cc)
- set_target_properties(generate_cc_embed_data PROPERTIES OUTPUT_NAME generate_cc_embed_data)
+ add_executable(generate_embed_data)
+ target_sources(generate_embed_data PRIVATE generate_embed_data_main.cc)
+ set_target_properties(generate_embed_data PROPERTIES OUTPUT_NAME generate_embed_data)
- target_link_libraries(generate_cc_embed_data
+ target_link_libraries(generate_embed_data
absl::flags
absl::flags_parse
absl::strings
)
- install(TARGETS generate_cc_embed_data
- COMPONENT generate_cc_embed_data
+
+ install(TARGETS generate_embed_data
+ COMPONENT generate_embed_data
RUNTIME DESTINATION bin)
endif()
diff --git a/build_tools/embed_data/build_defs.bzl b/build_tools/embed_data/build_defs.bzl
index 872542a..bb16601 100644
--- a/build_tools/embed_data/build_defs.bzl
+++ b/build_tools/embed_data/build_defs.bzl
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-"""Embeds data files into a C++ module."""
+"""Embeds data files into a C or C++ module."""
def cc_embed_data(
name,
@@ -57,7 +57,7 @@
identifier: The identifier to use in generated names (defaults to name).
**kwargs: Args to pass to the cc_library.
"""
- generator = "//build_tools/embed_data:generate_cc_embed_data"
+ generator = "//build_tools/embed_data:generate_embed_data"
generator_location = "$(location %s)" % generator
if identifier == None:
identifier = name
@@ -66,6 +66,7 @@
cc_file_output,
)
flags += " --identifier='%s'" % (identifier,)
+ flags += " --c_output=false"
if cpp_namespace != None:
flags += " --cpp_namespace='%s'" % (cpp_namespace,)
if strip_prefix != None:
@@ -91,3 +92,87 @@
testonly = testonly,
**kwargs
)
+
+def c_embed_data(
+ name,
+ srcs,
+ c_file_output,
+ h_file_output,
+ testonly = False,
+ strip_prefix = None,
+ flatten = False,
+ identifier = None,
+ **kwargs):
+ """Embeds 'srcs' into a C module.
+
+ Generates a header like:
+ #if __cplusplus
+ extern "C" {
+ #endif // __cplusplus
+ struct FileToc {
+ const char* name; // the file's original name
+ const char* data; // beginning of the file
+ size_t size; // length of the file
+ };
+ #if __cplusplus
+ }
+ #endif // __cplusplus
+
+ #if __cplusplus
+ extern "C" {
+ #endif // __cplusplus
+ const struct FileToc* this_rule_name__create();
+ #if __cplusplus
+ }
+ #endif // __cplusplus
+
+ The 'this_rule_name()' function will return an array of FileToc
+ structs terminated by one that has NULL 'name' and 'data' fields.
+ The 'data' field always has an extra null terminator at the end (which
+ is not included in the size).
+
+ Args:
+ name: The rule name, which will also be the identifier of the generated
+ code symbol.
+ srcs: List of files to embed.
+ c_file_output: The C implementation file to output.
+ h_file_output: The H header file to output.
+ testonly: If True, only testonly targets can depend on this target.
+ strip_prefix: Strips this verbatim prefix from filenames (in the TOC).
+ flatten: Removes all directory components from filenames (in the TOC).
+ identifier: The identifier to use in generated names (defaults to name).
+ **kwargs: Args to pass to the cc_library.
+ """
+ generator = "//build_tools/embed_data:generate_embed_data"
+ generator_location = "$(location %s)" % generator
+ if identifier == None:
+ identifier = name
+ flags = "--output_header='$(location %s)' --output_impl='$(location %s)'" % (
+ h_file_output,
+ c_file_output,
+ )
+ flags += " --c_output=true"
+ flags += " --identifier='%s'" % (identifier,)
+ if strip_prefix != None:
+ flags += " --strip_prefix='%s'" % (strip_prefix,)
+ if flatten:
+ flags += " --flatten"
+
+ native.genrule(
+ name = name + "__generator",
+ srcs = srcs,
+ outs = [
+ c_file_output,
+ h_file_output,
+ ],
+ tools = [generator],
+ cmd = "%s $(SRCS) %s" % (generator_location, flags),
+ testonly = testonly,
+ )
+ native.cc_library(
+ name = name,
+ hdrs = [h_file_output],
+ srcs = [c_file_output],
+ testonly = testonly,
+ **kwargs
+ )
diff --git a/build_tools/embed_data/c_embed_data_test.cc b/build_tools/embed_data/c_embed_data_test.cc
new file mode 100644
index 0000000..a112907
--- /dev/null
+++ b/build_tools/embed_data/c_embed_data_test.cc
@@ -0,0 +1,42 @@
+// Copyright 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "build_tools/embed_data/testembed1_c.h"
+#include "iree/testing/gtest.h"
+
+namespace {
+
+TEST(Generator, TestContents) {
+ auto* toc1 = testembed1_c_create();
+ ASSERT_EQ("file1.txt", std::string(toc1->name));
+ ASSERT_EQ(R"(Are you '"Still"' here?)"
+ "\n",
+ std::string(toc1->data));
+ ASSERT_EQ(24, toc1->size);
+ ASSERT_EQ(0, *(toc1->data + toc1->size));
+
+ ++toc1;
+ ASSERT_EQ("file2.txt", std::string(toc1->name));
+ ASSERT_EQ(R"(¯\_(ツ)_/¯)"
+ "\n",
+ std::string(toc1->data));
+ ASSERT_EQ(14, toc1->size);
+ ASSERT_EQ(0, *(toc1->data + toc1->size));
+
+ ++toc1;
+ ASSERT_EQ(nullptr, toc1->name);
+ ASSERT_EQ(nullptr, toc1->data);
+}
+
+} // namespace
diff --git a/build_tools/embed_data/generate_cc_embed_data_main.cc b/build_tools/embed_data/generate_embed_data_main.cc
similarity index 68%
rename from build_tools/embed_data/generate_cc_embed_data_main.cc
rename to build_tools/embed_data/generate_embed_data_main.cc
index 0ded0e6..137479c 100644
--- a/build_tools/embed_data/generate_cc_embed_data_main.cc
+++ b/build_tools/embed_data/generate_embed_data_main.cc
@@ -25,12 +25,25 @@
ABSL_FLAG(std::string, identifier, "resources",
"name of the resources function");
ABSL_FLAG(std::string, output_header, "", "output header file");
-ABSL_FLAG(std::string, output_impl, "", "output cc impl file");
+ABSL_FLAG(std::string, output_impl, "", "output impl file");
ABSL_FLAG(std::string, cpp_namespace, "", "generate in a c++ namespace");
+ABSL_FLAG(bool, c_output, false, "generate a c output");
ABSL_FLAG(std::string, strip_prefix, "", "strip prefix from filenames");
ABSL_FLAG(bool, flatten, false,
"whether to flatten the directory structure (only include basename)");
+void GenerateExternCOpen(std::ofstream& f) {
+ f << "\n#if __cplusplus\n";
+ f << "extern \"C\" {\n";
+ f << "#endif // __cplusplus\n";
+}
+
+void GenerateExternCClose(std::ofstream& f) {
+ f << "#if __cplusplus\n";
+ f << "}\n";
+ f << "#endif // __cplusplus\n\n";
+}
+
void GenerateNamespaceOpen(std::ofstream& f) {
const auto& ns = absl::GetFlag(FLAGS_cpp_namespace);
if (ns.empty()) return;
@@ -54,32 +67,58 @@
}
void GenerateTocStruct(std::ofstream& f) {
+ const auto& c_output = absl::GetFlag(FLAGS_c_output);
f << "#ifndef IREE_FILE_TOC\n";
f << "#define IREE_FILE_TOC\n";
- f << "namespace iree {\n";
+ if (c_output) {
+ GenerateExternCOpen(f);
+ } else {
+ f << "namespace iree {\n";
+ }
f << "struct FileToc {\n";
f << " const char* name; // the file's original name\n";
f << " const char* data; // beginning of the file\n";
- f << " std::size_t size; // length of the file\n";
- f << "};\n";
- f << "} // namespace iree\n";
+ if (c_output) {
+ f << " size_t size; // length of the file\n";
+ f << "};\n";
+ GenerateExternCClose(f);
+ } else {
+ f << " std::size_t size; // length of the file\n";
+ f << "};\n";
+ f << "} // namespace iree\n";
+ }
f << "#endif // IREE_FILE_TOC\n";
}
bool GenerateHeader(const std::string& header_file,
const std::vector<std::string>& toc_files) {
std::ofstream f(header_file, std::ios::out | std::ios::trunc);
+ const auto& c_output = absl::GetFlag(FLAGS_c_output);
+
f << "#pragma once\n"; // Pragma once isn't great but is the best we can do.
- f << "#include <cstddef>\n";
- GenerateTocStruct(f);
- GenerateNamespaceOpen(f);
- f << "extern const struct ::iree::FileToc* "
- << absl::GetFlag(FLAGS_identifier) << "_create();\n";
- f << "static inline std::size_t " << absl::GetFlag(FLAGS_identifier)
- << "_size() { \n";
- f << " return " << toc_files.size() << ";\n";
- f << "}\n";
- GenerateNamespaceClose(f);
+ if (c_output) {
+ f << "#include <stddef.h>\n";
+ GenerateTocStruct(f);
+ GenerateExternCOpen(f);
+ f << "const struct FileToc* " << absl::GetFlag(FLAGS_identifier)
+ << "_create();\n";
+ f << "static inline size_t " << absl::GetFlag(FLAGS_identifier)
+ << "_size() {\n";
+ f << " return " << toc_files.size() << ";\n";
+ f << "}\n";
+ GenerateExternCClose(f);
+ } else {
+ f << "#include <cstddef>\n";
+ GenerateTocStruct(f);
+ GenerateNamespaceOpen(f);
+ f << "extern const struct ::iree::FileToc* "
+ << absl::GetFlag(FLAGS_identifier) << "_create();\n";
+ f << "static inline std::size_t " << absl::GetFlag(FLAGS_identifier)
+ << "_size() { \n";
+ f << " return " << toc_files.size() << ";\n";
+ f << "}\n";
+ GenerateNamespaceClose(f);
+ }
f.close();
return f.good();
}
@@ -109,9 +148,16 @@
const std::vector<std::string>& input_files,
const std::vector<std::string>& toc_files) {
std::ofstream f(impl_file, std::ios::out | std::ios::trunc);
- f << "#include <cstddef>\n";
- GenerateTocStruct(f);
- GenerateNamespaceOpen(f);
+ const auto& c_output = absl::GetFlag(FLAGS_c_output);
+ if (c_output) {
+ f << "#include <stddef.h>\n";
+ f << "#include <stdalign.h>\n";
+ GenerateTocStruct(f);
+ } else {
+ f << "#include <cstddef>\n";
+ GenerateTocStruct(f);
+ GenerateNamespaceOpen(f);
+ }
for (size_t i = 0, e = input_files.size(); i < e; ++i) {
f << "alignas(alignof(void*)) static char const file_" << i << "[] = {\n";
std::string contents;
@@ -128,7 +174,11 @@
}
f << "};\n";
}
- f << "static const struct ::iree::FileToc toc[] = {\n";
+ if (c_output) {
+ f << "static const struct FileToc toc[] = {\n";
+ } else {
+ f << "static const struct ::iree::FileToc toc[] = {\n";
+ }
assert(input_files.size() == toc_files.size());
for (size_t i = 0, e = input_files.size(); i < e; ++i) {
f << " {\n";
@@ -137,14 +187,22 @@
f << " sizeof(file_" << i << ") - 1\n";
f << " },\n";
}
- f << " {nullptr, nullptr, 0},\n";
- f << "};\n";
- f << "const struct ::iree::FileToc* " << absl::GetFlag(FLAGS_identifier)
- << "_create() {\n";
+ if (c_output) {
+ f << " {NULL, NULL, 0},\n";
+ f << "};\n";
+ f << "const struct FileToc* " << absl::GetFlag(FLAGS_identifier)
+ << "_create() {\n";
+ } else {
+ f << " {nullptr, nullptr, 0},\n";
+ f << "};\n";
+ f << "const struct ::iree::FileToc* " << absl::GetFlag(FLAGS_identifier)
+ << "_create() {\n";
+ }
f << " return &toc[0];\n";
f << "}\n";
-
- GenerateNamespaceClose(f);
+ if (!c_output) {
+ GenerateNamespaceClose(f);
+ }
f.close();
return f.good();
}
@@ -175,7 +233,12 @@
}
toc_files.push_back(toc_file);
}
-
+ // Can either generate the c or c++ output.
+ if (!absl::GetFlag(FLAGS_cpp_namespace).empty() &&
+ absl::GetFlag(FLAGS_c_output)) {
+ std::cerr << "Can only generate either c or c++ output.\n";
+ return 1;
+ }
if (!absl::GetFlag(FLAGS_output_header).empty()) {
if (!GenerateHeader(absl::GetFlag(FLAGS_output_header), toc_files)) {
std::cerr << "Error generating headers.\n";
diff --git a/iree/base/internal/flags.h b/iree/base/internal/flags.h
index f98e19b..ef2c9da 100644
--- a/iree/base/internal/flags.h
+++ b/iree/base/internal/flags.h
@@ -17,6 +17,10 @@
#include "iree/base/api.h"
+#ifdef __cplusplus
+extern "C" {
+#endif // __cplusplus
+
//===----------------------------------------------------------------------===//
// Flag parsing
//===----------------------------------------------------------------------===//
@@ -56,4 +60,8 @@
// typo and shut down your entire server/sandbox/Android app/etc.
void iree_flags_parse_checked(int* argc, char*** argv);
+#ifdef __cplusplus
+} // extern "C"
+#endif // __cplusplus
+
#endif // IREE_BASE_INTERNAL_FLAGS_H_
diff --git a/iree/base/target_platform.h b/iree/base/target_platform.h
index 89234e9..118ebc6 100644
--- a/iree/base/target_platform.h
+++ b/iree/base/target_platform.h
@@ -15,6 +15,7 @@
#ifndef IREE_BASE_TARGET_PLATFORM_H_
#define IREE_BASE_TARGET_PLATFORM_H_
+#include <assert.h>
#include <stdint.h>
// The build system defines one of the following top-level platforms and then
@@ -104,12 +105,18 @@
// IREE_PTR_SIZE_*
//==============================================================================
-#if UINTPTR_MAX > UINT_MAX
+// See https://stackoverflow.com/q/51616057
+static_assert(sizeof(void*) == sizeof(uintptr_t),
+ "can't determine pointer size");
+
+#if UINTPTR_MAX == 0xFFFFFFFF
+#define IREE_PTR_SIZE_32
+#define IREE_PTR_SIZE 4
+#elif UINTPTR_MAX == 0xFFFFFFFFFFFFFFFFu
#define IREE_PTR_SIZE_64
#define IREE_PTR_SIZE 8
#else
-#define IREE_PTR_SIZE_32
-#define IREE_PTR_SIZE 4
+#error "can't determine pointer size"
#endif
//==============================================================================
diff --git a/iree/compiler/Conversion/Common/LinalgBufferizePass.cpp b/iree/compiler/Conversion/Common/LinalgBufferizePass.cpp
index 5a34433..3919ccb 100644
--- a/iree/compiler/Conversion/Common/LinalgBufferizePass.cpp
+++ b/iree/compiler/Conversion/Common/LinalgBufferizePass.cpp
@@ -14,13 +14,34 @@
//===- LinalgBufferizePass.cpp.cpp - Pass to bufferize Linalg on tensors --===//
//
-// Pass to convert from Linalg ops on tensors to Linalg ops on buffers.
-// This just inserts AllocOp to address space 0 that can be later hoisted,
-// promoted and generally rewritten to the desired backend.
+// The overall bufferizarion algorithm is summarized here. Each of the
+// individual steps are explained in detail later.
//
-// TODO(nicolasvasilache): the implementation of this pass is unnecessarily
-// convoluted due to asymmetries arising from tie_shape weirdness. Revisit once
-// this abstraction is replaced.
+// Problem statement:
+//
+// The bufferization in this file is intended for converting tensor-operations
+// into memref-operations for ops within a dispatch region. The goal is to reuse
+// the buffers provided as inputs/outputs by the hal layer as memrefs for each
+// of the operations. If the transformation cannot reuse input/output buffer to
+// store an intermediate tensor, an allocation is done. This allocation is
+// typically meant to be to target scratchspace memory.
+//
+// The algorithm has two phases an analysis phase and a tranformation phase.
+//
+// - The analysis phase walks the function and organizes relevant tensors
+// (tensors that need to be converted to memrefs) into equivalence clases. Two
+// tensors are part of the same equivalence class if they can eventually be
+// mapped to the same memref. This allows determining which operations can use
+// the buffer provided for the outputs to compute the results in place.
+// - The transformation phase walks the function again and inserts corresponding
+// memref operations. The tensor operations are still kept around since the
+// analysis driving the transformation is based on the tensor values.
+// - Converting tensor operations to memref operations when all operands use
+// either buffers that are inputs to the dispatch or are allocated
+// temporarily within the dispatch region can be achieved by a
+// straight-forward walk.
+// - Reusing memref for the result of the dispatch for operations is more
+// involved and explained below.
//
//===----------------------------------------------------------------------===//
@@ -33,6 +54,7 @@
#include "iree/compiler/Dialect/IREE/IR/IREEDialect.h"
#include "iree/compiler/Dialect/IREE/IR/IREEOps.h"
#include "iree/compiler/Dialect/Shape/IR/ShapeOps.h"
+#include "llvm/ADT/EquivalenceClasses.h"
#include "llvm/ADT/TypeSwitch.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
@@ -49,218 +71,22 @@
namespace iree_compiler {
//===----------------------------------------------------------------------===//
-// Utility functions.
+// Analysis to compute equivalence sets.
+//
+// These functions compute the equivalence relationships between all tensors in
+// the program. Two tensors are equivalent if they are to be mapped to the same
+// buffer. For every operation, based on the operation semantics the result of
+// the operation can reuse the buffer for an operand of the operation. This
+// information is captured by adding these two tensors to the same equivalence
+// class. Eventually the result of the dispatch tensor is added to some
+// equivalence set. All tensors in that equivalence set can reuse the result
+// buffer and compute the values in place. You can add tensors to equivalence
+// set only if
+// - They have a single use
+// - They are derived from a read-only buffer.
+//
//===----------------------------------------------------------------------===//
-static MemRefType getMemrefTypeForTensor(RankedTensorType tensorType,
- ArrayRef<AffineMap> layout = {},
- unsigned memorySpace = 0) {
- return MemRefType::get(tensorType.getShape(), tensorType.getElementType(),
- layout, memorySpace);
-}
-
-// Transfer all `dim` ops on `tensor` to `memref`.
-static void transferShapeOpsToMemref(OpBuilder &b, Value tensor, Value memref,
- BlockAndValueMapping &bvm) {
- for (OpOperand &opOperand : llvm::make_early_inc_range(tensor.getUses())) {
- if (isa<memref::DimOp>(opOperand.getOwner())) {
- opOperand.set(memref);
- continue;
- }
- if (auto flowTieShapeOp =
- dyn_cast<IREE::Flow::DispatchTieShapeOp>(opOperand.getOwner())) {
- OpBuilder::InsertionGuard g(b);
- b.setInsertionPoint(flowTieShapeOp);
- auto tieShapeOp =
- b.create<Shape::TieShapeOp>(flowTieShapeOp.getLoc(), memref.getType(),
- memref, flowTieShapeOp.shape());
- bvm.map(flowTieShapeOp.getResult(), tieShapeOp.getResult());
- continue;
- }
- }
-}
-
-/// Creates a subview operation given the `src`, `offsets`, `sizes` and
-/// `strides`. Handles the corner case where the `offsets`, `sizes` and
-/// `strides` are empty in which case just forward the `src` value.
-/// TODO(ataei): Instead create memref.subview %v [][][] folder.
-static Value createSubviewOp(OpBuilder &b, Location loc, Value src,
- ArrayRef<OpFoldResult> offsets,
- ArrayRef<OpFoldResult> sizes,
- ArrayRef<OpFoldResult> strides) {
- if (offsets.empty() && sizes.empty() && strides.empty()) return src;
- return b.create<memref::SubViewOp>(loc, src, offsets, sizes, strides);
-}
-
-//===----------------------------------------------------------------------===//
-// Bufferization helper functions using BlockAndValueMapping.
-//===----------------------------------------------------------------------===//
-
-// Non-conversion equivalent of the core MLIR Linalg bufferization patterns.
-// Allocate the output buffers for the bufferized Linalg op to write into.
-// If the tensor is an init tensor, we additionally copy the original value into
-// the newly allocated buffer.
-static LogicalResult allocateBuffersForResults(
- OpBuilder &b, Location loc, WorkgroupMemoryAllocationFn allocationFn,
- linalg::LinalgOp op, SmallVectorImpl<Value> &resultBuffers,
- BlockAndValueMapping &bvm) {
- // Lazily compute loopRanges.
- SmallVector<Range, 4> loopRanges;
-
- assert(op.getNumOutputs() == op->getNumResults());
- for (auto en : llvm::enumerate(op->getResultTypes())) {
- size_t resultIndex = en.index();
- Value outTensor = op.getOutput(resultIndex);
- Value resultTensor = op->getResult(en.index());
-
- // If output tensor was produced by a LinalgOp, just reuse the buffer.
- // TODO(nicolasvasilache): this may be too brutal and we may prefer to leave
- // this decision to a copy + alloc removal pass.
- if (outTensor.getDefiningOp<linalg::LinalgOp>()) {
- Value outBuffer = bvm.lookup(outTensor);
- bvm.map(resultTensor, outBuffer);
- resultBuffers.push_back(outBuffer);
- continue;
- }
-
- // If resultTensor already has a buffer, just use that.
- Value alloc = bvm.lookupOrNull(resultTensor);
- if (!alloc) {
- Type resultType = en.value();
- auto tensorType = resultType.dyn_cast<RankedTensorType>();
- auto tensorShape = tensorType.getShape();
- SmallVector<Value, 4> dynOperands;
- for (auto dim : llvm::enumerate(tensorShape)) {
- Value dimTensor = bvm.lookupOrNull(outTensor);
- if (!dimTensor) dimTensor = outTensor;
- if (dim.value() == TensorType::kDynamicSize) {
- dynOperands.push_back(
- b.createOrFold<memref::DimOp>(loc, dimTensor, dim.index()));
- }
- }
- alloc = allocationFn(b, loc, tensorShape, tensorType.getElementType(),
- dynOperands);
- bvm.map(resultTensor, alloc);
- }
- resultBuffers.push_back(alloc);
-
- // Additionally, if the output buffer is used, clone its value for now. The
- // method `payloadUsesValueFromOutputOperandIndex` only works on named ops
- // that have a region. Named ops like `conv`, etc. that are manually defined
- // do not have this generated by default. So for now, just handled these
- // manually defined ops specifically.
- if (!isa<linalg::FillOp>(op.getOperation()) &&
- op.payloadUsesValueFromOutputOperandIndex(resultIndex)) {
- b.create<linalg::CopyOp>(loc, bvm.lookup(outTensor), alloc);
- }
- }
- for (auto it : llvm::zip(op->getResults(), resultBuffers)) {
- transferShapeOpsToMemref(b, std::get<0>(it), std::get<1>(it), bvm);
- }
- return success();
-}
-
-// Non-conversion equivalent of the core MLIR Linalg bufferization patterns.
-static LogicalResult finalizeBufferAllocation(OpBuilder &b, linalg::LinalgOp op,
- ValueRange inputs,
- ValueRange outputs,
- BlockAndValueMapping &bvm) {
- SmallVector<Value, 8> newOperands = inputs;
- newOperands.append(outputs.begin(), outputs.end());
- auto otherOperands =
- llvm::map_range(op.getAssumedNonShapedOperands(),
- [&bvm](Value v) { return bvm.lookupOrDefault(v); });
- newOperands.append(otherOperands.begin(), otherOperands.end());
- Location loc = op.getLoc();
- op.clone(b, loc, {}, newOperands);
-
- // Replace the results of the old op with the new output buffers.
- for (auto result : llvm::enumerate(op.getOperation()->getResults())) {
- Value resultValue = result.value();
- Value resultBuffer = bvm.lookup(resultValue);
- if (resultBuffer != outputs[result.index()]) {
- b.create<linalg::CopyOp>(loc, outputs[result.index()], resultBuffer);
- }
- }
- return success();
-}
-
-/// Generic conversion pattern that matches any linalg::LinalgOp. This avoids
-/// template instantiating one pattern for each linalg::LinalgOp.
-static LogicalResult convertAnyLinalgOp(
- OpBuilder &b, WorkgroupMemoryAllocationFn allocationFn, linalg::LinalgOp op,
- BlockAndValueMapping &bvm) {
- // Skip linalg ops inserted by this pass.
- if (op.hasBufferSemantics()) return success();
-
- OpBuilder::InsertionGuard g(b);
- b.setInsertionPoint(op);
- Location loc = op.getLoc();
- SmallVector<Value, 2> newInputBuffers;
- newInputBuffers.reserve(op.getNumInputs());
- for (Value v : op.getInputs()) {
- newInputBuffers.push_back(bvm.lookup(v));
- }
- SmallVector<Value, 2> newOutputBuffers;
- if (failed(allocateBuffersForResults(b, loc, allocationFn, op,
- newOutputBuffers, bvm))) {
- LLVM_DEBUG(llvm::dbgs()
- << "failed to allocate output buffers for op: " << op << "\n");
- return failure();
- }
-
- // Delegate to the linalg generic pattern.
- if (auto genericOp = dyn_cast<linalg::GenericOp>(op.getOperation())) {
- return finalizeBufferAllocation(b, genericOp, newInputBuffers,
- newOutputBuffers, bvm);
- }
-
- return finalizeBufferAllocation(b, op, newInputBuffers, newOutputBuffers,
- bvm);
-}
-
-/// Constants that return tensor types can be handled natively by the
-/// backends. Here just provide a cast to memref to bridge the gap from tensors
-/// to memrefs.
-static LogicalResult convertConstantOp(OpBuilder &b, ConstantOp constantOp,
- BlockAndValueMapping &bvm) {
- Value result = constantOp.getResult();
- RankedTensorType tensorType = result.getType().dyn_cast<RankedTensorType>();
- if (!tensorType) return success();
- OpBuilder::InsertionGuard g(b);
- b.setInsertionPointAfter(constantOp);
- auto memrefType = getMemrefTypeForTensor(tensorType);
- Value memref =
- b.create<memref::BufferCastOp>(constantOp.getLoc(), memrefType, result);
- if (Value resultBuffer = bvm.lookupOrNull(result)) {
- // Since this is already remapped to a buffer, copy the data. Note that
- // constant ops are typicaly placed at the beginning of the block; we need
- // to make sure to insert the new copy op after the result buffer, which can
- // be after the constant op.
- b.setInsertionPointAfterValue(resultBuffer);
- b.create<linalg::CopyOp>(constantOp.getLoc(), memref, resultBuffer);
- } else {
- bvm.map(result, memref);
- }
- return success();
-}
-
-/// Converts a linalg.init_tensor op to memref.alloc op. This provides a shaped
-/// operand for pooling ops. The op will be deleted after going to loops.
-static LogicalResult convertInitTensorOp(
- OpBuilder &b, WorkgroupMemoryAllocationFn allocationFn,
- linalg::InitTensorOp initTensorOp, BlockAndValueMapping &bvm) {
- if (bvm.contains(initTensorOp.getResult())) return success();
- RankedTensorType tensorType = initTensorOp.getType();
- OpBuilder::InsertionGuard g(b);
- b.setInsertionPointAfter(initTensorOp);
- Value alloc = allocationFn(b, initTensorOp.getLoc(), tensorType.getShape(),
- tensorType.getElementType(),
- llvm::to_vector<4>(initTensorOp.sizes()));
- bvm.map(initTensorOp.getResult(), alloc);
- return success();
-}
-
/// Walks the use-def chain and see if this value comes from a read-only tensor.
static bool isFromReadOnlyTensor(Value v) {
auto definingOp = v.getDefiningOp();
@@ -283,224 +109,446 @@
.Default([&](Operation *op) { return false; });
}
-/// Avoids creating an allocation if the result tensor can just be aliased to
-/// use the same buffer (`inputBuffer`) that `srcTensor` is mapped to. This can
-/// be done if `srcTensor` has a single use, which is the operation which is
-/// being converted to buffers.
-/// Note that the mapping for `srcTensor` need not be mapped to `inputBuffer`
-/// directly. It could also be mapped to an alias of the `inputBuffer.
-static LogicalResult createAliasingBufferOrAllocationForResult(
- OpBuilder &b, Location loc, WorkgroupMemoryAllocationFn allocationFn,
- Value srcTensor, Value inputBuffer, Value resultTensor,
- ArrayRef<Value> allocationDynamicDims, BlockAndValueMapping &bvm) {
- // Case 1 : If result tensor is already mapped to a buffer just copy the
- // value.
- if (Value outputBuffer = bvm.lookupOrNull(resultTensor)) {
- if (inputBuffer != outputBuffer) {
- b.create<linalg::CopyOp>(loc, inputBuffer, outputBuffer);
+/// Class that tracks the equivalence relationship between tensors. Its a
+/// light-weight wrapper around `llvm::EquivalenceClasses` to account for
+/// `Value` not directly supported as a value type by this class.
+class BufferizationPlan {
+ public:
+ llvm::EquivalenceClasses<void *>::iterator findValue(Value v) {
+ return mappedTensors.findValue(getPointer(v));
+ }
+
+ llvm::EquivalenceClasses<void *>::iterator end() {
+ return mappedTensors.end();
+ }
+
+ SmallVector<Value> getTensorsMappedToSameSet(Value v) {
+ SmallVector<Value> tensors;
+ for (auto it = mappedTensors.findLeader(getPointer(v)),
+ ie = mappedTensors.member_end();
+ it != ie; ++it) {
+ tensors.push_back(getValue(*it));
}
- return success();
+ return tensors;
}
- // Case 2: If the input tensor has only one use (this operation) or is from a
- // read-only tensor, then no need to create a copy either.
- if (srcTensor.hasOneUse() || isFromReadOnlyTensor(srcTensor)) {
- bvm.map(resultTensor, inputBuffer);
- return success();
+
+ bool isEquivalent(Value v1, Value v2) {
+ return mappedTensors.isEquivalent(getPointer(v1), getPointer(v2));
}
- // Fallback is to create an allocation and copy the output.
- MemRefType inputBufferType = inputBuffer.getType().cast<MemRefType>();
- assert(allocationDynamicDims.size() ==
- static_cast<size_t>(inputBufferType.getRank()));
- Value alloc = allocationFn(
- b, loc, SmallVector<int64_t, 4>(inputBufferType.getRank(), -1),
- inputBufferType.getElementType(), allocationDynamicDims);
- b.create<linalg::CopyOp>(loc, inputBuffer, alloc);
- bvm.map(resultTensor, alloc);
+
+ void insert(Value v) { mappedTensors.insert(getPointer(v)); }
+
+ void unionSets(Value v1, Value v2) {
+ mappedTensors.unionSets(getPointer(v1), getPointer(v2));
+ }
+
+ /// Sets the equivalance class that contains `v` as the set that contains the
+ /// result tensor of the dispatch region (i.e. a tensor that is the `value`
+ /// operand of a flow.dispatch.tensor.store` op). All operations in this
+ /// equivalence class can use the result buffer of the dispatch region to
+ /// compute their values in place.
+ void storeSet(Value v) { storeLeaders.insert(getLeaderValue(v)); }
+
+ /// Queries if the value `v` is in the same equivalence class as the result of
+ /// the dispatch region.
+ bool isInStoreSet(Value v) { return storeLeaders.count(getLeaderValue(v)); }
+
+ void dump() {
+ llvm::dbgs() << "BufferMappings : \n";
+ unsigned numSets = 0;
+ for (auto it = mappedTensors.begin(), ie = mappedTensors.end(); it != ie;
+ ++it) {
+ if (!it->isLeader()) continue;
+ llvm::dbgs() << "\tSet " << numSets << ":\n";
+ for (auto member : llvm::make_range(mappedTensors.member_begin(it),
+ mappedTensors.member_end())) {
+ llvm::dbgs() << "\t\t";
+ getValue(member).print(llvm::dbgs());
+ llvm::dbgs() << "\n";
+ }
+ numSets++;
+ }
+ }
+
+ private:
+ Value getLeaderValue(Value v1) {
+ return getValue(mappedTensors.getLeaderValue(getPointer(v1)));
+ }
+
+ void *getPointer(Value v) { return v.getAsOpaquePointer(); }
+
+ Value getValue(void *v) { return Value::getFromOpaquePointer(v); }
+
+ llvm::EquivalenceClasses<void *> mappedTensors;
+
+ /// Leaders of the sets that contain the result tensor of the dispatch
+ /// region, i.e. a tensor that is the `value` operand of a
+ /// flow.dispatch.tensor.store` op
+ llvm::DenseSet<Value> storeLeaders;
+};
+
+/// Adds the result of `std.constant` to its set (there is nothing to tie to
+/// here).
+static LogicalResult analyseConstantOp(ConstantOp constantOp,
+ BufferizationPlan &plan) {
+ if (!constantOp.getResult().getType().isa<ShapedType>()) return success();
+ plan.insert(constantOp.getResult());
return success();
}
-/// Converts a `linalg.tensor_reshape` operation to a `linalg.reshape`
-/// operation.
-static LogicalResult convertTensorReshapeOp(
- OpBuilder &b, WorkgroupMemoryAllocationFn allocationFn,
- linalg::TensorReshapeOp op, BlockAndValueMapping &bvm) {
- OpBuilder::InsertionGuard g(b);
- b.setInsertionPoint(op);
- Location loc = op.getLoc();
- Value srcTensor = op.src();
- RankedTensorType srcTensorType = op.getSrcType();
- Value resultTensor = op.result();
- RankedTensorType resultTensorType = op.getResultType();
- Value inputBuffer = bvm.lookup(srcTensor);
- MemRefType inputBufferType = inputBuffer.getType().cast<MemRefType>();
- // Create the reshape op.
- auto reshapeSrcType = getMemrefTypeForTensor(
- srcTensorType, {}, inputBufferType.getMemorySpaceAsInt());
- Value reshapeSrc =
- b.createOrFold<memref::CastOp>(loc, inputBuffer, reshapeSrcType);
- auto reshapeResultType = getMemrefTypeForTensor(
- resultTensorType, {}, inputBufferType.getMemorySpaceAsInt());
- Value bufferReshape = b.create<linalg::ReshapeOp>(
- loc, reshapeResultType, reshapeSrc, op.reassociation());
- SmallVector<SmallVector<Value>> reshapeResultShape;
- if (failed(op.reifyReturnTypeShapesPerResultDim(b, reshapeResultShape)) ||
- reshapeResultShape.size() != 1) {
- return op.emitError("failed to get shape of result");
- }
- return createAliasingBufferOrAllocationForResult(
- b, loc, allocationFn, srcTensor, bufferReshape, resultTensor,
- reshapeResultShape[0], bvm);
-}
-
-static SmallVector<int64_t, 4> extractFromI64ArrayAttr(ArrayAttr attr) {
- return llvm::to_vector<4>(llvm::map_range(attr, [](Attribute a) -> int64_t {
- return a.cast<IntegerAttr>().getInt();
- }));
-}
-
-/// Converts a `subtensor` operation to a `subview` operation.
-static LogicalResult convertSubTensorOp(
- OpBuilder &b, WorkgroupMemoryAllocationFn allocationFn, SubTensorOp op,
- BlockAndValueMapping &bvm) {
- OpBuilder::InsertionGuard g(b);
- b.setInsertionPoint(op);
- Location loc = op.getLoc();
- Value srcTensor = op.source();
- Value resultTensor = op.result();
- Value inputBuffer = bvm.lookup(srcTensor);
- MemRefType inputBufferType = inputBuffer.getType().cast<MemRefType>();
-
- auto subViewResultType = memref::SubViewOp::inferResultType(
- inputBufferType, extractFromI64ArrayAttr(op.static_offsets()),
- extractFromI64ArrayAttr(op.static_sizes()),
- extractFromI64ArrayAttr(op.static_strides()));
- auto subViewOp = b.create<memref::SubViewOp>(
- loc, subViewResultType, inputBuffer, op.offsets(), op.sizes(),
- op.strides(), op.static_offsets(), op.static_sizes(),
- op.static_strides());
- auto allocationDynamicSizes = llvm::to_vector<4>(
- llvm::map_range(subViewOp.getOrCreateRanges(b, loc), [](Range range) {
- assert(matchPattern(range.stride, m_One()) &&
- "unhandled non-unit stride");
- return range.size;
- }));
- return createAliasingBufferOrAllocationForResult(
- b, loc, allocationFn, srcTensor, subViewOp, resultTensor,
- allocationDynamicSizes, bvm);
-}
-
-/// Converts a `subtensor_insert` operation to buffers by
-/// - Allocating a buffer for the result (if needed), and copying the
-/// destination value into this buffer.
-/// - Copying the source values into a subview of the result buffer.
-static LogicalResult convertSubTensorInsertOp(
- OpBuilder &b, WorkgroupMemoryAllocationFn allocationFn,
- SubTensorInsertOp op, BlockAndValueMapping &bvm) {
- OpBuilder::InsertionGuard g(b);
- b.setInsertionPoint(op);
- Location loc = op.getLoc();
- Value dest = op.dest();
- Value inputBuffer = bvm.lookup(dest);
- SmallVector<Value> allocationDynamicSizes;
- int64_t rank = inputBuffer.getType().cast<ShapedType>().getRank();
- for (auto dim : llvm::seq<int64_t>(0, rank)) {
- allocationDynamicSizes.push_back(
- b.createOrFold<memref::DimOp>(loc, inputBuffer, dim));
- }
- if (failed(createAliasingBufferOrAllocationForResult(
- b, loc, allocationFn, dest, inputBuffer, op.getResult(),
- allocationDynamicSizes, bvm))) {
- return failure();
- }
-
- Value source = op.source();
- Value outputBuffer = bvm.lookup(op.result());
- Value sourceBuffer = bvm.lookup(source);
- auto subViewOp = createSubviewOp(b, loc, outputBuffer, op.getMixedOffsets(),
- op.getMixedSizes(), op.getMixedStrides());
- b.create<linalg::CopyOp>(loc, sourceBuffer, subViewOp);
+/// Adds the result of the `flow.dispatch.tensor.load` op to the same
+/// equivalence class as the source.
+static LogicalResult analyseInterfaceLoadTensorOp(
+ IREE::Flow::DispatchTensorLoadOp loadOp, BufferizationPlan &plan) {
+ plan.unionSets(loadOp.result(), loadOp.source());
return success();
}
-/// Converts a `tensor.extract` operation into a `load`.
-static LogicalResult convertTensorExtractOp(OpBuilder &b, tensor::ExtractOp op,
- BlockAndValueMapping &bvm) {
- OpBuilder::InsertionGuard g(b);
- b.setInsertionPoint(op);
- Value inputBuffer = bvm.lookup(op.tensor());
- Value load =
- b.createOrFold<memref::LoadOp>(op.getLoc(), inputBuffer, op.indices());
- bvm.map(op.result(), load);
+/// Helper method to returns an operation of type `OpType` whose result is in
+/// the same equivalence set as `value`. Returns an operation if there is only
+/// one such op in the equivalence set or nullptr in all other cases.
+template <typename OpType>
+static OpType getEquivalentOpOfType(Value value, BufferizationPlan &plan) {
+ OpType equivalentOp;
+ SmallVector<Value> mappedTensors = plan.getTensorsMappedToSameSet(value);
+ for (auto v : mappedTensors) {
+ auto definingOp = v.getDefiningOp<OpType>();
+ if (!definingOp) continue;
+ assert((!equivalentOp || equivalentOp == definingOp) &&
+ "found two interface binding ops marked as equivalent");
+ if (!equivalentOp) equivalentOp = definingOp;
+ }
+ return equivalentOp;
+}
+
+/// Returns true if the value and target of a `flow.dispatch.tensor.store`
+/// operation can be added to the same equivalence set. This can be done only if
+/// - The `value` is not from a equivalence set that contains a read-only
+/// tensor.
+/// - All `hal.interface.binding.subspan` operations in the equivalence class of
+/// `value` and `target` have the same binding and offset. For now, it is
+/// assumed that the equivalence classes contain only 1 such instruction.
+/// This method asserts that the `target` equivalence class already contains a
+/// `hal.interface.binding.subspan` op.'
+static bool canSetStoreValueAndTargetAsEquivalent(
+ IREE::Flow::DispatchTensorStoreOp storeOp, BufferizationPlan &plan) {
+ Value value = storeOp.value();
+ Value target = storeOp.target();
+ auto targetInterfaceOp =
+ getEquivalentOpOfType<IREE::HAL::InterfaceBindingSubspanOp>(target, plan);
+ assert(targetInterfaceOp);
+ if (auto valueConstantOp = getEquivalentOpOfType<ConstantOp>(value, plan)) {
+ return false;
+ }
+ if (auto valueInterfaceOp =
+ getEquivalentOpOfType<IREE::HAL::InterfaceBindingSubspanOp>(value,
+ plan)) {
+ if (targetInterfaceOp.binding() != valueInterfaceOp.binding() ||
+ targetInterfaceOp.byte_offset() != valueInterfaceOp.byte_offset()) {
+ // If the binding and offsets are different, map these to different
+ // memrefs.
+ return false;
+ }
+ // If the binding and offsets are the same, make sure that the
+ // !flow.dispatch.tensor is read-write.
+ auto sourceType =
+ valueInterfaceOp.getType().dyn_cast<IREE::Flow::DispatchTensorType>();
+ return sourceType &&
+ sourceType.getAccess() == IREE::Flow::TensorAccess::ReadWrite;
+ }
+ return true;
+}
+
+/// Tries to add the `value` and `target` to the same equivalence class.
+static LogicalResult analyseInterfaceStoreTensorOp(
+ IREE::Flow::DispatchTensorStoreOp storeOp, BufferizationPlan &plan) {
+ // The value and target can be union-ed if the set the value is part of does
+ // not contain any hal.interface.binding.subspan from a different binding.
+ Value value = storeOp.value();
+ Value target = storeOp.target();
+ if (!getEquivalentOpOfType<IREE::HAL::InterfaceBindingSubspanOp>(target,
+ plan)) {
+ return storeOp.emitError(
+ "expected target of store op to already be added to an equivalence "
+ "set");
+ }
+ if (canSetStoreValueAndTargetAsEquivalent(storeOp, plan)) {
+ plan.unionSets(value, target);
+ } else {
+ plan.insert(value);
+ }
+ plan.storeSet(target);
return success();
}
-static LogicalResult convertTransferOp(OpBuilder &b,
- WorkgroupMemoryAllocationFn allocationFn,
- VectorTransferOpInterface op,
- BlockAndValueMapping &bvm) {
- if (op.getShapedType().isa<MemRefType>()) return failure();
- assert(op->getNumResults() == 1);
- Value outputTensor = op->getResult(0);
- OpBuilder::InsertionGuard g(b);
- b.setInsertionPoint(op);
- Location loc = op.getLoc();
- Value newInputBuffer = bvm.lookup(op.source());
- if (auto tensorType =
- op->getResult(0).getType().dyn_cast<RankedTensorType>()) {
- // If the op return a Tensor allocate a buffer for the returned value.
- auto tensorShape = tensorType.getShape();
- SmallVector<Value, 4> dynOperands;
- for (size_t idx : llvm::seq(size_t(0), tensorShape.size())) {
- if (tensorType.isDynamicDim(idx)) {
- Value tensor = bvm.lookupOrNull(outputTensor);
- if (!tensor) tensor = outputTensor;
- dynOperands.push_back(b.createOrFold<memref::DimOp>(loc, tensor, idx));
+static LogicalResult analyseInterfaceBindingSubspanOp(
+ IREE::HAL::InterfaceBindingSubspanOp subspanOp, BufferizationPlan &plan) {
+ plan.insert(subspanOp.getResult());
+ return success();
+}
+
+/// For every result of the LinalgOp, gets the operands (`ins` or `outs`) whose
+/// buffer can be reused for the result.
+static SmallVector<Value> getTiedOperandsForLinalgOps(
+ linalg::LinalgOp linalgOp) {
+ SmallVector<Value> tiedOperands(linalgOp.getOperation()->getNumResults());
+ for (auto outTensor : llvm::enumerate(linalgOp.getOutputs())) {
+ if (linalgOp.payloadUsesValueFromOutputOperandIndex(outTensor.index())) {
+ // If the `outs` tensor has a single use (this op) and is not from a
+ // read-only buffer, the `outs` tensor can be tied to the result.
+ if (outTensor.value().hasOneUse() &&
+ !isFromReadOnlyTensor(outTensor.value())) {
+ tiedOperands[outTensor.index()] = outTensor.value();
}
}
- auto alloc = allocationFn(b, loc, tensorShape, tensorType.getElementType(),
- dynOperands);
- bvm.map(op->getResult(0), alloc);
- transferShapeOpsToMemref(b, op->getResult(0), alloc, bvm);
}
+ for (auto result : llvm::enumerate(linalgOp.getOutputs())) {
+ // If the output tensor is not actually used (for initialization) by this
+ // op, we can reuse the result tensor's buffer for some operands.
+ // TODO(#5040): A better way to handle this case is to allocate a buffer and
+ // then vectorization + load-store forwarding to remove the intermediate
+ // buffer. This requires vectorization to handle all cases downstream. This
+ // is a WAR for current use cases.
+ if (linalgOp.payloadUsesValueFromOutputOperandIndex(result.index())) {
+ continue;
+ }
+ for (auto input : llvm::enumerate(linalgOp.getInputTensors())) {
+ auto producerOp = input.value().getDefiningOp<linalg::LinalgOp>();
+ if (producerOp && input.value().hasOneUse() &&
+ input.value().getType() == result.value().getType() &&
+ linalgOp.getInputIndexingMap(input.index()) ==
+ linalgOp.getOutputIndexingMap(result.index())) {
+ assert(!tiedOperands[result.index()]);
+ tiedOperands[result.index()] = input.value();
+ break;
+ }
+ }
+ }
+ return tiedOperands;
+}
- // Replace the tensor operand.
- if (auto readOp = dyn_cast<vector::TransferReadOp>(op.getOperation())) {
- readOp.sourceMutable().assign(newInputBuffer);
+/// Adds the corresponding `outs` and result tensors of the linalg op into the
+/// same equivalence class.
+static LogicalResult analyseLinalgOps(linalg::LinalgOp linalgOp,
+ BufferizationPlan &plan) {
+ if (!linalgOp.hasTensorSemantics()) return success();
+ auto tiedOperands = getTiedOperandsForLinalgOps(linalgOp);
+ for (auto it :
+ llvm::enumerate(llvm::zip(linalgOp->getResults(), tiedOperands))) {
+ Value resultTensor = std::get<0>(it.value());
+ Value tiedOperand = std::get<1>(it.value());
+ if (tiedOperand) {
+ plan.unionSets(resultTensor, tiedOperand);
+ }
+ plan.insert(linalgOp.getOutput(it.index()));
+ plan.insert(resultTensor);
+ }
+ return success();
+}
+
+/// For operations that have a single operand and result, adds both to the same
+/// equivalence class.
+static LogicalResult analyseSingleOperandResultOp(Value source, Value result,
+ BufferizationPlan &plan) {
+ if (source.hasOneUse() || isFromReadOnlyTensor(source)) {
+ plan.unionSets(source, result);
+ return success();
+ }
+ plan.insert(source);
+ plan.insert(result);
+ return success();
+}
+
+/// Adds the `dest` and `result` tensor of a subtensor insert operation into the
+/// same equivalence class. If `source` is not null also checks that the
+/// `source` and `dest` are not equivalent.
+static LogicalResult analyseDestructiveUpdateOp(Operation *op, Value source,
+ Value dest, Value result,
+ BufferizationPlan &plan) {
+ if (dest.hasOneUse() && !isFromReadOnlyTensor(dest)) {
+ plan.unionSets(dest, result);
+ }
+ if (source && plan.isEquivalent(source, dest)) {
+ return op->emitError(
+ "unexpected source and dest being mapped to same buffer");
+ }
+ plan.insert(dest);
+ plan.insert(result);
+ return success();
+}
+
+static LogicalResult analyseOperations(FuncOp funcOp, BufferizationPlan &plan) {
+ auto bufferMappingFn = [&](Operation *op) -> WalkResult {
+ return TypeSwitch<Operation *, LogicalResult>(op)
+ .Case<ConstantOp>([&](ConstantOp constantOp) {
+ return analyseConstantOp(constantOp, plan);
+ })
+ .Case<IREE::Flow::DispatchTensorLoadOp>(
+ [&](IREE::Flow::DispatchTensorLoadOp loadOp) {
+ return analyseInterfaceLoadTensorOp(loadOp, plan);
+ })
+ .Case<IREE::Flow::DispatchTensorStoreOp>(
+ [&](IREE::Flow::DispatchTensorStoreOp storeOp) {
+ return analyseInterfaceStoreTensorOp(storeOp, plan);
+ })
+ .Case<IREE::Flow::DispatchTieShapeOp>(
+ [&](IREE::Flow::DispatchTieShapeOp tieShapeOp) {
+ return analyseSingleOperandResultOp(tieShapeOp.operand(),
+ tieShapeOp.result(), plan);
+ })
+ .Case<IREE::HAL::InterfaceBindingSubspanOp>(
+ [&](IREE::HAL::InterfaceBindingSubspanOp subspanOp) {
+ return analyseInterfaceBindingSubspanOp(subspanOp, plan);
+ })
+ .Case<linalg::LinalgOp>([&](linalg::LinalgOp linalgOp) {
+ return analyseLinalgOps(linalgOp, plan);
+ })
+ .Case<linalg::TensorReshapeOp>(
+ [&](linalg::TensorReshapeOp tensorReshapeOp) {
+ return analyseSingleOperandResultOp(
+ tensorReshapeOp.src(), tensorReshapeOp.result(), plan);
+ })
+ .Case<SubTensorOp>([&](SubTensorOp subTensorOp) {
+ return analyseSingleOperandResultOp(subTensorOp.source(),
+ subTensorOp.result(), plan);
+ })
+ .Case<SubTensorInsertOp>([&](SubTensorInsertOp subTensorInsertOp) {
+ return analyseDestructiveUpdateOp(
+ subTensorInsertOp, subTensorInsertOp.source(),
+ subTensorInsertOp.dest(), subTensorInsertOp.result(), plan);
+ })
+ .Case<tensor::CastOp>([&](tensor::CastOp castOp) {
+ return analyseSingleOperandResultOp(castOp.source(), castOp.dest(),
+ plan);
+ })
+ .Case<vector::TransferReadOp>(
+ [&](vector::TransferReadOp transferReadOp) {
+ plan.insert(transferReadOp.source());
+ return success();
+ })
+ .Case<vector::TransferWriteOp>(
+ [&](vector::TransferWriteOp transferWriteOp) {
+ return analyseDestructiveUpdateOp(transferWriteOp, nullptr,
+ transferWriteOp.source(),
+ transferWriteOp.result(), plan);
+ })
+ .Default([&](Operation *op) { return success(); });
+ };
+ if (funcOp.walk(bufferMappingFn).wasInterrupted()) {
+ return failure();
+ }
+ DEBUG_WITH_TYPE(DEBUG_TYPE, plan.dump());
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Bufferization helper functions using BlockAndValueMapping.
+//===----------------------------------------------------------------------===//
+
+/// Returns the dynamic dimensions of a Value `v` that is assumed to be
+/// ShapedType.
+static SmallVector<Value, 4> getDynamicDims(OpBuilder &b, Location loc,
+ Value v) {
+ SmallVector<Value, 4> dynamicDims;
+ for (auto shape : enumerate(v.getType().cast<ShapedType>().getShape())) {
+ if (shape.value() == ShapedType::kDynamicSize) {
+ dynamicDims.push_back(
+ b.createOrFold<memref::DimOp>(loc, v, shape.index()));
+ }
+ }
+ return dynamicDims;
+}
+
+/// Allocates a memref for the results of an operation. Uses the
+/// `InferShapedTypeOpInterface` where possible to get the shape of the output
+/// in terms of the shapes of the operands.
+static Value allocateBufferForResult(OpBuilder &b, Operation *op,
+ WorkgroupMemoryAllocationFn allocationFn) {
+ assert(op->getNumResults() == 1);
+ RankedTensorType resultType =
+ op->getResult(0).getType().cast<RankedTensorType>();
+ SmallVector<Value, 4> dynamicDims;
+
+ // Get the shape of the result
+ Location loc = op->getLoc();
+ if (auto shapedOp = dyn_cast<InferShapedTypeOpInterface>(op)) {
+ SmallVector<SmallVector<Value>> resultShape;
+ if (failed(shapedOp.reifyReturnTypeShapesPerResultDim(b, resultShape))) {
+ return nullptr;
+ }
+ for (auto shape : enumerate(resultShape[0])) {
+ if (resultType.isDynamicDim(shape.index())) {
+ dynamicDims.push_back(shape.value());
+ }
+ }
+ } else if (auto subTensorOp = dyn_cast<SubTensorOp>(op)) {
+ dynamicDims = llvm::to_vector<4>(subTensorOp.sizes());
+ } else if (auto subTensorInsertOp = dyn_cast<SubTensorInsertOp>(op)) {
+ dynamicDims = getDynamicDims(b, loc, subTensorInsertOp.dest());
+ } else if (auto transferWriteOp = dyn_cast<vector::TransferWriteOp>(op)) {
+ dynamicDims = getDynamicDims(b, loc, transferWriteOp.source());
} else {
- auto writeOp = cast<vector::TransferWriteOp>(op.getOperation());
- // Create a new transfer_write on buffer that doesn't have a return value.
- // Leave the previous transfer_write to dead code as it still has uses at
- // this point.
- b.create<vector::TransferWriteOp>(
- loc, writeOp.vector(), newInputBuffer, writeOp.indices(),
- writeOp.permutation_map(), writeOp.mask(),
- writeOp.in_bounds() ? *writeOp.in_bounds() : ArrayAttr());
+ return nullptr;
}
- return success();
+ return allocationFn(b, loc, resultType.getShape(),
+ resultType.getElementType(), dynamicDims);
}
-// Extract int64_t values from the assumed ArrayAttr of IntegerAttr.
-static SmallVector<int64_t, 4> extractFromI64ArrayAttr(Attribute attr) {
- return llvm::to_vector<4>(llvm::map_range(
- attr.cast<ArrayAttr>(),
- [](Attribute a) -> int64_t { return a.cast<IntegerAttr>().getInt(); }));
+template <typename TensorType>
+static MemRefType getMemrefTypeForTensor(TensorType tensorType,
+ ArrayRef<AffineMap> layout = {},
+ unsigned memorySpace = 0) {
+ return MemRefType::get(tensorType.getShape(), tensorType.getElementType(),
+ layout, memorySpace);
}
-LogicalResult convertInterfaceLoadTensorOp(
- OpBuilder &b, IREE::Flow::DispatchTensorLoadOp loadOp,
- BlockAndValueMapping &bvm) {
- OpBuilder::InsertionGuard g(b);
- b.setInsertionPoint(loadOp);
- Location loc = loadOp.getLoc();
- Value memref = bvm.lookup(loadOp.source());
- Value res = createSubviewOp(b, loc, memref, loadOp.getMixedOffsets(),
- loadOp.getMixedSizes(), loadOp.getMixedStrides());
-
- bvm.map(loadOp.result(), res);
- transferShapeOpsToMemref(b, loadOp.result(), res, bvm);
- return success();
+/// Creates a subview operation given the `src`, `offsets`, `sizes` and
+/// `strides`. Handles the corner case where the `offsets`, `sizes` and
+/// `strides` are empty in which case just forward the `src` value.
+/// TODO(ataei): Instead create memref.subview %v [][][] folder.
+static Value createSubviewOp(OpBuilder &b, Location loc, Value src,
+ ArrayRef<OpFoldResult> offsets,
+ ArrayRef<OpFoldResult> sizes,
+ ArrayRef<OpFoldResult> strides,
+ MemRefType resultType = MemRefType()) {
+ if (offsets.empty() && sizes.empty() && strides.empty()) return src;
+ return b.create<memref::SubViewOp>(loc, resultType, src, offsets, sizes,
+ strides);
}
+//===----------------------------------------------------------------------===//
+// There might be cases when the `value` stored into a
+// `flow.dispatch.tensor.store` operation is obtained from operation that
+// computes the value (say a `linalg` operation) through a series of `reshapes`,
+// `cast` etc. When trying to reuse the buffer for the result passed in to the
+// dispatch region for these operations, these operations need to be "replayed"
+// in reverse so that the type of the buffer in the operation computing the
+// value matches what is expected.
+//
+// For example,
+// ```mlir
+// %buffer = hal.interface.binding.subspan .. : tensor<?xf32>
+// %result = linalg.matmul ins(%lhs, %rhs : tensor<?x?xf32>, tensor<?x?xf32>)
+// outs(%init : tensor<?x?xf32>) -> tensor<?x?xf32>
+// %value = linalg.tensor_reshape %result [affine_map<(d0, d1) -> (d0, d1)]
+// : tensor<?x?xf32> into tensor<?xf32>
+// flow.dispatch.tensor.store %value, %buffer[..] [..] [..]
+// ```
+//
+// needs to be converted to
+//
+// ```mlir
+// %buffer = hal.interface.binding.subspan .. : memref<?xf32>
+// %result = subview %buffer[..] [..] [..] : memref<?xf32>
+// %value = linalg.reshape %result [affine_map<(d0, d1) -> (d0, d1)]
+// : memref<?xf32> into memref<?x?xf32>
+// linalg.matmul ins(%lhs, %rhs : memref<?x?xf32>, memref<?x?xf32>)
+// outs(%result : memref<?x?xf32>)
+// flow.dispatch.tensor.store %value, %buffer[..] [..] [..]
+// ```
+//
+// ===----------------------------------------------------------------------===//
+
/// For a given store-like `op` that is to be replaced, find the insertion point
/// in the same block earliest possible when
/// - the replacement op uses values in `usedValues`, so has to be inserted
@@ -526,13 +574,12 @@
return nullptr;
}
-/// For cases where the value operand of the `storeOp` is produced by a
-/// LinalgOp, create the subview operation that can be used by the op itself to
-/// store the result into directly. This avoids an extra allocation + copies.
-LogicalResult preProcessInterfaceStoreTensorOp(
- OpBuilder &b, IREE::Flow::DispatchTensorStoreOp storeOp,
- BlockAndValueMapping &bvm) {
- // Find the insertion point for the subview.
+/// Returns the subview into the buffer that is supposed to be populated with
+/// the `value` of the `flow.dispatch.tensor.store` operation. This can be used
+/// to compute the results in place.
+static Value getSubviewOpForTensorStoreOp(
+ OpBuilder &b, Operation *insertBefore,
+ IREE::Flow::DispatchTensorStoreOp storeOp, BlockAndValueMapping &bvm) {
SmallVector<Value, 4> operandsOfSubviewOp;
operandsOfSubviewOp.push_back(bvm.lookup(storeOp.target()));
operandsOfSubviewOp.append(storeOp.offsets().begin(),
@@ -541,80 +588,338 @@
operandsOfSubviewOp.append(storeOp.strides().begin(),
storeOp.strides().end());
Operation *insertionPoint = getInsertionPointForReplacementStoreOp(
- storeOp.getOperation(), storeOp.value().getDefiningOp(),
- operandsOfSubviewOp);
- if (!insertionPoint) return success();
+ storeOp.getOperation(), insertBefore, operandsOfSubviewOp);
+ if (!insertionPoint) return nullptr;
OpBuilder::InsertionGuard g(b);
- b.setInsertionPoint(insertionPoint);
Value subview =
createSubviewOp(b, storeOp.getLoc(), bvm.lookup(storeOp.target()),
storeOp.getMixedOffsets(), storeOp.getMixedSizes(),
storeOp.getMixedStrides());
- bvm.map(storeOp.value(), subview);
+ return subview;
+}
+
+/// Gets the reverse of a `linalg.tensor_reshape` op to get a memref type that
+/// can be used for in-place computation of the result of a disaptch region.
+static Value getReverseOfReshapeOp(OpBuilder &b,
+ linalg::TensorReshapeOp reshapeOp,
+ Value resultBuffer) {
+ auto memrefType = getMemrefTypeForTensor(
+ reshapeOp.getSrcType(), {},
+ resultBuffer.getType().cast<MemRefType>().getMemorySpaceAsInt());
+ return b.create<linalg::ReshapeOp>(reshapeOp.getLoc(), memrefType,
+ resultBuffer, reshapeOp.reassociation());
+}
+
+/// Gets the reverse of a `tensor.cast` op to get a memref type that
+/// can be used for in-place computation of the result of a disaptch region.
+static Value getReverseOfCastOp(OpBuilder &b, tensor::CastOp castOp,
+ Value resultBuffer) {
+ auto memrefType = getMemrefTypeForTensor(
+ castOp.source().getType().cast<RankedTensorType>(),
+ resultBuffer.getType().cast<MemRefType>().getAffineMaps(),
+ resultBuffer.getType().cast<MemRefType>().getMemorySpaceAsInt());
+ return b.create<memref::CastOp>(castOp.getLoc(), memrefType, resultBuffer);
+}
+
+/// For an operation whose `resultValue` is the result of the dispatch region,
+/// gets the buffer to use to compute the value in-place.
+static Value getInplaceResultBuffer(OpBuilder &b, OpResult resultValue,
+ BlockAndValueMapping &bvm) {
+ Operation *currOp = resultValue.getOwner();
+ SmallVector<Operation *> traversedOps;
+
+ // Traverse the use-def chains to get the `flow.dispatch.tensor.store`
+ // operation keeping track of all the traversed operations. Note that the
+ // equivalence set construction should ensure that all operations traversed
+ // here have a single use.
+ while (!isa<IREE::Flow::DispatchTensorStoreOp>(currOp)) {
+ traversedOps.push_back(currOp);
+ if (!currOp->hasOneUse() || currOp->getNumResults() != 1) return nullptr;
+ currOp = *currOp->user_begin();
+ }
+ auto storeOp = dyn_cast<IREE::Flow::DispatchTensorStoreOp>(currOp);
+ if (!storeOp) return nullptr;
+ Operation *insertBefore = &(*b.getInsertionPoint());
+ Value resultBuffer =
+ getSubviewOpForTensorStoreOp(b, insertBefore, storeOp, bvm);
+ if (!resultBuffer) return nullptr;
+ DEBUG_WITH_TYPE(DEBUG_TYPE, {
+ llvm::dbgs() << "Pair :\n\tTensor :";
+ currOp->print(llvm::dbgs());
+ llvm::dbgs() << "\nt\tMemref :";
+ resultBuffer.print(llvm::dbgs());
+ llvm::dbgs() << "\n";
+ });
+
+ // Now replay the instructions that are essentially doing type-conversion, in
+ // reverse, to get the type needed for the operation computing the value.
+ for (auto op : traversedOps) {
+ resultBuffer =
+ TypeSwitch<Operation *, Value>(op)
+ .Case<linalg::LinalgOp, SubTensorInsertOp, vector::TransferWriteOp>(
+ [&](auto op) { return resultBuffer; })
+ .Case<linalg::TensorReshapeOp>(
+ [&](linalg::TensorReshapeOp reshapeOp) {
+ return getReverseOfReshapeOp(b, reshapeOp, resultBuffer);
+ })
+ .Case<tensor::CastOp>([&](tensor::CastOp castOp) {
+ return getReverseOfCastOp(b, castOp, resultBuffer);
+ })
+ .Default([&](Operation *) { return nullptr; });
+ if (!resultBuffer) return nullptr;
+ bvm.map(op->getResult(0), resultBuffer);
+ DEBUG_WITH_TYPE(DEBUG_TYPE, {
+ llvm::dbgs() << "Pair :\n\tTensor :";
+ op->print(llvm::dbgs());
+ llvm::dbgs() << "\nt\tMemref :";
+ resultBuffer.print(llvm::dbgs());
+ llvm::dbgs() << "\n";
+ });
+ }
+ return resultBuffer;
+}
+
+/// Converts a `tensor.cast` operation into a `memref.cast` operation with the
+/// result aliasing the buffer for the operand.
+static Value getAliasingBufferForCastResult(OpBuilder &b, tensor::CastOp castOp,
+ BlockAndValueMapping &bvm) {
+ Value inputBuffer = bvm.lookup(castOp.source());
+ Value resultTensor = castOp.dest();
+ auto outputType = getMemrefTypeForTensor(
+ resultTensor.getType().cast<RankedTensorType>(), {},
+ inputBuffer.getType().cast<MemRefType>().getMemorySpaceAsInt());
+ return b.create<memref::CastOp>(castOp.getLoc(), outputType, inputBuffer);
+}
+
+/// Converts a `linalg.tensor_reshape` operation to a `linalg.reshape`
+/// operation with the result aliasing the buffer for the operand.
+static Value getAliasingBufferForReshapeResult(OpBuilder &b,
+ linalg::TensorReshapeOp op,
+ BlockAndValueMapping &bvm) {
+ Location loc = op.getLoc();
+ Value srcTensor = op.src();
+ RankedTensorType resultTensorType = op.getResultType();
+ Value inputBuffer = bvm.lookup(srcTensor);
+
+ // Create the reshape op.
+ MemRefType inputBufferType = inputBuffer.getType().cast<MemRefType>();
+ auto reshapeResultType = getMemrefTypeForTensor(
+ resultTensorType, {}, inputBufferType.getMemorySpaceAsInt());
+ Value bufferReshape = b.create<linalg::ReshapeOp>(
+ loc, reshapeResultType, inputBuffer, op.reassociation());
+ return bufferReshape;
+}
+
+/// Converts a `subtensor` operation to a `subview` operation.
+static Value getAliasingBufferForSubtensorResult(OpBuilder &b, SubTensorOp op,
+ BlockAndValueMapping &bvm) {
+ Location loc = op.getLoc();
+ Value srcTensor = op.source();
+ Value inputBuffer = bvm.lookup(srcTensor);
+
+ ShapedType sourceType = op.getSourceType();
+ ShapedType resultType = op.getType();
+ SmallVector<OpFoldResult> offsets = op.getMixedOffsets();
+ SmallVector<OpFoldResult> sizes = op.getMixedSizes();
+ SmallVector<OpFoldResult> strides = op.getMixedStrides();
+ MemRefType subViewResultType =
+ (resultType.getRank() < sourceType.getRank()
+ ? memref::SubViewOp::inferRankReducedResultType(
+ resultType.getRank(), inputBuffer.getType().cast<MemRefType>(),
+ offsets, sizes, strides)
+ .cast<MemRefType>()
+ : MemRefType());
+ return b.create<memref::SubViewOp>(loc, subViewResultType, inputBuffer,
+ offsets, sizes, strides);
+}
+
+/// Computes the `memrefs` to use for the result of an operation based on
+/// - If the result has a tied operand reuse the buffer for the tied operand (or
+/// an alias of it) as the buffer for the result. The `tiedOperands` vector is
+/// expected to be as large as the number of results.
+/// - If the result has no tied operands, the corresponding position in the
+/// `tiedOperands` list must be `nullptr`.
+/// - If the result is in the same equivalence set as the result of the dispatch
+/// region (i.e. `value` operand of a `flow.dispatch.tensor.store`) then
+/// return an alias/view of the buffer passed into the dispatch region to
+/// store the results.
+/// - Lastly, allocate a temporary buffer for the result using the passed
+/// allocation function.
+static LogicalResult getOrAllocateResultBuffers(
+ OpBuilder &b, Operation *op, ArrayRef<Value> tiedOperands,
+ BlockAndValueMapping &bvm, BufferizationPlan &plan,
+ WorkgroupMemoryAllocationFn allocationFn) {
+ for (auto result : llvm::enumerate(op->getResults())) {
+ if (bvm.contains(result.value())) continue;
+ Value buffer;
+ if (tiedOperands[result.index()] &&
+ plan.isEquivalent(tiedOperands[result.index()], result.value())) {
+ buffer =
+ TypeSwitch<Operation *, Value>(op)
+ .Case<linalg::TensorReshapeOp>(
+ [&](linalg::TensorReshapeOp reshapeOp) {
+ return getAliasingBufferForReshapeResult(b, reshapeOp, bvm);
+ })
+ .Case<SubTensorInsertOp>(
+ [&](SubTensorInsertOp subTensorInsertOp) {
+ return bvm.lookupOrNull(subTensorInsertOp.dest());
+ })
+ .Case<SubTensorOp>([&](SubTensorOp subTensorOp) {
+ return getAliasingBufferForSubtensorResult(b, subTensorOp, bvm);
+ })
+ .Case<tensor::CastOp>([&](tensor::CastOp castOp) {
+ return getAliasingBufferForCastResult(b, castOp, bvm);
+ })
+ .Case<linalg::LinalgOp>([&](linalg::LinalgOp linalgOp) {
+ return bvm.lookupOrNull(linalgOp.getOutput(result.index()));
+ })
+ .Default([&](Operation *op) { return nullptr; });
+ }
+ if (!buffer && plan.isInStoreSet(result.value())) {
+ buffer = getInplaceResultBuffer(b, result.value(), bvm);
+ }
+ if (!buffer) {
+ buffer = allocateBufferForResult(b, op, allocationFn);
+ }
+ if (!buffer) {
+ return op->emitError("unable to get result buffer for op");
+ }
+ bvm.map(result.value(), buffer);
+ DEBUG_WITH_TYPE(DEBUG_TYPE, {
+ llvm::dbgs() << "Pair :\n\tTensor :";
+ op->print(llvm::dbgs());
+ llvm::dbgs() << "\nt\tMemref :";
+ buffer.print(llvm::dbgs());
+ llvm::dbgs() << "\n";
+ });
+ }
return success();
}
-/// Pre process linalg operations (on tensors) to propagate buffer assignment
-/// from results to operands wherever possible.
-LogicalResult preProcessLinalgOps(OpBuilder &b, linalg::LinalgOp op,
- BlockAndValueMapping &bvm) {
- if (!op.hasTensorSemantics()) return success();
+/// Generic conversion pattern that matches any linalg::LinalgOp. This avoids
+/// template instantiating one pattern for each linalg::LinalgOp. The method
+/// expects all operands and results have already been mapped to memrefs.
+static LogicalResult convertAnyLinalgOp(
+ OpBuilder &b, linalg::LinalgOp op, BlockAndValueMapping &bvm,
+ BufferizationPlan &plan, WorkgroupMemoryAllocationFn allocationFn) {
+ // Skip linalg ops inserted by this pass.
+ if (op.hasBufferSemantics()) return success();
- for (auto en :
- llvm::zip(op.getOperation()->getResults(), op.getOutputTensors())) {
- Value resultTensor = std::get<0>(en);
- Value outTensor = std::get<1>(en);
- unsigned resultIndex = resultTensor.cast<OpResult>().getResultNumber();
- Value resultBuffer = bvm.lookupOrNull(resultTensor);
-
- // If the result is mapped to a buffer, the corresponding output tensor can
- // be mapped to the same buffer to make this an inplace update.
- if (resultBuffer && outTensor.hasOneUse()) {
- bvm.map(outTensor, resultBuffer);
+ Location loc = op.getLoc();
+ SmallVector<Value, 2> newInputBuffers;
+ newInputBuffers.reserve(op.getNumInputs());
+ for (Value v : op.getInputs()) {
+ // For `linalg.poolin_*` ops, the input might be from a
+ // `linalg.init_tensor`. In such cases, the `BlockAndValueMapping` wont have
+ // a mapping for the buffer. Allocate a buffer for these.
+ Value inputBuffer = bvm.lookupOrNull(v);
+ if (!inputBuffer) {
+ inputBuffer = allocateBufferForResult(b, v.getDefiningOp(), allocationFn);
}
+ newInputBuffers.push_back(inputBuffer);
+ }
+ SmallVector<Value, 2> newOutputBuffers;
+ for (auto it : llvm::enumerate(
+ llvm::zip(op.getOperation()->getResults(), op.getOutputs()))) {
+ Value resultTensor = std::get<0>(it.value());
+ Value resultBuffer = bvm.lookup(resultTensor);
- // If the output tensor is not actually used (for initialization) by this
- // op, we can reuse the result tensor's buffer for some operands.
- if (!op.payloadUsesValueFromOutputOperandIndex(resultIndex)) {
- for (auto en : llvm::enumerate(op.getInputTensors())) {
- Value operand = en.value();
- auto producerOp = operand.getDefiningOp<linalg::LinalgOp>();
- if (producerOp && operand.hasOneUse() &&
- operand.getType() == resultTensor.getType() &&
- op.getInputIndexingMap(en.index()) ==
- op.getOutputIndexingMap(resultIndex)) {
- bvm.map(operand, resultBuffer);
- break;
- }
- }
+ Value outTensor = std::get<1>(it.value());
+ Value outBuffer = bvm.lookupOrNull(outTensor);
+ if (outBuffer && !plan.isEquivalent(outTensor, resultTensor) &&
+ op.payloadUsesValueFromOutputOperandIndex(it.index())) {
+ b.create<linalg::CopyOp>(loc, outBuffer, resultBuffer);
}
+ newOutputBuffers.push_back(resultBuffer);
}
+ SmallVector<Value, 8> newOperands(newInputBuffers.begin(),
+ newInputBuffers.end());
+ newOperands.append(newOutputBuffers.begin(), newOutputBuffers.end());
+ auto otherOperands =
+ llvm::map_range(op.getAssumedNonShapedOperands(),
+ [&bvm](Value v) { return bvm.lookupOrDefault(v); });
+ newOperands.append(otherOperands.begin(), otherOperands.end());
+ op.clone(b, loc, {}, newOperands);
return success();
}
-// Check if the buffer being copied from and being stored to are the same. If so
-// this copy is unnecessary since the output has been updated in place.
-bool isRedundantCopy(Value storeTo, Value storeFrom) {
- if (storeTo == storeFrom) return true;
- auto storeFromOp = storeFrom.getDefiningOp<memref::SubViewOp>();
- return storeFromOp && storeFromOp.source() == storeTo;
+/// Constants that return tensor types can be handled natively by the
+/// backends. Here just provide a cast to memref to bridge the gap from tensors
+/// to memrefs.
+static LogicalResult convertConstantOp(OpBuilder &b, ConstantOp constantOp,
+ BlockAndValueMapping &bvm) {
+ Value result = constantOp.getResult();
+ assert(!bvm.lookupOrNull(result));
+ RankedTensorType tensorType = result.getType().dyn_cast<RankedTensorType>();
+ if (!tensorType) return success();
+ OpBuilder::InsertionGuard g(b);
+ b.setInsertionPointAfter(constantOp);
+ auto memrefType = getMemrefTypeForTensor(tensorType);
+ Value memref =
+ b.create<memref::BufferCastOp>(constantOp.getLoc(), memrefType, result);
+ bvm.map(result, memref);
+ return success();
}
-LogicalResult convertInterfaceStoreTensorOp(
- OpBuilder &b, IREE::Flow::DispatchTensorStoreOp storeOp,
+static LogicalResult convertDimOp(OpBuilder &b, memref::DimOp dimOp,
+ BlockAndValueMapping &bvm) {
+ if (Value v = bvm.lookupOrNull(dimOp.memrefOrTensor())) {
+ dimOp.memrefOrTensorMutable().assign(v);
+ }
+ return success();
+}
+
+static LogicalResult convertDispatchTieShapeOp(
+ OpBuilder &b, IREE::Flow::DispatchTieShapeOp shapeOp,
BlockAndValueMapping &bvm) {
+ if (Value v = bvm.lookupOrNull(shapeOp.operand())) {
+ auto tieShapeOp = b.create<Shape::TieShapeOp>(shapeOp.getLoc(), v.getType(),
+ v, shapeOp.shape());
+ bvm.map(shapeOp.getResult(), tieShapeOp.getResult());
+ }
+ return success();
+}
+
+/// Converts a `tensor.extract` operation into a `load`.
+static LogicalResult convertTensorExtractOp(OpBuilder &b, tensor::ExtractOp op,
+ BlockAndValueMapping &bvm) {
+ OpBuilder::InsertionGuard g(b);
+ b.setInsertionPoint(op);
+ Value inputBuffer = bvm.lookup(op.tensor());
+ Value load =
+ b.createOrFold<memref::LoadOp>(op.getLoc(), inputBuffer, op.indices());
+ bvm.map(op.result(), load);
+ return success();
+}
+
+static LogicalResult convertInterfaceLoadTensorOp(
+ OpBuilder &b, IREE::Flow::DispatchTensorLoadOp loadOp,
+ BlockAndValueMapping &bvm) {
+ OpBuilder::InsertionGuard g(b);
+ b.setInsertionPoint(loadOp);
+ Location loc = loadOp.getLoc();
+ Value memref = bvm.lookup(loadOp.source());
+ Value res = createSubviewOp(b, loc, memref, loadOp.getMixedOffsets(),
+ loadOp.getMixedSizes(), loadOp.getMixedStrides());
+ bvm.map(loadOp.result(), res);
+ return success();
+}
+
+/// Converts a `flow.dispatch.tensor.store` operation to memrefs. If the `value`
+/// and `target` are in the same equivalent set, then there is nothing to do. If
+/// no create a subview into the result buffer and copy the `value`.
+static LogicalResult convertInterfaceStoreTensorOp(
+ OpBuilder &b, IREE::Flow::DispatchTensorStoreOp storeOp,
+ BlockAndValueMapping &bvm, BufferizationPlan &plan) {
+ if (plan.isEquivalent(storeOp.target(), storeOp.value())) {
+ storeOp->erase();
+ return success();
+ }
OpBuilder::InsertionGuard g(b);
b.setInsertionPoint(storeOp);
Value storeTo = bvm.lookup(storeOp.target());
Value storeFrom = bvm.lookup(storeOp.value());
- // If the value already has a mapping, it should already have been updated in
- // place by the converted producer.
- if (isRedundantCopy(storeTo, storeFrom)) {
- storeOp->erase();
- return success();
- }
-
Value subview =
createSubviewOp(b, storeOp.getLoc(), storeTo, storeOp.getMixedOffsets(),
storeOp.getMixedSizes(), storeOp.getMixedStrides());
@@ -624,19 +929,99 @@
return success();
}
-// Forwards buffer assigned to cast inputs to its outputs.
-LogicalResult convertTensorCastOp(OpBuilder &b,
- WorkgroupMemoryAllocationFn allocationFn,
- tensor::CastOp castOp,
- BlockAndValueMapping &bvm) {
- Value inputBuffer = bvm.lookup(castOp.source());
- // Note: tensor.cast isn't suppose to do any data-movements, so we should
- // never need to allocate and copy data to the result tensor.
- bvm.map(castOp.dest(), inputBuffer);
+/// Converts a `subtensor_insert` operation to buffers by
+/// - Allocating a buffer for the result (if needed), and copying the
+/// destination value into this buffer.
+/// - Copying the source values into a subview of the result buffer.
+static LogicalResult convertSubTensorInsertOp(OpBuilder &b,
+ SubTensorInsertOp op,
+ BlockAndValueMapping &bvm,
+ BufferizationPlan &plan) {
+ Location loc = op.getLoc();
+ Value result = op.getResult();
+ ShapedType resultType = op.getType();
+ Value resultBuffer = bvm.lookup(result);
+
+ // If `dest` and `result` are not equivalent, need a copy for that.
+ if (!plan.isEquivalent(op.dest(), result)) {
+ Value destBuffer = bvm.lookup(op.dest());
+ b.create<linalg::CopyOp>(loc, destBuffer, resultBuffer);
+ }
+
+ // Copy from the source to the result subview.
+ Value source = op.source();
+ ShapedType sourceType = op.getSourceType();
+ Value sourceBuffer = bvm.lookup(source);
+ SmallVector<OpFoldResult> offsets = op.getMixedOffsets();
+ SmallVector<OpFoldResult> sizes = op.getMixedSizes();
+ SmallVector<OpFoldResult> strides = op.getMixedStrides();
+ MemRefType subViewResultType =
+ (sourceType.getRank() < resultType.getRank()
+ ? memref::SubViewOp::inferRankReducedResultType(
+ sourceType.getRank(),
+ resultBuffer.getType().cast<MemRefType>(), offsets, sizes,
+ strides)
+ .cast<MemRefType>()
+ : MemRefType());
+ Value subViewOp = createSubviewOp(b, loc, resultBuffer, offsets, sizes,
+ strides, subViewResultType);
+ b.create<linalg::CopyOp>(loc, sourceBuffer, subViewOp);
return success();
}
+/// Converts a vector.transfer_read op to use memref operands for source.
+static LogicalResult convertVectorTransferReadOp(
+ OpBuilder &b, vector::TransferReadOp transferReadOp,
+ BlockAndValueMapping &bvm) {
+ Value source = transferReadOp.source();
+ if (!source.getType().isa<RankedTensorType>()) return success();
+ Value memref = bvm.lookup(source);
+ transferReadOp.sourceMutable().assign(memref);
+ return success();
+}
+
+/// Converts a vector.transfer_write op to use memref operands for source.
+static LogicalResult convertVectorTransferWriteOp(OpBuilder &b,
+ vector::TransferWriteOp op,
+ BlockAndValueMapping &bvm,
+ BufferizationPlan &plan) {
+ Location loc = op.getLoc();
+ Value result = op.result();
+ RankedTensorType resultType = result.getType().dyn_cast<RankedTensorType>();
+ if (!resultType) return success();
+ Value resultBuffer = bvm.lookup(result);
+
+ if (!plan.isEquivalent(op.source(), result)) {
+ Value destBuffer = bvm.lookup(op.source());
+ b.create<linalg::CopyOp>(loc, destBuffer, resultBuffer);
+ }
+
+ // Create a new vector.transfer_write operation without a result value.
+ b.create<vector::TransferWriteOp>(
+ loc, op.vector(), resultBuffer, op.indices(), op.permutation_map(),
+ op.mask(), op.in_bounds() ? *op.in_bounds() : ArrayAttr());
+ return success();
+}
+
+/// If the alias of the buffer for an input oeprand cannot be used for the
+/// "tied" results, need to do an explicit copy of the memory pointed to by the
+/// aliased buffer into the buffer assigned to the result.
+static void copyFromAliasingBufferToResultBuffer(OpBuilder &b, Location loc,
+ ArrayRef<Value> tiedOperands,
+ ArrayRef<Value> tiedResults,
+ BlockAndValueMapping &bvm,
+ BufferizationPlan &plan) {
+ for (auto result : enumerate(tiedResults)) {
+ Value operand = tiedOperands[result.index()];
+ if (!plan.isEquivalent(result.value(), operand)) {
+ b.create<linalg::CopyOp>(loc, bvm.lookup(operand),
+ bvm.lookup(result.value()));
+ }
+ }
+}
+
namespace {
+/// Pass to convert from tensor based ops to memref based ops.
class LinalgBufferizePass
: public PassWrapper<LinalgBufferizePass, FunctionPass> {
public:
@@ -652,26 +1037,20 @@
};
} // namespace
-// Special handling of dynamic sizes that must tie to InterfaceBindingSubspanOp.
-// This is necessary to propagate the InterfaceLoadConstantOp to memrefs.
-// In tensor world, the information is carried by TieShape ops.
-// TODO(ravishankarm): This needs to be moved to MaterializeInterface pass so
-// that here we dont need to deal with tie-shape ops.
-static Shape::MakeRankedShapeOp getMakeRankedShapeFromInterface(
- IREE::HAL::InterfaceBindingSubspanOp op) {
- for (Operation *user : op->getUsers()) {
- auto tieOp = dyn_cast<IREE::Flow::DispatchTieShapeOp>(user);
- if (!tieOp) continue;
- auto makeRankedShapeOp =
- tieOp.shape().getDefiningOp<Shape::MakeRankedShapeOp>();
- assert(makeRankedShapeOp);
- return makeRankedShapeOp;
- }
- llvm_unreachable("Expected IREE::Flow::DispatchTieShapeOp of op");
-}
-
void LinalgBufferizePass::runOnFunction() {
+ BufferizationPlan plan;
FuncOp funcOp = getFunction();
+ if (failed(analyseOperations(funcOp, plan))) {
+ return signalPassFailure();
+ }
+ if (funcOp
+ .walk([&](IREE::Flow::DispatchTensorStoreOp storeOp) -> WalkResult {
+ return analyseInterfaceStoreTensorOp(storeOp, plan);
+ })
+ .wasInterrupted()) {
+ return signalPassFailure();
+ }
+
MLIRContext *context = &getContext();
OpBuilder b(context);
@@ -689,89 +1068,78 @@
// the base buffer.
auto tensorType =
op.result().getType().cast<IREE::Flow::DispatchTensorType>();
- auto memRefType =
- MemRefType::get(tensorType.getShape(), tensorType.getElementType());
+ auto memRefType = getMemrefTypeForTensor(tensorType);
auto baseBuffer = b.create<IREE::HAL::InterfaceBindingSubspanOp>(
op->getLoc(), memRefType, op.binding(), op.byte_offset(),
op.byte_length());
bvm.map(op, baseBuffer);
- transferShapeOpsToMemref(b, op.getResult(), baseBuffer.getResult(), bvm);
});
- if (funcOp
- .walk([&](IREE::Flow::DispatchTensorStoreOp op) -> WalkResult {
- return preProcessInterfaceStoreTensorOp(b, op, bvm);
- })
- .wasInterrupted()) {
- return signalPassFailure();
- }
-
- // Walk backward and forward buffers assigned to tensor.cast results to their
- // inputs.
- SmallVector<tensor::CastOp> castOps;
- funcOp.walk([&castOps](tensor::CastOp castOp) { castOps.push_back(castOp); });
- for (tensor::CastOp castOp : llvm::reverse(castOps)) {
- auto outBuffer = bvm.lookup(castOp.dest());
- if (outBuffer) {
- bvm.map(castOp.source(), outBuffer);
- }
- }
-
- /// Walk the linalg operations backwards (if they are all in the same basic
- /// block) to propagate buffer usage backwards to reduce the need for
- /// allocation. This works for simple cases where all the linalg operations
- /// are within the same basic block. Fallback is to create a separate
- /// allocation for the output.
- {
- SmallVector<linalg::LinalgOp, 4> linalgOps;
- SmallVector<Operation *, 4> tiledLoops;
- if (succeeded(getLinalgOps(funcOp, linalgOps, tiledLoops))) {
- for (linalg::LinalgOp op : llvm::reverse(linalgOps)) {
- if (failed(preProcessLinalgOps(b, op, bvm))) {
- return signalPassFailure();
- }
- }
- }
- }
-
auto conversionDispatch = [&](Operation *op) -> WalkResult {
return TypeSwitch<Operation *, LogicalResult>(op)
.Case<ConstantOp>([&](ConstantOp constantOp) {
return convertConstantOp(b, constantOp, bvm);
})
+ .Case<memref::DimOp>(
+ [&](memref::DimOp dimOp) { return convertDimOp(b, dimOp, bvm); })
.Case<IREE::Flow::DispatchTensorLoadOp>(
[&](IREE::Flow::DispatchTensorLoadOp loadOp) {
return convertInterfaceLoadTensorOp(b, loadOp, bvm);
})
.Case<IREE::Flow::DispatchTensorStoreOp>(
[&](IREE::Flow::DispatchTensorStoreOp storeOp) {
- return convertInterfaceStoreTensorOp(b, storeOp, bvm);
+ return convertInterfaceStoreTensorOp(b, storeOp, bvm, plan);
})
- .Case<tensor::CastOp>([&](tensor::CastOp castOp) {
- return convertTensorCastOp(b, allocationFn, castOp, bvm);
- })
+ .Case<IREE::Flow::DispatchTieShapeOp>(
+ [&](IREE::Flow::DispatchTieShapeOp shapeOp) {
+ return convertDispatchTieShapeOp(b, shapeOp, bvm);
+ })
+ .Case<linalg::TensorReshapeOp, tensor::CastOp, SubTensorOp>(
+ [&](auto aliasingOp) {
+ if (failed(getOrAllocateResultBuffers(b, aliasingOp,
+ aliasingOp->getOperand(0),
+ bvm, plan, allocationFn))) {
+ return failure();
+ }
+ copyFromAliasingBufferToResultBuffer(
+ b, aliasingOp->getLoc(), aliasingOp->getOperand(0),
+ aliasingOp->getResult(0), bvm, plan);
+ return success();
+ })
.Case<linalg::LinalgOp>([&](linalg::LinalgOp linalgOp) {
- return convertAnyLinalgOp(b, allocationFn, linalgOp, bvm);
+ SmallVector<Value> tiedOperands =
+ getTiedOperandsForLinalgOps(linalgOp);
+ if (failed(getOrAllocateResultBuffers(b, linalgOp.getOperation(),
+ tiedOperands, bvm, plan,
+ allocationFn))) {
+ return failure();
+ }
+ return convertAnyLinalgOp(b, linalgOp, bvm, plan, allocationFn);
})
.Case<SubTensorInsertOp>([&](SubTensorInsertOp subTensorInsertOp) {
- return convertSubTensorInsertOp(b, allocationFn, subTensorInsertOp,
- bvm);
- })
- .Case<SubTensorOp>([&](SubTensorOp subTensorOp) {
- return convertSubTensorOp(b, allocationFn, subTensorOp, bvm);
- })
- .Case<linalg::TensorReshapeOp>([&](linalg::TensorReshapeOp reshapeOp) {
- return convertTensorReshapeOp(b, allocationFn, reshapeOp, bvm);
- })
- .Case<linalg::InitTensorOp>([&](linalg::InitTensorOp initTensorOp) {
- return convertInitTensorOp(b, allocationFn, initTensorOp, bvm);
+ if (failed(getOrAllocateResultBuffers(b, subTensorInsertOp,
+ subTensorInsertOp.dest(), bvm,
+ plan, allocationFn))) {
+ return failure();
+ }
+ return convertSubTensorInsertOp(b, subTensorInsertOp, bvm, plan);
})
.Case<tensor::ExtractOp>([&](tensor::ExtractOp extractOp) {
return convertTensorExtractOp(b, extractOp, bvm);
})
- .Case<VectorTransferOpInterface>(
- [&](VectorTransferOpInterface vectorTransferOp) {
- return convertTransferOp(b, allocationFn, vectorTransferOp, bvm);
+ .Case<vector::TransferReadOp>(
+ [&](vector::TransferReadOp transferReadOp) {
+ return convertVectorTransferReadOp(b, transferReadOp, bvm);
+ })
+ .Case<vector::TransferWriteOp>(
+ [&](vector::TransferWriteOp transferWriteOp) {
+ if (failed(getOrAllocateResultBuffers(b, transferWriteOp,
+ transferWriteOp.source(),
+ bvm, plan, allocationFn))) {
+ return failure();
+ }
+ return convertVectorTransferWriteOp(b, transferWriteOp, bvm,
+ plan);
})
.Default([&](Operation *op) {
// Replace any scalar remapped operands to the new values.
@@ -787,7 +1155,12 @@
return success();
});
};
- if (funcOp.walk(conversionDispatch).wasInterrupted()) {
+
+ auto walkResult = funcOp.walk([&](Operation *op) -> WalkResult {
+ b.setInsertionPoint(op);
+ return conversionDispatch(op);
+ });
+ if (walkResult.wasInterrupted()) {
return signalPassFailure();
}
}
diff --git a/iree/compiler/Conversion/Common/test/linalg_bufferize.mlir b/iree/compiler/Conversion/Common/test/linalg_bufferize.mlir
index 85fa706..d506e8b 100644
--- a/iree/compiler/Conversion/Common/test/linalg_bufferize.mlir
+++ b/iree/compiler/Conversion/Common/test/linalg_bufferize.mlir
@@ -10,7 +10,6 @@
%1 = hal.interface.binding.subspan @io::@TENSOR_RHS[%c0] : !flow.dispatch.tensor<readonly:?x?xf32>
%2 = hal.interface.binding.subspan @io::@TENSOR_INIT[%c0] : !flow.dispatch.tensor<readonly:?x?xf32>
%3 = hal.interface.binding.subspan @io::@ret0[%c0] : !flow.dispatch.tensor<writeonly:?x?xf32>
-
%4 = hal.interface.workgroup.id[0] : index
%5 = hal.interface.workgroup.id[1] : index
scf.for %arg0 = %5 to %c2 step %c2 {
@@ -49,6 +48,98 @@
// -----
+func @tile_from_tensor_load_inplace() {
+ %c0 = constant 0 : index
+ %c2 = constant 2 : index
+ %c4 = constant 4 : index
+ %c1 = constant 1 : index
+ %c3 = constant 3 : index
+ %0 = hal.interface.binding.subspan @io::@TENSOR_LHS[%c0] : !flow.dispatch.tensor<readonly:?x?xf32>
+ %1 = hal.interface.binding.subspan @io::@TENSOR_RHS[%c0] : !flow.dispatch.tensor<readonly:?x?xf32>
+ %2 = hal.interface.binding.subspan @io::@TENSOR_INIT[%c0] : !flow.dispatch.tensor<readwrite:?x?xf32>
+ %4 = hal.interface.workgroup.id[0] : index
+ %5 = hal.interface.workgroup.id[1] : index
+ scf.for %arg0 = %5 to %c2 step %c2 {
+ scf.for %arg1 = %4 to %c4 step %c4 {
+ %6 = flow.dispatch.tensor.load %0, offsets = [%arg0, %c0], sizes = [%c1, %c3], strides = [%c1, %c1] : !flow.dispatch.tensor<readonly:?x?xf32> -> tensor<1x3xf32>
+ %7 = flow.dispatch.tensor.load %1, offsets = [%c0, %arg1], sizes = [%c3, %c1], strides = [%c1, %c1] : !flow.dispatch.tensor<readonly:?x?xf32> -> tensor<3x1xf32>
+ %8 = flow.dispatch.tensor.load %2, offsets = [%arg0, %arg1], sizes = [%c1, %c1], strides = [%c1, %c1] : !flow.dispatch.tensor<readwrite:?x?xf32> -> tensor<1x1xf32>
+ %9 = linalg.matmul ins(%6, %7 : tensor<1x3xf32>, tensor<3x1xf32>) outs(%8 : tensor<1x1xf32>) -> tensor<1x1xf32>
+ flow.dispatch.tensor.store %9, %2, offsets = [%arg0, %arg1], sizes = [%c1, %c1], strides = [%c1, %c1] : tensor<1x1xf32> -> !flow.dispatch.tensor<readwrite:?x?xf32>
+ }
+ }
+ return
+}
+
+hal.interface @io attributes {sym_visibility = "private"} {
+ hal.interface.binding @TENSOR_LHS, set=0, binding=0, type="StorageBuffer", access="Read"
+ hal.interface.binding @TENSOR_RHS, set=0, binding=1, type="StorageBuffer", access="Read"
+ hal.interface.binding @TENSOR_INIT, set=0, binding=2, type="StorageBuffer", access="Read|Write"
+}
+// CHECK-LABEL: func @tile_from_tensor_load_inplace()
+// CHECK-DAG: %[[TENSOR_LHS:.+]] = hal.interface.binding.subspan @io::@TENSOR_LHS
+// CHECK-DAG: %[[TENSOR_RHS:.+]] = hal.interface.binding.subspan @io::@TENSOR_RHS
+// CHECK-DAG: %[[RETURN:.+]] = hal.interface.binding.subspan @io::@TENSOR_INIT
+// CHECK: scf.for %[[IV0:.+]] = {{.+}} {
+// CHECK: scf.for %[[IV1:.+]] = {{.+}} {
+// CHECK-DAG: %[[LHS:.+]] = memref.subview %[[TENSOR_LHS]][%[[IV0]], 0] [1, 3] [1, 1]
+// CHECK-DAG: %[[RHS:.+]] = memref.subview %[[TENSOR_RHS]][0, %[[IV1]]] [3, 1] [1, 1]
+// CHECK-DAG: %[[RESULT:.+]] = memref.subview %[[RETURN]][%[[IV0]], %[[IV1]]] [1, 1] [1, 1]
+// CHECK: linalg.matmul
+// CHECK-SAME: ins(%[[LHS]], %[[RHS]]
+// CHECK-SAME: outs(%[[RESULT]]
+
+// -----
+
+func @tile_from_tensor_load_inplace_and_copy() {
+ %c0 = constant 0 : index
+ %c2 = constant 2 : index
+ %c4 = constant 4 : index
+ %c1 = constant 1 : index
+ %c3 = constant 3 : index
+ %0 = hal.interface.binding.subspan @io::@TENSOR_LHS[%c0] : !flow.dispatch.tensor<readonly:?x?xf32>
+ %1 = hal.interface.binding.subspan @io::@TENSOR_RHS[%c0] : !flow.dispatch.tensor<readonly:?x?xf32>
+ %2 = hal.interface.binding.subspan @io::@TENSOR_INIT[%c0] : !flow.dispatch.tensor<readwrite:?x?xf32>
+ %3 = hal.interface.binding.subspan @io::@ret0[%c0] : !flow.dispatch.tensor<writeonly:?x?xf32>
+ %4 = hal.interface.workgroup.id[0] : index
+ %5 = hal.interface.workgroup.id[1] : index
+ scf.for %arg0 = %5 to %c2 step %c2 {
+ scf.for %arg1 = %4 to %c4 step %c4 {
+ %6 = flow.dispatch.tensor.load %0, offsets = [%arg0, %c0], sizes = [%c1, %c3], strides = [%c1, %c1] : !flow.dispatch.tensor<readonly:?x?xf32> -> tensor<1x3xf32>
+ %7 = flow.dispatch.tensor.load %1, offsets = [%c0, %arg1], sizes = [%c3, %c1], strides = [%c1, %c1] : !flow.dispatch.tensor<readonly:?x?xf32> -> tensor<3x1xf32>
+ %8 = flow.dispatch.tensor.load %2, offsets = [%arg0, %arg1], sizes = [%c1, %c1], strides = [%c1, %c1] : !flow.dispatch.tensor<readwrite:?x?xf32> -> tensor<1x1xf32>
+ %9 = linalg.matmul ins(%6, %7 : tensor<1x3xf32>, tensor<3x1xf32>) outs(%8 : tensor<1x1xf32>) -> tensor<1x1xf32>
+ flow.dispatch.tensor.store %9, %2, offsets = [%arg0, %arg1], sizes = [%c1, %c1], strides = [%c1, %c1] : tensor<1x1xf32> -> !flow.dispatch.tensor<readwrite:?x?xf32>
+ flow.dispatch.tensor.store %9, %3, offsets = [%arg0, %arg1], sizes = [%c1, %c1], strides = [%c1, %c1] : tensor<1x1xf32> -> !flow.dispatch.tensor<writeonly:?x?xf32>
+ }
+ }
+ return
+}
+
+hal.interface @io attributes {sym_visibility = "private"} {
+ hal.interface.binding @TENSOR_LHS, set=0, binding=0, type="StorageBuffer", access="Read"
+ hal.interface.binding @TENSOR_RHS, set=0, binding=1, type="StorageBuffer", access="Read"
+ hal.interface.binding @TENSOR_INIT, set=0, binding=2, type="StorageBuffer", access="Read|Write"
+ hal.interface.binding @ret0, set=0, binding=3, type="StorageBuffer", access="Write|Discard"
+}
+// CHECK-LABEL: func @tile_from_tensor_load_inplace_and_copy()
+// CHECK-DAG: %[[TENSOR_LHS:.+]] = hal.interface.binding.subspan @io::@TENSOR_LHS
+// CHECK-DAG: %[[TENSOR_RHS:.+]] = hal.interface.binding.subspan @io::@TENSOR_RHS
+// CHECK-DAG: %[[RETURN1:.+]] = hal.interface.binding.subspan @io::@TENSOR_INIT
+// CHECK-DAG: %[[RETURN2:.+]] = hal.interface.binding.subspan @io::@ret0
+// CHECK: scf.for %[[IV0:.+]] = {{.+}} {
+// CHECK: scf.for %[[IV1:.+]] = {{.+}} {
+// CHECK-DAG: %[[LHS:.+]] = memref.subview %[[TENSOR_LHS]][%[[IV0]], 0] [1, 3] [1, 1]
+// CHECK-DAG: %[[RHS:.+]] = memref.subview %[[TENSOR_RHS]][0, %[[IV1]]] [3, 1] [1, 1]
+// CHECK-DAG: %[[RESULT1:.+]] = memref.subview %[[RETURN1]][%[[IV0]], %[[IV1]]] [1, 1] [1, 1]
+// CHECK: linalg.matmul
+// CHECK-SAME: ins(%[[LHS]], %[[RHS]]
+// CHECK-SAME: outs(%[[RESULT1]]
+// CHECK: %[[RESULT2:.+]] = memref.subview %[[RETURN2]][%[[IV0]], %[[IV1]]] [1, 1] [1, 1]
+// CHECK: linalg.copy(%[[RESULT1]], %[[RESULT2]])
+
+// -----
+
#map = affine_map<(d0, d1) -> (d0, d1)>
func @tile_from_pointwise_lhs() {
%c0 = constant 0 : index
@@ -93,7 +184,6 @@
// CHECK-DAG: %[[RETURN:.+]] = hal.interface.binding.subspan @io::@ret0
// CHECK: scf.for %[[IV0:.+]] = {{.+}} {
// CHECK: scf.for %[[IV1:.+]] = {{.+}} {
-// CHECK-DAG: %[[RESULT:.+]] = memref.subview %[[RETURN]][%[[IV0]], %[[IV1]]] [1, 1] [1, 1]
// CHECK-DAG: %[[LHS:.+]] = memref.subview %[[TENSOR_LHS]][%[[IV0]], 0] [1, 3] [1, 1]
// CHECK-DAG: %[[RHS:.+]] = memref.subview %[[TENSOR_RHS]][0, %[[IV1]]] [3, 1] [1, 1]
// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<1x3xf32>
@@ -101,6 +191,7 @@
// CHECK-SAME: ins(%[[LHS]] :
// CHECK-SAME: outs(%[[ALLOC]]
// CHECK-DAG: %[[INIT:.+]] = memref.subview %[[TENSOR_INIT]][%[[IV0]], %[[IV1]]] [1, 1] [1, 1]
+// CHECK-DAG: %[[RESULT:.+]] = memref.subview %[[RETURN]][%[[IV0]], %[[IV1]]] [1, 1] [1, 1]
// CHECK: linalg.copy(%[[INIT]], %[[RESULT]])
// CHECK: linalg.matmul
// CHECK-SAME: ins(%[[ALLOC]], %[[RHS]]
@@ -109,6 +200,60 @@
// -----
#map = affine_map<(d0, d1) -> (d0, d1)>
+func @tile_from_pointwise_lhs_inplace() {
+ %c0 = constant 0 : index
+ %c2 = constant 2 : index
+ %c4 = constant 4 : index
+ %c1 = constant 1 : index
+ %c3 = constant 3 : index
+ %0 = hal.interface.binding.subspan @io::@TENSOR_LHS[%c0] : !flow.dispatch.tensor<readonly:?x?xf32>
+ %1 = hal.interface.binding.subspan @io::@TENSOR_RHS[%c0] : !flow.dispatch.tensor<readonly:?x?xf32>
+ %2 = hal.interface.binding.subspan @io::@TENSOR_INIT[%c0] : !flow.dispatch.tensor<readwrite:?x?xf32>
+ %4 = hal.interface.workgroup.id[0] : index
+ %5 = hal.interface.workgroup.id[1] : index
+ scf.for %arg0 = %5 to %c2 step %c2 {
+ scf.for %arg1 = %4 to %c4 step %c4 {
+ %6 = flow.dispatch.tensor.load %0, offsets = [%arg0, %c0], sizes = [%c1, %c3], strides = [%c1, %c1] : !flow.dispatch.tensor<readonly:?x?xf32> -> tensor<1x3xf32>
+ %7 = flow.dispatch.tensor.load %1, offsets = [%c0, %arg1], sizes = [%c3, %c1], strides = [%c1, %c1] : !flow.dispatch.tensor<readonly:?x?xf32> -> tensor<3x1xf32>
+ %shape = linalg.init_tensor [1, 3] : tensor<1x3xf32>
+ %8 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]}
+ ins(%6 : tensor<1x3xf32>) outs(%shape : tensor<1x3xf32>) {
+ ^bb0(%arg2: f32, %s: f32): // no predecessors
+ linalg.yield %arg2 : f32
+ } -> tensor<1x3xf32>
+ %9 = flow.dispatch.tensor.load %2, offsets = [%arg0, %arg1], sizes = [%c1, %c1], strides = [%c1, %c1] : !flow.dispatch.tensor<readwrite:?x?xf32> -> tensor<1x1xf32>
+ %10 = linalg.matmul ins(%8, %7 : tensor<1x3xf32>, tensor<3x1xf32>) outs(%9 : tensor<1x1xf32>) -> tensor<1x1xf32>
+ flow.dispatch.tensor.store %10, %2, offsets = [%arg0, %arg1], sizes = [%c1, %c1], strides = [%c1, %c1] : tensor<1x1xf32> -> !flow.dispatch.tensor<readwrite:?x?xf32>
+ }
+ }
+ return
+}
+
+hal.interface @io attributes {sym_visibility = "private"} {
+ hal.interface.binding @TENSOR_LHS, set=0, binding=0, type="StorageBuffer", access="Read"
+ hal.interface.binding @TENSOR_RHS, set=0, binding=1, type="StorageBuffer", access="Read"
+ hal.interface.binding @TENSOR_INIT, set=0, binding=2, type="StorageBuffer", access="Read|Write"
+}
+// CHECK-LABEL: func @tile_from_pointwise_lhs_inplace()
+// CHECK-DAG: %[[TENSOR_LHS:.+]] = hal.interface.binding.subspan @io::@TENSOR_LHS
+// CHECK-DAG: %[[TENSOR_RHS:.+]] = hal.interface.binding.subspan @io::@TENSOR_RHS
+// CHECK-DAG: %[[RETURN:.+]] = hal.interface.binding.subspan @io::@TENSOR_INIT
+// CHECK: scf.for %[[IV0:.+]] = {{.+}} {
+// CHECK: scf.for %[[IV1:.+]] = {{.+}} {
+// CHECK-DAG: %[[LHS:.+]] = memref.subview %[[TENSOR_LHS]][%[[IV0]], 0] [1, 3] [1, 1]
+// CHECK-DAG: %[[RHS:.+]] = memref.subview %[[TENSOR_RHS]][0, %[[IV1]]] [3, 1] [1, 1]
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<1x3xf32>
+// CHECK: linalg.generic
+// CHECK-SAME: ins(%[[LHS]] :
+// CHECK-SAME: outs(%[[ALLOC]]
+// CHECK-DAG: %[[RESULT:.+]] = memref.subview %[[RETURN]][%[[IV0]], %[[IV1]]] [1, 1] [1, 1]
+// CHECK: linalg.matmul
+// CHECK-SAME: ins(%[[ALLOC]], %[[RHS]]
+// CHECK-SAME: outs(%[[RESULT]]
+
+// -----
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
func @tile_from_pointwise_outs() {
%c0 = constant 0 : index
%c2 = constant 2 : index
@@ -164,6 +309,154 @@
// -----
+#map = affine_map<(d0, d1) -> (d0, d1)>
+func @tile_from_pointwise_outs_inplace() {
+ %c0 = constant 0 : index
+ %c2 = constant 2 : index
+ %c4 = constant 4 : index
+ %c1 = constant 1 : index
+ %c3 = constant 3 : index
+ %0 = hal.interface.binding.subspan @io::@TENSOR_LHS[%c0] : !flow.dispatch.tensor<readonly:?x?xf32>
+ %1 = hal.interface.binding.subspan @io::@TENSOR_RHS[%c0] : !flow.dispatch.tensor<readonly:?x?xf32>
+ %2 = hal.interface.binding.subspan @io::@TENSOR_INIT[%c0] : !flow.dispatch.tensor<readwrite:?x?xf32>
+ %4 = hal.interface.workgroup.id[0] : index
+ %5 = hal.interface.workgroup.id[1] : index
+ scf.for %arg0 = %5 to %c2 step %c2 {
+ scf.for %arg1 = %4 to %c4 step %c4 {
+ %6 = flow.dispatch.tensor.load %0, offsets = [%arg0, %c0], sizes = [%c1, %c3], strides = [%c1, %c1] : !flow.dispatch.tensor<readonly:?x?xf32> -> tensor<1x3xf32>
+ %7 = flow.dispatch.tensor.load %1, offsets = [%c0, %arg1], sizes = [%c3, %c1], strides = [%c1, %c1] : !flow.dispatch.tensor<readonly:?x?xf32> -> tensor<3x1xf32>
+ %8 = flow.dispatch.tensor.load %2, offsets = [%arg0, %arg1], sizes = [%c1, %c1], strides = [%c1, %c1] : !flow.dispatch.tensor<readwrite:?x?xf32> -> tensor<1x1xf32>
+ %shape = linalg.init_tensor [1, 1] : tensor<1x1xf32>
+ %9 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]}
+ ins(%8 : tensor<1x1xf32>) outs(%shape : tensor<1x1xf32>) {
+ ^bb0(%arg2: f32, %s: f32): // no predecessors
+ linalg.yield %arg2 : f32
+ } -> tensor<1x1xf32>
+ %10 = linalg.matmul ins(%6, %7 : tensor<1x3xf32>, tensor<3x1xf32>) outs(%9 : tensor<1x1xf32>) -> tensor<1x1xf32>
+ flow.dispatch.tensor.store %10, %2, offsets = [%arg0, %arg1], sizes = [%c1, %c1], strides = [%c1, %c1] : tensor<1x1xf32> -> !flow.dispatch.tensor<readwrite:?x?xf32>
+ }
+ }
+ return
+}
+hal.interface @io attributes {sym_visibility = "private"} {
+ hal.interface.binding @TENSOR_LHS, set=0, binding=0, type="StorageBuffer", access="Read"
+ hal.interface.binding @TENSOR_RHS, set=0, binding=1, type="StorageBuffer", access="Read"
+ hal.interface.binding @TENSOR_INIT, set=0, binding=2, type="StorageBuffer", access="Read|Write"
+}
+// CHECK-LABEL: func @tile_from_pointwise_outs_inplace()
+// CHECK-DAG: %[[TENSOR_LHS:.+]] = hal.interface.binding.subspan @io::@TENSOR_LHS
+// CHECK-DAG: %[[TENSOR_RHS:.+]] = hal.interface.binding.subspan @io::@TENSOR_RHS
+// CHECK-DAG: %[[RETURN:.+]] = hal.interface.binding.subspan @io::@TENSOR_INIT
+// CHECK: scf.for %[[IV0:.+]] = {{.+}} {
+// CHECK: scf.for %[[IV1:.+]] = {{.+}} {
+// CHECK-DAG: %[[RESULT:.+]] = memref.subview %[[RETURN]][%[[IV0]], %[[IV1]]] [1, 1] [1, 1]
+// CHECK-DAG: %[[LHS:.+]] = memref.subview %[[TENSOR_LHS]][%[[IV0]], 0] [1, 3] [1, 1]
+// CHECK-DAG: %[[RHS:.+]] = memref.subview %[[TENSOR_RHS]][0, %[[IV1]]] [3, 1] [1, 1]
+// CHECK: linalg.generic
+// CHECK-SAME: outs(%[[RESULT]]
+// CHECK: linalg.matmul
+// CHECK-SAME: ins(%[[LHS]], %[[RHS]]
+// CHECK-SAME: outs(%[[RESULT]]
+
+// -----
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+func @tile_from_matmul_outs() {
+ %c0 = constant 0 : index
+ %c2 = constant 2 : index
+ %c4 = constant 4 : index
+ %c1 = constant 1 : index
+ %c3 = constant 3 : index
+ %0 = hal.interface.binding.subspan @io::@TENSOR_LHS[%c0] : !flow.dispatch.tensor<readonly:?x?xf32>
+ %1 = hal.interface.binding.subspan @io::@TENSOR_RHS[%c0] : !flow.dispatch.tensor<readonly:?x?xf32>
+ %2 = hal.interface.binding.subspan @io::@TENSOR_INIT[%c0] : !flow.dispatch.tensor<readonly:?x?xf32>
+ %3 = hal.interface.binding.subspan @io::@ret0[%c0] : !flow.dispatch.tensor<writeonly:?x?xf32>
+ %4 = hal.interface.workgroup.id[0] : index
+ %5 = hal.interface.workgroup.id[1] : index
+ scf.for %arg0 = %5 to %c2 step %c2 {
+ scf.for %arg1 = %4 to %c4 step %c4 {
+ %6 = flow.dispatch.tensor.load %0, offsets = [%arg0, %c0], sizes = [%c1, %c3], strides = [%c1, %c1] : !flow.dispatch.tensor<readonly:?x?xf32> -> tensor<1x3xf32>
+ %7 = flow.dispatch.tensor.load %1, offsets = [%c0, %arg1], sizes = [%c3, %c1], strides = [%c1, %c1] : !flow.dispatch.tensor<readonly:?x?xf32> -> tensor<3x1xf32>
+ %8 = flow.dispatch.tensor.load %2, offsets = [%arg0, %arg1], sizes = [%c1, %c1], strides = [%c1, %c1] : !flow.dispatch.tensor<readonly:?x?xf32> -> tensor<1x1xf32>
+ %shape = linalg.init_tensor [1, 1] : tensor<1x1xf32>
+ %9 = linalg.matmul ins(%6, %7 : tensor<1x3xf32>, tensor<3x1xf32>) outs(%8 : tensor<1x1xf32>) -> tensor<1x1xf32>
+ %10 = linalg.matmul ins(%6, %7 : tensor<1x3xf32>, tensor<3x1xf32>) outs(%9 : tensor<1x1xf32>) -> tensor<1x1xf32>
+ flow.dispatch.tensor.store %10, %3, offsets = [%arg0, %arg1], sizes = [%c1, %c1], strides = [%c1, %c1] : tensor<1x1xf32> -> !flow.dispatch.tensor<writeonly:?x?xf32>
+ }
+ }
+ return
+}
+hal.interface @io attributes {sym_visibility = "private"} {
+ hal.interface.binding @TENSOR_LHS, set=0, binding=0, type="StorageBuffer", access="Read"
+ hal.interface.binding @TENSOR_RHS, set=0, binding=1, type="StorageBuffer", access="Read"
+ hal.interface.binding @TENSOR_INIT, set=0, binding=2, type="StorageBuffer", access="Read"
+ hal.interface.binding @ret0, set=0, binding=3, type="StorageBuffer", access="Write|Discard"
+}
+// CHECK-LABEL: func @tile_from_matmul_outs()
+// CHECK-DAG: %[[TENSOR_LHS:.+]] = hal.interface.binding.subspan @io::@TENSOR_LHS
+// CHECK-DAG: %[[TENSOR_RHS:.+]] = hal.interface.binding.subspan @io::@TENSOR_RHS
+// CHECK-DAG: %[[TENSOR_INIT:.+]] = hal.interface.binding.subspan @io::@TENSOR_INIT
+// CHECK-DAG: %[[RETURN:.+]] = hal.interface.binding.subspan @io::@ret0
+// CHECK: scf.for %[[IV0:.+]] = {{.+}} {
+// CHECK: scf.for %[[IV1:.+]] = {{.+}} {
+// CHECK-DAG: %[[LHS:.+]] = memref.subview %[[TENSOR_LHS]][%[[IV0]], 0] [1, 3] [1, 1]
+// CHECK-DAG: %[[RHS:.+]] = memref.subview %[[TENSOR_RHS]][0, %[[IV1]]] [3, 1] [1, 1]
+// CHECK-DAG: %[[INIT:.+]] = memref.subview %[[TENSOR_INIT]][%[[IV0]], %[[IV1]]] [1, 1] [1, 1]
+// CHECK-DAG: %[[RESULT:.+]] = memref.subview %[[RETURN]][%[[IV0]], %[[IV1]]] [1, 1] [1, 1]
+// CHECK: linalg.copy(%[[INIT]], %[[RESULT]])
+// CHECK: linalg.matmul
+// CHECK-SAME: outs(%[[RESULT]]
+// CHECK: linalg.matmul
+// CHECK-SAME: outs(%[[RESULT]]
+
+// -----
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+func @tile_from_matmul_outs_inplace() {
+ %c0 = constant 0 : index
+ %c2 = constant 2 : index
+ %c4 = constant 4 : index
+ %c1 = constant 1 : index
+ %c3 = constant 3 : index
+ %0 = hal.interface.binding.subspan @io::@TENSOR_LHS[%c0] : !flow.dispatch.tensor<readonly:?x?xf32>
+ %1 = hal.interface.binding.subspan @io::@TENSOR_RHS[%c0] : !flow.dispatch.tensor<readonly:?x?xf32>
+ %2 = hal.interface.binding.subspan @io::@TENSOR_INIT[%c0] : !flow.dispatch.tensor<readwrite:?x?xf32>
+ %4 = hal.interface.workgroup.id[0] : index
+ %5 = hal.interface.workgroup.id[1] : index
+ scf.for %arg0 = %5 to %c2 step %c2 {
+ scf.for %arg1 = %4 to %c4 step %c4 {
+ %6 = flow.dispatch.tensor.load %0, offsets = [%arg0, %c0], sizes = [%c1, %c3], strides = [%c1, %c1] : !flow.dispatch.tensor<readonly:?x?xf32> -> tensor<1x3xf32>
+ %7 = flow.dispatch.tensor.load %1, offsets = [%c0, %arg1], sizes = [%c3, %c1], strides = [%c1, %c1] : !flow.dispatch.tensor<readonly:?x?xf32> -> tensor<3x1xf32>
+ %8 = flow.dispatch.tensor.load %2, offsets = [%arg0, %arg1], sizes = [%c1, %c1], strides = [%c1, %c1] : !flow.dispatch.tensor<readwrite:?x?xf32> -> tensor<1x1xf32>
+ %9 = linalg.matmul ins(%6, %7 : tensor<1x3xf32>, tensor<3x1xf32>) outs(%8 : tensor<1x1xf32>) -> tensor<1x1xf32>
+ %10 = linalg.matmul ins(%6, %7 : tensor<1x3xf32>, tensor<3x1xf32>) outs(%9 : tensor<1x1xf32>) -> tensor<1x1xf32>
+ flow.dispatch.tensor.store %10, %2, offsets = [%arg0, %arg1], sizes = [%c1, %c1], strides = [%c1, %c1] : tensor<1x1xf32> -> !flow.dispatch.tensor<readwrite:?x?xf32>
+ }
+ }
+ return
+}
+hal.interface @io attributes {sym_visibility = "private"} {
+ hal.interface.binding @TENSOR_LHS, set=0, binding=0, type="StorageBuffer", access="Read"
+ hal.interface.binding @TENSOR_RHS, set=0, binding=1, type="StorageBuffer", access="Read"
+ hal.interface.binding @TENSOR_INIT, set=0, binding=2, type="StorageBuffer", access="Read|Write"
+}
+// CHECK-LABEL: func @tile_from_matmul_outs_inplace()
+// CHECK-DAG: %[[TENSOR_LHS:.+]] = hal.interface.binding.subspan @io::@TENSOR_LHS
+// CHECK-DAG: %[[TENSOR_RHS:.+]] = hal.interface.binding.subspan @io::@TENSOR_RHS
+// CHECK-DAG: %[[RETURN:.+]] = hal.interface.binding.subspan @io::@TENSOR_INIT
+// CHECK: scf.for %[[IV0:.+]] = {{.+}} {
+// CHECK: scf.for %[[IV1:.+]] = {{.+}} {
+// CHECK-DAG: %[[RESULT:.+]] = memref.subview %[[RETURN]][%[[IV0]], %[[IV1]]] [1, 1] [1, 1]
+// CHECK-DAG: %[[LHS:.+]] = memref.subview %[[TENSOR_LHS]][%[[IV0]], 0] [1, 3] [1, 1]
+// CHECK-DAG: %[[RHS:.+]] = memref.subview %[[TENSOR_RHS]][0, %[[IV1]]] [3, 1] [1, 1]
+// CHECK: linalg.matmul
+// CHECK-SAME: outs(%[[RESULT]]
+// CHECK: linalg.matmul
+// CHECK-SAME: outs(%[[RESULT]]
+
+
+// -----
+
func @bufferize_dynamic() {
%c0 = constant 0 : index
%c1 = constant 1 : index
@@ -260,47 +553,86 @@
// -----
-// TODO(GH-4734): Enable after fixing the allocation for vector.transfer_writes.
-// #map0 = affine_map<(d0, d1, d2) -> (d0, d2)>
-// #map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
-// #map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
-// module {
-// func @bufferize_transfer_op() {
-// %c3 = constant 3 : index
-// %cst = constant 0.000000e+00 : f32
-// %c0 = constant 0 : index
-// %c2 = constant 2 : index
-// %c1 = constant 1 : index
-// %0 = hal.interface.binding.subspan @io::@arg0[%c0] : !flow.dispatch.tensor<readonly:2x3xf32>
-// %1 = hal.interface.binding.subspan @io::@arg1[%c0] : !flow.dispatch.tensor<readonly:3x4xf32>
-// %2 = hal.interface.binding.subspan @io::@arg2[%c0] : !flow.dispatch.tensor<readonly:2x4xf32>
-// %3 = hal.interface.binding.subspan @io::@ret0[%c0] : !flow.dispatch.tensor<writeonly:2x4xf32>
-// %4 = flow.dispatch.tensor.load %0, offsets = [%c0, %c0], sizes = [%c1, %c3], strides = [%c1, %c1] : !flow.dispatch.tensor<readonly:2x3xf32> -> tensor<2x3xf32>
-// %5 = flow.dispatch.tensor.load %1, offsets = [%c0, %c0], sizes = [%c3, %c1], strides = [%c1, %c1] : !flow.dispatch.tensor<readonly:3x4xf32> -> tensor<3x1xf32>
-// %6 = flow.dispatch.tensor.load %2, offsets = [%c0, %c0], sizes = [%c1, %c1], strides = [%c1, %c1] : !flow.dispatch.tensor<readonly:2x4xf32> -> tensor<2x1xf32>
-// %7 = vector.transfer_read %4[%c0, %c0], %cst {in_bounds = [true, true]} : tensor<2x3xf32>, vector<1x1xf32>
-// %8 = vector.transfer_read %4[%c0, %c1], %cst {in_bounds = [true, true]} : tensor<2x3xf32>, vector<1x1xf32>
-// %9 = vector.transfer_read %4[%c0, %c2], %cst {in_bounds = [true, true]} : tensor<2x3xf32>, vector<1x1xf32>
-// %10 = vector.transfer_read %4[%c1, %c0], %cst {in_bounds = [true, true]} : tensor<2x3xf32>, vector<1x1xf32>
-// %11 = vector.transfer_read %4[%c1, %c1], %cst {in_bounds = [true, true]} : tensor<2x3xf32>, vector<1x1xf32>
-// %12 = vector.transfer_read %4[%c1, %c2], %cst {in_bounds = [true, true]} : tensor<2x3xf32>, vector<1x1xf32>
-// %13 = vector.transfer_read %5[%c0, %c0], %cst {in_bounds = [true, true]} : tensor<3x1xf32>, vector<1x1xf32>
-// %14 = vector.transfer_read %5[%c1, %c0], %cst {in_bounds = [true, true]} : tensor<3x1xf32>, vector<1x1xf32>
-// %15 = vector.transfer_read %5[%c2, %c0], %cst {in_bounds = [true, true]} : tensor<3x1xf32>, vector<1x1xf32>
-// %16 = vector.transfer_read %6[%c0, %c0], %cst {in_bounds = [true, true]} : tensor<2x1xf32>, vector<1x1xf32>
-// %17 = vector.transfer_read %6[%c1, %c0], %cst {in_bounds = [true, true]} : tensor<2x1xf32>, vector<1x1xf32>
-// %18 = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %7, %13, %16 : vector<1x1xf32>, vector<1x1xf32> into vector<1x1xf32>
-// %19 = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %8, %14, %18 : vector<1x1xf32>, vector<1x1xf32> into vector<1x1xf32>
-// %20 = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %9, %15, %19 : vector<1x1xf32>, vector<1x1xf32> into vector<1x1xf32>
-// %21 = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %10, %13, %17 : vector<1x1xf32>, vector<1x1xf32> into vector<1x1xf32>
-// %22 = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %11, %14, %21 : vector<1x1xf32>, vector<1x1xf32> into vector<1x1xf32>
-// %23 = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %12, %15, %22 : vector<1x1xf32>, vector<1x1xf32> into vector<1x1xf32>
-// %24 = vector.transfer_write %20, %6[%c0, %c0] {in_bounds = [true, true]} : vector<1x1xf32>, tensor<2x1xf32>
-// %25 = vector.transfer_write %23, %24[%c1, %c0] {in_bounds = [true, true]} : vector<1x1xf32>, tensor<2x1xf32>
-// flow.dispatch.tensor.store %25, %3, offsets = [%c0, %c0], sizes = [%c1, %c1], strides = [%c1, %c1] : tensor<2x1xf32> -> !flow.dispatch.tensor<writeonly:2x4xf32>
-// return
-// }
-// }
+func @bufferize_dynamic_inplace() {
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %0 = hal.interface.binding.subspan @io::@arg0[%c0] : !flow.dispatch.tensor<readonly:?x?xf32>
+ %1 = hal.interface.binding.subspan @io::@arg1[%c0] : !flow.dispatch.tensor<readonly:?x?xf32>
+ %2 = hal.interface.binding.subspan @io::@arg2[%c0] : !flow.dispatch.tensor<readwrite:?x?xf32>
+ %4 = hal.interface.load.constant offset = 0 : index
+ %5 = hal.interface.load.constant offset = 1 : index
+ %6 = hal.interface.load.constant offset = 2 : index
+ %7 = hal.interface.load.constant offset = 3 : index
+ %8 = hal.interface.load.constant offset = 4 : index
+ %9 = hal.interface.load.constant offset = 5 : index
+ %12 = shapex.make_ranked_shape %4, %5 : (index, index) -> !shapex.ranked_shape<[?,?]>
+ %13 = flow.dispatch.tie_shape %0, %12 : (!flow.dispatch.tensor<readonly:?x?xf32>, !shapex.ranked_shape<[?,?]>) -> !flow.dispatch.tensor<readonly:?x?xf32>
+ %14 = shapex.make_ranked_shape %6, %7 : (index, index) -> !shapex.ranked_shape<[?,?]>
+ %15 = flow.dispatch.tie_shape %1, %14 : (!flow.dispatch.tensor<readonly:?x?xf32>, !shapex.ranked_shape<[?,?]>) -> !flow.dispatch.tensor<readonly:?x?xf32>
+ %16 = shapex.make_ranked_shape %8, %9 : (index, index) -> !shapex.ranked_shape<[?,?]>
+ %17 = flow.dispatch.tie_shape %2, %16 : (!flow.dispatch.tensor<readwrite:?x?xf32>, !shapex.ranked_shape<[?,?]>) -> !flow.dispatch.tensor<readwrite:?x?xf32>
+ %workgroup_size_x = hal.interface.workgroup.size[0] : index
+ %workgroup_size_y = hal.interface.workgroup.size[1] : index
+ %workgroup_id_x = hal.interface.workgroup.id[0] : index
+ %workgroup_count_x = hal.interface.workgroup.count[0] : index
+ %workgroup_id_y = hal.interface.workgroup.id[1] : index
+ %workgroup_count_y = hal.interface.workgroup.count[1] : index
+ %20 = muli %workgroup_size_y, %workgroup_id_y : index
+ %21 = muli %workgroup_size_y, %workgroup_count_y : index
+ scf.for %arg0 = %20 to %4 step %21 {
+ %22 = muli %workgroup_size_x, %workgroup_id_x : index
+ %23 = muli %workgroup_size_x, %workgroup_count_x : index
+ scf.for %arg1 = %22 to %7 step %23 {
+ %24 = affine.min affine_map<(d0)[s0, s1] -> (s1, -d0 + s0)>(%arg0)[%4, %workgroup_size_y]
+ %25 = flow.dispatch.tensor.load %13, offsets = [%arg0, %c0], sizes = [%24, %5], strides = [%c1, %c1] : !flow.dispatch.tensor<readonly:?x?xf32> -> tensor<?x?xf32>
+ %26 = affine.min affine_map<(d0)[s0, s1] -> (s1, -d0 + s0)>(%arg1)[%7, %workgroup_size_x]
+ %27 = flow.dispatch.tensor.load %15, offsets = [%c0, %arg1], sizes = [%6, %26], strides = [%c1, %c1] : !flow.dispatch.tensor<readonly:?x?xf32> -> tensor<?x?xf32>
+ %28 = affine.min affine_map<(d0)[s0, s1] -> (s0, -d0 + s1)>(%arg0)[%workgroup_size_y, %8]
+ %29 = affine.min affine_map<(d0)[s0, s1] -> (s0, -d0 + s1)>(%arg1)[%workgroup_size_x, %9]
+ %30 = flow.dispatch.tensor.load %17, offsets = [%arg0, %arg1], sizes = [%28, %29], strides = [%c1, %c1] : !flow.dispatch.tensor<readwrite:?x?xf32> -> tensor<?x?xf32>
+ %31 = linalg.matmul {__internal_linalg_transform__ = "workgroup"} ins(%25, %27 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%30 : tensor<?x?xf32>) -> tensor<?x?xf32>
+ flow.dispatch.tensor.store %31, %17, offsets = [%arg0, %arg1], sizes = [%28, %29], strides = [%c1, %c1] : tensor<?x?xf32> -> !flow.dispatch.tensor<readwrite:?x?xf32>
+ }
+ }
+ return
+}
+hal.interface @io attributes {sym_visibility = "private"} {
+ hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+ hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
+ hal.interface.binding @arg2, set=0, binding=2, type="StorageBuffer", access="Read|Write"
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0, s1] -> (s1, -d0 + s0)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0)[s0, s1] -> (s0, -d0 + s1)>
+// CHECK: func @bufferize_dynamic_inplace()
+// CHECK-DAG: %[[LHS:.+]] = hal.interface.binding.subspan @io::@arg0
+// CHECK-DAG: %[[RHS:.+]] = hal.interface.binding.subspan @io::@arg1
+// CHECK-DAG: %[[RESULT:.+]] = hal.interface.binding.subspan @io::@arg2
+// CHECK-DAG: %[[DIM0:.+]] = hal.interface.load.constant offset = 0 : index
+// CHECK-DAG: %[[DIM1:.+]] = hal.interface.load.constant offset = 1 : index
+// CHECK-DAG: %[[DIM2:.+]] = hal.interface.load.constant offset = 2 : index
+// CHECK-DAG: %[[DIM3:.+]] = hal.interface.load.constant offset = 3 : index
+// CHECK-DAG: %[[DIM4:.+]] = hal.interface.load.constant offset = 4 : index
+// CHECK-DAG: %[[DIM5:.+]] = hal.interface.load.constant offset = 5 : index
+// CHECK: %[[SHAPE_LHS:.+]] = shapex.make_ranked_shape %[[DIM0]], %[[DIM1]]
+// CHECK: %[[LHS_SHAPED:.+]] = shapex.tie_shape %[[LHS]], %[[SHAPE_LHS]]
+// CHECK: %[[SHAPE_RHS:.+]] = shapex.make_ranked_shape %[[DIM2]], %[[DIM3]]
+// CHECK: %[[RHS_SHAPED:.+]] = shapex.tie_shape %[[RHS]], %[[SHAPE_RHS]]
+// CHECK: %[[SHAPE_RESULT:.+]] = shapex.make_ranked_shape %[[DIM4]], %[[DIM5]]
+// CHECK: %[[RESULT_SHAPED:.+]] = shapex.tie_shape %[[RESULT]], %[[SHAPE_RESULT]]
+// CHECK-DAG: %[[WGSIZE_X:.+]] = hal.interface.workgroup.size[0]
+// CHECK-DAG: %[[WGSIZE_Y:.+]] = hal.interface.workgroup.size[1]
+// CHECK: scf.for %[[IV0:.+]] = {{.+}} {
+// CHECK: scf.for %[[IV1:.+]] = {{.+}} {
+// CHECK: %[[TILE_M:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[DIM0]], %[[WGSIZE_Y]]]
+// CHECK: %[[LHS_TILE:.+]] = memref.subview %[[LHS_SHAPED]][%[[IV0]], 0] [%[[TILE_M]], %[[DIM1]]]
+// CHECK: %[[TILE_N:.+]] = affine.min #[[MAP0]](%[[IV1]])[%[[DIM3]], %[[WGSIZE_X]]]
+// CHECK-DAG: %[[RHS_TILE:.+]] = memref.subview %[[RHS_SHAPED]][0, %[[IV1]]] [%[[DIM2]], %[[TILE_N]]]
+// CHECK: %[[TILE_M_2:.+]] = affine.min #[[MAP2]](%[[IV0]])[%[[WGSIZE_Y]], %[[DIM4]]]
+// CHECK: %[[TILE_N_2:.+]] = affine.min #[[MAP2]](%[[IV1]])[%[[WGSIZE_X]], %[[DIM5]]]
+// CHECK-DAG: %[[RESULT_TILE:.+]] = memref.subview %[[RESULT_SHAPED]][%[[IV0]], %[[IV1]]] [%[[TILE_M_2]], %[[TILE_N_2]]]
+// CHECK: linalg.matmul
+// CHECK-SAME: ins(%[[LHS_TILE]], %[[RHS_TILE]]
+// CHECK-SAME: outs(%[[RESULT_TILE]]
// -----
@@ -312,9 +644,9 @@
%c12 = constant 12 : index
%0 = hal.interface.binding.subspan @io::@arg0[%c0] : !flow.dispatch.tensor<readonly:12xi32>
%1 = hal.interface.binding.subspan @io::@ret0[%c0] : !flow.dispatch.tensor<writeonly:3x4xi32>
- %2 = flow.dispatch.tensor.load %0, offsets = [%c0], sizes = [%c12], strides = [%c1] : !flow.dispatch.tensor<readonly:12xi32> -> tensor<12xi32>
+ %2 = flow.dispatch.tensor.load %0, offsets = [], sizes = [], strides = [] : !flow.dispatch.tensor<readonly:12xi32> -> tensor<12xi32>
%3 = linalg.tensor_reshape %2 [affine_map<(d0, d1) -> (d0, d1)>] : tensor<12xi32> into tensor<3x4xi32>
- flow.dispatch.tensor.store %3, %1, offsets = [%c0, %c0], sizes = [%c3, %c4], strides = [%c1, %c1] : tensor<3x4xi32> -> !flow.dispatch.tensor<writeonly:3x4xi32>
+ flow.dispatch.tensor.store %3, %1, offsets = [], sizes = [], strides = [] : tensor<3x4xi32> -> !flow.dispatch.tensor<writeonly:3x4xi32>
return
}
hal.interface @io attributes {sym_visibility = "private"} {
@@ -323,12 +655,10 @@
}
// CHECK: #[[MAP:.+]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK: func @reshape_simple()
-// CHECK: %[[C0:.+]] = constant 0
-// CHECK-DAG: %[[ARG0:.+]] = hal.interface.binding.subspan @io::@arg0[%[[C0]]] : memref<12xi32>
-// CHECK-DAG: %[[RET0:.+]] = hal.interface.binding.subspan @io::@ret0[%[[C0]]] : memref<3x4xi32>
-// CHECK-DAG: %[[RET0V:.+]] = memref.subview %[[RET0]][0, 0] [3, 4] [1, 1]
+// CHECK-DAG: %[[ARG0:.+]] = hal.interface.binding.subspan @io::@arg0
+// CHECK-DAG: %[[RET0:.+]] = hal.interface.binding.subspan @io::@ret0
// CHECK: %[[RESHAPE:.+]] = linalg.reshape %[[ARG0]] [#[[MAP]]]
-// CHECK: linalg.copy(%[[RESHAPE]], %[[RET0V]])
+// CHECK: linalg.copy(%[[RESHAPE]], %[[RET0]])
// -----
@@ -340,7 +670,7 @@
%c12 = constant 12 : index
%0 = hal.interface.binding.subspan @io::@arg0[%c0] : !flow.dispatch.tensor<readonly:12xi32>
%1 = hal.interface.binding.subspan @io::@ret0[%c0] : !flow.dispatch.tensor<writeonly:3x4xi32>
- %2 = flow.dispatch.tensor.load %0, offsets = [%c0], sizes = [%c12], strides = [%c1] : !flow.dispatch.tensor<readonly:12xi32> -> tensor<12xi32>
+ %2 = flow.dispatch.tensor.load %0, offsets = [], sizes = [], strides = [] : !flow.dispatch.tensor<readonly:12xi32> -> tensor<12xi32>
%3 = linalg.tensor_reshape %2 [affine_map<(d0, d1) -> (d0, d1)>] : tensor<12xi32> into tensor<3x4xi32>
%4 = linalg.init_tensor [3, 4] : tensor<3x4xi32>
%5 = linalg.generic {
@@ -351,24 +681,22 @@
%6 = addi %arg0, %arg0 : i32
linalg.yield %6 : i32
} -> tensor<3x4xi32>
- flow.dispatch.tensor.store %5, %1, offsets = [%c0, %c0], sizes = [%c3, %c4], strides = [%c1, %c1] : tensor<3x4xi32> -> !flow.dispatch.tensor<writeonly:3x4xi32>
+ flow.dispatch.tensor.store %5, %1, offsets = [], sizes = [], strides = [] : tensor<3x4xi32> -> !flow.dispatch.tensor<writeonly:3x4xi32>
return
}
hal.interface @io attributes {sym_visibility = "private"} {
hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
hal.interface.binding @ret0, set=0, binding=1, type="StorageBuffer", access="Write|Discard"
}
-// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1) -> (d0 * 4 + d1)>
-// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK: func @reshape_fused_source()
// CHECK: %[[C0:.+]] = constant 0
// CHECK-DAG: %[[ARG0:.+]] = hal.interface.binding.subspan @io::@arg0[%[[C0]]] : memref<12xi32>
// CHECK-DAG: %[[RET0:.+]] = hal.interface.binding.subspan @io::@ret0[%[[C0]]] : memref<3x4xi32>
-// CHECK-DAG: %[[RET0V:.+]] = memref.subview %[[RET0]][0, 0] [3, 4] [1, 1]
-// CHECK: %[[RESHAPE:.+]] = linalg.reshape %[[ARG0]] [#[[MAP1]]]
+// CHECK: %[[RESHAPE:.+]] = linalg.reshape %[[ARG0]] [#[[MAP]]]
// CHECK: linalg.generic
// CHECK-SAME: ins(%[[RESHAPE]] : memref<3x4xi32>)
-// CHECK-SAME: outs(%[[RET0V]] : memref<3x4xi32, #[[MAP0]]>)
+// CHECK-SAME: outs(%[[RET0]] : memref<3x4xi32>)
// -----
@@ -381,7 +709,7 @@
%0 = hal.interface.binding.subspan @io::@arg0[%c0] : !flow.dispatch.tensor<readonly:12xi32>
%1 = hal.interface.binding.subspan @io::@ret0[%c0] : !flow.dispatch.tensor<writeonly:3x4xi32>
%2 = hal.interface.binding.subspan @io::@ret1[%c0] : !flow.dispatch.tensor<writeonly:3x4xi32>
- %3 = flow.dispatch.tensor.load %0, offsets = [%c0], sizes = [%c12], strides = [%c1] : !flow.dispatch.tensor<readonly:12xi32> -> tensor<12xi32>
+ %3 = flow.dispatch.tensor.load %0, offsets = [], sizes = [], strides = [] : !flow.dispatch.tensor<readonly:12xi32> -> tensor<12xi32>
%4 = linalg.tensor_reshape %3 [affine_map<(d0, d1) -> (d0, d1)>] : tensor<12xi32> into tensor<3x4xi32>
%5 = linalg.init_tensor [3, 4] : tensor<3x4xi32>
%6 = linalg.generic {
@@ -392,8 +720,8 @@
%7 = addi %arg0, %arg0 : i32
linalg.yield %7 : i32
} -> tensor<3x4xi32>
- flow.dispatch.tensor.store %6, %1, offsets = [%c0, %c0], sizes = [%c3, %c4], strides = [%c1, %c1] : tensor<3x4xi32> -> !flow.dispatch.tensor<writeonly:3x4xi32>
- flow.dispatch.tensor.store %4, %2, offsets = [%c0, %c0], sizes = [%c3, %c4], strides = [%c1, %c1] : tensor<3x4xi32> -> !flow.dispatch.tensor<writeonly:3x4xi32>
+ flow.dispatch.tensor.store %6, %1, offsets = [], sizes = [], strides = [] : tensor<3x4xi32> -> !flow.dispatch.tensor<writeonly:3x4xi32>
+ flow.dispatch.tensor.store %4, %2, offsets = [], sizes = [], strides = [] : tensor<3x4xi32> -> !flow.dispatch.tensor<writeonly:3x4xi32>
return
}
hal.interface @io attributes {sym_visibility = "private"} {
@@ -401,21 +729,17 @@
hal.interface.binding @ret0, set=0, binding=1, type="StorageBuffer", access="Write|Discard"
hal.interface.binding @ret1, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
}
-// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1) -> (d0 * 4 + d1)>
-// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK: func @reshape_fused_source_and_copyout()
// CHECK: %[[C0:.+]] = constant 0
// CHECK-DAG: %[[ARG0:.+]] = hal.interface.binding.subspan @io::@arg0[%[[C0]]] : memref<12xi32>
// CHECK-DAG: %[[RET0:.+]] = hal.interface.binding.subspan @io::@ret0[%[[C0]]] : memref<3x4xi32>
// CHECK-DAG: %[[RET1:.+]] = hal.interface.binding.subspan @io::@ret1[%[[C0]]] : memref<3x4xi32>
-// CHECK-DAG: %[[RET0V:.+]] = memref.subview %[[RET0]][0, 0] [3, 4] [1, 1]
-// CHECK-DAG: %[[RET1V:.+]] = memref.subview %[[RET1]][0, 0] [3, 4] [1, 1]
-// CHECK: %[[RESHAPE:.+]] = linalg.reshape %[[ARG0]] [#[[MAP1]]]
-// CHECK-DAG: linalg.copy(%[[RESHAPE]], %[[RET1V]])
-// CHECK-DAG: linalg.generic
-// CHECK-SAME: ins(%[[RET1V]] : memref<3x4xi32, #[[MAP0]]>)
-// CHECK-SAME: outs(%[[RET0V]] : memref<3x4xi32, #[[MAP0]]>)
-
+// CHECK: %[[RESHAPE:.+]] = linalg.reshape %[[ARG0]] [#[[MAP]]]
+// CHECK: linalg.generic
+// CHECK-SAME: ins(%[[RESHAPE]] : memref<3x4xi32>)
+// CHECK-SAME: outs(%[[RET0]] : memref<3x4xi32>)
+// CHECK: linalg.copy(%[[RESHAPE]], %[[RET1]])
// -----
@@ -427,7 +751,7 @@
%c12 = constant 12 : index
%0 = hal.interface.binding.subspan @io::@arg0[%c0] : !flow.dispatch.tensor<readonly:3x4xi32>
%1 = hal.interface.binding.subspan @io::@ret0[%c0] : !flow.dispatch.tensor<writeonly:12xi32>
- %2 = flow.dispatch.tensor.load %0, offsets = [%c0, %c0], sizes = [%c3, %c4], strides = [%c1, %c1] : !flow.dispatch.tensor<readonly:3x4xi32> -> tensor<3x4xi32>
+ %2 = flow.dispatch.tensor.load %0, offsets = [], sizes = [], strides = [] : !flow.dispatch.tensor<readonly:3x4xi32> -> tensor<3x4xi32>
%3 = linalg.init_tensor [3, 4] : tensor<3x4xi32>
%4 = linalg.generic {
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
@@ -438,26 +762,22 @@
linalg.yield %5 : i32
} -> tensor<3x4xi32>
%5 = linalg.tensor_reshape %4 [affine_map<(d0, d1) -> (d0, d1)>] : tensor<3x4xi32> into tensor<12xi32>
- flow.dispatch.tensor.store %5, %1, offsets = [%c0], sizes = [%c12], strides = [%c1] : tensor<12xi32> -> !flow.dispatch.tensor<writeonly:12xi32>
+ flow.dispatch.tensor.store %5, %1, offsets = [], sizes = [], strides = [] : tensor<12xi32> -> !flow.dispatch.tensor<writeonly:12xi32>
return
}
hal.interface @io attributes {sym_visibility = "private"} {
hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
hal.interface.binding @ret0, set=0, binding=1, type="StorageBuffer", access="Write|Discard"
}
-// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1) -> (d0 * 4 + d1)>
-// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK: func @reshape_fused_target()
// CHECK: %[[C0:.+]] = constant 0
// CHECK-DAG: %[[ARG0:.+]] = hal.interface.binding.subspan @io::@arg0[%[[C0]]] : memref<3x4xi32>
// CHECK-DAG: %[[RET0:.+]] = hal.interface.binding.subspan @io::@ret0[%[[C0]]] : memref<12xi32>
-// CHECK-DAG: %[[ARG0V:.+]] = memref.subview %[[ARG0]][0, 0] [3, 4] [1, 1]
-// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<3x4xi32>
+// CHECK-DAG: %[[RESHAPE:.+]] = linalg.reshape %[[RET0]] [#[[MAP]]]
// CHECK: linalg.generic
-// CHECK-SAME: ins(%[[ARG0V]] : memref<3x4xi32, #[[MAP0]]>)
-// CHECK-SAME: outs(%[[ALLOC]] : memref<3x4xi32>)
-// CHECK: %[[RESULT:.+]] = linalg.reshape %[[ALLOC]] [#[[MAP1]]]
-// CHECK: linalg.copy(%[[RESULT]], %[[RET0]])
+// CHECK-SAME: ins(%[[ARG0]] : memref<3x4xi32>)
+// CHECK-SAME: outs(%[[RESHAPE]] : memref<3x4xi32>)
// -----
@@ -518,6 +838,146 @@
// -----
+func @slice() {
+ %c0 = constant 0 : index
+ %0 = hal.interface.binding.subspan @io::@arg0[%c0] : !flow.dispatch.tensor<readonly:?x?xi32>
+ %1 = hal.interface.binding.subspan @io::@ret0[%c0] : !flow.dispatch.tensor<writeonly:?x?xi32>
+ %2 = hal.interface.load.constant offset = 0 : index
+ %3 = hal.interface.load.constant offset = 1 : index
+ %4 = hal.interface.load.constant offset = 2 : index
+ %5 = hal.interface.load.constant offset = 3 : index
+ %6 = flow.dispatch.tensor.load %0, offsets = [], sizes = [], strides = [] : !flow.dispatch.tensor<readonly:?x?xi32> -> tensor<?x?xi32>
+ %7 = subtensor %6[%2, %3] [%4, %5] [1, 1] : tensor<?x?xi32> to tensor<?x?xi32>
+ flow.dispatch.tensor.store %7, %1, offsets = [], sizes = [], strides = [] : tensor<?x?xi32> -> !flow.dispatch.tensor<writeonly:?x?xi32>
+ return
+}
+hal.interface @io attributes {sym_visibility = "private"} {
+ hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+ hal.interface.binding @ret0, set=0, binding=1, type="StorageBuffer", access="Write|Discard"
+}
+// CHECK-LABEL: func @slice()
+// CHECK-DAG: %[[ARG:.+]] = hal.interface.binding.subspan @io::@arg0
+// CHECK-DAG: %[[RETURN:.+]] = hal.interface.binding.subspan @io::@ret0
+// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG]]
+// CHECK: linalg.copy(%[[SUBVIEW]], %[[RETURN]])
+
+// -----
+
+func @slice_rank_reducing() {
+ %c0 = constant 0 : index
+ %0 = hal.interface.binding.subspan @io::@arg0[%c0] : !flow.dispatch.tensor<readonly:?x?x?xi32>
+ %1 = hal.interface.binding.subspan @io::@ret0[%c0] : !flow.dispatch.tensor<writeonly:?x?xi32>
+ %2 = hal.interface.load.constant offset = 0 : index
+ %3 = hal.interface.load.constant offset = 1 : index
+ %4 = hal.interface.load.constant offset = 2 : index
+ %5 = hal.interface.load.constant offset = 3 : index
+ %6 = flow.dispatch.tensor.load %0, offsets = [], sizes = [], strides = [] : !flow.dispatch.tensor<readonly:?x?x?xi32> -> tensor<?x?x?xi32>
+ %7 = subtensor %6[%2, %2, %3] [%4, 1, %5] [1, 1, 1] : tensor<?x?x?xi32> to tensor<?x?xi32>
+ flow.dispatch.tensor.store %7, %1, offsets = [], sizes = [], strides = [] : tensor<?x?xi32> -> !flow.dispatch.tensor<writeonly:?x?xi32>
+ return
+}
+hal.interface @io attributes {sym_visibility = "private"} {
+ hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+ hal.interface.binding @ret0, set=0, binding=1, type="StorageBuffer", access="Write|Discard"
+}
+// CHECK-LABEL: func @slice_rank_reducing()
+// CHECK-DAG: %[[ARG:.+]] = hal.interface.binding.subspan @io::@arg0
+// CHECK-DAG: %[[RETURN:.+]] = hal.interface.binding.subspan @io::@ret0
+// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG]]
+// CHECK: linalg.copy(%[[SUBVIEW]], %[[RETURN]])
+
+// -----
+
+func @slice_multiple_copy() {
+ %c0 = constant 0 : index
+ %0 = hal.interface.binding.subspan @io::@arg0[%c0] : !flow.dispatch.tensor<readonly:?x?x?xi32>
+ %1 = hal.interface.binding.subspan @io::@ret0[%c0] : !flow.dispatch.tensor<writeonly:?x?x?xi32>
+ %2 = hal.interface.binding.subspan @io::@ret1[%c0] : !flow.dispatch.tensor<writeonly:?x?xi32>
+ %3 = hal.interface.load.constant offset = 0 : index
+ %4 = hal.interface.load.constant offset = 1 : index
+ %5 = hal.interface.load.constant offset = 2 : index
+ %6 = hal.interface.load.constant offset = 3 : index
+ %7 = hal.interface.load.constant offset = 4 : index
+ %8 = hal.interface.load.constant offset = 5 : index
+ %9 = flow.dispatch.tensor.load %0, offsets = [], sizes = [], strides = [] : !flow.dispatch.tensor<readonly:?x?x?xi32> -> tensor<?x?x?xi32>
+ %10 = subtensor %9[%3, %4, %5] [%6, %7, %8] [1, 1, 1] : tensor<?x?x?xi32> to tensor<?x?x?xi32>
+ %11 = subtensor %9[%3, %4, %5] [%6, 1, %8] [1, 1, 1] : tensor<?x?x?xi32> to tensor<?x?xi32>
+ flow.dispatch.tensor.store %10, %1, offsets = [], sizes = [], strides = [] : tensor<?x?x?xi32> -> !flow.dispatch.tensor<writeonly:?x?x?xi32>
+ flow.dispatch.tensor.store %11, %2, offsets = [%3, %5], sizes = [%6, %8], strides = [1, 1] : tensor<?x?xi32> -> !flow.dispatch.tensor<writeonly:?x?xi32>
+ return
+}
+hal.interface @io attributes {sym_visibility = "private"} {
+ hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+ hal.interface.binding @ret0, set=0, binding=1, type="StorageBuffer", access="Write|Discard"
+ hal.interface.binding @ret1, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
+}
+// CHECK-LABEL: func @slice_multiple_copy()
+// CHECK-DAG: %[[ARG:.+]] = hal.interface.binding.subspan @io::@arg0
+// CHECK-DAG: %[[RETURN1:.+]] = hal.interface.binding.subspan @io::@ret0
+// CHECK-DAG: %[[RETURN2:.+]] = hal.interface.binding.subspan @io::@ret1
+// CHECK: %[[SUBVIEW1:.+]] = memref.subview %[[ARG]]
+// CHECK: %[[SUBVIEW2:.+]] = memref.subview %[[ARG]]
+// CHECK: linalg.copy(%[[SUBVIEW1]], %[[RETURN1]])
+// CHECK: %[[RETURNVIEW:.+]] = memref.subview %[[RETURN2]]
+// CHECK: linalg.copy(%[[SUBVIEW2]], %[[RETURNVIEW]])
+
+// -----
+
+func @slice_multiple_copy() {
+ %c0 = constant 0 : index
+ %0 = hal.interface.binding.subspan @io::@arg0[%c0] : !flow.dispatch.tensor<readonly:?x?x?xi32>
+ %1 = hal.interface.binding.subspan @io::@ret0[%c0] : !flow.dispatch.tensor<writeonly:?x?x?xi32>
+ %2 = hal.interface.binding.subspan @io::@ret1[%c0] : !flow.dispatch.tensor<writeonly:?x?xi32>
+ %3 = hal.interface.load.constant offset = 0 : index
+ %4 = hal.interface.load.constant offset = 1 : index
+ %5 = hal.interface.load.constant offset = 2 : index
+ %6 = hal.interface.load.constant offset = 3 : index
+ %7 = hal.interface.load.constant offset = 4 : index
+ %8 = hal.interface.load.constant offset = 5 : index
+ %9 = flow.dispatch.tensor.load %0, offsets = [], sizes = [], strides = [] : !flow.dispatch.tensor<readonly:?x?x?xi32> -> tensor<?x?x?xi32>
+ %10 = subtensor %9[%3, %4, %5] [%6, %7, %8] [1, 1, 1] : tensor<?x?x?xi32> to tensor<?x?x?xi32>
+ %11 = subtensor %9[%3, %4, %5] [%6, 1, %8] [1, 1, 1] : tensor<?x?x?xi32> to tensor<?x?xi32>
+ flow.dispatch.tensor.store %10, %1, offsets = [], sizes = [], strides = [] : tensor<?x?x?xi32> -> !flow.dispatch.tensor<writeonly:?x?x?xi32>
+ flow.dispatch.tensor.store %11, %2, offsets = [%3, %5], sizes = [%6, %8], strides = [1, 1] : tensor<?x?xi32> -> !flow.dispatch.tensor<writeonly:?x?xi32>
+ return
+}
+hal.interface @io attributes {sym_visibility = "private"} {
+ hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+ hal.interface.binding @ret0, set=0, binding=1, type="StorageBuffer", access="Write|Discard"
+ hal.interface.binding @ret1, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
+}
+// CHECK-LABEL: func @slice_multiple_copy()
+// CHECK-DAG: %[[ARG:.+]] = hal.interface.binding.subspan @io::@arg0
+// CHECK-DAG: %[[RETURN1:.+]] = hal.interface.binding.subspan @io::@ret0
+// CHECK-DAG: %[[RETURN2:.+]] = hal.interface.binding.subspan @io::@ret1
+// CHECK: %[[SUBVIEW1:.+]] = memref.subview %[[ARG]]
+// CHECK: %[[SUBVIEW2:.+]] = memref.subview %[[ARG]]
+// CHECK: linalg.copy(%[[SUBVIEW1]], %[[RETURN1]])
+// CHECK: %[[RETURNVIEW:.+]] = memref.subview %[[RETURN2]]
+// CHECK: linalg.copy(%[[SUBVIEW2]], %[[RETURNVIEW]])
+
+// -----
+
+func @slice_in_place() {
+ %c0 = constant 0 : index
+ %0 = hal.interface.binding.subspan @io::@arg0[%c0] : !flow.dispatch.tensor<readwrite:?x?xi32>
+ %2 = hal.interface.load.constant offset = 0 : index
+ %3 = hal.interface.load.constant offset = 1 : index
+ %4 = hal.interface.load.constant offset = 2 : index
+ %5 = hal.interface.load.constant offset = 3 : index
+ %6 = flow.dispatch.tensor.load %0, offsets = [], sizes = [], strides = [] : !flow.dispatch.tensor<readwrite:?x?xi32> -> tensor<?x?xi32>
+ flow.dispatch.tensor.store %6, %0, offsets = [], sizes = [], strides = [] : tensor<?x?xi32> -> !flow.dispatch.tensor<readwrite:?x?xi32>
+ return
+}
+hal.interface @io attributes {sym_visibility = "private"} {
+ hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read|Write"
+}
+// CHECK-LABEL: func @slice_in_place()
+// CHECK-NOT: linalg.copy
+
+
+// -----
+
func @slice_whole_stride_dispatch_0() {
%c0 = constant 0 : index
%0 = hal.interface.binding.subspan @io::@arg0[%c0] : !flow.dispatch.tensor<readonly:?x?xi32>
@@ -752,9 +1212,9 @@
// CHECK-DAG: %[[INPUT:.+]] = hal.interface.binding.subspan @io::@ro1[%c0] : memref<1x4x6x1xf32>
// CHECK-DAG: %[[INIT:.+]] = hal.interface.binding.subspan @io::@ro0[%c0] : memref<f32>
// CHECK-DAG: %[[RET0:.+]] = hal.interface.binding.subspan @io::@wo2[%c0] : memref<1x2x2x1xf32>
-// CHECK: %[[WINDOW:.+]] = memref.alloc() : memref<2x3xf32>
// CHECK: %[[INIT_VAL:.+]] = memref.load %[[INIT]][] : memref<f32>
// CHECK: linalg.fill(%[[RET0]], %[[INIT_VAL]]) : memref<1x2x2x1xf32>, f32
+// CHECK: %[[WINDOW:.+]] = memref.alloc() : memref<2x3xf32>
// CHECK: linalg.pooling_nhwc_sum
// CHECK-SAME: dilations = dense<1> : vector<2xi64>
// CHECK-SAME: strides = dense<[2, 3]> : vector<2xi64>
@@ -1014,7 +1474,6 @@
// CHECK-SAME: ins(%[[INPUT]], %[[CAST5]] : memref<5xf32>, memref<5xi32>)
// CHECK-SAME: outs(%[[OUTPUT]] : memref<i32>)
-
// -----
func @cast_follwed_by_store() {
@@ -1054,12 +1513,158 @@
}
// CHECK-LABEL: func @cast_follwed_by_store()
-// CHECK: %[[ZERO:.+]] = constant 0.000000e+00 : f32
-// CHECK: %[[LHS:.+]] = hal.interface.binding.subspan @io::@s0b0_ro_external[%c0] : memref<4x32x1024xf32>
-// CHECK: %[[RHS:.+]] = hal.interface.binding.subspan @io::@s0b1_ro_external[%c0] : memref<4x1024x64xf32>
-// CHECK: %[[RESULT:.+]] = hal.interface.binding.subspan @io::@s0b2_xw_external[%c0] : memref<4x32x64xf32>
-// CHECK: %[[RESULTV:.+]] = memref.subview %[[RESULT]]
-// CHECK: %[[LHSV:.+]] = memref.subview %[[LHS]]
-// CHECK: %[[RHSV:.+]] = memref.subview %[[RHS]]
-// CHECK: linalg.fill(%[[RESULTV]], %[[ZERO]])
-// CHECK: linalg.batch_matmul {{.*}} ins(%[[LHSV]], %[[RHSV]] : {{.*}}) outs(%[[RESULTV]]
+// CHECK-DAG: %[[ZERO:.+]] = constant 0.000000e+00 : f32
+// CHECK-DAG: %[[LHS:.+]] = hal.interface.binding.subspan @io::@s0b0_ro_external[%c0] : memref<4x32x1024xf32>
+// CHECK-DAG: %[[RHS:.+]] = hal.interface.binding.subspan @io::@s0b1_ro_external[%c0] : memref<4x1024x64xf32>
+// CHECK-DAG: %[[RESULT:.+]] = hal.interface.binding.subspan @io::@s0b2_xw_external[%c0] : memref<4x32x64xf32>
+// CHECK: %[[LHSV:.+]] = memref.subview %[[LHS]]
+// CHECK: %[[RHSV:.+]] = memref.subview %[[RHS]]
+// CHECK: %[[RESULTV:.+]] = memref.subview %[[RESULT]]
+// CHECK: linalg.fill(%[[RESULTV]], %[[ZERO]])
+// CHECK: linalg.batch_matmul {{.*}} ins(%[[LHSV]], %[[RHSV]] : {{.*}}) outs(%[[RESULTV]]
+
+// -----
+
+func @rank_reduced_subtensor_insert() {
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %c2 = constant 2 : index
+ %0 = hal.interface.binding.subspan @io::@arg0[%c0] : !flow.dispatch.tensor<readonly:?x?xf32>
+ %1 = hal.interface.binding.subspan @io::@ret0[%c0] : !flow.dispatch.tensor<readwrite:?x?x?xf32>
+ %2 = flow.dispatch.tensor.load %0, offsets = [], sizes = [], strides = [] : !flow.dispatch.tensor<readonly:?x?xf32> -> tensor<?x?xf32>
+ %3 = flow.dispatch.tensor.load %1, offsets = [], sizes = [], strides = [] : !flow.dispatch.tensor<readwrite:?x?x?xf32> -> tensor<?x?x?xf32>
+ %4 = memref.dim %3, %c1 : tensor<?x?x?xf32>
+ %5 = memref.dim %3, %c2 : tensor<?x?x?xf32>
+ %6 = subtensor_insert %2 into %3[0, 0, 0] [1, %4, %5] [1, 1, 1] : tensor<?x?xf32> into tensor<?x?x?xf32>
+ flow.dispatch.tensor.store %6, %1, offsets = [], sizes = [], strides = [] : tensor<?x?x?xf32> -> !flow.dispatch.tensor<readwrite:?x?x?xf32>
+ return
+}
+hal.interface @io attributes {sym_visibility = "private"} {
+ hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+ hal.interface.binding @ret0, set=0, binding=1, type="StorageBuffer", access="Read|Write"
+}
+// CHECK-LABEL: func @rank_reduced_subtensor_insert()
+// CHECK-DAG: %[[ARG:.+]] = hal.interface.binding.subspan @io::@arg0
+// CHECK-DAG: %[[RET:.+]] = hal.interface.binding.subspan @io::@ret0
+// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[RET]]
+// CHECK: linalg.copy(%[[ARG]], %[[SUBVIEW]])
+
+// -----
+
+#map0 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func @bufferize_transfer_op() {
+ %c3 = constant 3 : index
+ %cst = constant 0.000000e+00 : f32
+ %c0 = constant 0 : index
+ %c2 = constant 2 : index
+ %c1 = constant 1 : index
+ %0 = hal.interface.binding.subspan @io::@arg0[%c0] : !flow.dispatch.tensor<readonly:2x3xf32>
+ %1 = hal.interface.binding.subspan @io::@arg1[%c0] : !flow.dispatch.tensor<readonly:3x4xf32>
+ %2 = hal.interface.binding.subspan @io::@arg2[%c0] : !flow.dispatch.tensor<readonly:2x4xf32>
+ %3 = hal.interface.binding.subspan @io::@ret0[%c0] : !flow.dispatch.tensor<writeonly:2x4xf32>
+ %4 = flow.dispatch.tensor.load %0, offsets = [%c0, %c0], sizes = [%c1, %c3], strides = [%c1, %c1] : !flow.dispatch.tensor<readonly:2x3xf32> -> tensor<2x3xf32>
+ %5 = flow.dispatch.tensor.load %1, offsets = [%c0, %c0], sizes = [%c3, %c1], strides = [%c1, %c1] : !flow.dispatch.tensor<readonly:3x4xf32> -> tensor<3x1xf32>
+ %6 = flow.dispatch.tensor.load %2, offsets = [%c0, %c0], sizes = [%c1, %c1], strides = [%c1, %c1] : !flow.dispatch.tensor<readonly:2x4xf32> -> tensor<2x1xf32>
+ %7 = vector.transfer_read %4[%c0, %c0], %cst {in_bounds = [true, true]} : tensor<2x3xf32>, vector<1x1xf32>
+ %8 = vector.transfer_read %4[%c0, %c1], %cst {in_bounds = [true, true]} : tensor<2x3xf32>, vector<1x1xf32>
+ %9 = vector.transfer_read %4[%c0, %c2], %cst {in_bounds = [true, true]} : tensor<2x3xf32>, vector<1x1xf32>
+ %10 = vector.transfer_read %4[%c1, %c0], %cst {in_bounds = [true, true]} : tensor<2x3xf32>, vector<1x1xf32>
+ %11 = vector.transfer_read %4[%c1, %c1], %cst {in_bounds = [true, true]} : tensor<2x3xf32>, vector<1x1xf32>
+ %12 = vector.transfer_read %4[%c1, %c2], %cst {in_bounds = [true, true]} : tensor<2x3xf32>, vector<1x1xf32>
+ %13 = vector.transfer_read %5[%c0, %c0], %cst {in_bounds = [true, true]} : tensor<3x1xf32>, vector<1x1xf32>
+ %14 = vector.transfer_read %5[%c1, %c0], %cst {in_bounds = [true, true]} : tensor<3x1xf32>, vector<1x1xf32>
+ %15 = vector.transfer_read %5[%c2, %c0], %cst {in_bounds = [true, true]} : tensor<3x1xf32>, vector<1x1xf32>
+ %16 = vector.transfer_read %6[%c0, %c0], %cst {in_bounds = [true, true]} : tensor<2x1xf32>, vector<1x1xf32>
+ %17 = vector.transfer_read %6[%c1, %c0], %cst {in_bounds = [true, true]} : tensor<2x1xf32>, vector<1x1xf32>
+ %18 = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %7, %13, %16 : vector<1x1xf32>, vector<1x1xf32> into vector<1x1xf32>
+ %19 = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %8, %14, %18 : vector<1x1xf32>, vector<1x1xf32> into vector<1x1xf32>
+ %20 = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %9, %15, %19 : vector<1x1xf32>, vector<1x1xf32> into vector<1x1xf32>
+ %21 = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %10, %13, %17 : vector<1x1xf32>, vector<1x1xf32> into vector<1x1xf32>
+ %22 = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %11, %14, %21 : vector<1x1xf32>, vector<1x1xf32> into vector<1x1xf32>
+ %23 = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %12, %15, %22 : vector<1x1xf32>, vector<1x1xf32> into vector<1x1xf32>
+ %24 = vector.transfer_write %20, %6[%c0, %c0] {in_bounds = [true, true]} : vector<1x1xf32>, tensor<2x1xf32>
+ %25 = vector.transfer_write %23, %24[%c1, %c0] {in_bounds = [true, true]} : vector<1x1xf32>, tensor<2x1xf32>
+ flow.dispatch.tensor.store %25, %3, offsets = [%c0, %c0], sizes = [%c1, %c1], strides = [%c1, %c1] : tensor<2x1xf32> -> !flow.dispatch.tensor<writeonly:2x4xf32>
+ return
+}
+hal.interface @io attributes {sym_visibility = "private"} {
+ hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+ hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
+ hal.interface.binding @arg2, set=0, binding=2, type="StorageBuffer", access="Read"
+ hal.interface.binding @ret0, set=0, binding=3, type="StorageBuffer", access="Write|Discard"
+}
+// CHECK-LABEL: func @bufferize_transfer_op()
+// CHECK-DAG: %[[ARG0:.+]] = hal.interface.binding.subspan @io::@arg0
+// CHECK-DAG: %[[ARG1:.+]] = hal.interface.binding.subspan @io::@arg1
+// CHECK-DAG: %[[ARG2:.+]] = hal.interface.binding.subspan @io::@arg2
+// CHECK-DAG: %[[RET0:.+]] = hal.interface.binding.subspan @io::@ret0
+// CHECK-DAG: %[[ARG0V:.+]] = memref.subview %[[ARG0]]
+// CHECK-DAG: %[[ARG1V:.+]] = memref.subview %[[ARG1]]
+// CHECK-DAG: %[[ARG2V:.+]] = memref.subview %[[ARG2]]
+// CHECK-COUNT-6: vector.transfer_read %[[ARG0V]]
+// CHECK-COUNT-3: vector.transfer_read %[[ARG1V]]
+// CHECK-COUNT-2: vector.transfer_read %[[ARG2V]]
+// CHECK: %[[RET0V:.+]] = memref.subview %[[RET0]]
+// CHECK: linalg.copy(%[[ARG2V]], %[[RET0V]])
+// CHECK: vector.transfer_write %{{.+}}, %[[RET0V]]
+// CHECK: vector.transfer_write %{{.+}}, %[[RET0V]]
+
+// -----
+
+#map0 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func @bufferize_transfer_op_inplace() {
+ %c3 = constant 3 : index
+ %cst = constant 0.000000e+00 : f32
+ %c0 = constant 0 : index
+ %c2 = constant 2 : index
+ %c1 = constant 1 : index
+ %0 = hal.interface.binding.subspan @io::@arg0[%c0] : !flow.dispatch.tensor<readonly:2x3xf32>
+ %1 = hal.interface.binding.subspan @io::@arg1[%c0] : !flow.dispatch.tensor<readonly:3x4xf32>
+ %3 = hal.interface.binding.subspan @io::@ret0[%c0] : !flow.dispatch.tensor<readwrite:2x4xf32>
+ %4 = flow.dispatch.tensor.load %0, offsets = [%c0, %c0], sizes = [%c1, %c3], strides = [%c1, %c1] : !flow.dispatch.tensor<readonly:2x3xf32> -> tensor<2x3xf32>
+ %5 = flow.dispatch.tensor.load %1, offsets = [%c0, %c0], sizes = [%c3, %c1], strides = [%c1, %c1] : !flow.dispatch.tensor<readonly:3x4xf32> -> tensor<3x1xf32>
+ %6 = flow.dispatch.tensor.load %3, offsets = [%c0, %c0], sizes = [%c1, %c1], strides = [%c1, %c1] : !flow.dispatch.tensor<readwrite:2x4xf32> -> tensor<2x1xf32>
+ %7 = vector.transfer_read %4[%c0, %c0], %cst {in_bounds = [true, true]} : tensor<2x3xf32>, vector<1x1xf32>
+ %8 = vector.transfer_read %4[%c0, %c1], %cst {in_bounds = [true, true]} : tensor<2x3xf32>, vector<1x1xf32>
+ %9 = vector.transfer_read %4[%c0, %c2], %cst {in_bounds = [true, true]} : tensor<2x3xf32>, vector<1x1xf32>
+ %10 = vector.transfer_read %4[%c1, %c0], %cst {in_bounds = [true, true]} : tensor<2x3xf32>, vector<1x1xf32>
+ %11 = vector.transfer_read %4[%c1, %c1], %cst {in_bounds = [true, true]} : tensor<2x3xf32>, vector<1x1xf32>
+ %12 = vector.transfer_read %4[%c1, %c2], %cst {in_bounds = [true, true]} : tensor<2x3xf32>, vector<1x1xf32>
+ %13 = vector.transfer_read %5[%c0, %c0], %cst {in_bounds = [true, true]} : tensor<3x1xf32>, vector<1x1xf32>
+ %14 = vector.transfer_read %5[%c1, %c0], %cst {in_bounds = [true, true]} : tensor<3x1xf32>, vector<1x1xf32>
+ %15 = vector.transfer_read %5[%c2, %c0], %cst {in_bounds = [true, true]} : tensor<3x1xf32>, vector<1x1xf32>
+ %16 = vector.transfer_read %6[%c0, %c0], %cst {in_bounds = [true, true]} : tensor<2x1xf32>, vector<1x1xf32>
+ %17 = vector.transfer_read %6[%c1, %c0], %cst {in_bounds = [true, true]} : tensor<2x1xf32>, vector<1x1xf32>
+ %18 = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %7, %13, %16 : vector<1x1xf32>, vector<1x1xf32> into vector<1x1xf32>
+ %19 = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %8, %14, %18 : vector<1x1xf32>, vector<1x1xf32> into vector<1x1xf32>
+ %20 = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %9, %15, %19 : vector<1x1xf32>, vector<1x1xf32> into vector<1x1xf32>
+ %21 = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %10, %13, %17 : vector<1x1xf32>, vector<1x1xf32> into vector<1x1xf32>
+ %22 = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %11, %14, %21 : vector<1x1xf32>, vector<1x1xf32> into vector<1x1xf32>
+ %23 = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %12, %15, %22 : vector<1x1xf32>, vector<1x1xf32> into vector<1x1xf32>
+ %24 = vector.transfer_write %20, %6[%c0, %c0] {in_bounds = [true, true]} : vector<1x1xf32>, tensor<2x1xf32>
+ %25 = vector.transfer_write %23, %24[%c1, %c0] {in_bounds = [true, true]} : vector<1x1xf32>, tensor<2x1xf32>
+ flow.dispatch.tensor.store %25, %3, offsets = [%c0, %c0], sizes = [%c1, %c1], strides = [%c1, %c1] : tensor<2x1xf32> -> !flow.dispatch.tensor<readwrite:2x4xf32>
+ return
+}
+hal.interface @io attributes {sym_visibility = "private"} {
+ hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+ hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
+ hal.interface.binding @ret0, set=0, binding=3, type="StorageBuffer", access="Write|Discard"
+}
+// CHECK-LABEL: func @bufferize_transfer_op_inplace()
+// CHECK-DAG: %[[ARG0:.+]] = hal.interface.binding.subspan @io::@arg0
+// CHECK-DAG: %[[ARG1:.+]] = hal.interface.binding.subspan @io::@arg1
+// CHECK-DAG: %[[RET0:.+]] = hal.interface.binding.subspan @io::@ret0
+// CHECK-DAG: %[[ARG0V:.+]] = memref.subview %[[ARG0]]
+// CHECK-DAG: %[[ARG1V:.+]] = memref.subview %[[ARG1]]
+// CHECK-DAG: %[[RET0V:.+]] = memref.subview %[[RET0]]
+// CHECK-COUNT-6: vector.transfer_read %[[ARG0V]]
+// CHECK-COUNT-3: vector.transfer_read %[[ARG1V]]
+// CHECK-COUNT-2: vector.transfer_read %[[RET0V]]
+// CHECK-NOT: linalg.copy
+// CHECK: vector.transfer_write %{{.+}}, %[[RET0V]]
+// CHECK: vector.transfer_write %{{.+}}, %[[RET0V]]
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/split_dispatch_function.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/split_dispatch_function.mlir
index be65cc4..977d0c3 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/test/split_dispatch_function.mlir
+++ b/iree/compiler/Conversion/LinalgToSPIRV/test/split_dispatch_function.mlir
@@ -386,7 +386,7 @@
hal.executable.entry_point @kernel attributes {
interface = @io, ordinal = 0 : index,
signature = (!flow.dispatch.tensor<readonly:1x3x3x512xf32>, !flow.dispatch.tensor<readonly:3x3x512x1xf32>,
- !flow.dispatch.tensor<writeonly:1x1x1x512xf32>) -> ()}
+ !flow.dispatch.tensor<writeonly:1x1x1x1xf32>) -> ()}
// CHECK-NOT: hal.entry_point_schedule
module {
// CHECK-LABEL: @kernel()
@@ -394,12 +394,12 @@
%cst = constant 0.000000e+00 : f32
%0 = iree.placeholder for "interface buffer" {binding = @io::@arg0} : memref<1x3x3x512xf32>
%1 = iree.placeholder for "interface buffer" {binding = @io::@arg1} : memref<3x3x512x1xf32>
- %2 = iree.placeholder for "interface buffer" {binding = @io::@ret0} : memref<1x1x1x512xf32>
+ %2 = iree.placeholder for "interface buffer" {binding = @io::@ret0} : memref<1x1x1x1xf32>
linalg.conv_2d_input_nhwc_filter_hwcf {
dilations = dense<1> : tensor<2xi64>,
strides = dense<2> : tensor<2xi64>}
ins(%0, %1 : memref<1x3x3x512xf32>, memref<3x3x512x1xf32>)
- outs(%2 : memref<1x1x1x512xf32>)
+ outs(%2 : memref<1x1x1x1xf32>)
return
}
// CHECK-LABEL: @kernel__num_workgroups__
@@ -428,21 +428,21 @@
hal.executable.entry_point @kernel attributes {
interface = @io, ordinal = 0 : index,
signature = (!flow.dispatch.tensor<readonly:?x3x512xf32>, !flow.dispatch.tensor<readonly:3x512x1xf32>,
- !flow.dispatch.tensor<writeonly:?x1x512xf32>) -> ()}
+ !flow.dispatch.tensor<writeonly:?x1x1xf32>) -> ()}
module {
// expected-error @+1 {{cannot separate Linalg/Parallel ops into multiple kernels}}
func @kernel() {
%cst = constant 0.000000e+00 : f32
%0 = iree.placeholder for "interface buffer" {binding = @io::@arg0} : memref<1x3x3x512xf32>
%1 = iree.placeholder for "interface buffer" {binding = @io::@arg1} : memref<3x3x512x1xf32>
- %2 = iree.placeholder for "interface buffer" {binding = @io::@ret0} : memref<1x1x1x512xf32>
- linalg.fill(%2, %cst) : memref<1x1x1x512xf32>, f32
+ %2 = iree.placeholder for "interface buffer" {binding = @io::@ret0} : memref<1x1x1x1xf32>
+ linalg.fill(%2, %cst) : memref<1x1x1x1xf32>, f32
"some_op"() : () -> ()
linalg.conv_2d_input_nhwc_filter_hwcf {
dilations = dense<1> : tensor<2xi64>,
strides = dense<2> : tensor<2xi64>}
ins(%0, %1 : memref<1x3x3x512xf32>, memref<3x3x512x1xf32>)
- outs(%2 : memref<1x1x1x512xf32>)
+ outs(%2 : memref<1x1x1x1xf32>)
return
}
hal.interface @io attributes {sym_visibility = "private"} {
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/LibraryBuilder.cpp b/iree/compiler/Dialect/HAL/Target/LLVM/LibraryBuilder.cpp
index 1762060..c3e7f8c 100644
--- a/iree/compiler/Dialect/HAL/Target/LLVM/LibraryBuilder.cpp
+++ b/iree/compiler/Dialect/HAL/Target/LLVM/LibraryBuilder.cpp
@@ -198,13 +198,16 @@
llvm::Function *LibraryBuilder::build(StringRef queryFuncName) {
auto &context = module->getContext();
auto *i32Type = llvm::IntegerType::getInt32Ty(context);
+ auto *ptrType = llvm::Type::getInt8PtrTy(context);
auto *libraryHeaderType = makeLibraryHeaderType(context);
- // %struct.iree_hal_executable_library_header_t** @iree_hal_library_query(i32)
+ // %struct.iree_hal_executable_library_header_t**
+ // @iree_hal_library_query(i32, void*)
auto *queryFuncType =
llvm::FunctionType::get(libraryHeaderType->getPointerTo(),
{
i32Type,
+ ptrType,
},
/*isVarArg=*/false);
auto *func =
diff --git a/iree/hal/local/BUILD b/iree/hal/local/BUILD
index f31005a..f6eb991 100644
--- a/iree/hal/local/BUILD
+++ b/iree/hal/local/BUILD
@@ -60,6 +60,20 @@
hdrs = ["executable_library.h"],
)
+cc_test(
+ name = "executable_library_test",
+ srcs = [
+ "executable_library_demo.c",
+ "executable_library_demo.h",
+ "executable_library_test.c",
+ ],
+ deps = [
+ "//iree/base:api",
+ "//iree/base:core_headers",
+ "//iree/hal/local:executable_library",
+ ],
+)
+
cc_library(
name = "local",
srcs = [
diff --git a/iree/hal/local/CMakeLists.txt b/iree/hal/local/CMakeLists.txt
index a83dc9b..c25cdab 100644
--- a/iree/hal/local/CMakeLists.txt
+++ b/iree/hal/local/CMakeLists.txt
@@ -50,6 +50,19 @@
PUBLIC
)
+iree_cc_test(
+ NAME
+ executable_library_test
+ SRCS
+ "executable_library_demo.c"
+ "executable_library_demo.h"
+ "executable_library_test.c"
+ DEPS
+ iree::base::api
+ iree::base::core_headers
+ iree::hal::local::executable_library
+)
+
iree_cc_library(
NAME
local
diff --git a/iree/hal/local/elf/arch.h b/iree/hal/local/elf/arch.h
index 4cef2ec..849b9a1 100644
--- a/iree/hal/local/elf/arch.h
+++ b/iree/hal/local/elf/arch.h
@@ -58,6 +58,9 @@
// void*(*)(int)
void* iree_elf_call_p_i(const void* symbol_ptr, int a0);
+// void*(*)(int, void*)
+void* iree_elf_call_p_ip(const void* symbol_ptr, int a0, void* a1);
+
// int(*)(void*, void*)
int iree_elf_call_i_pp(const void* symbol_ptr, void* a0, void* a1);
diff --git a/iree/hal/local/elf/arch/arm_32.c b/iree/hal/local/elf/arch/arm_32.c
index dacce6f..8b173c6 100644
--- a/iree/hal/local/elf/arch/arm_32.c
+++ b/iree/hal/local/elf/arch/arm_32.c
@@ -131,6 +131,11 @@
return ((ptr_t)symbol_ptr)(a0);
}
+void* iree_elf_call_p_ip(const void* symbol_ptr, int a0, void* a1) {
+ typedef void* (*ptr_t)(int, void*);
+ return ((ptr_t)symbol_ptr)(a0, a1);
+}
+
int iree_elf_call_i_pp(const void* symbol_ptr, void* a0, void* a1) {
typedef int (*ptr_t)(void*, void*);
return ((ptr_t)symbol_ptr)(a0, a1);
diff --git a/iree/hal/local/elf/arch/arm_64.c b/iree/hal/local/elf/arch/arm_64.c
index aa6815f..8823739 100644
--- a/iree/hal/local/elf/arch/arm_64.c
+++ b/iree/hal/local/elf/arch/arm_64.c
@@ -128,6 +128,11 @@
return ((ptr_t)symbol_ptr)(a0);
}
+void* iree_elf_call_p_ip(const void* symbol_ptr, int a0, void* a1) {
+ typedef void* (*ptr_t)(int, void*);
+ return ((ptr_t)symbol_ptr)(a0, a1);
+}
+
int iree_elf_call_i_pp(const void* symbol_ptr, void* a0, void* a1) {
typedef int (*ptr_t)(void*, void*);
return ((ptr_t)symbol_ptr)(a0, a1);
diff --git a/iree/hal/local/elf/arch/riscv.c b/iree/hal/local/elf/arch/riscv.c
index 5114c0d..0694352 100644
--- a/iree/hal/local/elf/arch/riscv.c
+++ b/iree/hal/local/elf/arch/riscv.c
@@ -171,6 +171,11 @@
return ((ptr_t)symbol_ptr)(a0);
}
+void* iree_elf_call_p_ip(const void* symbol_ptr, int a0, void* a1) {
+ typedef void* (*ptr_t)(int, void*);
+ return ((ptr_t)symbol_ptr)(a0, a1);
+}
+
int iree_elf_call_i_pp(const void* symbol_ptr, void* a0, void* a1) {
typedef int (*ptr_t)(void*, void*);
return ((ptr_t)symbol_ptr)(a0, a1);
diff --git a/iree/hal/local/elf/arch/x86_32.c b/iree/hal/local/elf/arch/x86_32.c
index 4db662b..9e431d7 100644
--- a/iree/hal/local/elf/arch/x86_32.c
+++ b/iree/hal/local/elf/arch/x86_32.c
@@ -152,6 +152,11 @@
return ((ptr_t)symbol_ptr)(a0);
}
+void* iree_elf_call_p_ip(const void* symbol_ptr, int a0, void* a1) {
+ typedef void* (*ptr_t)(int, void*);
+ return ((ptr_t)symbol_ptr)(a0, a1);
+}
+
int iree_elf_call_i_pp(const void* symbol_ptr, void* a0, void* a1) {
typedef int (*ptr_t)(void*, void*);
return ((ptr_t)symbol_ptr)(a0, a1);
diff --git a/iree/hal/local/elf/arch/x86_64.c b/iree/hal/local/elf/arch/x86_64.c
index a3181fe..221a128 100644
--- a/iree/hal/local/elf/arch/x86_64.c
+++ b/iree/hal/local/elf/arch/x86_64.c
@@ -193,6 +193,11 @@
return ((ptr_t)symbol_ptr)(a0);
}
+void* iree_elf_call_p_ip(const void* symbol_ptr, int a0, void* a1) {
+ typedef void* (*ptr_t)(int, void*);
+ return ((ptr_t)symbol_ptr)(a0, a1);
+}
+
int iree_elf_call_i_pp(const void* symbol_ptr, void* a0, void* a1) {
typedef int (*ptr_t)(void*, void*);
return ((ptr_t)symbol_ptr)(a0, a1);
diff --git a/iree/hal/local/elf/arch/x86_64_msvc.asm b/iree/hal/local/elf/arch/x86_64_msvc.asm
index 1b8ffec..5118d9f 100644
--- a/iree/hal/local/elf/arch/x86_64_msvc.asm
+++ b/iree/hal/local/elf/arch/x86_64_msvc.asm
@@ -148,6 +148,21 @@
ret
iree_elf_call_p_i ENDP
+; void* iree_elf_call_p_ip(const void* symbol_ptr, int a0, void* a1)
+iree_elf_call_p_ip PROC FRAME
+ _sysv_interop_prolog
+
+ ; RCX = symbol_ptr
+ ; RDX = a0
+ ; R8 = a1
+ mov rdi, rdx
+ mov rsi, r8
+ call rcx
+
+ _sysv_interop_epilog
+ ret
+iree_elf_call_p_ip ENDP
+
; int iree_elf_call_i_pp(const void* symbol_ptr, void* a0, void* a1)
iree_elf_call_i_pp PROC FRAME
_sysv_interop_prolog
diff --git a/iree/hal/local/elf/elf_module_test.cc b/iree/hal/local/elf/elf_module_test.cc
index 3539412..da07cb6 100644
--- a/iree/hal/local/elf/elf_module_test.cc
+++ b/iree/hal/local/elf/elf_module_test.cc
@@ -82,8 +82,9 @@
const iree_hal_executable_library_v0_t* v0;
} library;
library.header =
- (const iree_hal_executable_library_header_t**)iree_elf_call_p_i(
- query_fn_ptr, IREE_HAL_EXECUTABLE_LIBRARY_LATEST_VERSION);
+ (const iree_hal_executable_library_header_t**)iree_elf_call_p_ip(
+ query_fn_ptr, IREE_HAL_EXECUTABLE_LIBRARY_LATEST_VERSION,
+ /*reserved=*/NULL);
ASSERT_TRUE(library.header != NULL);
auto* header = *library.header;
diff --git a/iree/hal/local/executable_library.h b/iree/hal/local/executable_library.h
index ae50f85..6b9c97f 100644
--- a/iree/hal/local/executable_library.h
+++ b/iree/hal/local/executable_library.h
@@ -116,7 +116,7 @@
// than the max version supported by the caller.
typedef const iree_hal_executable_library_header_t** (
*iree_hal_executable_library_query_fn_t)(
- iree_hal_executable_library_version_t max_version);
+ iree_hal_executable_library_version_t max_version, void* reserved);
// Function name exported from dynamic libraries (pass to dlsym).
#define IREE_HAL_EXECUTABLE_LIBRARY_EXPORT_NAME \
diff --git a/iree/hal/local/executable_library_demo.c b/iree/hal/local/executable_library_demo.c
new file mode 100644
index 0000000..935d1b7
--- /dev/null
+++ b/iree/hal/local/executable_library_demo.c
@@ -0,0 +1,98 @@
+// Copyright 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "iree/hal/local/executable_library_demo.h"
+
+// An executable entry point, called one or more times based on the 3D XYZ
+// workgroup count specified during the dispatch. Each invocation gets access to
+// the dispatch state via |dispatch_state| such as workgroup parameters, push
+// constants providing small arguments, and buffer bindings.
+//
+// See the iree_hal_executable_dispatch_state_v0_t struct for more
+// information on the fields here and how they can be used.
+//
+// WARNING: these functions must not access mutable global state: read-only data
+// may be used but as each invocation may be running concurrently with any
+// number of other invocations (from any number of user sessions!) all
+// communication between invocations must use the buffer bindings for I/O.
+//
+// This is a simple scalar addition:
+// binding[1] = binding[0] + push_constant[0]
+static int dispatch_tile_a(
+ const iree_hal_executable_dispatch_state_v0_t* dispatch_state,
+ const iree_hal_vec3_t* workgroup_id) {
+ const dispatch_tile_a_push_constants_t* push_constants =
+ (const dispatch_tile_a_push_constants_t*)dispatch_state->push_constants;
+ const float* src = ((const float*)dispatch_state->binding_ptrs[0]);
+ float* dst = ((float*)dispatch_state->binding_ptrs[1]);
+ dst[workgroup_id->x] = src[workgroup_id->x] + push_constants->f0;
+ return 0;
+}
+
+// Just another entry point.
+static int dispatch_tile_b(
+ const iree_hal_executable_dispatch_state_v0_t* dispatch_state,
+ const iree_hal_vec3_t* workgroup_id) {
+ return 0;
+}
+
+// Version/metadata header.
+static const iree_hal_executable_library_header_t header = {
+ // Declares what library version is present: newer runtimes may support
+ // loading older executables but newer executables cannot load on older
+ // runtimes.
+ .version = IREE_HAL_EXECUTABLE_LIBRARY_LATEST_VERSION,
+ // Name used for logging/diagnostics and rendezvous.
+ .name = "demo_library",
+ .features = IREE_HAL_EXECUTABLE_LIBRARY_FEATURE_NONE,
+ .sanitizer = IREE_HAL_EXECUTABLE_LIBRARY_SANITIZER_NONE,
+};
+// Table of export function entry points.
+static const iree_hal_executable_dispatch_v0_t entry_points[2] = {
+ dispatch_tile_a,
+ dispatch_tile_b,
+};
+// Names for each entry point.
+static const char* entry_point_names[2] = {
+ "dispatch_tile_a",
+ "dispatch_tile_b",
+};
+// User tags for debugging/logging; not used for anything but presentation.
+static const char* entry_point_tags[2] = {
+ "matmul+div",
+ "conv2d[512x512]",
+};
+static const iree_hal_executable_library_v0_t library = {
+ .header = &header,
+ .entry_point_count = 2,
+ .entry_points = entry_points,
+ .entry_point_names = entry_point_names,
+ .entry_point_tags = entry_point_tags,
+};
+
+// The primary access point to the executable: in a static library this is
+// just like any other C symbol that can be called from other code (like
+// executable_library_test.c does), and in dynamic libraries this is the symbol
+// that you would be dlsym'ing.
+//
+// This is just code: if the executable wants to return different headers based
+// on the currently executing architecture or the requested version it can. For
+// example, an executable may want to swap out a few entry points to an
+// architecture-specific version.
+const iree_hal_executable_library_header_t** demo_executable_library_query(
+ iree_hal_executable_library_version_t max_version, void* reserved) {
+ return max_version <= 0
+ ? (const iree_hal_executable_library_header_t**)&library
+ : NULL;
+}
diff --git a/iree/hal/local/executable_library_demo.h b/iree/hal/local/executable_library_demo.h
new file mode 100644
index 0000000..bf183ce
--- /dev/null
+++ b/iree/hal/local/executable_library_demo.h
@@ -0,0 +1,58 @@
+// Copyright 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_HAL_LOCAL_EXECUTABLE_LIBRARY_DEMO_H_
+#define IREE_HAL_LOCAL_EXECUTABLE_LIBRARY_DEMO_H_
+
+#include "iree/hal/local/executable_library.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif // __cplusplus
+
+// Ideally we would have the IREE compiler generate a header like this so that
+// it's possible to manually call into executables. For now this is just an
+// example for the demo: the real HAL does not require this header as it
+// dlsym's the function pointer and packs the push constants itself.
+
+// Push constants used in the 'dispatch_tile_a' entry point.
+typedef union {
+ uint32_t values[1];
+ struct {
+ float f0;
+ };
+} dispatch_tile_a_push_constants_t;
+
+// Returns a simple demo library with the following structure:
+//
+// Name: 'demo_library'
+//
+// [0] 'dispatch_tile_a': matmul+div
+// push constants: 1 (dispatch_tile_a_push_constants_t)
+// bindings: 2
+// [0] = R
+// [1] = W
+//
+// [1] 'dispatch_tile_b': conv2d[512x512]
+// push constants: 0
+// bindings: 0
+//
+const iree_hal_executable_library_header_t** demo_executable_library_query(
+ iree_hal_executable_library_version_t max_version, void* reserved);
+
+#ifdef __cplusplus
+} // extern "C"
+#endif // __cplusplus
+
+#endif // IREE_HAL_LOCAL_EXECUTABLE_LIBRARY_DEMO_H_
diff --git a/iree/hal/local/executable_library_test.c b/iree/hal/local/executable_library_test.c
new file mode 100644
index 0000000..f2f5acc
--- /dev/null
+++ b/iree/hal/local/executable_library_test.c
@@ -0,0 +1,113 @@
+// Copyright 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "iree/hal/local/executable_library.h"
+
+#include <assert.h>
+
+#include "iree/base/api.h"
+#include "iree/hal/local/executable_library_demo.h"
+
+// Demonstration of the HAL-side of the iree_hal_executable_library_t ABI.
+// This is the lowest level of the system right before calling into generated
+// code.
+//
+// This shows what the various execution systems are doing (through a lot
+// of fancy means): all `inline_command_buffer.c` and `task_command_buffer.c`
+// lead up to just calling into the iree_hal_executable_dispatch_v0_t entry
+// point functions with a state structure and a workgroup XYZ.
+//
+// Below walks through acquiring the library pointer (which in this case is a
+// hand-coded example to show the codegen-side), setting up the I/O buffers and
+// state, and calling the function to do some math.
+//
+// See iree/hal/local/executable_library.h for more information.
+int main(int argc, char** argv) {
+ // Query the library header at the requested version.
+ // The query call in this example is going into the handwritten demo code
+ // but could be targeted at generated files or runtime-loaded shared objects.
+ union {
+ const iree_hal_executable_library_header_t** header;
+ const iree_hal_executable_library_v0_t* v0;
+ } library;
+ library.header = demo_executable_library_query(
+ IREE_HAL_EXECUTABLE_LIBRARY_LATEST_VERSION, /*reserved=*/NULL);
+ const iree_hal_executable_library_header_t* header = *library.header;
+ assert(header != NULL && "version may not have matched");
+ assert(header->version <= IREE_HAL_EXECUTABLE_LIBRARY_LATEST_VERSION &&
+ "expecting the library to have the same or older version as us");
+ assert(strcmp(header->name, "demo_library") == 0 &&
+ "library name can be used to rendezvous in a registry");
+ assert(library.v0->entry_point_count > 0 &&
+ "expected at least one entry point");
+
+ // Push constants are an array of 4-byte values that are much more efficient
+ // to specify (no buffer pointer indirection) and more efficient to access
+ // (static struct offset address calculation, all fit in a few cache lines,
+ // etc). They are limited in capacity, though, so only <=64(ish) are usable.
+ dispatch_tile_a_push_constants_t push_constants;
+ memset(&push_constants, 0, sizeof(push_constants));
+ push_constants.f0 = 5.0f;
+
+ // Setup the two buffer bindings the entry point is expecting.
+ // They only need to remain valid for the duration of the invocation and all
+ // memory accessed by the invocation will come from here.
+ float arg0[4] = {1.0f, 2.0f, 3.0f, 4.0f};
+ float ret0[4] = {0.0f, 0.0f, 0.0f, 0.0f};
+ const float ret0_expected[4] = {6.0f, 7.0f, 8.0f, 9.0f};
+ iree_device_size_t binding_lengths[2] = {
+ sizeof(arg0),
+ sizeof(ret0),
+ };
+ void* binding_ptrs[2] = {
+ arg0,
+ ret0,
+ };
+
+ // Resolve the entry point by ordinal.
+ const iree_hal_executable_dispatch_v0_t entry_fn_ptr =
+ library.v0->entry_points[0];
+
+ // Dispatch each workgroup with the same state.
+ iree_hal_executable_dispatch_state_v0_t dispatch_state = {
+ .workgroup_count = {{4, 1, 1}},
+ .workgroup_size = {{1, 1, 1}},
+ .push_constant_count = IREE_ARRAYSIZE(push_constants.values),
+ .push_constants = push_constants.values,
+ .binding_count = IREE_ARRAYSIZE(binding_ptrs),
+ .binding_ptrs = binding_ptrs,
+ .binding_lengths = binding_lengths,
+ .imports = NULL, // not yet implemented
+ };
+ for (uint32_t z = 0; z < dispatch_state.workgroup_count.z; ++z) {
+ for (uint32_t y = 0; y < dispatch_state.workgroup_count.y; ++y) {
+ for (uint32_t x = 0; x < dispatch_state.workgroup_count.x; ++x) {
+ // Invoke the workgroup (x, y, z).
+ iree_hal_vec3_t workgroup_id = {{x, y, z}};
+ int ret = entry_fn_ptr(&dispatch_state, &workgroup_id);
+ assert(ret == 0 &&
+ "if we have bounds checking enabled the executable will signal "
+ "us of badness");
+ }
+ }
+ }
+
+ // Ensure it worked.
+ bool all_match = true;
+ for (size_t i = 0; i < IREE_ARRAYSIZE(ret0_expected); ++i) {
+ assert(ret0[i] == ret0_expected[i] && "math is hard");
+ all_match = all_match && ret0[i] == ret0_expected[i];
+ }
+ return all_match ? 0 : 1;
+}
diff --git a/iree/hal/local/loaders/embedded_library_loader.c b/iree/hal/local/loaders/embedded_library_loader.c
index 34b5c32..dc19c84 100644
--- a/iree/hal/local/loaders/embedded_library_loader.c
+++ b/iree/hal/local/loaders/embedded_library_loader.c
@@ -51,8 +51,9 @@
// Query for a compatible version of the library.
executable->library.header =
- (const iree_hal_executable_library_header_t**)iree_elf_call_p_i(
- query_fn, IREE_HAL_EXECUTABLE_LIBRARY_LATEST_VERSION);
+ (const iree_hal_executable_library_header_t**)iree_elf_call_p_ip(
+ query_fn, IREE_HAL_EXECUTABLE_LIBRARY_LATEST_VERSION,
+ /*reserved=*/NULL);
if (!executable->library.header) {
return iree_make_status(
IREE_STATUS_FAILED_PRECONDITION,
diff --git a/iree/hal/local/loaders/legacy_library_loader.c b/iree/hal/local/loaders/legacy_library_loader.c
index 8459bbd..aca8993 100644
--- a/iree/hal/local/loaders/legacy_library_loader.c
+++ b/iree/hal/local/loaders/legacy_library_loader.c
@@ -124,7 +124,7 @@
// Query for a compatible version of the library.
executable->library.header =
- query_fn(IREE_HAL_EXECUTABLE_LIBRARY_LATEST_VERSION);
+ query_fn(IREE_HAL_EXECUTABLE_LIBRARY_LATEST_VERSION, /*reserved=*/NULL);
if (!executable->library.header) {
return iree_make_status(
IREE_STATUS_FAILED_PRECONDITION,
diff --git a/iree/testing/BUILD b/iree/testing/BUILD
index 76c0323..d7f7ae3 100644
--- a/iree/testing/BUILD
+++ b/iree/testing/BUILD
@@ -21,12 +21,28 @@
)
cc_library(
+ name = "benchmark",
+ testonly = True,
+ srcs = [
+ "benchmark_full.cc",
+ ],
+ hdrs = [
+ "benchmark.h",
+ ],
+ deps = [
+ "//iree/base:api",
+ "//iree/base:tracing",
+ "@com_google_benchmark//:benchmark",
+ ],
+)
+
+cc_library(
name = "benchmark_main",
testonly = True,
- srcs = ["benchmark_main.cc"],
+ srcs = ["benchmark_main.c"],
deps = [
+ ":benchmark",
"//iree/base/internal:flags",
- "@com_google_benchmark//:benchmark",
],
)
diff --git a/iree/testing/CMakeLists.txt b/iree/testing/CMakeLists.txt
index ec7e8e3..f3cc7d9 100644
--- a/iree/testing/CMakeLists.txt
+++ b/iree/testing/CMakeLists.txt
@@ -12,11 +12,26 @@
iree_cc_library(
NAME
- benchmark_main
+ benchmark
+ HDRS
+ "benchmark.h"
SRCS
- "benchmark_main.cc"
+ "benchmark_full.cc"
DEPS
benchmark
+ iree::base::api
+ iree::base::tracing
+ TESTONLY
+ PUBLIC
+)
+
+iree_cc_library(
+ NAME
+ benchmark_main
+ SRCS
+ "benchmark_main.c"
+ DEPS
+ ::benchmark
iree::base::internal::flags
TESTONLY
PUBLIC
diff --git a/iree/testing/benchmark.h b/iree/testing/benchmark.h
new file mode 100644
index 0000000..76fa307
--- /dev/null
+++ b/iree/testing/benchmark.h
@@ -0,0 +1,148 @@
+// Copyright 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_TESTING_BENCHMARK_H_
+#define IREE_TESTING_BENCHMARK_H_
+
+// This is a C API shim for a benchmark-like interface.
+// The intent is that we can write benchmarks that are portable to bare-metal
+// systems and use some simple tooling while also allowing them to run on
+// the full benchmark library with all its useful reporting and statistics.
+
+#include "iree/base/api.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif // __cplusplus
+
+//===----------------------------------------------------------------------===//
+// iree_benchmark_state_t
+//===----------------------------------------------------------------------===//
+
+// Benchmark state manipulator.
+// Passed to each benchmark during execution to control the benchmark state
+// or append information beyond just timing.
+typedef struct iree_benchmark_state_s {
+ // Internal implementation handle.
+ void* impl;
+
+ // Allocator that can be used for host allocations required during benchmark
+ // execution.
+ iree_allocator_t host_allocator;
+} iree_benchmark_state_t;
+
+// Returns a range argument with the given ordial.
+int64_t iree_benchmark_get_range(iree_benchmark_state_t* state,
+ iree_host_size_t ordinal);
+
+// Returns true while the benchmark should keep running its step loop.
+//
+// Usage:
+// while (iree_benchmark_keep_running(state, 1000)) {
+// // process 1000 elements
+// }
+bool iree_benchmark_keep_running(iree_benchmark_state_t* state,
+ uint64_t batch_count);
+
+// Reports that the currently executing benchmark cannot be run.
+// Callers should return after calling as further benchmark-related calls may
+// fail.
+void iree_benchmark_skip(iree_benchmark_state_t* state, const char* message);
+
+// Suspends the benchmark timer until iree_benchmark_resume_timing is called.
+// This can be used to guard per-step code that is required to initialze the
+// work but not something that needs to be accounted for in the benchmark
+// timing. Introduces non-trivial overhead: only use this ~once per step when
+// then going on to perform large amounts of batch work in the step.
+void iree_benchmark_pause_timing(iree_benchmark_state_t* state);
+
+// Resumes the benchmark timer after a prior iree_benchmark_suspend_timing.
+void iree_benchmark_resume_timing(iree_benchmark_state_t* state);
+
+// Sets a label string that will be displayed alongside the report line from the
+// currently executing benchmark.
+void iree_benchmark_set_label(iree_benchmark_state_t* state, const char* label);
+
+// Adds a 'bytes/s' label with the given value.
+//
+// REQUIRES: must only be called outside of the benchmark step loop.
+void iree_benchmark_set_bytes_processed(iree_benchmark_state_t* state,
+ int64_t bytes);
+
+// Adds an `items/s` label with the given value.
+//
+// REQUIRES: must only be called outside of the benchmark step loop.
+void iree_benchmark_set_items_processed(iree_benchmark_state_t* state,
+ int64_t items);
+
+//===----------------------------------------------------------------------===//
+// iree_benchmark_def_t
+//===----------------------------------------------------------------------===//
+
+typedef enum {
+ IREE_BENCHMARK_FLAG_MEASURE_PROCESS_CPU_TIME = 1u << 0,
+
+ IREE_BENCHMARK_FLAG_USE_REAL_TIME = 1u << 1,
+ IREE_BENCHMARK_FLAG_USE_MANUAL_TIME = 1u << 2,
+} iree_benchmark_flags_t;
+
+typedef enum {
+ IREE_BENCHMARK_UNIT_MILLISECOND = 0,
+ IREE_BENCHMARK_UNIT_MICROSECOND,
+ IREE_BENCHMARK_UNIT_NANOSECOND,
+} iree_benchmark_unit_t;
+
+// A benchmark case definition.
+typedef struct iree_benchmark_def_s {
+ // IREE_BENCHMARK_FLAG_* bitmask controlling benchmark behavior and reporting.
+ iree_benchmark_flags_t flags;
+
+ // Time unit used in display.
+ iree_benchmark_unit_t time_unit; // MILLISECOND by default
+
+ // Optional minimum duration the benchmark should run for in nanoseconds.
+ iree_duration_t minimum_duration_ns; // 0 if unspecified to autodetect
+ // Optional iteration count the benchmark should run for.
+ uint64_t iteration_count; // 0 if unspecified to autodetect
+
+ // TODO(benvanik): add range arguments.
+
+ // Runs the benchmark to completion.
+ // Implementations must call iree_benchmark_keep_running in a loop until it
+ // returns false.
+ iree_status_t (*run)(iree_benchmark_state_t* state);
+} iree_benchmark_def_t;
+
+// Registers a benchmark with the given definition.
+void iree_benchmark_register(iree_string_view_t name,
+ iree_benchmark_def_t* benchmark_def);
+
+//===----------------------------------------------------------------------===//
+// Benchmark infra management
+//===----------------------------------------------------------------------===//
+
+// Initializes the benchmark framework.
+// Must be called before any other iree_benchmark_* functions.
+void iree_benchmark_initialize(int* argc, char** argv);
+
+// Runs all registered benchmarks specified by the command line flags.
+// Must be called after iree_benchmark_initialize and zero or more benchmarks
+// have been registered with iree_benchmark_register.
+void iree_benchmark_run_specified();
+
+#ifdef __cplusplus
+} // extern "C"
+#endif // __cplusplus
+
+#endif // IREE_TESTING_BENCHMARK_H_
diff --git a/iree/testing/benchmark_full.cc b/iree/testing/benchmark_full.cc
new file mode 100644
index 0000000..a805970
--- /dev/null
+++ b/iree/testing/benchmark_full.cc
@@ -0,0 +1,164 @@
+// Copyright 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "benchmark/benchmark.h"
+#include "iree/base/tracing.h"
+#include "iree/testing/benchmark.h"
+
+//===----------------------------------------------------------------------===//
+// iree_benchmark_state_t
+//===----------------------------------------------------------------------===//
+
+benchmark::State& GetBenchmarkState(iree_benchmark_state_t* state) {
+ return *(benchmark::State*)state->impl;
+}
+
+int64_t iree_benchmark_get_range(iree_benchmark_state_t* state,
+ iree_host_size_t ordinal) {
+ auto& s = GetBenchmarkState(state);
+ return s.range(ordinal);
+}
+
+bool iree_benchmark_keep_running(iree_benchmark_state_t* state,
+ uint64_t batch_count) {
+ auto& s = GetBenchmarkState(state);
+ return s.KeepRunningBatch(batch_count);
+}
+
+void iree_benchmark_skip(iree_benchmark_state_t* state, const char* message) {
+ auto& s = GetBenchmarkState(state);
+ s.SkipWithError(message);
+}
+
+void iree_benchmark_pause_timing(iree_benchmark_state_t* state) {
+ auto& s = GetBenchmarkState(state);
+ s.PauseTiming();
+}
+
+void iree_benchmark_resume_timing(iree_benchmark_state_t* state) {
+ auto& s = GetBenchmarkState(state);
+ s.ResumeTiming();
+}
+
+void iree_benchmark_set_label(iree_benchmark_state_t* state,
+ const char* label) {
+ auto& s = GetBenchmarkState(state);
+ s.SetLabel(label);
+}
+
+void iree_benchmark_set_bytes_processed(iree_benchmark_state_t* state,
+ int64_t bytes) {
+ auto& s = GetBenchmarkState(state);
+ s.SetBytesProcessed(bytes);
+}
+
+void iree_benchmark_set_items_processed(iree_benchmark_state_t* state,
+ int64_t items) {
+ auto& s = GetBenchmarkState(state);
+ s.SetItemsProcessed(items);
+}
+
+//===----------------------------------------------------------------------===//
+// iree_benchmark_def_t
+//===----------------------------------------------------------------------===//
+
+static std::string StatusToString(iree_status_t status) {
+ if (iree_status_is_ok(status)) {
+ return "OK";
+ }
+ iree_host_size_t buffer_length = 0;
+ if (IREE_UNLIKELY(!iree_status_format(status, /*buffer_capacity=*/0,
+ /*buffer=*/NULL, &buffer_length))) {
+ return "<!>";
+ }
+ std::string result(buffer_length, '\0');
+ if (IREE_UNLIKELY(!iree_status_format(status, result.size() + 1,
+ const_cast<char*>(result.data()),
+ &buffer_length))) {
+ return "<!>";
+ }
+ return result;
+}
+
+static void iree_benchmark_run(const char* benchmark_name,
+ const iree_benchmark_def_t* benchmark_def,
+ benchmark::State& benchmark_state) {
+ IREE_TRACE_SCOPE_DYNAMIC(benchmark_name);
+ IREE_TRACE_FRAME_MARK();
+
+ iree_benchmark_state_t state;
+ memset(&state, 0, sizeof(state));
+ state.impl = &benchmark_state;
+ state.host_allocator = iree_allocator_system();
+
+ iree_status_t status = benchmark_def->run(&state);
+ if (!iree_status_is_ok(status)) {
+ auto status_str = StatusToString(status);
+ iree_status_ignore(status);
+ benchmark_state.SkipWithError(status_str.c_str());
+ }
+}
+
+void iree_benchmark_register(iree_string_view_t name,
+ const iree_benchmark_def_t* benchmark_def) {
+ std::string name_str(name.data, name.size);
+ std::string prefixed_str = "BM_" + name_str;
+ auto* instance = benchmark::RegisterBenchmark(
+ prefixed_str.c_str(),
+ [name_str, benchmark_def](benchmark::State& state) -> void {
+ iree_benchmark_run(name_str.c_str(), benchmark_def, state);
+ });
+
+ if (iree_all_bits_set(benchmark_def->flags,
+ IREE_BENCHMARK_FLAG_MEASURE_PROCESS_CPU_TIME)) {
+ instance->MeasureProcessCPUTime();
+ }
+ if (iree_all_bits_set(benchmark_def->flags,
+ IREE_BENCHMARK_FLAG_USE_REAL_TIME)) {
+ instance->UseRealTime();
+ }
+ if (iree_all_bits_set(benchmark_def->flags,
+ IREE_BENCHMARK_FLAG_USE_MANUAL_TIME)) {
+ instance->UseManualTime();
+ }
+
+ if (benchmark_def->minimum_duration_ns != 0) {
+ instance->MinTime((double)benchmark_def->minimum_duration_ns / 1e-9);
+ } else if (benchmark_def->iteration_count != 0) {
+ instance->Iterations(benchmark_def->iteration_count);
+ }
+
+ switch (benchmark_def->time_unit) {
+ default:
+ case IREE_BENCHMARK_UNIT_MILLISECOND:
+ instance->Unit(benchmark::kMillisecond);
+ break;
+ case IREE_BENCHMARK_UNIT_MICROSECOND:
+ instance->Unit(benchmark::kMicrosecond);
+ break;
+ case IREE_BENCHMARK_UNIT_NANOSECOND:
+ instance->Unit(benchmark::kNanosecond);
+ break;
+ }
+}
+
+//===----------------------------------------------------------------------===//
+// Benchmark infra management
+//===----------------------------------------------------------------------===//
+
+void iree_benchmark_initialize(int* argc, char** argv) {
+ benchmark::Initialize(argc, argv);
+}
+
+void iree_benchmark_run_specified() { benchmark::RunSpecifiedBenchmarks(); }
diff --git a/iree/testing/benchmark_main.cc b/iree/testing/benchmark_main.c
similarity index 77%
rename from iree/testing/benchmark_main.cc
rename to iree/testing/benchmark_main.c
index eeb5224..2fcf14b 100644
--- a/iree/testing/benchmark_main.cc
+++ b/iree/testing/benchmark_main.c
@@ -12,16 +12,12 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include "benchmark/benchmark.h"
#include "iree/base/internal/flags.h"
+#include "iree/testing/benchmark.h"
-namespace iree {
-
-extern "C" int main(int argc, char** argv) {
- ::benchmark::Initialize(&argc, argv);
+int main(int argc, char** argv) {
+ iree_benchmark_initialize(&argc, argv);
iree_flags_parse_checked(&argc, &argv);
- ::benchmark::RunSpecifiedBenchmarks();
+ iree_benchmark_run_specified();
return 0;
}
-
-} // namespace iree
diff --git a/iree/tools/compilation.bzl b/iree/tools/compilation.bzl
index b3a85be..159df41 100644
--- a/iree/tools/compilation.bzl
+++ b/iree/tools/compilation.bzl
@@ -14,7 +14,7 @@
"""Rules for compiling IREE executables, modules, and archives."""
-load("//build_tools/embed_data:build_defs.bzl", "cc_embed_data")
+load("//build_tools/embed_data:build_defs.bzl", "c_embed_data", "cc_embed_data")
# TODO(benvanik): port to a full starlark rule, document, etc.
def iree_bytecode_module(
@@ -23,6 +23,7 @@
flags = ["-iree-mlir-to-vm-bytecode-module"],
translate_tool = "//iree/tools:iree-translate",
cc_namespace = None,
+ c_output = False,
**kwargs):
native.genrule(
name = name,
@@ -57,3 +58,14 @@
flatten = True,
**kwargs
)
+ # Embed the module for use in C.
+ if c_output:
+ c_embed_data(
+ name = "%s_c" % (name),
+ identifier = "%s_c" % (name),
+ srcs = ["%s.vmfb" % (name)],
+ c_file_output = "%s_c.c" % (name),
+ h_file_output = "%s_c.h" % (name),
+ flatten = True,
+ **kwargs
+ )