ShopTRAINING/test/verify_save_logic.py
xz2000 87df49f764 添加训练算法模拟xgboosT,训练可以完成,预测读取还有问题
数据文件保存机构改为### 1.2. 文件存储位置

-   **最终产物**: 所有最终模型、元数据文件、损失图等,统一存放在 `saved_models/` 根目录下。
-   **过程文件**: 所有训练过程中的检查点文件,统一存放在 `saved_models/checkpoints/` 目录下。

### 1.3. 文件名生成规则

1.  **构建逻辑路径**: 根据训练参数(模式、范围、类型、版本)确定逻辑路径。
    -   *示例*: `product/P001_all/mlstm/v2`

2.  **生成文件名前缀**: 将逻辑路径中的所有 `/` 替换为 `_`。
    -   *示例*: `product_P001_all_mlstm_v2`

3.  **拼接文件后缀**: 在前缀后加上描述文件类型的后缀。
    -   `_model.pth`
    -   `_loss_curve.png`
    -   `_checkpoint_best.pth`
    -   `_checkpoint_epoch_{N}.pth`

#### **完整示例:**

-   **最终模型**: `saved_models/product_P001_all_mlstm_v2_model.pth`
-   **最佳检查点**: `saved_models/checkpoints/product_P001_all_mlstm_v2_checkpoint_best.pth`
-   **Epoch 50 检查点**: `saved_models/checkpoints/product_P001_all_mlstm_v2_checkpoint_epoch_50.pth`
2025-07-21 18:47:27 +08:00

118 lines
5.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import unittest
import os
import shutil
import sys
# 将项目根目录添加到 sys.path以解决模块导入问题
# 这使得测试脚本可以直接运行,而无需复杂的路径配置
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
if project_root not in sys.path:
sys.path.insert(0, project_root)
from server.utils.file_save import ModelPathManager
class TestModelPathManager(unittest.TestCase):
"""
测试 ModelPathManager 是否严格遵循扁平化文件存储规范。
"""
def setUp(self):
"""在每个测试用例开始前,设置测试环境。"""
self.test_base_dir = 'test_saved_models'
# 清理之前的测试目录和文件
if os.path.exists(self.test_base_dir):
shutil.rmtree(self.test_base_dir)
self.path_manager = ModelPathManager(base_dir=self.test_base_dir)
def tearDown(self):
"""在每个测试用例结束后,清理测试环境。"""
if os.path.exists(self.test_base_dir):
shutil.rmtree(self.test_base_dir)
def test_product_mode_path_generation(self):
"""测试 'product' 模式下的路径生成是否符合规范。"""
print("\n--- 测试 'product' 模式 ---")
params = {
'training_mode': 'product',
'model_type': 'mlstm',
'product_id': 'P001',
'store_id': 'all'
}
# 第一次调用,版本应为 1
paths_v1 = self.path_manager.get_model_paths(**params)
# 验证版本号
self.assertEqual(paths_v1['version'], 1)
# 验证文件名前缀
expected_prefix_v1 = 'product_P001_all_mlstm_v1'
self.assertEqual(paths_v1['filename_prefix'], expected_prefix_v1)
# 验证各个文件的完整路径
self.assertEqual(paths_v1['model_path'], os.path.join(self.test_base_dir, f'{expected_prefix_v1}_model.pth'))
self.assertEqual(paths_v1['metadata_path'], os.path.join(self.test_base_dir, f'{expected_prefix_v1}_metadata.json'))
self.assertEqual(paths_v1['loss_curve_path'], os.path.join(self.test_base_dir, f'{expected_prefix_v1}_loss_curve.png'))
# 验证检查点路径
checkpoint_dir = os.path.join(self.test_base_dir, 'checkpoints')
self.assertEqual(paths_v1['checkpoint_dir'], checkpoint_dir)
self.assertEqual(paths_v1['best_checkpoint_path'], os.path.join(checkpoint_dir, f'{expected_prefix_v1}_checkpoint_best.pth'))
self.assertEqual(paths_v1['epoch_checkpoint_template'], os.path.join(checkpoint_dir, f'{expected_prefix_v1}_checkpoint_epoch_{{N}}.pth'))
print(f"生成的文件名前缀: {paths_v1['filename_prefix']}")
print(f"生成的模型路径: {paths_v1['model_path']}")
print("验证通过!")
# 模拟一次成功的训练,以触发版本递增
self.path_manager.save_version_info(paths_v1['identifier'], paths_v1['version'])
# 第二次调用,版本应为 2
paths_v2 = self.path_manager.get_model_paths(**params)
self.assertEqual(paths_v2['version'], 2)
expected_prefix_v2 = 'product_P001_all_mlstm_v2'
self.assertEqual(paths_v2['filename_prefix'], expected_prefix_v2)
print(f"\n版本递增后,生成的文件名前缀: {paths_v2['filename_prefix']}")
print("版本递增验证通过!")
def test_store_mode_path_generation_with_hash(self):
"""测试 'store' 模式下使用哈希的路径生成。"""
print("\n--- 测试 'store' 模式 (多药品ID哈希) ---")
params = {
'training_mode': 'store',
'model_type': 'kan',
'store_id': 'S008',
'product_scope': 'specific',
'product_ids': ['P002', 'P005', 'P003'] # 顺序故意打乱
}
paths = self.path_manager.get_model_paths(**params)
# 哈希值应该是固定的因为ID列表会先排序再哈希
expected_hash = self.path_manager._hash_ids(sorted(['P002', 'P005', 'P003']))
expected_prefix = f'store_S008_{expected_hash}_kan_v1'
self.assertEqual(paths['filename_prefix'], expected_prefix)
self.assertEqual(paths['model_path'], os.path.join(self.test_base_dir, f'{expected_prefix}_model.pth'))
print(f"生成的文件名前缀: {paths['filename_prefix']}")
print("验证通过!")
def test_global_mode_path_generation(self):
"""测试 'global' 模式下的路径生成。"""
print("\n--- 测试 'global' 模式 ---")
params = {
'training_mode': 'global',
'model_type': 'transformer',
'training_scope': 'all',
'aggregation_method': 'mean'
}
paths = self.path_manager.get_model_paths(**params)
expected_prefix = 'global_all_agg_mean_transformer_v1'
self.assertEqual(paths['filename_prefix'], expected_prefix)
self.assertEqual(paths['model_path'], os.path.join(self.test_base_dir, f'{expected_prefix}_model.pth'))
print(f"生成的文件名前缀: {paths['filename_prefix']}")
print("验证通过!")
if __name__ == '__main__':
unittest.main()