Skip to content

Commit 96ed3cd

Browse files
authored
chore: map compute_cap from GPU name (#276)
1 parent a7e128b commit 96ed3cd

File tree

1 file changed

+50
-1
lines changed

1 file changed

+50
-1
lines changed

sagemaker-entrypoint-cuda-all.sh

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,26 @@
11
#!/bin/bash
22

3+
verlte() {
4+
[ "$1" = "$2" ] && return 1 || [ "$2" = "$(echo -e "$1\n$2" | sort -V | head -n1)" ]
5+
}
6+
7+
if [ -f /usr/local/cuda/compat/libcuda.so.1 ]; then
8+
CUDA_COMPAT_MAX_DRIVER_VERSION=$(readlink /usr/local/cuda/compat/libcuda.so.1 | cut -d"." -f 3-)
9+
echo "CUDA compat package requires Nvidia driver ≤${CUDA_COMPAT_MAX_DRIVER_VERSION}"
10+
cat /proc/driver/nvidia/version
11+
NVIDIA_DRIVER_VERSION=$(sed -n 's/^NVRM.*Kernel Module *\([0-9.]*\).*$/\1/p' /proc/driver/nvidia/version 2>/dev/null || true)
12+
echo "Current installed Nvidia driver version is ${NVIDIA_DRIVER_VERSION}"
13+
if [ $(verlte "$CUDA_COMPAT_MAX_DRIVER_VERSION" "$NVIDIA_DRIVER_VERSION") ]; then
14+
echo "Setup CUDA compatibility libs path to LD_LIBRARY_PATH"
15+
export LD_LIBRARY_PATH=/usr/local/cuda/compat:$LD_LIBRARY_PATH
16+
echo $LD_LIBRARY_PATH
17+
else
18+
echo "Skip CUDA compat libs setup as newer Nvidia driver is installed"
19+
fi
20+
else
21+
echo "Skip CUDA compat libs setup as package not found"
22+
fi
23+
324
if [[ -z "${HF_MODEL_ID}" ]]; then
425
echo "HF_MODEL_ID must be set"
526
exit 1
@@ -15,9 +36,37 @@ if ! command -v nvidia-smi &> /dev/null; then
1536
exit 1
1637
fi
1738

39+
# Query GPU name using nvidia-smi
40+
gpu_name=$(nvidia-smi --query-gpu=gpu_name --format=csv | awk 'NR==2')
41+
if [ $? -ne 0 ]; then
42+
echo "Error: $gpu_name"
43+
echo "Query gpu_name failed"
44+
else
45+
echo "Query gpu_name succeeded. Printing output: $gpu_name"
46+
fi
47+
48+
# Function to get compute capability based on GPU name
49+
get_compute_cap() {
50+
gpu_name="$1"
51+
52+
# Check if the GPU name contains "A10G"
53+
if [[ "$gpu_name" == *"A10G"* ]]; then
54+
echo "86"
55+
# Check if the GPU name contains "A100"
56+
elif [[ "$gpu_name" == *"A100"* ]]; then
57+
echo "80"
58+
# Check if the GPU name contains "H100"
59+
elif [[ "$gpu_name" == *"H100"* ]]; then
60+
echo "90"
61+
else
62+
echo "80" # Default compute capability
63+
fi
64+
}
65+
1866
if [[ -z "${CUDA_COMPUTE_CAP}" ]]
1967
then
20-
compute_cap=$(nvidia-smi --query-gpu=compute_cap --format=csv | sed -n '2p' | sed 's/\.//g')
68+
compute_cap=$(get_compute_cap "$gpu_name")
69+
echo "the compute_cap is $compute_cap"
2170
else
2271
compute_cap=$CUDA_COMPUTE_CAP
2372
fi

0 commit comments

Comments
 (0)