52 lines
2.0 KiB
CMake
52 lines
2.0 KiB
CMake
cmake_minimum_required(VERSION 3.27)
|
|
|
|
include(FetchContent)
|
|
|
|
set(LIBTORCH_PLATFORM "none" CACHE STRING "Determines libtorch platform version to download (CUDA11.8, CUDA12.1, CUDA12.8, ROCm6.1 or none).")
|
|
|
|
if(${LIBTORCH_PLATFORM} STREQUAL "none")
|
|
set(LIBTORCH_DEVICE "cpu")
|
|
elseif(${LIBTORCH_PLATFORM} STREQUAL "CUDA11.8")
|
|
set(LIBTORCH_DEVICE "cu118")
|
|
elseif(${LIBTORCH_PLATFORM} STREQUAL "CUDA12.1")
|
|
set(LIBTORCH_DEVICE "cu121")
|
|
elseif(${LIBTORCH_PLATFORM} STREQUAL "CUDA12.8")
|
|
set(LIBTORCH_DEVICE "cu128")
|
|
elseif(${LIBTORCH_PLATFORM} STREQUAL "ROCm6.1")
|
|
set(LIBTORCH_DEVICE "rocm6.1")
|
|
else()
|
|
message(FATAL_ERROR "Invalid libtorch platform, must be either CUDA11.8, CUDA12.1, CUDA12.8, ROCm6.1 or none.")
|
|
endif()
|
|
|
|
set(PYTORCH_VERSION "2.4.0")
|
|
|
|
if(${CMAKE_SYSTEM_NAME} STREQUAL "Windows")
|
|
set(LIBTORCH_URL "${LIBTORCH_DEVICE}/libtorch-win-shared-with-deps-${PYTORCH_VERSION}%2B${LIBTORCH_DEVICE}.zip")
|
|
elseif(${CMAKE_SYSTEM_NAME} STREQUAL "Linux")
|
|
set(LIBTORCH_URL "${LIBTORCH_DEVICE}/libtorch-shared-with-deps-${PYTORCH_VERSION}%2B${LIBTORCH_DEVICE}.zip")
|
|
elseif(${CMAKE_SYSTEM_NAME} STREQUAL "Darwin")
|
|
if(NOT ${LIBTORCH_DEVICE} STREQUAL "cpu")
|
|
message(WARNING "MacOS binaries support CPU version only, using it instead.")
|
|
set(LIBTORCH_DEVICE "cpu")
|
|
endif()
|
|
set(LIBTORCH_URL "cpu/libtorch-macos-arm64-${PYTORCH_VERSION}.zip")
|
|
else()
|
|
message(FATAL_ERROR "Unsupported CMake System Name '${CMAKE_SYSTEM_NAME}' (expected 'Windows', 'Linux' or 'Darwin')")
|
|
endif()
|
|
|
|
message(STATUS "Downloading libtorch version ${PYTORCH_VERSION} for ${LIBTORCH_DEVICE} on ${CMAKE_SYSTEM_NAME} from ${LIBTORCH_URL}...")
|
|
|
|
FetchContent_Declare(
|
|
libtorch
|
|
PREFIX libtorch
|
|
DOWNLOAD_DIR ${CMAKE_SOURCE_DIR}/libtorch
|
|
SOURCE_DIR ${CMAKE_SOURCE_DIR}/libtorch
|
|
URL "https://download.pytorch.org/libtorch/${LIBTORCH_URL}"
|
|
)
|
|
|
|
FetchContent_MakeAvailable(libtorch)
|
|
|
|
message(STATUS "Downloaded libtorch.")
|
|
|
|
find_package(Torch REQUIRED PATHS "${CMAKE_SOURCE_DIR}/libtorch")
|