ShopTRAINING/install_torch_gpu.sh

27 lines
922 B
Bash
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/bin/bash
echo "安装PyTorch GPU版本通过官方源"
echo "==================================="
echo ""
echo "请选择CUDA版本:"
echo "1. CUDA 12.1 (适用于较新的NVIDIA GPU)"
echo "2. CUDA 11.8 (适用于较旧的NVIDIA GPU)"
echo ""
read -p "请输入选项 (1/2): " choice
if [ "$choice" = "1" ]; then
echo "正在安装PyTorch CUDA 12.1版本..."
pip install torch==2.7.1 --index-url https://download.pytorch.org/whl/cu121
elif [ "$choice" = "2" ]; then
echo "正在安装PyTorch CUDA 11.8版本..."
pip install torch==2.7.1 --index-url https://download.pytorch.org/whl/cu118
else
echo "无效的选项!"
exit 1
fi
echo ""
echo "验证PyTorch GPU支持状态..."
python -c "import torch; print('CUDA是否可用:', torch.cuda.is_available()); print('PyTorch版本:', torch.__version__); print('CUDA版本:', torch.version.cuda if torch.cuda.is_available() else '无')"