blob: 8fa126354a38f7a2d580d62157f7125744787181 [file] [log] [blame]
# Copyright 2022 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
import enum, shutil, pickle
from library import *
from matmul import *
from batch_matmul import *
from split_k_matmul import *
from pathlib import Path
###############################################################################
class EmitSourceMLIR:
"""Emitters for the operation MLIR source files."""
def __init__(self, operation_path, dispatch_collection):
self.operation_path = operation_path
self.dispatch_collection = dispatch_collection
self.operation = dispatch_collection.operation
self.operation_kind = self.operation.operation_kind
self.configuration_list = dispatch_collection.configuration_list
self.operation_filepath = self.operation_path.joinpath(
self.operation.name()).with_suffix(".mlir")
mlir_configuration_emitter = {
OperationKind.Matmul: EmitMatmulCompilationInfo,
OperationKind.SplitkMatmul: EmitMatmulCompilationInfo,
OperationKind.BatchMatmul: EmitMatmulCompilationInfo,
}
self.configuration_emitter = mlir_configuration_emitter[
self.operation_kind]()
mlir_dispatch_emitter = {
OperationKind.Matmul: EmitLinalgMatmulDispatch,
OperationKind.SplitkMatmul: EmitLinalgMatmulDispatch,
OperationKind.BatchMatmul: EmitLinalgBatchMatmulDispatch,
}
self.dispatch_emitter = mlir_dispatch_emitter[self.operation_kind]()
def __enter__(self):
self.operation_file = open(self.operation_filepath, "w")
self.operation_file.write(f'// Finename: {self.operation_filepath}')
# Emit all the configuration attribute tags.
for configuration in self.configuration_list:
self.operation_file.write(self.configuration_emitter.emit(configuration))
return self
def emit(self):
"""Emit the op func.func for each dispatch (operation + configuration)"""
for dispatch in self.dispatch_collection.get_dispatches():
print(
f" Emitting tuning configuration : {dispatch.configuration.name()}"
)
self.operation_file.write(self.dispatch_emitter.emit(dispatch))
def __exit__(self, exc_type, exc_value, traceback):
self.operation_file.close()
###############################################################################
class Manifest:
"""Manifest collects, filters, and stores dispatches in a data structure.
Manifest organizes the dispatches in a dictionary.
Usage:
1. Create a manifest object with the command line arguments.
2(a). Generate dispatches, append them in the manifest, and
serialize them into a file.
2(b). Load dispatches from a serialized file.
```python
# generator.py usage:
manifest = Manifest(args)
manifest.initialize()
# compile.py or profile.py usage:
manifest = Manifest(args)
manifest.load()
```
"""
def __init__(self, args):
self.args = args
# Dictionary of operation kind to a list of dispatch collections. We
# initialize the dictionary during the generation of dispatches and
# serialize it to a file. The serialized file is used to load the
# dispatches for compilation and profiling.
# Datatype: OperationKind -> [DispatchCollection]
self.dispatch_collection_map = {}
# For operation kind-based filtering of dispatches.
self.operation_kind_enabled = []
# For name-based filtering of dispatches.
self.dispatch_names = []
self.ignore_dispatch_names = []
if args.operation_kind == 'all':
self.operation_kind_enabled = []
else:
operations_kind_list = [
OperationKind.Matmul,
OperationKind.SplitkMatmul,
OperationKind.BatchMatmul,
]
self.operation_kind_enabled = [
x for x in operations_kind_list
if OperationKindNames[x] in args.operation_kind.split(',')
]
if args.dispatches == 'all':
self.dispatch_names = []
else:
self.dispatch_names = [x for x in args.dispatches.split(',') if x != '']
# Paths to the generated directory (e.g. `./generated/linalg`).
self.generated_path = Path(self.args.generated_dir, 'generated',
self.args.mlir_dialect)
# Create the directories in self.generated_path, if it does not exist.
if not self.generated_path.exists():
self.generated_path.mkdir(parents=True, exist_ok=True)
# Path to the serialized file.
self.serialized_file_path = self.generated_path.joinpath(
'serialized_file.pkl')
def _filter_string_matches(self, filter_string, haystack):
"""Returns true if all substrings appear in the haystack in order"""
substrings = filter_string.split('*')
for sub in substrings:
idx = haystack.find(sub)
if idx < 0:
return False
haystack = haystack[idx + len(sub):]
return True
def is_enabled(self, dispatch):
"""Rerturns true if pass through filters based various criteria."""
# Get the operation and configuration from the dispatch.
operation = dispatch.operation
configuration = dispatch.configuration
# If the operation is not in the enabled list, return False.
enabled = True
# If operation_kind filter is enabled and the \
# operation_kind in not in the enabled list, return False.
if len(self.operation_kind_enabled) and \
operation.operation_kind not in self.operation_kind_enabled:
enabled = False
# If dispatch name-based filter regex is enabled match the \
# dispatch name (operation+configuration) against all regexs \
# in self.dispatch_names.
if len(self.dispatch_names):
name = dispatch.name()
enabled = False
# compare against each regex included in self.dispatch_names.
for substr_to_match in self.dispatch_names:
if self._filter_string_matches(substr_to_match, name):
enabled = True
break
# Return the result of the filter.
return enabled
def append_dispatch_collection(self, dispatch_collection):
"""Appends one instance of DispatchCollection to the manifest."""
operation_kind = dispatch_collection.operation.operation_kind
if operation_kind not in self.dispatch_collection_map.keys():
self.dispatch_collection_map[operation_kind] = []
# Get all the dispatches from the dispatch_collection.
dispatches = dispatch_collection.get_dispatches()
# Filter dispatches based on the filter criteria.
filtered_dispatch_collection = DispatchCollection(
dispatch_collection.operation, [])
for dispatch in dispatches:
if self.is_enabled(dispatch):
filtered_dispatch_collection.append(dispatch)
# Only append the filtered_dispatch_collection if it has an unfiltered configuration.
if len(filtered_dispatch_collection.configuration_list):
self.dispatch_collection_map[operation_kind].append(
filtered_dispatch_collection)
def append(self, dispatch_collection_list):
"""Appends one instance of DispatchCollection to the manifest."""
for dispatch_collection in dispatch_collection_list:
self.append_dispatch_collection(dispatch_collection)
def initialize(self):
"""Initialize the mainfest object by generating dispatches for supported operations."""
self.append(CudaMatmulGenerator(self.args).generate())
self.append(CudaSplitKMatmulGenerator(self.args).generate())
self.append(CudaBatchMatmulGenerator(self.args).generate())
# Serialize the initialized mainfest state.
self.dump()
def dump(self):
"""Serialize (dump) the self.dispatch_collection_map to a pickle file."""
with open(self.serialized_file_path, 'wb') as f:
pickle.dump(self.dispatch_collection_map, f)
def load(self):
"""Deserialize (load) the self.dispatch_collection_map from a pickle file."""
if not self.serialized_file_path.exists():
raise ValueError(f"Could not find : {self.serialized_file_path}")
with open(self.serialized_file_path, 'rb') as load_file:
self.dispatch_collection_map = pickle.load(load_file)
def emit(self):
"""Emits the operations in the Manifest to the build directory as MLIR source files.
The operations are emitted in the dialect specified by the `mlir_dialect` flag.
"""
# For each operation_kind create a directory and emit the operations with
# all the configurations in the configuration_list into their seperate directories.
for operation_kind, dispatch_collection_list\
in self.dispatch_collection_map.items():
operation_kind_path = self.generated_path.joinpath(
OperationKindNames[operation_kind])
# If the operation_kind_path does not exists, create it.
if not operation_kind_path.exists():
operation_kind_path.mkdir(parents=True, exist_ok=True)
for dispatch_collection in dispatch_collection_list:
operation_path = operation_kind_path.joinpath(
dispatch_collection.operation.name())
# If the operation_path does not exists, create it.
if not operation_path.exists():
operation_path.mkdir()
with EmitSourceMLIR(operation_path,
dispatch_collection) as emit_mlir_source:
mlir_file_path = operation_path.joinpath(
dispatch_collection.operation.name()).with_suffix('.mlir')
print(f"[Generating]: {mlir_file_path}")
# Emit mlir source file for the dispatch_collection.operation with all the configurations
emit_mlir_source.emit()