39 lines
1.2 KiB
Python
39 lines
1.2 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
重命名模型文件以符合API命名规范
|
|
"""
|
|
import os
|
|
import glob
|
|
|
|
def fix_model_names():
|
|
saved_models_dir = "saved_models"
|
|
|
|
# 重命名映射
|
|
rename_map = {
|
|
"P001_mlstm_model.pt": "mlstm_product_P001_v1.pth",
|
|
"P001_model.pt": "kan_optimized_product_P001_v1.pth"
|
|
}
|
|
|
|
print("开始重命名模型文件:")
|
|
|
|
for old_name, new_name in rename_map.items():
|
|
old_path = os.path.join(saved_models_dir, old_name)
|
|
new_path = os.path.join(saved_models_dir, new_name)
|
|
|
|
if os.path.exists(old_path):
|
|
try:
|
|
os.rename(old_path, new_path)
|
|
print(f"重命名成功: {old_name} -> {new_name}")
|
|
except Exception as e:
|
|
print(f"重命名失败: {old_name} -> {e}")
|
|
else:
|
|
print(f"文件不存在: {old_name}")
|
|
|
|
# 检查结果
|
|
model_files = glob.glob(f"{saved_models_dir}/*.pth")
|
|
print(f"\n当前saved_models目录包含 {len(model_files)} 个.pth模型文件:")
|
|
for model in model_files:
|
|
print(f" - {os.path.basename(model)}")
|
|
|
|
if __name__ == "__main__":
|
|
fix_model_names() |