import unittest import os import shutil import sys # 将项目根目录添加到 sys.path,以解决模块导入问题 # 这使得测试脚本可以直接运行,而无需复杂的路径配置 project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) if project_root not in sys.path: sys.path.insert(0, project_root) from server.utils.file_save import ModelPathManager class TestModelPathManager(unittest.TestCase): """ 测试 ModelPathManager 是否严格遵循扁平化文件存储规范。 """ def setUp(self): """在每个测试用例开始前,设置测试环境。""" self.test_base_dir = 'test_saved_models' # 清理之前的测试目录和文件 if os.path.exists(self.test_base_dir): shutil.rmtree(self.test_base_dir) self.path_manager = ModelPathManager(base_dir=self.test_base_dir) def tearDown(self): """在每个测试用例结束后,清理测试环境。""" if os.path.exists(self.test_base_dir): shutil.rmtree(self.test_base_dir) def test_product_mode_path_generation(self): """测试 'product' 模式下的路径生成是否符合规范。""" print("\n--- 测试 'product' 模式 ---") params = { 'training_mode': 'product', 'model_type': 'mlstm', 'product_id': 'P001', 'store_id': 'all' } # 第一次调用,版本应为 1 paths_v1 = self.path_manager.get_model_paths(**params) # 验证版本号 self.assertEqual(paths_v1['version'], 1) # 验证文件名前缀 expected_prefix_v1 = 'product_P001_all_mlstm_v1' self.assertEqual(paths_v1['filename_prefix'], expected_prefix_v1) # 验证各个文件的完整路径 self.assertEqual(paths_v1['model_path'], os.path.join(self.test_base_dir, f'{expected_prefix_v1}_model.pth')) self.assertEqual(paths_v1['metadata_path'], os.path.join(self.test_base_dir, f'{expected_prefix_v1}_metadata.json')) self.assertEqual(paths_v1['loss_curve_path'], os.path.join(self.test_base_dir, f'{expected_prefix_v1}_loss_curve.png')) # 验证检查点路径 checkpoint_dir = os.path.join(self.test_base_dir, 'checkpoints') self.assertEqual(paths_v1['checkpoint_dir'], checkpoint_dir) self.assertEqual(paths_v1['best_checkpoint_path'], os.path.join(checkpoint_dir, f'{expected_prefix_v1}_checkpoint_best.pth')) self.assertEqual(paths_v1['epoch_checkpoint_template'], os.path.join(checkpoint_dir, f'{expected_prefix_v1}_checkpoint_epoch_{{N}}.pth')) print(f"生成的文件名前缀: {paths_v1['filename_prefix']}") print(f"生成的模型路径: {paths_v1['model_path']}") print("验证通过!") # 模拟一次成功的训练,以触发版本递增 self.path_manager.save_version_info(paths_v1['identifier'], paths_v1['version']) # 第二次调用,版本应为 2 paths_v2 = self.path_manager.get_model_paths(**params) self.assertEqual(paths_v2['version'], 2) expected_prefix_v2 = 'product_P001_all_mlstm_v2' self.assertEqual(paths_v2['filename_prefix'], expected_prefix_v2) print(f"\n版本递增后,生成的文件名前缀: {paths_v2['filename_prefix']}") print("版本递增验证通过!") def test_store_mode_path_generation_with_hash(self): """测试 'store' 模式下使用哈希的路径生成。""" print("\n--- 测试 'store' 模式 (多药品ID哈希) ---") params = { 'training_mode': 'store', 'model_type': 'kan', 'store_id': 'S008', 'product_scope': 'specific', 'product_ids': ['P002', 'P005', 'P003'] # 顺序故意打乱 } paths = self.path_manager.get_model_paths(**params) # 哈希值应该是固定的,因为ID列表会先排序再哈希 expected_hash = self.path_manager._hash_ids(sorted(['P002', 'P005', 'P003'])) expected_prefix = f'store_S008_{expected_hash}_kan_v1' self.assertEqual(paths['filename_prefix'], expected_prefix) self.assertEqual(paths['model_path'], os.path.join(self.test_base_dir, f'{expected_prefix}_model.pth')) print(f"生成的文件名前缀: {paths['filename_prefix']}") print("验证通过!") def test_global_mode_path_generation(self): """测试 'global' 模式下的路径生成。""" print("\n--- 测试 'global' 模式 ---") params = { 'training_mode': 'global', 'model_type': 'transformer', 'training_scope': 'all', 'aggregation_method': 'mean' } paths = self.path_manager.get_model_paths(**params) expected_prefix = 'global_all_agg_mean_transformer_v1' self.assertEqual(paths['filename_prefix'], expected_prefix) self.assertEqual(paths['model_path'], os.path.join(self.test_base_dir, f'{expected_prefix}_model.pth')) print(f"生成的文件名前缀: {paths['filename_prefix']}") print("验证通过!") if __name__ == '__main__': unittest.main()