Merge pull request #2742 from google/benvanik-vm64-flags
diff --git a/SUBMODULE_VERSIONS b/SUBMODULE_VERSIONS
index 879a1cf..90e5f67 100644
--- a/SUBMODULE_VERSIONS
+++ b/SUBMODULE_VERSIONS
@@ -4,15 +4,15 @@
a5d9d0f7d368054fd1691aedf1db4116efcc233e third_party/flatbuffers
4fb0ff7069bd88ee85902f4d0bb62794e5f6d021 third_party/flatcc
f2fb48c3b3d79a75a88a99fba6576b25d42ec528 third_party/googletest
-950f1bf976b332eca60267b25bf759e2ad564e0c third_party/llvm-project
+30c1633386e7cfb01c0a54b31ccf4c3a3873e71b third_party/llvm-project
17b12a4481daa150e2d1ea3ada086b551b856707 third_party/marl
-80885f899e12d55a45561ef758eea47bb340dbf1 third_party/mlir-emitc
+a3479bbf9161df8c8cac55a08205864e6f371491 third_party/mlir-emitc
d8c7ee00a687ac369e62e2032514a93a9b413502 third_party/pybind11
9f53ba413e6fc879236dcaa3e008915973d67a4f third_party/ruy
a1390ed39ec77ecfb574bc6fcd5bfc5e3adbdea9 third_party/sdl2
f8bf11a0253a32375c32cad92c841237b96696c0 third_party/spirv_headers
57eb48aed36160c4876bc8310d9ca84d42ee9e2a third_party/swiftshader
-8a4ffe2e1ae722cff5306778df0cfca8b7f503fe third_party/tensorflow
+86efb18ca5812c76dd52c8536f336e6962b7f8ca third_party/tensorflow
864d86e8b6d21449474db5e9313dbff90aa9c24f third_party/tracy
9bd3f561bcee3f01d22912de10bb07ce4e23d378 third_party/vulkan_headers
909f36b714c9239ee0b112a321220213a474ba53 third_party/vulkan_memory_allocator
diff --git a/bindings/python/build_tools/python/generate_build.py b/bindings/python/build_tools/python/generate_build.py
index 6705cfd..a5c0bda 100644
--- a/bindings/python/build_tools/python/generate_build.py
+++ b/bindings/python/build_tools/python/generate_build.py
@@ -18,10 +18,6 @@
# Debugging hint: Just runt his with python to see what it prints.
"""Generates a bazel BUILD file for the repo."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
import json
import os
import sys
diff --git a/bindings/python/pyiree/compiler/compiler.cc b/bindings/python/pyiree/compiler/compiler.cc
index 726cabc..78a2bd4 100644
--- a/bindings/python/pyiree/compiler/compiler.cc
+++ b/bindings/python/pyiree/compiler/compiler.cc
@@ -31,6 +31,7 @@
#include "iree/tools/init_mlir_dialects.h"
#include "iree/tools/init_mlir_passes.h"
#include "iree/tools/init_targets.h"
+#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/PrettyStackTrace.h"
#include "llvm/Support/Signals.h"
@@ -185,71 +186,61 @@
// doesn't do any path shortening, which seems to make long Python stack traces
// a bit easier to scan.
void PrintLocation(Location loc, raw_ostream& out) {
- switch (loc->getKind()) {
- case StandardAttributes::OpaqueLocation:
- PrintLocation(loc.cast<OpaqueLoc>().getFallbackLocation(), out);
- break;
- case StandardAttributes::UnknownLocation:
- out << " [unknown location]\n";
- break;
- case StandardAttributes::FileLineColLocation: {
- auto line_col_loc = loc.cast<FileLineColLoc>();
- StringRef this_filename = line_col_loc.getFilename();
- auto slash_pos = this_filename.find_last_of("/\\");
- // We print both the basename and extended names with a structure like
- // `foo.py:35:4`. Even though technically the line/col
- // information is redundant to include in both names, having it on both
- // makes it easier to paste the paths into an editor and jump to the exact
- // location.
- std::string line_col_suffix =
- ":" + std::to_string(line_col_loc.getLine()) + ":" +
- std::to_string(line_col_loc.getColumn());
- bool has_basename = false;
- StringRef basename = this_filename;
- if (slash_pos != StringRef::npos) {
- has_basename = true;
- basename = this_filename.substr(slash_pos + 1);
- }
- out << " at: " << basename << line_col_suffix;
- if (has_basename) {
- // When running through bazel, such as in our e2e test suite,
- // the paths involved can be quite large, and will have a very long
- // prefix before the sandboxed "runfiles" directory that the program
- // runs in. Trim off that long prefix. By convention, the path names
- // with this prefix dropped will correspond to the path in the source
- // directory, which is probably what we want anyway.
- StringRef kRunfiles(".runfiles/");
- StringRef extended_name = this_filename;
- auto runfiles_pos = extended_name.rfind(kRunfiles);
- if (runfiles_pos != StringRef::npos) {
- extended_name =
- extended_name.drop_front(runfiles_pos + kRunfiles.size());
+ TypeSwitch<Location>(loc)
+ .Case<OpaqueLoc>(
+ [&](OpaqueLoc loc) { PrintLocation(loc.getFallbackLocation(), out); })
+ .Case<UnknownLoc>([&](UnknownLoc) { out << " [unknown location]\n"; })
+ .Case<FileLineColLoc>([&](FileLineColLoc line_col_loc) {
+ StringRef this_filename = line_col_loc.getFilename();
+ auto slash_pos = this_filename.find_last_of("/\\");
+ // We print both the basename and extended names with a structure like
+ // `foo.py:35:4`. Even though technically the line/col
+ // information is redundant to include in both names, having it on both
+ // makes it easier to paste the paths into an editor and jump to the
+ // exact location.
+ std::string line_col_suffix =
+ ":" + std::to_string(line_col_loc.getLine()) + ":" +
+ std::to_string(line_col_loc.getColumn());
+ bool has_basename = false;
+ StringRef basename = this_filename;
+ if (slash_pos != StringRef::npos) {
+ has_basename = true;
+ basename = this_filename.substr(slash_pos + 1);
}
- // Print out two tabs, as basenames usually vary in length by more than
- // one tab width.
- out << "\t\t( " << extended_name << line_col_suffix << " )";
- }
- out << "\n";
- break;
- }
- case StandardAttributes::NameLocation: {
- auto nameLoc = loc.cast<NameLoc>();
- out << " @'" << nameLoc.getName() << "':\n";
- auto childLoc = nameLoc.getChildLoc();
- if (!childLoc.isa<UnknownLoc>()) {
- out << "(...\n";
- PrintLocation(childLoc, out);
- out << ")\n";
- }
- break;
- }
- case StandardAttributes::CallSiteLocation: {
- auto call_site = loc.cast<CallSiteLoc>();
- PrintLocation(call_site.getCaller(), out);
- PrintLocation(call_site.getCallee(), out);
- break;
- }
- }
+ out << " at: " << basename << line_col_suffix;
+ if (has_basename) {
+ // When running through bazel, such as in our e2e test suite,
+ // the paths involved can be quite large, and will have a very long
+ // prefix before the sandboxed "runfiles" directory that the program
+ // runs in. Trim off that long prefix. By convention, the path names
+ // with this prefix dropped will correspond to the path in the source
+ // directory, which is probably what we want anyway.
+ StringRef kRunfiles(".runfiles/");
+ StringRef extended_name = this_filename;
+ auto runfiles_pos = extended_name.rfind(kRunfiles);
+ if (runfiles_pos != StringRef::npos) {
+ extended_name =
+ extended_name.drop_front(runfiles_pos + kRunfiles.size());
+ }
+ // Print out two tabs, as basenames usually vary in length by more
+ // than one tab width.
+ out << "\t\t( " << extended_name << line_col_suffix << " )";
+ }
+ out << "\n";
+ })
+ .Case<NameLoc>([&](NameLoc name_loc) {
+ out << " @'" << name_loc.getName() << "':\n";
+ auto child_loc = name_loc.getChildLoc();
+ if (!child_loc.isa<UnknownLoc>()) {
+ out << "(...\n";
+ PrintLocation(child_loc, out);
+ out << ")\n";
+ }
+ })
+ .Case<CallSiteLoc>([&](CallSiteLoc call_site) {
+ PrintLocation(call_site.getCaller(), out);
+ PrintLocation(call_site.getCallee(), out);
+ });
}
std::string DiagnosticCapture::ConsumeDiagnosticsAsString(
diff --git a/bindings/python/pyiree/rt/BUILD b/bindings/python/pyiree/rt/BUILD
index 6d77a7f..a89363a 100644
--- a/bindings/python/pyiree/rt/BUILD
+++ b/bindings/python/pyiree/rt/BUILD
@@ -117,10 +117,6 @@
name = "function_abi_test",
srcs = ["function_abi_test.py"],
python_version = "PY3",
- # TODO(laurenzo): Enable once test does not depend on a real vulkan device.
- tags = [
- "nokokoro",
- ],
deps = NUMPY_DEPS + [
"//bindings/python:pathsetup", # build_cleaner: keep
"@absl_py//absl/testing:absltest",
diff --git a/bindings/python/pyiree/rt/function_abi.cc b/bindings/python/pyiree/rt/function_abi.cc
index 58a0295..7434653 100644
--- a/bindings/python/pyiree/rt/function_abi.cc
+++ b/bindings/python/pyiree/rt/function_abi.cc
@@ -356,7 +356,7 @@
VmVariantList& f_args,
VmVariantList& f_results) {
if (f_args.size() != raw_config().inputs.size()) {
- throw RaiseValueError("Mismatched AllocatResults() input arity");
+ throw RaiseValueError("Mismatched AllocateResults() input arity");
}
for (size_t i = 0, e = descs.size(); i < e; ++i) {
diff --git a/bindings/python/pyiree/rt/function_abi_test.py b/bindings/python/pyiree/rt/function_abi_test.py
index cb8c804..6e23f10 100644
--- a/bindings/python/pyiree/rt/function_abi_test.py
+++ b/bindings/python/pyiree/rt/function_abi_test.py
@@ -16,6 +16,7 @@
import re
+from absl import logging
from absl.testing import absltest
import numpy as np
@@ -40,7 +41,7 @@
def test_baseclass(self):
htf = rt.HostTypeFactory()
- print(htf)
+ logging.info("HostTypeFactory: %s", htf)
class FunctionAbiTest(absltest.TestCase):
@@ -50,12 +51,12 @@
super().setUpClass()
driver_names = rt.HalDriver.query()
for driver_name in driver_names:
- print("Try create driver:", driver_name)
+ logging.info("Try to create driver: %s", driver_name)
try:
cls.driver = rt.HalDriver.create(driver_name)
cls.device = cls.driver.create_default_device()
except Exception:
- print("Could not create driver:", driver_name)
+ logging.error("Could not create driver: %s", driver_name)
else:
break
@@ -66,7 +67,7 @@
def test_static_arg_success(self):
fabi = rt.FunctionAbi(self.device, self.htf,
ATTRS_1ARG_FLOAT32_10X128X64_TO_SINT32_32X8X64_V1)
- print(fabi)
+ logging.info("fabi: %s", fabi)
self.assertEqual(
"<FunctionAbi (Buffer<float32[10x128x64]>) -> "
"(Buffer<sint32[32x8x64]>)>", repr(fabi))
@@ -75,7 +76,7 @@
arg = np.zeros((10, 128, 64), dtype=np.float32)
packed = fabi.raw_pack_inputs([arg])
- print(packed)
+ logging.info("packed: %s", packed)
self.assertEqual("<VmVariantList(1): [HalBufferView(10x128x64:0x3000020)]>",
repr(packed))
@@ -85,7 +86,7 @@
arg = np.zeros((10, 128, 64), dtype=np.float32)
f_args = fabi.raw_pack_inputs([arg])
f_results = fabi.allocate_results(f_args)
- print(f_results)
+ logging.info("f_results: %s", f_results)
self.assertEqual("<VmVariantList(1): [HalBufferView(32x8x64:0x1000020)]>",
repr(f_results))
py_result, = fabi.raw_unpack_results(f_results)
@@ -98,13 +99,13 @@
arg = np.zeros((10, 128, 64), dtype=np.float32)
f_args = fabi.raw_pack_inputs([arg])
f_results = fabi.allocate_results(f_args, static_alloc=False)
- print(f_results)
+ logging.info("f_results: %s", f_results)
self.assertEqual("<VmVariantList(0): []>", repr(f_results))
def test_dynamic_arg_success(self):
fabi = rt.FunctionAbi(self.device, self.htf,
ATTRS_1ARG_FLOAT32_DYNX128X64_TO_SINT32_DYNX8X64_V1)
- print(fabi)
+ logging.info("fabi: %s", fabi)
self.assertEqual(
"<FunctionAbi (Buffer<float32[?x128x64]>) -> "
"(Buffer<sint32[?x8x64]>)>", repr(fabi))
@@ -113,14 +114,14 @@
arg = np.zeros((10, 128, 64), dtype=np.float32)
packed = fabi.raw_pack_inputs([arg])
- print(packed)
+ logging.info("packed: %s", packed)
self.assertEqual("<VmVariantList(1): [HalBufferView(10x128x64:0x3000020)]>",
repr(packed))
def test_static_arg_rank_mismatch(self):
fabi = rt.FunctionAbi(self.device, self.htf,
ATTRS_1ARG_FLOAT32_10X128X64_TO_SINT32_32X8X64_V1)
- print(fabi)
+ logging.info("fabi: %s", fabi)
arg = np.zeros((10,), dtype=np.float32)
with self.assertRaisesRegex(
ValueError,
@@ -130,7 +131,7 @@
def test_static_arg_eltsize_mismatch(self):
fabi = rt.FunctionAbi(self.device, self.htf,
ATTRS_1ARG_FLOAT32_10X128X64_TO_SINT32_32X8X64_V1)
- print(fabi)
+ logging.info("fabi: %s", fabi)
arg = np.zeros((10, 128, 64), dtype=np.float64)
with self.assertRaisesRegex(
ValueError,
@@ -140,7 +141,7 @@
def test_static_arg_dtype_mismatch(self):
fabi = rt.FunctionAbi(self.device, self.htf,
ATTRS_1ARG_FLOAT32_10X128X64_TO_SINT32_32X8X64_V1)
- print(fabi)
+ logging.info("fabi: %s", fabi)
arg = np.zeros((10, 128, 64), dtype=np.int32)
with self.assertRaisesRegex(
ValueError,
@@ -150,7 +151,7 @@
def test_static_arg_static_dim_mismatch(self):
fabi = rt.FunctionAbi(self.device, self.htf,
ATTRS_1ARG_FLOAT32_10X128X64_TO_SINT32_32X8X64_V1)
- print(fabi)
+ logging.info("fabi: %s", fabi)
arg = np.zeros((10, 32, 64), dtype=np.float32)
with self.assertRaisesRegex(
ValueError,
diff --git a/bindings/python/pyiree/rt/hal_test.py b/bindings/python/pyiree/rt/hal_test.py
index b7ab59b..a7f1a12 100644
--- a/bindings/python/pyiree/rt/hal_test.py
+++ b/bindings/python/pyiree/rt/hal_test.py
@@ -12,10 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
+from absl import logging
from absl.testing import absltest
import numpy as np
@@ -25,8 +22,8 @@
class HalTest(absltest.TestCase):
def testEnums(self):
- print("MemoryType =", rt.MemoryType)
- print("HOST_VISIBLE =", int(rt.MemoryType.HOST_VISIBLE))
+ logging.info("MemoryType: %s", rt.MemoryType)
+ logging.info("HOST_VISIBLE: %s", int(rt.MemoryType.HOST_VISIBLE))
def testAllocateHeap(self):
b = rt.HalBuffer.allocate_heap(
diff --git a/bindings/python/pyiree/rt/system_api.py b/bindings/python/pyiree/rt/system_api.py
index aaea01f..327883bc 100644
--- a/bindings/python/pyiree/rt/system_api.py
+++ b/bindings/python/pyiree/rt/system_api.py
@@ -138,7 +138,7 @@
def __call__(self, *args):
# NOTE: This is just doing sync dispatch right now. In the future,
# this should default to async and potentially have some kind of policy
- # flag that can allow it to be overriden.
+ # flag that can allow it to be overridden.
inputs = self._abi.raw_pack_inputs(args)
results = self._abi.allocate_results(inputs, static_alloc=False)
self._context._vm_context.invoke(self._vm_function, inputs, results)
@@ -269,8 +269,7 @@
def load_modules(*modules, config: Optional[Config] = None):
"""Loads modules into a new or shared context and returns them."""
context = SystemContext(modules=modules, config=config)
- context_modules = context.modules
- bound_modules = [context_modules[m.name] for m in modules]
+ bound_modules = [context.modules[m.name] for m in modules]
return bound_modules
diff --git a/bindings/python/pyiree/rt/system_api_test.py b/bindings/python/pyiree/rt/system_api_test.py
index 9d670ce..ca47439 100644
--- a/bindings/python/pyiree/rt/system_api_test.py
+++ b/bindings/python/pyiree/rt/system_api_test.py
@@ -17,6 +17,7 @@
import re
+from absl import logging
from absl.testing import absltest
import numpy as np
from pyiree import compiler
@@ -68,7 +69,7 @@
self.assertEqual(ctx.modules.arithmetic.name, "arithmetic")
f = ctx.modules.arithmetic["simple_mul"]
f_repr = repr(f)
- print(f_repr)
+ logging.info("f_repr: %s", f_repr)
self.assertRegex(
f_repr,
re.escape(
diff --git a/bindings/python/pyiree/rt/vm_test.py b/bindings/python/pyiree/rt/vm_test.py
index 5a3c1ec..da05a6c 100644
--- a/bindings/python/pyiree/rt/vm_test.py
+++ b/bindings/python/pyiree/rt/vm_test.py
@@ -15,6 +15,7 @@
# pylint: disable=unused-variable
+from absl import logging
from absl.testing import absltest
import numpy as np
from pyiree import compiler
@@ -70,7 +71,7 @@
def setUpClass(cls):
super().setUpClass()
driver_names = rt.HalDriver.query()
- print("DRIVER_NAMES =", driver_names)
+ logging.info("driver_names: %s", driver_names)
cls.driver = rt.HalDriver.create("vmla")
cls.device = cls.driver.create_default_device()
cls.hal_module = rt.create_hal_module(cls.device)
@@ -78,7 +79,7 @@
def test_variant_list(self):
l = rt.VmVariantList(5)
- print(l)
+ logging.info("variant_list: %s", l)
self.assertEqual(l.size, 0)
def test_context_id(self):
@@ -102,19 +103,19 @@
def test_static_module_context(self):
m = create_simple_static_mul_module()
- print(m)
+ logging.info("module: %s", m)
instance = rt.VmInstance()
- print(instance)
+ logging.info("instance: %s", instance)
context = rt.VmContext(instance, modules=[self.hal_module, m])
- print(context)
+ logging.info("context: %s", context)
def test_dynamic_shape_compile(self):
m = create_simple_dynamic_abs_module()
- print(m)
+ logging.info("module: %s", m)
instance = rt.VmInstance()
- print(instance)
+ logging.info("instance: %s", instance)
context = rt.VmContext(instance, modules=[self.hal_module, m])
- print(context)
+ logging.info("context: %s", context)
def test_add_scalar(self):
m = create_add_scalar_module()
@@ -122,18 +123,19 @@
context = rt.VmContext(instance, modules=[self.hal_module, m])
f = m.lookup_function("add_scalar")
abi = context.create_function_abi(self.device, self.htf, f)
- print("INVOKING:", abi)
- arg0 = np.array([1., 2., 3., 4.], dtype=np.float32)
- arg1 = np.array([4., 5., 6., 7.], dtype=np.float32)
+ logging.info("abi: %s", abi)
+
inputs = abi.raw_pack_inputs((5, 6))
- print("INPUTS:", inputs)
+ logging.info("inputs: %s", inputs)
+
allocated_results = abi.allocate_results(inputs, static_alloc=False)
- print("ALLOCATED RESULTS:", allocated_results)
- print("--- INVOKE:")
+ logging.info("allocated_results: %s", allocated_results)
+ logging.info("Invoking...")
context.invoke(f, inputs, allocated_results)
- print("--- DONE.")
+ logging.info("...done")
+
results = abi.raw_unpack_results(allocated_results)
- print("RESULTS:", results)
+ logging.info("results: %s", results)
self.assertEqual(results[0], 11)
def test_synchronous_dynamic_shape_invoke_function(self):
@@ -142,17 +144,20 @@
context = rt.VmContext(instance, modules=[self.hal_module, m])
f = m.lookup_function("simple_mul")
abi = context.create_function_abi(self.device, self.htf, f)
- print("INVOKING:", abi)
+ logging.info("abi: %s", abi)
+
arg0 = np.array([[-1., 2.], [3., -4.]], dtype=np.float32)
inputs = abi.raw_pack_inputs((arg0,))
- print("INPUTS:", inputs)
+ logging.info("inputs: %s", inputs)
+
allocated_results = abi.allocate_results(inputs, static_alloc=False)
- print("ALLOCATED RESULTS:", allocated_results)
- print("--- INVOKE:")
+ logging.info("allocated_results: %s", allocated_results)
+ logging.info("Invoking...")
context.invoke(f, inputs, allocated_results)
- print("--- DONE.")
+ logging.info("...done")
+
results = abi.raw_unpack_results(allocated_results)
- print("RESULTS:", results)
+ logging.info("results: %s", results)
np.testing.assert_allclose(results[0], [[1., 2.], [3., 4.]])
def test_synchronous_invoke_function(self):
@@ -161,18 +166,21 @@
context = rt.VmContext(instance, modules=[self.hal_module, m])
f = m.lookup_function("simple_mul")
abi = context.create_function_abi(self.device, self.htf, f)
- print("INVOKING:", abi)
+ logging.info("abi: %s", abi)
+
arg0 = np.array([1., 2., 3., 4.], dtype=np.float32)
arg1 = np.array([4., 5., 6., 7.], dtype=np.float32)
inputs = abi.raw_pack_inputs((arg0, arg1))
- print("INPUTS:", inputs)
+ logging.info("inputs: %s", inputs)
+
allocated_results = abi.allocate_results(inputs, static_alloc=False)
- print("ALLOCATED RESULTS:", allocated_results)
- print("--- INVOKE:")
+ logging.info("allocated_results: %s", allocated_results)
+ logging.info("Invoking...")
context.invoke(f, inputs, allocated_results)
- print("--- DONE.")
+ logging.info("...done")
+
results = abi.raw_unpack_results(allocated_results)
- print("RESULTS:", results)
+ logging.info("results: %s", results)
np.testing.assert_allclose(results[0], [4., 10., 18., 28.])
diff --git a/build_tools/cmake/iree_copts.cmake b/build_tools/cmake/iree_copts.cmake
index 7563c1f..8ac8c2d 100644
--- a/build_tools/cmake/iree_copts.cmake
+++ b/build_tools/cmake/iree_copts.cmake
@@ -249,6 +249,8 @@
${PROJECT_SOURCE_DIR}/third_party/tensorflow/tensorflow/compiler/mlir/hlo/include/
${PROJECT_BINARY_DIR}/build_tools/third_party/tensorflow
${PROJECT_BINARY_DIR}/build_tools/third_party/tensorflow/tensorflow/compiler/mlir/hlo/include/
+ ${PROJECT_BINARY_DIR}/build_tools/third_party/tensorflow/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/
+ ${PROJECT_BINARY_DIR}/build_tools/third_party/tensorflow/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms
)
#-------------------------------------------------------------------------------
diff --git a/experimental/ModelBuilder/test/TestMatMulVulkan.cpp b/experimental/ModelBuilder/test/TestMatMulVulkan.cpp
index b31c663..2a2fb6b 100644
--- a/experimental/ModelBuilder/test/TestMatMulVulkan.cpp
+++ b/experimental/ModelBuilder/test/TestMatMulVulkan.cpp
@@ -60,6 +60,10 @@
"workgroup-size", llvm::cl::desc("Workgroup size to use"),
llvm::cl::CommaSeparated);
+static llvm::cl::list<int> tileSizes("tile-sizes",
+ llvm::cl::desc("Tile sizes to use"),
+ llvm::cl::CommaSeparated);
+
using namespace mlir; // NOLINT
using namespace mlir::edsc; // NOLINT
using namespace mlir::edsc::intrinsics; // NOLINT
@@ -96,9 +100,10 @@
SmallVector<Type, 3> args = {typeA, typeB, typeC};
SmallVector<int64_t, 4> vWorkgroupSizes(workgroupSize.begin(),
workgroupSize.end());
+ SmallVector<int64_t, 4> vTileSizes(tileSizes.begin(), tileSizes.end());
auto lowering = [&](mlir::PassManager &pm) {
pm.addPass(mlir::iree_compiler::createLinalgTileAndFusePass(
- vWorkgroupSizes, useWorkgroupMemory));
+ vWorkgroupSizes, vTileSizes, useWorkgroupMemory));
pm.addPass(mlir::iree_compiler::createConvertToGPUPass());
pm.addPass(mlir::createLowerAffinePass());
pm.addPass(mlir::createLegalizeStdOpsForSPIRVLoweringPass());
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py
index 87a087b..54b79b9 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py
@@ -58,14 +58,7 @@
parent_dir = os.path.join(tempfile.gettempdir(), "iree", "modules")
artifacts_dir = os.path.join(parent_dir, module_name)
logging.info("Saving compilation artifacts and traces to '%s'", artifacts_dir)
-
- # If the artifacts already exist then we overwrite/update them.
- try:
- # Use try/except instead of os.path.exists to address a race condition
- # between multiple tests targets.
- os.makedirs(artifacts_dir)
- except IOError:
- pass
+ tf_utils._makedirs(artifacts_dir)
return artifacts_dir
@@ -314,9 +307,9 @@
return True
def _get_trace_dir(self, artifacts_dir):
- trace_dir = os.path.join(artifacts_dir, "traces")
- if not os.path.exists(trace_dir):
- os.makedirs(trace_dir)
+ trace_dir = os.path.join(artifacts_dir, self.backend, "traces",
+ self.function_name)
+ tf_utils._makedirs(trace_dir)
return trace_dir
def save_plaintext(self, artifacts_dir, summarize=True):
@@ -335,7 +328,7 @@
edgeitems=10) # Can show more items since they won't clutter the logs.
trace_dir = self._get_trace_dir(artifacts_dir)
- path = os.path.join(trace_dir, f"{self.function_name}__{self.backend}.txt")
+ path = os.path.join(trace_dir, "log.txt")
with open(path, "w") as f:
f.write(str(self))
f.write("\n")
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py
index d1cd14c..b93541a 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py
@@ -48,16 +48,6 @@
return np.arange(np.prod(shape), dtype=dtype).reshape(shape)
-def backends_to_str(backend_infos):
- """Creates a normalized string representing the provided backends."""
- normalized_names = []
- for backend_info in backend_infos:
- # Remove unusual characters and ensure names don't end or start in "_".
- name = re.sub("[^0-9a-zA-Z_]+", "_", backend_info.name)
- normalized_names.append(name.strip("_"))
- return "__".join(normalized_names)
-
-
def to_mlir_type(dtype):
"""Returns a string that denotes the type `dtype` in MLIR style."""
bits = dtype.itemsize * 8
@@ -97,6 +87,37 @@
return result
+def backends_to_str(backend_infos):
+ """Creates a normalized string representing the provided backends."""
+ normalized_names = []
+ for backend_info in backend_infos:
+ # Remove unusual characters and ensure names don't end or start in "_".
+ name = re.sub("[^0-9a-zA-Z_]+", "_", backend_info.name)
+ normalized_names.append(name.strip("_"))
+ return "__".join(normalized_names)
+
+
+def _get_backends_path(artifact_name, backend_infos, artifacts_dir):
+ backends_string = backends_to_str(backend_infos)
+ # Put the artifact in a directory if there's only one backend.
+ if len(backend_infos) == 1:
+ backend_dir = os.path.join(artifacts_dir, backends_string)
+ _makedirs(backend_dir)
+ return os.path.join(artifacts_dir, backends_string, artifact_name)
+ else:
+ return os.path.join(artifacts_dir, f"{artifact_name}__{backends_string}")
+
+
+def _makedirs(path):
+ # If the artifacts already exist then we overwrite/update them.
+ try:
+ # Use try/except instead of os.path.exists to address any race conditions
+ # that might arise between multiple tests targets.
+ os.makedirs(path)
+ except IOError:
+ pass
+
+
def compile_tf_module(tf_module,
backend_infos=(),
exported_names=(),
@@ -107,16 +128,21 @@
that returns a module that can be called without any further steps.
If artifacts_dir is provided then the following artifacts will be saved:
- saved_model:
+ backend_name/saved_model:
A TF SavedModel directory containing the files used translate the
- tf.Module into an IREE module.
+ tf.Module into an IREE module. Only saved if '--keep_saved_model=True'.
tf_input.mlir:
MLIR for the module in TF's input dialect.
iree_input.mlir:
The MLIR above translated to IREE via compiler.TF_IMPORT_PASS_PIPELINE.
- compiled__backends.vmfb:
+ backend_name/compiled.vmfb:
A VM FlatBuffer compiled to the target backends from the IREE MLIR above.
- Here 'backends' is a '__' delimited list of iree backends (e.g. vmla__llvm_ir)
+
+ If multiple backends are specified, then instead of saving the SavedModel and
+ compiled 'vmfb' under 'backend_name/', they will be saved as follows:
+ - 'saved_model__{backends}'
+ - 'compiled__{backends}.vmfb'
+ where 'backends' is a '__' delimited list (e.g. iree_vmla__iree_llvmjit).
Args:
tf_module: A tf.Module.
@@ -168,8 +194,9 @@
compiled_module = compiler_module.compile(target_backends=target_backends)
if artifacts_dir is not None:
- compiled_name = f"compiled__{backends_string}.vmfb"
- compiled_path = os.path.join(artifacts_dir, compiled_name)
+ compiled_path = _get_backends_path("compiled", backend_infos,
+ artifacts_dir)
+ compiled_path = f"{compiled_path}.vmfb"
logging.info("Saving compiled IREE module to: %s", compiled_path)
with open(compiled_path, "wb") as f:
f.write(compiled_module)
@@ -187,7 +214,7 @@
# Create a saved model for these target backends to avoid a race condition
# when running a test suite.
# TODO(meadowlark): Remove this once we have a TfLiteCompiledModule.
- sm_path = os.path.join(artifacts_dir, f"saved_model__{backends_string}")
+ sm_path = _get_backends_path("saved_model", backend_infos, artifacts_dir)
tf.saved_model.save(tf_module, sm_path, options=options)
return _compile_from_path(sm_path)
else:
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils_test.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils_test.py
index aa1df8e..4f4084e 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils_test.py
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils_test.py
@@ -19,9 +19,9 @@
from absl import logging
from absl.testing import parameterized
+import numpy as np
from pyiree.tf.support import tf_utils
import tensorflow as tf
-import numpy as np
class ConstantModule(tf.Module):
@@ -67,10 +67,13 @@
iree_compiled_module = tf_utils.compile_tf_module(
tf_module, backend_infos=backend_infos, artifacts_dir=artifacts_dir)
+ compiled_path = tf_utils._get_backends_path('compiled', backend_infos,
+ artifacts_dir)
+ compiled_path = f'{compiled_path}.vmfb'
artifacts_to_check = [
'tf_input.mlir',
'iree_input.mlir',
- f'compiled__{tf_utils.backends_to_str(backend_infos)}.vmfb',
+ compiled_path,
]
for artifact in artifacts_to_check:
artifact_path = os.path.join(artifacts_dir, artifact)
diff --git a/integrations/tensorflow/compiler/dialect/tf_strings/ir/dialect.cpp b/integrations/tensorflow/compiler/dialect/tf_strings/ir/dialect.cpp
index 0f2a336..031bc0c 100644
--- a/integrations/tensorflow/compiler/dialect/tf_strings/ir/dialect.cpp
+++ b/integrations/tensorflow/compiler/dialect/tf_strings/ir/dialect.cpp
@@ -54,13 +54,14 @@
}
void TFStringsDialect::printType(Type type, DialectAsmPrinter& os) const {
- switch (type.getKind()) {
- case TFStringsTypes::String:
- os << "string";
- break;
- default:
- llvm_unreachable("unhandled string type");
- }
+ if (type.isa<tf_strings::StringType>())
+ os << "string";
+ else
+ llvm_unreachable("unhandled string type");
+}
+
+bool TFStringsType::classof(Type type) {
+ return llvm::isa<TFStringsDialect>(type.getDialect());
}
} // namespace tf_strings
diff --git a/integrations/tensorflow/compiler/dialect/tf_strings/ir/types.h b/integrations/tensorflow/compiler/dialect/tf_strings/ir/types.h
index a7f4fe3..c61e0b8 100644
--- a/integrations/tensorflow/compiler/dialect/tf_strings/ir/types.h
+++ b/integrations/tensorflow/compiler/dialect/tf_strings/ir/types.h
@@ -41,10 +41,7 @@
public:
using Type::Type;
- static bool classof(Type type) {
- return type.getKind() >= TFStringsTypes::FIRST_USED_STRINGS_TYPE &&
- type.getKind() <= TFStringsTypes::LAST_USED_STRINGS_TYPE;
- }
+ static bool classof(Type type);
};
class StringType
@@ -54,8 +51,6 @@
static StringType get(MLIRContext* context) {
return Base::get(context, TFStringsTypes::String);
}
-
- static bool kindof(unsigned kind) { return kind == TFStringsTypes::String; }
};
} // namespace tf_strings
diff --git a/integrations/tensorflow/compiler/dialect/tf_tensorlist/ir/tf_tensorlist_types.h b/integrations/tensorflow/compiler/dialect/tf_tensorlist/ir/tf_tensorlist_types.h
index f9602cb..e31bc95 100644
--- a/integrations/tensorflow/compiler/dialect/tf_tensorlist/ir/tf_tensorlist_types.h
+++ b/integrations/tensorflow/compiler/dialect/tf_tensorlist/ir/tf_tensorlist_types.h
@@ -30,7 +30,6 @@
: public Type::TypeBase<TensorListType, Type, TypeStorage> {
public:
using Base::Base;
- static bool kindof(unsigned kind) { return kind == TypeKind::kTensorList; }
static TensorListType get(MLIRContext *context) {
return Base::get(context, TypeKind::kTensorList);
}
diff --git a/integrations/tensorflow/e2e/README.md b/integrations/tensorflow/e2e/README.md
index 300cadf..89b3bb1 100644
--- a/integrations/tensorflow/e2e/README.md
+++ b/integrations/tensorflow/e2e/README.md
@@ -13,7 +13,7 @@
See [Install TensorFlow with pip](https://www.tensorflow.org/install/pip) for
instructions.
-## Vulkan setup
+## Vulkan Setup
If you do not have your environment setup to use IREE with Vulkan (see
[the doc](../../../docs/vulkan_and_spirv.md)), then you can run the manual test
@@ -48,7 +48,7 @@
By default the TensorFlow SavedModels will not be kept. This can be overridden
via the `--keep_saved_model` flag.
-## Running tests
+## Running Tests
For locally running tests and iterating on backend development, `bazel run` is
preferred.
@@ -150,7 +150,33 @@
bazel test :e2e_tests_failing_broadcasting_test__tf__iree_vulkan
```
-## Debugging tests
+## Generated Artifacts
+
+By default, running an E2E test generates a number of compilation, debugging and
+benchmarking artifacts in `/tmp/iree/modules/`. The location of these artifacts
+can be changed via the `--artifacts_dir` flag. The generated directory structure
+for each module is as follows:
+
+```
+/tmp/iree/modules/ModuleName
+├── tf_input.mlir # MLIR for ModuleName in TF's input dialect
+├── iree_input.mlir # tf_input.mlir translated to IREE MLIR
+├── backend_name # e.g. iree_vmla, tf or tf_ref
+│ ├── compiled.vmfb # flatbuffer of ModuleName compiled to this backend
+│ ├── saved_model
+│ └── traces
+│ ├── trace_function
+│ │ └── log.txt # A more detailed version of the test logs
+│ └── trace_function
+│ └── log.txt
+└── backend_name
+ └── ...
+```
+
+The `saved_model` directory is only created if `--keep_saved_model` is
+specified.
+
+## Debugging Tests
If the compiler fails to compile the program, then it will create a crash
reproducer (see [MLIR documentation](https://mlir.llvm.org/docs/WritingAPass/)),
@@ -159,7 +185,7 @@
TODO(silvasean): debugging miscompiles
-## Test harnesses
+## Test Harnesses
### Simple function tests
diff --git a/iree/compiler/Conversion/CodegenUtils/MatmulCodegenStrategy.cpp b/iree/compiler/Conversion/CodegenUtils/MatmulCodegenStrategy.cpp
index 31c7d29..1372f3e 100644
--- a/iree/compiler/Conversion/CodegenUtils/MatmulCodegenStrategy.cpp
+++ b/iree/compiler/Conversion/CodegenUtils/MatmulCodegenStrategy.cpp
@@ -228,10 +228,16 @@
linalg::getLinalgTilingCanonicalizationPatterns(context);
stage2Patterns.insert<AffineMinCanonicalizationPattern>(context);
- auto stage3Transforms = [this](Operation *op) {
- // Some of these may be too aggressive as a stage 3 that is applied on each
- // stage 1 application and may have to be split out to post staged patterns
- // application (in which case they could just be passes, TBD).
+ auto stage3Transforms = [](Operation *op) {
+ promoteSingleIterationLoops(cast<FuncOp>(op));
+ return success();
+ };
+ linalg::applyStagedPatterns(func, stage1Patterns, stage2Patterns,
+ stage3Transforms);
+
+ auto postStageTransforms = [this](Operation *op) {
+ // Run LICM and hoisting patterns after all the stages as we want to
+ // unrolling before moving transfer ops out of the loop.
if (hoistInvariantCode) {
PassManager pm(op->getContext());
pm.addPass(createLoopInvariantCodeMotionPass());
@@ -241,11 +247,8 @@
hoistRedundantVectorTransfers(cast<FuncOp>(op));
hoistRedundantCopies(cast<FuncOp>(op));
}
- promoteSingleIterationLoops(cast<FuncOp>(op));
- return success();
};
- linalg::applyStagedPatterns(func, stage1Patterns, stage2Patterns,
- stage3Transforms);
+ postStageTransforms(func);
if (lowering != nullptr) lowering(func);
}
diff --git a/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnBuffers.cpp b/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnBuffers.cpp
index 3d754a7..35f8bdb 100644
--- a/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnBuffers.cpp
+++ b/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnBuffers.cpp
@@ -66,10 +66,11 @@
if (!matchPattern(init, m_Constant(&attr))) return {};
auto type = attr.getType().dyn_cast<ShapedType>();
if (!type || type.getRank() != 0) return {};
- if (auto intType = type.getElementType().dyn_cast<IntegerType>())
+ if (auto intType = type.getElementType().dyn_cast<IntegerType>()) {
return IntegerAttr::get(intType, attr.getValue<APInt>({}));
- else if (auto floatType = type.getElementType().dyn_cast<FloatType>())
+ } else if (auto floatType = type.getElementType().dyn_cast<FloatType>()) {
return FloatAttr::get(floatType, attr.getValue<APFloat>({}));
+ }
return {};
}
@@ -254,14 +255,17 @@
a == b;
};
if (lhsShape.size() == 1 && rhsShape.size() == 1 &&
- shapeMatches(lhsShape[0], rhsShape[0]))
+ shapeMatches(lhsShape[0], rhsShape[0])) {
return DotOperationType::VectorDot;
+ }
if (lhsShape.size() == 2 && rhsShape.size() == 1 &&
- shapeMatches(lhsShape[1], rhsShape[0]))
+ shapeMatches(lhsShape[1], rhsShape[0])) {
return DotOperationType::MatrixVector;
+ }
if (rhsShape.size() == 2 && rhsShape.size() == 2 &&
- shapeMatches(lhsShape[1], rhsShape[0]))
+ shapeMatches(lhsShape[1], rhsShape[0])) {
return DotOperationType::MatrixMatrix;
+ }
return DotOperationType::Unsupported;
}
@@ -317,8 +321,9 @@
// batch_count, spatial_dims..., input_feature_count.
if (dimensionNumbers.input_batch_dimension().getInt() != 0 ||
dimensionNumbers.input_feature_dimension().getInt() !=
- (inputSpatialRank + 1))
+ (inputSpatialRank + 1)) {
return failure();
+ }
const int kernelSpatialRank =
llvm::size(dimensionNumbers.kernel_spatial_dimensions());
@@ -327,8 +332,9 @@
if (dimensionNumbers.kernel_input_feature_dimension().getInt() !=
kernelSpatialRank ||
dimensionNumbers.kernel_output_feature_dimension().getInt() !=
- (kernelSpatialRank + 1))
+ (kernelSpatialRank + 1)) {
return failure();
+ }
const int outputSpatialRank =
llvm::size(dimensionNumbers.output_spatial_dimensions());
@@ -336,12 +342,14 @@
// batch_count, spatial_dims.., output_feature_count.
if (dimensionNumbers.output_batch_dimension().getInt() != 0 ||
dimensionNumbers.output_feature_dimension().getInt() !=
- (outputSpatialRank + 1))
+ (outputSpatialRank + 1)) {
return failure();
+ }
if (inputSpatialRank != outputSpatialRank ||
- inputSpatialRank != kernelSpatialRank)
+ inputSpatialRank != kernelSpatialRank) {
return failure();
+ }
auto inputSpatialDim = dimensionNumbers.input_spatial_dimensions().begin();
auto kernelSpatialDim =
@@ -353,8 +361,9 @@
const int dim = i + 1;
if ((*inputSpatialDim++).getZExtValue() != dim ||
(*outputSpatialDim++).getZExtValue() != dim ||
- (*kernelSpatialDim++).getZExtValue() != i)
+ (*kernelSpatialDim++).getZExtValue() != i) {
return failure();
+ }
}
}
@@ -669,8 +678,9 @@
ArrayRef<Value> resultBuffers, ConversionPatternRewriter &rewriter) const {
auto loc = op.getLoc();
auto argType = inputBuffers[0].getType().template dyn_cast<ShapedType>();
- if (!argType || !argType.hasRank())
+ if (!argType || !argType.hasRank()) {
return op.emitError("expected known-rank args");
+ }
SmallVector<Value, 3> offsets, sizes, strides;
for (int i = 0, e = argType.getRank(); i < e; ++i) {
@@ -731,9 +741,12 @@
int rank = output.getType().cast<ShapedType>().getRank();
SmallVector<Attribute, 2> indexingMaps;
SmallVector<AffineExpr, 4> exprs;
- for (int i = 0; i < batch; ++i) exprs.push_back(rewriter.getAffineDimExpr(i));
- for (int i = 0, e = nIndices - batch; i < e; ++i)
+ for (int i = 0; i < batch; ++i) {
+ exprs.push_back(rewriter.getAffineDimExpr(i));
+ }
+ for (int i = 0, e = nIndices - batch; i < e; ++i) {
exprs.push_back(rewriter.getAffineDimExpr(axis + i));
+ }
indexingMaps.emplace_back(AffineMapAttr::get(
AffineMap::get(rank, /*symbolCount=*/0, exprs, rewriter.getContext())));
indexingMaps.emplace_back(
@@ -763,10 +776,13 @@
SmallVector<Value, 4> indices;
Value castedValue = rewriter.create<IndexCastOp>(
loc, block->getArgument(rank), rewriter.getIndexType());
- for (int i = 0; i < axis; ++i) indices.push_back(block->getArgument(i));
- indices.push_back(castedValue);
- for (int i = axis + nIndices - batch; i < rank; ++i)
+ for (int i = 0; i < axis; ++i) {
indices.push_back(block->getArgument(i));
+ }
+ indices.push_back(castedValue);
+ for (int i = axis + nIndices - batch; i < rank; ++i) {
+ indices.push_back(block->getArgument(i));
+ }
Value res = rewriter.create<LoadOp>(loc, adaptor.input(), indices);
rewriter.create<linalg::YieldOp>(loc, res);
@@ -822,8 +838,9 @@
// Create a fake window dimension.
SmallVector<int64_t, 4> shapes;
- for (auto dim : op.window_dimensions().getValues<int64_t>())
+ for (auto dim : op.window_dimensions().getValues<int64_t>()) {
shapes.push_back(dim);
+ }
Type type = rewriter.getIntegerType(32);
auto memrefType = MemRefType::get(shapes, type);
auto fakeWindowDims = rewriter.create<AllocOp>(loc, memrefType);
@@ -889,8 +906,9 @@
for (auto dim : reductionDims) s.insert(dim);
SmallVector<unsigned, 4> permutation;
- for (int i = 0; i < rank; ++i)
+ for (int i = 0; i < rank; ++i) {
if (!s.count(i)) permutation.push_back(i);
+ }
for (auto dim : reductionDims) permutation.push_back(dim);
auto map = AffineMap::getPermutationMap(permutation, context);
@@ -1002,8 +1020,9 @@
auto loc = reduceOp.getLoc();
DenseIntElementsAttr dimensionsAttr = reduceOp.dimensions();
SmallVector<int, 4> reductionDims;
- for (const auto &dim : dimensionsAttr.getIntValues())
+ for (const auto &dim : dimensionsAttr.getIntValues()) {
reductionDims.push_back(dim.getSExtValue());
+ }
// Check if initVal is constant. If so, inline the value into the region.
Attribute initConstVal = getInitValueAsConst(initVal);
@@ -1020,16 +1039,18 @@
SmallVector<Attribute, 3> indexingMaps;
indexingMaps.emplace_back(AffineMapAttr::get(getTransposeMapForReduction(
rewriter.getContext(), nInputRank, reductionDims)));
- if (!initConstVal)
+ if (!initConstVal) {
indexingMaps.emplace_back(AffineMapAttr::get(
AffineMap::get(nInputRank, /*symbolCount=*/0, rewriter.getContext())));
+ }
// The indexing map of `dst` should drop the reduction loops. Since the
// reduction loops now are all in the innermost, drops `reductionDims.size()`
// dimensions. We don't need an inverse permutation here because they are the
// same.
SmallVector<AffineExpr, 4> exprs;
- for (int i = 0, e = nInputRank - reductionDims.size(); i < e; ++i)
+ for (int i = 0, e = nInputRank - reductionDims.size(); i < e; ++i) {
exprs.push_back(rewriter.getAffineDimExpr(i));
+ }
indexingMaps.emplace_back(AffineMapAttr::get(
exprs.empty()
? AffineMap::get(nInputRank, /*symbolCount=*/0, rewriter.getContext())
@@ -1038,7 +1059,9 @@
SmallVector<Type, 2> resultTypes = {};
SmallVector<Value, 2> linalgOpArgs = {inputBuffers[0]};
- if (!initConstVal) linalgOpArgs.push_back(inputBuffers[1]);
+ if (!initConstVal) {
+ linalgOpArgs.push_back(inputBuffers[1]);
+ }
linalgOpArgs.push_back(resultBuffers[0]);
if (failed(zeroFillBuffer(loc, resultBuffers[0], rewriter))) {
rewriter.notifyMatchFailure(reduceOp, "failed to zero fill result buffer");
@@ -1141,8 +1164,9 @@
// type.
TypeConverter::SignatureConversion signatureConverter(numIndices +
numTensorOperands);
- for (int i = 0; i < numIndices; ++i)
+ for (int i = 0; i < numIndices; ++i) {
signatureConverter.addInputs(i, rewriter.getIndexType());
+ }
for (auto arg : llvm::enumerate(opArgs)) {
if (arg.index() < numTensorOperands) {
signatureConverter.addInputs(
@@ -1224,13 +1248,15 @@
Shape::TieShapeOp shapeOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
Shape::TieShapeOp::Adaptor adaptor(operands);
- if (Value buffer = resolveResult(shapeOp.operand(), adaptor.operand(),
- shapeOp.result(), resultTensorToBufferMap))
+ if (Value buffer =
+ resolveResult(shapeOp.operand(), adaptor.operand(),
+ shapeOp.result(), resultTensorToBufferMap)) {
rewriter.replaceOp(shapeOp, buffer);
- else
+ } else {
rewriter.replaceOpWithNewOp<Shape::TieShapeOp>(
shapeOp, getMemrefTypeForTensor(shapeOp.result()), adaptor.operand(),
adaptor.shape());
+ }
return success();
}
@@ -1250,8 +1276,9 @@
LogicalResult matchAndRewrite(IREE::HAL::InterfaceLoadTensorOp loadOp,
ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
- if (!matchPattern(loadOp.offset(), m_Zero()))
+ if (!matchPattern(loadOp.offset(), m_Zero())) {
return loadOp.emitError("unhandled non-zero offset");
+ }
// Get the corresponding memref type from the tensor type.
auto tensorType = loadOp.result().getType().cast<RankedTensorType>();
@@ -1363,8 +1390,9 @@
static LogicalResult createAndPropagateBufferUsedForResultTensor(
IREE::HAL::InterfaceStoreTensorOp op, OutputBufferMap &outputBufferMap,
TensorToBufferMap &resultTensorToBufferMap, OpBuilder &builder) {
- if (!matchPattern(op.offset(), m_Zero()))
+ if (!matchPattern(op.offset(), m_Zero())) {
return op.emitError("unhandled non-zero offset");
+ }
// Get the corresponding memref type from the tensor type.
Value tensor = op.operand();
@@ -1460,8 +1488,9 @@
OutputBufferMap outputBufferMap;
TensorToBufferMap resultTensorToBufferMap;
if (failed(createAndPropagateBufferUsedForResultTensors(
- funcOp, outputBufferMap, resultTensorToBufferMap)))
+ funcOp, outputBufferMap, resultTensorToBufferMap))) {
return signalPassFailure();
+ }
OwningRewritePatternList patterns;
populateHLOToLinalgOnBuffersConversionPatterns(context, patterns,
@@ -1486,8 +1515,9 @@
target.addDynamicallyLegalDialect<linalg::LinalgDialect>(
Optional<ConversionTarget::DynamicLegalityCallbackFn>([](Operation *op) {
// The generated structured Linalg ops should have buffer semantics.
- if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op))
+ if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
return linalgOp.hasBufferSemantics();
+ }
// The other Linalg ops (like linalg.yield) are okay.
return true;
}));
diff --git a/iree/compiler/Conversion/HLOToLinalg/Passes.h b/iree/compiler/Conversion/HLOToLinalg/Passes.h
index 445bd51..b50d243 100644
--- a/iree/compiler/Conversion/HLOToLinalg/Passes.h
+++ b/iree/compiler/Conversion/HLOToLinalg/Passes.h
@@ -48,7 +48,7 @@
using TensorToBufferMap = DenseMap<Value, Value>;
void populateHLOToLinalgOnBuffersConversionPatterns(
MLIRContext *context, OwningRewritePatternList &patterns,
- TensorToBufferMap const &outputTensorToBuffer);
+ TensorToBufferMap const &resultTensorToBufferMap);
/// Populates passes to convert from XLA-HLO to Linalg on buffers as well as
/// handling some IREE specific conversions (like iree.interface.* and
diff --git a/iree/compiler/Conversion/HLOToLinalg/ResolveShapeOps.cpp b/iree/compiler/Conversion/HLOToLinalg/ResolveShapeOps.cpp
index cd86250..1ddeaf4 100644
--- a/iree/compiler/Conversion/HLOToLinalg/ResolveShapeOps.cpp
+++ b/iree/compiler/Conversion/HLOToLinalg/ResolveShapeOps.cpp
@@ -99,8 +99,9 @@
ConversionTarget target(*context);
target.addIllegalOp<DimOp>();
target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
- if (failed(applyFullConversion(getFunction(), target, dimPatterns)))
+ if (failed(applyFullConversion(getFunction(), target, dimPatterns))) {
return signalPassFailure();
+ }
OwningRewritePatternList shapePatterns;
shapePatterns.insert<TieShapeElider>(context);
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/BUILD b/iree/compiler/Conversion/LinalgToSPIRV/BUILD
index 7a80343..7877459 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/BUILD
+++ b/iree/compiler/Conversion/LinalgToSPIRV/BUILD
@@ -26,6 +26,7 @@
"CooperativeMatrixAnalysis.cpp",
"LinalgTileAndFusePass.cpp",
"MarkerUtils.cpp",
+ "MatMulVectorizationTest.cpp",
"Passes.cpp",
"SplitDispatchFunctionPass.cpp",
"Utils.cpp",
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/CMakeLists.txt b/iree/compiler/Conversion/LinalgToSPIRV/CMakeLists.txt
index 7bca4c8..af6fbe9 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/CMakeLists.txt
+++ b/iree/compiler/Conversion/LinalgToSPIRV/CMakeLists.txt
@@ -30,6 +30,7 @@
"CooperativeMatrixAnalysis.cpp"
"LinalgTileAndFusePass.cpp"
"MarkerUtils.cpp"
+ "MatMulVectorizationTest.cpp"
"Passes.cpp"
"SplitDispatchFunctionPass.cpp"
"Utils.cpp"
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/ConvertToGPUPass.cpp b/iree/compiler/Conversion/LinalgToSPIRV/ConvertToGPUPass.cpp
index 6ec315c..d623eae 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/ConvertToGPUPass.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/ConvertToGPUPass.cpp
@@ -19,9 +19,11 @@
//===----------------------------------------------------------------------===//
#include <array>
+#include <numeric>
#include "iree/compiler/Conversion/LinalgToSPIRV/Attributes.h"
#include "iree/compiler/Conversion/LinalgToSPIRV/MarkerUtils.h"
+#include "iree/compiler/Conversion/LinalgToSPIRV/MemorySpace.h"
#include "iree/compiler/Conversion/LinalgToSPIRV/Passes.h"
#include "iree/compiler/Conversion/LinalgToSPIRV/Utils.h"
#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
@@ -30,6 +32,7 @@
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/SCF/SCF.h"
+#include "mlir/Dialect/SPIRV/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/TargetAndABI.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/AffineMap.h"
@@ -41,22 +44,21 @@
namespace mlir {
namespace iree_compiler {
+static constexpr int kNumDims = 3;
+
+/// In some cases the iterations of the loops when partitioned to workgroups
+/// need to be distributed in a cyclic manner. The main use case here is when
+/// the number of workgroups is constrained such that the number of iterations
+/// is greater than or equal to number of processors (along any dimension). In
+/// those cases, distribute the iterations in a cyclic manner. This adds
+/// additional control flow, but isn't too detrimental to performance since they
+/// are convergent for the most part.
static llvm::cl::opt<bool> isWorkgroupCountConstrained(
"iree-codegen-constrained-workgroup-count",
llvm::cl::desc("Specify whether the number of workgroups can be assumed to "
"be large enough to cover the entire workload"),
llvm::cl::init(false));
-// TODO(#2134): Remove this flag/set it to false always when issue with
-// convolution is resolved (see bug for more details).
-// TODO(#2346): Make this a pass specific option.
-llvm::cl::opt<bool> useLegacyConvLowering{
- "iree-codegen-use-legacy-conv-lowering",
- llvm::cl::desc("Use conv lowering that does not assume 1:1 mapping "
- "between threads within a block and iterations of "
- "parallel loops distributed to the block"),
- llvm::cl::init(true)};
-
//===----------------------------------------------------------------------===//
// Loop utilities
//===----------------------------------------------------------------------===//
@@ -420,7 +422,7 @@
unsigned numDims,
MutableArrayRef<Value> id,
MutableArrayRef<Value> count) {
- std::array<StringRef, 3> dims{"x", "y", "z"};
+ std::array<StringRef, kNumDims> dims{"x", "y", "z"};
assert(id.size() == numDims);
assert(count.size() == numDims);
for (unsigned i = 0; i < numDims; ++i) {
@@ -431,6 +433,24 @@
}
}
+template <typename GPUIdOp, typename GPUCountOp>
+static ProcessorIdAndCount getLinearizedGPUProcessorIdAndCount(
+ Location loc, ConversionPatternRewriter &rewriter) {
+ std::array<Value, kNumDims> ids, counts;
+ getGPUProcessorIdsAndCounts<GPUIdOp, GPUCountOp>(loc, rewriter, kNumDims, ids,
+ counts);
+ ProcessorIdAndCount linearized;
+ linearized.id = ids[0];
+ linearized.count = counts[0];
+ for (unsigned i = 0; i < kNumDims - 1; ++i) {
+ linearized.id = rewriter.create<MulIOp>(loc, linearized.id, counts[i + 1]);
+ linearized.id = rewriter.create<AddIOp>(loc, linearized.id, ids[i + 1]);
+ linearized.count =
+ rewriter.create<MulIOp>(loc, linearized.count, counts[i + 1]);
+ }
+ return linearized;
+}
+
/// Distributes scf.parallel to processors where `IdOp` is used to get the
/// processor ID and `DimOp` is used to get the number of processors along a
/// dimension.
@@ -474,16 +494,22 @@
static LogicalResult mapToWorkgroups(ConversionPatternRewriter &rewriter,
scf::ParallelOp pLoopOp,
bool useCyclicDistribution = false) {
- if (useCyclicDistribution)
+ if (useCyclicDistribution) {
return distributeCyclicallyToProcessors<gpu::BlockIdOp, gpu::GridDimOp>(
rewriter, pLoopOp);
+ }
return distributeSingleIterationPerProcessor<gpu::BlockIdOp, gpu::GridDimOp>(
rewriter, pLoopOp, false);
}
/// Distributes scf.parallel to workitems using local invocation ID.
-static LogicalResult mapToLocalInvocationId(ConversionPatternRewriter &rewriter,
- scf::ParallelOp pLoopOp) {
+static LogicalResult mapToLocalInvocationId(
+ ConversionPatternRewriter &rewriter, scf::ParallelOp pLoopOp,
+ bool useCyclicDistribution = false) {
+ if (useCyclicDistribution) {
+ return distributeCyclicallyToProcessors<gpu::ThreadIdOp, gpu::BlockDimOp>(
+ rewriter, pLoopOp);
+ }
return distributeSingleIterationPerProcessor<gpu::ThreadIdOp,
gpu::BlockDimOp>(rewriter,
pLoopOp);
@@ -498,22 +524,47 @@
rewriter, pLoopOp);
}
+/// Returns the number of bytes copied when loading to/storing from workgorup
+/// memory. It is approximated to be the size of the underlying allocation being
+/// copied into/from.
+static Optional<int64_t> getLinearizedCopySize(linalg::CopyOp copyOp) {
+ Value src = copyOp.input();
+ Value dst = copyOp.output();
+ MemRefType srcType = src.getType().cast<MemRefType>();
+ MemRefType dstType = dst.getType().cast<MemRefType>();
+
+ Value workgroupMemoryView;
+ MemRefType workgroupMemoryType;
+ if (srcType.getMemorySpace() == getWorkgroupMemorySpace()) {
+ workgroupMemoryView = src;
+ workgroupMemoryType = srcType;
+ } else if (dstType.getMemorySpace() == getWorkgroupMemorySpace()) {
+ workgroupMemoryView = dst;
+ workgroupMemoryType = dstType;
+ } else {
+ return {};
+ }
+
+ SubViewOp workgroupMemorySubviewOp =
+ dyn_cast_or_null<SubViewOp>(workgroupMemoryView.getDefiningOp());
+ if (!workgroupMemorySubviewOp) return {};
+ AllocOp allocOp = dyn_cast_or_null<AllocOp>(
+ workgroupMemorySubviewOp.source().getDefiningOp());
+ if (!allocOp) return {};
+
+ MemRefType allocOpType = allocOp.getType();
+ if (!allocOpType.hasStaticShape()) return {};
+ return allocOpType.getNumElements();
+}
+
//===----------------------------------------------------------------------===//
// Pass and patterns.
//===----------------------------------------------------------------------===//
-/// In some cases the iterations of the loops when partitioned to workgroups
-/// need to be distributed in a cyclic manner. The main use cases here is when
-/// the number of workgroups is constrained such that the number of iterations
-/// is greater than equal to number of processors (along any dimension). In
-/// those cases, distribute the iterations in a cyclic manner. This adds
-/// additional control flow, but isn't too detrimental to performance since they
-/// are convergent for the most part.
// TODO(#2134): Mapping iterations to processors directly by assuming number of
// iterations <= number of processors again seems to have an issue with
// convolution/pooling. Needs further investigation.
static bool useCyclicLoopDistribution(scf::ParallelOp pLoopOp) {
- if (!useLegacyConvLowering) return false;
auto walkResult = pLoopOp.walk([](Operation *op) -> WalkResult {
if (isa<linalg::ConvOp>(op) || isa<linalg::PoolingMaxOp>(op) ||
isa<linalg::PoolingMinOp>(op) || isa<linalg::PoolingSumOp>(op))
@@ -545,6 +596,70 @@
}
};
+/// Implementation of the mapping of tiled linalg op to workitems within a
+/// workgroup.
+template <typename LinalgOpTy>
+static LogicalResult mapLinalgOpToLocalInvocationIdImpl(
+ LinalgOpTy linalgOp, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) {
+ // Check for marker that specifies that the linalg op is to be partitioned
+ // across threads within a workgroup.
+ if (!hasMarker(linalgOp)) return failure();
+ Optional<linalg::LinalgLoops> loops =
+ linalg::linalgLowerOpToLoops<scf::ParallelOp>(rewriter, linalgOp);
+ if (!loops) return failure();
+ if (loops.getValue().empty()) return success();
+
+ scf::ParallelOp pLoopOp = dyn_cast<scf::ParallelOp>(loops.getValue()[0]);
+ if (!pLoopOp) return success();
+
+ return mapToLocalInvocationId(
+ rewriter, pLoopOp,
+ hasMarker(linalgOp, {getWorkgroupMarker(), getWorkgroupMemoryMarker()}));
+}
+
+/// CopyOp that are loading to/storing from workgroup memory are special cased
+/// to use all workitems to do a copy. This is done by linearizing the copy
+/// operation.
+// TODO(ravishankarm): This linearization is achieved through collapsing the
+// generated parallel loops from a multi-dimensional copy. Such lowering results
+// in mods/divs in the collapsed loop body. This can be removed by reshaping the
+// copy to be a 1D copy. This seems to be hitting an error in reshape
+// canonicalization. Investigate this further.
+template <>
+LogicalResult mapLinalgOpToLocalInvocationIdImpl<linalg::CopyOp>(
+ linalg::CopyOp copyOp, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) {
+ if (!hasMarker(copyOp, getCopyToWorkgroupMemoryMarker())) return failure();
+ Optional<linalg::LinalgLoops> loops =
+ linalg::linalgLowerOpToLoops<scf::ParallelOp>(rewriter, copyOp);
+ if (!loops) return failure();
+ if (loops.getValue().empty()) return success();
+
+ scf::ParallelOp pLoopOp = dyn_cast<scf::ParallelOp>(loops.getValue()[0]);
+ if (!pLoopOp) return success();
+ pLoopOp = collapseParallelLoops(rewriter, pLoopOp);
+ if (!pLoopOp) return failure();
+
+ Optional<int64_t> copyLength = getLinearizedCopySize(copyOp);
+ ProcessorIdAndCount idAndCount =
+ getLinearizedGPUProcessorIdAndCount<gpu::ThreadIdOp, gpu::BlockDimOp>(
+ copyOp.getLoc(), rewriter);
+ auto workgroupSize =
+ spirv::lookupLocalWorkGroupSize(copyOp).getValues<APInt>();
+ int64_t linearizedWorkgroupSize = std::accumulate(
+ workgroupSize.begin(), workgroupSize.end(), 1,
+ [](int64_t total, APInt value) { return total * value.getSExtValue(); });
+
+ if (copyLength.hasValue() && !workgroupSize.empty() &&
+ copyLength.getValue() <= linearizedWorkgroupSize) {
+ return distributeSingleIterationPerProcessor(
+ rewriter, pLoopOp, idAndCount.id, /*generateGuard=*/true);
+ }
+ return distributeCyclicallyToProcessors(rewriter, pLoopOp, idAndCount.id,
+ idAndCount.count);
+}
+
/// Map tiled linalg op to workitems by lowering it to scf.parallel and
/// partitioning it to workitems.
template <typename LinalgOpTy>
@@ -553,41 +668,20 @@
LogicalResult matchAndRewrite(
LinalgOpTy linalgOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- // Check for marker that specifies that the linalg op is to be partitioned
- // across threads within a workgroup.
- if (!hasWorkGroupMarker(linalgOp)) return failure();
- Optional<linalg::LinalgLoops> loops =
- linalg::linalgLowerOpToLoops<scf::ParallelOp>(rewriter, linalgOp);
- if (!loops) return failure();
- if (!loops.getValue().empty()) {
- scf::ParallelOp pLoopOp = dyn_cast<scf::ParallelOp>(loops.getValue()[0]);
- if (!pLoopOp || failed(mapToLocalInvocationId(rewriter, pLoopOp)))
- return failure();
- }
- rewriter.eraseOp(linalgOp);
- return success();
- }
-};
-
-/// Legacy path for lowering tiled conv/pooling op to loops.
-// TODO(#2134): Remove this pattern. The default path of using
-// `MapLinalgOpToLocalInvocationId` seems to have a bug. It only shows up
-// currently on Resnet50. Remove this pattern after the bug is triaged/fixed.
-template <typename LinalgOpTy>
-struct MapConvPoolToLocalInvocationId : public OpConversionPattern<LinalgOpTy> {
- using OpConversionPattern<LinalgOpTy>::OpConversionPattern;
- LogicalResult matchAndRewrite(
- LinalgOpTy linalgOp, ArrayRef<Value> operands,
- ConversionPatternRewriter &rewriter) const override {
- if (!hasWorkGroupMarker(linalgOp)) return failure();
- Optional<linalg::LinalgLoops> loops =
- linalg::linalgLowerOpToLoops<scf::ParallelOp>(rewriter, linalgOp);
- if (!loops) return failure();
- scf::ParallelOp pLoopOp = cast<scf::ParallelOp>(loops.getValue()[0]);
if (failed(
- distributeCyclicallyToProcessors<gpu::ThreadIdOp, gpu::BlockDimOp>(
- rewriter, pLoopOp)))
+ mapLinalgOpToLocalInvocationIdImpl(linalgOp, operands, rewriter)))
return failure();
+
+ // If the `linalgOp` writes to workgroup memory insert barrier after the
+ // op.
+ if (llvm::any_of(linalgOp.getOperands(), [](Value output) {
+ return output.getType().cast<MemRefType>().getMemorySpace() ==
+ getWorkgroupMemorySpace();
+ })) {
+ rewriter.create<spirv::ControlBarrierOp>(
+ linalgOp.getLoc(), spirv::Scope::Workgroup, spirv::Scope::Workgroup,
+ spirv::MemorySemantics::AcquireRelease);
+ }
rewriter.eraseOp(linalgOp);
return success();
}
@@ -674,39 +768,17 @@
OwningRewritePatternList patterns;
- // clang-format off
- patterns.insert<
-
-#define ADD_ALL_LINALG_PATTERNS(OP_NAME) \
- MapLinalgOpToGlobalInvocationId<OP_NAME>, \
- MapLinalgOpToLocalInvocationId<OP_NAME>
-
- ADD_ALL_LINALG_PATTERNS(linalg::CopyOp),
- ADD_ALL_LINALG_PATTERNS(linalg::FillOp),
- ADD_ALL_LINALG_PATTERNS(linalg::GenericOp),
- ADD_ALL_LINALG_PATTERNS(linalg::IndexedGenericOp),
-
-#undef ADD_ALL_LINALG_PATTERNS
-
-#define ADD_ALL_CONV_POOL_PATTERNS(OP_NAME) \
- MapConvPoolToLocalInvocationId<OP_NAME>, \
- MapLinalgOpToGlobalInvocationId<OP_NAME>
-
- ADD_ALL_CONV_POOL_PATTERNS(linalg::PoolingMaxOp),
- ADD_ALL_CONV_POOL_PATTERNS(linalg::PoolingMinOp),
- ADD_ALL_CONV_POOL_PATTERNS(linalg::PoolingSumOp),
-
-#undef ADD_ALL_CONV_POOL_PATTERNS
-
- MapLinalgOpToLocalInvocationId<linalg::MatmulOp>,
- PartitionPLoopToWorkgroups, RemoveLinalgRange>(context);
- // clang-format on
-
- patterns.insert<MapLinalgOpToGlobalInvocationId<linalg::ConvOp>>(context);
- if (useLegacyConvLowering)
- patterns.insert<MapConvPoolToLocalInvocationId<linalg::ConvOp>>(context);
- else
- patterns.insert<MapLinalgOpToLocalInvocationId<linalg::ConvOp>>(context);
+ patterns.insert<MapLinalgOpToGlobalInvocationId<linalg::CopyOp>,
+ MapLinalgOpToGlobalInvocationId<linalg::FillOp>,
+ MapLinalgOpToGlobalInvocationId<linalg::GenericOp>,
+ MapLinalgOpToGlobalInvocationId<linalg::IndexedGenericOp>,
+ MapLinalgOpToLocalInvocationId<linalg::ConvOp>,
+ MapLinalgOpToLocalInvocationId<linalg::CopyOp>,
+ MapLinalgOpToLocalInvocationId<linalg::MatmulOp>,
+ MapLinalgOpToLocalInvocationId<linalg::PoolingMaxOp>,
+ MapLinalgOpToLocalInvocationId<linalg::PoolingMinOp>,
+ MapLinalgOpToLocalInvocationId<linalg::PoolingSumOp>,
+ PartitionPLoopToWorkgroups, RemoveLinalgRange>(context);
if (failed(applyFullConversion(funcOp, target, patterns)))
return signalPassFailure();
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndFusePass.cpp b/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndFusePass.cpp
index 934e5ae..f83d444 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndFusePass.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndFusePass.cpp
@@ -25,18 +25,16 @@
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
-#include "mlir/Dialect/SPIRV/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/TargetAndABI.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/Identifier.h"
+#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/FoldUtils.h"
#define DEBUG_TYPE "iree-linalg-tile-and-fuse-buffer"
-static std::string PromotionMarker = "promotion";
-
namespace mlir {
namespace iree_compiler {
@@ -81,26 +79,24 @@
workgroupSize.resize(3, 1);
}
- /// Compute the tile sizes based on workgroup size specified.
- LogicalResult setTileSizesBasedOnWorkgroupSize(
- ArrayRef<int64_t> vWorkGroupSize) {
- if (!vWorkGroupSize.empty()) {
- vWorkGroupSize = dropTrailingOnes(vWorkGroupSize);
- workgroupSize.assign(vWorkGroupSize.begin(), vWorkGroupSize.end());
- auto rev = reverse(workgroupSize);
- tileSizes.assign(rev.begin(), rev.end());
- }
- return success();
+ /// Set tile sizes to use.
+ void setTileSizes(ArrayRef<int64_t> sizes) {
+ tileSizes.assign(sizes.begin(), sizes.end());
+ }
+
+ /// Set workgroup size to use.
+ void setWorkgroupSize(ArrayRef<int64_t> sizes) {
+ workgroupSize.assign(sizes.begin(), sizes.end());
}
/// Compute the tile sizes based on the Linalg Ops within the dispatch region.
- LogicalResult setTileSizesBasedOnOps(ArrayRef<linalg::LinalgOp> linalgOps);
+ LogicalResult inferTileAndWorkgroupSize(ArrayRef<linalg::LinalgOp> linalgOps);
/// Get the current tile size computed.
ArrayRef<int64_t> getTileSizes() const { return tileSizes; }
/// Returns the workgroup size to use based on the tile sizes.
- ArrayRef<int64_t> getWorkGroupSize() const { return workgroupSize; }
+ ArrayRef<int64_t> getWorkgroupSize() const { return workgroupSize; }
private:
/// Current tile size configuration.
@@ -114,7 +110,7 @@
};
} // namespace
-LogicalResult TileSizeCalculator::setTileSizesBasedOnOps(
+LogicalResult TileSizeCalculator::inferTileAndWorkgroupSize(
ArrayRef<linalg::LinalgOp> linalgOps) {
tileSizes.clear();
if (linalgOps.empty()) {
@@ -134,11 +130,16 @@
uint32_t opInfo = OpInfo::None;
for (linalg::LinalgOp linalgOp : linalgOps) {
Operation *op = linalgOp.getOperation();
- if (isa<linalg::ConvOp>(op)) opInfo |= OpInfo::Convolution;
- if (isa<linalg::MatmulOp>(op)) opInfo |= OpInfo::Matmul;
- if (isa<linalg::PoolingMaxOp>(op)) opInfo |= OpInfo::Pooling;
- if (isa<linalg::PoolingMinOp>(op)) opInfo |= OpInfo::Pooling;
- if (isa<linalg::PoolingSumOp>(op)) opInfo |= OpInfo::Pooling;
+ if (isa<linalg::ConvOp>(op))
+ opInfo |= OpInfo::Convolution;
+ else if (isa<linalg::MatmulOp>(op))
+ opInfo |= OpInfo::Matmul;
+ else if (isa<linalg::PoolingMaxOp>(op))
+ opInfo |= OpInfo::Pooling;
+ else if (isa<linalg::PoolingMinOp>(op))
+ opInfo |= OpInfo::Pooling;
+ else if (isa<linalg::PoolingSumOp>(op))
+ opInfo |= OpInfo::Pooling;
}
// If there are no tilable ops, there is nothing to do here.
if (!opInfo) return success();
@@ -155,10 +156,6 @@
unsigned maxWorkgroupSize =
resourceLimits.max_compute_workgroup_invocations().getInt();
if (opInfo & OpInfo::Convolution) {
- // TODO(ravishankarm): This tiling is meant to enable promotion to workgroup
- // memory, but doesnt actually get us to a state where we can do this. The
- // promotion is possible only when the subviews created are constant
- // size. For now this doesnt really matter. Revisit this later.
int64_t tileSizeX = 32;
int64_t tileSizeY = maxWorkgroupSize / 32;
tileSizes = {1, tileSizeY, tileSizeX};
@@ -187,18 +184,22 @@
//===----------------------------------------------------------------------===//
/// Allocation callback for allocation workgroup local memory.
-static Value allocateWorkgroupMemory(OpBuilder &b, SubViewOp subview,
- ArrayRef<Value> boundingSubViewSize,
- OperationFolder *folder) {
+static Optional<Value> allocateWorkgroupMemory(
+ OpBuilder &b, SubViewOp subview, ArrayRef<Value> boundingSubViewSize,
+ OperationFolder *folder) {
// The bounding subview size is expected to be constant. This specified the
// shape of the allocation.
- SmallVector<int64_t, 2> shape(boundingSubViewSize.size(),
- ShapedType::kDynamicSize);
- return b.create<AllocOp>(
- subview.getLoc(),
- MemRefType::get(shape, subview.getType().getElementType(), {},
- getWorkgroupMemorySpace()),
- boundingSubViewSize);
+ SmallVector<int64_t, 2> shape = llvm::to_vector<2>(
+ llvm::map_range(boundingSubViewSize, [](Value v) -> int64_t {
+ APInt value;
+ if (matchPattern(v, m_ConstantInt(&value))) return value.getSExtValue();
+ return -1;
+ }));
+ if (llvm::any_of(shape, [](int64_t v) { return v == -1; })) return {};
+ MemRefType allocType = MemRefType::get(
+ shape, subview.getType().getElementType(), {}, getWorkgroupMemorySpace());
+ Value buffer = b.create<AllocOp>(subview.getLoc(), allocType);
+ return buffer;
}
/// Deallocation callback for allocation workgroup local memory.
@@ -208,56 +209,50 @@
return success();
}
-/// Insert barrier after `op`.
-static void insertBarrierAfter(OpBuilder &b, Location loc, Operation *op) {
- OpBuilder::InsertionGuard guard(b);
- b.setInsertionPointAfter(op);
- b.create<spirv::ControlBarrierOp>(loc, spirv::Scope::Workgroup,
- spirv::Scope::Workgroup,
- spirv::MemorySemantics::AcquireRelease);
-}
-
-/// Function used as callback for copyin/copyout in promotion pattern used to
-/// promote subviews to workgroup memory.
-static LogicalResult copyToFromWorkgroupMemory(
- OpBuilder &b, Value src, Value dst, StringRef marker = PromotionMarker) {
- auto copyOp = b.create<linalg::CopyOp>(src.getLoc(), src, dst);
- setMarker(copyOp, marker);
- return success();
-}
-
namespace {
/// Function pass that implements tiling and fusion in Linalg on buffers.
struct LinalgTileAndFusePass
: public PassWrapper<LinalgTileAndFusePass, FunctionPass> {
- LinalgTileAndFusePass(ArrayRef<int64_t> workGroupSize = {},
- bool useWorkgroupMem = false)
- : workGroupSize(workGroupSize.begin(), workGroupSize.end()) {
+ LinalgTileAndFusePass(ArrayRef<int64_t> workgroupSize = {},
+ ArrayRef<int64_t> tileSizes = {},
+ bool useWorkgroupMem = false) {
+ this->workgroupSize = workgroupSize;
+ this->tileSizes = tileSizes;
this->useWorkgroupMemory = useWorkgroupMem;
}
LinalgTileAndFusePass(const LinalgTileAndFusePass &pass) {}
void runOnFunction() override;
+ private:
Option<bool> useWorkgroupMemory{
*this, "use-workgroup-memory",
llvm::cl::desc("Promote subviews to use workgroup memory"),
llvm::cl::init(false)};
- private:
- SmallVector<int64_t, 3> workGroupSize;
+ ListOption<int64_t> workgroupSize{
+ *this, "workgroup-size",
+ llvm::cl::desc("Override the default workgroup size"),
+ llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
+
+ ListOption<int64_t> tileSizes{
+ *this, "tile-sizes", llvm::cl::desc("Set tile sizes to use"),
+ llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
};
/// Pattern for tiling operations. Updates the workgroup size in the surrounding
/// function operation if tiling succeeds.
-template <typename OpTy>
-struct TilingPattern : public linalg::LinalgTilingPattern<OpTy> {
- using Base = linalg::LinalgTilingPattern<OpTy>;
- TilingPattern(MLIRContext *context, linalg::LinalgTilingOptions options,
- ArrayRef<int64_t> workgroupSize,
- linalg::LinalgMarker marker = linalg::LinalgMarker(),
- PatternBenefit benefit = 1)
- : Base(context, options, marker, benefit),
+struct TileMatmulPattern
+ : public linalg::LinalgTilingPattern<linalg::MatmulOp> {
+ using Base = linalg::LinalgTilingPattern<linalg::MatmulOp>;
+ TileMatmulPattern(MLIRContext *context, linalg::LinalgTilingOptions options,
+ ArrayRef<int64_t> workgroupSize, PatternBenefit benefit = 1)
+ : Base(context, options,
+ linalg::LinalgMarker(
+ ArrayRef<Identifier>(),
+ Identifier::get(getWorkgroupNumItemsGENumItersMarker(),
+ context)),
+ benefit),
workgroupSize(workgroupSize.begin(), workgroupSize.end()) {}
virtual LogicalResult matchAndRewrite(Operation *op,
@@ -280,9 +275,17 @@
/// Pattern for tiling convolution and pooling operations. Currently is just a
/// way to not tile when the operation has padding.
template <typename OpTy>
-struct TileConvPoolPattern : public TilingPattern<OpTy> {
- using Base = TilingPattern<OpTy>;
- using Base::TilingPattern;
+struct TileConvPoolPattern : public linalg::LinalgTilingPattern<OpTy> {
+ using Base = linalg::LinalgTilingPattern<OpTy>;
+ TileConvPoolPattern(MLIRContext *context, linalg::LinalgTilingOptions options,
+ ArrayRef<int64_t> workgroupSize,
+ PatternBenefit benefit = 1)
+ : Base(context, options,
+ linalg::LinalgMarker(
+ ArrayRef<Identifier>(),
+ Identifier::get(getWorkgroupMarker(), context)),
+ benefit),
+ workgroupSize(workgroupSize.begin(), workgroupSize.end()) {}
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
@@ -296,28 +299,69 @@
WorkgroupCountMethodology::Default)));
return success();
}
+
+ SmallVector<int64_t, 3> workgroupSize;
};
-/// Pattern to promote subviews to memory.
-// TODO(ravishankarm): Generalize this for other operations.
-struct PromoteSubviewsPattern
+//===----------------------------------------------------------------------===//
+// Patterns to promote subviews to workgroup memory
+//===----------------------------------------------------------------------===//
+
+/// Function used as callback for copyin/copyout in promotion pattern used to
+/// promote subviews to workgroup memory when the number of threads is known to
+/// be greater than equal to the number of iteration of loops the copy is
+/// lowered to.
+static LogicalResult copyToWorkgroupMemory(OpBuilder &b, Value src, Value dst) {
+ auto copyOp = b.create<linalg::CopyOp>(src.getLoc(), src, dst);
+ setMarker(copyOp, getCopyToWorkgroupMemoryMarker());
+ return success();
+}
+
+/// Pattern to promote matmul operands to workgroup memory.
+struct PromoteMatmulSubviewsPattern
: public linalg::LinalgPromotionPattern<linalg::MatmulOp> {
- PromoteSubviewsPattern(MLIRContext *context,
- linalg::LinalgPromotionOptions options,
- linalg::LinalgMarker marker = linalg::LinalgMarker(),
- PatternBenefit benefit = 1)
+ PromoteMatmulSubviewsPattern(
+ MLIRContext *context, linalg::LinalgPromotionOptions options,
+ linalg::LinalgMarker marker = linalg::LinalgMarker(),
+ PatternBenefit benefit = 1)
: linalg::LinalgPromotionPattern<linalg::MatmulOp>(
context,
options.setOperandsToPromote({0, 1}).setUseFullTileBuffers(
{false, false}),
- marker, benefit) {}
+ linalg::LinalgMarker(
+ Identifier::get(getWorkgroupNumItemsGENumItersMarker(),
+ context),
+ Identifier::get(getWorkgroupMemoryNumItemsGENumItersMarker(),
+ context)),
+ benefit) {}
+};
- LogicalResult matchAndRewrite(Operation *op,
- PatternRewriter &rewriter) const override {
- if (!hasWorkGroupMarker(op)) return failure();
- return linalg::LinalgPromotionPattern<linalg::MatmulOp>::matchAndRewrite(
- op, rewriter);
- }
+/// Patterns to promote convolution operands to workgroup memory.
+// TODO(ravishankarm): This pattern is only promoting the image subview to
+// workgroup memory. In reality we should also be able to promote the filter
+// subview to workgroup memory as well. Since none of the loops used to access
+// the filter are tiled, this would mean the entire filter is moved to workgroup
+// memory. Two reasons this is not done right now:
+// 1) Linalg when tiling doesnt create a subview for the filter (since none of
+// its dimensions are tiled. This needs to be relaxed (maybe by using an
+// option.
+// 2) Maybe there are better alternatives for handling filter (using different
+// StorageClasses, since for inference workloads these are model
+// constants. This is TBD.
+struct PromoteConvolutionSubviewsPattern
+ : public linalg::LinalgPromotionPattern<linalg::ConvOp> {
+ PromoteConvolutionSubviewsPattern(
+ MLIRContext *context, linalg::LinalgPromotionOptions options,
+ linalg::LinalgMarker marker = linalg::LinalgMarker(),
+ PatternBenefit benefit = 1)
+ : linalg::LinalgPromotionPattern<linalg::ConvOp>(
+ context,
+ options.setOperandsToPromote({1}).setUseFullTileBuffers(
+ {false, false}),
+ linalg::LinalgMarker(
+ Identifier::get(getWorkgroupMarker(), context),
+ Identifier::get(getWorkgroupMemoryMarker(), context)),
+ benefit) {}
};
} // namespace
@@ -334,28 +378,29 @@
if (linalgOps.empty()) return;
TileSizeCalculator tileSizeCalculator(funcOp);
- if (workGroupSize.empty()) {
+ if (tileSizes.empty()) {
// Get the tile sizes to use for the lowering.
SmallVector<int64_t, 3> tileSizes;
SmallVector<linalg::LinalgOp, 1> opsVec(linalgOps.begin(), linalgOps.end());
- if (failed(tileSizeCalculator.setTileSizesBasedOnOps(opsVec)))
+ if (failed(tileSizeCalculator.inferTileAndWorkgroupSize(opsVec)))
return signalPassFailure();
} else {
- tileSizeCalculator.setTileSizesBasedOnWorkgroupSize(workGroupSize);
+ tileSizeCalculator.setTileSizes(tileSizes);
+ if (!workgroupSize.empty())
+ tileSizeCalculator.setWorkgroupSize(workgroupSize);
}
LLVM_DEBUG({
llvm::dbgs() << "--- IREE Linalg tile and fuse configuration ---\n";
- llvm::dbgs() << "# workgroup sizes at start: [";
- interleaveComma(workGroupSize, llvm::dbgs());
+ llvm::dbgs() << "# workgroup sizes: [";
+ interleaveComma(tileSizeCalculator.getWorkgroupSize(), llvm::dbgs());
llvm::dbgs() << "]\ntile sizes: [";
interleaveComma(tileSizeCalculator.getTileSizes(), llvm::dbgs());
llvm::dbgs() << "]\n";
});
OwningRewritePatternList tilingPatterns;
- tilingPatterns.insert<TileConvPoolPattern<linalg::ConvOp>,
- TilingPattern<linalg::MatmulOp>,
+ tilingPatterns.insert<TileConvPoolPattern<linalg::ConvOp>, TileMatmulPattern,
TileConvPoolPattern<linalg::PoolingMaxOp>,
TileConvPoolPattern<linalg::PoolingMinOp>,
TileConvPoolPattern<linalg::PoolingSumOp>>(
@@ -363,9 +408,7 @@
linalg::LinalgTilingOptions()
.setTileSizes(tileSizeCalculator.getTileSizes())
.setLoopType(linalg::LinalgTilingLoopType::ParallelLoops),
- tileSizeCalculator.getWorkGroupSize(),
- linalg::LinalgMarker(ArrayRef<Identifier>(),
- Identifier::get(getWorkGroupMarker(), context)));
+ tileSizeCalculator.getWorkgroupSize());
applyPatternsAndFoldGreedily(getOperation(), tilingPatterns);
if (useWorkgroupMemory) {
@@ -373,31 +416,15 @@
// sure that the allocated scratchspace memory is constant sizes which
// requires some folding to trigger.
OwningRewritePatternList promotionPatterns;
- promotionPatterns.insert<PromoteSubviewsPattern>(
+ promotionPatterns.insert<PromoteMatmulSubviewsPattern,
+ PromoteConvolutionSubviewsPattern>(
context,
linalg::LinalgPromotionOptions()
.setAllocationDeallocationFns(allocateWorkgroupMemory,
deallocateWorkgroupMemory)
- .setCopyInOutFns(
- [&](OpBuilder &b, Value src, Value dst) -> LogicalResult {
- return copyToFromWorkgroupMemory(b, src, dst);
- },
- [&](OpBuilder &b, Value src, Value dst) -> LogicalResult {
- return copyToFromWorkgroupMemory(b, src, dst);
- }),
- linalg::LinalgMarker(Identifier::get(getWorkGroupMarker(), context),
- Identifier::get(PromotionMarker, context)));
+ .setCopyInOutFns(copyToWorkgroupMemory, copyToWorkgroupMemory));
applyPatternsAndFoldGreedily(getOperation(), promotionPatterns);
}
-
- // Add barrier after all linalg operations marked with workitem marker.
- OpBuilder builder(context);
- funcOp.walk([&builder](linalg::LinalgOp linalgOp) {
- if (hasMarker(linalgOp, PromotionMarker)) {
- setWorkGroupMarker(linalgOp);
- insertBarrierAfter(builder, linalgOp.getLoc(), linalgOp);
- }
- });
}
//===----------------------------------------------------------------------===//
@@ -405,8 +432,9 @@
//===----------------------------------------------------------------------===//
std::unique_ptr<OperationPass<FuncOp>> createLinalgTileAndFusePass(
- ArrayRef<int64_t> workGroupSize, bool useWorkgroupMemory) {
- return std::make_unique<LinalgTileAndFusePass>(workGroupSize,
+ ArrayRef<int64_t> workgroupSize, ArrayRef<int64_t> tileSizes,
+ bool useWorkgroupMemory) {
+ return std::make_unique<LinalgTileAndFusePass>(workgroupSize, tileSizes,
useWorkgroupMemory);
}
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/MarkerUtils.cpp b/iree/compiler/Conversion/LinalgToSPIRV/MarkerUtils.cpp
index 47747de..9c45419 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/MarkerUtils.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/MarkerUtils.cpp
@@ -26,22 +26,30 @@
};
const StringLiteral VectorTransforms::kVectorTransformMarker =
"__internal_vector_transform__";
-/// Checks if the operation has the `marker` If `marker` is null string, checks
-/// if any marker is set.
-static bool checkMarkerValue(Operation *op, StringRef marker = "") {
+
+StringRef getWorkgroupMarker() { return "workgroup"; }
+
+StringRef getWorkgroupMemoryMarker() { return "workgroup_memory"; }
+
+StringRef getWorkgroupNumItemsGENumItersMarker() {
+ return "workgroup_numprocs_ge_numiters";
+}
+
+StringRef getWorkgroupMemoryNumItemsGENumItersMarker() {
+ return "workgroup_memory_numprocs_ge_numiters";
+}
+
+StringRef getCopyToWorkgroupMemoryMarker() {
+ return "copy_to_workgroup_memory";
+}
+
+bool hasMarker(Operation *op, ArrayRef<StringRef> marker) {
StringAttr attr = op->getAttrOfType<StringAttr>(
linalg::LinalgTransforms::kLinalgTransformMarker);
- return attr && (marker.empty() || attr.getValue() == marker);
-}
-
-StringRef getWorkGroupMarker() { return "workgroup"; }
-
-bool hasMarker(Operation *op, StringRef marker) {
- return checkMarkerValue(op, marker);
-}
-
-bool hasWorkGroupMarker(Operation *op) {
- return checkMarkerValue(op, getWorkGroupMarker());
+ return attr && (marker.empty() ||
+ llvm::any_of(marker, [&attr](StringRef markerValue) {
+ return attr.getValue() == markerValue;
+ }));
}
void setMarker(Operation *op, StringRef marker) {
@@ -49,7 +57,5 @@
StringAttr::get(marker, op->getContext()));
}
-void setWorkGroupMarker(Operation *op) { setMarker(op, getWorkGroupMarker()); }
-
} // namespace iree_compiler
} // namespace mlir
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/MarkerUtils.h b/iree/compiler/Conversion/LinalgToSPIRV/MarkerUtils.h
index e512ead..222291d 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/MarkerUtils.h
+++ b/iree/compiler/Conversion/LinalgToSPIRV/MarkerUtils.h
@@ -22,7 +22,7 @@
#ifndef IREE_COMPILER_CONVERSION_LINALGTOSPIRV_MARKERUTILS_H_
#define IREE_COMPILER_CONVERSION_LINALGTOSPIRV_MARKERUTILS_H_
-#include "llvm/ADT/StringRef.h"
+#include "llvm/ADT/ArrayRef.h"
#include "mlir/Support/LLVM.h"
namespace mlir {
@@ -30,24 +30,29 @@
class Operation;
namespace iree_compiler {
-/// Marker to denote that a linalg operation is to be partitioned to workitems.
-StringRef getWorkGroupMarker();
+/// Marker to denote that a linalg operation is to be partitioned to
+/// workitems. No assumption can be made about the number of woritems in the
+/// workgroup and number of iterations, i.e. a cyclic distribution is required.
+StringRef getWorkgroupMarker();
+StringRef getWorkgroupMemoryMarker();
+
+/// Marker to denote that a linalg operation is to be partitioned to workitems
+/// with the assumption that the number of workitems in the workgroup is greater
+/// than equal to the number of iterations.
+StringRef getWorkgroupNumItemsGENumItersMarker();
+StringRef getWorkgroupMemoryNumItemsGENumItersMarker();
+
+/// Marker for copy operations that are moving data from StorageClass to
+/// Workgroup memory.
+StringRef getCopyToWorkgroupMemoryMarker();
/// Returns true if an operation has the specified `marker`. When `marker` is
/// empty, returns true if the operation has any marker.
-bool hasMarker(Operation *, StringRef marker = "");
-
-/// Returns true if an operation has marker to denote that it is to be
-/// partitioned to workitems.
-bool hasWorkGroupMarker(Operation *);
+bool hasMarker(Operation *, ArrayRef<StringRef> markers = {});
/// Sets a given marker on an operation.
void setMarker(Operation *, StringRef);
-/// Sets marker to denote that a linalg operation is to be partitioned to
-/// workitems.
-void setWorkGroupMarker(Operation *);
-
} // namespace iree_compiler
} // namespace mlir
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/MatMulVectorizationTest.cpp b/iree/compiler/Conversion/LinalgToSPIRV/MatMulVectorizationTest.cpp
new file mode 100644
index 0000000..9183933
--- /dev/null
+++ b/iree/compiler/Conversion/LinalgToSPIRV/MatMulVectorizationTest.cpp
@@ -0,0 +1,74 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+#include "iree/compiler/Conversion/CodegenUtils/MatmulCodegenStrategy.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassRegistry.h"
+
+static llvm::cl::opt<int> wgTileSize(
+ "iree-codegen-linalg-to-gpu-wg-tile-size",
+ llvm::cl::desc(
+ "Specify the size of workgroup tile for matmul vector lowering"),
+ llvm::cl::init(32));
+
+static llvm::cl::list<uint32_t> unrollSize(
+ "iree-codegen-linalg-to-gpu-unroll-size",
+ llvm::cl::desc("Specify the size of the "), llvm::cl::CommaSeparated);
+
+static llvm::cl::opt<bool> enableLICM(
+ "iree-codegen-linalg-to-gpu-matmul-licm",
+ llvm::cl::desc(
+ "If true run LICM and hoisting passes after the staged transforms"),
+ llvm::cl::init(true));
+
+namespace mlir {
+namespace iree_compiler {
+
+namespace {
+struct MatMulTileAndVectorizeGPUPass
+ : PassWrapper<MatMulTileAndVectorizeGPUPass, FunctionPass> {
+ void runOnFunction() override;
+};
+} // namespace
+
+void MatMulTileAndVectorizeGPUPass::runOnFunction() {
+ FuncOp fn = getFunction();
+ SmallVector<uint32_t, 3> vUnrollSize(unrollSize.begin(), unrollSize.end());
+ if (vUnrollSize.size() != 3) signalPassFailure();
+ MatmulCodegenStrategy strategy;
+ strategy
+ .tile<linalg::MatmulOp>(
+ linalg::LinalgTilingOptions()
+ // TODO(thomasraoux): Enable parallel loops once affine.min
+ // canonicalize supports it.
+ //.setLoopType(linalg::LinalgTilingLoopType::ParallelLoops)
+ .setTileSizes({wgTileSize, wgTileSize, wgTileSize}))
+ .setHoistInvariantCode(enableLICM)
+ .vectorize<linalg::MatmulOp>()
+ .unrollVector<vector::ContractionOp>(
+ {vUnrollSize[0], vUnrollSize[1], vUnrollSize[2]});
+ strategy.transform(fn);
+}
+
+std::unique_ptr<FunctionPass> createMatMulTileAndVectorizeGPUPass() {
+ return std::make_unique<MatMulTileAndVectorizeGPUPass>();
+}
+
+static PassRegistration<MatMulTileAndVectorizeGPUPass> pass(
+ "iree-codegen-linalg-to-gpu-matmul-vectorization-pass",
+ "Tile and vectorize linalg.matmul operation",
+ [] { return std::make_unique<MatMulTileAndVectorizeGPUPass>(); });
+
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/Passes.cpp b/iree/compiler/Conversion/LinalgToSPIRV/Passes.cpp
index fd7b316..5c1b672 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/Passes.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/Passes.cpp
@@ -59,6 +59,9 @@
"three integers standarding for the x, y, and z dimension; "
"additional arguments will be ignored (used only for testing)"),
llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
+ ListOption<int64_t> tileSizes{
+ *this, "tile-sizes", llvm::cl::desc("Set tile sizes to use"),
+ llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
Option<bool> useWorkgroupMemory{
*this, "use-workgroup-memory",
llvm::cl::desc(
@@ -97,8 +100,8 @@
// afterwards. This gives each Linalg op a second chance to be tiled,
// with the second tile and fuse pass.
//===--------------------------------------------------------------------===//
- pm.addPass(createLinalgTileAndFusePass(options.workgroupSize,
- options.useWorkgroupMemory));
+ pm.addPass(createLinalgTileAndFusePass(
+ options.workgroupSize, options.tileSizes, options.useWorkgroupMemory));
pm.addPass(createSplitDispatchFunctionPass());
pm.addPass(createLinalgTileAndFusePass(options.workgroupSize,
options.useWorkgroupMemory));
@@ -221,6 +224,7 @@
SPIRVCodegenOptions options;
options.workgroupSize.assign(clOpts.workgroupSize.begin(),
clOpts.workgroupSize.end());
+ options.tileSizes.assign(clOpts.tileSizes.begin(), clOpts.tileSizes.end());
options.useWorkgroupMemory = clOpts.useWorkgroupMemory;
return options;
}
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/Passes.h b/iree/compiler/Conversion/LinalgToSPIRV/Passes.h
index 8267780..35de230 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/Passes.h
+++ b/iree/compiler/Conversion/LinalgToSPIRV/Passes.h
@@ -25,6 +25,7 @@
// Options that can be used to configure SPIR-V codegeneration.
struct SPIRVCodegenOptions {
SmallVector<int64_t, 3> workgroupSize = {};
+ SmallVector<int64_t, 3> tileSizes = {};
bool useWorkgroupMemory = false;
};
@@ -35,7 +36,8 @@
/// it exists) and along "z" for the next loop (if it exists). The workgroup
/// size is expected to be of size at-most 3.
std::unique_ptr<OperationPass<FuncOp>> createLinalgTileAndFusePass(
- ArrayRef<int64_t> workGroupSize = {}, bool useWorkgroupMem = false);
+ ArrayRef<int64_t> workGroupSize = {}, ArrayRef<int64_t> tileSizes = {},
+ bool useWorkgroupMem = false);
/// Pass to add the synchronizations and attributes needed to lower from PLoops
/// to GPU dialect.
@@ -60,6 +62,9 @@
/// vector size equal to subgroup size are distributed across the subgroup.
std::unique_ptr<OperationPass<FuncOp>> createVectorToGPUPass();
+/// Pass to apply tiling and vectorization transformations on linagl::MatMulOp.
+std::unique_ptr<FunctionPass> createMatMulTileAndVectorizeGPUPass();
+
/// Populates passes needed to lower a XLA HLO op to SPIR-V dialect via the
/// structured ops path. The pass manager `pm` in here operate on the module
/// within the IREE::HAL::ExecutableOp. The `workGroupSize` can be used to
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/convert_to_gpu.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/convert_to_gpu.mlir
index 64621f3..599af08 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/test/convert_to_gpu.mlir
+++ b/iree/compiler/Conversion/LinalgToSPIRV/test/convert_to_gpu.mlir
@@ -162,7 +162,7 @@
%12 = dim %arg2, %c1 : memref<?x?xf32>
%13 = affine.min #map0(%arg4)[%12]
%14 = subview %arg2[%arg3, %arg4] [%11, %13] [1, 1] : memref<?x?xf32> to memref<?x?xf32, #map2>
- linalg.matmul %5, %9, %14 {__internal_linalg_transform__ = "workgroup"} : (memref<?x?xf32, #map2>, memref<?x?xf32, #map2>, memref<?x?xf32, #map2>)
+ linalg.matmul %5, %9, %14 {__internal_linalg_transform__ = "workgroup_numprocs_ge_numiters"} : (memref<?x?xf32, #map2>, memref<?x?xf32, #map2>, memref<?x?xf32, #map2>)
}
scf.yield
}
@@ -288,57 +288,6 @@
// -----
-#map0 = affine_map<(d0, d1, d2) -> (32, d1 - d2)>
-#map1 = affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3, s4] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3 + d3 * s4)>
-
-
-module attributes {spv.target_env = #spv.target_env<#spv.vce<v1.3, [Shader], [SPV_KHR_storage_buffer_storage_class]>, {max_compute_workgroup_invocations = 128 : i32, max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} {
- func @conv_padding(%arg0: memref<?x?x?x?xf32>, %arg1: memref<?x?x?x?xf32>, %arg2: memref<?x?x?x?xf32>) attributes {spv.entry_point_abi = {local_size = dense<[32, 1, 1]> : vector<3xi32>}} {
- linalg.conv(%arg0, %arg1, %arg2) {dilations = [1, 1], padding = dense<[[1, 1], [0, 1]]> : tensor<2x2xi64>, strides = [1, 1]} : memref<?x?x?x?xf32>, memref<?x?x?x?xf32>, memref<?x?x?x?xf32>
- return
- }
-}
-
-// CHECK-LABEL: func @conv_padding
-// CHECK-SAME: %[[ARG0:[a-zA-Z0-9$._-]+]]: memref<?x?x?x?xf32>
-// CHECK-SAME: %[[ARG1:[a-zA-Z0-9$._-]+]]: memref<?x?x?x?xf32>
-// CHECK-SAME: %[[ARG2:[a-zA-Z0-9$._-]+]]: memref<?x?x?x?xf32>
-// CHECK-SAME: local_size = dense<[32, 1, 1]>
-// CHECK-SAME: vkspv.workgroup_count_from_result_shape = 1
-// CHECK-DAG: %[[C2:.+]] = constant 2 : index
-// CHECK-DAG: %[[C3:.+]] = constant 3 : index
-// CHECK-DAG: %[[C0:.+]] = constant 0 : index
-// CHECK-DAG: %[[C1:.+]] = constant 1 : index
-// CHECK-DAG: %[[UB0:.+]] = dim %[[ARG0]], %[[C0]]
-// CHECK-DAG: %[[UB1:.+]] = dim %[[ARG0]], %[[C1]]
-// CHECK-DAG: %[[UB2:.+]] = dim %[[ARG0]], %[[C2]]
-// CHECK-DAG: %[[UB3:.+]] = dim %[[ARG0]], %[[C3]]
-// CHECK-DAG: %[[UB4:.+]] = dim %[[ARG1]], %[[C0]]
-// CHECK-DAG: %[[UB5:.+]] = dim %[[ARG2]], %[[C1]]
-// CHECK-DAG: %[[UB6:.+]] = dim %[[ARG2]], %[[C2]]
-// CHECK: %[[T7:.+]] = muli %[[UB3]], %[[UB6]]
-// CHECK: %[[T8:.+]] = muli %[[T7]], %[[UB5]]
-// CHECK: %[[UB:.+]] = muli %[[T8]], %[[UB4]]
-// CHECK-DAG: %[[BIDX:.+]] = "gpu.block_id"() {dimension = "x"}
-// CHECK-DAG: %[[NTHREADSX:.+]] = "gpu.block_dim"() {dimension = "x"}
-// CHECK-DAG: %[[TIDX:.+]] = "gpu.thread_id"() {dimension = "x"}
-// CHECK: %[[T13:.+]] = muli %[[BIDX]], %[[NTHREADSX]]
-// CHECK: %[[PROCID:.+]] = addi %[[T13]], %[[TIDX]]
-// CHECK: %[[COND:.+]] = cmpi "slt", %[[PROCID]], %[[UB]]
-// CHECK: scf.if %[[COND]]
-// CHECK: %[[IV0:.+]] = divi_signed %[[PROCID]], %[[T8]]
-// CHECK: %[[T17:.+]] = remi_signed %[[PROCID]], %[[T8]]
-// CHECK: %[[IV1:.+]] = divi_signed %[[T17]], %[[T7]]
-// CHECK: %[[T19:.+]] = remi_signed %[[T17]], %[[T7]]
-// CHECK: %[[IV2:.+]] = divi_signed %[[T19]], %[[UB3]]
-// CHECK: %[[T21:.+]] = remi_signed %[[T19]], %[[UB3]]
-// CHECK: scf.for %[[IV3:.+]] = %[[C0]] to %[[UB2]] step %[[C1]]
-// CHECK: scf.for %[[IV4:.+]] = %[[C0]] to %[[UB0]] step %[[C1]]
-// CHECK: scf.for %[[IV5:.+]]= %[[C0]] to %[[UB1]] step %[[C1]]
-// CHECK-NOT: linalg.conv
-
-// -----
-
#map0 = affine_map<(d0)[s0, s1] -> (s0 + 4, -d0 + s1)>
#map1 = affine_map<(d0)[s0, s1] -> (s0 + 32, -d0 + s1)>
#map2 = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/convert_to_gpu_option.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/convert_to_gpu_option.mlir
deleted file mode 100644
index 63f8aa5..0000000
--- a/iree/compiler/Conversion/LinalgToSPIRV/test/convert_to_gpu_option.mlir
+++ /dev/null
@@ -1,76 +0,0 @@
-// RUN: iree-opt -iree-codegen-convert-to-gpu -iree-codegen-use-legacy-conv-lowering=false -canonicalize -cse -split-input-file %s | IreeFileCheck %s
-
-#map0 = affine_map<(d0)[s0] -> (1, -d0 + s0)>
-#map1 = affine_map<(d0)[s0, s1] -> (s0 + 4, -d0 + s1)>
-#map2 = affine_map<(d0)[s0, s1] -> (s0 + 32, -d0 + s1)>
-#map3 = affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3 + d3)>
-#map4 = affine_map<(d0)[s0] -> (4, -d0 + s0)>
-#map5 = affine_map<(d0)[s0] -> (32, -d0 + s0)>
-
-module attributes {spv.target_env = #spv.target_env<#spv.vce<v1.3, [Shader], [SPV_KHR_storage_buffer_storage_class]>, {max_compute_workgroup_invocations = 128 : i32, max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} {
- func @conv_no_padding(%arg0: memref<?x?x?x?xf32>, %arg1: memref<?x?x?x?xf32>, %arg2: memref<?x?x?x?xf32>) attributes {spv.entry_point_abi = {local_size = dense<[32, 4, 1]> : vector<3xi32>}} {
- %c4 = constant 4 : index
- %c32 = constant 32 : index
- %c2 = constant 2 : index
- %c0 = constant 0 : index
- %c3 = constant 3 : index
- %c1 = constant 1 : index
- %0 = dim %arg1, %c0 : memref<?x?x?x?xf32>
- %1 = dim %arg1, %c1 : memref<?x?x?x?xf32>
- %2 = dim %arg1, %c2 : memref<?x?x?x?xf32>
- %3 = dim %arg2, %c1 : memref<?x?x?x?xf32>
- %4 = dim %arg2, %c2 : memref<?x?x?x?xf32>
- scf.parallel (%arg3, %arg4, %arg5) = (%c0, %c0, %c0) to (%0, %3, %4) step (%c1, %c4, %c32) {
- %5 = affine.min #map0(%arg3)[%0]
- %6 = affine.min #map1(%arg4)[%1, %1]
- %7 = affine.min #map2(%arg5)[%2, %2]
- %8 = dim %arg1, %c3 : memref<?x?x?x?xf32>
- %9 = subview %arg1[%arg3, %arg4, %arg5, 0] [%5, %6, %7, %8] [1, 1, 1, 1] : memref<?x?x?x?xf32> to memref<?x?x?x?xf32, #map3>
- %10 = dim %arg2, %c0 : memref<?x?x?x?xf32>
- %11 = affine.min #map0(%arg3)[%10]
- %12 = affine.min #map4(%arg4)[%3]
- %13 = affine.min #map5(%arg5)[%4]
- %14 = dim %arg2, %c3 : memref<?x?x?x?xf32>
- %15 = subview %arg2[%arg3, %arg4, %arg5, 0] [%11, %12, %13, %14] [1, 1, 1, 1] : memref<?x?x?x?xf32> to memref<?x?x?x?xf32, #map3>
- linalg.conv(%arg0, %9, %15) {__internal_linalg_transform__ = "workgroup", dilations = [1, 1], strides = [1, 1]} : memref<?x?x?x?xf32>, memref<?x?x?x?xf32, #map3>, memref<?x?x?x?xf32, #map3>
- scf.yield
- }
- return
- }
-}
-
-// CHECK-LABEL: func @conv_no_padding
-// CHECK-SAME: %[[ARG0:[a-zA-Z0-9$._-]+]]: memref<?x?x?x?xf32>
-// CHECK-SAME: %[[ARG1:[a-zA-Z0-9$._-]+]]: memref<?x?x?x?xf32>
-// CHECK-SAME: %[[ARG2:[a-zA-Z0-9$._-]+]]: memref<?x?x?x?xf32>
-// CHECK-DAG: %[[C4:.+]] = constant 4 : index
-// CHECK-DAG: %[[C32:.+]] = constant 32 : index
-// CHECK-DAG: %[[C2:.+]] = constant 2 : index
-// CHECK-DAG: %[[C0:.+]] = constant 0 : index
-// CHECK-DAG: %[[C1:.+]] = constant 1 : index
-// CHECK-DAG: %[[UB0:.+]] = dim %[[ARG1]], %[[C0]]
-// CHECK-DAG: %[[UB1:.+]] = dim %[[ARG1]], %[[C1]]
-// CHECK-DAG: %[[UB2:.+]] = dim %[[ARG1]], %[[C2]]
-// CHECK-DAG: %[[UB3:.+]] = dim %[[ARG2]], %[[C1]]
-// CHECK-DAG: %[[UB4:.+]] = dim %[[ARG2]], %[[C2]]
-// CHECK-DAG: %[[BIDX:.+]] = "gpu.block_id"() {dimension = "x"}
-// CHECK-DAG: %[[BIDY:.+]] = "gpu.block_id"() {dimension = "y"}
-// CHECK-DAG: %[[BIDZ:.+]] = "gpu.block_id"() {dimension = "z"}
-// CHECK-DAG: %[[BOFFSETY:.+]] = muli %[[BIDY]], %[[C4]]
-// CHECK-DAG: %[[BOFFSETX:.+]] = muli %[[BIDX]], %[[C32]]
-// CHECK: %[[SV1:.+]] = subview %[[ARG1]][%[[BIDZ]], %[[BOFFSETY]], %[[BOFFSETX]], 0]
-// CHECK: %[[SV2:.+]] = subview %[[ARG2]][%[[BIDZ]], %[[BOFFSETY]], %[[BOFFSETX]], 0]
-// CHECK-DAG: %[[TIDX:.+]] = "gpu.thread_id"() {dimension = "x"}
-// CHECK-DAG: %[[TIDY:.+]] = "gpu.thread_id"() {dimension = "y"}
-// CHECK-DAG: %[[TIDZ:.+]] = "gpu.thread_id"() {dimension = "z"}
-// CHECK: %[[INBOUNDSZ:.+]] = cmpi "slt", %[[TIDZ]], %{{.+}}
-// CHECK: %[[INBOUNDSY:.+]] = cmpi "slt", %[[TIDY]], %{{.+}}
-// CHECK: %[[T35:.+]] = and %[[INBOUNDSZ]], %[[INBOUNDSY]]
-// CHECK: %[[INBOUNDSX:.+]] = cmpi "slt", %[[TIDX]], %{{.+}}
-// CHECK: %[[INBOUNDS:.+]] = and %[[T35]], %[[INBOUNDSX]]
-// CHECK: scf.if %[[INBOUNDS]]
-// CHECK: scf.for
-// CHECK: scf.for
-// CHECK: scf.for
-// CHECK: scf.for
-// CHECK-NOT: linalg.conv
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/cyclic_to_workgroup.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/cyclic_to_workgroup.mlir
deleted file mode 100644
index cac18ab..0000000
--- a/iree/compiler/Conversion/LinalgToSPIRV/test/cyclic_to_workgroup.mlir
+++ /dev/null
@@ -1,69 +0,0 @@
-// RUN: iree-opt -iree-codegen-convert-to-gpu -canonicalize -cse -split-input-file -iree-codegen-constrained-workgroup-count %s | IreeFileCheck %s
-
-#map0 = affine_map<(d0)[s0] -> (8, -d0 + s0)>
-#map1 = affine_map<(d0)[s0] -> (4, -d0 + s0)>
-#map2 = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
-
-module attributes {spv.target_env = #spv.target_env<#spv.vce<v1.3, [Shader], [SPV_KHR_storage_buffer_storage_class]>, {max_compute_workgroup_invocations = 128 : i32, max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} {
- func @matmul(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?xf32>) attributes {spv.entry_point_abi = {local_size = dense<[8, 8, 1]> : vector<3xi32>}} {
- %c0 = constant 0 : index
- %c1 = constant 1 : index
- %c4 = constant 4 : index
- %c8 = constant 8 : index
- %0 = dim %arg0, %c0 : memref<?x?xf32>
- %1 = dim %arg0, %c1 : memref<?x?xf32>
- %2 = dim %arg1, %c1 : memref<?x?xf32>
- scf.parallel (%arg3, %arg4) = (%c0, %c0) to (%0, %2) step (%c8, %c8) {
- scf.for %arg5 = %c0 to %1 step %c4 {
- %3 = affine.min #map0(%arg3)[%0]
- %4 = affine.min #map1(%arg5)[%1]
- %5 = subview %arg0[%arg3, %arg5] [%3, %4] [1, 1] : memref<?x?xf32> to memref<?x?xf32, #map2>
- %6 = dim %arg1, %c0 : memref<?x?xf32>
- %7 = affine.min #map1(%arg5)[%6]
- %8 = affine.min #map0(%arg4)[%2]
- %9 = subview %arg1[%arg5, %arg4] [%7, %8] [1, 1] : memref<?x?xf32> to memref<?x?xf32, #map2>
- %10 = dim %arg2, %c0 : memref<?x?xf32>
- %11 = affine.min #map0(%arg3)[%10]
- %12 = dim %arg2, %c1 : memref<?x?xf32>
- %13 = affine.min #map0(%arg4)[%12]
- %14 = subview %arg2[%arg3, %arg4] [%11, %13] [1, 1] : memref<?x?xf32> to memref<?x?xf32, #map2>
- linalg.matmul %5, %9, %14 {__internal_linalg_transform__ = "workgroup"} : (memref<?x?xf32, #map2>, memref<?x?xf32, #map2>, memref<?x?xf32, #map2>)
- }
- scf.yield
- }
- return
- }
-}
-
-// CHECK-LABEL: func @matmul
-// CHECK-SAME: %[[ARG0:[a-zA-Z0-9$._-]+]]: memref<?x?xf32>
-// CHECK-SAME: %[[ARG1:[a-zA-Z0-9$._-]+]]: memref<?x?xf32>
-// CHECK-SAME: %[[ARG2:[a-zA-Z0-9$._-]+]]: memref<?x?xf32>
-// CHECK-DAG: %[[C4:.+]] = constant 4 : index
-// CHECK-DAG: %[[C8:.+]] = constant 8 : index
-// CHECK-DAG: %[[C0:.+]] = constant 0 : index
-// CHECK-DAG: %[[C1:.+]] = constant 1 : index
-// CHECK-DAG: %[[UB0:.+]] = dim %[[ARG0]], %[[C0]]
-// CHECK-DAG: %[[UB1:.+]] = dim %[[ARG1]], %[[C1]]
-// CHECK-DAG: %[[UB2:.+]] = dim %[[ARG0]], %[[C1]]
-// CHECK-DAG: %[[BIDX:.+]] = "gpu.block_id"() {dimension = "x"}
-// CHECK-DAG: %[[GDIMX:.+]] = "gpu.grid_dim"() {dimension = "x"}
-// CHECK-DAG: %[[BIDY:.+]] = "gpu.block_id"() {dimension = "y"}
-// CHECK-DAG: %[[GDIMY:.+]] = "gpu.grid_dim"() {dimension = "y"}
-// CHECK: %[[BOFFSETY:.+]] = muli %[[BIDY]], %[[C8]]
-// CHECK: %[[BSTEPY:.+]] = muli %[[GDIMY]], %[[C8]]
-// CHECK: %[[BOFFSETX:.+]] = muli %[[BIDX]], %[[C8]]
-// CHECK: %[[BSTEPX:.+]] = muli %[[GDIMX]], %[[C8]]
-// CHECK: scf.for %[[BIV0:.+]] = %[[BOFFSETY]] to %[[UB0]] step %[[BSTEPY]]
-// CHECK: scf.for %[[BIV1:.+]] = %[[BOFFSETX]] to %[[UB1]] step %[[BSTEPX]]
-// CHECK: scf.for %[[BIV2:.+]] = %[[C0]] to %[[UB2]] step %[[C4]]
-// CHECK-DAG: %[[VIEWUB0:.+]] = affine.min #{{.*}}(%[[BIV0]])[%[[UB0]]]
-// CHECK-DAG: %[[VIEWUB1:.+]] = affine.min #{{.*}}(%[[BIV1]])[%[[UB1]]]
-// CHECK-DAG: %[[VIEWUB2:.+]] = affine.min #{{.*}}(%[[BIV2]])[%[[UB2]]]
-// CHECK-DAG: %[[TIDX:.+]] = "gpu.thread_id"() {dimension = "x"}
-// CHECK-DAG: %[[TIDY:.+]] = "gpu.thread_id"() {dimension = "y"}
-// CHECK: %[[INBOUNDY:.+]] = cmpi "slt", %[[TIDY]], %[[VIEWUB0]]
-// CHECK: %[[INBOUNDX:.+]] = cmpi "slt", %[[TIDX]], %[[VIEWUB1]]
-// CHECK: %[[COND:.+]] = and %[[INBOUNDY]], %[[INBOUNDX]]
-// CHECK: scf.if %[[COND]]
-// CHECK: scf.for %{{.*}} = %[[C0]] to %[[VIEWUB2]] step %[[C1]]
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/linalg_tile_and_fuse.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/linalg_tile_and_fuse.mlir
index 1728d35..7e7433a 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/test/linalg_tile_and_fuse.mlir
+++ b/iree/compiler/Conversion/LinalgToSPIRV/test/linalg_tile_and_fuse.mlir
@@ -81,7 +81,7 @@
// CHECK: %[[VIEW1:.+]] = subview %[[ARG1]]
// CHECK: %[[VIEW2:.+]] = subview %[[ARG2]]
// CHECK: linalg.matmul
-// CHECK-SAME: "workgroup"
+// CHECK-SAME: "workgroup_numprocs_ge_numiters"
// CHECK-SAME: %[[VIEW0]], %[[VIEW1]], %[[VIEW2]]
// -----
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/matmul_vectorization.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/matmul_vectorization.mlir
new file mode 100644
index 0000000..63ba65f
--- /dev/null
+++ b/iree/compiler/Conversion/LinalgToSPIRV/test/matmul_vectorization.mlir
@@ -0,0 +1,21 @@
+// RUN: iree-opt --iree-codegen-linalg-to-gpu-matmul-vectorization-pass
+// RUN: -split-input-file %s --iree-codegen-linalg-to-gpu-unroll-size=8,8,32 \
+// RUN: -iree-codegen-linalg-to-gpu-matmul-licm | IreeFileCheck %s
+
+// CHECK-LABEL: func @matmul_128x128x128
+// CHECK-SAME: (%[[ARG0:.+]]: memref<128x128xf32>, %[[ARG1:.+]]: memref<128x128xf32>, %[[ARG2:.+]]: memref<128x128xf32>)
+func @matmul_128x128x128(%arg0 : memref<128x128xf32>, %arg1: memref<128x128xf32>, %arg2: memref<128x128xf32>) {
+ linalg.matmul %arg0, %arg1, %arg2 : (memref<128x128xf32>, memref<128x128xf32>, memref<128x128xf32>)
+ return
+}
+
+// CHECK: %[[TILESIZE:.+]] = constant 32 : index
+// CHECK: %[[MATSIZE:.+]] = constant 128 : index
+// CHECK: %[[START:.+]] = constant 0 : index
+// CHECK: scf.for %[[IL:.+]] = %[[START]] to %[[MATSIZE]] step %[[TILESIZE]]
+// CHECK: scf.for %[[JL:.+]] = %[[START]] to %[[MATSIZE]] step %[[TILESIZE]]
+// CHECK: %[[SUBVVIEWC:.+]] = subview %[[ARG2]][%[[IL]], %[[JL]]] [32, 32] [1, 1] : memref<128x128xf32> to memref<32x32xf32
+// CHECK: scf.for %[[KL:.+]] = %[[START]] to %[[MATSIZE]] step %[[TILESIZE]]
+// CHECK: %[[SUBVVIEWA:.+]] = subview %[[ARG0]][%[[IL]], %[[KL]]] [32, 32] [1, 1] : memref<128x128xf32> to memref<32x32xf32
+// CHECK: %[[SUBVVIEWB:.+]] = subview %[[ARG1]][%[[KL]], %[[JL]]] [32, 32] [1, 1] : memref<128x128xf32> to memref<32x32xf32
+
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/workgroup_memory_promotion.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/workgroup_memory_promotion.mlir
index a24c77b..0f97d1e 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/test/workgroup_memory_promotion.mlir
+++ b/iree/compiler/Conversion/LinalgToSPIRV/test/workgroup_memory_promotion.mlir
@@ -1,47 +1,55 @@
-// RUN: iree-opt -split-input-file -iree-codegen-linalg-tile-and-fuse=use-workgroup-memory %s | IreeFileCheck %s
+// RUN: iree-opt -split-input-file -iree-codegen-linalg-tile-and-fuse=use-workgroup-memory -canonicalize -cse %s | IreeFileCheck %s
module attributes {spv.target_env = #spv.target_env<#spv.vce<v1.3, [Shader], [SPV_KHR_storage_buffer_storage_class]>, {max_compute_workgroup_invocations = 128 : i32, max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} {
- func @matmul_tile() {
- %arg0 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<96x96xf32>
- %arg1 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg1} : memref<96x96xf32>
- %arg2 = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<96x96xf32>
+ func @matmul_tile(%arg0 : memref<?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?xf32>) {
linalg.matmul %arg0, %arg1, %arg2 :
- (memref<96x96xf32>, memref<96x96xf32>, memref<96x96xf32>)
+ (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>)
return
}
+}
+// CHECK-LABEL: func @matmul_tile
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<?x?xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: memref<?x?xf32>
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: memref<?x?xf32>
+// CHECK: scf.parallel (%{{.*}}, %{{.*}})
+// CHECK: scf.for %{{.*}}
+// CHECK: %[[ARG0SV:.+]] = subview %[[ARG0]]
+// CHECK: %[[ARG1SV:.+]] = subview %[[ARG1]]
+// CHECK: %[[ARG2SV:.+]] = subview %[[ARG2]]
+// CHECK: %[[ALLOC1:.+]] = alloc() : memref<8x4xf32, 3>
+// CHECK: %[[SUBVIEW1:.+]] = subview %[[ALLOC1]]
+// CHECK: %[[ALLOC2:.+]] = alloc() : memref<4x8xf32, 3>
+// CHECK: %[[SUBVIEW2:.+]] = subview %[[ALLOC2]]
+// CHECK: linalg.copy(%[[ARG0SV]], %[[SUBVIEW1]])
+// CHECK-SAME: "copy_to_workgroup_memory"
+// CHECK: linalg.copy(%[[ARG1SV]], %[[SUBVIEW2]])
+// CHECK-SAME: "copy_to_workgroup_memory"
+// CHECK: linalg.matmul
+// CHECK-SAME: "workgroup_memory_numprocs_ge_numiters"
+// CHECK-SAME: %[[SUBVIEW1]], %[[SUBVIEW2]], %[[ARG2SV]]
+// CHECK-DAG: dealloc %[[ALLOC1]] : memref<8x4xf32, 3>
+// CHECK-DAG: dealloc %[[ALLOC2]] : memref<4x8xf32, 3>
- hal.interface @legacy_io attributes {push_constants = 5 : i32, sym_visibility = "private"} {
- hal.interface.binding @arg0, set=0, binding=1, type="StorageBuffer", access="Read"
- hal.interface.binding @arg1, set=0, binding=2, type="StorageBuffer", access="Read"
- hal.interface.binding @ret0, set=0, binding=3, type="StorageBuffer", access="Write"
+// -----
+
+
+module attributes {spv.target_env = #spv.target_env<#spv.vce<v1.3, [Shader], [SPV_KHR_storage_buffer_storage_class]>, {max_compute_workgroup_invocations = 128 : i32, max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} {
+ func @conv_no_padding_tile(%arg0: memref<3x4x3x2xf32>, %arg1: memref<?x?x?x3xf32>, %arg2: memref<?x?x?x2xf32>) {
+ linalg.conv(%arg0, %arg1, %arg2) {dilations = [1, 1], strides = [1, 1]} : memref<3x4x3x2xf32>, memref<?x?x?x3xf32>, memref<?x?x?x2xf32>
+ return
}
}
-// CHECK-DAG: %[[C4:.+]] = constant 4 : index
-// CHECK-DAG: %[[C8:.+]] = constant 8 : index
-// CHECK-DAG: %[[C96:.+]] = constant 96 : index
-// CHECK-DAG: %[[C0:.+]] = constant 0 : index
-// CHECK: %[[ARG0:.+]] = iree.placeholder
-// CHECK-SAME: binding = @legacy_io::@arg0
-// CHECK: %[[ARG1:.+]] = iree.placeholder
-// CHECK-SAME: binding = @legacy_io::@arg1
-// CHECK: %[[RET0:.+]] = iree.placeholder
-// CHECK-SAME: binding = @legacy_io::@ret0
-// CHECK: scf.parallel (%{{.*}}, %{{.*}})
-// CHECK: scf.for %{{.*}} = %[[C0]] to %{{.*}} step %[[C4]]
-// CHECK: %[[ARG0SV:.+]] = subview %[[ARG0]]
-// CHECK: %[[ARG1SV:.+]] = subview %[[ARG1]]
-// CHECK: %[[RET0SV:.+]] = subview %[[RET0]]
-// CHECK: %[[ALLOC1:.+]] = alloc(%[[C8]], %[[C4]]) : memref<?x?xf32, 3>
-// CHECK: %[[SUBVIEW1:.+]] = subview %[[ALLOC1]]
-// CHECK: %[[ALLOC2:.+]] = alloc(%[[C4]], %[[C8]]) : memref<?x?xf32, 3>
-// CHECK: %[[SUBVIEW2:.+]] = subview %[[ALLOC2]]
-// CHECK: linalg.copy(%[[ARG0SV]], %[[SUBVIEW1]])
-// CHECK-SAME: "workgroup"
-// CHECK: spv.ControlBarrier "Workgroup", "Workgroup", "AcquireRelease"
-// CHECK: linalg.copy(%[[ARG1SV]], %[[SUBVIEW2]])
-// CHECK-SAME: "workgroup"
-// CHECK: spv.ControlBarrier "Workgroup", "Workgroup", "AcquireRelease"
-// CHECK: linalg.matmul {{.*}}"workgroup"{{.*}} %[[SUBVIEW1]], %[[SUBVIEW2]], %[[RET0SV]]
-// CHECK: spv.ControlBarrier "Workgroup", "Workgroup", "AcquireRelease"
-// CHECK-DAG: dealloc %[[ALLOC1]] : memref<?x?xf32, 3>
-// CHECK-DAG: dealloc %[[ALLOC2]] : memref<?x?xf32, 3>
+// CHECK-LABEL: func @conv_no_padding_tile
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<3x4x3x2xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: memref<?x?x?x3xf32>
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: memref<?x?x?x2xf32>
+// CHECK: scf.parallel (%{{.*}}, %{{.*}}, %{{.*}})
+// CHECK: %[[ARG1SV:.+]] = subview %[[ARG1]]
+// CHECK: %[[ARG2SV:.+]] = subview %[[ARG2]]
+// CHECK: %[[ALLOC1:.+]] = alloc() : memref<1x7x36x3xf32, 3>
+// CHECK: %[[SUBVIEW1:.+]] = subview %[[ALLOC1]]
+// CHECK: linalg.copy(%[[ARG1SV]], %[[SUBVIEW1]])
+// CHECK-SAME: "copy_to_workgroup_memory"
+// CHECK: linalg.conv(%[[ARG0]], %[[SUBVIEW1]], %[[ARG2SV]])
+// CHECK-SAME: "workgroup_memory"
+// CHECK: dealloc %[[ALLOC1]] : memref<1x7x36x3xf32, 3>
diff --git a/iree/compiler/Conversion/init_conversions.h b/iree/compiler/Conversion/init_conversions.h
index bc218e8..74eaf71 100644
--- a/iree/compiler/Conversion/init_conversions.h
+++ b/iree/compiler/Conversion/init_conversions.h
@@ -40,6 +40,7 @@
createLinalgTileAndFusePass();
createSplitDispatchFunctionPass();
createVectorToGPUPass();
+ createMatMulTileAndVectorizeGPUPass();
return true;
}();
(void)init_once;
diff --git a/iree/compiler/Dialect/HAL/IR/HALTypes.h b/iree/compiler/Dialect/HAL/IR/HALTypes.h
index 3df35e1..790e8e6 100644
--- a/iree/compiler/Dialect/HAL/IR/HALTypes.h
+++ b/iree/compiler/Dialect/HAL/IR/HALTypes.h
@@ -63,7 +63,6 @@
static AllocatorType get(MLIRContext *context) {
return Base::get(context, TypeKind::Allocator);
}
- static bool kindof(unsigned kind) { return kind == TypeKind::Allocator; }
};
class BufferType : public Type::TypeBase<BufferType, Type, TypeStorage> {
@@ -72,7 +71,6 @@
static BufferType get(MLIRContext *context) {
return Base::get(context, TypeKind::Buffer);
}
- static bool kindof(unsigned kind) { return kind == TypeKind::Buffer; }
};
class BufferViewType
@@ -82,7 +80,6 @@
static BufferViewType get(MLIRContext *context) {
return Base::get(context, TypeKind::BufferView);
}
- static bool kindof(unsigned kind) { return kind == TypeKind::BufferView; }
};
class CommandBufferType
@@ -92,7 +89,6 @@
static CommandBufferType get(MLIRContext *context) {
return Base::get(context, TypeKind::CommandBuffer);
}
- static bool kindof(unsigned kind) { return kind == TypeKind::CommandBuffer; }
};
class DescriptorSetType
@@ -102,7 +98,6 @@
static DescriptorSetType get(MLIRContext *context) {
return Base::get(context, TypeKind::DescriptorSet);
}
- static bool kindof(unsigned kind) { return kind == TypeKind::DescriptorSet; }
};
class DescriptorSetLayoutType
@@ -112,9 +107,6 @@
static DescriptorSetLayoutType get(MLIRContext *context) {
return Base::get(context, TypeKind::DescriptorSetLayout);
}
- static bool kindof(unsigned kind) {
- return kind == TypeKind::DescriptorSetLayout;
- }
};
class DeviceType : public Type::TypeBase<DeviceType, Type, TypeStorage> {
@@ -123,7 +115,6 @@
static DeviceType get(MLIRContext *context) {
return Base::get(context, TypeKind::Device);
}
- static bool kindof(unsigned kind) { return kind == TypeKind::Device; }
};
class EventType : public Type::TypeBase<EventType, Type, TypeStorage> {
@@ -132,7 +123,6 @@
static EventType get(MLIRContext *context) {
return Base::get(context, TypeKind::Event);
}
- static bool kindof(unsigned kind) { return kind == TypeKind::Event; }
};
class ExecutableType
@@ -142,7 +132,6 @@
static ExecutableType get(MLIRContext *context) {
return Base::get(context, TypeKind::Executable);
}
- static bool kindof(unsigned kind) { return kind == TypeKind::Executable; }
};
class ExecutableCacheType
@@ -152,9 +141,6 @@
static ExecutableCacheType get(MLIRContext *context) {
return Base::get(context, TypeKind::ExecutableCache);
}
- static bool kindof(unsigned kind) {
- return kind == TypeKind::ExecutableCache;
- }
};
class ExecutableLayoutType
@@ -164,9 +150,6 @@
static ExecutableLayoutType get(MLIRContext *context) {
return Base::get(context, TypeKind::ExecutableLayout);
}
- static bool kindof(unsigned kind) {
- return kind == TypeKind::ExecutableLayout;
- }
};
class RingBufferType
@@ -176,7 +159,6 @@
static RingBufferType get(MLIRContext *context) {
return Base::get(context, TypeKind::RingBuffer);
}
- static bool kindof(unsigned kind) { return kind == TypeKind::RingBuffer; }
};
class SemaphoreType : public Type::TypeBase<SemaphoreType, Type, TypeStorage> {
@@ -185,7 +167,6 @@
static SemaphoreType get(MLIRContext *context) {
return Base::get(context, TypeKind::Semaphore);
}
- static bool kindof(unsigned kind) { return kind == TypeKind::Semaphore; }
};
//===----------------------------------------------------------------------===//
diff --git a/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp b/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp
index a0cafcc..957bb53 100644
--- a/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp
+++ b/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp
@@ -63,6 +63,11 @@
"Workgroup size to use for XLA-HLO to Linalg to SPIR-V path"),
llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated);
+ static llvm::cl::list<unsigned> clTileSizes(
+ "iree-spirv-tile-size",
+ llvm::cl::desc("Tile size to use for tiling Linalg operations"),
+ llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated);
+
static llvm::cl::opt<std::string> clVulkanTargetEnv(
"iree-vulkan-target-env",
llvm::cl::desc(
@@ -70,9 +75,10 @@
llvm::cl::init(Vulkan::swiftShaderTargetEnvAssembly));
VulkanSPIRVTargetOptions targetOptions;
- for (unsigned dim : clWorkgroupSize) {
- targetOptions.codegenOptions.workgroupSize.push_back(dim);
- }
+ targetOptions.codegenOptions.workgroupSize.assign(clWorkgroupSize.begin(),
+ clWorkgroupSize.end());
+ targetOptions.codegenOptions.tileSizes.assign(clTileSizes.begin(),
+ clTileSizes.end());
targetOptions.codegenOptions.useWorkgroupMemory = clUseWorkgroupMemory;
targetOptions.vulkanTargetEnv = clVulkanTargetEnv;
return targetOptions;
diff --git a/iree/compiler/Dialect/IREE/IR/IREEDialect.cpp b/iree/compiler/Dialect/IREE/IR/IREEDialect.cpp
index cf95bff..216a258 100644
--- a/iree/compiler/Dialect/IREE/IR/IREEDialect.cpp
+++ b/iree/compiler/Dialect/IREE/IR/IREEDialect.cpp
@@ -65,21 +65,14 @@
}
void IREEDialect::printType(Type type, DialectAsmPrinter& os) const {
- switch (type.getKind()) {
- case IREE::TypeKind::Ptr: {
- auto targetType = type.cast<IREE::PtrType>().getTargetType();
- os << "ptr<" << targetType << ">";
- break;
- }
- case IREE::TypeKind::ByteBuffer:
- os << "byte_buffer";
- break;
- case IREE::TypeKind::MutableByteBuffer:
- os << "mutable_byte_buffer";
- break;
- default:
- llvm_unreachable("unhandled IREE type");
- }
+ if (auto ptrType = type.dyn_cast<IREE::PtrType>())
+ os << "ptr<" << ptrType.getTargetType() << ">";
+ else if (type.isa<IREE::ByteBufferType>())
+ os << "byte_buffer";
+ else if (type.isa<IREE::MutableByteBufferType>())
+ os << "mutable_byte_buffer";
+ else
+ llvm_unreachable("unhandled IREE type");
}
} // namespace iree_compiler
diff --git a/iree/compiler/Dialect/IREE/IR/IREETypes.h b/iree/compiler/Dialect/IREE/IR/IREETypes.h
index 7170a50..db40e1a 100644
--- a/iree/compiler/Dialect/IREE/IR/IREETypes.h
+++ b/iree/compiler/Dialect/IREE/IR/IREETypes.h
@@ -143,7 +143,6 @@
public:
static PtrType get(Type targetType);
static PtrType getChecked(Type targetType, Location location);
- static bool kindof(unsigned kind) { return kind == TypeKind::Ptr; }
using Base::Base;
@@ -156,8 +155,6 @@
public:
using Base::Base;
- static bool kindof(unsigned kind) { return kind == TypeKind::ByteBuffer; }
-
static ByteBufferType get(MLIRContext *context) {
return Base::get(context, TypeKind::ByteBuffer);
}
@@ -169,10 +166,6 @@
public:
using Base::Base;
- static bool kindof(unsigned kind) {
- return kind == TypeKind::MutableByteBuffer;
- }
-
static MutableByteBufferType get(MLIRContext *context) {
return Base::get(context, TypeKind::MutableByteBuffer);
}
diff --git a/iree/compiler/Dialect/IREE/Tools/StructAttrGen.cpp b/iree/compiler/Dialect/IREE/Tools/StructAttrGen.cpp
index b56be53..1e19cc7 100644
--- a/iree/compiler/Dialect/IREE/Tools/StructAttrGen.cpp
+++ b/iree/compiler/Dialect/IREE/Tools/StructAttrGen.cpp
@@ -106,7 +106,6 @@
using Base::Base;
static StringRef getKindName() { return "{2}"; }
- static bool kindof(unsigned kind) { return kind == AttrKind::{1}; }
)",
structAttr.getDescription(), structAttr.getStructClassName(),
diff --git a/iree/compiler/Dialect/Modules/Strings/IR/Types.h b/iree/compiler/Dialect/Modules/Strings/IR/Types.h
index d67078e..86854a2 100644
--- a/iree/compiler/Dialect/Modules/Strings/IR/Types.h
+++ b/iree/compiler/Dialect/Modules/Strings/IR/Types.h
@@ -30,7 +30,6 @@
static StringType get(MLIRContext *context) {
return Base::get(context, TypeKind::String);
}
- static bool kindof(unsigned kind) { return kind == TypeKind::String; }
};
class StringTensorType
@@ -40,7 +39,6 @@
static StringTensorType get(MLIRContext *context) {
return Base::get(context, TypeKind::StringTensor);
}
- static bool kindof(unsigned kind) { return kind == TypeKind::StringTensor; }
};
} // namespace Strings
diff --git a/iree/compiler/Dialect/Modules/TensorList/IR/TensorListTypes.h b/iree/compiler/Dialect/Modules/TensorList/IR/TensorListTypes.h
index bc76f9c..1a23958 100644
--- a/iree/compiler/Dialect/Modules/TensorList/IR/TensorListTypes.h
+++ b/iree/compiler/Dialect/Modules/TensorList/IR/TensorListTypes.h
@@ -37,7 +37,6 @@
static TensorListType get(MLIRContext *context) {
return Base::get(context, TypeKind::kTensorList);
}
- static bool kindof(unsigned kind) { return kind == TypeKind::kTensorList; }
};
} // namespace TensorList
diff --git a/iree/compiler/Dialect/Sequence/IR/SequenceDialect.cpp b/iree/compiler/Dialect/Sequence/IR/SequenceDialect.cpp
index 20a640b..4e0a33d 100644
--- a/iree/compiler/Dialect/Sequence/IR/SequenceDialect.cpp
+++ b/iree/compiler/Dialect/Sequence/IR/SequenceDialect.cpp
@@ -53,17 +53,8 @@
}
void SequenceDialect::printType(Type type, DialectAsmPrinter& os) const {
- switch (type.getKind()) {
- case TypeKind::Sequence: {
- auto targetType = type.cast<SequenceType>().getTargetType();
- os << "of<";
- os.printType(targetType);
- os << ">";
- break;
- }
- default:
- llvm_unreachable("unhandled sequence type");
- }
+ if (auto sequenceType = type.dyn_cast<SequenceType>())
+ os << "of<" << sequenceType.getTargetType() << ">";
}
} // namespace Sequence
diff --git a/iree/compiler/Dialect/Sequence/IR/SequenceTypes.h b/iree/compiler/Dialect/Sequence/IR/SequenceTypes.h
index 3e98f08..21a0099 100644
--- a/iree/compiler/Dialect/Sequence/IR/SequenceTypes.h
+++ b/iree/compiler/Dialect/Sequence/IR/SequenceTypes.h
@@ -32,7 +32,6 @@
public:
using Base::Base;
- static bool kindof(unsigned kind) { return kind == TypeKind::Sequence; }
static SequenceType get(Type targetType);
static SequenceType getChecked(Type targetType, Location location);
diff --git a/iree/compiler/Dialect/Shape/IR/ShapeDialect.cpp b/iree/compiler/Dialect/Shape/IR/ShapeDialect.cpp
index 451bc1d..90d592a 100644
--- a/iree/compiler/Dialect/Shape/IR/ShapeDialect.cpp
+++ b/iree/compiler/Dialect/Shape/IR/ShapeDialect.cpp
@@ -164,13 +164,9 @@
}
void ShapeDialect::printType(Type type, DialectAsmPrinter& os) const {
- switch (type.getKind()) {
- case Shape::TypeKind::RankedShape:
- printRankedShape(type.cast<Shape::RankedShapeType>(), os);
- break;
- default:
- llvm_unreachable("unhandled Shape type");
- }
+ if (auto rankedShapeTy = type.dyn_cast<Shape::RankedShapeType>())
+ return printRankedShape(type.cast<Shape::RankedShapeType>(), os);
+ llvm_unreachable("unhandled Shape type");
}
} // namespace iree_compiler
diff --git a/iree/compiler/Dialect/Shape/IR/ShapeTypes.h b/iree/compiler/Dialect/Shape/IR/ShapeTypes.h
index 2946627..45ea8cf 100644
--- a/iree/compiler/Dialect/Shape/IR/ShapeTypes.h
+++ b/iree/compiler/Dialect/Shape/IR/ShapeTypes.h
@@ -39,9 +39,6 @@
public:
using Base::Base;
- /// Support method to enable LLVM-style type casting.
- static bool kindof(unsigned kind) { return kind == TypeKind::RankedShape; }
-
// Gets an instance of a RankedShapeType given an array of dimensions.
// Any dynamic dim should be -1.
static RankedShapeType get(ArrayRef<int64_t> dims, MLIRContext *context);
diff --git a/iree/compiler/Dialect/VM/Analysis/RegisterAllocation.cpp b/iree/compiler/Dialect/VM/Analysis/RegisterAllocation.cpp
index 36ba5cc..ce0962d 100644
--- a/iree/compiler/Dialect/VM/Analysis/RegisterAllocation.cpp
+++ b/iree/compiler/Dialect/VM/Analysis/RegisterAllocation.cpp
@@ -340,19 +340,18 @@
int indegree = 0;
int outdegree = 0;
};
- SmallVector<FASNode, 8> nodeStorage;
- llvm::SmallDenseMap<NodeID, FASNode *> nodes;
+ // This should not be modified after creation in this loop. We take pointers
+ // to its entries so do not want to invalidate them with reallocation.
+ llvm::SmallDenseMap<NodeID, FASNode> nodes;
for (auto &edge : inputEdges) {
NodeID sourceID = edge.first.asBaseRegister();
NodeID sinkID = edge.second.asBaseRegister();
assert(sourceID != sinkID && "self-cycles not supported");
if (nodes.count(sourceID) == 0) {
- nodeStorage.push_back({sourceID, 0, 0});
- nodes.insert({sourceID, &nodeStorage.back()});
+ nodes.insert({sourceID, {sourceID, 0, 0}});
}
if (nodes.count(sinkID) == 0) {
- nodeStorage.push_back({sinkID, 0, 0});
- nodes.insert({sinkID, &nodeStorage.back()});
+ nodes.insert({sinkID, {sinkID, 0, 0}});
}
}
@@ -366,13 +365,13 @@
for (auto &edge : inputEdges) {
NodeID sourceID = edge.first.asBaseRegister();
NodeID sinkID = edge.second.asBaseRegister();
- auto *sourceNode = nodes[sourceID];
- ++sourceNode->outdegree;
- maxOutdegree = std::max(maxOutdegree, sourceNode->outdegree);
- auto *sinkNode = nodes[sinkID];
- ++sinkNode->indegree;
- maxIndegree = std::max(maxIndegree, sinkNode->indegree);
- edges.push_back({sourceNode, sinkNode});
+ auto &sourceNode = nodes[sourceID];
+ ++sourceNode.outdegree;
+ maxOutdegree = std::max(maxOutdegree, sourceNode.outdegree);
+ auto &sinkNode = nodes[sinkID];
+ ++sinkNode.indegree;
+ maxIndegree = std::max(maxIndegree, sinkNode.indegree);
+ edges.push_back({&sourceNode, &sinkNode});
}
std::vector<SmallVector<FASNode *, 2>> buckets;
@@ -392,8 +391,10 @@
buckets[index].erase(it);
}
};
+ llvm::SmallPtrSet<FASNode *, 8> remainingNodes;
for (auto &nodeEntry : nodes) {
- assignBucket(nodeEntry.second);
+ assignBucket(&nodeEntry.getSecond());
+ remainingNodes.insert(&nodeEntry.getSecond());
}
auto removeNode = [&](FASNode *node) {
@@ -416,6 +417,7 @@
}
removeBucket(edge.source);
--edge.source->outdegree;
+ assert(edge.source->outdegree >= 0 && "outdegree has become negative");
assignBucket(edge.source);
}
for (auto &edge : outEdges) {
@@ -425,10 +427,11 @@
}
removeBucket(edge.sink);
--edge.sink->indegree;
+ assert(edge.sink->indegree >= 0 && "indegree has become negative");
assignBucket(edge.sink);
}
- nodes.erase(node->id);
+ remainingNodes.erase(node);
edges.erase(std::remove_if(edges.begin(), edges.end(),
[&](const FASEdge &edge) {
return edge.source == node ||
@@ -438,13 +441,13 @@
return results;
};
auto ends = buckets.back();
- while (!nodes.empty()) {
+ while (!remainingNodes.empty()) {
while (!ends.empty()) {
auto *node = ends.front();
ends.erase(ends.begin());
removeNode(node);
}
- if (nodes.empty()) break;
+ if (remainingNodes.empty()) break;
for (ssize_t i = buckets.size() - 1; i >= 0; --i) {
if (buckets[i].empty()) continue;
auto *bucket = buckets[i].front();
diff --git a/iree/compiler/Dialect/VM/IR/VMDialect.cpp b/iree/compiler/Dialect/VM/IR/VMDialect.cpp
index 8c2c1db..ee0f5d6 100644
--- a/iree/compiler/Dialect/VM/IR/VMDialect.cpp
+++ b/iree/compiler/Dialect/VM/IR/VMDialect.cpp
@@ -246,26 +246,21 @@
}
void VMDialect::printType(Type type, DialectAsmPrinter &os) const {
- switch (type.getKind()) {
- case IREE::VM::TypeKind::Ref: {
- auto objectType = type.cast<IREE::VM::RefType>().getObjectType();
- if (auto listType = objectType.dyn_cast<IREE::VM::ListType>()) {
- printType(listType, os);
- } else if (objectType.isa<IREE::VM::OpaqueType>()) {
- os << "ref<?>";
- } else {
- os << "ref<" << objectType << ">";
- }
- break;
+ if (auto refType = type.dyn_cast<IREE::VM::RefType>()) {
+ auto objectType = refType.getObjectType();
+ if (auto listType = objectType.dyn_cast<IREE::VM::ListType>()) {
+ printType(listType, os);
+ } else if (objectType.isa<IREE::VM::OpaqueType>()) {
+ os << "ref<?>";
+ } else {
+ os << "ref<" << objectType << ">";
}
- case IREE::VM::TypeKind::Opaque:
- os << "opaque";
- break;
- case IREE::VM::TypeKind::List:
- os << "list<" << type.cast<IREE::VM::ListType>().getElementType() << ">";
- break;
- default:
- llvm_unreachable("unhandled VM type");
+ } else if (type.isa<IREE::VM::OpaqueType>()) {
+ os << "opaque";
+ } else if (auto listType = type.dyn_cast<IREE::VM::ListType>()) {
+ os << "list<" << listType.getElementType() << ">";
+ } else {
+ llvm_unreachable("unhandled VM type");
}
}
diff --git a/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp b/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp
index 3bd554c..819373f 100644
--- a/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp
+++ b/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp
@@ -46,26 +46,14 @@
/// Creates a constant one attribute matching the given type.
Attribute oneOfType(Type type) {
Builder builder(type.getContext());
- switch (type.getKind()) {
- case StandardTypes::BF16:
- case StandardTypes::F16:
- case StandardTypes::F32:
- case StandardTypes::F64:
- return builder.getFloatAttr(type, 1.0);
- case StandardTypes::Integer: {
- auto width = type.cast<IntegerType>().getWidth();
- if (width == 1) return builder.getBoolAttr(true);
- return builder.getIntegerAttr(type, APInt(width, 1));
- }
- case StandardTypes::Vector:
- case StandardTypes::RankedTensor: {
- auto vtType = type.cast<ShapedType>();
- auto element = oneOfType(vtType.getElementType());
- if (!element) return {};
- return DenseElementsAttr::get(vtType, element);
- }
- default:
- break;
+ if (type.isa<FloatType>()) return builder.getFloatAttr(type, 1.0);
+ if (auto integerTy = type.dyn_cast<IntegerType>())
+ return builder.getIntegerAttr(integerTy, APInt(integerTy.getWidth(), 1));
+ if (type.isa<RankedTensorType, VectorType>()) {
+ auto vtType = type.cast<ShapedType>();
+ auto element = oneOfType(vtType.getElementType());
+ if (!element) return {};
+ return DenseElementsAttr::get(vtType, element);
}
return {};
}
diff --git a/iree/compiler/Dialect/VM/IR/VMTypes.h b/iree/compiler/Dialect/VM/IR/VMTypes.h
index 67bc785..2a85ae6 100644
--- a/iree/compiler/Dialect/VM/IR/VMTypes.h
+++ b/iree/compiler/Dialect/VM/IR/VMTypes.h
@@ -64,8 +64,6 @@
}
Type getElementType();
-
- static bool kindof(unsigned kind) { return kind == TypeKind::List; }
};
/// An opaque ref object that comes from an external source.
@@ -73,8 +71,6 @@
public:
using Base::Base;
- static bool kindof(unsigned kind) { return kind == TypeKind::Opaque; }
-
static OpaqueType get(MLIRContext *context) {
return Base::get(context, TypeKind::Opaque);
}
@@ -106,8 +102,6 @@
}
Type getObjectType();
-
- static bool kindof(unsigned kind) { return kind == TypeKind::Ref; }
};
} // namespace VM
diff --git a/iree/compiler/Dialect/VMLA/IR/VMLATypes.h b/iree/compiler/Dialect/VMLA/IR/VMLATypes.h
index 634b903..0bebaf7 100644
--- a/iree/compiler/Dialect/VMLA/IR/VMLATypes.h
+++ b/iree/compiler/Dialect/VMLA/IR/VMLATypes.h
@@ -49,7 +49,6 @@
static BufferType get(MLIRContext *context) {
return Base::get(context, TypeKind::Buffer);
}
- static bool kindof(unsigned kind) { return kind == TypeKind::Buffer; }
};
class InterfaceType : public Type::TypeBase<InterfaceType, Type, TypeStorage> {
@@ -58,7 +57,6 @@
static InterfaceType get(MLIRContext *context) {
return Base::get(context, TypeKind::Interface);
}
- static bool kindof(unsigned kind) { return kind == TypeKind::Interface; }
};
} // namespace VMLA
diff --git a/iree/compiler/Dialect/Vulkan/IR/VulkanAttributes.h b/iree/compiler/Dialect/Vulkan/IR/VulkanAttributes.h
index 0ae3957..f9dbc38 100644
--- a/iree/compiler/Dialect/Vulkan/IR/VulkanAttributes.h
+++ b/iree/compiler/Dialect/Vulkan/IR/VulkanAttributes.h
@@ -70,8 +70,6 @@
/// bits.
CapabilitiesAttr getCapabilitiesAttr();
- static bool kindof(unsigned kind) { return kind == AttrKind::TargetEnv; }
-
static LogicalResult verifyConstructionInvariants(
Location loc, IntegerAttr version, IntegerAttr revision,
ArrayAttr extensions, DictionaryAttr capabilities);
diff --git a/iree/samples/custom_modules/dialect/custom_dialect.h b/iree/samples/custom_modules/dialect/custom_dialect.h
index 7a5c9b1..fe35ed4 100644
--- a/iree/samples/custom_modules/dialect/custom_dialect.h
+++ b/iree/samples/custom_modules/dialect/custom_dialect.h
@@ -47,7 +47,6 @@
static MessageType get(MLIRContext *context) {
return Base::get(context, TypeKind::Message);
}
- static bool kindof(unsigned kind) { return kind == TypeKind::Message; }
};
#define GET_OP_CLASSES
diff --git a/iree/test/e2e/vulkan_specific/BUILD b/iree/test/e2e/vulkan_specific/BUILD
index 481366f..9686098 100644
--- a/iree/test/e2e/vulkan_specific/BUILD
+++ b/iree/test/e2e/vulkan_specific/BUILD
@@ -27,7 +27,9 @@
name = "check_vulkan-spirv_vulkan",
srcs = glob(
["*.mlir"],
- exclude = ["gemm.mlir"],
+ exclude = [
+ "gemm.mlir",
+ ],
),
driver = "vulkan",
target_backend = "vulkan-spirv",
@@ -35,7 +37,10 @@
iree_check_single_backend_test_suite(
name = "check_vulkan-spirv_vulkan_wgmem",
- srcs = ["gemm.mlir"],
+ srcs = [
+ "conv.mlir",
+ "gemm.mlir",
+ ],
compiler_flags = ["-iree-spirv-use-workgroup-memory"],
driver = "vulkan",
target_backend = "vulkan-spirv",
diff --git a/iree/test/e2e/vulkan_specific/CMakeLists.txt b/iree/test/e2e/vulkan_specific/CMakeLists.txt
index edf9415..4260b07 100644
--- a/iree/test/e2e/vulkan_specific/CMakeLists.txt
+++ b/iree/test/e2e/vulkan_specific/CMakeLists.txt
@@ -32,6 +32,7 @@
NAME
check_vulkan-spirv_vulkan_wgmem
SRCS
+ "conv.mlir"
"gemm.mlir"
TARGET_BACKEND
vulkan-spirv
diff --git a/iree/test/e2e/vulkan_specific/conv.mlir b/iree/test/e2e/vulkan_specific/conv.mlir
new file mode 100644
index 0000000..e2dc59d
--- /dev/null
+++ b/iree/test/e2e/vulkan_specific/conv.mlir
@@ -0,0 +1,83 @@
+func @conv() attributes { iree.module.export } {
+ %0 = iree.unfoldable_constant dense<
+ [[[[0.5 , 0.5212766 ],
+ [0.54255319, 0.56382979],
+ [0.58510638, 0.60638298],
+ [0.62765957, 0.64893617],
+ [0.67021277, 0.69148936],
+ [0.71276596, 0.73404255]],
+
+ [[0.75531915, 0.77659574],
+ [0.79787234, 0.81914894],
+ [0.84042553, 0.86170213],
+ [0.88297872, 0.90425532],
+ [0.92553191, 0.94680851],
+ [0.96808511, 0.9893617 ]],
+
+ [[1.0106383 , 1.03191489],
+ [1.05319149, 1.07446809],
+ [1.09574468, 1.11702128],
+ [1.13829787, 1.15957447],
+ [1.18085106, 1.20212766],
+ [1.22340426, 1.24468085]],
+
+ [[1.26595745, 1.28723404],
+ [1.30851064, 1.32978723],
+ [1.35106383, 1.37234043],
+ [1.39361702, 1.41489362],
+ [1.43617021, 1.45744681],
+ [1.4787234 , 1.5 ]]]]> : tensor<1x4x6x2xf32>
+ %1 = iree.unfoldable_constant dense<
+ [[[[0.5 , 0.52857143, 0.55714286],
+ [0.58571429, 0.61428571, 0.64285714]],
+
+ [[0.67142857, 0.7 , 0.72857143],
+ [0.75714286, 0.78571429, 0.81428571]],
+
+ [[0.84285714, 0.87142857, 0.9 ],
+ [0.92857143, 0.95714286, 0.98571429]]],
+
+
+ [[[1.01428571, 1.04285714, 1.07142857],
+ [1.1 , 1.12857143, 1.15714286]],
+
+ [[1.18571429, 1.21428571, 1.24285714],
+ [1.27142857, 1.3 , 1.32857143]],
+
+ [[1.35714286, 1.38571429, 1.41428571],
+ [1.44285714, 1.47142857, 1.5 ]]]]>
+ : tensor<2x3x2x3xf32>
+ %2 = "mhlo.convolution"(%0, %1) {
+ batch_group_count = 1 : i64,
+ dimension_numbers = {
+ input_batch_dimension = 0 : i64,
+ input_feature_dimension = 3 : i64,
+ input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>,
+ kernel_input_feature_dimension = 2 : i64,
+ kernel_output_feature_dimension = 3 : i64,
+ kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>,
+ output_batch_dimension = 0 : i64,
+ output_feature_dimension = 3 : i64,
+ output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>},
+ feature_group_count = 1 : i64,
+ rhs_dilation = dense<1> : tensor<2xi64>,
+ window_strides = dense<1> : tensor<2xi64>}
+ : (tensor<1x4x6x2xf32>, tensor<2x3x2x3xf32>) -> (tensor<1x3x4x3xf32>)
+ check.expect_almost_eq_const(%2, dense<
+ [[[[ 8.39452888, 8.62796353, 8.86139818],
+ [ 8.89057751, 9.13860182, 9.38662614],
+ [ 9.38662614, 9.64924012, 9.9118541 ],
+ [ 9.88267477, 10.15987842, 10.43708207]],
+
+ [[11.37082067, 11.69179331, 12.01276596],
+ [11.8668693 , 12.20243161, 12.53799392],
+ [12.36291793, 12.71306991, 13.06322188],
+ [12.85896657, 13.22370821, 13.58844985]],
+
+ [[14.34711246, 14.7556231 , 15.16413374],
+ [14.84316109, 15.2662614 , 15.6893617 ],
+ [15.33920973, 15.7768997 , 16.21458967],
+ [15.83525836, 16.28753799, 16.73981763]]]]>
+ : tensor<1x3x4x3xf32>) : tensor<1x3x4x3xf32>
+ return
+}
diff --git a/third_party/llvm-project b/third_party/llvm-project
index 950f1bf..30c1633 160000
--- a/third_party/llvm-project
+++ b/third_party/llvm-project
@@ -1 +1 @@
-Subproject commit 950f1bf976b332eca60267b25bf759e2ad564e0c
+Subproject commit 30c1633386e7cfb01c0a54b31ccf4c3a3873e71b
diff --git a/third_party/mlir-emitc b/third_party/mlir-emitc
index 80885f8..a3479bb 160000
--- a/third_party/mlir-emitc
+++ b/third_party/mlir-emitc
@@ -1 +1 @@
-Subproject commit 80885f899e12d55a45561ef758eea47bb340dbf1
+Subproject commit a3479bbf9161df8c8cac55a08205864e6f371491
diff --git a/third_party/tensorflow b/third_party/tensorflow
index 8a4ffe2..86efb18 160000
--- a/third_party/tensorflow
+++ b/third_party/tensorflow
@@ -1 +1 @@
-Subproject commit 8a4ffe2e1ae722cff5306778df0cfca8b7f503fe
+Subproject commit 86efb18ca5812c76dd52c8536f336e6962b7f8ca