ShopTRAINING/install/install_torch_gpu.sh

27 lines
922 B
Bash
Raw Permalink Normal View History

#!/bin/bash
echo "安装PyTorch GPU版本通过官方源"
echo "==================================="
echo ""
echo "请选择CUDA版本:"
echo "1. CUDA 12.1 (适用于较新的NVIDIA GPU)"
echo "2. CUDA 11.8 (适用于较旧的NVIDIA GPU)"
echo ""
read -p "请输入选项 (1/2): " choice
if [ "$choice" = "1" ]; then
echo "正在安装PyTorch CUDA 12.1版本..."
pip install torch==2.7.1 --index-url https://download.pytorch.org/whl/cu121
elif [ "$choice" = "2" ]; then
echo "正在安装PyTorch CUDA 11.8版本..."
pip install torch==2.7.1 --index-url https://download.pytorch.org/whl/cu118
else
echo "无效的选项!"
exit 1
fi
echo ""
echo "验证PyTorch GPU支持状态..."
python -c "import torch; print('CUDA是否可用:', torch.cuda.is_available()); print('PyTorch版本:', torch.__version__); print('CUDA版本:', torch.version.cuda if torch.cuda.is_available() else '无')"