ShopTRAINING/server/check_gpu.py

84 lines
2.7 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
GPU状态检查工具
此脚本用于检查PyTorch是否能够识别和使用GPU
"""
import torch
import sys
def check_gpu():
"""检查GPU可用性并打印详细信息"""
print("=" * 40)
print("📊 PyTorch GPU 支持检查")
print("=" * 40)
# 检查PyTorch版本
print(f"PyTorch版本: {torch.__version__}")
# 检查CUDA是否可用
cuda_available = torch.cuda.is_available()
print(f"CUDA是否可用: {cuda_available}")
if cuda_available:
# 打印CUDA版本
print(f"CUDA版本: {torch.version.cuda}")
# 获取可用的GPU数量
gpu_count = torch.cuda.device_count()
print(f"可用GPU数量: {gpu_count}")
# 打印每个GPU的信息
for i in range(gpu_count):
print(f"\nGPU {i} 信息:")
print(f" 名称: {torch.cuda.get_device_name(i)}")
print(f" 计算能力: {torch.cuda.get_device_capability(i)}")
# 获取GPU内存信息
try:
mem_info = torch.cuda.get_device_properties(i).total_memory / (1024**3)
print(f" 总内存: {mem_info:.2f} GB")
except:
print(" 无法获取内存信息")
# 进行简单的GPU测试
print("\n执行GPU测试...")
try:
# 创建一个在GPU上的张量
x = torch.rand(1000, 1000).cuda()
y = torch.rand(1000, 1000).cuda()
# 计时矩阵乘法操作
import time
start = time.time()
z = torch.matmul(x, y)
torch.cuda.synchronize() # 等待GPU操作完成
end = time.time()
print(f"1000x1000矩阵乘法耗时: {(end-start)*1000:.2f} ms")
print("✅ GPU测试成功!")
except Exception as e:
print(f"❌ GPU测试失败: {str(e)}")
else:
print("\n❌ 未检测到支持CUDA的GPU")
print("如果你的计算机有NVIDIA GPU请确保:")
print("1. 已安装正确的NVIDIA驱动程序")
print("2. 已安装与PyTorch兼容的CUDA版本")
print("3. 安装了正确的PyTorch CUDA版本")
print("\n建议:")
if cuda_available:
print("✅ 你的系统已准备好使用GPU加速药店销售预测模型")
else:
print("🔄 考虑使用GPU版本安装脚本来启用GPU加速:")
if sys.platform == "win32":
print(" 运行 install_dependencies.bat 并选择GPU选项")
else:
print(" 运行 ./install_dependencies.sh 并选择GPU选项")
print("=" * 40)
if __name__ == "__main__":
check_gpu()