1
1
#! /bin/bash
2
2
3
+ verlte () {
4
+ [ " $1 " = " $2 " ] && return 1 || [ " $2 " = " $( echo -e " $1 \n$2 " | sort -V | head -n1) " ]
5
+ }
6
+
7
+ if [ -f /usr/local/cuda/compat/libcuda.so.1 ]; then
8
+ CUDA_COMPAT_MAX_DRIVER_VERSION=$( readlink /usr/local/cuda/compat/libcuda.so.1 | cut -d" ." -f 3-)
9
+ echo " CUDA compat package requires Nvidia driver ≤${CUDA_COMPAT_MAX_DRIVER_VERSION} "
10
+ cat /proc/driver/nvidia/version
11
+ NVIDIA_DRIVER_VERSION=$( sed -n ' s/^NVRM.*Kernel Module *\([0-9.]*\).*$/\1/p' /proc/driver/nvidia/version 2> /dev/null || true)
12
+ echo " Current installed Nvidia driver version is ${NVIDIA_DRIVER_VERSION} "
13
+ if [ $( verlte " $CUDA_COMPAT_MAX_DRIVER_VERSION " " $NVIDIA_DRIVER_VERSION " ) ]; then
14
+ echo " Setup CUDA compatibility libs path to LD_LIBRARY_PATH"
15
+ export LD_LIBRARY_PATH=/usr/local/cuda/compat:$LD_LIBRARY_PATH
16
+ echo $LD_LIBRARY_PATH
17
+ else
18
+ echo " Skip CUDA compat libs setup as newer Nvidia driver is installed"
19
+ fi
20
+ else
21
+ echo " Skip CUDA compat libs setup as package not found"
22
+ fi
23
+
3
24
if [[ -z " ${HF_MODEL_ID} " ]]; then
4
25
echo " HF_MODEL_ID must be set"
5
26
exit 1
@@ -15,9 +36,37 @@ if ! command -v nvidia-smi &> /dev/null; then
15
36
exit 1
16
37
fi
17
38
39
+ # Query GPU name using nvidia-smi
40
+ gpu_name=$( nvidia-smi --query-gpu=gpu_name --format=csv | awk ' NR==2' )
41
+ if [ $? -ne 0 ]; then
42
+ echo " Error: $gpu_name "
43
+ echo " Query gpu_name failed"
44
+ else
45
+ echo " Query gpu_name succeeded. Printing output: $gpu_name "
46
+ fi
47
+
48
+ # Function to get compute capability based on GPU name
49
+ get_compute_cap () {
50
+ gpu_name=" $1 "
51
+
52
+ # Check if the GPU name contains "A10G"
53
+ if [[ " $gpu_name " == * " A10G" * ]]; then
54
+ echo " 86"
55
+ # Check if the GPU name contains "A100"
56
+ elif [[ " $gpu_name " == * " A100" * ]]; then
57
+ echo " 80"
58
+ # Check if the GPU name contains "H100"
59
+ elif [[ " $gpu_name " == * " H100" * ]]; then
60
+ echo " 90"
61
+ else
62
+ echo " 80" # Default compute capability
63
+ fi
64
+ }
65
+
18
66
if [[ -z " ${CUDA_COMPUTE_CAP} " ]]
19
67
then
20
- compute_cap=$( nvidia-smi --query-gpu=compute_cap --format=csv | sed -n ' 2p' | sed ' s/\.//g' )
68
+ compute_cap=$( get_compute_cap " $gpu_name " )
69
+ echo " the compute_cap is $compute_cap "
21
70
else
22
71
compute_cap=$CUDA_COMPUTE_CAP
23
72
fi
0 commit comments