| # Copyright 2022 The IREE Authors |
| # |
| # Licensed under the Apache License v2.0 with LLVM Exceptions. |
| # See https://llvm.org/LICENSE.txt for license information. |
| # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| from library import * |
| from dispatch import * |
| from matmul import MatmulOperation, MatmulCompilationInfo, CudaMatmulGenerator |
| |
| |
| class BatchMatmulOperation(MatmulOperation): |
| """Data structure to describe a batch matrix multiplication operation.""" |
| |
| def __init__(self, bmm_shape, lhs, rhs, result): |
| assert len(bmm_shape) == 4, "Batch matmul shape must be 4D" |
| super().__init__( |
| bmm_shape[1:], lhs, rhs, result, bmm_shape[0], 1, OperationKind.BatchMatmul |
| ) |
| |
| def name(self): |
| return ( |
| f"{OperationKindNames[self.operation_kind]}_" |
| f"{self.batch_count}x{self.M}x{self.N}x{self.K}_" |
| f"{DataTypeName[self.lhs.datatype]}{ShortLayoutTypeName[self.lhs.layout]}_" |
| f"{DataTypeName[self.rhs.datatype]}{ShortLayoutTypeName[self.rhs.layout]}_" |
| f"{DataTypeName[self.result.datatype]}{ShortLayoutTypeName[self.result.layout]}" |
| ) |
| |
| def lhs_npy_shape(self): |
| return f"{self.batch_count}x{super().lhs_npy_shape()}" |
| |
| def rhs_npy_shape(self): |
| return f"{self.batch_count}x{super().rhs_npy_shape()}" |
| |
| def result_npy_shape(self): |
| return f"{self.batch_count}x{super().result_npy_shape()}" |
| |
| |
| class EmitLinalgBatchMatmulDispatch: |
| """Emitters for the `linalg.batch_matmul` dispatch.""" |
| |
| def __init__(self): |
| self.mlir_dialect = MlirDialect.Linalg |
| |
| self.linalg_row_row_matmul_template = """ |
| // Dispatch linalg.batch_matmul row-row layout |
| func.func @${operation_name}_${compilation_info_name}( |
| %lhs: tensor<${batch_count}x${problem_m}x${problem_k}x${datatype_lhs}>, |
| %rhs: tensor<${batch_count}x${problem_k}x${problem_n}x${datatype_rhs}>) -> tensor<${batch_count}x${problem_m}x${problem_n}x${datatype_result}> |
| { |
| %c0 = arith.constant 0.0 : ${datatype_result} |
| %init = tensor.empty() : tensor<${batch_count}x${problem_m}x${problem_n}x${datatype_result}> |
| %inital_result = linalg.fill ins(%c0 : ${datatype_result}) outs(%init : tensor<${batch_count}x${problem_m}x${problem_n}x${datatype_result}>) -> tensor<${batch_count}x${problem_m}x${problem_n}x${datatype_result}> |
| %result = linalg.batch_matmul ${compilation_info_attribute} |
| ins(%lhs, %rhs: tensor<${batch_count}x${problem_m}x${problem_k}x${datatype_lhs}>, tensor<${batch_count}x${problem_k}x${problem_n}x${datatype_rhs}>) |
| outs(%inital_result: tensor<${batch_count}x${problem_m}x${problem_n}x${datatype_result}>) -> tensor<${batch_count}x${problem_m}x${problem_n}x${datatype_result}> |
| return %result : tensor<${batch_count}x${problem_m}x${problem_n}x${datatype_result}> |
| } |
| """ |
| |
| def emit(self, dispatch): |
| """Emit the matmul operation in the MLIR dialect for a single compilation info""" |
| compilation_info_attribute_template = ( |
| """{compilation_info = #${compilation_info_name}}""" |
| ) |
| compilation_info_attribute_str = SubstituteTemplate( |
| compilation_info_attribute_template, |
| {"compilation_info_name": dispatch.configuration.name()}, |
| ) |
| compilation_info_attribute = ( |
| compilation_info_attribute_str |
| if dispatch.configuration.config_type != CompilationConfigType.Default |
| else "" |
| ) |
| |
| values = { |
| "operation_name": dispatch.operation.name(), |
| "compilation_info_attribute": compilation_info_attribute, |
| "batch_count": str(dispatch.operation.batch_count), |
| "problem_m": str(dispatch.operation.M), |
| "problem_n": str(dispatch.operation.N), |
| "problem_k": str(dispatch.operation.K), |
| "datatype_lhs": DataTypeName[dispatch.operation.lhs.datatype], |
| "datatype_rhs": DataTypeName[dispatch.operation.rhs.datatype], |
| "datatype_result": DataTypeName[dispatch.operation.result.datatype], |
| "compilation_info_name": dispatch.configuration.name(), |
| } |
| |
| return SubstituteTemplate(self.linalg_row_row_matmul_template, values) |
| |
| |
| class ReferenceBatchMatmulOp(ReferenceOpInterface): |
| """Reference implementation for the batch matmul operation in numpy.""" |
| |
| def __init__(self, bmm_operation, op_reference_cache_path, dist_lhs, dist_rhs): |
| self.bmm_operation = bmm_operation |
| self.op_reference_cache_path = op_reference_cache_path |
| |
| if not self.op_reference_cache_path.exists(): |
| self.op_reference_cache_path.mkdir() |
| |
| # Problem shape. |
| self.batch_count = bmm_operation.batch_count |
| self.M = bmm_operation.M |
| self.N = bmm_operation.N |
| self.K = bmm_operation.K |
| |
| # Data type for the input and result matrices. |
| self.dtype_lhs = DataTypeNumPyTag[bmm_operation.lhs.datatype] |
| self.dtype_rhs = DataTypeNumPyTag[bmm_operation.rhs.datatype] |
| self.dtype_result = DataTypeNumPyTag[bmm_operation.result.datatype] |
| |
| # Distribution of the input tensors. |
| self.dist_lhs = dist_lhs |
| self.dist_rhs = dist_rhs |
| |
| # Filename for the left hand side input tensor. |
| self.filename_lhs = ( |
| "batch_count{batch_count}xm{problem_m}xk{problem_k}_" |
| "{tensor_description}_{dist}_lhs.npy".format( |
| batch_count=self.batch_count, |
| problem_m=self.M, |
| problem_k=self.K, |
| tensor_description=self.bmm_operation.lhs.name(), |
| dist=DistributionName[self.dist_lhs], |
| ) |
| ) |
| |
| # Filename for the right hand side input tensor. |
| self.filename_rhs = ( |
| "batch_count{batch_count}xk{problem_k}xn{problem_n}_" |
| "{tensor_description}_{dist}_rhs.npy".format( |
| batch_count=self.batch_count, |
| problem_k=self.K, |
| problem_n=self.N, |
| tensor_description=self.bmm_operation.rhs.name(), |
| dist=DistributionName[self.dist_rhs], |
| ) |
| ) |
| |
| # Filename for the reference result tensor. |
| self.filename_reference_result = ( |
| "batch_count{batch_count}xm{problem_m}xn{problem_n}_" |
| "{tensor_description}_reference_result.npy".format( |
| batch_count=self.batch_count, |
| problem_m=self.M, |
| problem_n=self.N, |
| tensor_description=self.bmm_operation.result.name(), |
| ) |
| ) |
| |
| # Filepath for input and output files. |
| self.filepath_lhs = self.op_reference_cache_path.joinpath(self.filename_lhs) |
| self.filepath_rhs = self.op_reference_cache_path.joinpath(self.filename_rhs) |
| self.filepath_reference_result = self.op_reference_cache_path.joinpath( |
| self.filename_reference_result |
| ) |
| |
| def get_input_filepaths(self): |
| """Returns the list of input file paths.""" |
| return [self.filepath_lhs, self.filepath_rhs] |
| |
| def get_output_filepaths(self): |
| """Returns the list of expected output file paths.""" |
| return [self.filepath_reference_result] |
| |
| def __call__(self): |
| """Generates input data, runs reference numpy.matmul, and save npy files to the output directory.""" |
| # Generate the input data as np.array for the matmul operation. |
| lhs_np_array = get_np_array( |
| self.bmm_operation.lhs, (self.batch_count, self.M, self.K), self.dist_lhs |
| ) |
| rhs_np_array = get_np_array( |
| self.bmm_operation.rhs, (self.batch_count, self.K, self.N), self.dist_rhs |
| ) |
| |
| # Run the reference np.matmul and generate result np.array. |
| result = np.matmul(lhs_np_array, rhs_np_array) |
| |
| # Save the input data as np.array for the matmul operation. |
| np.save(self.filepath_lhs, np.array(lhs_np_array, dtype=self.dtype_lhs)) |
| np.save(self.filepath_rhs, np.array(rhs_np_array, dtype=self.dtype_rhs)) |
| |
| # Save the expected result as an np.array. |
| np.save( |
| self.filepath_reference_result, np.array(result, dtype=self.dtype_result) |
| ) |
| |
| |
| ############################################################################## |
| class CudaBatchMatmulGenerator(CudaMatmulGenerator): |
| """Batch matmul dispatch generator class.""" |
| |
| def __init__(self, args): |
| """Initializes the batch matmul dispatch generator.""" |
| super().__init__(args) |
| |
| # Predefined batch matmul problem shapes. |
| self.batch_matmul_shapes = [[16, 512, 64, 512]] |
| |
| # Batch matmul dispatches collection. |
| self.dispatches_collection_list = [] |
| |
| def _append_matmul_dispatch_collection( |
| self, bmm_shapes, data_type, configuration_list |
| ): |
| """Update the batch matmul dispatch collection with the given configuration list.""" |
| |
| # Create dispatches collection for each problem shape with the configuration list. |
| for bmm_shape in bmm_shapes: |
| operation = BatchMatmulOperation( |
| bmm_shape, |
| TensorDescription(data_type[0], LayoutType.RowMajor), |
| TensorDescription(data_type[1], LayoutType.RowMajor), |
| TensorDescription(data_type[2], LayoutType.RowMajor), |
| ) |
| |
| # Filter out configurations that are not supported by LLVM GPU CUDA backend. |
| supported_configuration_list = self._cuda_supported_configuration_list( |
| operation, configuration_list |
| ) |
| |
| # Add default configuration if enabled. |
| if self.args.default_config: |
| supported_configuration_list.append( |
| MatmulCompilationInfo( |
| [], [], OperationKind.BatchMatmul, CompilationConfigType.Default |
| ) |
| ) |
| |
| # Append the dispatches collection. |
| self.dispatches_collection_list.append( |
| DispatchCollection(operation, supported_configuration_list) |
| ) |
| |
| def _cuda_matmul_tensor_cores_f16(self): |
| """Appends a list of matmul dispatches for GPU TensorCore F16 data type.""" |
| configuration_list = self._get_matmul_custom_compilation_info_list( |
| self.tile_descriptions_tensor_cores_f16, |
| self.translation_infos, |
| OperationKind.BatchMatmul, |
| ) |
| data_type = [DataType.f16, DataType.f16, DataType.f16] |
| self._append_matmul_dispatch_collection( |
| self.batch_matmul_shapes, data_type, configuration_list |
| ) |
| |
| def _cuda_matmul_tensor_cores_f32(self): |
| """Appends a list of matmul dispatches for GPU TensorCore F32 data type.""" |
| configuration_list = self._get_matmul_custom_compilation_info_list( |
| self.tile_descriptions_tensor_cores_f32, |
| self.translation_infos, |
| OperationKind.BatchMatmul, |
| ) |
| data_type = [DataType.f32, DataType.f32, DataType.f32] |
| self._append_matmul_dispatch_collection( |
| self.batch_matmul_shapes, data_type, configuration_list |
| ) |
| |
| def generate(self): |
| """Generates a list of matmul operations.""" |
| self._cuda_matmul_tensor_cores_f16() |
| self._cuda_matmul_tensor_cores_f32() |
| return self.dispatches_collection_list |