42 lines
1.3 KiB
Python
42 lines
1.3 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
将模型文件从server/models目录复制到saved_models目录
|
|
"""
|
|
import os
|
|
import shutil
|
|
import glob
|
|
|
|
def copy_models():
|
|
source_dir = "server/models"
|
|
target_dir = "saved_models"
|
|
|
|
# 确保目标目录存在
|
|
os.makedirs(target_dir, exist_ok=True)
|
|
|
|
# 查找所有.pt和.pth文件
|
|
model_files = []
|
|
model_files.extend(glob.glob(f"{source_dir}/*.pt"))
|
|
model_files.extend(glob.glob(f"{source_dir}/*.pth"))
|
|
model_files.extend(glob.glob(f"{source_dir}/**/*.pt", recursive=True))
|
|
model_files.extend(glob.glob(f"{source_dir}/**/*.pth", recursive=True))
|
|
|
|
print(f"找到 {len(model_files)} 个模型文件:")
|
|
|
|
for model_file in model_files:
|
|
filename = os.path.basename(model_file)
|
|
target_path = os.path.join(target_dir, filename)
|
|
|
|
try:
|
|
shutil.copy2(model_file, target_path)
|
|
print(f"复制成功: {model_file} -> {target_path}")
|
|
except Exception as e:
|
|
print(f"复制失败: {model_file} -> {e}")
|
|
|
|
# 检查结果
|
|
saved_models = glob.glob(f"{target_dir}/*.p*")
|
|
print(f"\n当前saved_models目录包含 {len(saved_models)} 个模型文件:")
|
|
for model in saved_models:
|
|
print(f" - {os.path.basename(model)}")
|
|
|
|
if __name__ == "__main__":
|
|
copy_models() |