ShopTRAINING/test/verify_save_logic.py

118 lines
5.2 KiB
Python
Raw Normal View History

import unittest
import os
import shutil
import sys
# 将项目根目录添加到 sys.path以解决模块导入问题
# 这使得测试脚本可以直接运行,而无需复杂的路径配置
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
if project_root not in sys.path:
sys.path.insert(0, project_root)
from server.utils.file_save import ModelPathManager
class TestModelPathManager(unittest.TestCase):
"""
测试 ModelPathManager 是否严格遵循扁平化文件存储规范
"""
def setUp(self):
"""在每个测试用例开始前,设置测试环境。"""
self.test_base_dir = 'test_saved_models'
# 清理之前的测试目录和文件
if os.path.exists(self.test_base_dir):
shutil.rmtree(self.test_base_dir)
self.path_manager = ModelPathManager(base_dir=self.test_base_dir)
def tearDown(self):
"""在每个测试用例结束后,清理测试环境。"""
if os.path.exists(self.test_base_dir):
shutil.rmtree(self.test_base_dir)
def test_product_mode_path_generation(self):
"""测试 'product' 模式下的路径生成是否符合规范。"""
print("\n--- 测试 'product' 模式 ---")
params = {
'training_mode': 'product',
'model_type': 'mlstm',
'product_id': 'P001',
'store_id': 'all'
}
# 第一次调用,版本应为 1
paths_v1 = self.path_manager.get_model_paths(**params)
# 验证版本号
self.assertEqual(paths_v1['version'], 1)
# 验证文件名前缀
expected_prefix_v1 = 'product_P001_all_mlstm_v1'
self.assertEqual(paths_v1['filename_prefix'], expected_prefix_v1)
# 验证各个文件的完整路径
self.assertEqual(paths_v1['model_path'], os.path.join(self.test_base_dir, f'{expected_prefix_v1}_model.pth'))
self.assertEqual(paths_v1['metadata_path'], os.path.join(self.test_base_dir, f'{expected_prefix_v1}_metadata.json'))
self.assertEqual(paths_v1['loss_curve_path'], os.path.join(self.test_base_dir, f'{expected_prefix_v1}_loss_curve.png'))
# 验证检查点路径
checkpoint_dir = os.path.join(self.test_base_dir, 'checkpoints')
self.assertEqual(paths_v1['checkpoint_dir'], checkpoint_dir)
self.assertEqual(paths_v1['best_checkpoint_path'], os.path.join(checkpoint_dir, f'{expected_prefix_v1}_checkpoint_best.pth'))
self.assertEqual(paths_v1['epoch_checkpoint_template'], os.path.join(checkpoint_dir, f'{expected_prefix_v1}_checkpoint_epoch_{{N}}.pth'))
print(f"生成的文件名前缀: {paths_v1['filename_prefix']}")
print(f"生成的模型路径: {paths_v1['model_path']}")
print("验证通过!")
# 模拟一次成功的训练,以触发版本递增
self.path_manager.save_version_info(paths_v1['identifier'], paths_v1['version'])
# 第二次调用,版本应为 2
paths_v2 = self.path_manager.get_model_paths(**params)
self.assertEqual(paths_v2['version'], 2)
expected_prefix_v2 = 'product_P001_all_mlstm_v2'
self.assertEqual(paths_v2['filename_prefix'], expected_prefix_v2)
print(f"\n版本递增后,生成的文件名前缀: {paths_v2['filename_prefix']}")
print("版本递增验证通过!")
def test_store_mode_path_generation_with_hash(self):
"""测试 'store' 模式下使用哈希的路径生成。"""
print("\n--- 测试 'store' 模式 (多药品ID哈希) ---")
params = {
'training_mode': 'store',
'model_type': 'kan',
'store_id': 'S008',
'product_scope': 'specific',
'product_ids': ['P002', 'P005', 'P003'] # 顺序故意打乱
}
paths = self.path_manager.get_model_paths(**params)
# 哈希值应该是固定的因为ID列表会先排序再哈希
expected_hash = self.path_manager._hash_ids(sorted(['P002', 'P005', 'P003']))
expected_prefix = f'store_S008_{expected_hash}_kan_v1'
self.assertEqual(paths['filename_prefix'], expected_prefix)
self.assertEqual(paths['model_path'], os.path.join(self.test_base_dir, f'{expected_prefix}_model.pth'))
print(f"生成的文件名前缀: {paths['filename_prefix']}")
print("验证通过!")
def test_global_mode_path_generation(self):
"""测试 'global' 模式下的路径生成。"""
print("\n--- 测试 'global' 模式 ---")
params = {
'training_mode': 'global',
'model_type': 'transformer',
'training_scope': 'all',
'aggregation_method': 'mean'
}
paths = self.path_manager.get_model_paths(**params)
expected_prefix = 'global_all_agg_mean_transformer_v1'
self.assertEqual(paths['filename_prefix'], expected_prefix)
self.assertEqual(paths['model_path'], os.path.join(self.test_base_dir, f'{expected_prefix}_model.pth'))
print(f"生成的文件名前缀: {paths['filename_prefix']}")
print("验证通过!")
if __name__ == '__main__':
unittest.main()