Created
May 8, 2025 07:44
-
-
Save danieldk/74bb6bf40d7b0c215a187fcd664029b4 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
cmake_minimum_required(VERSION 3.26) | |
project(activation LANGUAGES CXX) | |
set(TARGET_DEVICE "cuda" CACHE STRING "Target device backend for kernel") | |
install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY TRUE)" ALL_COMPONENTS) | |
include(FetchContent) | |
file(MAKE_DIRECTORY ${FETCHCONTENT_BASE_DIR}) # Ensure the directory exists | |
message(STATUS "FetchContent base directory: ${FETCHCONTENT_BASE_DIR}") | |
set(CUDA_SUPPORTED_ARCHS "7.0;7.2;7.5;8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0") | |
set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx1101") | |
include(${CMAKE_CURRENT_LIST_DIR}/cmake/utils.cmake) | |
if(DEFINED Python_EXECUTABLE) | |
# Allow passing through the interpreter (e.g. from setup.py). | |
find_package(Python COMPONENTS Development Development.SABIModule Interpreter) | |
if (NOT Python_FOUND) | |
message(FATAL_ERROR "Unable to find python matching: ${EXECUTABLE}.") | |
endif() | |
else() | |
find_package(Python REQUIRED COMPONENTS Development Development.SABIModule Interpreter) | |
endif() | |
append_cmake_prefix_path("torch" "torch.utils.cmake_prefix_path") | |
find_package(Torch REQUIRED) | |
if (NOT TARGET_DEVICE STREQUAL "cuda" AND | |
NOT TARGET_DEVICE STREQUAL "rocm") | |
return() | |
endif() | |
if (NOT HIP_FOUND AND CUDA_FOUND) | |
set(GPU_LANG "CUDA") | |
elseif(HIP_FOUND) | |
set(GPU_LANG "HIP") | |
# Importing torch recognizes and sets up some HIP/ROCm configuration but does | |
# not let cmake recognize .hip files. In order to get cmake to understand the | |
# .hip extension automatically, HIP must be enabled explicitly. | |
enable_language(HIP) | |
else() | |
message(FATAL_ERROR "Can't find CUDA or HIP installation.") | |
endif() | |
if(GPU_LANG STREQUAL "CUDA") | |
clear_cuda_arches(CUDA_ARCH_FLAGS) | |
extract_unique_cuda_archs_ascending(CUDA_ARCHS "${CUDA_ARCH_FLAGS}") | |
message(STATUS "CUDA target architectures: ${CUDA_ARCHS}") | |
# Filter the target architectures by the supported supported archs | |
# since for some files we will build for all CUDA_ARCHS. | |
cuda_archs_loose_intersection(CUDA_ARCHS "${CUDA_SUPPORTED_ARCHS}" "${CUDA_ARCHS}") | |
message(STATUS "CUDA supported target architectures: ${CUDA_ARCHS}") | |
if(NVCC_THREADS AND GPU_LANG STREQUAL "CUDA") | |
list(APPEND GPU_FLAGS "--threads=${NVCC_THREADS}") | |
endif() | |
elseif(GPU_LANG STREQUAL "HIP") | |
set(ROCM_ARCHS "${HIP_SUPPORTED_ARCHS}") | |
# TODO: remove this once we can set specific archs per source file set. | |
override_gpu_arches(GPU_ARCHES | |
${GPU_LANG} | |
"${${GPU_LANG}_SUPPORTED_ARCHS}") | |
else() | |
override_gpu_arches(GPU_ARCHES | |
${GPU_LANG} | |
"${${GPU_LANG}_SUPPORTED_ARCHS}") | |
endif() | |
get_torch_gpu_compiler_flags(TORCH_GPU_FLAGS ${GPU_LANG}) | |
list(APPEND GPU_FLAGS ${TORCH_GPU_FLAGS}) | |
set(TORCH_activation_SRC | |
torch-ext/torch_binding.cpp torch-ext/torch_binding.h | |
) | |
list(APPEND SRC "${TORCH_activation_SRC}") | |
set(activation_SRC | |
"activation/activation_kernels.cu" | |
"activation/cuda_compat.h" | |
"activation/dispatch_utils.h" | |
) | |
if(GPU_LANG STREQUAL "CUDA") | |
cuda_archs_loose_intersection(activation_ARCHS "${CUDA_SUPPORTED_ARCHS}" "${CUDA_ARCHS}") | |
message(STATUS "Capabilities for kernel activation: ${activation_ARCHS}") | |
set_gencode_flags_for_srcs(SRCS "${activation_SRC}" CUDA_ARCHS "${activation_ARCHS}") | |
list(APPEND SRC "${activation_SRC}") | |
endif() | |
define_gpu_extension_target( | |
_activation_psnp6q5y4k4wg | |
DESTINATION _activation_psnp6q5y4k4wg | |
LANGUAGE ${GPU_LANG} | |
SOURCES ${SRC} | |
COMPILE_FLAGS ${GPU_FLAGS} | |
ARCHITECTURES ${GPU_ARCHES} | |
#INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR} | |
USE_SABI 3 | |
WITH_SOABI) | |
target_link_options(_activation_psnp6q5y4k4wg PRIVATE -static-libstdc++) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment