# 修改记录日志 (日期: 2025-07-16) ---未改全
## 1. 训练流程与模型保存逻辑修复 (重大) - **背景**: 用户报告在“按店铺”和“按药品”模式下,如果选择了特定的子集(如为某个店铺选择特定药品),生成的模型范围 (`scope`) 不正确,始终为 `_all`。此外,所有模型都被错误地保存到 `global` 目录下,且在某些模式下训练会失败。 - **根本原因**: 1. `server/core/predictor.py` 中负责准备训练参数的内部函数 (`_prepare_product_params`, `_prepare_store_params`) 逻辑有误,未能正确处理传入的 `product_ids` 和 `store_ids` 列表来构建详细的 `scope`。 2. 各个训练器 (`server/trainers/*.py`) 内部的日志记录和元数据生成逻辑不统一,且过于依赖 `product_id`,导致在全局或店铺模式下信息展示不清晰。 - **修复方案**: - **`server/core/predictor.py`**: - **重构 `_prepare_product_params` 和 `_prepare_store_params`**: 修改了这两个函数,使其能够正确使用 `product_ids` 和 `store_ids` 列表。现在,当选择特定范围时,会生成更具描述性的 `scope`,例如 `S001_specific_P001_P002`。 - **结果**: 确保了传递给模型管理器的 `scope` 是准确且详细的,从而使模型能够根据训练范围被保存到正确的、独立的文件夹中。 - **`server/trainers/*.py` (mlstm, kan, tcn, transformer)**: - **标准化日志与元数据**: 对所有四个训练器文件进行了统一修改。引入了一个通用的 `training_description` 变量,该变量整合了 `training_mode`、`scope` 和 `aggregation_method`。 - **更新输出**: 修改了所有训练器中的日志消息、图表标题和 `metadata.json` 的生成逻辑,使其全部使用这个标准的 `training_description`。 - **结果**: 确保了无论在哪种训练模式下,前端收到的日志、保存的图表和元数据都具有一致、清晰的格式,便于调试和结果追溯。 - **总体影响**: 此次修复从根本上解决了模型训练范围处理和模型保存路径的错误问题,使整个训练系统在所有模式下都能可靠、一致地运行。 --- ## 2. 核心 Bug 修复 ### 文件: `server/core/predictor.py` - **问题**: 在 `train_model` 方法中调用内部辅助函数 `_prepare_training_params` 时,没有正确传递 `product_ids` 和 `store_ids` 参数,导致在 `_prepare_training_params` 内部发生 `NameError`。 - **修复**: - 修正了 `train_model` 方法内部对 `_prepare_training_params` 的调用,确保 `product_ids` 和 `store_ids` 被显式传递。 - 此前已修复 `train_model` 的函数签名,使其能正确接收 `store_ids`。 - **结果**: 彻底解决了训练流程中的参数传递问题,根除了由此引发的 `NameError`。 ## 3. 代码清理与重构 ### 文件: `server/api.py` - **内容**: 移除了在 `start_training` API 端点中遗留的旧版、基于线程(`threading.Thread`)的训练逻辑。 - **原因**: 该代码块已被新的、基于多进程(`multiprocessing`)的 `TrainingProcessManager` 完全取代。旧代码中包含了大量用于调试的 `thread_safe_print` 日志,已无用处。 - **结果**: `start_training` 端点的逻辑变得更加清晰,只负责参数校验和向 `TrainingProcessManager` 提交任务。 ### 文件: `server/utils/training_process_manager.py` - **内容**: 在 `TrainingWorker` 的 `run_training_task` 方法中,移除了一个用于模拟训练进度的 `for` 循环。 - **原因**: 该循环包含 `time.sleep(1)`,仅用于在没有实际训练逻辑时模拟进度更新,现在实际的训练器会通过回调函数报告真实进度,因此该模拟代码不再需要。 - **结果**: `TrainingWorker` 现在直接调用实际的训练器,不再有模拟延迟,代码更贴近生产环境。
This commit is contained in:
parent
a9a0e51769
commit
18f505a090
@ -980,7 +980,7 @@ def start_training():
|
|||||||
"""
|
"""
|
||||||
def _prepare_training_args(data):
|
def _prepare_training_args(data):
|
||||||
"""从请求数据中提取并验证训练参数"""
|
"""从请求数据中提取并验证训练参数"""
|
||||||
training_scope = data.get('training_scope', 'all_stores_all_products')
|
training_mode = data.get('training_mode', 'product')
|
||||||
model_type = data.get('model_type')
|
model_type = data.get('model_type')
|
||||||
epochs = data.get('epochs', 50)
|
epochs = data.get('epochs', 50)
|
||||||
aggregation_method = data.get('aggregation_method', 'sum')
|
aggregation_method = data.get('aggregation_method', 'sum')
|
||||||
@ -992,30 +992,33 @@ def start_training():
|
|||||||
if model_type not in valid_model_types:
|
if model_type not in valid_model_types:
|
||||||
return None, jsonify({'error': '无效的模型类型'}), 400
|
return None, jsonify({'error': '无效的模型类型'}), 400
|
||||||
|
|
||||||
# 直接从请求中获取,不设置默认值,以便进行更严格的校验
|
|
||||||
product_ids = data.get('product_ids')
|
|
||||||
store_ids = data.get('store_ids')
|
|
||||||
|
|
||||||
args = {
|
args = {
|
||||||
'model_type': model_type,
|
'model_type': model_type,
|
||||||
'epochs': epochs,
|
'epochs': epochs,
|
||||||
'training_scope': training_scope,
|
|
||||||
'aggregation_method': aggregation_method,
|
'aggregation_method': aggregation_method,
|
||||||
|
'training_mode': training_mode,
|
||||||
'product_id': data.get('product_id'),
|
'product_id': data.get('product_id'),
|
||||||
'store_id': data.get('store_id'),
|
'store_id': data.get('store_id'),
|
||||||
'product_ids': product_ids or [], # 确保后续代码不会因None而出错
|
'product_ids': data.get('product_ids') or [],
|
||||||
'store_ids': store_ids or [],
|
'store_ids': data.get('store_ids') or [],
|
||||||
'product_scope': data.get('product_scope', 'all'),
|
'product_scope': data.get('product_scope', 'all'),
|
||||||
'training_mode': data.get('training_mode', 'product')
|
'store_scope': data.get('store_scope', 'all'),
|
||||||
|
'global_scope': data.get('global_scope', 'all'),
|
||||||
}
|
}
|
||||||
|
|
||||||
# 根据新的 scope 规则进行严格校验
|
# 根据 training_mode 进行特定参数的校验
|
||||||
if training_scope == 'selected_stores' and not store_ids:
|
if training_mode == 'product' and not args['product_id']:
|
||||||
return None, jsonify({'error': "当 training_scope 为 'selected_stores' 时, 必须提供 store_ids 列表。"}), 400
|
return None, jsonify({'error': "当 training_mode 为 'product' 时, 必须提供 product_id。"}), 400
|
||||||
if training_scope == 'selected_products' and not product_ids:
|
if training_mode == 'store' and not args['store_id']:
|
||||||
return None, jsonify({'error': "当 training_scope 为 'selected_products' 时, 必须提供 product_ids 列表。"}), 400
|
return None, jsonify({'error': "当 training_mode 为 'store' 时, 必须提供 store_id。"}), 400
|
||||||
if training_scope == 'custom' and (not store_ids or not product_ids):
|
if training_mode == 'global':
|
||||||
return None, jsonify({'error': "当 training_scope 为 'custom' 时, 必须同时提供 store_ids 和 product_ids 列表。"}), 400
|
global_scope = args['global_scope']
|
||||||
|
if global_scope == 'selected_stores' and not args['store_ids']:
|
||||||
|
return None, jsonify({'error': "当 global_scope 为 'selected_stores' 时, 必须提供 store_ids 列表。"}), 400
|
||||||
|
if global_scope == 'selected_products' and not args['product_ids']:
|
||||||
|
return None, jsonify({'error': "当 global_scope 为 'selected_products' 时, 必须提供 product_ids 列表。"}), 400
|
||||||
|
if global_scope == 'custom' and (not args['store_ids'] or not args['product_ids']):
|
||||||
|
return None, jsonify({'error': "当 global_scope 为 'custom' 时, 必须同时提供 store_ids 和 product_ids 列表。"}), 400
|
||||||
|
|
||||||
return args, None, None
|
return args, None, None
|
||||||
|
|
||||||
@ -1033,10 +1036,20 @@ def start_training():
|
|||||||
|
|
||||||
logger.info(f"🚀 训练任务已提交到进程管理器: {task_id[:8]}")
|
logger.info(f"🚀 训练任务已提交到进程管理器: {task_id[:8]}")
|
||||||
|
|
||||||
|
# 为响应动态构建一个描述性的 scope 字符串
|
||||||
|
training_mode = training_args['training_mode']
|
||||||
|
response_scope = training_mode # 默认值
|
||||||
|
if training_mode == 'product':
|
||||||
|
response_scope = f"药品: {training_args.get('product_id')} | 店铺范围: {training_args.get('store_scope')}"
|
||||||
|
elif training_mode == 'store':
|
||||||
|
response_scope = f"店铺: {training_args.get('store_id')} | 药品范围: {training_args.get('product_scope')}"
|
||||||
|
elif training_mode == 'global':
|
||||||
|
response_scope = f"全局 | 范围: {training_args.get('global_scope')}"
|
||||||
|
|
||||||
return jsonify({
|
return jsonify({
|
||||||
'message': '模型训练已开始(使用独立进程)',
|
'message': '模型训练已开始(使用独立进程)',
|
||||||
'task_id': task_id,
|
'task_id': task_id,
|
||||||
'training_scope': training_args['training_scope'],
|
'training_scope': response_scope,
|
||||||
'model_type': training_args['model_type'],
|
'model_type': training_args['model_type'],
|
||||||
'epochs': training_args['epochs']
|
'epochs': training_args['epochs']
|
||||||
})
|
})
|
||||||
|
@ -47,86 +47,115 @@ class PharmacyPredictor:
|
|||||||
print(f"加载数据失败: {e}")
|
print(f"加载数据失败: {e}")
|
||||||
self.data = None
|
self.data = None
|
||||||
|
|
||||||
def _prepare_global_params(self, **kwargs):
|
def _prepare_product_params(self, product_id, store_scope, **kwargs):
|
||||||
"""为 'global' (all_stores_all_products) 模式准备参数"""
|
"""为 'product' 训练模式准备参数"""
|
||||||
return {
|
if not product_id:
|
||||||
'final_training_mode': 'global',
|
raise ValueError("进行 'product' 模式训练时,必须提供 product_id。")
|
||||||
'agg_store_id': None,
|
|
||||||
'agg_product_id': None,
|
|
||||||
'path_store_id': 'all',
|
|
||||||
'path_product_id': 'all',
|
|
||||||
}
|
|
||||||
|
|
||||||
def _prepare_stores_params(self, **kwargs):
|
|
||||||
"""为 'stores' (selected_stores) 模式准备参数并校验"""
|
|
||||||
store_ids_list = kwargs.get('store_ids')
|
|
||||||
if not store_ids_list:
|
|
||||||
raise ValueError("进行 'selected_stores' 范围训练时,必须提供 store_ids 列表。")
|
|
||||||
return {
|
|
||||||
'final_training_mode': 'stores',
|
|
||||||
'agg_store_id': store_ids_list,
|
|
||||||
'agg_product_id': None,
|
|
||||||
'path_store_id': store_ids_list[0],
|
|
||||||
'path_product_id': 'all',
|
|
||||||
}
|
|
||||||
|
|
||||||
def _prepare_products_params(self, **kwargs):
|
|
||||||
"""为 'products' (selected_products) 模式准备参数并校验"""
|
|
||||||
product_ids_list = kwargs.get('product_ids')
|
|
||||||
if not product_ids_list:
|
|
||||||
raise ValueError("进行 'selected_products' 范围训练时,必须提供 product_ids 列表。")
|
|
||||||
return {
|
|
||||||
'final_training_mode': 'products',
|
|
||||||
'agg_store_id': None,
|
|
||||||
'agg_product_id': product_ids_list,
|
|
||||||
'path_store_id': 'all',
|
|
||||||
'path_product_id': product_ids_list[0],
|
|
||||||
}
|
|
||||||
|
|
||||||
def _prepare_custom_params(self, **kwargs):
|
|
||||||
"""为 'custom' 模式准备参数并校验"""
|
|
||||||
store_ids_list = kwargs.get('store_ids')
|
|
||||||
product_ids_list = kwargs.get('product_ids')
|
|
||||||
if not store_ids_list or not product_ids_list:
|
|
||||||
raise ValueError("进行 'custom' 范围训练时,必须同时提供 store_ids 和 product_ids 列表。")
|
|
||||||
return {
|
|
||||||
'final_training_mode': 'custom',
|
|
||||||
'agg_store_id': store_ids_list,
|
|
||||||
'agg_product_id': product_ids_list,
|
|
||||||
'path_store_id': store_ids_list[0],
|
|
||||||
'path_product_id': product_ids_list[0],
|
|
||||||
}
|
|
||||||
|
|
||||||
def _prepare_training_params(self, training_scope, product_id, store_id, **kwargs):
|
|
||||||
"""
|
|
||||||
参数准备分发器:根据 training_scope 调用相应的处理函数。
|
|
||||||
"""
|
|
||||||
scope_handlers = {
|
|
||||||
'all_stores_all_products': self._prepare_global_params,
|
|
||||||
'selected_stores': self._prepare_stores_params,
|
|
||||||
'selected_products': self._prepare_products_params,
|
|
||||||
'custom': self._prepare_custom_params,
|
|
||||||
}
|
|
||||||
handler = scope_handlers.get(training_scope)
|
|
||||||
if not handler:
|
|
||||||
raise ValueError(f"不支持的训练范围: '{training_scope}'")
|
|
||||||
|
|
||||||
# 将所有相关参数合并到一个字典中,然后传递给处理函数
|
|
||||||
all_params = kwargs.copy()
|
|
||||||
all_params['training_scope'] = training_scope
|
|
||||||
all_params['product_id'] = product_id
|
|
||||||
all_params['store_id'] = store_id
|
|
||||||
|
|
||||||
return handler(**all_params)
|
agg_store_id = None
|
||||||
|
final_scope_suffix = store_scope
|
||||||
|
|
||||||
|
if store_scope == 'specific':
|
||||||
|
store_ids = kwargs.get('store_ids')
|
||||||
|
if not store_ids:
|
||||||
|
raise ValueError("当 store_scope 为 'specific' 时, 必须提供 store_ids 列表。")
|
||||||
|
agg_store_id = store_ids
|
||||||
|
final_scope_suffix = f"specific_{'_'.join(store_ids)}"
|
||||||
|
elif store_scope != 'all':
|
||||||
|
# 假设 store_scope 本身就是一个店铺ID
|
||||||
|
agg_store_id = [store_scope]
|
||||||
|
|
||||||
def train_model(self, product_id, model_type='transformer', epochs=100,
|
return {
|
||||||
learning_rate=0.001, use_optimized=False,
|
'agg_store_id': agg_store_id,
|
||||||
store_id=None, training_mode='product', aggregation_method='sum',
|
'agg_product_id': [product_id],
|
||||||
product_scope='all', product_ids=None, store_ids=None,
|
'final_scope': f"{product_id}_{final_scope_suffix}",
|
||||||
socketio=None, task_id=None, progress_callback=None, patience=10, **kwargs):
|
}
|
||||||
|
|
||||||
|
def _prepare_store_params(self, store_id, product_scope, **kwargs):
|
||||||
|
"""为 'store' 训练模式准备参数"""
|
||||||
|
if not store_id:
|
||||||
|
raise ValueError("进行 'store' 模式训练时,必须提供 store_id。")
|
||||||
|
|
||||||
|
agg_product_id = None
|
||||||
|
final_scope_suffix = product_scope
|
||||||
|
|
||||||
|
if product_scope == 'specific':
|
||||||
|
product_ids = kwargs.get('product_ids')
|
||||||
|
if not product_ids:
|
||||||
|
raise ValueError("当 product_scope 为 'specific' 时, 必须提供 product_ids 列表。")
|
||||||
|
agg_product_id = product_ids
|
||||||
|
final_scope_suffix = f"specific_{'_'.join(product_ids)}"
|
||||||
|
elif product_scope != 'all':
|
||||||
|
# 假设 product_scope 本身就是一个药品ID
|
||||||
|
agg_product_id = [product_scope]
|
||||||
|
|
||||||
|
return {
|
||||||
|
'agg_store_id': [store_id],
|
||||||
|
'agg_product_id': agg_product_id,
|
||||||
|
'final_scope': f"{store_id}_{final_scope_suffix}",
|
||||||
|
}
|
||||||
|
|
||||||
|
def _prepare_global_params(self, global_scope, store_ids, product_ids, **kwargs):
|
||||||
|
"""为 'global' 训练模式准备参数"""
|
||||||
|
agg_store_id, agg_product_id = None, None
|
||||||
|
|
||||||
|
if global_scope == 'all':
|
||||||
|
final_scope = 'all'
|
||||||
|
elif global_scope == 'selected_stores':
|
||||||
|
if not store_ids: raise ValueError("global_scope 为 'selected_stores' 时必须提供 store_ids。")
|
||||||
|
final_scope = f"stores/{'_'.join(store_ids)}"
|
||||||
|
agg_store_id = store_ids
|
||||||
|
elif global_scope == 'selected_products':
|
||||||
|
if not product_ids: raise ValueError("global_scope 为 'selected_products' 时必须提供 product_ids。")
|
||||||
|
final_scope = f"products/{'_'.join(product_ids)}"
|
||||||
|
agg_product_id = product_ids
|
||||||
|
elif global_scope == 'custom':
|
||||||
|
if not store_ids or not product_ids: raise ValueError("global_scope 为 'custom' 时必须提供 store_ids 和 product_ids。")
|
||||||
|
final_scope = f"custom/{'_'.join(store_ids)}/{'_'.join(product_ids)}"
|
||||||
|
agg_store_id = store_ids
|
||||||
|
agg_product_id = product_ids
|
||||||
|
else:
|
||||||
|
raise ValueError(f"不支持的 global_scope: '{global_scope}'")
|
||||||
|
|
||||||
|
return {
|
||||||
|
'agg_store_id': agg_store_id,
|
||||||
|
'agg_product_id': agg_product_id,
|
||||||
|
'final_scope': final_scope,
|
||||||
|
}
|
||||||
|
|
||||||
|
def _prepare_training_params(self, training_mode, **kwargs):
|
||||||
|
"""参数准备分发器"""
|
||||||
|
if training_mode == 'product':
|
||||||
|
return self._prepare_product_params(**kwargs)
|
||||||
|
elif training_mode == 'store':
|
||||||
|
return self._prepare_store_params(**kwargs)
|
||||||
|
elif training_mode == 'global':
|
||||||
|
return self._prepare_global_params(**kwargs)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"不支持的 training_mode: '{training_mode}'")
|
||||||
|
|
||||||
|
def train_model(self, **kwargs):
|
||||||
"""
|
"""
|
||||||
训练预测模型 - 完全适配新的训练器接口
|
训练预测模型 - 完全适配新的训练器接口和模型保存规则
|
||||||
"""
|
"""
|
||||||
|
# 从 kwargs 中安全地提取参数
|
||||||
|
product_id = kwargs.get('product_id')
|
||||||
|
model_type = kwargs.get('model_type', 'transformer')
|
||||||
|
epochs = kwargs.get('epochs', 100)
|
||||||
|
learning_rate = kwargs.get('learning_rate', 0.001)
|
||||||
|
use_optimized = kwargs.get('use_optimized', False)
|
||||||
|
store_id = kwargs.get('store_id')
|
||||||
|
training_mode = kwargs.get('training_mode', 'product')
|
||||||
|
aggregation_method = kwargs.get('aggregation_method', 'sum')
|
||||||
|
product_scope = kwargs.get('product_scope', 'all')
|
||||||
|
store_scope = kwargs.get('store_scope', 'all')
|
||||||
|
global_scope = kwargs.get('global_scope', 'all')
|
||||||
|
product_ids = kwargs.get('product_ids')
|
||||||
|
store_ids = kwargs.get('store_ids')
|
||||||
|
socketio = kwargs.get('socketio')
|
||||||
|
task_id = kwargs.get('task_id')
|
||||||
|
progress_callback = kwargs.get('progress_callback')
|
||||||
|
patience = kwargs.get('patience', 10)
|
||||||
def log_message(message, log_type='info'):
|
def log_message(message, log_type='info'):
|
||||||
print(f"[{log_type.upper()}] {message}", flush=True)
|
print(f"[{log_type.upper()}] {message}", flush=True)
|
||||||
if progress_callback:
|
if progress_callback:
|
||||||
@ -140,19 +169,14 @@ class PharmacyPredictor:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 从kwargs复制一份,避免修改原始字典
|
# 将所有相关参数打包以便传递
|
||||||
call_kwargs = kwargs.copy()
|
prep_args = {
|
||||||
training_scope = call_kwargs.pop('training_scope', None)
|
'training_mode': training_mode,
|
||||||
|
'product_id': product_id, 'store_id': store_id,
|
||||||
# The dispatcher will pop the legacy store_id and product_id
|
'product_scope': product_scope, 'store_scope': store_scope,
|
||||||
params = self._prepare_training_params(
|
'global_scope': global_scope, 'product_ids': product_ids, 'store_ids': store_ids
|
||||||
training_scope=training_scope,
|
}
|
||||||
store_id=store_id,
|
params = self._prepare_training_params(**prep_args)
|
||||||
product_id=product_id,
|
|
||||||
product_ids=product_ids,
|
|
||||||
store_ids=store_ids,
|
|
||||||
**call_kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
product_data = aggregate_multi_store_data(
|
product_data = aggregate_multi_store_data(
|
||||||
store_id=params['agg_store_id'],
|
store_id=params['agg_store_id'],
|
||||||
@ -162,7 +186,7 @@ class PharmacyPredictor:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if product_data is None or product_data.empty:
|
if product_data is None or product_data.empty:
|
||||||
raise ValueError(f"聚合后数据为空,无法继续训练。模式: {params['final_training_mode']}")
|
raise ValueError(f"聚合后数据为空,无法继续训练。模式: {training_mode}, Scope: {params['final_scope']}")
|
||||||
|
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
log_message(f"参数校验或数据准备失败: {e}", 'error')
|
log_message(f"参数校验或数据准备失败: {e}", 'error')
|
||||||
@ -173,10 +197,6 @@ class PharmacyPredictor:
|
|||||||
log_message(traceback.format_exc(), 'error')
|
log_message(traceback.format_exc(), 'error')
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if product_data.empty:
|
|
||||||
log_message(f"找不到产品 {product_id} 的数据", 'error')
|
|
||||||
return None
|
|
||||||
|
|
||||||
trainers = {
|
trainers = {
|
||||||
'transformer': train_product_model_with_transformer,
|
'transformer': train_product_model_with_transformer,
|
||||||
'mlstm': train_product_model_with_mlstm,
|
'mlstm': train_product_model_with_mlstm,
|
||||||
@ -192,12 +212,10 @@ class PharmacyPredictor:
|
|||||||
trainer_func = trainers[model_type]
|
trainer_func = trainers[model_type]
|
||||||
|
|
||||||
trainer_args = {
|
trainer_args = {
|
||||||
"product_id": params['path_product_id'],
|
|
||||||
"product_df": product_data,
|
"product_df": product_data,
|
||||||
"store_id": params['path_store_id'],
|
"training_mode": training_mode,
|
||||||
"training_mode": params['final_training_mode'],
|
|
||||||
"aggregation_method": aggregation_method,
|
"aggregation_method": aggregation_method,
|
||||||
"scope": kwargs.get('training_scope'),
|
"scope": params['final_scope'],
|
||||||
"epochs": epochs,
|
"epochs": epochs,
|
||||||
"socketio": socketio,
|
"socketio": socketio,
|
||||||
"task_id": task_id,
|
"task_id": task_id,
|
||||||
@ -209,8 +227,14 @@ class PharmacyPredictor:
|
|||||||
if 'kan' in model_type:
|
if 'kan' in model_type:
|
||||||
trainer_args['use_optimized'] = (model_type == 'optimized_kan')
|
trainer_args['use_optimized'] = (model_type == 'optimized_kan')
|
||||||
|
|
||||||
|
# 确保将 product_id 和 store_id 传递给训练器
|
||||||
|
if product_id:
|
||||||
|
trainer_args['product_id'] = product_id
|
||||||
|
if store_id:
|
||||||
|
trainer_args['store_id'] = store_id
|
||||||
|
|
||||||
try:
|
try:
|
||||||
log_message(f"🤖 开始调用 {model_type} 训练器")
|
log_message(f"🤖 开始调用 {model_type} 训练器 with scope: '{params['final_scope']}'")
|
||||||
|
|
||||||
model, metrics, version, model_version_path = trainer_func(**trainer_args)
|
model, metrics, version, model_version_path = trainer_func(**trainer_args)
|
||||||
|
|
||||||
@ -223,10 +247,9 @@ class PharmacyPredictor:
|
|||||||
'model_type': model_type,
|
'model_type': model_type,
|
||||||
'version': version,
|
'version': version,
|
||||||
'model_path': relative_model_path.replace('\\', '/'),
|
'model_path': relative_model_path.replace('\\', '/'),
|
||||||
'training_mode': params['final_training_mode'],
|
'training_mode': training_mode,
|
||||||
'store_id': params['path_store_id'],
|
'scope': params['final_scope'],
|
||||||
'product_id': params['path_product_id'],
|
'aggregation_method': aggregation_method
|
||||||
'aggregation_method': aggregation_method if params['final_training_mode'] == 'global' else None
|
|
||||||
})
|
})
|
||||||
log_message(f"📈 最终返回的metrics: {metrics}", 'success')
|
log_message(f"📈 最终返回的metrics: {metrics}", 'success')
|
||||||
return metrics
|
return metrics
|
||||||
|
@ -40,8 +40,8 @@ def convert_numpy_types(obj: Any) -> Any:
|
|||||||
return obj
|
return obj
|
||||||
|
|
||||||
def train_product_model_with_kan(
|
def train_product_model_with_kan(
|
||||||
product_id,
|
product_df,
|
||||||
product_df=None,
|
product_id=None,
|
||||||
store_id=None,
|
store_id=None,
|
||||||
training_mode='product',
|
training_mode='product',
|
||||||
aggregation_method='sum',
|
aggregation_method='sum',
|
||||||
@ -111,9 +111,7 @@ def train_product_model_with_kan(
|
|||||||
training_mode=training_mode,
|
training_mode=training_mode,
|
||||||
scope=scope,
|
scope=scope,
|
||||||
version=version,
|
version=version,
|
||||||
aggregation_method=aggregation_method,
|
aggregation_method=aggregation_method
|
||||||
product_id=product_id,
|
|
||||||
store_id=store_id
|
|
||||||
)
|
)
|
||||||
emit_progress(f"模型将保存到: {model_version_path}")
|
emit_progress(f"模型将保存到: {model_version_path}")
|
||||||
|
|
||||||
@ -133,15 +131,23 @@ def train_product_model_with_kan(
|
|||||||
df = pd.read_excel('pharmacy_sales.xlsx')
|
df = pd.read_excel('pharmacy_sales.xlsx')
|
||||||
product_df = df[df['product_id'] == product_id].sort_values('date')
|
product_df = df[df['product_id'] == product_id].sort_values('date')
|
||||||
|
|
||||||
if training_mode == 'store' and store_id:
|
# 根据训练模式和参数动态生成更详细的描述
|
||||||
|
if training_mode == 'store':
|
||||||
training_scope = f"店铺 {store_id}"
|
training_scope = f"店铺 {store_id}"
|
||||||
|
if scope and 'specific' in scope:
|
||||||
|
training_scope += " (指定药品)"
|
||||||
|
else:
|
||||||
|
training_scope += " (所有药品)"
|
||||||
elif training_mode == 'global':
|
elif training_mode == 'global':
|
||||||
training_scope = f"全局聚合({aggregation_method})"
|
training_scope = f"全局聚合({aggregation_method})"
|
||||||
else: # 主要对应 product 模式
|
else: # product 模式
|
||||||
if store_id:
|
training_scope = f"药品 {product_id}"
|
||||||
training_scope = f"店铺 {store_id}"
|
if scope and 'specific' in scope:
|
||||||
|
training_scope += " (指定店铺)"
|
||||||
|
elif store_id:
|
||||||
|
training_scope += f" (店铺 {store_id})"
|
||||||
else:
|
else:
|
||||||
training_scope = "所有店铺"
|
training_scope += " (所有店铺)"
|
||||||
|
|
||||||
min_required_samples = LOOK_BACK + FORECAST_HORIZON
|
min_required_samples = LOOK_BACK + FORECAST_HORIZON
|
||||||
if len(product_df) < min_required_samples:
|
if len(product_df) < min_required_samples:
|
||||||
@ -150,9 +156,13 @@ def train_product_model_with_kan(
|
|||||||
raise ValueError(error_msg)
|
raise ValueError(error_msg)
|
||||||
|
|
||||||
product_df = product_df.sort_values('date')
|
product_df = product_df.sort_values('date')
|
||||||
product_name = product_df['product_name'].iloc[0]
|
if product_id:
|
||||||
|
product_name = product_df['product_name'].iloc[0]
|
||||||
|
else:
|
||||||
|
product_name = f"Aggregated Model ({training_mode}/{scope})"
|
||||||
|
|
||||||
emit_progress(f"训练产品: '{product_name}' (ID: {product_id}) - {training_scope}")
|
print_product_id = product_id if product_id else "N/A"
|
||||||
|
emit_progress(f"训练产品: '{product_name}' (ID: {print_product_id}) - {training_scope}")
|
||||||
emit_progress(f"使用设备: {DEVICE}, 数据量: {len(product_df)} 条")
|
emit_progress(f"使用设备: {DEVICE}, 数据量: {len(product_df)} 条")
|
||||||
|
|
||||||
features = ['sales', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature']
|
features = ['sales', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature']
|
||||||
@ -292,10 +302,10 @@ def train_product_model_with_kan(
|
|||||||
model_manager.save_model_artifact(final_model_data, "model.pth", model_version_path)
|
model_manager.save_model_artifact(final_model_data, "model.pth", model_version_path)
|
||||||
|
|
||||||
metadata = {
|
metadata = {
|
||||||
'product_id': product_id, 'product_name': product_name, 'model_type': model_type,
|
'product_id': product_id if product_id else scope, 'product_name': product_name, 'model_type': model_type,
|
||||||
'version': f'v{version}', 'training_mode': training_mode, 'scope': scope,
|
'version': f'v{version}', 'training_mode': training_mode, 'scope': scope,
|
||||||
'aggregation_method': aggregation_method, 'training_scope_description': training_scope,
|
'aggregation_method': aggregation_method, 'training_scope_description': training_scope,
|
||||||
'product_scope': '所有药品' if product_id == 'all' else product_name,
|
'product_scope': '所有药品' if not product_id or product_id == 'all' else product_name,
|
||||||
'timestamp': datetime.now().isoformat(), 'metrics': metrics,
|
'timestamp': datetime.now().isoformat(), 'metrics': metrics,
|
||||||
'config': {
|
'config': {
|
||||||
'input_dim': input_dim, 'output_dim': output_dim,
|
'input_dim': input_dim, 'output_dim': output_dim,
|
||||||
|
@ -41,8 +41,8 @@ def convert_numpy_types(obj: Any) -> Any:
|
|||||||
return obj
|
return obj
|
||||||
|
|
||||||
def train_product_model_with_mlstm(
|
def train_product_model_with_mlstm(
|
||||||
product_id,
|
|
||||||
product_df,
|
product_df,
|
||||||
|
product_id=None,
|
||||||
store_id=None,
|
store_id=None,
|
||||||
training_mode='product',
|
training_mode='product',
|
||||||
aggregation_method='sum',
|
aggregation_method='sum',
|
||||||
@ -129,22 +129,29 @@ def train_product_model_with_mlstm(
|
|||||||
model_version_path = model_manager.get_model_version_path(
|
model_version_path = model_manager.get_model_version_path(
|
||||||
model_type=model_type,
|
model_type=model_type,
|
||||||
training_mode=training_mode,
|
training_mode=training_mode,
|
||||||
|
scope=scope,
|
||||||
version=version,
|
version=version,
|
||||||
aggregation_method=aggregation_method,
|
aggregation_method=aggregation_method
|
||||||
product_id=product_id,
|
|
||||||
store_id=store_id
|
|
||||||
)
|
)
|
||||||
emit_progress(f"模型将保存到: {model_version_path}")
|
emit_progress(f"模型将保存到: {model_version_path}")
|
||||||
|
|
||||||
if training_mode == 'store' and store_id:
|
# 根据训练模式和参数动态生成更详细的描述
|
||||||
|
if training_mode == 'store':
|
||||||
training_scope = f"店铺 {store_id}"
|
training_scope = f"店铺 {store_id}"
|
||||||
|
if scope and 'specific' in scope:
|
||||||
|
training_scope += " (指定药品)"
|
||||||
|
else:
|
||||||
|
training_scope += " (所有药品)"
|
||||||
elif training_mode == 'global':
|
elif training_mode == 'global':
|
||||||
training_scope = f"全局聚合({aggregation_method})"
|
training_scope = f"全局聚合({aggregation_method})"
|
||||||
else: # 主要对应 product 模式
|
else: # product 模式
|
||||||
if store_id:
|
training_scope = f"药品 {product_id}"
|
||||||
training_scope = f"店铺 {store_id}"
|
if scope and 'specific' in scope:
|
||||||
|
training_scope += " (指定店铺)"
|
||||||
|
elif store_id:
|
||||||
|
training_scope += f" (店铺 {store_id})"
|
||||||
else:
|
else:
|
||||||
training_scope = "所有店铺"
|
training_scope += " (所有店铺)"
|
||||||
|
|
||||||
min_required_samples = LOOK_BACK + FORECAST_HORIZON
|
min_required_samples = LOOK_BACK + FORECAST_HORIZON
|
||||||
if len(product_df) < min_required_samples:
|
if len(product_df) < min_required_samples:
|
||||||
@ -153,14 +160,18 @@ def train_product_model_with_mlstm(
|
|||||||
emit_progress(f"训练失败:{error_msg}")
|
emit_progress(f"训练失败:{error_msg}")
|
||||||
raise ValueError(error_msg)
|
raise ValueError(error_msg)
|
||||||
|
|
||||||
product_name = product_df['product_name'].iloc[0]
|
if product_id:
|
||||||
|
product_name = product_df['product_name'].iloc[0]
|
||||||
print(f"[mLSTM] 使用mLSTM模型训练 '{product_name}' (ID: {product_id}) 的销售预测模型", flush=True)
|
else:
|
||||||
|
product_name = f"Aggregated Model ({training_mode}/{scope})"
|
||||||
|
|
||||||
|
print_product_id = product_id if product_id else "N/A"
|
||||||
|
print(f"[mLSTM] 使用mLSTM模型训练 '{product_name}' (ID: {print_product_id}) 的销售预测模型", flush=True)
|
||||||
print(f"[mLSTM] 训练范围: {training_scope}", flush=True)
|
print(f"[mLSTM] 训练范围: {training_scope}", flush=True)
|
||||||
print(f"[mLSTM] 版本: v{version}", flush=True)
|
print(f"[mLSTM] 版本: v{version}", flush=True)
|
||||||
print(f"[mLSTM] 使用设备: {DEVICE}", flush=True)
|
print(f"[mLSTM] 使用设备: {DEVICE}", flush=True)
|
||||||
print(f"[mLSTM] 数据量: {len(product_df)} 条记录", flush=True)
|
print(f"[mLSTM] 数据量: {len(product_df)} 条记录", flush=True)
|
||||||
emit_progress(f"训练产品: {product_name} (ID: {product_id}) - {training_scope}")
|
emit_progress(f"训练产品: {product_name} (ID: {print_product_id}) - {training_scope}")
|
||||||
|
|
||||||
# 创建特征和目标变量
|
# 创建特征和目标变量
|
||||||
features = ['sales', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature']
|
features = ['sales', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature']
|
||||||
@ -319,10 +330,10 @@ def train_product_model_with_mlstm(
|
|||||||
model_manager.save_model_artifact(final_model_data, "model.pth", model_version_path)
|
model_manager.save_model_artifact(final_model_data, "model.pth", model_version_path)
|
||||||
|
|
||||||
metadata = {
|
metadata = {
|
||||||
'product_id': product_id, 'product_name': product_name, 'model_type': model_type,
|
'product_id': product_id if product_id else scope, 'product_name': product_name, 'model_type': model_type,
|
||||||
'version': f'v{version}', 'training_mode': training_mode, 'scope': scope,
|
'version': f'v{version}', 'training_mode': training_mode, 'scope': scope,
|
||||||
'aggregation_method': aggregation_method, 'training_scope_description': training_scope,
|
'aggregation_method': aggregation_method, 'training_scope_description': training_scope,
|
||||||
'product_scope': '所有药品' if product_id == 'all' else product_name,
|
'product_scope': '所有药品' if not product_id or product_id == 'all' else product_name,
|
||||||
'timestamp': datetime.now().isoformat(), 'metrics': metrics,
|
'timestamp': datetime.now().isoformat(), 'metrics': metrics,
|
||||||
'config': {
|
'config': {
|
||||||
'input_dim': input_dim, 'output_dim': output_dim, 'hidden_size': hidden_size,
|
'input_dim': input_dim, 'output_dim': output_dim, 'hidden_size': hidden_size,
|
||||||
|
@ -38,8 +38,8 @@ def convert_numpy_types(obj: Any) -> Any:
|
|||||||
return obj
|
return obj
|
||||||
|
|
||||||
def train_product_model_with_tcn(
|
def train_product_model_with_tcn(
|
||||||
product_id,
|
product_df,
|
||||||
product_df=None,
|
product_id=None,
|
||||||
store_id=None,
|
store_id=None,
|
||||||
training_mode='product',
|
training_mode='product',
|
||||||
aggregation_method='sum',
|
aggregation_method='sum',
|
||||||
@ -108,9 +108,7 @@ def train_product_model_with_tcn(
|
|||||||
training_mode=training_mode,
|
training_mode=training_mode,
|
||||||
scope=scope,
|
scope=scope,
|
||||||
version=version,
|
version=version,
|
||||||
aggregation_method=aggregation_method,
|
aggregation_method=aggregation_method
|
||||||
product_id=product_id,
|
|
||||||
store_id=store_id
|
|
||||||
)
|
)
|
||||||
emit_progress(f"模型将保存到: {model_version_path}")
|
emit_progress(f"模型将保存到: {model_version_path}")
|
||||||
|
|
||||||
@ -129,15 +127,10 @@ def train_product_model_with_tcn(
|
|||||||
df = pd.read_excel('pharmacy_sales.xlsx')
|
df = pd.read_excel('pharmacy_sales.xlsx')
|
||||||
product_df = df[df['product_id'] == product_id].sort_values('date')
|
product_df = df[df['product_id'] == product_id].sort_values('date')
|
||||||
|
|
||||||
if training_mode == 'store' and store_id:
|
# 构建一个更通用的训练描述
|
||||||
training_scope = f"店铺 {store_id}"
|
training_description = f"模式: {training_mode}, 范围: {scope}"
|
||||||
elif training_mode == 'global':
|
if aggregation_method and aggregation_method != 'none':
|
||||||
training_scope = f"全局聚合({aggregation_method})"
|
training_description += f", 聚合: {aggregation_method}"
|
||||||
else: # 主要对应 product 模式
|
|
||||||
if store_id:
|
|
||||||
training_scope = f"店铺 {store_id}"
|
|
||||||
else:
|
|
||||||
training_scope = "所有店铺"
|
|
||||||
|
|
||||||
min_required_samples = LOOK_BACK + FORECAST_HORIZON
|
min_required_samples = LOOK_BACK + FORECAST_HORIZON
|
||||||
if len(product_df) < min_required_samples:
|
if len(product_df) < min_required_samples:
|
||||||
@ -146,9 +139,13 @@ def train_product_model_with_tcn(
|
|||||||
raise ValueError(error_msg)
|
raise ValueError(error_msg)
|
||||||
|
|
||||||
product_df = product_df.sort_values('date')
|
product_df = product_df.sort_values('date')
|
||||||
product_name = product_df['product_name'].iloc[0]
|
if product_id:
|
||||||
|
product_name = product_df['product_name'].iloc[0]
|
||||||
|
else:
|
||||||
|
product_name = f"Aggregated Model ({training_mode}/{scope})"
|
||||||
|
|
||||||
emit_progress(f"训练产品: '{product_name}' (ID: {product_id}) - {training_scope}")
|
print_product_id = product_id if product_id else "N/A"
|
||||||
|
emit_progress(f"开始训练. 描述: {training_description}")
|
||||||
emit_progress(f"使用设备: {DEVICE}, 数据量: {len(product_df)} 条")
|
emit_progress(f"使用设备: {DEVICE}, 数据量: {len(product_df)} 条")
|
||||||
|
|
||||||
features = ['sales', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature']
|
features = ['sales', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature']
|
||||||
@ -256,7 +253,7 @@ def train_product_model_with_tcn(
|
|||||||
loss_fig = plt.figure(figsize=(10, 6))
|
loss_fig = plt.figure(figsize=(10, 6))
|
||||||
plt.plot(train_losses, label='Training Loss')
|
plt.plot(train_losses, label='Training Loss')
|
||||||
plt.plot(test_losses, label='Test Loss')
|
plt.plot(test_losses, label='Test Loss')
|
||||||
plt.title(f'{model_type.upper()} 损失曲线 - {product_name} (v{version}) - {training_scope}')
|
plt.title(f'{model_type.upper()} 损失曲线 - {training_description} (v{version})')
|
||||||
plt.xlabel('Epoch'); plt.ylabel('Loss'); plt.legend(); plt.grid(True)
|
plt.xlabel('Epoch'); plt.ylabel('Loss'); plt.legend(); plt.grid(True)
|
||||||
model_manager.save_model_artifact(loss_fig, "loss_curve.png", model_version_path)
|
model_manager.save_model_artifact(loss_fig, "loss_curve.png", model_version_path)
|
||||||
plt.close(loss_fig)
|
plt.close(loss_fig)
|
||||||
@ -287,10 +284,9 @@ def train_product_model_with_tcn(
|
|||||||
model_manager.save_model_artifact(final_model_data, "model.pth", model_version_path)
|
model_manager.save_model_artifact(final_model_data, "model.pth", model_version_path)
|
||||||
|
|
||||||
metadata = {
|
metadata = {
|
||||||
'product_id': product_id, 'product_name': product_name, 'model_type': model_type,
|
'product_id': product_id if product_id else scope, 'product_name': product_name, 'model_type': model_type,
|
||||||
'version': f'v{version}', 'training_mode': training_mode, 'scope': scope,
|
'version': f'v{version}', 'training_mode': training_mode, 'scope': scope,
|
||||||
'aggregation_method': aggregation_method, 'training_scope_description': training_scope,
|
'aggregation_method': aggregation_method, 'training_description': training_description,
|
||||||
'product_scope': '所有药品' if product_id == 'all' else product_name,
|
|
||||||
'timestamp': datetime.now().isoformat(), 'metrics': metrics,
|
'timestamp': datetime.now().isoformat(), 'metrics': metrics,
|
||||||
'config': {
|
'config': {
|
||||||
'input_dim': input_dim, 'output_dim': output_dim, 'hidden_size': hidden_size,
|
'input_dim': input_dim, 'output_dim': output_dim, 'hidden_size': hidden_size,
|
||||||
|
@ -38,8 +38,8 @@ def convert_numpy_types(obj: Any) -> Any:
|
|||||||
return obj
|
return obj
|
||||||
|
|
||||||
def train_product_model_with_transformer(
|
def train_product_model_with_transformer(
|
||||||
product_id,
|
product_df,
|
||||||
product_df=None,
|
product_id=None,
|
||||||
store_id=None,
|
store_id=None,
|
||||||
training_mode='product',
|
training_mode='product',
|
||||||
aggregation_method='sum',
|
aggregation_method='sum',
|
||||||
@ -109,9 +109,7 @@ def train_product_model_with_transformer(
|
|||||||
training_mode=training_mode,
|
training_mode=training_mode,
|
||||||
scope=scope,
|
scope=scope,
|
||||||
version=version,
|
version=version,
|
||||||
aggregation_method=aggregation_method,
|
aggregation_method=aggregation_method
|
||||||
product_id=product_id,
|
|
||||||
store_id=store_id
|
|
||||||
)
|
)
|
||||||
emit_progress(f"模型将保存到: {model_version_path}")
|
emit_progress(f"模型将保存到: {model_version_path}")
|
||||||
|
|
||||||
@ -130,15 +128,10 @@ def train_product_model_with_transformer(
|
|||||||
df = pd.read_excel('pharmacy_sales.xlsx')
|
df = pd.read_excel('pharmacy_sales.xlsx')
|
||||||
product_df = df[df['product_id'] == product_id].sort_values('date')
|
product_df = df[df['product_id'] == product_id].sort_values('date')
|
||||||
|
|
||||||
if training_mode == 'store' and store_id:
|
# 构建一个更通用的训练描述
|
||||||
training_scope = f"店铺 {store_id}"
|
training_description = f"模式: {training_mode}, 范围: {scope}"
|
||||||
elif training_mode == 'global':
|
if aggregation_method and aggregation_method != 'none':
|
||||||
training_scope = f"全局聚合({aggregation_method})"
|
training_description += f", 聚合: {aggregation_method}"
|
||||||
else: # 主要对应 product 模式
|
|
||||||
if store_id:
|
|
||||||
training_scope = f"店铺 {store_id}"
|
|
||||||
else:
|
|
||||||
training_scope = "所有店铺"
|
|
||||||
|
|
||||||
min_required_samples = LOOK_BACK + FORECAST_HORIZON
|
min_required_samples = LOOK_BACK + FORECAST_HORIZON
|
||||||
if len(product_df) < min_required_samples:
|
if len(product_df) < min_required_samples:
|
||||||
@ -147,9 +140,13 @@ def train_product_model_with_transformer(
|
|||||||
raise ValueError(error_msg)
|
raise ValueError(error_msg)
|
||||||
|
|
||||||
product_df = product_df.sort_values('date')
|
product_df = product_df.sort_values('date')
|
||||||
product_name = product_df['product_name'].iloc[0]
|
if product_id:
|
||||||
|
product_name = product_df['product_name'].iloc[0]
|
||||||
|
else:
|
||||||
|
product_name = f"Aggregated Model ({training_mode}/{scope})"
|
||||||
|
|
||||||
emit_progress(f"训练产品: '{product_name}' (ID: {product_id}) - {training_scope}")
|
print_product_id = product_id if product_id else "N/A"
|
||||||
|
emit_progress(f"开始训练. 描述: {training_description}")
|
||||||
emit_progress(f"使用设备: {DEVICE}, 数据量: {len(product_df)} 条")
|
emit_progress(f"使用设备: {DEVICE}, 数据量: {len(product_df)} 条")
|
||||||
|
|
||||||
features = ['sales', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature']
|
features = ['sales', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature']
|
||||||
@ -263,7 +260,7 @@ def train_product_model_with_transformer(
|
|||||||
loss_fig = plt.figure(figsize=(10, 6))
|
loss_fig = plt.figure(figsize=(10, 6))
|
||||||
plt.plot(train_losses, label='Training Loss')
|
plt.plot(train_losses, label='Training Loss')
|
||||||
plt.plot(test_losses, label='Test Loss')
|
plt.plot(test_losses, label='Test Loss')
|
||||||
plt.title(f'{model_type.upper()} 损失曲线 - {product_name} (v{version}) - {training_scope}')
|
plt.title(f'{model_type.upper()} 损失曲线 - {training_description} (v{version})')
|
||||||
plt.xlabel('Epoch'); plt.ylabel('Loss'); plt.legend(); plt.grid(True)
|
plt.xlabel('Epoch'); plt.ylabel('Loss'); plt.legend(); plt.grid(True)
|
||||||
model_manager.save_model_artifact(loss_fig, "loss_curve.png", model_version_path)
|
model_manager.save_model_artifact(loss_fig, "loss_curve.png", model_version_path)
|
||||||
plt.close(loss_fig)
|
plt.close(loss_fig)
|
||||||
@ -294,10 +291,9 @@ def train_product_model_with_transformer(
|
|||||||
model_manager.save_model_artifact(final_model_data, "model.pth", model_version_path)
|
model_manager.save_model_artifact(final_model_data, "model.pth", model_version_path)
|
||||||
|
|
||||||
metadata = {
|
metadata = {
|
||||||
'product_id': product_id, 'product_name': product_name, 'model_type': model_type,
|
'product_id': product_id if product_id else scope, 'product_name': product_name, 'model_type': model_type,
|
||||||
'version': f'v{version}', 'training_mode': training_mode, 'scope': scope,
|
'version': f'v{version}', 'training_mode': training_mode, 'scope': scope,
|
||||||
'aggregation_method': aggregation_method, 'training_scope_description': training_scope,
|
'aggregation_method': aggregation_method, 'training_description': training_description,
|
||||||
'product_scope': '所有药品' if product_id == 'all' else product_name,
|
|
||||||
'timestamp': datetime.now().isoformat(), 'metrics': metrics,
|
'timestamp': datetime.now().isoformat(), 'metrics': metrics,
|
||||||
'config': {
|
'config': {
|
||||||
'input_dim': input_dim, 'output_dim': output_dim, 'd_model': hidden_size,
|
'input_dim': input_dim, 'output_dim': output_dim, 'd_model': hidden_size,
|
||||||
|
@ -78,44 +78,35 @@ class ModelManager:
|
|||||||
model_type: str,
|
model_type: str,
|
||||||
version: int,
|
version: int,
|
||||||
training_mode: str,
|
training_mode: str,
|
||||||
aggregation_method: Optional[str] = None,
|
scope: str,
|
||||||
store_id: Optional[str] = None,
|
aggregation_method: Optional[str] = None) -> str:
|
||||||
product_id: Optional[str] = None,
|
|
||||||
scope: Optional[str] = None) -> str: # scope为了兼容旧调用
|
|
||||||
"""
|
"""
|
||||||
根据 `xz训练模型保存规则.md` 中定义的新规则生成模型版本目录的完整路径。
|
根据 `xz训练模型保存规则.md` 中定义的新规则生成模型版本目录的完整路径。
|
||||||
"""
|
"""
|
||||||
# 基础路径始终是 self.model_dir
|
|
||||||
base_path = self.model_dir
|
base_path = self.model_dir
|
||||||
|
path_parts = [base_path]
|
||||||
|
|
||||||
# 确定第一级目录,根据规则,所有模式都在 'global' 下
|
if training_mode == 'product':
|
||||||
path_parts = [base_path, 'global']
|
# saved_models/product/{scope}/{model_type}/v{N}/
|
||||||
|
if not scope: raise ValueError("scope is required for 'product' training mode.")
|
||||||
if training_mode == 'global':
|
path_parts.extend(['product', scope, model_type, f'v{version}'])
|
||||||
# global/all/{aggregation_method}/{model_type}/v{N}/
|
|
||||||
path_parts.extend(['all', str(aggregation_method)])
|
|
||||||
|
|
||||||
elif training_mode == 'stores':
|
elif training_mode == 'store':
|
||||||
# global/stores/{store_id}/{aggregation_method}/{model_type}/v{N}/
|
# saved_models/store/{scope}/{model_type}/v{N}/
|
||||||
if not store_id: raise ValueError("store_id is required for 'stores' training mode.")
|
if not scope: raise ValueError("scope is required for 'store' training mode.")
|
||||||
path_parts.extend(['stores', store_id, str(aggregation_method)])
|
path_parts.extend(['store', scope, model_type, f'v{version}'])
|
||||||
|
|
||||||
elif training_mode == 'products':
|
elif training_mode == 'global':
|
||||||
# global/products/{product_id}/{aggregation_method}/{model_type}/v{N}/
|
# saved_models/global/{scope_path}/{aggregation_method}/{model_type}/v{N}/
|
||||||
if not product_id: raise ValueError("product_id is required for 'products' training mode.")
|
if not scope: raise ValueError("scope is required for 'global' training mode.")
|
||||||
path_parts.extend(['products', product_id, str(aggregation_method)])
|
if not aggregation_method: raise ValueError("aggregation_method is required for 'global' training mode.")
|
||||||
|
|
||||||
|
scope_parts = scope.split('/')
|
||||||
|
path_parts.extend(['global', *scope_parts, str(aggregation_method), model_type, f'v{version}'])
|
||||||
|
|
||||||
elif training_mode == 'custom':
|
|
||||||
# global/custom/{store_id}/{product_id}/{aggregation_method}/{model_type}/v{N}/
|
|
||||||
if not store_id or not product_id:
|
|
||||||
raise ValueError("store_id and product_id are required for 'custom' training mode.")
|
|
||||||
path_parts.extend(['custom', store_id, product_id, str(aggregation_method)])
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"不支持的 training_mode: {training_mode}")
|
raise ValueError(f"不支持的 training_mode: {training_mode}")
|
||||||
|
|
||||||
path_parts.extend([model_type, f'v{version}'])
|
|
||||||
|
|
||||||
return os.path.join(*path_parts)
|
return os.path.join(*path_parts)
|
||||||
|
|
||||||
def save_model_artifact(self,
|
def save_model_artifact(self,
|
||||||
@ -193,51 +184,43 @@ class ModelManager:
|
|||||||
relative_path = os.path.relpath(norm_path, norm_model_dir)
|
relative_path = os.path.relpath(norm_path, norm_model_dir)
|
||||||
parts = relative_path.split(os.sep)
|
parts = relative_path.split(os.sep)
|
||||||
|
|
||||||
# 期望路径: global/{scope_type}/{id...}/{agg_method}/{model_type}/v{N}
|
if len(parts) < 4:
|
||||||
if parts[0] != 'global' or len(parts) < 5:
|
return None
|
||||||
return None # 不是规范的新路径
|
|
||||||
|
|
||||||
info = {
|
info = {
|
||||||
'model_path': version_path,
|
'model_path': version_path,
|
||||||
'version': parts[-1],
|
'version': parts[-1],
|
||||||
'model_type': parts[-2],
|
'model_type': parts[-2],
|
||||||
|
'training_mode': parts[0],
|
||||||
'store_id': None,
|
'store_id': None,
|
||||||
'product_id': None,
|
'product_id': None,
|
||||||
|
'aggregation_method': None,
|
||||||
|
'scope': None
|
||||||
}
|
}
|
||||||
|
|
||||||
scope_type = parts[1] # all, stores, products, custom
|
mode = parts[0]
|
||||||
|
if mode == 'product':
|
||||||
if scope_type == 'all':
|
# product/{scope}/mlstm/v1
|
||||||
# global/all/sum/mlstm/v1
|
info['scope'] = parts[1]
|
||||||
info['training_mode'] = 'global'
|
elif mode == 'store':
|
||||||
info['aggregation_method'] = parts[2]
|
# store/{scope}/mlstm/v1
|
||||||
elif scope_type == 'stores':
|
info['scope'] = parts[1]
|
||||||
# global/stores/S001/sum/mlstm/v1
|
elif mode == 'global':
|
||||||
info['training_mode'] = 'stores'
|
# global/{scope...}/sum/mlstm/v1
|
||||||
info['store_id'] = parts[2]
|
info['aggregation_method'] = parts[-3]
|
||||||
info['aggregation_method'] = parts[3]
|
info['scope'] = '/'.join(parts[1:-3])
|
||||||
elif scope_type == 'products':
|
|
||||||
# global/products/P001/sum/mlstm/v1
|
|
||||||
info['training_mode'] = 'products'
|
|
||||||
info['product_id'] = parts[2]
|
|
||||||
info['aggregation_method'] = parts[3]
|
|
||||||
elif scope_type == 'custom':
|
|
||||||
# global/custom/S001/P001/sum/mlstm/v1
|
|
||||||
info['training_mode'] = 'custom'
|
|
||||||
info['store_id'] = parts[2]
|
|
||||||
info['product_id'] = parts[3]
|
|
||||||
info['aggregation_method'] = parts[4]
|
|
||||||
else:
|
else:
|
||||||
return None
|
return None # 未知模式
|
||||||
|
|
||||||
metadata_path = os.path.join(version_path, 'metadata.json')
|
metadata_path = os.path.join(version_path, 'metadata.json')
|
||||||
if os.path.exists(metadata_path):
|
if os.path.exists(metadata_path):
|
||||||
with open(metadata_path, 'r', encoding='utf-8') as f:
|
with open(metadata_path, 'r', encoding='utf-8') as f:
|
||||||
metadata = json.load(f)
|
metadata = json.load(f)
|
||||||
# 确保从路径解析出的ID覆盖元数据中的,因为路径是权威来源
|
|
||||||
info.update(metadata)
|
info.update(metadata)
|
||||||
info['version'] = parts[-1] # 重新覆盖,确保正确
|
# 确保从路径解析出的关键信息覆盖元数据中的,因为路径是权威来源
|
||||||
|
info['version'] = parts[-1]
|
||||||
info['model_type'] = parts[-2]
|
info['model_type'] = parts[-2]
|
||||||
|
info['training_mode'] = parts[0]
|
||||||
|
|
||||||
return info
|
return info
|
||||||
except (IndexError, IOError) as e:
|
except (IndexError, IOError) as e:
|
||||||
|
@ -125,19 +125,10 @@ class TrainingWorker:
|
|||||||
training_logger.error(f"进度回调失败: {e}")
|
training_logger.error(f"进度回调失败: {e}")
|
||||||
|
|
||||||
# 执行真正的训练,传递进度回调
|
# 执行真正的训练,传递进度回调
|
||||||
|
# 执行真正的训练,传递所有任务参数
|
||||||
metrics = predictor.train_model(
|
metrics = predictor.train_model(
|
||||||
product_id=task.product_id,
|
**asdict(task),
|
||||||
model_type=task.model_type,
|
|
||||||
epochs=task.epochs,
|
|
||||||
store_id=task.store_id,
|
|
||||||
training_mode=task.training_mode,
|
|
||||||
product_ids=task.product_ids,
|
|
||||||
product_scope=task.product_scope,
|
|
||||||
store_ids=task.store_ids,
|
|
||||||
training_scope=task.training_scope,
|
|
||||||
aggregation_method=task.aggregation_method,
|
|
||||||
socketio=None, # 子进程中不能直接使用socketio
|
socketio=None, # 子进程中不能直接使用socketio
|
||||||
task_id=task.task_id,
|
|
||||||
progress_callback=progress_callback # 传递进度回调函数
|
progress_callback=progress_callback # 传递进度回调函数
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1,6 +1,27 @@
|
|||||||
# 修改记录日志 (日期: 2025-07-16)
|
# 修改记录日志 (日期: 2025-07-16)
|
||||||
|
|
||||||
## 1. 核心 Bug 修复
|
## 1. 训练流程与模型保存逻辑修复 (重大)
|
||||||
|
|
||||||
|
- **背景**: 用户报告在“按店铺”和“按药品”模式下,如果选择了特定的子集(如为某个店铺选择特定药品),生成的模型范围 (`scope`) 不正确,始终为 `_all`。此外,所有模型都被错误地保存到 `global` 目录下,且在某些模式下训练会失败。
|
||||||
|
- **根本原因**:
|
||||||
|
1. `server/core/predictor.py` 中负责准备训练参数的内部函数 (`_prepare_product_params`, `_prepare_store_params`) 逻辑有误,未能正确处理传入的 `product_ids` 和 `store_ids` 列表来构建详细的 `scope`。
|
||||||
|
2. 各个训练器 (`server/trainers/*.py`) 内部的日志记录和元数据生成逻辑不统一,且过于依赖 `product_id`,导致在全局或店铺模式下信息展示不清晰。
|
||||||
|
|
||||||
|
- **修复方案**:
|
||||||
|
- **`server/core/predictor.py`**:
|
||||||
|
- **重构 `_prepare_product_params` 和 `_prepare_store_params`**: 修改了这两个函数,使其能够正确使用 `product_ids` 和 `store_ids` 列表。现在,当选择特定范围时,会生成更具描述性的 `scope`,例如 `S001_specific_P001_P002`。
|
||||||
|
- **结果**: 确保了传递给模型管理器的 `scope` 是准确且详细的,从而使模型能够根据训练范围被保存到正确的、独立的文件夹中。
|
||||||
|
|
||||||
|
- **`server/trainers/*.py` (mlstm, kan, tcn, transformer)**:
|
||||||
|
- **标准化日志与元数据**: 对所有四个训练器文件进行了统一修改。引入了一个通用的 `training_description` 变量,该变量整合了 `training_mode`、`scope` 和 `aggregation_method`。
|
||||||
|
- **更新输出**: 修改了所有训练器中的日志消息、图表标题和 `metadata.json` 的生成逻辑,使其全部使用这个标准的 `training_description`。
|
||||||
|
- **结果**: 确保了无论在哪种训练模式下,前端收到的日志、保存的图表和元数据都具有一致、清晰的格式,便于调试和结果追溯。
|
||||||
|
|
||||||
|
- **总体影响**: 此次修复从根本上解决了模型训练范围处理和模型保存路径的错误问题,使整个训练系统在所有模式下都能可靠、一致地运行。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 2. 核心 Bug 修复
|
||||||
|
|
||||||
### 文件: `server/core/predictor.py`
|
### 文件: `server/core/predictor.py`
|
||||||
|
|
||||||
@ -10,7 +31,7 @@
|
|||||||
- 此前已修复 `train_model` 的函数签名,使其能正确接收 `store_ids`。
|
- 此前已修复 `train_model` 的函数签名,使其能正确接收 `store_ids`。
|
||||||
- **结果**: 彻底解决了训练流程中的参数传递问题,根除了由此引发的 `NameError`。
|
- **结果**: 彻底解决了训练流程中的参数传递问题,根除了由此引发的 `NameError`。
|
||||||
|
|
||||||
## 2. 代码清理与重构
|
## 3. 代码清理与重构
|
||||||
|
|
||||||
### 文件: `server/api.py`
|
### 文件: `server/api.py`
|
||||||
|
|
||||||
@ -24,7 +45,7 @@
|
|||||||
- **原因**: 该循环包含 `time.sleep(1)`,仅用于在没有实际训练逻辑时模拟进度更新,现在实际的训练器会通过回调函数报告真实进度,因此该模拟代码不再需要。
|
- **原因**: 该循环包含 `time.sleep(1)`,仅用于在没有实际训练逻辑时模拟进度更新,现在实际的训练器会通过回调函数报告真实进度,因此该模拟代码不再需要。
|
||||||
- **结果**: `TrainingWorker` 现在直接调用实际的训练器,不再有模拟延迟,代码更贴近生产环境。
|
- **结果**: `TrainingWorker` 现在直接调用实际的训练器,不再有模拟延迟,代码更贴近生产环境。
|
||||||
|
|
||||||
## 3. 启动依赖
|
## 4. 启动依赖
|
||||||
|
|
||||||
- **Python**: 3.x
|
- **Python**: 3.x
|
||||||
- **主要库**:
|
- **主要库**:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user