#!/bin/bash echo "安装PyTorch GPU版本(通过官方源)" echo "===================================" echo "" echo "请选择CUDA版本:" echo "1. CUDA 12.1 (适用于较新的NVIDIA GPU)" echo "2. CUDA 11.8 (适用于较旧的NVIDIA GPU)" echo "" read -p "请输入选项 (1/2): " choice if [ "$choice" = "1" ]; then echo "正在安装PyTorch CUDA 12.1版本..." pip install torch==2.7.1 --index-url https://download.pytorch.org/whl/cu121 elif [ "$choice" = "2" ]; then echo "正在安装PyTorch CUDA 11.8版本..." pip install torch==2.7.1 --index-url https://download.pytorch.org/whl/cu118 else echo "无效的选项!" exit 1 fi echo "" echo "验证PyTorch GPU支持状态..." python -c "import torch; print('CUDA是否可用:', torch.cuda.is_available()); print('PyTorch版本:', torch.__version__); print('CUDA版本:', torch.version.cuda if torch.cuda.is_available() else '无')"