import os import sys import shutil import json # 将项目根目录添加到系统路径,以便导入server模块 project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) sys.path.insert(0, project_root) from server.utils.file_save import ModelPathManager def run_tests(): """执行所有路径生成逻辑的测试""" # --- 测试设置 --- test_base_dir = 'test_saved_models' if os.path.exists(test_base_dir): shutil.rmtree(test_base_dir) # 清理旧的测试目录 path_manager = ModelPathManager(base_dir=test_base_dir) model_type = 'mlstm' print("="*50) print("🚀 开始测试 ModelPathManager 路径生成逻辑...") print(f"测试根目录: {os.path.abspath(test_base_dir)}") print("="*50) # --- 1. 按店铺训练 (Store Training) 测试 --- print("\n--- 🧪 1. 按店铺训练 (Store Training) ---") # a) 店铺训练 - 所有药品 print("\n[1a] 场景: 店铺训练 - 所有药品") store_payload_all = { 'store_id': 'S001', 'model_type': model_type, 'training_mode': 'store', 'product_scope': 'all' } payload = store_payload_all.copy() payload.pop('model_type', None) payload.pop('training_mode', None) paths_store_all = path_manager.get_model_paths(training_mode='store', model_type=model_type, **payload) print(f" - Identifier: {paths_store_all['identifier']}") print(f" - Version Dir: {paths_store_all['version_dir']}") assert f"store_S001_products_all_{model_type}" == paths_store_all['identifier'] expected_path = os.path.join(test_base_dir, 'store', 'S001_all', model_type, 'v1') assert os.path.normpath(expected_path) == os.path.normpath(paths_store_all['version_dir']) # b) 店铺训练 - 特定药品 (使用哈希) print("\n[1b] 场景: 店铺训练 - 特定药品 (使用哈希)") store_payload_specific = { 'store_id': 'S002', 'model_type': model_type, 'training_mode': 'store', 'product_scope': 'specific', 'product_ids': ['P001', 'P005', 'P002'] } payload = store_payload_specific.copy() payload.pop('model_type', None) payload.pop('training_mode', None) paths_store_specific = path_manager.get_model_paths(training_mode='store', model_type=model_type, **payload) hashed_ids = path_manager._hash_ids(['P001', 'P005', 'P002']) print(f" - Hashed IDs: {hashed_ids}") print(f" - Identifier: {paths_store_specific['identifier']}") print(f" - Version Dir: {paths_store_specific['version_dir']}") assert f"store_S002_products_{hashed_ids}_{model_type}" == paths_store_specific['identifier'] expected_path = os.path.join(test_base_dir, 'store', f'S002_{hashed_ids}', model_type, 'v1') assert os.path.normpath(expected_path) == os.path.normpath(paths_store_specific['version_dir']) # c) 店铺训练 - 单个指定药品 print("\n[1c] 场景: 店铺训练 - 单个指定药品") store_payload_single_product = { 'store_id': 'S003', 'model_type': model_type, 'training_mode': 'store', 'product_scope': 'specific', 'product_ids': ['P789'] } payload = store_payload_single_product.copy() payload.pop('model_type', None) payload.pop('training_mode', None) paths_store_single_product = path_manager.get_model_paths(training_mode='store', model_type=model_type, **payload) print(f" - Identifier: {paths_store_single_product['identifier']}") print(f" - Version Dir: {paths_store_single_product['version_dir']}") assert f"store_S003_products_P789_{model_type}" == paths_store_single_product['identifier'] expected_path = os.path.join(test_base_dir, 'store', 'S003_P789', model_type, 'v1') assert os.path.normpath(expected_path) == os.path.normpath(paths_store_single_product['version_dir']) # --- 2. 按药品训练 (Product Training) 测试 --- print("\n--- 🧪 2. 按药品训练 (Product Training) ---") # a) 药品训练 - 所有店铺 print("\n[2a] 场景: 药品训练 - 所有店铺") product_payload_all = { 'product_id': 'P123', 'model_type': model_type, 'training_mode': 'product', 'store_id': None # 明确测试 None 的情况 } payload = product_payload_all.copy() payload.pop('model_type', None) payload.pop('training_mode', None) paths_product_all = path_manager.get_model_paths(training_mode='product', model_type=model_type, **payload) print(f" - Identifier: {paths_product_all['identifier']}") print(f" - Version Dir: {paths_product_all['version_dir']}") assert f"product_P123_scope_all_{model_type}" == paths_product_all['identifier'] expected_path = os.path.join(test_base_dir, 'product', 'P123_all', model_type, 'v1') assert os.path.normpath(expected_path) == os.path.normpath(paths_product_all['version_dir']) # b) 药品训练 - 特定店铺 print("\n[2b] 场景: 药品训练 - 特定店铺") product_payload_specific = { 'product_id': 'P456', 'store_id': 'S003', 'model_type': model_type, 'training_mode': 'product' } payload = product_payload_specific.copy() payload.pop('model_type', None) payload.pop('training_mode', None) paths_product_specific = path_manager.get_model_paths(training_mode='product', model_type=model_type, **payload) print(f" - Identifier: {paths_product_specific['identifier']}") print(f" - Version Dir: {paths_product_specific['version_dir']}") assert f"product_P456_scope_S003_{model_type}" == paths_product_specific['identifier'] expected_path = os.path.join(test_base_dir, 'product', 'P456_S003', model_type, 'v1') assert os.path.normpath(expected_path) == os.path.normpath(paths_product_specific['version_dir']) # --- 3. 全局训练 (Global Training) 测试 --- print("\n--- 🧪 3. 全局训练 (Global Training) ---") # a) 全局训练 - 所有数据 print("\n[3a] 场景: 全局训练 - 所有数据") global_payload_all = { 'model_type': model_type, 'training_mode': 'global', 'training_scope': 'all', 'aggregation_method': 'sum' } payload = global_payload_all.copy() payload.pop('model_type', None) payload.pop('training_mode', None) paths_global_all = path_manager.get_model_paths(training_mode='global', model_type=model_type, **payload) print(f" - Identifier: {paths_global_all['identifier']}") print(f" - Version Dir: {paths_global_all['version_dir']}") assert f"global_all_agg_sum_{model_type}" == paths_global_all['identifier'] expected_path = os.path.join(test_base_dir, 'global', 'all', 'sum', model_type, 'v1') assert os.path.normpath(expected_path) == os.path.normpath(paths_global_all['version_dir']) # a2) 全局训练 - 所有数据 (使用 all_stores_all_products) print("\n[3a2] 场景: 全局训练 - 所有数据 (使用 'all_stores_all_products')") global_payload_all_alt = { 'model_type': model_type, 'training_mode': 'global', 'training_scope': 'all_stores_all_products', 'aggregation_method': 'sum' } payload = global_payload_all_alt.copy() payload.pop('model_type', None) payload.pop('training_mode', None) paths_global_all_alt = path_manager.get_model_paths(training_mode='global', model_type=model_type, **payload) assert f"global_all_agg_sum_{model_type}" == paths_global_all_alt['identifier'] assert os.path.normpath(expected_path) == os.path.normpath(paths_global_all_alt['version_dir']) # b) 全局训练 - 自定义范围 (使用哈希) print("\n[3b] 场景: 全局训练 - 自定义范围 (使用哈希)") global_payload_custom = { 'model_type': model_type, 'training_mode': 'global', 'training_scope': 'custom', 'aggregation_method': 'mean', 'store_ids': ['S001', 'S003'], 'product_ids': ['P001', 'P002'] } payload = global_payload_custom.copy() payload.pop('model_type', None) payload.pop('training_mode', None) paths_global_custom = path_manager.get_model_paths(training_mode='global', model_type=model_type, **payload) s_hash = path_manager._hash_ids(['S001', 'S003']) p_hash = path_manager._hash_ids(['P001', 'P002']) print(f" - Store Hash: {s_hash}, Product Hash: {p_hash}") print(f" - Identifier: {paths_global_custom['identifier']}") print(f" - Version Dir: {paths_global_custom['version_dir']}") assert f"global_custom_s_{s_hash}_p_{p_hash}_agg_mean_{model_type}" == paths_global_custom['identifier'] expected_path = os.path.join(test_base_dir, 'global', 'custom', s_hash, p_hash, 'mean', model_type, 'v1') assert os.path.normpath(expected_path) == os.path.normpath(paths_global_custom['version_dir']) # c) 全局训练 - 单个店铺 print("\n[3c] 场景: 全局训练 - 单个店铺") global_payload_single_store = { 'model_type': model_type, 'training_mode': 'global', 'training_scope': 'selected_stores', 'aggregation_method': 'mean', 'store_ids': ['S007'] } payload = global_payload_single_store.copy() payload.pop('model_type', None) payload.pop('training_mode', None) paths_global_single_store = path_manager.get_model_paths(training_mode='global', model_type=model_type, **payload) print(f" - Identifier: {paths_global_single_store['identifier']}") print(f" - Version Dir: {paths_global_single_store['version_dir']}") assert f"global_stores_S007_agg_mean_{model_type}" == paths_global_single_store['identifier'] expected_path = os.path.join(test_base_dir, 'global', 'stores', 'S007', 'mean', model_type, 'v1') assert os.path.normpath(expected_path) == os.path.normpath(paths_global_single_store['version_dir']) # d) 全局训练 - 自定义范围 (单ID) print("\n[3d] 场景: 全局训练 - 自定义范围 (单ID)") global_payload_custom_single = { 'model_type': model_type, 'training_mode': 'global', 'training_scope': 'custom', 'aggregation_method': 'mean', 'store_ids': ['S008'], 'product_ids': ['P888'] } payload = global_payload_custom_single.copy() payload.pop('model_type', None) payload.pop('training_mode', None) paths_global_custom_single = path_manager.get_model_paths(training_mode='global', model_type=model_type, **payload) print(f" - Identifier: {paths_global_custom_single['identifier']}") print(f" - Version Dir: {paths_global_custom_single['version_dir']}") assert f"global_custom_s_S008_p_P888_agg_mean_{model_type}" == paths_global_custom_single['identifier'] expected_path = os.path.join(test_base_dir, 'global', 'custom', 'S008', 'P888', 'mean', model_type, 'v1') assert os.path.normpath(expected_path) == os.path.normpath(paths_global_custom_single['version_dir']) # --- 4. 版本管理测试 --- print("\n--- 🧪 4. 版本管理测试 ---") print("\n[4a] 场景: 多次调用同一训练,版本号递增") # 第一次训练 path_manager.save_version_info(paths_store_all['identifier'], paths_store_all['version']) print(f" - 保存版本: {paths_store_all['identifier']} -> v{paths_store_all['version']}") # 第二次训练 payload = store_payload_all.copy() payload.pop('model_type', None) payload.pop('training_mode', None) paths_store_all_v2 = path_manager.get_model_paths(training_mode='store', model_type=model_type, **payload) print(f" - 获取新版本: {paths_store_all_v2['identifier']} -> v{paths_store_all_v2['version']}") assert paths_store_all_v2['version'] == 2 expected_path = os.path.join(test_base_dir, 'store', 'S001_all', model_type, 'v2') assert os.path.normpath(expected_path) == os.path.normpath(paths_store_all_v2['version_dir']) # 验证 versions.json 文件 with open(path_manager.versions_file, 'r') as f: versions_data = json.load(f) print(f" - versions.json 内容: {versions_data}") assert versions_data[paths_store_all['identifier']] == 1 print("\n="*50) print("✅ 所有测试用例通过!") print("="*50) # --- 清理 --- shutil.rmtree(test_base_dir) print(f"🗑️ 测试目录 '{test_base_dir}' 已清理。") if __name__ == '__main__': run_tests()