blob: 7e53fe5c572314cf136eb569ed7d197455096a95 [file] [log] [blame]
# Copyright 2022 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from library import *
from dispatch import *
from matmul import MatmulOperation, MatmulCompilationInfo, CudaMatmulGenerator
class BatchMatmulOperation(MatmulOperation):
"""Data structure to describe a batch matrix multiplication operation."""
def __init__(self, bmm_shape, lhs, rhs, result):
assert len(bmm_shape) == 4, "Batch matmul shape must be 4D"
super().__init__(bmm_shape[1:], lhs, rhs, result, bmm_shape[0], 1,
OperationKind.BatchMatmul)
def name(self):
return f'{OperationKindNames[self.operation_kind]}_'\
f'{self.batch_count}x{self.M}x{self.N}x{self.K}_'\
f'{DataTypeName[self.lhs.datatype]}{ShortLayoutTypeName[self.lhs.layout]}_'\
f'{DataTypeName[self.rhs.datatype]}{ShortLayoutTypeName[self.rhs.layout]}_'\
f'{DataTypeName[self.result.datatype]}{ShortLayoutTypeName[self.result.layout]}'
def lhs_npy_shape(self):
return f'{self.batch_count}x{super().lhs_npy_shape()}'
def rhs_npy_shape(self):
return f'{self.batch_count}x{super().rhs_npy_shape()}'
def result_npy_shape(self):
return f'{self.batch_count}x{super().result_npy_shape()}'
class EmitLinalgBatchMatmulDispatch:
"""Emitters for the `linalg.batch_matmul` dispatch."""
def __init__(self):
self.mlir_dialect = MlirDialect.Linalg
self.linalg_row_row_matmul_template = """
// Dispatch linalg.batch_matmul row-row layout
func.func @${operation_name}_${compilation_info_name}(
%lhs: tensor<${batch_count}x${problem_m}x${problem_k}x${datatype_lhs}>,
%rhs: tensor<${batch_count}x${problem_k}x${problem_n}x${datatype_rhs}>) -> tensor<${batch_count}x${problem_m}x${problem_n}x${datatype_result}>
{
%c0 = arith.constant 0.0 : ${datatype_result}
%init = tensor.empty() : tensor<${batch_count}x${problem_m}x${problem_n}x${datatype_result}>
%inital_result = linalg.fill ins(%c0 : ${datatype_result}) outs(%init : tensor<${batch_count}x${problem_m}x${problem_n}x${datatype_result}>) -> tensor<${batch_count}x${problem_m}x${problem_n}x${datatype_result}>
%result = linalg.batch_matmul ${compilation_info_attribute}
ins(%lhs, %rhs: tensor<${batch_count}x${problem_m}x${problem_k}x${datatype_lhs}>, tensor<${batch_count}x${problem_k}x${problem_n}x${datatype_rhs}>)
outs(%inital_result: tensor<${batch_count}x${problem_m}x${problem_n}x${datatype_result}>) -> tensor<${batch_count}x${problem_m}x${problem_n}x${datatype_result}>
return %result : tensor<${batch_count}x${problem_m}x${problem_n}x${datatype_result}>
}
"""
def emit(self, dispatch):
"""Emit the matmul operation in the MLIR dialect for a single compilation info"""
compilation_info_attribute_template = """{compilation_info = #${compilation_info_name}}"""
compilation_info_attribute_str = SubstituteTemplate(
compilation_info_attribute_template,
{'compilation_info_name': dispatch.configuration.name()})
compilation_info_attribute = compilation_info_attribute_str \
if dispatch.configuration.config_type != CompilationConfigType.Default else ""
values = {
'operation_name': dispatch.operation.name(),
'compilation_info_attribute': compilation_info_attribute,
'batch_count': str(dispatch.operation.batch_count),
'problem_m': str(dispatch.operation.M),
'problem_n': str(dispatch.operation.N),
'problem_k': str(dispatch.operation.K),
'datatype_lhs': DataTypeName[dispatch.operation.lhs.datatype],
'datatype_rhs': DataTypeName[dispatch.operation.rhs.datatype],
'datatype_result': DataTypeName[dispatch.operation.result.datatype],
'compilation_info_name': dispatch.configuration.name()
}
return SubstituteTemplate(self.linalg_row_row_matmul_template, values)
class ReferenceBatchMatmulOp(ReferenceOpInterface):
"""Reference implementation for the batch matmul operation in numpy."""
def __init__(self, bmm_operation, op_reference_cache_path, dist_lhs,
dist_rhs):
self.bmm_operation = bmm_operation
self.op_reference_cache_path = op_reference_cache_path
if not self.op_reference_cache_path.exists():
self.op_reference_cache_path.mkdir()
# Problem shape.
self.batch_count = bmm_operation.batch_count
self.M = bmm_operation.M
self.N = bmm_operation.N
self.K = bmm_operation.K
# Data type for the input and result matrices.
self.dtype_lhs = DataTypeNumPyTag[bmm_operation.lhs.datatype]
self.dtype_rhs = DataTypeNumPyTag[bmm_operation.rhs.datatype]
self.dtype_result = DataTypeNumPyTag[bmm_operation.result.datatype]
# Distribution of the input tensors.
self.dist_lhs = dist_lhs
self.dist_rhs = dist_rhs
# Filename for the left hand side input tensor.
self.filename_lhs = "batch_count{batch_count}xm{problem_m}xk{problem_k}_"\
"{tensor_description}_{dist}_lhs.npy".format(
batch_count=self.batch_count,
problem_m=self.M,
problem_k=self.K,
tensor_description=self.bmm_operation.lhs.name(),
dist=DistributionName[self.dist_lhs])
# Filename for the right hand side input tensor.
self.filename_rhs = "batch_count{batch_count}xk{problem_k}xn{problem_n}_"\
"{tensor_description}_{dist}_rhs.npy".format(
batch_count=self.batch_count,
problem_k=self.K,
problem_n=self.N,
tensor_description=self.bmm_operation.rhs.name(),
dist=DistributionName[self.dist_rhs])
# Filename for the reference result tensor.
self.filename_reference_result = "batch_count{batch_count}xm{problem_m}xn{problem_n}_"\
"{tensor_description}_reference_result.npy".format(
batch_count=self.batch_count,
problem_m=self.M,
problem_n=self.N,
tensor_description=self.bmm_operation.result.name())
# Filepath for input and output files.
self.filepath_lhs = self.op_reference_cache_path.joinpath(self.filename_lhs)
self.filepath_rhs = self.op_reference_cache_path.joinpath(self.filename_rhs)
self.filepath_reference_result = self.op_reference_cache_path.joinpath(
self.filename_reference_result)
def get_input_filepaths(self):
"""Returns the list of input file paths."""
return [self.filepath_lhs, self.filepath_rhs]
def get_output_filepaths(self):
"""Returns the list of expected output file paths."""
return [self.filepath_reference_result]
def __call__(self):
"""Generates input data, runs reference numpy.matmul, and save npy files to the output directory."""
# Generate the input data as np.array for the matmul operation.
lhs_np_array = get_np_array(self.bmm_operation.lhs,
(self.batch_count, self.M, self.K),
self.dist_lhs)
rhs_np_array = get_np_array(self.bmm_operation.rhs,
(self.batch_count, self.K, self.N),
self.dist_rhs)
# Run the reference np.matmul and generate result np.array.
result = np.matmul(lhs_np_array, rhs_np_array)
# Save the input data as np.array for the matmul operation.
np.save(self.filepath_lhs, np.array(lhs_np_array, dtype=self.dtype_lhs))
np.save(self.filepath_rhs, np.array(rhs_np_array, dtype=self.dtype_rhs))
# Save the expected result as an np.array.
np.save(self.filepath_reference_result,
np.array(result, dtype=self.dtype_result))
##############################################################################
class CudaBatchMatmulGenerator(CudaMatmulGenerator):
"""Batch matmul dispatch generator class. """
def __init__(self, args):
"""Initializes the batch matmul dispatch generator."""
super().__init__(args)
# Predefined batch matmul problem shapes.
self.batch_matmul_shapes = [[16, 512, 64, 512]]
# Batch matmul dispatches collection.
self.dispatches_collection_list = []
def _append_matmul_dispatch_collection(self, bmm_shapes, data_type,
configuration_list):
"""Update the batch matmul dispatch collection with the given configuration list."""
# Create dispatches collection for each problem shape with the configuration list.
for bmm_shape in bmm_shapes:
operation = BatchMatmulOperation(
bmm_shape,\
TensorDescription(data_type[0], LayoutType.RowMajor), \
TensorDescription(data_type[1], LayoutType.RowMajor), \
TensorDescription(data_type[2], LayoutType.RowMajor))
# Filter out configurations that are not supported by LLVM GPU CUDA backend.
supported_configuration_list = self._cuda_supported_configuration_list(
operation, configuration_list)
# Add default configuration if enabled.
if self.args.default_config:
supported_configuration_list.append(
MatmulCompilationInfo([], [], OperationKind.BatchMatmul,
CompilationConfigType.Default))
# Append the dispatches collection.
self.dispatches_collection_list.append(DispatchCollection(\
operation, supported_configuration_list))
def _cuda_matmul_tensor_cores_f16(self):
"""Appends a list of matmul dispatches for GPU TensorCore F16 data type."""
configuration_list = self._get_matmul_custom_compilation_info_list(
self.tile_descriptions_tensor_cores_f16, self.translation_infos,
OperationKind.BatchMatmul)
data_type = [DataType.f16, DataType.f16, DataType.f16]
self._append_matmul_dispatch_collection(self.batch_matmul_shapes, data_type,
configuration_list)
def _cuda_matmul_tensor_cores_f32(self):
"""Appends a list of matmul dispatches for GPU TensorCore F32 data type."""
configuration_list = self._get_matmul_custom_compilation_info_list(
self.tile_descriptions_tensor_cores_f32, self.translation_infos,
OperationKind.BatchMatmul)
data_type = [DataType.f32, DataType.f32, DataType.f32]
self._append_matmul_dispatch_collection(self.batch_matmul_shapes, data_type,
configuration_list)