数据文件保存机构改为### 1.2. 文件存储位置 - **最终产物**: 所有最终模型、元数据文件、损失图等,统一存放在 `saved_models/` 根目录下。 - **过程文件**: 所有训练过程中的检查点文件,统一存放在 `saved_models/checkpoints/` 目录下。 ### 1.3. 文件名生成规则 1. **构建逻辑路径**: 根据训练参数(模式、范围、类型、版本)确定逻辑路径。 - *示例*: `product/P001_all/mlstm/v2` 2. **生成文件名前缀**: 将逻辑路径中的所有 `/` 替换为 `_`。 - *示例*: `product_P001_all_mlstm_v2` 3. **拼接文件后缀**: 在前缀后加上描述文件类型的后缀。 - `_model.pth` - `_loss_curve.png` - `_checkpoint_best.pth` - `_checkpoint_epoch_{N}.pth` #### **完整示例:** - **最终模型**: `saved_models/product_P001_all_mlstm_v2_model.pth` - **最佳检查点**: `saved_models/checkpoints/product_P001_all_mlstm_v2_checkpoint_best.pth` - **Epoch 50 检查点**: `saved_models/checkpoints/product_P001_all_mlstm_v2_checkpoint_epoch_50.pth`
118 lines
5.2 KiB
Python
118 lines
5.2 KiB
Python
import unittest
|
||
import os
|
||
import shutil
|
||
import sys
|
||
|
||
# 将项目根目录添加到 sys.path,以解决模块导入问题
|
||
# 这使得测试脚本可以直接运行,而无需复杂的路径配置
|
||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
|
||
if project_root not in sys.path:
|
||
sys.path.insert(0, project_root)
|
||
|
||
from server.utils.file_save import ModelPathManager
|
||
|
||
class TestModelPathManager(unittest.TestCase):
|
||
"""
|
||
测试 ModelPathManager 是否严格遵循扁平化文件存储规范。
|
||
"""
|
||
def setUp(self):
|
||
"""在每个测试用例开始前,设置测试环境。"""
|
||
self.test_base_dir = 'test_saved_models'
|
||
# 清理之前的测试目录和文件
|
||
if os.path.exists(self.test_base_dir):
|
||
shutil.rmtree(self.test_base_dir)
|
||
self.path_manager = ModelPathManager(base_dir=self.test_base_dir)
|
||
|
||
def tearDown(self):
|
||
"""在每个测试用例结束后,清理测试环境。"""
|
||
if os.path.exists(self.test_base_dir):
|
||
shutil.rmtree(self.test_base_dir)
|
||
|
||
def test_product_mode_path_generation(self):
|
||
"""测试 'product' 模式下的路径生成是否符合规范。"""
|
||
print("\n--- 测试 'product' 模式 ---")
|
||
params = {
|
||
'training_mode': 'product',
|
||
'model_type': 'mlstm',
|
||
'product_id': 'P001',
|
||
'store_id': 'all'
|
||
}
|
||
|
||
# 第一次调用,版本应为 1
|
||
paths_v1 = self.path_manager.get_model_paths(**params)
|
||
|
||
# 验证版本号
|
||
self.assertEqual(paths_v1['version'], 1)
|
||
|
||
# 验证文件名前缀
|
||
expected_prefix_v1 = 'product_P001_all_mlstm_v1'
|
||
self.assertEqual(paths_v1['filename_prefix'], expected_prefix_v1)
|
||
|
||
# 验证各个文件的完整路径
|
||
self.assertEqual(paths_v1['model_path'], os.path.join(self.test_base_dir, f'{expected_prefix_v1}_model.pth'))
|
||
self.assertEqual(paths_v1['metadata_path'], os.path.join(self.test_base_dir, f'{expected_prefix_v1}_metadata.json'))
|
||
self.assertEqual(paths_v1['loss_curve_path'], os.path.join(self.test_base_dir, f'{expected_prefix_v1}_loss_curve.png'))
|
||
|
||
# 验证检查点路径
|
||
checkpoint_dir = os.path.join(self.test_base_dir, 'checkpoints')
|
||
self.assertEqual(paths_v1['checkpoint_dir'], checkpoint_dir)
|
||
self.assertEqual(paths_v1['best_checkpoint_path'], os.path.join(checkpoint_dir, f'{expected_prefix_v1}_checkpoint_best.pth'))
|
||
self.assertEqual(paths_v1['epoch_checkpoint_template'], os.path.join(checkpoint_dir, f'{expected_prefix_v1}_checkpoint_epoch_{{N}}.pth'))
|
||
|
||
print(f"生成的文件名前缀: {paths_v1['filename_prefix']}")
|
||
print(f"生成的模型路径: {paths_v1['model_path']}")
|
||
print("验证通过!")
|
||
|
||
# 模拟一次成功的训练,以触发版本递增
|
||
self.path_manager.save_version_info(paths_v1['identifier'], paths_v1['version'])
|
||
|
||
# 第二次调用,版本应为 2
|
||
paths_v2 = self.path_manager.get_model_paths(**params)
|
||
self.assertEqual(paths_v2['version'], 2)
|
||
expected_prefix_v2 = 'product_P001_all_mlstm_v2'
|
||
self.assertEqual(paths_v2['filename_prefix'], expected_prefix_v2)
|
||
print(f"\n版本递增后,生成的文件名前缀: {paths_v2['filename_prefix']}")
|
||
print("版本递增验证通过!")
|
||
|
||
def test_store_mode_path_generation_with_hash(self):
|
||
"""测试 'store' 模式下使用哈希的路径生成。"""
|
||
print("\n--- 测试 'store' 模式 (多药品ID哈希) ---")
|
||
params = {
|
||
'training_mode': 'store',
|
||
'model_type': 'kan',
|
||
'store_id': 'S008',
|
||
'product_scope': 'specific',
|
||
'product_ids': ['P002', 'P005', 'P003'] # 顺序故意打乱
|
||
}
|
||
|
||
paths = self.path_manager.get_model_paths(**params)
|
||
|
||
# 哈希值应该是固定的,因为ID列表会先排序再哈希
|
||
expected_hash = self.path_manager._hash_ids(sorted(['P002', 'P005', 'P003']))
|
||
expected_prefix = f'store_S008_{expected_hash}_kan_v1'
|
||
|
||
self.assertEqual(paths['filename_prefix'], expected_prefix)
|
||
self.assertEqual(paths['model_path'], os.path.join(self.test_base_dir, f'{expected_prefix}_model.pth'))
|
||
print(f"生成的文件名前缀: {paths['filename_prefix']}")
|
||
print("验证通过!")
|
||
|
||
def test_global_mode_path_generation(self):
|
||
"""测试 'global' 模式下的路径生成。"""
|
||
print("\n--- 测试 'global' 模式 ---")
|
||
params = {
|
||
'training_mode': 'global',
|
||
'model_type': 'transformer',
|
||
'training_scope': 'all',
|
||
'aggregation_method': 'mean'
|
||
}
|
||
|
||
paths = self.path_manager.get_model_paths(**params)
|
||
|
||
expected_prefix = 'global_all_agg_mean_transformer_v1'
|
||
self.assertEqual(paths['filename_prefix'], expected_prefix)
|
||
self.assertEqual(paths['model_path'], os.path.join(self.test_base_dir, f'{expected_prefix}_model.pth'))
|
||
print(f"生成的文件名前缀: {paths['filename_prefix']}")
|
||
print("验证通过!")
|
||
|
||
if __name__ == '__main__':
|
||
unittest.main() |