ShopTRAINING/test/fix_model_names.py
2025-07-02 11:05:23 +08:00

39 lines
1.2 KiB
Python

#!/usr/bin/env python3
"""
重命名模型文件以符合API命名规范
"""
import os
import glob
def fix_model_names():
saved_models_dir = "saved_models"
# 重命名映射
rename_map = {
"P001_mlstm_model.pt": "mlstm_product_P001_v1.pth",
"P001_model.pt": "kan_optimized_product_P001_v1.pth"
}
print("开始重命名模型文件:")
for old_name, new_name in rename_map.items():
old_path = os.path.join(saved_models_dir, old_name)
new_path = os.path.join(saved_models_dir, new_name)
if os.path.exists(old_path):
try:
os.rename(old_path, new_path)
print(f"重命名成功: {old_name} -> {new_name}")
except Exception as e:
print(f"重命名失败: {old_name} -> {e}")
else:
print(f"文件不存在: {old_name}")
# 检查结果
model_files = glob.glob(f"{saved_models_dir}/*.pth")
print(f"\n当前saved_models目录包含 {len(model_files)} 个.pth模型文件:")
for model in model_files:
print(f" - {os.path.basename(model)}")
if __name__ == "__main__":
fix_model_names()