blob: 937f44f1a53460128cf27bc55d44f6a42fb74688 [file] [log] [blame]
# Copyright 2023 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from library import *
from dispatch import *
from matmul import MatmulOperation, MatmulCompilationInfo, CudaMatmulGenerator
class CudaSplitKMatmulGenerator(CudaMatmulGenerator):
"""SplitK Matmul dispatch generator class."""
def __init__(self, args):
"""Initializes the splitK matmul generator."""
super().__init__(args)
# Predefined matmul shapes for splitK matmul.
self.matmul_shapes = [[128, 128, 12288]]
# Predefined split_k_slices list for splitK matmul.
self.split_k_slices = [2, 4, 16, 18]
# SplitK matmul dispatches collection list.
self.dispatches_collection_list = []
def _append_matmul_dispatch_collection(self, matmul_shapes, split_k_slices,
data_type, configuration_list):
"""Appends the split-k matmul dispatches collection with the given configuration list."""
# Create dispatches collection for each matmul_shape x split_k_slice x configuration list.
for matmul_shape in matmul_shapes:
for split_k_slice in split_k_slices:
operation = MatmulOperation(
matmul_shape,\
TensorDescription(data_type[0], LayoutType.RowMajor), \
TensorDescription(data_type[1], LayoutType.RowMajor), \
TensorDescription(data_type[2], LayoutType.RowMajor), \
1, # batch_count
split_k_slice,
OperationKind.SplitkMatmul)
# Filter out configurations that are not supported by LLVM GPU CUDA backend.
supported_configuration_list = self._cuda_supported_configuration_list(
operation, configuration_list)
# Add default configuration if enabled.
if self.args.default_config:
supported_configuration_list.append(
MatmulCompilationInfo([], [], OperationKind.Matmul,
CompilationConfigType.Default))
# Append the dispatch collection.
self.dispatches_collection_list.append(DispatchCollection(\
operation, supported_configuration_list))
def _cuda_matmul_tensor_cores_f16(self):
"""Appends a list of matmul split-k dispatches for GPU TensorCore F16 data type."""
configuration_list = self._get_matmul_custom_compilation_info_list(
self.tile_descriptions_tensor_cores_f16, self.translation_infos,
OperationKind.SplitkMatmul)
data_type = [DataType.f16, DataType.f16, DataType.f16]
self._append_matmul_dispatch_collection(self.matmul_shapes,
self.split_k_slices, data_type,
configuration_list)
def _cuda_matmul_tensor_cores_f32(self):
"""Appends a list of matmul split-k dispatches for GPU TensorCore F32 data type."""
configuration_list = self._get_matmul_custom_compilation_info_list(
self.tile_descriptions_tensor_cores_f32, self.translation_infos,
OperationKind.SplitkMatmul)
data_type = [DataType.f32, DataType.f32, DataType.f32]
self._append_matmul_dispatch_collection(self.matmul_shapes,
self.split_k_slices, data_type,
configuration_list)