259 lines
12 KiB
Python
259 lines
12 KiB
Python
|
import os
|
|||
|
import sys
|
|||
|
import shutil
|
|||
|
import json
|
|||
|
|
|||
|
# 将项目根目录添加到系统路径,以便导入server模块
|
|||
|
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
|
|||
|
sys.path.insert(0, project_root)
|
|||
|
|
|||
|
from server.utils.file_save import ModelPathManager
|
|||
|
|
|||
|
def run_tests():
|
|||
|
"""执行所有路径生成逻辑的测试"""
|
|||
|
|
|||
|
# --- 测试设置 ---
|
|||
|
test_base_dir = 'test_saved_models'
|
|||
|
if os.path.exists(test_base_dir):
|
|||
|
shutil.rmtree(test_base_dir) # 清理旧的测试目录
|
|||
|
|
|||
|
path_manager = ModelPathManager(base_dir=test_base_dir)
|
|||
|
model_type = 'mlstm'
|
|||
|
|
|||
|
print("="*50)
|
|||
|
print("🚀 开始测试 ModelPathManager 路径生成逻辑...")
|
|||
|
print(f"测试根目录: {os.path.abspath(test_base_dir)}")
|
|||
|
print("="*50)
|
|||
|
|
|||
|
# --- 1. 按店铺训练 (Store Training) 测试 ---
|
|||
|
print("\n--- 🧪 1. 按店铺训练 (Store Training) ---")
|
|||
|
|
|||
|
# a) 店铺训练 - 所有药品
|
|||
|
print("\n[1a] 场景: 店铺训练 - 所有药品")
|
|||
|
store_payload_all = {
|
|||
|
'store_id': 'S001',
|
|||
|
'model_type': model_type,
|
|||
|
'training_mode': 'store',
|
|||
|
'product_scope': 'all'
|
|||
|
}
|
|||
|
payload = store_payload_all.copy()
|
|||
|
payload.pop('model_type', None)
|
|||
|
payload.pop('training_mode', None)
|
|||
|
paths_store_all = path_manager.get_model_paths(training_mode='store', model_type=model_type, **payload)
|
|||
|
print(f" - Identifier: {paths_store_all['identifier']}")
|
|||
|
print(f" - Version Dir: {paths_store_all['version_dir']}")
|
|||
|
assert f"store_S001_products_all_{model_type}" == paths_store_all['identifier']
|
|||
|
expected_path = os.path.join(test_base_dir, 'store', 'S001_all', model_type, 'v1')
|
|||
|
assert os.path.normpath(expected_path) == os.path.normpath(paths_store_all['version_dir'])
|
|||
|
|
|||
|
# b) 店铺训练 - 特定药品 (使用哈希)
|
|||
|
print("\n[1b] 场景: 店铺训练 - 特定药品 (使用哈希)")
|
|||
|
store_payload_specific = {
|
|||
|
'store_id': 'S002',
|
|||
|
'model_type': model_type,
|
|||
|
'training_mode': 'store',
|
|||
|
'product_scope': 'specific',
|
|||
|
'product_ids': ['P001', 'P005', 'P002']
|
|||
|
}
|
|||
|
payload = store_payload_specific.copy()
|
|||
|
payload.pop('model_type', None)
|
|||
|
payload.pop('training_mode', None)
|
|||
|
paths_store_specific = path_manager.get_model_paths(training_mode='store', model_type=model_type, **payload)
|
|||
|
hashed_ids = path_manager._hash_ids(['P001', 'P005', 'P002'])
|
|||
|
print(f" - Hashed IDs: {hashed_ids}")
|
|||
|
print(f" - Identifier: {paths_store_specific['identifier']}")
|
|||
|
print(f" - Version Dir: {paths_store_specific['version_dir']}")
|
|||
|
assert f"store_S002_products_{hashed_ids}_{model_type}" == paths_store_specific['identifier']
|
|||
|
expected_path = os.path.join(test_base_dir, 'store', f'S002_{hashed_ids}', model_type, 'v1')
|
|||
|
assert os.path.normpath(expected_path) == os.path.normpath(paths_store_specific['version_dir'])
|
|||
|
|
|||
|
# c) 店铺训练 - 单个指定药品
|
|||
|
print("\n[1c] 场景: 店铺训练 - 单个指定药品")
|
|||
|
store_payload_single_product = {
|
|||
|
'store_id': 'S003',
|
|||
|
'model_type': model_type,
|
|||
|
'training_mode': 'store',
|
|||
|
'product_scope': 'specific',
|
|||
|
'product_ids': ['P789']
|
|||
|
}
|
|||
|
payload = store_payload_single_product.copy()
|
|||
|
payload.pop('model_type', None)
|
|||
|
payload.pop('training_mode', None)
|
|||
|
paths_store_single_product = path_manager.get_model_paths(training_mode='store', model_type=model_type, **payload)
|
|||
|
print(f" - Identifier: {paths_store_single_product['identifier']}")
|
|||
|
print(f" - Version Dir: {paths_store_single_product['version_dir']}")
|
|||
|
assert f"store_S003_products_P789_{model_type}" == paths_store_single_product['identifier']
|
|||
|
expected_path = os.path.join(test_base_dir, 'store', 'S003_P789', model_type, 'v1')
|
|||
|
assert os.path.normpath(expected_path) == os.path.normpath(paths_store_single_product['version_dir'])
|
|||
|
|
|||
|
# --- 2. 按药品训练 (Product Training) 测试 ---
|
|||
|
print("\n--- 🧪 2. 按药品训练 (Product Training) ---")
|
|||
|
|
|||
|
# a) 药品训练 - 所有店铺
|
|||
|
print("\n[2a] 场景: 药品训练 - 所有店铺")
|
|||
|
product_payload_all = {
|
|||
|
'product_id': 'P123',
|
|||
|
'model_type': model_type,
|
|||
|
'training_mode': 'product',
|
|||
|
'store_id': None # 明确测试 None 的情况
|
|||
|
}
|
|||
|
payload = product_payload_all.copy()
|
|||
|
payload.pop('model_type', None)
|
|||
|
payload.pop('training_mode', None)
|
|||
|
paths_product_all = path_manager.get_model_paths(training_mode='product', model_type=model_type, **payload)
|
|||
|
print(f" - Identifier: {paths_product_all['identifier']}")
|
|||
|
print(f" - Version Dir: {paths_product_all['version_dir']}")
|
|||
|
assert f"product_P123_scope_all_{model_type}" == paths_product_all['identifier']
|
|||
|
expected_path = os.path.join(test_base_dir, 'product', 'P123_all', model_type, 'v1')
|
|||
|
assert os.path.normpath(expected_path) == os.path.normpath(paths_product_all['version_dir'])
|
|||
|
|
|||
|
# b) 药品训练 - 特定店铺
|
|||
|
print("\n[2b] 场景: 药品训练 - 特定店铺")
|
|||
|
product_payload_specific = {
|
|||
|
'product_id': 'P456',
|
|||
|
'store_id': 'S003',
|
|||
|
'model_type': model_type,
|
|||
|
'training_mode': 'product'
|
|||
|
}
|
|||
|
payload = product_payload_specific.copy()
|
|||
|
payload.pop('model_type', None)
|
|||
|
payload.pop('training_mode', None)
|
|||
|
paths_product_specific = path_manager.get_model_paths(training_mode='product', model_type=model_type, **payload)
|
|||
|
print(f" - Identifier: {paths_product_specific['identifier']}")
|
|||
|
print(f" - Version Dir: {paths_product_specific['version_dir']}")
|
|||
|
assert f"product_P456_scope_S003_{model_type}" == paths_product_specific['identifier']
|
|||
|
expected_path = os.path.join(test_base_dir, 'product', 'P456_S003', model_type, 'v1')
|
|||
|
assert os.path.normpath(expected_path) == os.path.normpath(paths_product_specific['version_dir'])
|
|||
|
|
|||
|
# --- 3. 全局训练 (Global Training) 测试 ---
|
|||
|
print("\n--- 🧪 3. 全局训练 (Global Training) ---")
|
|||
|
|
|||
|
# a) 全局训练 - 所有数据
|
|||
|
print("\n[3a] 场景: 全局训练 - 所有数据")
|
|||
|
global_payload_all = {
|
|||
|
'model_type': model_type,
|
|||
|
'training_mode': 'global',
|
|||
|
'training_scope': 'all',
|
|||
|
'aggregation_method': 'sum'
|
|||
|
}
|
|||
|
payload = global_payload_all.copy()
|
|||
|
payload.pop('model_type', None)
|
|||
|
payload.pop('training_mode', None)
|
|||
|
paths_global_all = path_manager.get_model_paths(training_mode='global', model_type=model_type, **payload)
|
|||
|
print(f" - Identifier: {paths_global_all['identifier']}")
|
|||
|
print(f" - Version Dir: {paths_global_all['version_dir']}")
|
|||
|
assert f"global_all_agg_sum_{model_type}" == paths_global_all['identifier']
|
|||
|
expected_path = os.path.join(test_base_dir, 'global', 'all', 'sum', model_type, 'v1')
|
|||
|
assert os.path.normpath(expected_path) == os.path.normpath(paths_global_all['version_dir'])
|
|||
|
|
|||
|
# a2) 全局训练 - 所有数据 (使用 all_stores_all_products)
|
|||
|
print("\n[3a2] 场景: 全局训练 - 所有数据 (使用 'all_stores_all_products')")
|
|||
|
global_payload_all_alt = {
|
|||
|
'model_type': model_type,
|
|||
|
'training_mode': 'global',
|
|||
|
'training_scope': 'all_stores_all_products',
|
|||
|
'aggregation_method': 'sum'
|
|||
|
}
|
|||
|
payload = global_payload_all_alt.copy()
|
|||
|
payload.pop('model_type', None)
|
|||
|
payload.pop('training_mode', None)
|
|||
|
paths_global_all_alt = path_manager.get_model_paths(training_mode='global', model_type=model_type, **payload)
|
|||
|
assert f"global_all_agg_sum_{model_type}" == paths_global_all_alt['identifier']
|
|||
|
assert os.path.normpath(expected_path) == os.path.normpath(paths_global_all_alt['version_dir'])
|
|||
|
|
|||
|
# b) 全局训练 - 自定义范围 (使用哈希)
|
|||
|
print("\n[3b] 场景: 全局训练 - 自定义范围 (使用哈希)")
|
|||
|
global_payload_custom = {
|
|||
|
'model_type': model_type,
|
|||
|
'training_mode': 'global',
|
|||
|
'training_scope': 'custom',
|
|||
|
'aggregation_method': 'mean',
|
|||
|
'store_ids': ['S001', 'S003'],
|
|||
|
'product_ids': ['P001', 'P002']
|
|||
|
}
|
|||
|
payload = global_payload_custom.copy()
|
|||
|
payload.pop('model_type', None)
|
|||
|
payload.pop('training_mode', None)
|
|||
|
paths_global_custom = path_manager.get_model_paths(training_mode='global', model_type=model_type, **payload)
|
|||
|
s_hash = path_manager._hash_ids(['S001', 'S003'])
|
|||
|
p_hash = path_manager._hash_ids(['P001', 'P002'])
|
|||
|
print(f" - Store Hash: {s_hash}, Product Hash: {p_hash}")
|
|||
|
print(f" - Identifier: {paths_global_custom['identifier']}")
|
|||
|
print(f" - Version Dir: {paths_global_custom['version_dir']}")
|
|||
|
assert f"global_custom_s_{s_hash}_p_{p_hash}_agg_mean_{model_type}" == paths_global_custom['identifier']
|
|||
|
expected_path = os.path.join(test_base_dir, 'global', 'custom', s_hash, p_hash, 'mean', model_type, 'v1')
|
|||
|
assert os.path.normpath(expected_path) == os.path.normpath(paths_global_custom['version_dir'])
|
|||
|
|
|||
|
# c) 全局训练 - 单个店铺
|
|||
|
print("\n[3c] 场景: 全局训练 - 单个店铺")
|
|||
|
global_payload_single_store = {
|
|||
|
'model_type': model_type,
|
|||
|
'training_mode': 'global',
|
|||
|
'training_scope': 'selected_stores',
|
|||
|
'aggregation_method': 'mean',
|
|||
|
'store_ids': ['S007']
|
|||
|
}
|
|||
|
payload = global_payload_single_store.copy()
|
|||
|
payload.pop('model_type', None)
|
|||
|
payload.pop('training_mode', None)
|
|||
|
paths_global_single_store = path_manager.get_model_paths(training_mode='global', model_type=model_type, **payload)
|
|||
|
print(f" - Identifier: {paths_global_single_store['identifier']}")
|
|||
|
print(f" - Version Dir: {paths_global_single_store['version_dir']}")
|
|||
|
assert f"global_stores_S007_agg_mean_{model_type}" == paths_global_single_store['identifier']
|
|||
|
expected_path = os.path.join(test_base_dir, 'global', 'stores', 'S007', 'mean', model_type, 'v1')
|
|||
|
assert os.path.normpath(expected_path) == os.path.normpath(paths_global_single_store['version_dir'])
|
|||
|
|
|||
|
# d) 全局训练 - 自定义范围 (单ID)
|
|||
|
print("\n[3d] 场景: 全局训练 - 自定义范围 (单ID)")
|
|||
|
global_payload_custom_single = {
|
|||
|
'model_type': model_type,
|
|||
|
'training_mode': 'global',
|
|||
|
'training_scope': 'custom',
|
|||
|
'aggregation_method': 'mean',
|
|||
|
'store_ids': ['S008'],
|
|||
|
'product_ids': ['P888']
|
|||
|
}
|
|||
|
payload = global_payload_custom_single.copy()
|
|||
|
payload.pop('model_type', None)
|
|||
|
payload.pop('training_mode', None)
|
|||
|
paths_global_custom_single = path_manager.get_model_paths(training_mode='global', model_type=model_type, **payload)
|
|||
|
print(f" - Identifier: {paths_global_custom_single['identifier']}")
|
|||
|
print(f" - Version Dir: {paths_global_custom_single['version_dir']}")
|
|||
|
assert f"global_custom_s_S008_p_P888_agg_mean_{model_type}" == paths_global_custom_single['identifier']
|
|||
|
expected_path = os.path.join(test_base_dir, 'global', 'custom', 'S008', 'P888', 'mean', model_type, 'v1')
|
|||
|
assert os.path.normpath(expected_path) == os.path.normpath(paths_global_custom_single['version_dir'])
|
|||
|
|
|||
|
# --- 4. 版本管理测试 ---
|
|||
|
print("\n--- 🧪 4. 版本管理测试 ---")
|
|||
|
print("\n[4a] 场景: 多次调用同一训练,版本号递增")
|
|||
|
|
|||
|
# 第一次训练
|
|||
|
path_manager.save_version_info(paths_store_all['identifier'], paths_store_all['version'])
|
|||
|
print(f" - 保存版本: {paths_store_all['identifier']} -> v{paths_store_all['version']}")
|
|||
|
|
|||
|
# 第二次训练
|
|||
|
payload = store_payload_all.copy()
|
|||
|
payload.pop('model_type', None)
|
|||
|
payload.pop('training_mode', None)
|
|||
|
paths_store_all_v2 = path_manager.get_model_paths(training_mode='store', model_type=model_type, **payload)
|
|||
|
print(f" - 获取新版本: {paths_store_all_v2['identifier']} -> v{paths_store_all_v2['version']}")
|
|||
|
assert paths_store_all_v2['version'] == 2
|
|||
|
expected_path = os.path.join(test_base_dir, 'store', 'S001_all', model_type, 'v2')
|
|||
|
assert os.path.normpath(expected_path) == os.path.normpath(paths_store_all_v2['version_dir'])
|
|||
|
|
|||
|
# 验证 versions.json 文件
|
|||
|
with open(path_manager.versions_file, 'r') as f:
|
|||
|
versions_data = json.load(f)
|
|||
|
print(f" - versions.json 内容: {versions_data}")
|
|||
|
assert versions_data[paths_store_all['identifier']] == 1
|
|||
|
|
|||
|
print("\n="*50)
|
|||
|
print("✅ 所有测试用例通过!")
|
|||
|
print("="*50)
|
|||
|
|
|||
|
# --- 清理 ---
|
|||
|
shutil.rmtree(test_base_dir)
|
|||
|
print(f"🗑️ 测试目录 '{test_base_dir}' 已清理。")
|
|||
|
|
|||
|
if __name__ == '__main__':
|
|||
|
run_tests()
|