| # Copyright 2023 The IREE Authors |
| # |
| # Licensed under the Apache License v2.0 with LLVM Exceptions. |
| # See https://llvm.org/LICENSE.txt for license information. |
| # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| |
| from library import * |
| from dispatch import * |
| from matmul import MatmulOperation, MatmulCompilationInfo, CudaMatmulGenerator |
| |
| |
| class CudaSplitKMatmulGenerator(CudaMatmulGenerator): |
| """SplitK Matmul dispatch generator class.""" |
| |
| def __init__(self, args): |
| """Initializes the splitK matmul generator.""" |
| super().__init__(args) |
| |
| # Predefined matmul shapes for splitK matmul. |
| self.matmul_shapes = [[128, 128, 12288]] |
| |
| # Predefined split_k_slices list for splitK matmul. |
| self.split_k_slices = [2, 4, 16, 18] |
| |
| # SplitK matmul dispatches collection list. |
| self.dispatches_collection_list = [] |
| |
| def _append_matmul_dispatch_collection(self, matmul_shapes, split_k_slices, |
| data_type, configuration_list): |
| """Appends the split-k matmul dispatches collection with the given configuration list.""" |
| |
| # Create dispatches collection for each matmul_shape x split_k_slice x configuration list. |
| for matmul_shape in matmul_shapes: |
| for split_k_slice in split_k_slices: |
| operation = MatmulOperation( |
| matmul_shape,\ |
| TensorDescription(data_type[0], LayoutType.RowMajor), \ |
| TensorDescription(data_type[1], LayoutType.RowMajor), \ |
| TensorDescription(data_type[2], LayoutType.RowMajor), \ |
| 1, # batch_count |
| split_k_slice, |
| OperationKind.SplitkMatmul) |
| |
| # Filter out configurations that are not supported by LLVM GPU CUDA backend. |
| supported_configuration_list = self._cuda_supported_configuration_list( |
| operation, configuration_list) |
| |
| # Add default configuration if enabled. |
| if self.args.default_config: |
| supported_configuration_list.append( |
| MatmulCompilationInfo([], [], OperationKind.Matmul, |
| CompilationConfigType.Default)) |
| |
| # Append the dispatch collection. |
| self.dispatches_collection_list.append(DispatchCollection(\ |
| operation, supported_configuration_list)) |
| |
| def _cuda_matmul_tensor_cores_f16(self): |
| """Appends a list of matmul split-k dispatches for GPU TensorCore F16 data type.""" |
| configuration_list = self._get_matmul_custom_compilation_info_list( |
| self.tile_descriptions_tensor_cores_f16, self.translation_infos, |
| OperationKind.SplitkMatmul) |
| data_type = [DataType.f16, DataType.f16, DataType.f16] |
| self._append_matmul_dispatch_collection(self.matmul_shapes, |
| self.split_k_slices, data_type, |
| configuration_list) |
| |
| def _cuda_matmul_tensor_cores_f32(self): |
| """Appends a list of matmul split-k dispatches for GPU TensorCore F32 data type.""" |
| configuration_list = self._get_matmul_custom_compilation_info_list( |
| self.tile_descriptions_tensor_cores_f32, self.translation_infos, |
| OperationKind.SplitkMatmul) |
| data_type = [DataType.f32, DataType.f32, DataType.f32] |
| self._append_matmul_dispatch_collection(self.matmul_shapes, |
| self.split_k_slices, data_type, |
| configuration_list) |