Adding iree_hal_device_queue_update and improving queue DMA operations. (#19000)
As with all queue DMA operations it's best if things are batched into
command buffers but it's bad to have a command buffer with a single DMA
operation - this completes the set of fill/update/copy operations at the
queue level to match the command buffer DMA operations. Practically this
is useful when combined with reusable/indirect command buffers for
uploading new parameters in queue order prior to issuing a command
buffer that references them. The compiler will use this to turn push
constants into uniform buffers. An emulated version is added but
implementations are encouraged to do better... they currently don't.
While updating the queue API I've added placeholder flags to all DMA
operations in preparation for compiler updates that will provide them.
`iree_hal_device_queue_execute` has needed simplification for awhile and
that's done here to allow implementations to not need to worry with
batched command buffer juggling. The unused-since-its-inception
`iree_hal_command_buffer_discard_buffer` API has been renamed to
`iree_hal_command_buffer_advise_buffer` ahead of compiler changes that
will use it for multi-device cache management.
No breaking changes to the compiler here - future PRs will update the
HAL module and ops.
diff --git a/compiler/bindings/python/CMakeLists.txt b/compiler/bindings/python/CMakeLists.txt
index bc8119f..76caebb 100644
--- a/compiler/bindings/python/CMakeLists.txt
+++ b/compiler/bindings/python/CMakeLists.txt
@@ -249,11 +249,15 @@
SOURCES
__init__.py
__main__.py
+ args.py
+ compile_actions.py
executor.py
lang.py
main.py
+ metadata.py
net_actions.py
onnx_actions.py
+ target_machine.py
)
add_mlir_python_modules(IREECompilerBuildPythonModules
diff --git a/compiler/bindings/python/iree/build/__init__.py b/compiler/bindings/python/iree/build/__init__.py
index 95ee3f7..3cc054e 100644
--- a/compiler/bindings/python/iree/build/__init__.py
+++ b/compiler/bindings/python/iree/build/__init__.py
@@ -8,5 +8,7 @@
from iree.build.lang import *
from iree.build.main import *
+
+from iree.build.compile_actions import *
from iree.build.net_actions import *
from iree.build.onnx_actions import *
diff --git a/compiler/bindings/python/iree/build/args.py b/compiler/bindings/python/iree/build/args.py
new file mode 100644
index 0000000..9a2482a
--- /dev/null
+++ b/compiler/bindings/python/iree/build/args.py
@@ -0,0 +1,187 @@
+# Copyright 2024 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+from typing import Callable, Generator, TypeVar
+
+import argparse
+import contextlib
+import functools
+import inspect
+import threading
+
+from typing import Callable
+
+_locals = threading.local()
+_ALL_ARG_REGISTRARS: list[Callable[[argparse.ArgumentParser], None]] = []
+_ALL_ARG_HANDLERS: list[Callable[[argparse.Namespace], None]] = []
+
+
+def register_arg_parser_callback(registrar: Callable[[argparse.ArgumentParser], None]):
+ """Decorator that adds a global argument registration callback.
+
+ This callback will be invoked when a new ArgumentParser is constructed.
+ """
+ _ALL_ARG_REGISTRARS.append(registrar)
+ return registrar
+
+
+def register_arg_handler_callback(handler: Callable[[argparse.Namespace], None]):
+ """Decorator that registers a handler to be run on global arguments at startup."""
+ _ALL_ARG_HANDLERS.append(handler)
+ return handler
+
+
+def configure_arg_parser(p: argparse.ArgumentParser):
+ """Invokes all callbacks from `register_arg_parser_callback` on the parser."""
+ for callback in _ALL_ARG_REGISTRARS:
+ callback(p)
+
+
+def run_global_arg_handlers(ns: argparse.Namespace):
+ """Invokes all global argument handlers."""
+ for h in _ALL_ARG_HANDLERS:
+ h(ns)
+
+
+@contextlib.contextmanager
+def argument_namespace_context(ns: argparse.Namespace):
+ """Establish that given namespace as the current namespace for this thread.
+
+ Note that as a thread local, this does not propagate to child threads or
+ sub-processes. This means that all argument management must be done during
+ action setup and action invocations will not typically have access to args.
+ """
+ if not hasattr(_locals, "arg_ns_stack"):
+ _locals.arg_ns_stack = []
+ _locals.arg_ns_stack.append(ns)
+ try:
+ yield ns
+ finally:
+ _locals.arg_ns_stack.pop()
+
+
+def current_args_namespace() -> argparse.Namespace:
+ try:
+ return _locals.arg_ns_stack[-1]
+ except (AttributeError, IndexError):
+ raise AssertionError(
+ "No current argument namespace: Is it possible you are trying to resolve "
+ "arguments from another thread or process"
+ )
+
+
+_Decorated = TypeVar("_Decorated", bound=Callable)
+
+
+def expand_cl_arg_defaults(wrapped: _Decorated) -> _Decorated:
+ sig = inspect.signature(wrapped)
+
+ def wrapper(*args, **kwargs):
+ args_ns = current_args_namespace()
+ bound = sig.bind(*args, **kwargs)
+ bound.apply_defaults()
+
+ def filter(arg):
+ if isinstance(arg, ClArgRef):
+ return arg.resolve(args_ns)
+ return arg
+
+ new_args = [filter(arg) for arg in bound.args]
+ new_kwargs = {k: filter(v) for k, v in bound.kwargs.items()}
+ return wrapped(*new_args, **new_kwargs)
+
+ functools.update_wrapper(wrapper, wrapped)
+ return wrapper
+
+
+class ClArgRef:
+ """Used in default values of function arguments to indicate that the default should
+ be derived from an argument reference.
+
+ Actually defining the argument must be done elsewhere.
+
+ See `cl_arg_ref()` for canonical use.
+ """
+
+ def __init__(self, dest: str):
+ self.dest = dest
+
+ def resolve(self, arg_namespace: argparse.Namespace):
+ try:
+ return getattr(arg_namespace, self.dest)
+ except AttributeError as e:
+ raise RuntimeError(
+ f"Unable to resolve command line argument '{self.dest}' in namespace"
+ ) from e
+
+
+def cl_arg_ref(dest: str):
+ """Used as a default value for functions wrapped in @expand_cl_defaults to indicate
+ that an argument must come from the command line environment.
+
+ Note that this does not have a typing annotation, allowing the argument to be
+ annotated with a type, assuming that resolution will happen dynamically in some
+ fashion.
+ """
+ return ClArgRef(dest)
+
+
+class ClArg(ClArgRef):
+ """Used in default values of function arguments to indicate that an argument needs
+ to be defined and referenced.
+
+ This is used in user-defined entry points, and the executor has special logic to
+ collect all needed arguments automatically.
+
+ See `cl_arg()` for canonical use.
+ """
+
+ def __init__(self, name, dest: str, **add_argument_kw):
+ super().__init__(dest)
+ self.name = name
+ self.add_argument_kw = add_argument_kw
+
+ def define_arg(self, parser: argparse.ArgumentParser):
+ parser.add_argument(f"--{self.name}", dest=self.dest, **self.add_argument_kw)
+
+
+def cl_arg(name: str, *, action=None, default=None, type=None, help=None):
+ """Used to define or reference a command-line argument from within actions
+ and entry-points.
+
+ Keywords have the same interpretation as `ArgumentParser.add_argument()`.
+
+ Any ClArg set as a default value for an argument to an `entrypoint` will be
+ added to the global argument parser. Any particular argument name can only be
+ registered once and must not conflict with a built-in command line option.
+ The implication of this is that for single-use arguments, the `=cl_arg(...)`
+ can just be added as a default argument. Otherwise, for shared arguments,
+ it should be created at the module level and referenced.
+
+ When called, any entrypoint arguments that do not have an explicit keyword
+ set will get their value from the command line environment.
+
+ Note that this does not have a typing annotation, allowing the argument to be
+ annotated with a type, assuming that resolution will happen dynamically in some
+ fashion.
+ """
+ if name.startswith("-"):
+ raise ValueError("cl_arg name must not be prefixed with dashes")
+ dest = name.replace("-", "_")
+ return ClArg(name, action=action, default=default, type=type, dest=dest, help=help)
+
+
+def extract_cl_arg_defs(callable: Callable) -> Generator[ClArg, None, None]:
+ """Extracts all `ClArg` default values from a callable.
+
+ This is used in order to eagerly register argument definitions for some set
+ of functions.
+ """
+ sig = inspect.signature(callable)
+ for p in sig.parameters.values():
+ def_value = p.default
+ if isinstance(def_value, ClArg):
+ yield def_value
diff --git a/compiler/bindings/python/iree/build/compile_actions.py b/compiler/bindings/python/iree/build/compile_actions.py
new file mode 100644
index 0000000..2c2056c
--- /dev/null
+++ b/compiler/bindings/python/iree/build/compile_actions.py
@@ -0,0 +1,208 @@
+# Copyright 2024 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+import argparse
+import shlex
+
+import iree.compiler.api as compiler_api
+import iree.compiler.tools as compiler_tools
+
+from iree.build.args import (
+ expand_cl_arg_defaults,
+ register_arg_handler_callback,
+ register_arg_parser_callback,
+ cl_arg_ref,
+)
+
+from iree.build.executor import (
+ BuildAction,
+ BuildContext,
+ BuildFile,
+ BuildFileLike,
+ FileNamespace,
+)
+
+from iree.build.metadata import CompileSourceMeta
+from iree.build.target_machine import compute_target_machines_from_flags
+
+__all__ = [
+ "compile",
+]
+
+
+@register_arg_parser_callback
+def _(p: argparse.ArgumentParser):
+ g = p.add_argument_group(
+ title="IREE Compiler Options",
+ description="Global options controlling invocation of iree-compile",
+ )
+ g.add_argument(
+ "--iree-compile-out-of-process",
+ action=argparse.BooleanOptionalAction,
+ help="Invokes iree-compiler as an out of process executable (the default is to "
+ "invoke it in-process via API bindings). This can make debugging somewhat "
+ "easier and also grants access to global command line options that may not "
+ "otherwise be available.",
+ )
+ g.add_argument(
+ "--iree-compile-extra-args",
+ help="Extra arguments to pass to iree-compile. When running in-process, these "
+ "will be passed as globals to the library and effect all compilation in the "
+ "process. These are split with shlex rules.",
+ )
+
+
+@register_arg_handler_callback
+def _(ns: argparse.Namespace):
+ in_process = not ns.iree_compile_out_of_process
+ extra_args_str = ns.iree_compile_extra_args
+ if in_process and extra_args_str:
+ # TODO: This is very unsafe. If called multiple times (i.e. in a library)
+ # or with illegal arguments, the program will abort. The safe way to do
+ # this is to spawn one child process and route all "in process" compilation
+ # there. This would allow explicit control of startup/shutdown and would
+ # provide isolation in the event of a compiler crash. It is still important
+ # for a single process to handle all compilation activities since this
+ # allows global compiler resources (like threads) to be pooled and not
+ # saturate the machine resources.
+ extra_args_list = shlex.split(extra_args_str)
+ compiler_api._initializeGlobalCL("unused_prog_name", *extra_args_list)
+
+
+class CompilerInvocation:
+ @expand_cl_arg_defaults
+ def __init__(
+ self,
+ *,
+ input_file: BuildFile,
+ output_file: BuildFile,
+ out_of_process: bool = cl_arg_ref("iree_compile_out_of_process"),
+ extra_args_str=cl_arg_ref("iree_compile_extra_args"),
+ ):
+ self.input_file = input_file
+ self.output_file = output_file
+ # We manage most flags as keyword values that can have at most one
+ # setting.
+ self.kw_flags: dict[str, str | None] = {}
+ # Flags can also be set free-form. These are always added to the command
+ # line after the kw_flags.
+ self.extra_flags: list[str] = []
+ self.out_of_process = out_of_process
+
+ if extra_args_str:
+ self.extra_args = shlex.split(extra_args_str)
+ else:
+ self.extra_args = []
+
+ def run(self):
+ raw_flags: list[str] = []
+
+ # Set any defaults derived from the input_file metadata. These are set
+ # first because they can be overriden by explicit flag settings.
+ meta = CompileSourceMeta.get(self.input_file)
+ raw_flags.append(f"--iree-input-type={meta.input_type}")
+
+ # Process kw_flags.
+ for key, value in self.kw_flags.items():
+ if value is None:
+ raw_flags.append(f"--{key}")
+ else:
+ raw_flags.append(f"--{key}={value}")
+
+ # Process extra_flags.
+ for raw in self.extra_flags:
+ raw_flags.append(raw)
+
+ if self.out_of_process:
+ self.run_out_of_process(raw_flags)
+ else:
+ self.run_inprocess(raw_flags)
+
+ def run_inprocess(self, flags: list[str]):
+ with compiler_api.Session() as session:
+ session.set_flags(*flags)
+ with compiler_api.Invocation(session) as inv, compiler_api.Source.open_file(
+ session, str(self.input_file.get_fs_path())
+ ) as source, compiler_api.Output.open_file(
+ str(self.output_file.get_fs_path())
+ ) as output:
+ inv.enable_console_diagnostics()
+ inv.parse_source(source)
+ if not inv.execute():
+ raise RuntimeError("COMPILE FAILED (TODO)")
+ inv.output_vm_bytecode(output)
+ output.keep()
+
+ def run_out_of_process(self, flags: list[str]):
+ # TODO: This Python executable wrapper is really long in the tooth. We should
+ # just invoke iree-compile directly (which would also let us have a flag for
+ # the path to it).
+ all_extra_args = self.extra_args + flags
+ compiler_tools.compile_file(
+ str(self.input_file.get_fs_path()),
+ output_file=str(self.output_file.get_fs_path()),
+ extra_args=self.extra_args + flags,
+ )
+
+
+def compile(
+ *,
+ name: str,
+ source: BuildFileLike,
+ target_default: bool = True,
+) -> tuple[BuildFile]:
+ """Invokes iree-compile on a source file, producing binaries for one or more target
+ machines.
+
+ Args:
+ name: The logical name of the compilation command. This is used as the stem
+ for multiple kinds of output files.
+ source: Input source file.
+ target_default: Whether to use command line arguments to compute a target
+ machine configuration (default True). This would be set to False to explicitly
+ depend on target information contained in the source file and not require
+ any target flags passed to the build tool.
+ """
+ context = BuildContext.current()
+ input_file = context.file(source)
+ if target_default:
+ # Compute the target machines from flags and create one compilation for each.
+ tms = compute_target_machines_from_flags()
+ output_files: list[BuildFile] = []
+ for tm in tms:
+ output_file = context.allocate_file(
+ f"{name}_{tm.target_spec}.vmfb", namespace=FileNamespace.BIN
+ )
+ inv = CompilerInvocation(input_file=input_file, output_file=output_file)
+ inv.extra_flags.extend(tm.flag_list)
+ CompileAction(
+ inv,
+ desc=f"Compiling {input_file} (for {tm.target_spec})",
+ executor=context.executor,
+ )
+ output_files.append(output_file)
+ return output_files
+ else:
+ # The compilation is self contained, so just directly compile it.
+ output_file = context.allocate_file(f"{name}.vmfb", namespace=FileNamespace.BIN)
+ inv = CompilerInvocation(input_file=input_file, output_file=output_file)
+ CompileAction(
+ inv,
+ desc=f"Compiling {name}",
+ executor=context.executor,
+ )
+ return output_file
+
+
+class CompileAction(BuildAction):
+ def __init__(self, inv: CompilerInvocation, **kwargs):
+ super().__init__(**kwargs)
+ self.inv = inv
+ self.inv.output_file.deps.add(self)
+ self.deps.add(self.inv.input_file)
+
+ def _invoke(self):
+ self.inv.run()
diff --git a/compiler/bindings/python/iree/build/executor.py b/compiler/bindings/python/iree/build/executor.py
index 35b9173..1d580bd 100644
--- a/compiler/bindings/python/iree/build/executor.py
+++ b/compiler/bindings/python/iree/build/executor.py
@@ -4,20 +4,23 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-from typing import Callable, Collection, Generator, IO
+from typing import Callable, Collection, IO, Type, TypeVar
import abc
-import argparse
import concurrent.futures
import enum
-import inspect
import multiprocessing
-import sys
import time
import traceback
from pathlib import Path
import threading
+from iree.build.args import (
+ current_args_namespace,
+ expand_cl_arg_defaults,
+ extract_cl_arg_defs,
+)
+
_locals = threading.local()
@@ -52,24 +55,6 @@
return f"{prefix}/{suffix}"
-class ClArg:
- def __init__(self, name, dest: str, **add_argument_kw):
- self.name = name
- self.dest = dest
- self.add_argument_kw = add_argument_kw
-
- def define_arg(self, parser: argparse.ArgumentParser):
- parser.add_argument(f"--{self.name}", dest=self.dest, **self.add_argument_kw)
-
- def resolve(self, arg_namespace: argparse.Namespace):
- try:
- return getattr(arg_namespace, self.dest)
- except AttributeError as e:
- raise RuntimeError(
- f"Unable to resolve command line argument '{self.dest}' in namespace"
- ) from e
-
-
class Entrypoint:
def __init__(
self,
@@ -79,17 +64,12 @@
):
self.name = name
self.description = description
- self._wrapped = wrapped
-
- def cl_args(self) -> Generator[ClArg, None, None]:
- sig = inspect.signature(self._wrapped)
- for p in sig.parameters.values():
- def_value = p.default
- if isinstance(def_value, ClArg):
- yield def_value
+ self.cl_arg_defs = list(extract_cl_arg_defs(wrapped))
+ self._wrapped = expand_cl_arg_defaults(wrapped)
def __call__(self, *args, **kwargs):
parent_context = BuildContext.current()
+ args_ns = current_args_namespace()
bep = BuildEntrypoint(
join_namespace(parent_context.path, self.name),
parent_context.executor,
@@ -97,18 +77,7 @@
)
parent_context.executor.entrypoints.append(bep)
with bep:
- sig = inspect.signature(self._wrapped)
- bound = sig.bind(*args, **kwargs)
- bound.apply_defaults()
-
- def filter(arg):
- if isinstance(arg, ClArg):
- return arg.resolve(parent_context.executor.args_namespace)
- return arg
-
- new_args = [filter(arg) for arg in bound.args]
- new_kwargs = {k: filter(v) for k, v in bound.kwargs.items()}
- results = self._wrapped(*new_args, **new_kwargs)
+ results = self._wrapped(*args, **kwargs)
if results is not None:
files = bep.files(results)
bep.deps.update(files)
@@ -119,15 +88,12 @@
class Executor:
"""Executor that all build contexts share."""
- def __init__(
- self, output_dir: Path, args_namespace: argparse.Namespace, stderr: IO
- ):
+ def __init__(self, output_dir: Path, stderr: IO):
self.output_dir = output_dir
self.verbose_level = 0
# Keyed by path
self.all: dict[str, "BuildContext" | "BuildFile"] = {}
self.entrypoints: list["BuildEntrypoint"] = []
- self.args_namespace = args_namespace
self.stderr = stderr
BuildContext("", self)
@@ -189,6 +155,41 @@
scheduler.shutdown()
+BuildMetaType = TypeVar("BuildMetaType", bound="BuildMeta")
+
+
+class BuildMeta:
+ """Base class for typed metadata that can be set on a BuildDependency.
+
+ This is an open namespace where each sub-class must have a unique key as the class
+ level attribute `KEY`.
+ """
+
+ def __init__(self):
+ key = getattr(self, "KEY", None)
+ assert isinstance(key, str), "BuildMeta.KEY must be a str"
+
+ @classmethod
+ def get(cls: Type[BuildMetaType], dep: "BuildDependency") -> BuildMetaType:
+ """Gets a metadata instance of this type from a dependency.
+
+ If it does not yet exist, returns the value of `create_default()`, which
+ by default returns a new instance (which is set on the dep).
+ """
+ key = getattr(cls, "KEY", None)
+ assert isinstance(key, str), f"{cls.__name__}.KEY must be a str"
+ instance = dep._metadata.get(key)
+ if instance is None:
+ instance = cls.create_default()
+ dep._metadata[key] = instance
+ return instance
+
+ @classmethod
+ def create_default(cls) -> "BuildMeta":
+ """Creates a default instance."""
+ return cls()
+
+
class BuildDependency:
"""Base class of entities that can act as a build dependency."""
@@ -205,6 +206,9 @@
self.start_time: float | None = None
self.finish_time: float | None = None
+ # Metadata.
+ self._metadata: dict[str, BuildMeta] = {}
+
@property
def is_scheduled(self) -> bool:
return self.future is not None
@@ -287,8 +291,11 @@
def __repr__(self):
return f"Action[{type(self).__name__}]('{self.desc}')"
- @abc.abstractmethod
def invoke(self):
+ self._invoke()
+
+ @abc.abstractmethod
+ def _invoke(self):
...
@@ -317,7 +324,8 @@
"""
if not path.startswith("/"):
path = join_namespace(self.path, path)
- return BuildFile(executor=self.executor, path=path, namespace=namespace)
+ build_file = BuildFile(executor=self.executor, path=path, namespace=namespace)
+ return build_file
def file(self, file: str | BuildFile) -> BuildFile:
"""Accesses a BuildFile by either string (path) or BuildFile.
@@ -374,9 +382,6 @@
existing = stack.pop()
assert existing is self, "Unbalanced BuildContext enter/exit"
- def populate_arg_parser(self, parser: argparse.ArgumentParser):
- ...
-
class BuildEntrypoint(BuildContext):
def __init__(self, path: str, executor: Executor, entrypoint: Entrypoint):
@@ -506,7 +511,16 @@
return dep
print(f"Scheduling action: {dep}", file=self.stderr)
- dep.start(self.thread_pool_executor.submit(invoke))
+ if dep.concurrnecy == ActionConcurrency.NONE:
+ invoke()
+ elif dep.concurrnecy == ActionConcurrency.THREAD:
+ dep.start(self.thread_pool_executor.submit(invoke))
+ elif dep.concurrnecy == ActionConcurrency.PROCESS:
+ dep.start(self.process_pool_executor.submit(invoke))
+ else:
+ raise AssertionError(
+ f"Unhandled ActionConcurrency value: {dep.concurrnecy}"
+ )
else:
# Not schedulable. Just mark it as done.
dep.start(concurrent.futures.Future())
diff --git a/compiler/bindings/python/iree/build/lang.py b/compiler/bindings/python/iree/build/lang.py
index 5cb8779..60f4c9f 100644
--- a/compiler/bindings/python/iree/build/lang.py
+++ b/compiler/bindings/python/iree/build/lang.py
@@ -9,7 +9,8 @@
import argparse
import functools
-from iree.build.executor import ClArg, Entrypoint
+from iree.build.args import cl_arg # Export as part of the public API
+from iree.build.executor import Entrypoint
__all__ = [
"cl_arg",
@@ -28,25 +29,3 @@
target = Entrypoint(f.__name__, f, description=description)
functools.wraps(target, f)
return target
-
-
-def cl_arg(name: str, *, action=None, default=None, type=None, help=None):
- """Used to define or reference a command-line argument from within actions
- and entry-points.
-
- Keywords have the same interpretation as `ArgumentParser.add_argument()`.
-
- Any ClArg set as a default value for an argument to an `entrypoint` will be
- added to the global argument parser. Any particular argument name can only be
- registered once and must not conflict with a built-in command line option.
- The implication of this is that for single-use arguments, the `=cl_arg(...)`
- can just be added as a default argument. Otherwise, for shared arguments,
- it should be created at the module level and referenced.
-
- When called, any entrypoint arguments that do not have an explicit keyword
- set will get their value from the command line environment.
- """
- if name.startswith("-"):
- raise ValueError("cl_arg name must not be prefixed with dashes")
- dest = name.replace("-", "_")
- return ClArg(name, action=action, default=default, type=type, dest=dest, help=help)
diff --git a/compiler/bindings/python/iree/build/main.py b/compiler/bindings/python/iree/build/main.py
index 272533c..6d40d18 100644
--- a/compiler/bindings/python/iree/build/main.py
+++ b/compiler/bindings/python/iree/build/main.py
@@ -12,7 +12,12 @@
from pathlib import Path
import sys
-from iree.build.executor import BuildEntrypoint, Entrypoint, Executor
+from iree.build.args import (
+ argument_namespace_context,
+ configure_arg_parser,
+ run_global_arg_handlers,
+)
+from iree.build.executor import BuildEntrypoint, BuildFile, Entrypoint, Executor
__all__ = [
"iree_build_main",
@@ -117,6 +122,7 @@
help="Paths of actions to build (default to top-level actions)",
)
+ configure_arg_parser(p)
self._define_action_arguments(p)
self.args = self.arg_parser.parse_args(args)
@@ -126,7 +132,7 @@
def _define_action_arguments(self, p: argparse.ArgumentParser):
user_group = p.add_argument_group("Action defined options")
for ep in self.top_module.entrypoints.values():
- for cl_arg in ep.cl_args():
+ for cl_arg in ep.cl_arg_defs:
cl_arg.define_arg(user_group)
def _resolve_module_arguments(
@@ -169,15 +175,17 @@
return rem_args, top_module
def _create_executor(self) -> Executor:
- executor = Executor(self.args.output_dir, self.args, stderr=self.stderr)
+ executor = Executor(self.args.output_dir, stderr=self.stderr)
executor.analyze(*self.top_module.entrypoints.values())
return executor
def run(self):
- command = self.args.command
- if command is None:
- command = self.build_command
- command()
+ with argument_namespace_context(self.args):
+ run_global_arg_handlers(self.args)
+ command = self.args.command
+ if command is None:
+ command = self.build_command
+ command()
def build_command(self):
executor = self._create_executor()
@@ -203,7 +211,9 @@
for build_action in build_actions:
if isinstance(build_action, BuildEntrypoint):
for output in build_action.outputs:
- print(f"{output.get_fs_path()}", file=self.stdout)
+ print(output.get_fs_path(), file=self.stdout)
+ elif isinstance(build_action, BuildFile):
+ print(build_action.get_fs_path(), file=self.stdout)
def list_command(self):
executor = self._create_executor()
diff --git a/compiler/bindings/python/iree/build/metadata.py b/compiler/bindings/python/iree/build/metadata.py
new file mode 100644
index 0000000..91767dd
--- /dev/null
+++ b/compiler/bindings/python/iree/build/metadata.py
@@ -0,0 +1,37 @@
+# Copyright 2024 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+"""Common `BuildMeta` subclasses for built-in actions.
+
+These are maintained here purely as an aid to avoiding circular dependencies.
+Typically, in out of tree actions, they would just be inlined into the implementation
+file.
+"""
+
+from .executor import BuildMeta
+
+
+class CompileSourceMeta(BuildMeta):
+ """CompileSourceMeta tracks source level properties that can influence compilation.
+
+ This meta can be set on any dependency that ultimately is used as a source to a
+ `compile` action.
+ """
+
+ # Slots in this case simply will catch attempts to set undefined attributes.
+ __slots__ = [
+ "input_type",
+ ]
+ KEY = "iree.compile.source"
+
+ def __init__(self):
+ super().__init__()
+
+ # The value to the --iree-input-type= flag for this source file.
+ self.input_type: str = "auto"
+
+ def __repr__(self):
+ return f"CompileSourceMeta(input_type={self.input_type})"
diff --git a/compiler/bindings/python/iree/build/net_actions.py b/compiler/bindings/python/iree/build/net_actions.py
index 1d2e158..da74d9a 100644
--- a/compiler/bindings/python/iree/build/net_actions.py
+++ b/compiler/bindings/python/iree/build/net_actions.py
@@ -30,7 +30,7 @@
self.url = url
self.output_file = output_file
- def invoke(self):
+ def _invoke(self):
path = self.output_file.get_fs_path()
self.executor.write_status(f"Fetching URL: {self.url} -> {path}")
try:
diff --git a/compiler/bindings/python/iree/build/onnx_actions.py b/compiler/bindings/python/iree/build/onnx_actions.py
index 5220890..83ec3a1 100644
--- a/compiler/bindings/python/iree/build/onnx_actions.py
+++ b/compiler/bindings/python/iree/build/onnx_actions.py
@@ -5,6 +5,7 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from iree.build.executor import BuildAction, BuildContext, BuildFile, BuildFileLike
+from iree.build.metadata import CompileSourceMeta
__all__ = [
"onnx_import",
@@ -30,7 +31,7 @@
input_file=input_file,
output_file=processed_file,
executor=context.executor,
- desc=f"Upgrading ONNX {name}",
+ desc=f"Upgrading ONNX {input_file} -> {processed_file}",
deps=[
input_file,
],
@@ -41,14 +42,13 @@
ImportOnnxAction(
input_file=input_file,
output_file=output_file,
- desc=f"Importing ONNX {name}",
+ desc=f"Importing ONNX {name} -> {output_file}",
executor=context.executor,
deps=[
input_file,
],
)
- output_file.deps.add(processed_file)
return output_file
@@ -57,9 +57,11 @@
super().__init__(**kwargs)
self.input_file = input_file
self.output_file = output_file
+ self.deps.add(self.input_file)
output_file.deps.add(self)
+ CompileSourceMeta.get(output_file).input_type = "onnx"
- def invoke(self):
+ def _invoke(self):
import onnx
input_path = self.input_file.get_fs_path()
@@ -75,9 +77,11 @@
super().__init__(**kwargs)
self.input_file = input_file
self.output_file = output_file
+ self.deps.add(input_file)
output_file.deps.add(self)
+ CompileSourceMeta.get(output_file).input_type = "onnx"
- def invoke(self):
+ def _invoke(self):
import iree.compiler.tools.import_onnx.__main__ as m
args = m.parse_arguments(
diff --git a/compiler/bindings/python/iree/build/target_machine.py b/compiler/bindings/python/iree/build/target_machine.py
new file mode 100644
index 0000000..c8eca7d
--- /dev/null
+++ b/compiler/bindings/python/iree/build/target_machine.py
@@ -0,0 +1,190 @@
+# Copyright 2024 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+"""Handles the messy affair of deriving options for targeting machines."""
+
+import argparse
+
+from iree.build.args import (
+ expand_cl_arg_defaults,
+ register_arg_parser_callback,
+ cl_arg_ref,
+)
+
+
+class TargetMachine:
+ def __init__(
+ self,
+ target_spec: str,
+ *,
+ iree_compile_device_type: str | None = None,
+ extra_flags: list[str] | None = None,
+ ):
+ self.target_spec = target_spec
+ self.iree_compile_device_type = iree_compile_device_type
+ self.extra_flags = extra_flags
+
+ @property
+ def flag_list(self) -> list[str]:
+ if self.iree_compile_device_type is not None:
+ # This is just a hard-coded machine model using a single IREE device
+ # type alias in the default configuration.
+ return [f"--iree-hal-target-device={self.iree_compile_device_type}"] + (
+ self.extra_flags or []
+ )
+ raise RuntimeError(f"Cannot compute iree-compile flags for: {self}")
+
+ def __repr__(self):
+ r = f"TargetMachine({self.target_spec}, "
+ if self.iree_compile_device_type is not None:
+ r += f"iree_compile_device_type='{self.iree_compile_device_type}', "
+ if self.extra_flags:
+ r += f"extra_flags={self.extra_flags}, "
+ r += ")"
+ return r
+
+
+################################################################################
+# Handling of --iree-hal-target-device from flags
+################################################################################
+
+
+HAL_TARGET_DEVICES_FROM_FLAGS_HANDLERS = {}
+
+
+def handle_hal_target_devices_from_flags(*mnemonics: str):
+ def decorator(f):
+ for mn in mnemonics:
+ HAL_TARGET_DEVICES_FROM_FLAGS_HANDLERS[mn] = f
+ return f
+
+ return decorator
+
+
+def handle_unknown_hal_target_device(mnemonic: str) -> list[TargetMachine]:
+ return [TargetMachine(mnemonic, iree_compile_device_type=mnemonic)]
+
+
+@handle_hal_target_devices_from_flags("amdgpu", "hip")
+@expand_cl_arg_defaults
+def amdgpu_hal_target_from_flags(
+ mnemonic: str, *, amdgpu_target=cl_arg_ref("iree_amdgpu_target")
+) -> list[TargetMachine]:
+ if not amdgpu_target:
+ raise RuntimeError(
+ "No AMDGPU targets specified. Pass a chip to target as "
+ "--iree-amdgpu-target=gfx..."
+ )
+ return [
+ TargetMachine(
+ f"amdgpu-{amdgpu_target}",
+ iree_compile_device_type="amdgpu",
+ extra_flags=[f"--iree-hip-target={amdgpu_target}"],
+ )
+ ]
+
+
+@handle_hal_target_devices_from_flags("llvm-cpu", "cpu")
+@expand_cl_arg_defaults
+def cpu_hal_target_from_flags(
+ mnemonic: str,
+ *,
+ cpu=cl_arg_ref("iree_llvmcpu_target_cpu"),
+ features=cl_arg_ref("iree_llvmcpu_target_cpu_features"),
+) -> list[TargetMachine]:
+ target_spec = "cpu"
+ extra_flags = []
+ if cpu:
+ target_spec += f"-{cpu}"
+ extra_flags.append(f"--iree-llvmcpu-target-cpu={cpu}")
+ if features:
+ target_spec += f":{features}"
+ extra_flags.append(f"--iree-llvmcpu-target-cpu-features={features}")
+
+ return [
+ TargetMachine(
+ f"cpu-{cpu or 'generic'}",
+ iree_compile_device_type="llvm-cpu",
+ extra_flags=extra_flags,
+ )
+ ]
+
+
+################################################################################
+# Flag definition
+################################################################################
+
+
+@register_arg_parser_callback
+def _(p: argparse.ArgumentParser):
+ g = p.add_argument_group(
+ title="IREE Target Machine Options",
+ description="Global options controlling invocation of iree-compile",
+ )
+ g.add_argument(
+ "--iree-hal-target-device",
+ help="Compiles with a single machine model and a single specified device"
+ " (mutually exclusive with other ways to set the machine target). This "
+ "emulates the simple case of device targeting if invoking `iree-compile` "
+ "directly and is mostly a pass-through which also enforces other flags "
+ "depending on the value given. Supported options (or any supported by the "
+ "compiler): "
+ f"{', '.join(HAL_TARGET_DEVICES_FROM_FLAGS_HANDLERS.keys() - 'default')}",
+ nargs="*",
+ )
+
+ hip_g = p.add_argument_group(
+ title="IREE AMDGPU Target Options",
+ description="Options controlling explicit targeting of AMDGPU devices",
+ )
+ hip_g.add_argument(
+ "--iree-amdgpu-target",
+ "--iree-hip-target",
+ help="AMDGPU target selection (i.e. 'gfxYYYY')",
+ )
+
+ cpu_g = p.add_argument_group(
+ title="IREE CPU Target Options",
+ description="These are mostly pass-through. See `iree-compile --help` for "
+ "full information. Advanced usage will require an explicit machine config "
+ "file",
+ )
+ cpu_g.add_argument(
+ "--iree-llvmcpu-target-cpu",
+ help="'generic', 'host', or an explicit CPU name. See iree-compile help.",
+ )
+ cpu_g.add_argument(
+ "--iree-llvmcpu-target-cpu-features",
+ help="Comma separated list of '+' prefixed CPU features. See iree-compile help.",
+ )
+
+
+################################################################################
+# Global flag dispatch
+################################################################################
+
+
+@expand_cl_arg_defaults
+def compute_target_machines_from_flags(
+ *,
+ explicit_hal_target_devices: list[str]
+ | None = cl_arg_ref("iree_hal_target_device"),
+) -> list[TargetMachine]:
+ if explicit_hal_target_devices is not None:
+ # Most basic default case for setting up compilation.
+ machines = []
+ for explicit_hal_target_device in explicit_hal_target_devices:
+ handler = (
+ HAL_TARGET_DEVICES_FROM_FLAGS_HANDLERS.get(explicit_hal_target_device)
+ or handle_unknown_hal_target_device
+ )
+ machines.extend(handler(explicit_hal_target_device))
+ return machines
+
+ raise RuntimeError(
+ "iree-compile target information is required but none was provided. "
+ "See flags: --iree-hal-target-device"
+ )
diff --git a/compiler/bindings/python/iree/compiler/api/ctypes_dl.py b/compiler/bindings/python/iree/compiler/api/ctypes_dl.py
index 784fc61..26f7228 100644
--- a/compiler/bindings/python/iree/compiler/api/ctypes_dl.py
+++ b/compiler/bindings/python/iree/compiler/api/ctypes_dl.py
@@ -189,6 +189,11 @@
return len(view) >= 4 and view[:4].hex() == "4d4cef52"
+class SessionObject:
+ def close(self):
+ ...
+
+
class Session:
def __init__(self):
self._global_init = _global_init
@@ -197,12 +202,27 @@
# its ownership of it, so we must cache the new Python-level MLIRContext
# so its lifetime extends at least to our own.
self._owned_context = None
+ self._dependents: set[SessionObject] = set()
def __del__(self):
- _dylib.ireeCompilerSessionDestroy(self._session_p)
+ self.close()
+
+ def close(self):
+ if self._session_p:
+ for dep in list(self._dependents):
+ dep.close()
+ _dylib.ireeCompilerSessionDestroy(self._session_p)
+ self._session_p = c_void_p()
+
+ def __enter__(self) -> "Session":
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ self.close()
@property
def context(self):
+ assert self._session_p, "Session is closed"
if self._owned_context is None:
from .. import ir
@@ -218,9 +238,11 @@
return self._owned_context
def invocation(self) -> "Invocation":
+ assert self._session_p, "Session is closed"
return Invocation(self)
def get_flags(self, non_default_only: bool = False) -> Sequence[str]:
+ assert self._session_p, "Session is closed"
results = []
@_GET_FLAG_CALLBACK
@@ -235,6 +257,7 @@
return results
def set_flags(self, *flags: str):
+ assert self._session_p, "Session is closed"
argv_type = c_char_p * len(flags)
argv = argv_type(*[flag.encode("UTF-8") for flag in flags])
_handle_error(
@@ -257,6 +280,12 @@
self._local_dylib.ireeCompilerOutputDestroy(self._output_p)
self._output_p = None
+ def __enter__(self) -> "Invocation":
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ self.close()
+
@staticmethod
def open_file(file_path: str) -> "Output":
output_p = c_void_p()
@@ -300,11 +329,12 @@
return pointer
-class Source:
+class Source(SessionObject):
"""Wraps an iree_compiler_source_t."""
- def __init__(self, session: c_void_p, source_p: c_void_p, backing_ref):
- self._session: c_void_p = session # Keeps ref alive.
+ def __init__(self, session: Session, source_p: c_void_p, backing_ref):
+ self._session: Session | None = session # Keeps ref alive.
+ self._session._dependents.add(self)
self._source_p: c_void_p = source_p
self._backing_ref = backing_ref
self._local_dylib = _dylib
@@ -318,7 +348,14 @@
self._source_p = c_void_p()
self._local_dylib.ireeCompilerSourceDestroy(s)
self._backing_ref = None
- self._session = c_void_p()
+ self._session._dependents.remove(self)
+ self._session = None
+
+ def __enter__(self) -> "Invocation":
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ self.close()
def __repr__(self):
return f"<Source {self._source_p}>"
@@ -362,9 +399,10 @@
IREE_COMPILER_PIPELINE_PRECOMPILE = 2
-class Invocation:
+class Invocation(SessionObject):
def __init__(self, session: Session):
- self._session = session
+ self._session: Session | None = session
+ self._session._dependents.add(self)
self._inv_p = _dylib.ireeCompilerInvocationCreate(self._session._session_p)
self._sources: list[Source] = []
self._local_dylib = _dylib
@@ -387,6 +425,14 @@
for s in self._sources:
s.close()
self._sources.clear()
+ self._session._dependents.remove(self)
+ self._session = None
+
+ def __enter__(self) -> "Invocation":
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ self.close()
def enable_console_diagnostics(self):
_dylib.ireeCompilerInvocationEnableConsoleDiagnostics(self._inv_p)
diff --git a/compiler/bindings/python/test/build_api/mnist_builder.py b/compiler/bindings/python/test/build_api/mnist_builder.py
index 54d6e30..31c4845 100644
--- a/compiler/bindings/python/test/build_api/mnist_builder.py
+++ b/compiler/bindings/python/test/build_api/mnist_builder.py
@@ -23,7 +23,10 @@
name="mnist.mlir",
source="mnist.onnx",
)
- return "mnist.mlir"
+ return compile(
+ name="mnist",
+ source="mnist.mlir",
+ )
if __name__ == "__main__":
diff --git a/compiler/bindings/python/test/build_api/mnist_builder_test.py b/compiler/bindings/python/test/build_api/mnist_builder_test.py
index 02a9140..fb8cefb 100644
--- a/compiler/bindings/python/test/build_api/mnist_builder_test.py
+++ b/compiler/bindings/python/test/build_api/mnist_builder_test.py
@@ -16,6 +16,11 @@
THIS_DIR = Path(__file__).resolve().parent
+DEFAULT_TARGET_ARGS = [
+ "--iree-hal-target-device=cpu",
+ "--iree-llvmcpu-target-cpu=host",
+]
+
class MnistBuilderTest(unittest.TestCase):
def setUp(self):
@@ -39,14 +44,14 @@
"--output-dir",
str(self.output_path),
]
+ + DEFAULT_TARGET_ARGS
).decode()
print("OUTPUT:", output)
output_paths = output.splitlines()
- self.assertEqual(len(output_paths), 1)
+ self.assertEqual(len(output_paths), 1, msg=f"Found {output_paths}")
output_path = Path(output_paths[0])
self.assertTrue(output_path.is_relative_to(self.output_path))
- contents = output_path.read_text()
- self.assertIn("module", contents)
+ self.assertIn("mnist_cpu-host.vmfb", output_paths[0])
# Tests that invoking via the build module itself works
# python {path to py file}
@@ -59,22 +64,24 @@
"--output-dir",
str(self.output_path),
]
+ + DEFAULT_TARGET_ARGS
).decode()
print("OUTPUT:", output)
output_paths = output.splitlines()
- self.assertEqual(len(output_paths), 1)
+ self.assertEqual(len(output_paths), 1, msg=f"Found {output_paths}")
+ self.assertIn("mnist_cpu-host.vmfb", output_paths[0])
def testListCommand(self):
mod = load_build_module(THIS_DIR / "mnist_builder.py")
out_file = io.StringIO()
- iree_build_main(mod, args=["--list"], stdout=out_file)
+ iree_build_main(mod, args=["--list"] + DEFAULT_TARGET_ARGS, stdout=out_file)
output = out_file.getvalue().strip()
self.assertEqual(output, "mnist")
def testListAllCommand(self):
mod = load_build_module(THIS_DIR / "mnist_builder.py")
out_file = io.StringIO()
- iree_build_main(mod, args=["--list-all"], stdout=out_file)
+ iree_build_main(mod, args=["--list-all"] + DEFAULT_TARGET_ARGS, stdout=out_file)
output = out_file.getvalue().splitlines()
self.assertIn("mnist", output)
self.assertIn("mnist/mnist.onnx", output)
@@ -92,11 +99,24 @@
args=[
"--mnist-onnx-url",
"https://github.com/iree-org/doesnotexist",
- ],
+ ]
+ + DEFAULT_TARGET_ARGS,
stdout=out_file,
stderr=err_file,
)
+ def testBuildNonDefaultSubTarget(self):
+ mod = load_build_module(THIS_DIR / "mnist_builder.py")
+ out_file = io.StringIO()
+ iree_build_main(
+ mod, args=["mnist/mnist.mlir"] + DEFAULT_TARGET_ARGS, stdout=out_file
+ )
+ output = out_file.getvalue().strip()
+ self.assertIn("genfiles/mnist/mnist.mlir", output)
+ output_path = Path(output)
+ contents = output_path.read_text()
+ self.assertIn("module", contents)
+
if __name__ == "__main__":
try:
diff --git a/compiler/plugins/target/LLVMCPU/LLVMIRPasses.cpp b/compiler/plugins/target/LLVMCPU/LLVMIRPasses.cpp
index 1e79555..c045c06 100644
--- a/compiler/plugins/target/LLVMCPU/LLVMIRPasses.cpp
+++ b/compiler/plugins/target/LLVMCPU/LLVMIRPasses.cpp
@@ -56,7 +56,7 @@
case SanitizerKind::kAddress: {
passBuilder.registerOptimizerLastEPCallback(
[](llvm::ModulePassManager &modulePassManager,
- llvm::OptimizationLevel Level) {
+ llvm::OptimizationLevel Level, llvm::ThinOrFullLTOPhase) {
llvm::AddressSanitizerOptions opts;
// Can use Never or Always, just not the default Runtime, which
// introduces a reference to
@@ -73,7 +73,7 @@
case SanitizerKind::kThread: {
passBuilder.registerOptimizerLastEPCallback(
[](llvm::ModulePassManager &modulePassManager,
- llvm::OptimizationLevel Level) {
+ llvm::OptimizationLevel Level, llvm::ThinOrFullLTOPhase) {
modulePassManager.addPass(llvm::ModuleThreadSanitizerPass());
modulePassManager.addPass(llvm::createModuleToFunctionPassAdaptor(
llvm::ThreadSanitizerPass()));
diff --git a/compiler/plugins/target/ROCM/test/opt_pass_plugin/GPUHello.cpp b/compiler/plugins/target/ROCM/test/opt_pass_plugin/GPUHello.cpp
index 6433530..ada0382 100644
--- a/compiler/plugins/target/ROCM/test/opt_pass_plugin/GPUHello.cpp
+++ b/compiler/plugins/target/ROCM/test/opt_pass_plugin/GPUHello.cpp
@@ -67,10 +67,11 @@
llvm::PassPluginLibraryInfo getPassPluginInfo() {
const auto callback = [](llvm::PassBuilder &pb) {
- pb.registerOptimizerLastEPCallback([&](llvm::ModulePassManager &mpm, auto) {
- mpm.addPass(GpuHello());
- return true;
- });
+ pb.registerOptimizerLastEPCallback(
+ [&](llvm::ModulePassManager &mpm, auto, auto) {
+ mpm.addPass(GpuHello());
+ return true;
+ });
};
return {LLVM_PLUGIN_API_VERSION, "gpu-hello", LLVM_VERSION_STRING, callback};
};
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
index dfa6745..0be2db3 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
@@ -236,7 +236,8 @@
case MMAIntrinsic::MFMA_F32_32x32x8_BF16: {
return OpaqueMmaLayout{32, 32, 8, bf16, bf16, f32};
}
- case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ: {
+ case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ:
+ case MMAIntrinsic::VMFMA_F32_16x16x32_F8E4M3FNUZ: {
return OpaqueMmaLayout{16, 16, 32, f8E4M3FNUZ, f8E4M3FNUZ, f32};
}
case MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ: {
@@ -420,6 +421,7 @@
}
case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ:
case MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ:
+ case MMAIntrinsic::VMFMA_F32_16x16x32_F8E4M3FNUZ:
case MMAIntrinsic::VMFMA_F32_16x16x32_F16:
case MMAIntrinsic::MFMA_I32_16x16x32_I8: {
auto aType = VectorType::get({8}, getAType());
@@ -471,6 +473,7 @@
case MMAIntrinsic::MFMA_I32_32x32x8_I8:
case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ:
case MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ:
+ case MMAIntrinsic::VMFMA_F32_16x16x32_F8E4M3FNUZ:
case MMAIntrinsic::VMFMA_F32_16x16x32_F16:
case MMAIntrinsic::MFMA_I32_16x16x32_I8:
case MMAIntrinsic::VMFMA_F32_32x32x16_F16:
@@ -496,6 +499,7 @@
case MMAIntrinsic::MFMA_I32_32x32x8_I8:
case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ:
case MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ:
+ case MMAIntrinsic::VMFMA_F32_16x16x32_F8E4M3FNUZ:
case MMAIntrinsic::VMFMA_F32_16x16x32_F16:
case MMAIntrinsic::MFMA_I32_16x16x32_I8:
case MMAIntrinsic::VMFMA_F32_32x32x16_F16:
@@ -578,6 +582,18 @@
return {/*outer=*/{1, 1}, /*thread=*/{4, 16}, /*tstrides=*/{16, 1},
/*element=*/{4, 1}};
}
+ case MMAIntrinsic::VMFMA_F32_16x16x32_F8E4M3FNUZ:
+ switch (fragment) {
+ case MMAFragment::Lhs:
+ return {/*outer=*/{1, 2}, /*thread=*/{16, 4}, /*tstrides=*/{1, 16},
+ /*element=*/{1, 4}};
+ case MMAFragment::Rhs:
+ return {/*outer=*/{2, 1}, /*thread=*/{4, 16}, /*tstrides=*/{16, 1},
+ /*element=*/{4, 1}};
+ case MMAFragment::Acc:
+ return {/*outer=*/{1, 1}, /*thread=*/{4, 16}, /*tstrides=*/{16, 1},
+ /*element=*/{4, 1}};
+ }
case MMAIntrinsic::VMFMA_F32_32x32x16_F16:
case MMAIntrinsic::MFMA_I32_32x32x16_I8:
switch (fragment) {
@@ -639,6 +655,8 @@
return {MMAIntrinsic::VMFMA_F32_16x16x32_F16};
case MMAIntrinsic::MFMA_F32_32x32x8_F16:
return {MMAIntrinsic::VMFMA_F32_32x32x16_F16};
+ case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ:
+ return {MMAIntrinsic::VMFMA_F32_16x16x32_F8E4M3FNUZ};
default:
return {};
}
@@ -711,6 +729,7 @@
case MMAIntrinsic::MFMA_F32_32x32x8_BF16:
case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ:
case MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ:
+ case MMAIntrinsic::VMFMA_F32_16x16x32_F8E4M3FNUZ:
case MMAIntrinsic::MFMA_I32_16x16x32_I8:
case MMAIntrinsic::MFMA_I32_32x32x16_I8: {
auto [m, n, k] = getMNKShape();
@@ -1618,6 +1637,51 @@
return getAttributes().getAs<IREE::GPU::MmaInterfaceAttr>(kMmaKindName);
}
+void LoweringConfigAttr::setMmaKind(MLIRContext *context,
+ SmallVectorImpl<NamedAttribute> &attrs,
+ IREE::GPU::MmaInterfaceAttr kind) {
+ attrs.emplace_back(StringAttr::get(context, kMmaKindName), kind);
+}
+
+// TODO: Merge subgroup counts functionality into subgroup tiling level
+// lowering, when we have it implemented.
+constexpr StringLiteral kSubgroupMCountName = "subgroup_m_count";
+constexpr StringLiteral kSubgroupNCountName = "subgroup_n_count";
+
+std::optional<int64_t> LoweringConfigAttr::getSubgroupMCount() const {
+ auto subgroup_m_count_attr =
+ getAttributes().getAs<IntegerAttr>(kSubgroupMCountName);
+ if (!subgroup_m_count_attr) {
+ return std::nullopt;
+ }
+ return subgroup_m_count_attr.getInt();
+}
+
+std::optional<int64_t> LoweringConfigAttr::getSubgroupNCount() const {
+ auto subgroup_n_count_attr =
+ getAttributes().getAs<IntegerAttr>(kSubgroupNCountName);
+ if (!subgroup_n_count_attr) {
+ return std::nullopt;
+ }
+ return subgroup_n_count_attr.getInt();
+}
+
+void LoweringConfigAttr::setSubgroupMCount(
+ MLIRContext *context, SmallVectorImpl<NamedAttribute> &attrs,
+ int64_t subgroup_m_count) {
+ attrs.emplace_back(
+ StringAttr::get(context, kSubgroupMCountName),
+ IntegerAttr::get(IntegerType::get(context, 64), subgroup_m_count));
+}
+
+void LoweringConfigAttr::setSubgroupNCount(
+ MLIRContext *context, SmallVectorImpl<NamedAttribute> &attrs,
+ int64_t subgroup_n_count) {
+ attrs.emplace_back(
+ StringAttr::get(context, kSubgroupNCountName),
+ IntegerAttr::get(IntegerType::get(context, 64), subgroup_n_count));
+}
+
constexpr StringLiteral kPromoteOperandsName = "promote_operands";
std::optional<SmallVector<int64_t>>
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td
index 411d382..ee4cb93 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td
@@ -57,8 +57,23 @@
"The configured fields, including tiling levels">:$attributes
);
let extraClassDeclaration = [{
- /// Helper to retrieve a target mma intrinsic if present.
+ /// Helper to retrieve/set a target mma intrinsic.
::mlir::iree_compiler::IREE::GPU::MmaInterfaceAttr getMmaKind() const;
+ static void setMmaKind(MLIRContext *context,
+ SmallVectorImpl<NamedAttribute> &attrs,
+ ::mlir::iree_compiler::IREE::GPU::MmaInterfaceAttr kind);
+
+ // TODO: Merge subgroup counts functionality into subgroup tiling level
+ // lowering, when we have it implemented.
+ /// Helper to retrieve/set a target subgroup M/N counts.
+ std::optional<int64_t> getSubgroupMCount() const;
+ std::optional<int64_t> getSubgroupNCount() const;
+ static void setSubgroupMCount(MLIRContext *context,
+ SmallVectorImpl<NamedAttribute> &attrs,
+ int64_t subgroup_m_count);
+ static void setSubgroupNCount(MLIRContext *context,
+ SmallVectorImpl<NamedAttribute> &attrs,
+ int64_t subgroup_n_count);
/// Helper to retrieve/set a list of operand indices to promote.
std::optional<SmallVector<int64_t>> getPromotedOperandList() const;
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.td b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.td
index 1afdf0d..49f210e 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.td
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.td
@@ -133,6 +133,9 @@
def MFMA_F32_32x32x8_BF16 : I32EnumAttrCase<"MFMA_F32_32x32x8_BF16", 0x0921>;
def MFMA_F32_16x16x32_F8E5M2FNUZ : I32EnumAttrCase<"MFMA_F32_16x16x32_F8E5M2FNUZ", 0x0930>;
def MFMA_F32_16x16x32_F8E4M3FNUZ : I32EnumAttrCase<"MFMA_F32_16x16x32_F8E4M3FNUZ", 0x0940>;
+// V-Intrinsic below interleaves read from K-dim from one 8xF8 to two 4xF8.
+// (Useful in F8 chained-MM to align B-layout of 2nd MM to C-layout of 1st MM)
+def VMFMA_F32_16x16x32_F8E4M3FNUZ : I32EnumAttrCase<"VMFMA_F32_16x16x32_F8E4M3FNUZ", 0x0941>;
def MFMA_I32_16x16x32_I8 : I32EnumAttrCase<"MFMA_I32_16x16x32_I8", 0x0980>;
def MFMA_I32_32x32x16_I8 : I32EnumAttrCase<"MFMA_I32_32x32x16_I8", 0x0981>;
@@ -159,6 +162,7 @@
MFMA_F32_32x32x8_BF16,
MFMA_F32_16x16x32_F8E4M3FNUZ,
MFMA_F32_16x16x32_F8E5M2FNUZ,
+ VMFMA_F32_16x16x32_F8E4M3FNUZ,
MFMA_I32_16x16x32_I8,
MFMA_I32_32x32x16_I8,
MFMA_I32_16x16x16_I8,
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtAttrs.td b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtAttrs.td
index bf31ba6..446ff77 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtAttrs.td
+++ b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtAttrs.td
@@ -256,25 +256,24 @@
}];
let parameters = (ins
- ArrayRefParameter<"int64_t", "subgroup_tile">:$subgroupTile,
- ArrayRefParameter<"int64_t", "batch_tile">:$batchTile,
- ArrayRefParameter<"int64_t", "outer_tile">:$outerTile,
- ArrayRefParameter<"int64_t", "thread_tile">:$threadTile,
- ArrayRefParameter<"int64_t", "element_tile">:$elementTile,
+ OptionalArrayRefParameter<"int64_t", "subgroup_tile">:$subgroupTile,
+ OptionalArrayRefParameter<"int64_t", "batch_tile">:$batchTile,
+ OptionalArrayRefParameter<"int64_t", "outer_tile">:$outerTile,
+ OptionalArrayRefParameter<"int64_t", "thread_tile">:$threadTile,
+ OptionalArrayRefParameter<"int64_t", "element_tile">:$elementTile,
- ArrayRefParameter<"int64_t", "subgroup_strides">:$subgroupStrides,
- ArrayRefParameter<"int64_t", "thread_strides">:$threadStrides
+ OptionalArrayRefParameter<"int64_t", "subgroup_strides">:$subgroupStrides,
+ OptionalArrayRefParameter<"int64_t", "thread_strides">:$threadStrides
);
let assemblyFormat = [{
- `<` `subgroup_tile` `=` `[` $subgroupTile `]` `,`
- `batch_tile` `=` `[` $batchTile `]` `,`
- `outer_tile` `=` `[` $outerTile `]` `,`
- `thread_tile` `=` `[` $threadTile `]` `,`
- `element_tile` `=` `[` $elementTile `]` `,`
-
- `subgroup_strides` `=` `[` $subgroupStrides `]` `,`
- `thread_strides` `=` `[` $threadStrides `]`
+ `<` `subgroup_tile` `=` `[` (`]`) : ($subgroupTile^ `]`)? `,`
+ `batch_tile` `=` `[` (`]`) : ($batchTile^ `]`)? `,`
+ `outer_tile` `=` `[` (`]`) : ($outerTile^ `]`)? `,`
+ `thread_tile` `=` `[` (`]`) : ($threadTile^ `]`)? `,`
+ `element_tile` `=` `[` (`]`) : ($elementTile^ `]`)? `,`
+ `subgroup_strides` `=` `[` (`]`) : ($subgroupStrides^ `]`)? `,`
+ `thread_strides` `=` `[` (`]`) : ($threadStrides^ `]`)?
`>`
}];
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtOps.td b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtOps.td
index 4e40cd8..04055bf 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtOps.td
+++ b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtOps.td
@@ -84,10 +84,10 @@
distributed vectors.
}];
let arguments = (ins
- AnyVector:$input
+ AnyVectorOfAnyRank:$input
);
let results = (outs
- AnyVector:$output
+ AnyVectorOfAnyRank:$output
);
let extraClassDeclaration = [{}];
let assemblyFormat = "$input attr-dict `:` type($input) `->` type($output)";
@@ -103,10 +103,10 @@
distributed vectors.
}];
let arguments = (ins
- AnyVector:$input
+ AnyVectorOfAnyRank:$input
);
let results = (outs
- AnyVector:$output
+ AnyVectorOfAnyRank:$output
);
let extraClassDeclaration = [{}];
let assemblyFormat = "$input attr-dict `:` type($input) `->` type($output)";
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/test/roundtrip.mlir b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/test/roundtrip.mlir
index f320654..fc14c3b 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/test/roundtrip.mlir
+++ b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/test/roundtrip.mlir
@@ -88,6 +88,37 @@
// -----
+#nested_0 = #iree_vector_ext.nested_layout<
+ subgroup_tile = [],
+ batch_tile = [],
+ outer_tile = [],
+ thread_tile = [],
+ element_tile = [],
+
+ subgroup_strides = [],
+ thread_strides = []
+>
+
+func.func @specify_nested_0d(%lhs: vector<f16>) -> vector<f16> {
+ %result = iree_vector_ext.to_layout %lhs to layout(#nested_0) : vector<f16>
+ func.return %result : vector<f16>
+}
+
+// CHECK: #[[$LAYOUT0:.+]] = #iree_vector_ext.nested_layout<
+// CHECK-SAME: subgroup_tile = [],
+// CHECK-SAME: batch_tile = [],
+// CHECK-SAME: outer_tile = [],
+// CHECK-SAME: thread_tile = [],
+// CHECK-SAME: element_tile = [],
+// CHECK-SAME: subgroup_strides = [],
+// CHECK-SAME: thread_strides = []>
+
+// CHECK-LABEL: func.func @specify_nested_0d
+// CHECK: to_layout
+// CHECK-SAME: layout(#[[$LAYOUT0]])
+
+// -----
+
func.func @to_simd_op(%simt: vector<4x4x4xf16>) -> vector<64x64xf16> {
%simd = iree_vector_ext.to_simd %simt : vector<4x4x4xf16> -> vector<64x64xf16>
func.return %simd : vector<64x64xf16>
@@ -103,3 +134,21 @@
}
// CHECK-LABEL: func.func @to_simt_op
// CHECK: iree_vector_ext.to_simd
+
+// -----
+
+func.func @to_simd_op_0d(%simt: vector<f16>) -> vector<f16> {
+ %simd = iree_vector_ext.to_simd %simt : vector<f16> -> vector<f16>
+ func.return %simd : vector<f16>
+}
+// CHECK-LABEL: func.func @to_simd_op
+// CHECK: iree_vector_ext.to_simd
+
+// -----
+
+func.func @to_simt_op_0d(%simd: vector<f32>) -> vector<f32> {
+ %simt = iree_vector_ext.to_simd %simd : vector<f32> -> vector<f32>
+ func.return %simt : vector<f32>
+}
+// CHECK-LABEL: func.func @to_simt_op
+// CHECK: iree_vector_ext.to_simd
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
index e8926eb..45d592f 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
@@ -98,7 +98,7 @@
static llvm::cl::opt<bool>
clLLVMGPUUseIgemm("iree-codegen-llvmgpu-use-igemm",
llvm::cl::desc("Enable implicit gemm for convolutions."),
- llvm::cl::init(false));
+ llvm::cl::init(true));
namespace {
using CodeGenPipeline = IREE::Codegen::DispatchLoweringPassPipeline;
@@ -416,18 +416,17 @@
attrs.emplace_back(StringAttr::get(context, "reduction"),
b.getI64ArrayAttr(reductionTileSizes));
IREE::GPU::LoweringConfigAttr::setPromotedOperandList(context, attrs, {0, 1});
+ IREE::GPU::LoweringConfigAttr::setMmaKind(context, attrs,
+ mmaAttrs[schedule->index]);
+ IREE::GPU::LoweringConfigAttr::setSubgroupMCount(
+ context, attrs, schedule->mSubgroupCounts[0]);
+ IREE::GPU::LoweringConfigAttr::setSubgroupNCount(
+ context, attrs, schedule->nSubgroupCounts[0]);
auto configDict = DictionaryAttr::get(context, attrs);
auto loweringConfig = IREE::GPU::LoweringConfigAttr::get(context, configDict);
- // Attach the MMA schedule as an attribute to the entry point export function
- // for later access in the pipeline.
SmallVector<NamedAttribute, 1> pipelineAttrs;
- auto scheduleAttr = IREE::GPU::MMAScheduleAttr::get(
- context, mmaAttrs[schedule->index], schedule->mSubgroupCounts[0],
- schedule->nSubgroupCounts[0]);
- pipelineAttrs.emplace_back(StringAttr::get(context, "mma_schedule"),
- scheduleAttr);
// Prefetch shared memory if requested.
if (clLLVMGPUEnablePrefetch) {
@@ -682,6 +681,12 @@
attrs.emplace_back(StringAttr::get(context, "reduction"),
b.getI64ArrayAttr(reductionTileSizes));
IREE::GPU::LoweringConfigAttr::setPromotedOperandList(context, attrs, {0, 1});
+ IREE::GPU::LoweringConfigAttr::setMmaKind(context, attrs,
+ mmaAttrs[schedule->index]);
+ IREE::GPU::LoweringConfigAttr::setSubgroupMCount(
+ context, attrs, schedule->mSubgroupCounts[0]);
+ IREE::GPU::LoweringConfigAttr::setSubgroupNCount(
+ context, attrs, schedule->nSubgroupCounts[0]);
auto configDict = DictionaryAttr::get(context, attrs);
auto loweringConfig = IREE::GPU::LoweringConfigAttr::get(context, configDict);
@@ -689,11 +694,6 @@
// Attach the MMA schedule as an attribute to the entry point export function
// for later access in the pipeline.
SmallVector<NamedAttribute, 1> pipelineAttrs;
- auto scheduleAttr = IREE::GPU::MMAScheduleAttr::get(
- context, mmaAttrs[schedule->index], schedule->mSubgroupCounts[0],
- schedule->nSubgroupCounts[0]);
- pipelineAttrs.emplace_back(StringAttr::get(context, "mma_schedule"),
- scheduleAttr);
// Prefetch shared memory if requested.
if (clLLVMGPUEnablePrefetch) {
@@ -902,9 +902,32 @@
SmallVector<NamedAttribute, 2> qkConfig;
SmallVector<NamedAttribute, 2> pvConfig;
+ // On attention subgroup distribution:
+ // The subgroup distribution in attention is controlled by the second matmul
+ // (Parallel dimension distribution is usually (almost always) controlled by
+ // the last reduction operation in a dispatch). Since VectorDistribution
+ // doesn't have logic to set subgroup and thread layouts seperately, we
+ // explicitly set the subgroup count for the first matmul as well,
+ // corresponding to what the second matmul dictates.
+
+ // Configuring for qk matmul.
+ // subgroup_n count for qk matmul is always 1, since we do not tile K1.
IREE::GPU::LoweringConfigAttr::setPromotedOperandList(context, qkConfig,
{0, 1});
+ IREE::GPU::LoweringConfigAttr::setMmaKind(context, qkConfig,
+ mmaAttrs[schedule->index]);
+ IREE::GPU::LoweringConfigAttr::setSubgroupMCount(
+ context, qkConfig, schedule->mSubgroupCounts[0]);
+ IREE::GPU::LoweringConfigAttr::setSubgroupNCount(context, qkConfig, 1);
+
+ // Configuring for pv matmul.
IREE::GPU::LoweringConfigAttr::setPromotedOperandList(context, pvConfig, {1});
+ IREE::GPU::LoweringConfigAttr::setMmaKind(context, pvConfig,
+ mmaAttrs[schedule->index]);
+ IREE::GPU::LoweringConfigAttr::setSubgroupMCount(
+ context, pvConfig, schedule->mSubgroupCounts[0]);
+ IREE::GPU::LoweringConfigAttr::setSubgroupNCount(
+ context, pvConfig, schedule->nSubgroupCounts[0]);
SmallVector<NamedAttribute, 2> qkAttrs;
SmallVector<NamedAttribute, 2> pvAttrs;
@@ -938,14 +961,7 @@
auto configDict = b.getDictionaryAttr(attrs);
auto loweringConfig = IREE::GPU::LoweringConfigAttr::get(context, configDict);
- // Attach the MMA schedule as an attribute to the entry point export function
- // for later access in the pipeline.
SmallVector<NamedAttribute, 1> pipelineAttrs;
- auto scheduleAttr = IREE::GPU::MMAScheduleAttr::get(
- context, mmaAttrs[schedule->index], schedule->mSubgroupCounts[0],
- schedule->nSubgroupCounts[0]);
- pipelineAttrs.emplace_back(StringAttr::get(context, "mma_schedule"),
- scheduleAttr);
// TODO: We do not turn prefetching on even when requested by the prefetching
// flag because there is a shared memory allocation the two matmuls, which
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUConfigureTensorLayouts.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUConfigureTensorLayouts.cpp
index 4945e66..7008f3e 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUConfigureTensorLayouts.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUConfigureTensorLayouts.cpp
@@ -48,6 +48,29 @@
return promotedOperands;
}
+static IREE::GPU::MmaInterfaceAttr getIntrinsic(Operation *op) {
+ auto config = getLoweringConfig<IREE::GPU::LoweringConfigAttr>(op);
+ assert(config && "Cannot find intrinsic from unconfigured op.");
+
+ IREE::GPU::MmaInterfaceAttr mmaIntrinsic = config.getMmaKind();
+ assert(mmaIntrinsic && "Cannot find intrinsic in lowering config.");
+ return mmaIntrinsic;
+}
+
+static int64_t getSubgroupMCount(Operation *op) {
+ auto config = getLoweringConfig<IREE::GPU::LoweringConfigAttr>(op);
+ assert(config && "Cannot find intrinsic from unconfigured op.");
+
+ return *config.getSubgroupMCount();
+}
+
+static int64_t getSubgroupNCount(Operation *op) {
+ auto config = getLoweringConfig<IREE::GPU::LoweringConfigAttr>(op);
+ assert(config && "Cannot find intrinsic from unconfigured op.");
+
+ return *config.getSubgroupNCount();
+}
+
static LogicalResult setContractionAnchor(IREE::GPU::MMAScheduleAttr schedule,
SmallVector<bool> promotedOperands,
RewriterBase &rewriter,
@@ -264,14 +287,19 @@
schedule.getSubgroupMCount());
}
-static LogicalResult
-setAttentionMatmulAnchor(IREE::GPU::MMAScheduleAttr schedule,
- RewriterBase &rewriter, linalg::LinalgOp qkMatmul,
- linalg::LinalgOp pvMatmul) {
- // TODO: Add SIMT fallback.
- if (!schedule) {
- return pvMatmul->emitError("missing mma schedule for contraction");
- }
+static LogicalResult setAttentionMatmulAnchor(RewriterBase &rewriter,
+ linalg::LinalgOp qkMatmul,
+ linalg::LinalgOp pvMatmul) {
+
+ IREE::GPU::MMAScheduleAttr qkSchedule =
+ rewriter.getAttr<IREE::GPU::MMAScheduleAttr>(getIntrinsic(qkMatmul),
+ getSubgroupMCount(qkMatmul),
+ getSubgroupNCount(qkMatmul));
+
+ IREE::GPU::MMAScheduleAttr pvSchedule =
+ rewriter.getAttr<IREE::GPU::MMAScheduleAttr>(getIntrinsic(pvMatmul),
+ getSubgroupMCount(pvMatmul),
+ getSubgroupNCount(pvMatmul));
// Check if the intrinsic output for qkMatmul can be reused for pvMatmul.
// We know that pvMatmul takes result of qkMatmul as it's lhs.
@@ -280,13 +308,14 @@
bool reuseIntrinsicOutput = false;
bool transposeIntrinsic = false;
- auto intrinsic = cast<IREE::GPU::MMAAttr>(schedule.getIntrinsic());
+ auto qkIntrinsic = cast<IREE::GPU::MMAAttr>(qkSchedule.getIntrinsic());
+ auto pvIntrinsic = cast<IREE::GPU::MMAAttr>(pvSchedule.getIntrinsic());
IREE::GPU::MMASingleSubgroupLayout lhsLayout =
- intrinsic.getASingleSubgroupLayout();
+ pvIntrinsic.getASingleSubgroupLayout();
IREE::GPU::MMASingleSubgroupLayout rhsLayout =
- intrinsic.getBSingleSubgroupLayout();
+ pvIntrinsic.getBSingleSubgroupLayout();
IREE::GPU::MMASingleSubgroupLayout outLayout =
- intrinsic.getCSingleSubgroupLayout();
+ qkIntrinsic.getCSingleSubgroupLayout();
auto matchLayout = [](IREE::GPU::MMASingleSubgroupLayout layoutA,
IREE::GPU::MMASingleSubgroupLayout layoutB) -> bool {
@@ -305,15 +334,6 @@
transposeIntrinsic = true;
}
- // subgroup_n count for attention matmul is always 1, because it is the
- // reduction dimension. The subgroup_n count is in reality, for the pvMatmul.
- IREE::GPU::MMAScheduleAttr qkSchedule =
- rewriter.getAttr<IREE::GPU::MMAScheduleAttr>(
- schedule.getIntrinsic(),
- /*subgroup_m_count=*/schedule.getSubgroupMCount(),
- /*subgroup_n_count=*/1);
- IREE::GPU::MMAScheduleAttr pvSchedule = schedule;
-
SmallVector<bool> promotedQKOperands = getPromotedOperands(qkMatmul);
SmallVector<bool> promotedPVOperands = getPromotedOperands(pvMatmul);
@@ -488,12 +508,6 @@
return signalPassFailure();
}
- llvm::StringLiteral scheduleAttrName =
- IREE::GPU::MMAScheduleAttr::getMnemonic();
- DictionaryAttr configDict = getTranslationInfo(func).getConfiguration();
- auto scheduleAttr = dyn_cast_or_null<IREE::GPU::MMAScheduleAttr>(
- configDict.get(scheduleAttrName));
-
// Vector layout option setter aimed at contractions and convolutions. For
// now, layout setting for other problems like reductions is TODO.
SmallVector<linalg::LinalgOp> contracts;
@@ -529,23 +543,28 @@
for (linalg::LinalgOp contract : contracts) {
SmallVector<bool> promotedOperands = getPromotedOperands(contract);
- if (failed(setContractionAnchor(scheduleAttr, promotedOperands, rewriter,
- contract))) {
+ auto contractionSchedule = rewriter.getAttr<IREE::GPU::MMAScheduleAttr>(
+ getIntrinsic(contract), getSubgroupMCount(contract),
+ getSubgroupNCount(contract));
+ if (failed(setContractionAnchor(contractionSchedule, promotedOperands,
+ rewriter, contract))) {
return signalPassFailure();
}
}
for (linalg::LinalgOp conv : convs) {
SmallVector<bool> promotedOperands = getPromotedOperands(conv);
- if (failed(setConvolutionAnchor(scheduleAttr, promotedOperands, rewriter,
+ auto convSchedule = rewriter.getAttr<IREE::GPU::MMAScheduleAttr>(
+ getIntrinsic(conv), getSubgroupMCount(conv), getSubgroupNCount(conv));
+ if (failed(setConvolutionAnchor(convSchedule, promotedOperands, rewriter,
conv))) {
return signalPassFailure();
}
}
if (attentionQKMatmul && attentionPVMatmul) {
- if (failed(setAttentionMatmulAnchor(
- scheduleAttr, rewriter, attentionQKMatmul, attentionPVMatmul))) {
+ if (failed(setAttentionMatmulAnchor(rewriter, attentionQKMatmul,
+ attentionPVMatmul))) {
return signalPassFailure();
}
}
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel
index 4e5776f..3a0d7d3 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel
@@ -23,7 +23,6 @@
"amdgpu_set_anchor_layouts.mlir",
"assign_constant_ordinals.mlir",
"conv_pipeline_test_cuda.mlir",
- "conv_pipeline_test_rocm.mlir",
"convert_to_nvvm.mlir",
"convert_to_rocdl.mlir",
"create_async_groups.mlir",
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt
index bc3935f..f628010 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt
@@ -25,7 +25,6 @@
"config_winograd.mlir"
"configure_tensor_layout.mlir"
"conv_pipeline_test_cuda.mlir"
- "conv_pipeline_test_rocm.mlir"
"convert_to_nvvm.mlir"
"convert_to_rocdl.mlir"
"create_async_groups.mlir"
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse.mlir
index 6bef11e..80d4691 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse.mlir
@@ -1,5 +1,6 @@
// RUN: iree-opt --mlir-print-local-scope --split-input-file --iree-gpu-test-target=gfx940 \
// RUN: --iree-codegen-llvmgpu-test-tile-and-fuse-matmul=true --iree-codegen-llvmgpu-test-tile-and-fuse-vectorize=true \
+// RUN: --iree-codegen-llvmgpu-use-igemm=false \
// RUN: --pass-pipeline="builtin.module(iree-llvmgpu-select-lowering-strategy)" %s | FileCheck %s
// TODO: This test is still using the legacy LLVMGPU kernel config. This needs
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_user_vector_distribute.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_user_vector_distribute.mlir
index a1b1627..2820162 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_user_vector_distribute.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_user_vector_distribute.mlir
@@ -11,12 +11,12 @@
// OPT-OUT: #[[$TRANSLATION:.+]] = #iree_codegen.translation_info<LLVMGPUVectorDistribute workgroup_size = [128, 2, 1] subgroup_size = 64
// OPT-OUT-SAME: gpu_pipeline_options = #iree_gpu.pipeline_options<no_reduce_shared_memory_bank_conflicts = true>
-// OPT-OUT-SAME: mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>,
// OPT-IN: #[[$TRANSLATION:.+]] = #iree_codegen.translation_info<LLVMGPUVectorDistribute workgroup_size = [128, 2, 1] subgroup_size = 64
// OPT-IN-SAME: gpu_pipeline_options = #iree_gpu.pipeline_options<no_reduce_shared_memory_bank_conflicts = true>
-// OPT-IN-SAME: mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>,
-#config = #iree_gpu.lowering_config<{workgroup = [128, 128, 0], reduction = [0, 0, 32], promote_operands = [0, 1]}>
+#config = #iree_gpu.lowering_config<{workgroup = [128, 128, 0], reduction = [0, 0, 32], promote_operands = [0, 1],
+ mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>,
+ subgroup_m_count = 2, subgroup_n_count = 2}>
#pipeline_layout = #hal.pipeline.layout<bindings = [
#hal.pipeline.binding<storage_buffer>,
#hal.pipeline.binding<storage_buffer>,
@@ -48,7 +48,6 @@
func.func @main_0_dispatch_0_matmul_transpose_b_2048x10240x1280_f16xf16xf32()
attributes {translation_info = #iree_codegen.translation_info<LLVMGPUVectorDistribute workgroup_size = [128, 2, 1] subgroup_size = 64, {
- mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>, subgroup_m_count = 2, subgroup_n_count = 2>,
gpu_pipeline_options = #iree_gpu.pipeline_options<no_reduce_shared_memory_bank_conflicts = true> // Disable the 'reduceSharedMemoryBankConflicts' pass.
}>} {
%cst = arith.constant 0.000000e+00 : f16
@@ -87,12 +86,12 @@
// OPT-OUT: #[[$TRANSLATION:.+]] = #iree_codegen.translation_info<LLVMGPUVectorDistribute workgroup_size = [128, 2, 1] subgroup_size = 64
// OPT-OUT-SAME: gpu_pipeline_options = #iree_gpu.pipeline_options<reorder_workgroups_strategy = <Transpose>>
-// OPT-OUT-SAME: mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>,
// OPT-IN: #[[$TRANSLATION:.+]] = #iree_codegen.translation_info<LLVMGPUVectorDistribute workgroup_size = [128, 2, 1] subgroup_size = 64
// OPT-IN-SAME: gpu_pipeline_options = #iree_gpu.pipeline_options<reorder_workgroups_strategy = <Transpose>>
-// OPT-IN-SAME: mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>,
-#config = #iree_gpu.lowering_config<{workgroup = [128, 128, 0], reduction = [0, 0, 32], promote_operands = [0, 1]}>
+#config = #iree_gpu.lowering_config<{workgroup = [128, 128, 0], reduction = [0, 0, 32], promote_operands = [0, 1],
+ mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>,
+ subgroup_m_count = 2, subgroup_n_count = 2}>
#pipeline_layout = #hal.pipeline.layout<bindings = [
#hal.pipeline.binding<storage_buffer>,
#hal.pipeline.binding<storage_buffer>,
@@ -125,7 +124,6 @@
// OPT-IN: scf.for
func.func @main_0_dispatch_0_matmul_transpose_b_2048x10240x1280_f16xf16xf32()
attributes {translation_info = #iree_codegen.translation_info<LLVMGPUVectorDistribute workgroup_size = [128, 2, 1] subgroup_size = 64, {
- mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>, subgroup_m_count = 2, subgroup_n_count = 2>,
gpu_pipeline_options = #iree_gpu.pipeline_options<reorder_workgroups_strategy = <Transpose>> // enable the 'reorderWorkgroups' pass.
}>} {
%cst = arith.constant 0.000000e+00 : f16
@@ -163,8 +161,9 @@
// OPT-OUT: #[[$TRANSLATION:.+]] = #iree_codegen.translation_info<LLVMGPUVectorDistribute workgroup_size = [128, 2, 1] subgroup_size = 64
// OPT-OUT-SAME: gpu_pipeline_options = #iree_gpu.pipeline_options<reorder_workgroups_strategy = <None>>
-// OPT-OUT-SAME: mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>,
-#config = #iree_gpu.lowering_config<{workgroup = [128, 128, 0], reduction = [0, 0, 32], promote_operands = [0, 1]}>
+#config = #iree_gpu.lowering_config<{workgroup = [128, 128, 0], reduction = [0, 0, 32], promote_operands = [0, 1],
+ mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>,
+ subgroup_m_count = 2, subgroup_n_count = 2}>
#pipeline_layout = #hal.pipeline.layout<bindings = [
#hal.pipeline.binding<storage_buffer>,
#hal.pipeline.binding<storage_buffer>,
@@ -186,7 +185,6 @@
// OPT-OUT-NEXT: scf.for
func.func @main_0_dispatch_0_matmul_transpose_b_2048x10240x1280_f16xf16xf32()
attributes {translation_info = #iree_codegen.translation_info<LLVMGPUVectorDistribute workgroup_size = [128, 2, 1] subgroup_size = 64, {
- mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>, subgroup_m_count = 2, subgroup_n_count = 2>,
gpu_pipeline_options = #iree_gpu.pipeline_options<reorder_workgroups_strategy = <None>> // Disable the 'reorderWorkgroups' pass.
}>} {
%cst = arith.constant 0.000000e+00 : f16
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_vector_distribute_gfx1100.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_vector_distribute_gfx1100.mlir
index 9d45ea0..6028b47 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_vector_distribute_gfx1100.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_vector_distribute_gfx1100.mlir
@@ -6,9 +6,8 @@
// located here.
// WMMA: #iree_codegen.translation_info<LLVMGPUVectorDistribute
-// WMMA-SAME: mma_schedule = #iree_gpu.mma_schedule
-// WMMA-SAME: intrinsic = #iree_gpu.mma_layout<WMMA_F32_16x16x16_F16>
-// WMMA-SAME: subgroup_m_count = 2, subgroup_n_count = 2
+// WMMA-SAME: workgroup_size = [128, 1, 1]
+// WMMA-SAME: subgroup_size = 32
#pipeline_layout = #hal.pipeline.layout<bindings = [
#hal.pipeline.binding<storage_buffer>,
@@ -32,5 +31,8 @@
// WMMA-LABEL: func.func @wmma_matmul_1024x1024x1024()
// WMMA: linalg.matmul {{.*}}lowering_config = #iree_gpu.lowering_config
+// WMMA-SAME: mma_kind = #iree_gpu.mma_layout<WMMA_F32_16x16x16_F16>
// WMMA-SAME: reduction = [0, 0, 64]
+// WMMA-SAME: subgroup_m_count = 2
+// WMMA-SAME: subgroup_n_count = 2
// WMMA-SAME: workgroup = [64, 128, 0]
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_vector_distribute_gfx940.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_vector_distribute_gfx940.mlir
index da6a563..d848c03 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_vector_distribute_gfx940.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_vector_distribute_gfx940.mlir
@@ -1,5 +1,5 @@
// RUN: iree-opt --split-input-file --iree-gpu-test-target=gfx940 --iree-codegen-llvmgpu-use-vector-distribution \
-// RUN: --iree-codegen-llvmgpu-use-unaligned-gemm-vector-distribution \
+// RUN: --iree-codegen-llvmgpu-use-unaligned-gemm-vector-distribution --iree-codegen-llvmgpu-use-igemm=false \
// RUN: --pass-pipeline="builtin.module(iree-llvmgpu-select-lowering-strategy)" %s | FileCheck %s
// TODO: This test is still using the legacy LLVMGPU kernel config. This needs
@@ -7,9 +7,6 @@
// located here.
// CHECK: #iree_codegen.translation_info<LLVMGPUVectorDistribute
-// CHECK-SAME: mma_schedule = #iree_gpu.mma_schedule
-// CHECK-SAME: intrinsic = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>
-// CHECK-SAME: subgroup_m_count = 1, subgroup_n_count = 4
#pipeline_layout = #hal.pipeline.layout<bindings = [
#hal.pipeline.binding<storage_buffer>,
@@ -41,15 +38,15 @@
// CHECK-LABEL: func.func @expanded_matmul_transpose_b()
// CHECK: linalg.generic {{.*}}lowering_config = #iree_gpu.lowering_config
+// CHECK-SAME: mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>
// CHECK-SAME: reduction = [0, 0, 0, 0, 128]
+// CHECK-SAME: subgroup_m_count = 1
+// CHECK-SAME: subgroup_n_count = 4
// CHECK-SAME: workgroup = [1, 1, 64, 64, 0]
// -----
// CHECK: #iree_codegen.translation_info<LLVMGPUVectorDistribute
-// CHECK-SAME: mma_schedule = #iree_gpu.mma_schedule
-// CHECK-SAME: intrinsic = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>
-// CHECK-SAME: subgroup_m_count = 2, subgroup_n_count = 2
#pipeline_layout = #hal.pipeline.layout<bindings = [
#hal.pipeline.binding<storage_buffer>,
@@ -73,7 +70,10 @@
// CHECK-LABEL: func.func @conv_nhwc()
// CHECK: linalg.conv_2d_nhwc_hwcf {{.*}} lowering_config = #iree_gpu.lowering_config
+// CHECK-SAME: mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>
// CHECK-SAME: reduction = [0, 0, 0, 0, 1, 1, 32]
+// CHECK-SAME: subgroup_m_count = 2
+// CHECK-SAME: subgroup_n_count = 2
// CHECK-SAME: workgroup = [1, 1, 64, 128, 0, 0, 0]
// -----
@@ -113,9 +113,6 @@
// -----
// CHECK: #iree_codegen.translation_info<LLVMGPUVectorDistribute
-// CHECK-SAME: mma_schedule = #iree_gpu.mma_schedule
-// CHECK-SAME: intrinsic = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>
-// CHECK-SAME: subgroup_m_count = 2, subgroup_n_count = 2
#pipeline_layout = #hal.pipeline.layout<bindings = [
#hal.pipeline.binding<storage_buffer>,
@@ -139,15 +136,15 @@
// CHECK-LABEL: func.func @mfma_matmul_1024x1024x1024()
// CHECK: linalg.matmul {{.*}}lowering_config = #iree_gpu.lowering_config
+// CHECK-SAME: mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>
// CHECK-SAME: reduction = [0, 0, 64]
+// CHECK-SAME: subgroup_m_count = 2
+// CHECK-SAME: subgroup_n_count = 2
// CHECK-SAME: workgroup = [64, 128, 0]
// -----
// CHECK: #iree_codegen.translation_info<LLVMGPUVectorDistribute
-// CHECK-SAME: mma_schedule = #iree_gpu.mma_schedule
-// CHECK-SAME: intrinsic = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>
-// CHECK-SAME: subgroup_m_count = 2, subgroup_n_count = 2
#pipeline_layout = #hal.pipeline.layout<bindings = [
#hal.pipeline.binding<storage_buffer>,
@@ -190,15 +187,15 @@
// CHECK-LABEL: func.func @conv_nchwc()
// CHECK: linalg.generic {{.*}}lowering_config = #iree_gpu.lowering_config
+// CHECK-SAME: mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>
// CHECK-SAME: reduction = [0, 0, 0, 0, 0, 1, 1, 1, 32]
+// CHECK-SAME: subgroup_m_count = 2
+// CHECK-SAME: subgroup_n_count = 2
// CHECK-SAME: workgroup = [1, 1, 1, 32, 32, 0, 0, 0, 0]
// -----
// CHECK: #iree_codegen.translation_info<LLVMGPUPadAndVectorDistribute
-// CHECK-SAME: mma_schedule = #iree_gpu.mma_schedule
-// CHECK-SAME: intrinsic = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>
-// CHECK-SAME: subgroup_m_count = 1, subgroup_n_count = 1
#pipeline_layout = #hal.pipeline.layout<bindings = [
#hal.pipeline.binding<storage_buffer>,
@@ -222,15 +219,15 @@
// CHECK-LABEL: func.func @unaligned_mk_batch_matmul()
// CHECK: linalg.batch_matmul
// CHECK-SAME: lowering_config = #iree_gpu.lowering_config
+// CHECK-SAME: mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>
// CHECK-SAME: reduction = [0, 0, 0, 16]
+// CHECK-SAME: subgroup_m_count = 1
+// CHECK-SAME: subgroup_n_count = 1
// CHECK-SAME: workgroup = [1, 16, 16, 0]
// -----
// CHECK: #iree_codegen.translation_info<LLVMGPUPadAndVectorDistribute
-// CHECK-SAME: mma_schedule = #iree_gpu.mma_schedule
-// CHECK-SAME: intrinsic = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>
-// CHECK-SAME: subgroup_m_count = 1, subgroup_n_count = 4
#pipeline_layout = #hal.pipeline.layout<bindings = [
#hal.pipeline.binding<storage_buffer>,
@@ -254,7 +251,10 @@
// CHECK-LABEL: func.func @unaligned_m_batch_matmul_64x72x1280x1280()
// CHECK: linalg.batch_matmul
// CHECK-SAME: lowering_config = #iree_gpu.lowering_config
+// CHECK-SAME: mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>
// CHECK-SAME: reduction = [0, 0, 0, 128]
+// CHECK-SAME: subgroup_m_count = 1
+// CHECK-SAME: subgroup_n_count = 4
// CHECK-SAME: workgroup = [1, 16, 128, 0]
// -----
@@ -318,7 +318,6 @@
// -----
// CHECK: #iree_codegen.translation_info<LLVMGPUVectorDistribute
-// CHECK-SAME: subgroup_m_count = 2, subgroup_n_count = 1
// CHECK-NOT: prefetch_shared_memory = true
// CHECK-LABEL: func.func @attention_20x4096x64x4096x64()
@@ -354,13 +353,14 @@
}
// CHECK: #iree_gpu.lowering_config
+// CHECK-SAME: subgroup_m_count = 2
+// CHECK-SAME: subgroup_n_count = 1
// CHECK-SAME: reduction = [0, 0, 0, 64, 0]
// CHECK-SAME: workgroup = [1, 64, 0, 0, 64]
// -----
// CHECK: #iree_codegen.translation_info<LLVMGPUVectorDistribute
-// CHECK-SAME: subgroup_m_count = 2, subgroup_n_count = 1
// CHECK-NOT: prefetch_shared_memory = true
// CHECK-LABEL: func.func @attention_large_head_dim_shared_mem()
@@ -399,5 +399,7 @@
}
// CHECK: #iree_gpu.lowering_config
+// CHECK-SAME: subgroup_m_count = 2
+// CHECK-SAME: subgroup_n_count = 1
// CHECK-SAME: reduction = [0, 0, 16, 0]
// CHECK-SAME: workgroup = [32, 0, 0, 32]
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx1100.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx1100.mlir
index 58bfb00..128a259 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx1100.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx1100.mlir
@@ -3,8 +3,8 @@
// RUN: --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(builtin.module(func.func(iree-llvmgpu-lower-executable-target)))))" \
// RUN: %s | FileCheck %s
-#config = #iree_gpu.lowering_config<{workgroup = [64, 64, 0], reduction = [0, 0, 128]}>
-#translation = #iree_codegen.translation_info<LLVMGPUVectorDistribute workgroup_size = [64, 2, 1] subgroup_size = 32, {gpu_pipeline_options = #iree_gpu.pipeline_options<prefetch_shared_memory = true, no_reduce_shared_memory_bank_conflicts = false>, mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<WMMA_F32_16x16x16_F16>, subgroup_m_count = 2, subgroup_n_count = 2>}>
+#config = #iree_gpu.lowering_config<{workgroup = [64, 64, 0], reduction = [0, 0, 128], mma_kind = #iree_gpu.mma_layout<WMMA_F32_16x16x16_F16>, subgroup_m_count = 2, subgroup_n_count = 2}>
+#translation = #iree_codegen.translation_info<LLVMGPUVectorDistribute workgroup_size = [64, 2, 1] subgroup_size = 32, {gpu_pipeline_options = #iree_gpu.pipeline_options<prefetch_shared_memory = true, no_reduce_shared_memory_bank_conflicts = false>}>
#pipeline_layout = #hal.pipeline.layout<bindings = [
#hal.pipeline.binding<storage_buffer>,
@@ -50,8 +50,8 @@
// -----
-#config = #iree_gpu.lowering_config<{workgroup = [64, 64, 0], reduction = [0, 0, 128]}>
-#translation = #iree_codegen.translation_info<LLVMGPUVectorDistribute workgroup_size = [64, 2, 1] subgroup_size = 32, {gpu_pipeline_options = #iree_gpu.pipeline_options<prefetch_shared_memory = true, no_reduce_shared_memory_bank_conflicts = false>, mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<WMMA_F16_16x16x16_F16>, subgroup_m_count = 2, subgroup_n_count = 2>}>
+#config = #iree_gpu.lowering_config<{workgroup = [64, 64, 0], reduction = [0, 0, 128], mma_kind = #iree_gpu.mma_layout<WMMA_F16_16x16x16_F16>, subgroup_m_count = 2, subgroup_n_count = 2}>
+#translation = #iree_codegen.translation_info<LLVMGPUVectorDistribute workgroup_size = [64, 2, 1] subgroup_size = 32, {gpu_pipeline_options = #iree_gpu.pipeline_options<prefetch_shared_memory = true, no_reduce_shared_memory_bank_conflicts = false>}>
#pipeline_layout = #hal.pipeline.layout<bindings = [
#hal.pipeline.binding<storage_buffer>,
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx940.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx940.mlir
index cedec2d..57ebf56 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx940.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx940.mlir
@@ -8,8 +8,8 @@
// RUN: --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(builtin.module(func.func(iree-llvmgpu-lower-executable-target)))))" \
// RUN: %s | FileCheck %s --check-prefix=MEMORY
-#config = #iree_gpu.lowering_config<{workgroup = [64, 64, 0], reduction = [0, 0, 128], promote_operands = [0, 1]}>
-#translation = #iree_codegen.translation_info<LLVMGPUVectorDistribute workgroup_size = [256, 1, 1] subgroup_size = 64, {gpu_pipeline_options = #iree_gpu.pipeline_options<prefetch_shared_memory = true, no_reduce_shared_memory_bank_conflicts = false>, mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>, subgroup_m_count = 2, subgroup_n_count = 2>}>
+#config = #iree_gpu.lowering_config<{workgroup = [64, 64, 0], reduction = [0, 0, 128], promote_operands = [0, 1], mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>, subgroup_m_count = 2, subgroup_n_count = 2}>
+#translation = #iree_codegen.translation_info<LLVMGPUVectorDistribute workgroup_size = [256, 1, 1] subgroup_size = 64, {gpu_pipeline_options = #iree_gpu.pipeline_options<prefetch_shared_memory = true, no_reduce_shared_memory_bank_conflicts = false>}>
#pipeline_layout = #hal.pipeline.layout<bindings = [
#hal.pipeline.binding<storage_buffer>,
@@ -54,8 +54,8 @@
// -----
-#config = #iree_gpu.lowering_config<{workgroup = [64, 64, 0], reduction = [0, 0, 128], promote_operands = [0, 1]}>
-#translation = #iree_codegen.translation_info<LLVMGPUVectorDistribute workgroup_size = [256, 1, 1] subgroup_size = 64, {gpu_pipeline_options = #iree_gpu.pipeline_options<prefetch_shared_memory = true, no_reduce_shared_memory_bank_conflicts = false>, mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>, subgroup_m_count = 2, subgroup_n_count = 2>}>
+#config = #iree_gpu.lowering_config<{workgroup = [64, 64, 0], reduction = [0, 0, 128], promote_operands = [0, 1], mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>, subgroup_m_count = 2, subgroup_n_count = 2}>
+#translation = #iree_codegen.translation_info<LLVMGPUVectorDistribute workgroup_size = [256, 1, 1] subgroup_size = 64, {gpu_pipeline_options = #iree_gpu.pipeline_options<prefetch_shared_memory = true, no_reduce_shared_memory_bank_conflicts = false>}>
#pipeline_layout = #hal.pipeline.layout<bindings = [
#hal.pipeline.binding<storage_buffer>,
@@ -98,8 +98,8 @@
// -----
-#config = #iree_gpu.lowering_config<{workgroup = [1, 1, 64, 64, 0], reduction = [0, 0, 0, 0, 128], promote_operands = [0, 1]}>
-#translation = #iree_codegen.translation_info<LLVMGPUVectorDistribute workgroup_size = [256, 1, 1] subgroup_size = 64, {gpu_pipeline_options = #iree_gpu.pipeline_options<prefetch_shared_memory = true, no_reduce_shared_memory_bank_conflicts = false>, mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>, subgroup_m_count = 1, subgroup_n_count = 4>}>
+#config = #iree_gpu.lowering_config<{workgroup = [1, 1, 64, 64, 0], reduction = [0, 0, 0, 0, 128], promote_operands = [0, 1], mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>, subgroup_m_count = 1, subgroup_n_count = 4}>
+#translation = #iree_codegen.translation_info<LLVMGPUVectorDistribute workgroup_size = [256, 1, 1] subgroup_size = 64, {gpu_pipeline_options = #iree_gpu.pipeline_options<prefetch_shared_memory = true, no_reduce_shared_memory_bank_conflicts = false>}>
#pipeline_layout = #hal.pipeline.layout<bindings = [
#hal.pipeline.binding<storage_buffer>,
@@ -180,7 +180,7 @@
hal.return %x, %y, %z : index, index, index
}
builtin.module {
- func.func @matmul_multiple_k() attributes {translation_info = #iree_codegen.translation_info<LLVMGPUVectorDistribute workgroup_size = [256, 1, 1] subgroup_size = 64, {mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>, subgroup_m_count = 1, subgroup_n_count = 4>}>} {
+ func.func @matmul_multiple_k() attributes {translation_info = #iree_codegen.translation_info<LLVMGPUVectorDistribute workgroup_size = [256, 1, 1] subgroup_size = 64>} {
%cst = arith.constant 0.000000e+00 : f16
%c0 = arith.constant 0 : index
%0 = hal.interface.binding.subspan layout(<bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(0) alignment(64) offset(%c0) flags("ReadOnly|Indirect") : !flow.dispatch.tensor<readonly:tensor<2x128x64x2048xf16>>
@@ -190,7 +190,7 @@
%4 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0, 0], sizes = [10, 128, 64, 2048], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<10x128x64x2048xf16>> -> tensor<10x128x64x2048xf16>
%5 = tensor.empty() : tensor<2x10x64x64xf16>
%6 = linalg.fill ins(%cst : f16) outs(%5 : tensor<2x10x64x64xf16>) -> tensor<2x10x64x64xf16>
- %7 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d4, d2, d5)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d4, d3, d5)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%3, %4 : tensor<2x128x64x2048xf16>, tensor<10x128x64x2048xf16>) outs(%6 : tensor<2x10x64x64xf16>) attrs = {lowering_config = #iree_gpu.lowering_config<{reduction = [0, 0, 0, 0, 1, 128], workgroup = [1, 1, 64, 64, 0, 0], promote_operands = [0, 1]}>} {
+ %7 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d4, d2, d5)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d4, d3, d5)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%3, %4 : tensor<2x128x64x2048xf16>, tensor<10x128x64x2048xf16>) outs(%6 : tensor<2x10x64x64xf16>) attrs = {lowering_config = #iree_gpu.lowering_config<{reduction = [0, 0, 0, 0, 1, 128], workgroup = [1, 1, 64, 64, 0, 0], promote_operands = [0, 1], mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>, subgroup_m_count = 1, subgroup_n_count = 4}>} {
^bb0(%in: f16, %in_0: f16, %out: f16):
%8 = arith.mulf %in, %in_0 : f16
%9 = arith.addf %8, %out : f16
@@ -217,8 +217,8 @@
// Basic f8, f8 -> f32 matmul.
-#config = #iree_gpu.lowering_config<{workgroup = [64, 64, 0], reduction = [0, 0, 256], promote_operands = [0, 1]}>
-#translation = #iree_codegen.translation_info<LLVMGPUVectorDistribute workgroup_size = [256, 1, 1] subgroup_size = 64, {gpu_pipeline_options = #iree_gpu.pipeline_options<prefetch_shared_memory = true, no_reduce_shared_memory_bank_conflicts = false>, mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<MFMA_F32_16x16x32_F8E4M3FNUZ>, subgroup_m_count = 2, subgroup_n_count = 2>}>
+#config = #iree_gpu.lowering_config<{workgroup = [64, 64, 0], reduction = [0, 0, 256], promote_operands = [0, 1], mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x32_F8E4M3FNUZ>, subgroup_m_count = 2, subgroup_n_count = 2}>
+#translation = #iree_codegen.translation_info<LLVMGPUVectorDistribute workgroup_size = [256, 1, 1] subgroup_size = 64, {gpu_pipeline_options = #iree_gpu.pipeline_options<prefetch_shared_memory = true, no_reduce_shared_memory_bank_conflicts = false>}>
#pipeline_layout = #hal.pipeline.layout<bindings = [
#hal.pipeline.binding<storage_buffer>,
@@ -263,8 +263,8 @@
// Basic i8, i8 -> i32 matmul.
-#config = #iree_gpu.lowering_config<{workgroup = [64, 64, 0], reduction = [0, 0, 256], promote_operands = [0, 1]}>
-#translation = #iree_codegen.translation_info<LLVMGPUVectorDistribute workgroup_size = [256, 1, 1] subgroup_size = 64, {gpu_pipeline_options = #iree_gpu.pipeline_options<prefetch_shared_memory = true, no_reduce_shared_memory_bank_conflicts = false>, mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<MFMA_I32_16x16x32_I8>, subgroup_m_count = 2, subgroup_n_count = 2>}>
+#config = #iree_gpu.lowering_config<{workgroup = [64, 64, 0], reduction = [0, 0, 256], promote_operands = [0, 1], mma_kind = #iree_gpu.mma_layout<MFMA_I32_16x16x32_I8>, subgroup_m_count = 2, subgroup_n_count = 2}>
+#translation = #iree_codegen.translation_info<LLVMGPUVectorDistribute workgroup_size = [256, 1, 1] subgroup_size = 64, {gpu_pipeline_options = #iree_gpu.pipeline_options<prefetch_shared_memory = true, no_reduce_shared_memory_bank_conflicts = false>}>
#pipeline_layout = #hal.pipeline.layout<bindings = [
#hal.pipeline.binding<storage_buffer>,
@@ -309,8 +309,8 @@
// Basic i8, i8 -> i32 matmul_transpose_b.
-#config = #iree_gpu.lowering_config<{workgroup = [64, 64, 0], reduction = [0, 0, 256], promote_operands = [0, 1]}>
-#translation = #iree_codegen.translation_info<LLVMGPUVectorDistribute workgroup_size = [256, 1, 1] subgroup_size = 64, {gpu_pipeline_options = #iree_gpu.pipeline_options<prefetch_shared_memory = true, no_reduce_shared_memory_bank_conflicts = false>, mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<MFMA_I32_16x16x32_I8>, subgroup_m_count = 2, subgroup_n_count = 2>}>
+#config = #iree_gpu.lowering_config<{workgroup = [64, 64, 0], reduction = [0, 0, 256], promote_operands = [0, 1], mma_kind = #iree_gpu.mma_layout<MFMA_I32_16x16x32_I8>, subgroup_m_count = 2, subgroup_n_count = 2}>
+#translation = #iree_codegen.translation_info<LLVMGPUVectorDistribute workgroup_size = [256, 1, 1] subgroup_size = 64, {gpu_pipeline_options = #iree_gpu.pipeline_options<prefetch_shared_memory = true, no_reduce_shared_memory_bank_conflicts = false>}>
#pipeline_layout = #hal.pipeline.layout<bindings = [
#hal.pipeline.binding<storage_buffer>,
@@ -353,8 +353,8 @@
// -----
-#config = #iree_gpu.lowering_config<{workgroup = [1, 1, 64, 128, 0, 0, 0], reduction = [0, 0, 0, 0, 1, 1, 32], promote_operands = [0, 1]}>
-#translation = #iree_codegen.translation_info<LLVMGPUVectorDistribute workgroup_size = [256, 1, 1] subgroup_size = 64, {gpu_pipeline_options = #iree_gpu.pipeline_options<prefetch_shared_memory = true, no_reduce_shared_memory_bank_conflicts = false>, mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>, subgroup_m_count = 2, subgroup_n_count = 2>}>
+#config = #iree_gpu.lowering_config<{workgroup = [1, 1, 64, 128, 0, 0, 0], reduction = [0, 0, 0, 0, 1, 1, 32], promote_operands = [0, 1], mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>, subgroup_m_count = 2, subgroup_n_count = 2}>
+#translation = #iree_codegen.translation_info<LLVMGPUVectorDistribute workgroup_size = [256, 1, 1] subgroup_size = 64, {gpu_pipeline_options = #iree_gpu.pipeline_options<prefetch_shared_memory = true, no_reduce_shared_memory_bank_conflicts = false>}>
#pipeline_layout = #hal.pipeline.layout<bindings = [
#hal.pipeline.binding<storage_buffer>,
@@ -396,8 +396,8 @@
// -----
-#config = #iree_gpu.lowering_config<{workgroup = [1, 64, 1, 64, 0], reduction = [0, 0, 0, 0, 128], promote_operands = [0, 1]}>
-#translation = #iree_codegen.translation_info<LLVMGPUVectorDistribute workgroup_size = [256, 1, 1] subgroup_size = 64, {gpu_pipeline_options = #iree_gpu.pipeline_options<prefetch_shared_memory = true, no_reduce_shared_memory_bank_conflicts = false>, mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>, subgroup_m_count = 2, subgroup_n_count = 2>}>
+#config = #iree_gpu.lowering_config<{workgroup = [1, 64, 1, 64, 0], reduction = [0, 0, 0, 0, 128], promote_operands = [0, 1], mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>, subgroup_m_count = 2, subgroup_n_count = 2}>
+#translation = #iree_codegen.translation_info<LLVMGPUVectorDistribute workgroup_size = [256, 1, 1] subgroup_size = 64, {gpu_pipeline_options = #iree_gpu.pipeline_options<prefetch_shared_memory = true, no_reduce_shared_memory_bank_conflicts = false>}>
#pipeline_layout = #hal.pipeline.layout<constants = 2, bindings = [
#hal.pipeline.binding<storage_buffer>,
@@ -462,8 +462,8 @@
// -----
-#config = #iree_gpu.lowering_config<{workgroup = [1, 16, 16, 0], reduction = [0, 0, 0, 16], promote_operands = [0, 1]}>
-#translation = #iree_codegen.translation_info<LLVMGPUPadAndVectorDistribute workgroup_size = [64, 1, 1] subgroup_size = 64, {gpu_pipeline_options = #iree_gpu.pipeline_options<prefetch_shared_memory = true, no_reduce_shared_memory_bank_conflicts = false>, mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>, subgroup_m_count = 1, subgroup_n_count = 1>}>
+#config = #iree_gpu.lowering_config<{workgroup = [1, 16, 16, 0], reduction = [0, 0, 0, 16], promote_operands = [0, 1], mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>, subgroup_m_count = 1, subgroup_n_count = 1}>
+#translation = #iree_codegen.translation_info<LLVMGPUPadAndVectorDistribute workgroup_size = [64, 1, 1] subgroup_size = 64, {gpu_pipeline_options = #iree_gpu.pipeline_options<prefetch_shared_memory = true, no_reduce_shared_memory_bank_conflicts = false>}>
#pipeline_layout = #hal.pipeline.layout<bindings = [
#hal.pipeline.binding<storage_buffer>,
@@ -533,8 +533,8 @@
// -----
-#config = #iree_gpu.lowering_config<{workgroup = [1, 16, 32, 0], reduction = [0, 0, 0, 8], promote_operands = [0, 1]}>
-#translation = #iree_codegen.translation_info<LLVMGPUPadAndVectorDistribute workgroup_size = [128, 1, 1] subgroup_size = 64, {mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<MFMA_F32_16x16x4_F32>, subgroup_m_count = 1, subgroup_n_count = 2>}>
+#config = #iree_gpu.lowering_config<{workgroup = [1, 16, 32, 0], reduction = [0, 0, 0, 8], promote_operands = [0, 1], mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x4_F32>, subgroup_m_count = 1, subgroup_n_count = 2}>
+#translation = #iree_codegen.translation_info<LLVMGPUPadAndVectorDistribute workgroup_size = [128, 1, 1] subgroup_size = 64>
#pipeline_layout = #hal.pipeline.layout<bindings = [
#hal.pipeline.binding<storage_buffer>,
@@ -604,8 +604,8 @@
// NOTE: This test is not exhaustive of all possible ways the above condition is breaking,
// but rather is an example of a matmul shape from a model that broke our compilation heuristic.
-#config = #iree_gpu.lowering_config<{workgroup = [1, 16, 128, 0], reduction = [0, 0, 0, 128], promote_operands = [0, 1]}>
-#translation = #iree_codegen.translation_info<LLVMGPUVectorDistribute workgroup_size = [256, 1, 1] subgroup_size = 64, {gpu_pipeline_options = #iree_gpu.pipeline_options<prefetch_shared_memory = true, no_reduce_shared_memory_bank_conflicts = false>, mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>, subgroup_m_count = 1, subgroup_n_count = 4>}>
+#config = #iree_gpu.lowering_config<{workgroup = [1, 16, 128, 0], reduction = [0, 0, 0, 128], promote_operands = [0, 1], mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>, subgroup_m_count = 1, subgroup_n_count = 4}>
+#translation = #iree_codegen.translation_info<LLVMGPUVectorDistribute workgroup_size = [256, 1, 1] subgroup_size = 64, {gpu_pipeline_options = #iree_gpu.pipeline_options<prefetch_shared_memory = true, no_reduce_shared_memory_bank_conflicts = false>}>
#pipeline_layout = #hal.pipeline.layout<constants = 3, bindings = [
#hal.pipeline.binding<storage_buffer>,
@@ -657,8 +657,8 @@
// This test ensures that we can generate and decompose the right instructions from V(Virtual) MFMAs.
-#config = #iree_gpu.lowering_config<{workgroup = [64, 64, 0], reduction = [0, 0, 128], promote_operands = [0, 1]}>
-#translation = #iree_codegen.translation_info<LLVMGPUVectorDistribute workgroup_size = [256, 1, 1] subgroup_size = 64, {gpu_pipeline_options = #iree_gpu.pipeline_options<prefetch_shared_memory = true, no_reduce_shared_memory_bank_conflicts = false>, mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<VMFMA_F32_32x32x16_F16>, subgroup_m_count = 2, subgroup_n_count = 2>}>
+#config = #iree_gpu.lowering_config<{workgroup = [64, 64, 0], reduction = [0, 0, 128], promote_operands = [0, 1], mma_kind = #iree_gpu.mma_layout<VMFMA_F32_32x32x16_F16>, subgroup_m_count = 2, subgroup_n_count = 2}>
+#translation = #iree_codegen.translation_info<LLVMGPUVectorDistribute workgroup_size = [256, 1, 1] subgroup_size = 64, {gpu_pipeline_options = #iree_gpu.pipeline_options<prefetch_shared_memory = true, no_reduce_shared_memory_bank_conflicts = false>}>
#pipeline_layout = #hal.pipeline.layout<bindings = [
#hal.pipeline.binding<storage_buffer>,
@@ -718,8 +718,78 @@
// -----
+// This test ensures we can generate correct instructions from V(Virtual) MFMAs that has only different read layouts.
+
+#config = #iree_gpu.lowering_config<{workgroup = [32, 32, 0], reduction = [0, 0, 128], promote_operands = [0, 1], mma_kind = #iree_gpu.mma_layout<VMFMA_F32_16x16x32_F8E4M3FNUZ>, subgroup_m_count = 2, subgroup_n_count = 2}>
+#translation = #iree_codegen.translation_info<LLVMGPUVectorDistribute workgroup_size = [256, 1, 1] subgroup_size = 64, {gpu_pipeline_options = #iree_gpu.pipeline_options<prefetch_shared_memory = true, no_reduce_shared_memory_bank_conflicts = false>}>
+
+#pipeline_layout = #hal.pipeline.layout<bindings = [
+ #hal.pipeline.binding<storage_buffer>,
+ #hal.pipeline.binding<storage_buffer>,
+ #hal.pipeline.binding<storage_buffer>
+]>
+hal.executable @virtual_intrinsic_256x256x256_f8E4M3FNUZ_f32 {
+hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb">) {
+ hal.executable.export @virtual_intrinsic_256x256x256_f8E4M3FNUZ_f32 layout(#pipeline_layout) {
+ ^bb0(%arg0: !hal.device, %arg1: index, %arg2 : index):
+ %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg1, %arg2
+ hal.return %x, %y, %z : index, index, index
+ }
+ builtin.module {
+ func.func @virtual_intrinsic_256x256x256_f8E4M3FNUZ_f32() attributes {translation_info = #translation} {
+ %cst = arith.constant 0.000000e+00 : f32
+ %c0 = arith.constant 0 : index
+ %0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<256x256xf8E4M3FNUZ>>
+ %1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<256x256xf8E4M3FNUZ>>
+ %2 = hal.interface.binding.subspan layout(#pipeline_layout) binding(2) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<256x256xf32>>
+ %3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [256, 256], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<256x256xf8E4M3FNUZ>> -> tensor<256x256xf8E4M3FNUZ>
+ %4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [256, 256], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<256x256xf8E4M3FNUZ>> -> tensor<256x256xf8E4M3FNUZ>
+ %5 = tensor.empty() : tensor<256x256xf32>
+ %6 = linalg.fill ins(%cst : f32) outs(%5 : tensor<256x256xf32>) -> tensor<256x256xf32>
+ %7 = linalg.matmul {lowering_config = #config} ins(%3, %4 : tensor<256x256xf8E4M3FNUZ>, tensor<256x256xf8E4M3FNUZ>) outs(%6 : tensor<256x256xf32>) -> tensor<256x256xf32>
+ flow.dispatch.tensor.store %7, %2, offsets = [0, 0], sizes = [256, 256], strides = [1, 1] : tensor<256x256xf32> -> !flow.dispatch.tensor<writeonly:tensor<256x256xf32>>
+ return
+ }
+ }
+}
+}
+
+// CHECK-LABEL: func @virtual_intrinsic_256x256x256_f8E4M3FNUZ_f32
+// CHECK-DAG: %[[ALLOC_LHS:.+]] = memref.alloc() : memref<32x136xf8E4M3FNUZ, #gpu.address_space<workgroup>>
+// CHECK-DAG: %[[ALLOC_RHS:.+]] = memref.alloc() : memref<128x40xf8E4M3FNUZ, #gpu.address_space<workgroup>>
+// CHECK: scf.for {{.*}} = %c0 to %c256 step %c128 iter_args(%[[ARG:.+]] = {{.*}}) -> (vector<1x1x1x1x4x1xf32>)
+
+// Validate that VMFMA do 2 interleaved reads, combine them for every MFMA:
+
+// CHECK-COUNT-6: vector.transfer_read %[[ALLOC_LHS]]
+// CHECK: %[[SLICE_LHS_0:.+]] = vector.transfer_read %[[ALLOC_LHS]]
+// CHECK: %[[VECTOR_LHS_0:.+]] = vector.insert_strided_slice %[[SLICE_LHS_0]], %{{.*}}
+// CHECK: %[[SLICE_LHS_1:.+]] = vector.transfer_read %[[ALLOC_LHS]]
+// CHECK: %[[VECTOR_LHS_1:.+]] = vector.insert_strided_slice %[[SLICE_LHS_1]], %[[VECTOR_LHS_0]] {{.*}} : vector<1x4xf8E4M3FNUZ> into vector<1x4x1x2x1x4xf8E4M3FNUZ>
+
+// CHECK-COUNT-6: vector.transfer_read %[[ALLOC_RHS]]
+// CHECK: %[[SLICE_RHS_0:.+]] = vector.transfer_read %[[ALLOC_RHS]]
+// CHECK: %[[VECTOR_RHS_0:.+]] = vector.insert_strided_slice %[[SLICE_RHS_0]], %{{.*}}
+// CHECK: %[[SLICE_RHS_1:.+]] = vector.transfer_read %[[ALLOC_RHS]]
+// CHECK: %[[VECTOR_RHS_1:.+]] = vector.insert_strided_slice %[[SLICE_RHS_1]], %[[VECTOR_RHS_0]] {{.*}} : vector<4x1xf8E4M3FNUZ> into vector<4x1x2x1x4x1xf8E4M3FNUZ>
+
+// CHECK: %[[EXTRACT_LHS:.+]] = vector.extract %[[VECTOR_LHS_1]][{{.*}}, {{.*}}] : vector<1x2x1x4xf8E4M3FNUZ> from vector<1x4x1x2x1x4xf8E4M3FNUZ>
+// CHECK: %[[EXTRACT_RHS:.+]] = vector.extract %[[VECTOR_RHS_1]][{{.*}}, {{.*}}] : vector<2x1x4x1xf8E4M3FNUZ> from vector<4x1x2x1x4x1xf8E4M3FNUZ>
+
+// CHECK: %[[LHS_CAST:.+]] = vector.shape_cast %[[EXTRACT_LHS]] : vector<1x2x1x4xf8E4M3FNUZ> to vector<8xf8E4M3FNUZ>
+// CHECK: %[[RHS_CAST:.+]] = vector.shape_cast %[[EXTRACT_RHS]] : vector<2x1x4x1xf8E4M3FNUZ> to vector<8xf8E4M3FNUZ>
+// CHECK: amdgpu.mfma %[[LHS_CAST]] * %[[RHS_CAST]] + %{{.*}} {blocks = 1 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32}
+// CHECK-SAME: : vector<8xf8E4M3FNUZ>, vector<8xf8E4M3FNUZ>, vector<4xf32>
+
+// Ensure right number of instructions are being generated.
+// CHECK-COUNT-3: amdgpu.mfma
+
+// CHECK: scf.yield
+
+// -----
+
#config = #iree_gpu.lowering_config<{workgroup = [1, 64, 0, 0, 64], reduction = [0, 0, 0, 64, 0], promote_operands = [0, 1, 2]}>
-#translation = #iree_codegen.translation_info<LLVMGPUVectorDistribute workgroup_size = [128, 1, 1] subgroup_size = 64, {mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>, subgroup_m_count = 2, subgroup_n_count = 1>}>
+#translation = #iree_codegen.translation_info<LLVMGPUVectorDistribute workgroup_size = [128, 1, 1] subgroup_size = 64>
#pipeline_layout = #hal.pipeline.layout<bindings = [
#hal.pipeline.binding<storage_buffer>,
@@ -753,8 +823,8 @@
affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>],
lowering_config = #config,
decomposition_config = {
- qk_attrs = {attention_qk_matmul, lowering_config = #iree_gpu.lowering_config<{promote_operands = [0, 1]}>},
- pv_attrs = {attention_pv_matmul, lowering_config = #iree_gpu.lowering_config<{promote_operands = [1]}>}
+ qk_attrs = {attention_qk_matmul, lowering_config = #iree_gpu.lowering_config<{mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>, subgroup_m_count = 2, subgroup_n_count = 1, promote_operands = [0, 1]}>},
+ pv_attrs = {attention_pv_matmul, lowering_config = #iree_gpu.lowering_config<{mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>, subgroup_m_count = 2, subgroup_n_count = 1, promote_operands = [1]}>}
}}
ins(%4, %5, %6, %cst : tensor<20x4096x64xf16>, tensor<20x4096x64xf16>, tensor<20x4096x64xf16>, f16) outs(%7 : tensor<20x4096x64xf16>) {
^bb0(%score: f32):
@@ -792,7 +862,7 @@
// -----
#config = #iree_gpu.lowering_config<{workgroup = [1, 1, 64, 0, 0, 64], reduction = [0, 0, 0, 0, 64, 0], promote_operands = [0, 1, 2]}>
-#translation = #iree_codegen.translation_info<LLVMGPUVectorDistribute workgroup_size = [128, 1, 1] subgroup_size = 64, {mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>, subgroup_m_count = 2, subgroup_n_count = 1>}>
+#translation = #iree_codegen.translation_info<LLVMGPUVectorDistribute workgroup_size = [128, 1, 1] subgroup_size = 64>
#pipeline_layout = #hal.pipeline.layout<bindings = [
#hal.pipeline.binding<storage_buffer>,
@@ -827,8 +897,8 @@
affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)>],
lowering_config = #config,
decomposition_config = {
- qk_attrs = {attention_qk_matmul, lowering_config = #iree_gpu.lowering_config<{promote_operands = [0, 1]}>},
- pv_attrs = {attention_pv_matmul, lowering_config = #iree_gpu.lowering_config<{promote_operands = [1]}>}
+ qk_attrs = {attention_qk_matmul, lowering_config = #iree_gpu.lowering_config<{mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>, subgroup_m_count = 2, subgroup_n_count = 1, promote_operands = [0, 1]}>},
+ pv_attrs = {attention_pv_matmul, lowering_config = #iree_gpu.lowering_config<{mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>, subgroup_m_count = 2, subgroup_n_count = 1, promote_operands = [1]}>}
}}
ins(%4, %5, %6, %cst : tensor<24x64x4608x128xf16>, tensor<24x4608x128xf16>, tensor<24x4608x128xf16>, f16) outs(%8 : tensor<24x64x4608x128xf16>) {
^bb0(%score: f32):
@@ -860,7 +930,7 @@
// -----
#config = #iree_gpu.lowering_config<{workgroup = [1, 1, 128, 0, 0, 64], reduction = [0, 0, 0, 0, 32, 0], promote_operands = [0, 1, 2]}>
-#translation = #iree_codegen.translation_info<LLVMGPUVectorDistribute workgroup_size = [256, 1, 1] subgroup_size = 64, {mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<MFMA_F32_32x32x8_F16>, subgroup_m_count = 4, subgroup_n_count = 1>}>
+#translation = #iree_codegen.translation_info<LLVMGPUVectorDistribute workgroup_size = [256, 1, 1] subgroup_size = 64>
#pipeline_layout = #hal.pipeline.layout<bindings = [
#hal.pipeline.binding<storage_buffer>,
@@ -895,8 +965,8 @@
affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)>],
lowering_config = #config,
decomposition_config = {
- qk_attrs = {attention_qk_matmul, lowering_config = #iree_gpu.lowering_config<{promote_operands = [0, 1]}>},
- pv_attrs = {attention_pv_matmul, lowering_config = #iree_gpu.lowering_config<{promote_operands = [1]}>}
+ qk_attrs = {attention_qk_matmul, lowering_config = #iree_gpu.lowering_config<{mma_kind = #iree_gpu.mma_layout<MFMA_F32_32x32x8_F16>, subgroup_m_count = 4, subgroup_n_count = 1, promote_operands = [0, 1]}>},
+ pv_attrs = {attention_pv_matmul, lowering_config = #iree_gpu.lowering_config<{mma_kind = #iree_gpu.mma_layout<MFMA_F32_32x32x8_F16>, subgroup_m_count = 4, subgroup_n_count = 1, promote_operands = [1]}>}
}}
ins(%4, %5, %6, %cst : tensor<24x64x4608x128xf16>, tensor<24x4608x128xf16>, tensor<24x4608x128xf16>, f16) outs(%8 : tensor<24x64x4608x128xf16>) {
^bb0(%score: f32):
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/config_custom_op.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/config_custom_op.mlir
index 032bd68..2dd57f3 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/config_custom_op.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/config_custom_op.mlir
@@ -34,14 +34,13 @@
}
// CHECK: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[64, 64, 0]]>
// CHECK: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<LLVMGPUVectorDistribute workgroup_size = [256, 1, 1] subgroup_size = 64,
-// CHECK-SAME: mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<MFMA_F32_16x16x4_F32>, subgroup_m_count = 2, subgroup_n_count = 2>
// CHECK: func @custom_op
// CHECK-SAME: translation_info = #[[TRANSLATION]]
// CHECK: iree_linalg_ext.custom_op
// CHECK-SAME: lowering_config = #[[CONFIG]]
// CHECK: ^bb
// CHECK: linalg.matmul
-// CHECK-SAME: lowering_config = #iree_gpu.lowering_config<{promote_operands = [0, 1], reduction = [0, 0, 32], workgroup = [64, 64, 0]}>
+// CHECK-SAME: lowering_config = #iree_gpu.lowering_config<{mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x4_F32>, promote_operands = [0, 1], reduction = [0, 0, 32], subgroup_m_count = 2 : i64, subgroup_n_count = 2 : i64, workgroup = [64, 64, 0]}>
// CHECK: iree_linalg_ext.yield
// -----
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/configure_tensor_layout.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/configure_tensor_layout.mlir
index 6c96d15..937174d 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/configure_tensor_layout.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/configure_tensor_layout.mlir
@@ -2,10 +2,7 @@
#translation = #iree_codegen.translation_info<LLVMGPUVectorDistribute
workgroup_size = [64, 1, 1]
- subgroup_size = 64,
- {mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<MFMA_F32_32x32x8_F16>,
- subgroup_m_count = 1,
- subgroup_n_count = 1>}>
+ subgroup_size = 64>
#maps = [
affine_map<(m, n, k) -> (m, k)>,
@@ -15,7 +12,9 @@
#traits = {
indexing_maps = #maps,
- iterator_types = ["parallel", "parallel", "reduction"]
+ iterator_types = ["parallel", "parallel", "reduction"],
+ lowering_config = #iree_gpu.lowering_config<{mma_kind = #iree_gpu.mma_layout<MFMA_F32_32x32x8_F16>,
+ subgroup_m_count = 1, subgroup_n_count = 1}>
}
func.func @matmul_96x64x16_mfma(%lhs: tensor<96x16xf16>,
@@ -53,10 +52,7 @@
#translation = #iree_codegen.translation_info<LLVMGPUVectorDistribute
workgroup_size = [64, 1, 1]
- subgroup_size = 64,
- {mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<WMMA_F32_16x16x16_F16>,
- subgroup_m_count = 1,
- subgroup_n_count = 1>}>
+ subgroup_size = 64>
#maps = [
affine_map<(m, n, k) -> (m, k)>,
@@ -66,7 +62,9 @@
#traits = {
indexing_maps = #maps,
- iterator_types = ["parallel", "parallel", "reduction"]
+ iterator_types = ["parallel", "parallel", "reduction"],
+ lowering_config = #iree_gpu.lowering_config<{mma_kind = #iree_gpu.mma_layout<WMMA_F32_16x16x16_F16>,
+ subgroup_m_count = 1 : i64, subgroup_n_count = 1 : i64}>
}
func.func @matmul_96x64x16_wmma(%lhs: tensor<96x16xf16>,
@@ -117,7 +115,9 @@
#traits = {
indexing_maps = #maps,
- iterator_types = ["parallel", "parallel", "reduction"]
+ iterator_types = ["parallel", "parallel", "reduction"],
+ lowering_config = #iree_gpu.lowering_config<{mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>,
+ subgroup_m_count = 4, subgroup_n_count = 1}>
}
func.func @matmul_128x64x16_multi_subgroup(%lhs: tensor<128x16xf16>,
@@ -155,10 +155,7 @@
#translation = #iree_codegen.translation_info<LLVMGPUVectorDistribute
workgroup_size = [64, 1, 1]
- subgroup_size = 64,
- {mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>,
- subgroup_m_count = 2,
- subgroup_n_count = 2>}>
+ subgroup_size = 64>
#maps = [
affine_map<(bm, bn, m, n, k) -> (bm, m, k)>,
@@ -169,7 +166,9 @@
#traits = {
indexing_maps = #maps,
iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"],
- lowering_config = #iree_gpu.lowering_config<{promote_operands = [0]}>
+ lowering_config = #iree_gpu.lowering_config<{promote_operands = [0],
+ mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>,
+ subgroup_m_count = 2, subgroup_n_count = 2}>
}
func.func @packed_matmul_128x128x128(%lhs: tensor<8x16x16xf16>,
@@ -205,13 +204,9 @@
// -----
-// TODO: We shouldn't have to specify mma_schedule here.
#translation = #iree_codegen.translation_info<LLVMGPUVectorDistribute
workgroup_size = [64, 1, 1]
- subgroup_size = 64,
- {mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>,
- subgroup_m_count = 1,
- subgroup_n_count = 1>}>
+ subgroup_size = 64>
func.func @linalg_copy(%in : tensor<16x16x16xf16>) -> tensor<16x16x16xf16>
attributes { translation_info = #translation } {
@@ -233,10 +228,7 @@
#translation = #iree_codegen.translation_info<LLVMGPUVectorDistribute
workgroup_size = [64, 1, 1]
- subgroup_size = 64,
- {mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>,
- subgroup_m_count = 1,
- subgroup_n_count = 1>}>
+ subgroup_size = 64>
#map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2)>
#map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/conv_pipeline_test_rocm.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/conv_pipeline_test_rocm.mlir
deleted file mode 100644
index b33502e..0000000
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/conv_pipeline_test_rocm.mlir
+++ /dev/null
@@ -1,53 +0,0 @@
-// RUN: iree-opt --split-input-file --iree-gpu-test-target=gfx1100 \
-// RUN: --pass-pipeline='builtin.module(hal.executable(hal.executable.variant(builtin.module(iree-llvmgpu-select-lowering-strategy, func.func(iree-llvmgpu-lower-executable-target,canonicalize)))))' \
-// RUN: %s | FileCheck %s
-
-#pipeline_layout = #hal.pipeline.layout<bindings = [
- #hal.pipeline.binding<storage_buffer>,
- #hal.pipeline.binding<storage_buffer>,
- #hal.pipeline.binding<storage_buffer>,
- #hal.pipeline.binding<storage_buffer>
-]>
-hal.executable private @conv_nchw_dispatch_1 {
- hal.executable.variant public @rocm_hsaco_fb target(<"rocm", "rocm-hsaco-fb">) {
- hal.executable.export public @conv_2d_nchw_fchw_2x320x64x64x320x3x3_f16 ordinal(0) layout(#pipeline_layout) attributes {
- translation_info = #iree_codegen.translation_info<LLVMGPUVectorize workgroup_size = [16, 2, 1]>
- } {
- ^bb0(%arg0: !hal.device):
- %x, %y, %z = flow.dispatch.workgroup_count_from_slice
- hal.return %x, %y, %z : index, index, index
- }
- builtin.module {
- func.func @conv_2d_nchw_fchw_2x320x64x64x320x3x3_f16() {
- %cst = arith.constant 0.000000e+00 : f16
- %c0 = arith.constant 0 : index
- %0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<2x320x130x130xf16>>
- %1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<320x320x3x3xf16>>
- %2 = hal.interface.binding.subspan layout(#pipeline_layout) binding(2) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<320xf16>>
- %3 = hal.interface.binding.subspan layout(#pipeline_layout) binding(3) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<2x320x64x64xf16>>
- %4 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0, 0], sizes = [2, 320, 130, 130], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<2x320x130x130xf16>> -> tensor<2x320x130x130xf16>
- %5 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0, 0], sizes = [320, 320, 3, 3], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<320x320x3x3xf16>> -> tensor<320x320x3x3xf16>
- %6 = flow.dispatch.tensor.load %2, offsets = [0], sizes = [320], strides = [1] : !flow.dispatch.tensor<readonly:tensor<320xf16>> -> tensor<320xf16>
- %7 = tensor.empty() : tensor<2x320x64x64xf16>
- %8 = linalg.fill {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[1, 1, 8, 64, 4, 1, 1], [0, 0, 1, 0]]>} ins(%cst : f16) outs(%7 : tensor<2x320x64x64xf16>) -> tensor<2x320x64x64xf16>
- %9 = linalg.conv_2d_nchw_fchw {dilations = dense<1> : vector<2xi64>, lowering_config = #iree_codegen.lowering_config<tile_sizes = [[1, 1, 8, 64, 4, 1, 1], [0, 0, 1, 0]]>, strides = dense<2> : vector<2xi64>} ins(%4, %5 : tensor<2x320x130x130xf16>, tensor<320x320x3x3xf16>) outs(%8 : tensor<2x320x64x64xf16>) -> tensor<2x320x64x64xf16>
- %10 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d1)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%9, %6 : tensor<2x320x64x64xf16>, tensor<320xf16>) outs(%7 : tensor<2x320x64x64xf16>) attrs = {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[1, 1, 8, 64, 4, 1, 1], [0, 0, 1, 0]]>} {
- ^bb0(%in: f16, %in_0: f16, %out: f16):
- %11 = arith.addf %in, %in_0 : f16
- linalg.yield %11 : f16
- } -> tensor<2x320x64x64xf16>
- flow.dispatch.tensor.store %10, %3, offsets = [0, 0, 0, 0], sizes = [2, 320, 64, 64], strides = [1, 1, 1, 1] : tensor<2x320x64x64xf16> -> !flow.dispatch.tensor<writeonly:tensor<2x320x64x64xf16>>
- return
- }
- }
- }
-}
-
-// TODO: This test reflects a bug related to how the convolution is bufferized
-// for the LLVMGPUVectorize pipeline, meaning these local memory allocations are
-// not desired. This test should be dropped once the extra buffers have been
-// eliminated.
-
-// CHECK-LABEL: func @conv_2d_nchw_fchw_2x320x64x64x320x3x3_f16
-// CHECK-COUNT-3: memref.alloca() : memref<1x1x1x4xf16, #gpu.address_space<private>>
-// CHECK-COUNT-3: memref.copy %{{.*}}, %{{.*}} : memref<1x1x1x4xf16, #gpu.address_space<private>> to memref<{{.*}} #hal.descriptor_type<storage_buffer>>
diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp
index 83ac234..b14bb08 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp
+++ b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp
@@ -105,7 +105,7 @@
return APInt(16, b0 | (b1 << 8));
} else if (b0 == b4 && b1 == b5 && b2 == b6 && b3 == b7) {
// 0xAABBCCDDAABBCCDD : i64 => 0xAABBCCDD : i32
- return APInt(32, b0 | (b1 << 8) | (b2 << 16) | (b3 << 32));
+ return APInt(32, b0 | (b1 << 8) | (b2 << 16) | (b3 << 24));
}
return pattern;
}
diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/test/tensor_folding.mlir b/compiler/src/iree/compiler/Dialect/Stream/IR/test/tensor_folding.mlir
index 08ad4a1..05ef820 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/IR/test/tensor_folding.mlir
+++ b/compiler/src/iree/compiler/Dialect/Stream/IR/test/tensor_folding.mlir
@@ -145,7 +145,7 @@
util.func private @NarrowSplatPatternI64ToI32() -> !stream.resource<*> {
%c100 = arith.constant 100 : index
%pattern = arith.constant 0xAABBCCDDAABBCCDD : i64
- // CHECK: stream.tensor.splat %c12307677_i32
+ // CHECK: stream.tensor.splat %c-1430532899_i32
%0 = stream.tensor.splat %pattern : i64 -> tensor<2x2xf32> in !stream.resource<*>{%c100}
util.return %0 : !stream.resource<*>
}
diff --git a/tests/e2e/matmul/CMakeLists.txt b/tests/e2e/matmul/CMakeLists.txt
index f229434..a70cb98 100644
--- a/tests/e2e/matmul/CMakeLists.txt
+++ b/tests/e2e/matmul/CMakeLists.txt
@@ -1368,7 +1368,7 @@
iree_generated_e2e_runner_test(
NAME
- e2e_matmul_rocm_f16_large_cdna3_mfma
+ e2e_matmul_cdna3_vecdistmfma_f16
TEST_TYPE
matmul
GENERATOR
@@ -1396,7 +1396,7 @@
iree_generated_e2e_runner_test(
NAME
- e2e_matmul_rocm_f32_large_cdna3_mfma
+ e2e_matmul_cdna3_vecdistmfma_f32
TEST_TYPE
matmul
GENERATOR
@@ -1424,7 +1424,7 @@
iree_generated_e2e_runner_test(
NAME
- e2e_matmul_rocm_f16_large_cdna3_mfma_tb
+ e2e_matmul_cdna3_vecdistmfma_tb_f16
TEST_TYPE
matmul
GENERATOR
@@ -1458,7 +1458,7 @@
iree_generated_e2e_runner_test(
NAME
- e2e_matmul_rocm_f8_large_cdna3_mfma
+ e2e_matmul_cdna3_vecdistmfma_f8E4M3FNUZ
TEST_TYPE
matmul
GENERATOR
@@ -1486,7 +1486,7 @@
iree_generated_e2e_runner_test(
NAME
- e2e_matmul_rocm_i8_large_cdna3_mfma_tb
+ e2e_matmul_cdna3_vecdistmfma_i8
TEST_TYPE
matmul
GENERATOR
@@ -1515,35 +1515,7 @@
iree_generated_e2e_runner_test(
NAME
- e2e_matmul_cdna_experimental_dt_f32_f32
- TEST_TYPE
- matmul
- GENERATOR
- "generate_e2e_matmul_tests.py"
- GENERATOR_ARGS
- "--lhs_rhs_type=f32"
- "--acc_type=f32"
- TEST_RUNNER
- iree_tools_testing_e2e_iree-e2e-matmul-test
- TARGET_BACKENDS
- "rocm"
- DRIVERS
- "hip"
- COMPILER_FLAGS
- ${IREE_HIP_TEST_COMPILER_FLAGS}
- "--iree-opt-data-tiling"
- "--iree-global-opt-enable-early-materialization=false"
- LABELS
- "noasan"
- "nomsan"
- "notsan"
- "noubsan"
- "requires-gpu-cdna3"
-)
-
-iree_generated_e2e_runner_test(
- NAME
- e2e_matmul_rocm_f16_cdna3_mfma_data_tiled
+ e2e_matmul_cdna3_dt_f16
TEST_TYPE
matmul
GENERATOR
@@ -1572,7 +1544,7 @@
iree_generated_e2e_runner_test(
NAME
- e2e_matmul_rocm_bf16_cdna3_mfma_data_tiled
+ e2e_matmul_cdna3_dt_bf16
TEST_TYPE
matmul
GENERATOR
@@ -1601,7 +1573,7 @@
iree_generated_e2e_runner_test(
NAME
- e2e_matmul_rocm_i8_cdna3_mfma_data_tiled
+ e2e_matmul_cdna3_dt_i8
TEST_TYPE
matmul
GENERATOR
@@ -1630,7 +1602,7 @@
iree_generated_e2e_runner_test(
NAME
- e2e_matmul_rocm_f32_cdna3_mfma_data_tiled
+ e2e_matmul_cdna3_dt_f32
TEST_TYPE
matmul
GENERATOR
@@ -1659,7 +1631,7 @@
iree_generated_e2e_runner_test(
NAME
- e2e_matmul_rocm_f8E5M2FNUZ_cdna3_mfma_data_tiled
+ e2e_matmul_cdna3_dt_f8E5M2FNUZ
TEST_TYPE
matmul
GENERATOR
@@ -1688,7 +1660,7 @@
iree_generated_e2e_runner_test(
NAME
- e2e_matmul_rocm_f8E4M3FNUZ_cdna3_mfma_data_tiled
+ e2e_matmul_cdna3_dt_f8E4M3FNUZ
TEST_TYPE
matmul
GENERATOR
@@ -1816,33 +1788,4 @@
"requires-gpu-rdna3"
)
-iree_generated_e2e_runner_test(
- NAME
- e2e_matmul_rdna3_experimental_dt_f32_f32
- TEST_TYPE
- matmul
- GENERATOR
- "generate_e2e_matmul_tests.py"
- GENERATOR_ARGS
- "--lhs_rhs_type=f32"
- "--acc_type=f32"
- "--shapes=small"
- TEST_RUNNER
- iree_tools_testing_e2e_iree-e2e-matmul-test
- TARGET_BACKENDS
- "rocm"
- DRIVERS
- "hip"
- COMPILER_FLAGS
- ${IREE_HIP_TEST_COMPILER_FLAGS}
- "--iree-opt-data-tiling"
- "--iree-global-opt-enable-early-materialization=false"
- LABELS
- "noasan"
- "nomsan"
- "notsan"
- "noubsan"
- "requires-gpu-rdna3"
-)
-
endif()
diff --git a/tests/e2e/matmul/generate_e2e_matmul_tests.py b/tests/e2e/matmul/generate_e2e_matmul_tests.py
index dd387f3..9ced22c 100644
--- a/tests/e2e/matmul/generate_e2e_matmul_tests.py
+++ b/tests/e2e/matmul/generate_e2e_matmul_tests.py
@@ -135,9 +135,6 @@
requested_pipeline = self.dispatch_lowering_pass_pipeline
compiler_pipeline = requested_pipeline
- mma_schedule = ""
- if self.mma_schedule is not None:
- mma_schedule = "{}".format(self.mma_schedule)
subgroup_size_str = ""
if self.subgroup_size is not None:
subgroup_size_str = f"subgroup_size = {self.subgroup_size}"
@@ -145,11 +142,13 @@
return (
"#iree_codegen.compilation_info<\n"
f" lowering_config = #iree_gpu.lowering_config<{{"
+ f" mma_kind = #iree_gpu.mma_layout<{self.mma_schedule.intrinsic}>, "
+ f" subgroup_m_count = {self.mma_schedule.m_count}, "
+ f" subgroup_n_count = {self.mma_schedule.n_count}, "
f" workgroup = {self.workgroup_tile}, "
f" reduction = {self.reduction_tile} }}>,\n"
f" translation_info = <{compiler_pipeline} {self.workgroup_size_str()}\n"
- f" {subgroup_size_str},\n"
- f" {{ {mma_schedule} }}>>\n"
+ f" {subgroup_size_str}>>\n"
)
@@ -351,6 +350,8 @@
MMASchedule("VMFMA_F32_16x16x32_F16", 4, 2, 1, 2, 4),
MMASchedule("VMFMA_F32_32x32x16_F16", 1, 1, 1, 1, 1),
MMASchedule("VMFMA_F32_32x32x16_F16", 4, 2, 1, 2, 4),
+ MMASchedule("VMFMA_F32_16x16x32_F8E4M3FNUZ", 1, 1, 1, 1, 1),
+ MMASchedule("VMFMA_F32_16x16x32_F8E4M3FNUZ", 4, 1, 4, 1, 1),
]
elif intrinsic == "WMMA":
schedules = [
@@ -400,6 +401,7 @@
schedule.intrinsic == "VMFMA_F32_16x16x32_F16"
or schedule.intrinsic == "MFMA_I32_16x16x32_I8"
or schedule.intrinsic == "MFMA_F32_16x16x32_F8E4M3FNUZ"
+ or schedule.intrinsic == "VMFMA_F32_16x16x32_F8E4M3FNUZ"
):
wg_tile_m = schedule.m_count * schedule.m_tile_count * 16
wg_tile_n = schedule.n_count * schedule.n_tile_count * 16
diff --git a/tests/external/iree-test-suites/onnx_ops/onnx_ops_cpu_llvm_sync.json b/tests/external/iree-test-suites/onnx_ops/onnx_ops_cpu_llvm_sync.json
index 7242c40..2b845a5 100644
--- a/tests/external/iree-test-suites/onnx_ops/onnx_ops_cpu_llvm_sync.json
+++ b/tests/external/iree-test-suites/onnx_ops/onnx_ops_cpu_llvm_sync.json
@@ -134,7 +134,6 @@
"onnx/node/generated/test_gridsample_volumetric_bilinear_align_corners_1",
"onnx/node/generated/test_gridsample_volumetric_nearest_align_corners_0",
"onnx/node/generated/test_gridsample_volumetric_nearest_align_corners_1",
- "onnx/node/generated/test_if",
"onnx/node/generated/test_image_decoder_decode_bmp_rgb",
"onnx/node/generated/test_image_decoder_decode_jpeg2k_rgb",
"onnx/node/generated/test_image_decoder_decode_jpeg_bgr",
@@ -379,8 +378,6 @@
"onnx/node/generated/test_sce_none_weights_expanded",
"onnx/node/generated/test_sce_none_weights_log_prob",
"onnx/node/generated/test_sce_none_weights_log_prob_expanded",
- "onnx/node/generated/test_shape_clip_start",
- "onnx/node/generated/test_shape_end_negative_1",
"onnx/node/generated/test_slice",
"onnx/node/generated/test_slice_default_steps",
"onnx/node/generated/test_slice_end_out_of_bounds",
@@ -392,6 +389,9 @@
"onnx/node/generated/test_softsign_example",
"onnx/node/generated/test_stft",
"onnx/node/generated/test_stft_with_window",
+ "onnx/node/generated/test_tfidfvectorizer_tf_batch_onlybigrams_skip0",
+ "onnx/node/generated/test_tfidfvectorizer_tf_batch_onlybigrams_skip5",
+ "onnx/node/generated/test_tfidfvectorizer_tf_batch_uniandbigrams_skip5",
"onnx/node/generated/test_training_dropout",
"onnx/node/generated/test_training_dropout_default",
"onnx/node/generated/test_training_dropout_default_mask",
@@ -438,10 +438,7 @@
"onnx/node/generated/test_reduce_min_empty_set",
"onnx/node/generated/test_reduce_sum_empty_set_non_reduced_axis_zero",
"onnx/node/generated/test_resize_downsample_scales_linear_align_corners",
- "onnx/node/generated/test_scan9_sum",
- "onnx/node/generated/test_scan_sum",
"onnx/node/generated/test_shape_end_1",
- "onnx/node/generated/test_shape_start_1",
"onnx/node/generated/test_shape_start_1_end_2",
"onnx/node/generated/test_shape_start_1_end_negative_1",
"onnx/node/generated/test_shape_start_negative_1",
diff --git a/tests/external/iree-test-suites/onnx_ops/onnx_ops_gpu_rocm_rdna3.json b/tests/external/iree-test-suites/onnx_ops/onnx_ops_gpu_rocm_rdna3.json
index a8713b6..11beff6 100644
--- a/tests/external/iree-test-suites/onnx_ops/onnx_ops_gpu_rocm_rdna3.json
+++ b/tests/external/iree-test-suites/onnx_ops/onnx_ops_gpu_rocm_rdna3.json
@@ -138,7 +138,6 @@
"onnx/node/generated/test_gridsample_volumetric_bilinear_align_corners_1",
"onnx/node/generated/test_gridsample_volumetric_nearest_align_corners_0",
"onnx/node/generated/test_gridsample_volumetric_nearest_align_corners_1",
- "onnx/node/generated/test_if",
"onnx/node/generated/test_image_decoder_decode_bmp_rgb",
"onnx/node/generated/test_image_decoder_decode_jpeg2k_rgb",
"onnx/node/generated/test_image_decoder_decode_jpeg_bgr",
@@ -380,8 +379,6 @@
"onnx/node/generated/test_sce_none_weights_expanded",
"onnx/node/generated/test_sce_none_weights_log_prob",
"onnx/node/generated/test_sce_none_weights_log_prob_expanded",
- "onnx/node/generated/test_shape_clip_start",
- "onnx/node/generated/test_shape_end_negative_1",
"onnx/node/generated/test_slice",
"onnx/node/generated/test_slice_default_steps",
"onnx/node/generated/test_slice_end_out_of_bounds",
@@ -488,11 +485,11 @@
"onnx/node/generated/test_reduce_sum_square_default_axes_keepdims_random",
"onnx/node/generated/test_reduce_sum_square_default_axes_keepdims_random_expanded",
"onnx/node/generated/test_resize_downsample_scales_linear_align_corners",
- "onnx/node/generated/test_scan9_sum",
- "onnx/node/generated/test_scan_sum",
"onnx/node/generated/test_shape",
"onnx/node/generated/test_shape_clip_end",
+ "onnx/node/generated/test_shape_clip_start",
"onnx/node/generated/test_shape_end_1",
+ "onnx/node/generated/test_shape_end_negative_1",
"onnx/node/generated/test_shape_example",
"onnx/node/generated/test_shape_start_1",
"onnx/node/generated/test_shape_start_1_end_2",
diff --git a/tests/external/iree-test-suites/onnx_ops/onnx_ops_gpu_vulkan.json b/tests/external/iree-test-suites/onnx_ops/onnx_ops_gpu_vulkan.json
index eb2dee2..eafe678 100644
--- a/tests/external/iree-test-suites/onnx_ops/onnx_ops_gpu_vulkan.json
+++ b/tests/external/iree-test-suites/onnx_ops/onnx_ops_gpu_vulkan.json
@@ -178,7 +178,6 @@
"onnx/node/generated/test_gridsample_volumetric_nearest_align_corners_0",
"onnx/node/generated/test_gridsample_volumetric_nearest_align_corners_1",
"onnx/node/generated/test_gridsample_zeros_padding",
- "onnx/node/generated/test_if",
"onnx/node/generated/test_image_decoder_decode_bmp_rgb",
"onnx/node/generated/test_image_decoder_decode_jpeg2k_rgb",
"onnx/node/generated/test_image_decoder_decode_jpeg_bgr",
@@ -444,8 +443,6 @@
"onnx/node/generated/test_sce_none_weights_expanded",
"onnx/node/generated/test_sce_none_weights_log_prob",
"onnx/node/generated/test_sce_none_weights_log_prob_expanded",
- "onnx/node/generated/test_shape_clip_start",
- "onnx/node/generated/test_shape_end_negative_1",
"onnx/node/generated/test_slice",
"onnx/node/generated/test_slice_default_steps",
"onnx/node/generated/test_slice_end_out_of_bounds",
@@ -534,6 +531,7 @@
"onnx/node/generated/test_einsum_sum",
"onnx/node/generated/test_einsum_transpose",
"onnx/node/generated/test_eyelike_with_dtype",
+ "onnx/node/generated/test_if",
"onnx/node/generated/test_isinf_float16",
"onnx/node/generated/test_isnan_float16",
"onnx/node/generated/test_lstm_with_peepholes",
@@ -593,10 +591,7 @@
"onnx/node/generated/test_reduce_sum_square_default_axes_keepdims_example_expanded",
"onnx/node/generated/test_reduce_sum_square_default_axes_keepdims_random",
"onnx/node/generated/test_reduce_sum_square_default_axes_keepdims_random_expanded",
- "onnx/node/generated/test_scan9_sum",
- "onnx/node/generated/test_scan_sum",
"onnx/node/generated/test_shape_end_1",
- "onnx/node/generated/test_shape_start_1",
"onnx/node/generated/test_shape_start_1_end_2",
"onnx/node/generated/test_shape_start_1_end_negative_1",
"onnx/node/generated/test_shape_start_negative_1",
diff --git a/third_party/llvm-project b/third_party/llvm-project
index 854df1e..8323ca8 160000
--- a/third_party/llvm-project
+++ b/third_party/llvm-project
@@ -1 +1 @@
-Subproject commit 854df1eecb66c5e034a28f18a01cfadba29cfc42
+Subproject commit 8323ca8956aec45713231e06768a0b330f83cce1
diff --git a/third_party/torch-mlir b/third_party/torch-mlir
index 1570c15..6aa4696 160000
--- a/third_party/torch-mlir
+++ b/third_party/torch-mlir
@@ -1 +1 @@
-Subproject commit 1570c151e9491e2ca77687a4fd457ae00491b02e
+Subproject commit 6aa46967b69a01a46d56146250978d08e243e75e