84 lines
2.7 KiB
Python
84 lines
2.7 KiB
Python
#!/usr/bin/env python
|
||
# -*- coding: utf-8 -*-
|
||
|
||
"""
|
||
GPU状态检查工具
|
||
此脚本用于检查PyTorch是否能够识别和使用GPU
|
||
"""
|
||
|
||
import torch
|
||
import sys
|
||
|
||
def check_gpu():
|
||
"""检查GPU可用性并打印详细信息"""
|
||
print("=" * 40)
|
||
print("📊 PyTorch GPU 支持检查")
|
||
print("=" * 40)
|
||
|
||
# 检查PyTorch版本
|
||
print(f"PyTorch版本: {torch.__version__}")
|
||
|
||
# 检查CUDA是否可用
|
||
cuda_available = torch.cuda.is_available()
|
||
print(f"CUDA是否可用: {cuda_available}")
|
||
|
||
if cuda_available:
|
||
# 打印CUDA版本
|
||
print(f"CUDA版本: {torch.version.cuda}")
|
||
|
||
# 获取可用的GPU数量
|
||
gpu_count = torch.cuda.device_count()
|
||
print(f"可用GPU数量: {gpu_count}")
|
||
|
||
# 打印每个GPU的信息
|
||
for i in range(gpu_count):
|
||
print(f"\nGPU {i} 信息:")
|
||
print(f" 名称: {torch.cuda.get_device_name(i)}")
|
||
print(f" 计算能力: {torch.cuda.get_device_capability(i)}")
|
||
|
||
# 获取GPU内存信息
|
||
try:
|
||
mem_info = torch.cuda.get_device_properties(i).total_memory / (1024**3)
|
||
print(f" 总内存: {mem_info:.2f} GB")
|
||
except:
|
||
print(" 无法获取内存信息")
|
||
|
||
# 进行简单的GPU测试
|
||
print("\n执行GPU测试...")
|
||
try:
|
||
# 创建一个在GPU上的张量
|
||
x = torch.rand(1000, 1000).cuda()
|
||
y = torch.rand(1000, 1000).cuda()
|
||
|
||
# 计时矩阵乘法操作
|
||
import time
|
||
start = time.time()
|
||
z = torch.matmul(x, y)
|
||
torch.cuda.synchronize() # 等待GPU操作完成
|
||
end = time.time()
|
||
|
||
print(f"1000x1000矩阵乘法耗时: {(end-start)*1000:.2f} ms")
|
||
print("✅ GPU测试成功!")
|
||
except Exception as e:
|
||
print(f"❌ GPU测试失败: {str(e)}")
|
||
else:
|
||
print("\n❌ 未检测到支持CUDA的GPU")
|
||
print("如果你的计算机有NVIDIA GPU,请确保:")
|
||
print("1. 已安装正确的NVIDIA驱动程序")
|
||
print("2. 已安装与PyTorch兼容的CUDA版本")
|
||
print("3. 安装了正确的PyTorch CUDA版本")
|
||
|
||
print("\n建议:")
|
||
if cuda_available:
|
||
print("✅ 你的系统已准备好使用GPU加速药店销售预测模型")
|
||
else:
|
||
print("🔄 考虑使用GPU版本安装脚本来启用GPU加速:")
|
||
if sys.platform == "win32":
|
||
print(" 运行 install_dependencies.bat 并选择GPU选项")
|
||
else:
|
||
print(" 运行 ./install_dependencies.sh 并选择GPU选项")
|
||
|
||
print("=" * 40)
|
||
|
||
if __name__ == "__main__":
|
||
check_gpu() |