ShopTRAINING/test/copy_models.py
2025-07-02 11:05:23 +08:00

42 lines
1.3 KiB
Python

#!/usr/bin/env python3
"""
将模型文件从server/models目录复制到saved_models目录
"""
import os
import shutil
import glob
def copy_models():
source_dir = "server/models"
target_dir = "saved_models"
# 确保目标目录存在
os.makedirs(target_dir, exist_ok=True)
# 查找所有.pt和.pth文件
model_files = []
model_files.extend(glob.glob(f"{source_dir}/*.pt"))
model_files.extend(glob.glob(f"{source_dir}/*.pth"))
model_files.extend(glob.glob(f"{source_dir}/**/*.pt", recursive=True))
model_files.extend(glob.glob(f"{source_dir}/**/*.pth", recursive=True))
print(f"找到 {len(model_files)} 个模型文件:")
for model_file in model_files:
filename = os.path.basename(model_file)
target_path = os.path.join(target_dir, filename)
try:
shutil.copy2(model_file, target_path)
print(f"复制成功: {model_file} -> {target_path}")
except Exception as e:
print(f"复制失败: {model_file} -> {e}")
# 检查结果
saved_models = glob.glob(f"{target_dir}/*.p*")
print(f"\n当前saved_models目录包含 {len(saved_models)} 个模型文件:")
for model in saved_models:
print(f" - {os.path.basename(model)}")
if __name__ == "__main__":
copy_models()