2025-06-18 06:39:41 +08:00
|
|
|
|
"""
|
2025-07-15 20:09:05 +08:00
|
|
|
|
药店销售预测系统 - 核心预测器类 (已重构)
|
|
|
|
|
支持多店铺销售预测功能,并完全集成新的ModelManager
|
2025-06-18 06:39:41 +08:00
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
import os
|
|
|
|
|
import pandas as pd
|
|
|
|
|
import time
|
|
|
|
|
from datetime import datetime
|
|
|
|
|
|
|
|
|
|
from trainers import (
|
|
|
|
|
train_product_model_with_mlstm,
|
|
|
|
|
train_product_model_with_kan,
|
|
|
|
|
train_product_model_with_tcn,
|
|
|
|
|
train_product_model_with_transformer
|
|
|
|
|
)
|
|
|
|
|
from predictors.model_predictor import load_model_and_predict
|
2025-07-02 11:05:23 +08:00
|
|
|
|
from utils.multi_store_data_utils import (
|
2025-07-15 20:09:05 +08:00
|
|
|
|
load_multi_store_data,
|
2025-07-02 11:05:23 +08:00
|
|
|
|
get_store_product_sales_data,
|
|
|
|
|
aggregate_multi_store_data
|
|
|
|
|
)
|
2025-06-18 06:39:41 +08:00
|
|
|
|
from core.config import DEVICE, DEFAULT_MODEL_DIR, DEFAULT_DATA_PATH
|
2025-07-15 20:09:05 +08:00
|
|
|
|
from utils.model_manager import model_manager
|
2025-06-18 06:39:41 +08:00
|
|
|
|
|
|
|
|
|
class PharmacyPredictor:
|
|
|
|
|
"""
|
|
|
|
|
药店销售预测系统核心类,用于训练模型和进行预测
|
|
|
|
|
"""
|
2025-07-02 11:05:23 +08:00
|
|
|
|
def __init__(self, data_path=None, model_dir=DEFAULT_MODEL_DIR):
|
2025-06-18 06:39:41 +08:00
|
|
|
|
"""
|
|
|
|
|
初始化预测器
|
|
|
|
|
"""
|
2025-07-15 20:09:05 +08:00
|
|
|
|
self.data_path = data_path if data_path else DEFAULT_DATA_PATH
|
2025-06-18 06:39:41 +08:00
|
|
|
|
self.model_dir = model_dir
|
|
|
|
|
self.device = DEVICE
|
|
|
|
|
|
|
|
|
|
if not os.path.exists(model_dir):
|
|
|
|
|
os.makedirs(model_dir)
|
|
|
|
|
|
|
|
|
|
print(f"使用设备: {self.device}")
|
|
|
|
|
|
2025-07-02 11:05:23 +08:00
|
|
|
|
try:
|
2025-07-15 20:09:05 +08:00
|
|
|
|
self.data = load_multi_store_data(self.data_path)
|
|
|
|
|
print(f"已加载多店铺数据,来源: {self.data_path}")
|
2025-07-02 11:05:23 +08:00
|
|
|
|
except Exception as e:
|
|
|
|
|
print(f"加载数据失败: {e}")
|
2025-06-18 06:39:41 +08:00
|
|
|
|
self.data = None
|
|
|
|
|
|
2025-07-16 16:50:30 +08:00
|
|
|
|
def _prepare_product_params(self, product_id, store_scope, **kwargs):
|
|
|
|
|
"""为 'product' 训练模式准备参数"""
|
|
|
|
|
if not product_id:
|
|
|
|
|
raise ValueError("进行 'product' 模式训练时,必须提供 product_id。")
|
|
|
|
|
|
|
|
|
|
agg_store_id = None
|
|
|
|
|
final_scope_suffix = store_scope
|
|
|
|
|
|
|
|
|
|
if store_scope == 'specific':
|
|
|
|
|
store_ids = kwargs.get('store_ids')
|
|
|
|
|
if not store_ids:
|
|
|
|
|
raise ValueError("当 store_scope 为 'specific' 时, 必须提供 store_ids 列表。")
|
|
|
|
|
agg_store_id = store_ids
|
|
|
|
|
final_scope_suffix = f"specific_{'_'.join(store_ids)}"
|
|
|
|
|
elif store_scope != 'all':
|
|
|
|
|
# 假设 store_scope 本身就是一个店铺ID
|
|
|
|
|
agg_store_id = [store_scope]
|
2025-07-16 15:34:48 +08:00
|
|
|
|
|
|
|
|
|
return {
|
2025-07-16 16:50:30 +08:00
|
|
|
|
'agg_store_id': agg_store_id,
|
|
|
|
|
'agg_product_id': [product_id],
|
|
|
|
|
'final_scope': f"{product_id}_{final_scope_suffix}",
|
2025-07-16 15:34:48 +08:00
|
|
|
|
}
|
|
|
|
|
|
2025-07-16 16:50:30 +08:00
|
|
|
|
def _prepare_store_params(self, store_id, product_scope, **kwargs):
|
|
|
|
|
"""为 'store' 训练模式准备参数"""
|
|
|
|
|
if not store_id:
|
|
|
|
|
raise ValueError("进行 'store' 模式训练时,必须提供 store_id。")
|
|
|
|
|
|
|
|
|
|
agg_product_id = None
|
|
|
|
|
final_scope_suffix = product_scope
|
|
|
|
|
|
|
|
|
|
if product_scope == 'specific':
|
|
|
|
|
product_ids = kwargs.get('product_ids')
|
|
|
|
|
if not product_ids:
|
|
|
|
|
raise ValueError("当 product_scope 为 'specific' 时, 必须提供 product_ids 列表。")
|
|
|
|
|
agg_product_id = product_ids
|
|
|
|
|
final_scope_suffix = f"specific_{'_'.join(product_ids)}"
|
|
|
|
|
elif product_scope != 'all':
|
|
|
|
|
# 假设 product_scope 本身就是一个药品ID
|
|
|
|
|
agg_product_id = [product_scope]
|
2025-07-16 15:34:48 +08:00
|
|
|
|
|
|
|
|
|
return {
|
2025-07-16 16:50:30 +08:00
|
|
|
|
'agg_store_id': [store_id],
|
|
|
|
|
'agg_product_id': agg_product_id,
|
|
|
|
|
'final_scope': f"{store_id}_{final_scope_suffix}",
|
2025-07-16 15:34:48 +08:00
|
|
|
|
}
|
|
|
|
|
|
2025-07-16 16:50:30 +08:00
|
|
|
|
def _prepare_global_params(self, global_scope, store_ids, product_ids, **kwargs):
|
|
|
|
|
"""为 'global' 训练模式准备参数"""
|
|
|
|
|
agg_store_id, agg_product_id = None, None
|
2025-07-16 15:34:48 +08:00
|
|
|
|
|
2025-07-16 16:50:30 +08:00
|
|
|
|
if global_scope == 'all':
|
|
|
|
|
final_scope = 'all'
|
|
|
|
|
elif global_scope == 'selected_stores':
|
|
|
|
|
if not store_ids: raise ValueError("global_scope 为 'selected_stores' 时必须提供 store_ids。")
|
|
|
|
|
final_scope = f"stores/{'_'.join(store_ids)}"
|
|
|
|
|
agg_store_id = store_ids
|
|
|
|
|
elif global_scope == 'selected_products':
|
|
|
|
|
if not product_ids: raise ValueError("global_scope 为 'selected_products' 时必须提供 product_ids。")
|
|
|
|
|
final_scope = f"products/{'_'.join(product_ids)}"
|
|
|
|
|
agg_product_id = product_ids
|
|
|
|
|
elif global_scope == 'custom':
|
|
|
|
|
if not store_ids or not product_ids: raise ValueError("global_scope 为 'custom' 时必须提供 store_ids 和 product_ids。")
|
|
|
|
|
final_scope = f"custom/{'_'.join(store_ids)}/{'_'.join(product_ids)}"
|
|
|
|
|
agg_store_id = store_ids
|
|
|
|
|
agg_product_id = product_ids
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError(f"不支持的 global_scope: '{global_scope}'")
|
|
|
|
|
|
|
|
|
|
return {
|
|
|
|
|
'agg_store_id': agg_store_id,
|
|
|
|
|
'agg_product_id': agg_product_id,
|
|
|
|
|
'final_scope': final_scope,
|
|
|
|
|
}
|
2025-07-16 15:34:48 +08:00
|
|
|
|
|
2025-07-16 16:50:30 +08:00
|
|
|
|
def _prepare_training_params(self, training_mode, **kwargs):
|
|
|
|
|
"""参数准备分发器"""
|
|
|
|
|
if training_mode == 'product':
|
|
|
|
|
return self._prepare_product_params(**kwargs)
|
|
|
|
|
elif training_mode == 'store':
|
|
|
|
|
return self._prepare_store_params(**kwargs)
|
|
|
|
|
elif training_mode == 'global':
|
|
|
|
|
return self._prepare_global_params(**kwargs)
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError(f"不支持的 training_mode: '{training_mode}'")
|
|
|
|
|
|
|
|
|
|
def train_model(self, **kwargs):
|
2025-06-18 06:39:41 +08:00
|
|
|
|
"""
|
2025-07-16 16:50:30 +08:00
|
|
|
|
训练预测模型 - 完全适配新的训练器接口和模型保存规则
|
2025-06-18 06:39:41 +08:00
|
|
|
|
"""
|
2025-07-16 16:50:30 +08:00
|
|
|
|
# 从 kwargs 中安全地提取参数
|
|
|
|
|
product_id = kwargs.get('product_id')
|
|
|
|
|
model_type = kwargs.get('model_type', 'transformer')
|
|
|
|
|
epochs = kwargs.get('epochs', 100)
|
|
|
|
|
learning_rate = kwargs.get('learning_rate', 0.001)
|
|
|
|
|
use_optimized = kwargs.get('use_optimized', False)
|
|
|
|
|
store_id = kwargs.get('store_id')
|
|
|
|
|
training_mode = kwargs.get('training_mode', 'product')
|
|
|
|
|
aggregation_method = kwargs.get('aggregation_method', 'sum')
|
|
|
|
|
product_scope = kwargs.get('product_scope', 'all')
|
|
|
|
|
store_scope = kwargs.get('store_scope', 'all')
|
|
|
|
|
global_scope = kwargs.get('global_scope', 'all')
|
|
|
|
|
product_ids = kwargs.get('product_ids')
|
|
|
|
|
store_ids = kwargs.get('store_ids')
|
|
|
|
|
socketio = kwargs.get('socketio')
|
|
|
|
|
task_id = kwargs.get('task_id')
|
|
|
|
|
progress_callback = kwargs.get('progress_callback')
|
|
|
|
|
patience = kwargs.get('patience', 10)
|
2025-07-02 11:05:23 +08:00
|
|
|
|
def log_message(message, log_type='info'):
|
2025-07-15 20:09:05 +08:00
|
|
|
|
print(f"[{log_type.upper()}] {message}", flush=True)
|
2025-07-02 11:05:23 +08:00
|
|
|
|
if progress_callback:
|
|
|
|
|
try:
|
2025-07-15 20:09:05 +08:00
|
|
|
|
progress_callback({'log_type': log_type, 'message': message})
|
2025-07-02 11:05:23 +08:00
|
|
|
|
except Exception as e:
|
2025-07-15 20:09:05 +08:00
|
|
|
|
print(f"[ERROR] 进度回调失败: {e}", flush=True)
|
2025-07-02 11:05:23 +08:00
|
|
|
|
|
2025-06-18 06:39:41 +08:00
|
|
|
|
if self.data is None:
|
2025-07-02 11:05:23 +08:00
|
|
|
|
log_message("没有可用的数据,请先加载或生成数据", 'error')
|
2025-06-18 06:39:41 +08:00
|
|
|
|
return None
|
|
|
|
|
|
2025-07-15 20:09:05 +08:00
|
|
|
|
try:
|
2025-07-16 16:50:30 +08:00
|
|
|
|
# 将所有相关参数打包以便传递
|
|
|
|
|
prep_args = {
|
|
|
|
|
'training_mode': training_mode,
|
|
|
|
|
'product_id': product_id, 'store_id': store_id,
|
|
|
|
|
'product_scope': product_scope, 'store_scope': store_scope,
|
|
|
|
|
'global_scope': global_scope, 'product_ids': product_ids, 'store_ids': store_ids
|
|
|
|
|
}
|
|
|
|
|
params = self._prepare_training_params(**prep_args)
|
2025-07-16 15:34:48 +08:00
|
|
|
|
|
|
|
|
|
product_data = aggregate_multi_store_data(
|
|
|
|
|
store_id=params['agg_store_id'],
|
|
|
|
|
product_id=params['agg_product_id'],
|
|
|
|
|
aggregation_method=aggregation_method,
|
|
|
|
|
file_path=self.data_path
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if product_data is None or product_data.empty:
|
2025-07-16 16:50:30 +08:00
|
|
|
|
raise ValueError(f"聚合后数据为空,无法继续训练。模式: {training_mode}, Scope: {params['final_scope']}")
|
2025-07-16 15:34:48 +08:00
|
|
|
|
|
|
|
|
|
except ValueError as e:
|
|
|
|
|
log_message(f"参数校验或数据准备失败: {e}", 'error')
|
|
|
|
|
return None
|
2025-07-15 20:09:05 +08:00
|
|
|
|
except Exception as e:
|
2025-07-16 15:34:48 +08:00
|
|
|
|
import traceback
|
|
|
|
|
log_message(f"数据准备过程中发生未知错误: {e}", 'error')
|
|
|
|
|
log_message(traceback.format_exc(), 'error')
|
2025-06-18 06:39:41 +08:00
|
|
|
|
return None
|
2025-07-02 11:05:23 +08:00
|
|
|
|
|
2025-07-15 20:09:05 +08:00
|
|
|
|
trainers = {
|
|
|
|
|
'transformer': train_product_model_with_transformer,
|
|
|
|
|
'mlstm': train_product_model_with_mlstm,
|
|
|
|
|
'tcn': train_product_model_with_tcn,
|
|
|
|
|
'kan': train_product_model_with_kan,
|
|
|
|
|
'optimized_kan': train_product_model_with_kan,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if model_type not in trainers:
|
|
|
|
|
log_message(f"不支持的模型类型: {model_type}", 'error')
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
trainer_func = trainers[model_type]
|
|
|
|
|
|
|
|
|
|
trainer_args = {
|
|
|
|
|
"product_df": product_data,
|
2025-07-16 16:50:30 +08:00
|
|
|
|
"training_mode": training_mode,
|
2025-07-15 20:09:05 +08:00
|
|
|
|
"aggregation_method": aggregation_method,
|
2025-07-16 16:50:30 +08:00
|
|
|
|
"scope": params['final_scope'],
|
2025-07-15 20:09:05 +08:00
|
|
|
|
"epochs": epochs,
|
|
|
|
|
"socketio": socketio,
|
|
|
|
|
"task_id": task_id,
|
|
|
|
|
"progress_callback": progress_callback,
|
|
|
|
|
"patience": patience,
|
|
|
|
|
"learning_rate": learning_rate
|
|
|
|
|
}
|
2025-06-18 06:39:41 +08:00
|
|
|
|
|
2025-07-15 20:09:05 +08:00
|
|
|
|
if 'kan' in model_type:
|
|
|
|
|
trainer_args['use_optimized'] = (model_type == 'optimized_kan')
|
|
|
|
|
|
2025-07-16 16:50:30 +08:00
|
|
|
|
# 确保将 product_id 和 store_id 传递给训练器
|
|
|
|
|
if product_id:
|
|
|
|
|
trainer_args['product_id'] = product_id
|
|
|
|
|
if store_id:
|
|
|
|
|
trainer_args['store_id'] = store_id
|
|
|
|
|
|
2025-07-02 11:05:23 +08:00
|
|
|
|
try:
|
2025-07-16 16:50:30 +08:00
|
|
|
|
log_message(f"🤖 开始调用 {model_type} 训练器 with scope: '{params['final_scope']}'")
|
2025-07-02 11:05:23 +08:00
|
|
|
|
|
2025-07-15 20:09:05 +08:00
|
|
|
|
model, metrics, version, model_version_path = trainer_func(**trainer_args)
|
|
|
|
|
|
|
|
|
|
log_message(f"✅ {model_type} 训练器成功返回", 'success')
|
2025-07-02 11:05:23 +08:00
|
|
|
|
|
|
|
|
|
if metrics:
|
2025-07-16 15:34:48 +08:00
|
|
|
|
relative_model_path = os.path.relpath(model_version_path, os.getcwd())
|
|
|
|
|
|
2025-07-02 11:05:23 +08:00
|
|
|
|
metrics.update({
|
2025-07-15 20:09:05 +08:00
|
|
|
|
'model_type': model_type,
|
|
|
|
|
'version': version,
|
2025-07-16 15:34:48 +08:00
|
|
|
|
'model_path': relative_model_path.replace('\\', '/'),
|
2025-07-16 16:50:30 +08:00
|
|
|
|
'training_mode': training_mode,
|
|
|
|
|
'scope': params['final_scope'],
|
|
|
|
|
'aggregation_method': aggregation_method
|
2025-07-02 11:05:23 +08:00
|
|
|
|
})
|
|
|
|
|
log_message(f"📈 最终返回的metrics: {metrics}", 'success')
|
2025-07-15 20:09:05 +08:00
|
|
|
|
return metrics
|
2025-07-02 11:05:23 +08:00
|
|
|
|
else:
|
2025-07-15 20:09:05 +08:00
|
|
|
|
log_message("⚠️ 训练器返回的metrics为空", 'warning')
|
|
|
|
|
return None
|
|
|
|
|
|
2025-07-02 11:05:23 +08:00
|
|
|
|
except Exception as e:
|
2025-07-15 20:09:05 +08:00
|
|
|
|
import traceback
|
|
|
|
|
log_message(f"模型训练过程中发生严重错误: {e}\n{traceback.format_exc()}", 'error')
|
2025-07-02 11:05:23 +08:00
|
|
|
|
return None
|
2025-07-15 20:09:05 +08:00
|
|
|
|
|
|
|
|
|
def predict(self, model_version_path, future_days=7, start_date=None, analyze_result=False):
|
2025-06-18 06:39:41 +08:00
|
|
|
|
"""
|
2025-07-15 20:09:05 +08:00
|
|
|
|
使用已训练的模型进行预测 - 直接使用模型版本路径
|
2025-06-18 06:39:41 +08:00
|
|
|
|
"""
|
2025-07-15 20:09:05 +08:00
|
|
|
|
if not os.path.exists(model_version_path):
|
|
|
|
|
raise FileNotFoundError(f"指定的模型路径不存在: {model_version_path}")
|
|
|
|
|
|
2025-06-18 06:39:41 +08:00
|
|
|
|
return load_model_and_predict(
|
2025-07-15 20:09:05 +08:00
|
|
|
|
model_version_path=model_version_path,
|
|
|
|
|
future_days=future_days,
|
|
|
|
|
start_date=start_date,
|
|
|
|
|
analyze_result=analyze_result
|
2025-06-18 06:39:41 +08:00
|
|
|
|
)
|
2025-07-15 20:09:05 +08:00
|
|
|
|
|
|
|
|
|
def list_models(self, **kwargs):
|
2025-06-18 06:39:41 +08:00
|
|
|
|
"""
|
2025-07-15 20:09:05 +08:00
|
|
|
|
列出所有可用的模型版本。
|
|
|
|
|
直接调用 ModelManager 的 list_models 方法。
|
|
|
|
|
支持的过滤参数: model_type, training_mode, scope, version
|
2025-06-18 06:39:41 +08:00
|
|
|
|
"""
|
2025-07-15 20:09:05 +08:00
|
|
|
|
return model_manager.list_models(**kwargs)
|
|
|
|
|
|
|
|
|
|
def delete_model(self, model_version_path):
|
2025-06-18 06:39:41 +08:00
|
|
|
|
"""
|
2025-07-15 20:09:05 +08:00
|
|
|
|
删除一个指定的模型版本目录。
|
2025-06-18 06:39:41 +08:00
|
|
|
|
"""
|
2025-07-15 20:09:05 +08:00
|
|
|
|
return model_manager.delete_model_version(model_version_path)
|
|
|
|
|
|
|
|
|
|
def compare_models(self, product_id, epochs=50, **kwargs):
|
2025-07-02 11:05:23 +08:00
|
|
|
|
"""
|
2025-07-15 20:09:05 +08:00
|
|
|
|
在相同数据上训练并比较多个模型的性能。
|
2025-07-02 11:05:23 +08:00
|
|
|
|
"""
|
2025-07-15 20:09:05 +08:00
|
|
|
|
results = {}
|
|
|
|
|
model_types_to_compare = ['tcn', 'mlstm', 'transformer', 'kan', 'optimized_kan']
|
2025-07-02 11:05:23 +08:00
|
|
|
|
|
2025-07-15 20:09:05 +08:00
|
|
|
|
for model_type in model_types_to_compare:
|
|
|
|
|
print(f"\n{'='*20} 训练模型: {model_type.upper()} {'='*20}")
|
|
|
|
|
try:
|
|
|
|
|
metrics = self.train_model(
|
|
|
|
|
product_id=product_id,
|
|
|
|
|
model_type=model_type,
|
|
|
|
|
epochs=epochs,
|
|
|
|
|
**kwargs
|
|
|
|
|
)
|
|
|
|
|
results[model_type] = metrics if metrics else {}
|
|
|
|
|
except Exception as e:
|
|
|
|
|
print(f"训练 {model_type} 模型失败: {e}")
|
|
|
|
|
results[model_type] = {'error': str(e)}
|
2025-06-18 06:39:41 +08:00
|
|
|
|
|
2025-07-15 20:09:05 +08:00
|
|
|
|
# 打印比较结果
|
|
|
|
|
print(f"\n{'='*25} 模型性能比较 {'='*25}")
|
|
|
|
|
|
|
|
|
|
# 准备数据
|
|
|
|
|
df_data = []
|
|
|
|
|
for model, metrics in results.items():
|
|
|
|
|
if metrics and 'rmse' in metrics:
|
|
|
|
|
df_data.append({
|
|
|
|
|
'Model': model.upper(),
|
|
|
|
|
'RMSE': metrics.get('rmse'),
|
|
|
|
|
'R²': metrics.get('r2'),
|
|
|
|
|
'MAPE (%)': metrics.get('mape'),
|
|
|
|
|
'Time (s)': metrics.get('training_time')
|
|
|
|
|
})
|
2025-06-18 06:39:41 +08:00
|
|
|
|
|
2025-07-15 20:09:05 +08:00
|
|
|
|
if not df_data:
|
|
|
|
|
print("没有可供比较的模型结果。")
|
|
|
|
|
return results
|
|
|
|
|
|
|
|
|
|
comparison_df = pd.DataFrame(df_data).set_index('Model')
|
|
|
|
|
print(comparison_df.to_string(float_format="%.4f"))
|
2025-06-18 06:39:41 +08:00
|
|
|
|
|
2025-07-15 20:09:05 +08:00
|
|
|
|
return results
|