[CUDA] Set attribute for function using large shared memory (#8634)
diff --git a/iree/hal/cuda/dynamic_symbol_tables.h b/iree/hal/cuda/dynamic_symbol_tables.h
index aa2ee87..9b3f5c9 100644
--- a/iree/hal/cuda/dynamic_symbol_tables.h
+++ b/iree/hal/cuda/dynamic_symbol_tables.h
@@ -49,6 +49,7 @@
CUstream)
CU_PFN_DECL(cuMemcpyAsync, CUdeviceptr, CUdeviceptr, size_t, CUstream)
CU_PFN_DECL(cuMemcpyHtoDAsync_v2, CUdeviceptr, const void*, size_t, CUstream)
+CU_PFN_DECL(cuFuncSetAttribute, CUfunction, CUfunction_attribute, int)
CU_PFN_DECL(cuLaunchKernel, CUfunction, unsigned int, unsigned int,
unsigned int, unsigned int, unsigned int, unsigned int,
unsigned int, CUstream, void**, void**)
diff --git a/iree/hal/cuda/native_executable.c b/iree/hal/cuda/native_executable.c
index 184a4ef..5046595 100644
--- a/iree/hal/cuda/native_executable.c
+++ b/iree/hal/cuda/native_executable.c
@@ -101,6 +101,14 @@
status = CU_RESULT_TO_STATUS(
context->syms, cuModuleGetFunction(&function, module, entry_name),
"cuModuleGetFunction");
+ if (iree_status_is_ok(status)) {
+ status = CU_RESULT_TO_STATUS(
+ context->syms,
+ cuFuncSetAttribute(function,
+ CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
+ shared_memory_sizes[i]),
+ "cuFuncSetAttribute");
+ }
executable->entry_functions[i].cu_function = function;
executable->entry_functions[i].block_size_x = block_sizes_vec[i].x;
executable->entry_functions[i].block_size_y = block_sizes_vec[i].y;