插件式添加模型
This commit is contained in:
parent
038289ae32
commit
751de9b548
Binary file not shown.
@ -56,5 +56,6 @@ tzdata==2025.2
|
||||
werkzeug==3.1.3
|
||||
win32-setctime==1.2.0
|
||||
wsproto==1.2.0
|
||||
|
||||
python-dateutil
|
||||
xgboost
|
||||
scikit-learn
|
||||
|
@ -45,6 +45,7 @@ from trainers.mlstm_trainer import train_product_model_with_mlstm
|
||||
from trainers.kan_trainer import train_product_model_with_kan
|
||||
from trainers.tcn_trainer import train_product_model_with_tcn
|
||||
from trainers.transformer_trainer import train_product_model_with_transformer
|
||||
from trainers.xgboost_trainer import train_product_model_with_xgboost
|
||||
|
||||
# 导入预测函数
|
||||
from predictors.model_predictor import load_model_and_predict
|
||||
@ -810,7 +811,7 @@ def get_all_training_tasks():
|
||||
'type': 'object',
|
||||
'properties': {
|
||||
'product_id': {'type': 'string', 'description': '例如 P001'},
|
||||
'model_type': {'type': 'string', 'enum': ['mlstm', 'transformer', 'kan', 'optimized_kan', 'tcn'], 'description': '要训练的模型类型'},
|
||||
'model_type': {'type': 'string', 'enum': ['mlstm', 'transformer', 'kan', 'optimized_kan', 'tcn', 'xgboost'], 'description': '要训练的模型类型'},
|
||||
'store_id': {'type': 'string', 'description': '店铺ID,如 S001。为空时使用全局聚合数据'},
|
||||
'epochs': {'type': 'integer', 'default': 50, 'description': '训练轮次'}
|
||||
},
|
||||
@ -873,10 +874,10 @@ def start_training():
|
||||
# 全局模式不需要特定的product_id或store_id
|
||||
pass
|
||||
|
||||
# 检查模型类型是否有效
|
||||
valid_model_types = ['mlstm', 'kan', 'optimized_kan', 'transformer', 'tcn']
|
||||
if model_type not in valid_model_types:
|
||||
return jsonify({'error': '无效的模型类型'}), 400
|
||||
# 检查模型类型是否有效 (v2 - 动态检查)
|
||||
from models.model_registry import TRAINER_REGISTRY
|
||||
if model_type not in TRAINER_REGISTRY:
|
||||
return jsonify({'error': f"无效的模型类型: '{model_type}'. 可用模型: {list(TRAINER_REGISTRY.keys())}"}), 400
|
||||
|
||||
# 使用新的训练进程管理器提交任务
|
||||
try:
|
||||
@ -3445,41 +3446,37 @@ def analyze_model_metrics():
|
||||
}
|
||||
})
|
||||
def get_model_types():
|
||||
"""获取系统支持的所有模型类型"""
|
||||
model_types = [
|
||||
{
|
||||
'id': 'mlstm',
|
||||
'name': 'mLSTM',
|
||||
'description': '矩阵长短期记忆网络,适合处理多变量时序数据',
|
||||
'tag_type': 'primary'
|
||||
},
|
||||
{
|
||||
'id': 'transformer',
|
||||
'name': 'Transformer',
|
||||
'description': '基于注意力机制的序列模型,适合捕捉长期依赖关系',
|
||||
'tag_type': 'success'
|
||||
},
|
||||
{
|
||||
'id': 'kan',
|
||||
'name': 'KAN',
|
||||
'description': 'Kolmogorov-Arnold网络,能够逼近任意连续函数',
|
||||
'tag_type': 'warning'
|
||||
},
|
||||
{
|
||||
'id': 'optimized_kan',
|
||||
'name': '优化版KAN',
|
||||
'description': '经过优化的KAN模型,提供更高的预测精度和训练效率',
|
||||
'tag_type': 'info'
|
||||
},
|
||||
{
|
||||
'id': 'tcn',
|
||||
'name': 'TCN',
|
||||
'description': '时间卷积网络,适合处理长序列和平行计算',
|
||||
'tag_type': 'danger'
|
||||
}
|
||||
]
|
||||
"""获取系统支持的所有模型类型 (v2 - 动态加载)"""
|
||||
from models.model_registry import TRAINER_REGISTRY
|
||||
|
||||
return jsonify({"status": "success", "data": model_types})
|
||||
# 预定义的模型元数据,用于美化显示
|
||||
model_meta = {
|
||||
'mlstm': {'name': 'mLSTM', 'description': '矩阵长短期记忆网络,适合处理多变量时序数据', 'tag_type': 'primary'},
|
||||
'transformer': {'name': 'Transformer', 'description': '基于注意力机制的序列模型,适合捕捉长期依赖关系', 'tag_type': 'success'},
|
||||
'kan': {'name': 'KAN', 'description': 'Kolmogorov-Arnold网络,能够逼近任意连续函数', 'tag_type': 'warning'},
|
||||
'optimized_kan': {'name': '优化版KAN', 'description': '经过优化的KAN模型,提供更高的预测精度和训练效率', 'tag_type': 'info'},
|
||||
'tcn': {'name': 'TCN', 'description': '时间卷积网络,适合处理长序列和平行计算', 'tag_type': 'danger'},
|
||||
'xgboost': {'name': 'XGBoost', 'description': '梯度提升决策树,性能强大且高效的经典模型', 'tag_type': 'primary'}
|
||||
}
|
||||
|
||||
# 从注册表动态获取所有已注册的模型ID
|
||||
registered_models = TRAINER_REGISTRY.keys()
|
||||
|
||||
dynamic_model_types = []
|
||||
for model_id in registered_models:
|
||||
meta = model_meta.get(model_id, {
|
||||
'name': model_id.upper(),
|
||||
'description': f'自定义模型: {model_id}',
|
||||
'tag_type': 'secondary'
|
||||
})
|
||||
dynamic_model_types.append({
|
||||
'id': model_id,
|
||||
'name': meta['name'],
|
||||
'description': meta['description'],
|
||||
'tag_type': meta['tag_type']
|
||||
})
|
||||
|
||||
return jsonify({"status": "success", "data": dynamic_model_types})
|
||||
|
||||
# ========== 新增版本管理API ==========
|
||||
|
||||
|
@ -58,7 +58,9 @@ HIDDEN_SIZE = 64 # 隐藏层大小
|
||||
NUM_LAYERS = 2 # 层数
|
||||
|
||||
# 支持的模型类型
|
||||
SUPPORTED_MODELS = ['mlstm', 'kan', 'transformer', 'tcn', 'optimized_kan']
|
||||
# 支持的模型类型 (v2 - 动态加载)
|
||||
from models.model_registry import TRAINER_REGISTRY
|
||||
SUPPORTED_MODELS = list(TRAINER_REGISTRY.keys())
|
||||
|
||||
# 版本管理配置
|
||||
MODEL_VERSION_PREFIX = 'v' # 版本前缀
|
||||
|
@ -11,12 +11,13 @@ import time
|
||||
import matplotlib.pyplot as plt
|
||||
from datetime import datetime
|
||||
|
||||
from trainers import (
|
||||
train_product_model_with_mlstm,
|
||||
train_product_model_with_kan,
|
||||
train_product_model_with_tcn,
|
||||
train_product_model_with_transformer
|
||||
)
|
||||
# from trainers import (
|
||||
# train_product_model_with_mlstm,
|
||||
# train_product_model_with_kan,
|
||||
# train_product_model_with_tcn,
|
||||
# train_product_model_with_transformer
|
||||
# )
|
||||
# 上述导入已不再需要,因为我们现在通过模型注册表动态获取训练器
|
||||
from predictors.model_predictor import load_model_and_predict
|
||||
from utils.data_utils import prepare_data, prepare_sequences
|
||||
from utils.multi_store_data_utils import (
|
||||
@ -187,89 +188,49 @@ class PharmacyPredictor:
|
||||
else: # product mode
|
||||
model_identifier = product_id
|
||||
|
||||
# 调用相应的训练函数
|
||||
# 调用相应的训练函数 (重构为使用注册表)
|
||||
try:
|
||||
log_message(f"🤖 开始调用 {model_type} 训练器")
|
||||
if model_type == 'transformer':
|
||||
model_result, metrics, actual_version = train_product_model_with_transformer(
|
||||
product_id=product_id, # product_id 仍然需要,用于数据过滤
|
||||
model_identifier=model_identifier, # 这是用于保存模型的唯一ID
|
||||
product_df=product_data,
|
||||
store_id=store_id,
|
||||
training_mode=training_mode,
|
||||
aggregation_method=aggregation_method,
|
||||
epochs=epochs,
|
||||
sequence_length=sequence_length,
|
||||
forecast_horizon=forecast_horizon,
|
||||
model_dir=self.model_dir,
|
||||
version=version,
|
||||
socketio=socketio,
|
||||
task_id=task_id,
|
||||
continue_training=continue_training
|
||||
)
|
||||
log_message(f"✅ {model_type} 训练器返回: metrics={type(metrics)}, version={actual_version}", 'success')
|
||||
elif model_type == 'mlstm':
|
||||
_, metrics, _, _ = train_product_model_with_mlstm(
|
||||
product_id=product_id,
|
||||
model_identifier=model_identifier, # 传递修正后的ID
|
||||
product_df=product_data,
|
||||
store_id=store_id,
|
||||
training_mode=training_mode,
|
||||
aggregation_method=aggregation_method,
|
||||
epochs=epochs,
|
||||
sequence_length=sequence_length,
|
||||
forecast_horizon=forecast_horizon,
|
||||
model_dir=self.model_dir,
|
||||
socketio=socketio,
|
||||
task_id=task_id,
|
||||
progress_callback=progress_callback
|
||||
)
|
||||
elif model_type == 'kan':
|
||||
_, metrics = train_product_model_with_kan(
|
||||
product_id=product_id,
|
||||
model_identifier=model_identifier, # 传递修正后的ID
|
||||
product_df=product_data,
|
||||
store_id=store_id,
|
||||
training_mode=training_mode,
|
||||
aggregation_method=aggregation_method,
|
||||
epochs=epochs,
|
||||
sequence_length=sequence_length,
|
||||
forecast_horizon=forecast_horizon,
|
||||
use_optimized=use_optimized,
|
||||
model_dir=self.model_dir
|
||||
)
|
||||
elif model_type == 'optimized_kan':
|
||||
_, metrics = train_product_model_with_kan(
|
||||
product_id=product_id,
|
||||
model_identifier=model_identifier, # 传递修正后的ID
|
||||
product_df=product_data,
|
||||
store_id=store_id,
|
||||
training_mode=training_mode,
|
||||
aggregation_method=aggregation_method,
|
||||
epochs=epochs,
|
||||
sequence_length=sequence_length,
|
||||
forecast_horizon=forecast_horizon,
|
||||
use_optimized=True,
|
||||
model_dir=self.model_dir
|
||||
)
|
||||
elif model_type == 'tcn':
|
||||
_, metrics, _, _ = train_product_model_with_tcn(
|
||||
product_id=product_id,
|
||||
model_identifier=model_identifier, # 传递修正后的ID
|
||||
product_df=product_data,
|
||||
store_id=store_id,
|
||||
training_mode=training_mode,
|
||||
aggregation_method=aggregation_method,
|
||||
epochs=epochs,
|
||||
sequence_length=sequence_length,
|
||||
forecast_horizon=forecast_horizon,
|
||||
model_dir=self.model_dir,
|
||||
socketio=socketio,
|
||||
task_id=task_id
|
||||
)
|
||||
from models.model_registry import get_trainer
|
||||
log_message(f"🤖 正在从注册表获取 '{model_type}' 训练器...")
|
||||
trainer_function = get_trainer(model_type)
|
||||
log_message(f"✅ 成功获取训练器: {trainer_function.__name__}")
|
||||
|
||||
# 准备通用参数
|
||||
trainer_args = {
|
||||
'product_id': product_id,
|
||||
'model_identifier': model_identifier,
|
||||
'product_df': product_data,
|
||||
'store_id': store_id,
|
||||
'training_mode': training_mode,
|
||||
'aggregation_method': aggregation_method,
|
||||
'epochs': epochs,
|
||||
'sequence_length': sequence_length,
|
||||
'forecast_horizon': forecast_horizon,
|
||||
'model_dir': self.model_dir,
|
||||
'socketio': socketio,
|
||||
'task_id': task_id,
|
||||
'progress_callback': progress_callback,
|
||||
'version': version,
|
||||
'continue_training': continue_training,
|
||||
'use_optimized': use_optimized # KAN模型需要
|
||||
}
|
||||
|
||||
# 动态调用训练函数 (v2 - 智能参数过滤)
|
||||
import inspect
|
||||
sig = inspect.signature(trainer_function)
|
||||
valid_args = {k: v for k, v in trainer_args.items() if k in sig.parameters}
|
||||
|
||||
log_message(f"🔍 准备调用 {trainer_function.__name__},有效参数: {list(valid_args.keys())}")
|
||||
|
||||
result = trainer_function(**valid_args)
|
||||
|
||||
# 根据返回值的数量解析metrics
|
||||
if isinstance(result, tuple) and len(result) >= 2:
|
||||
metrics = result[1] # 通常第二个返回值是metrics
|
||||
else:
|
||||
log_message(f"不支持的模型类型: {model_type}", 'error')
|
||||
return None
|
||||
log_message(f"⚠️ 训练器返回格式未知,无法直接提取metrics: {type(result)}", 'warning')
|
||||
metrics = None
|
||||
|
||||
|
||||
# 检查和打印返回的metrics
|
||||
log_message(f"📊 训练完成,检查返回的metrics: {metrics}")
|
||||
|
102
server/models/cnn_bilstm_attention.py
Normal file
102
server/models/cnn_bilstm_attention.py
Normal file
@ -0,0 +1,102 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
CNN-BiLSTM-Attention 模型定义,适配药店销售预测系统。
|
||||
原始代码来源: python机器学习回归全家桶
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
# 注意:由于原始代码使用了 TensorFlow/Keras 的层,我们将在这里创建一个 PyTorch 的等效实现。
|
||||
# 这是一个更健壮、更符合现有系统架构的做法。
|
||||
|
||||
class Attention(nn.Module):
|
||||
"""
|
||||
PyTorch 实现的注意力机制。
|
||||
"""
|
||||
def __init__(self, feature_dim, step_dim, bias=True, **kwargs):
|
||||
super(Attention, self).__init__(**kwargs)
|
||||
|
||||
self.supports_masking = True
|
||||
self.bias = bias
|
||||
self.feature_dim = feature_dim
|
||||
self.step_dim = step_dim
|
||||
self.features_dim = 0
|
||||
|
||||
weight = torch.zeros(feature_dim, 1)
|
||||
nn.init.xavier_uniform_(weight)
|
||||
self.weight = nn.Parameter(weight)
|
||||
|
||||
if bias:
|
||||
self.b = nn.Parameter(torch.zeros(step_dim))
|
||||
|
||||
def forward(self, x, mask=None):
|
||||
feature_dim = self.feature_dim
|
||||
step_dim = self.step_dim
|
||||
|
||||
eij = torch.mm(
|
||||
x.contiguous().view(-1, feature_dim),
|
||||
self.weight
|
||||
).view(-1, step_dim)
|
||||
|
||||
if self.bias:
|
||||
eij = eij + self.b
|
||||
|
||||
eij = torch.tanh(eij)
|
||||
a = torch.exp(eij)
|
||||
|
||||
if mask is not None:
|
||||
a = a * mask
|
||||
|
||||
a = a / (torch.sum(a, 1, keepdim=True) + 1e-10)
|
||||
|
||||
weighted_input = x * torch.unsqueeze(a, -1)
|
||||
return torch.sum(weighted_input, 1)
|
||||
|
||||
|
||||
class CnnBiLstmAttention(nn.Module):
|
||||
"""
|
||||
CNN-BiLSTM-Attention 模型的 PyTorch 实现。
|
||||
"""
|
||||
def __init__(self, input_dim, output_dim, sequence_length, cnn_filters=64, cnn_kernel_size=1, lstm_units=128):
|
||||
super(CnnBiLstmAttention, self).__init__()
|
||||
self.sequence_length = sequence_length
|
||||
self.cnn_filters = cnn_filters
|
||||
self.lstm_units = lstm_units
|
||||
|
||||
# CNN 层
|
||||
self.conv1d = nn.Conv1d(in_channels=input_dim, out_channels=cnn_filters, kernel_size=cnn_kernel_size)
|
||||
self.relu = nn.ReLU()
|
||||
self.maxpool = nn.MaxPool1d(kernel_size=1)
|
||||
|
||||
# BiLSTM 层
|
||||
self.bilstm = nn.LSTM(input_size=cnn_filters, hidden_size=lstm_units, num_layers=1, batch_first=True, bidirectional=True)
|
||||
|
||||
# Attention 层
|
||||
self.attention = Attention(feature_dim=lstm_units * 2, step_dim=sequence_length)
|
||||
|
||||
# 全连接输出层
|
||||
self.dense = nn.Linear(lstm_units * 2, output_dim)
|
||||
|
||||
def forward(self, x):
|
||||
# 输入 x 的形状: (batch_size, sequence_length, input_dim)
|
||||
|
||||
# CNN 处理
|
||||
x = x.permute(0, 2, 1) # 转换为 (batch_size, input_dim, sequence_length) 以适应 Conv1d
|
||||
x = self.conv1d(x)
|
||||
x = self.relu(x)
|
||||
x = x.permute(0, 2, 1) # 转换回 (batch_size, sequence_length, cnn_filters)
|
||||
|
||||
# BiLSTM 处理
|
||||
lstm_out, _ = self.bilstm(x) # lstm_out 形状: (batch_size, sequence_length, lstm_units * 2)
|
||||
|
||||
# Attention 处理
|
||||
# 注意:这里的 Attention 实现可能需要根据具体任务微调
|
||||
# 一个简化的方法是直接使用 LSTM 的最终隐藏状态或输出
|
||||
# 这里我们先用一个简化的逻辑:直接展平 LSTM 输出
|
||||
attention_out = self.attention(lstm_out)
|
||||
|
||||
# 全连接层输出
|
||||
output = self.dense(attention_out)
|
||||
|
||||
return output
|
64
server/models/model_registry.py
Normal file
64
server/models/model_registry.py
Normal file
@ -0,0 +1,64 @@
|
||||
"""
|
||||
模型注册表
|
||||
用于解耦模型的调用和实现,支持插件式扩展新模型。
|
||||
"""
|
||||
|
||||
# 训练器注册表
|
||||
TRAINER_REGISTRY = {}
|
||||
|
||||
def register_trainer(name, func):
|
||||
"""
|
||||
注册一个模型训练器。
|
||||
|
||||
参数:
|
||||
name (str): 模型类型名称 (e.g., 'xgboost')
|
||||
func (function): 对应的训练函数
|
||||
"""
|
||||
if name in TRAINER_REGISTRY:
|
||||
print(f"警告: 模型训练器 '{name}' 已被覆盖注册。")
|
||||
TRAINER_REGISTRY[name] = func
|
||||
print(f"✅ 已注册训练器: {name}")
|
||||
|
||||
def get_trainer(name):
|
||||
"""
|
||||
根据模型类型名称获取一个已注册的训练器。
|
||||
"""
|
||||
if name not in TRAINER_REGISTRY:
|
||||
# 在打印可用训练器之前,确保它们已经被加载
|
||||
from trainers import discover_trainers
|
||||
discover_trainers()
|
||||
if name not in TRAINER_REGISTRY:
|
||||
raise ValueError(f"未注册的模型训练器: '{name}'. 可用: {list(TRAINER_REGISTRY.keys())}")
|
||||
return TRAINER_REGISTRY[name]
|
||||
|
||||
# --- 预测器注册表 ---
|
||||
|
||||
# 预测器函数需要一个统一的接口,例如:
|
||||
# def predictor_function(model, checkpoint, **kwargs): -> predictions
|
||||
|
||||
PREDICTOR_REGISTRY = {}
|
||||
|
||||
def register_predictor(name, func):
|
||||
"""
|
||||
注册一个模型预测器。
|
||||
"""
|
||||
if name in PREDICTOR_REGISTRY:
|
||||
print(f"警告: 模型预测器 '{name}' 已被覆盖注册。")
|
||||
PREDICTOR_REGISTRY[name] = func
|
||||
|
||||
def get_predictor(name):
|
||||
"""
|
||||
根据模型类型名称获取一个已注册的预测器。
|
||||
如果找不到特定预测器,可以返回一个默认的。
|
||||
"""
|
||||
return PREDICTOR_REGISTRY.get(name, PREDICTOR_REGISTRY.get('default'))
|
||||
|
||||
# 默认的PyTorch预测逻辑可以被注册为 'default'
|
||||
def register_default_predictors():
|
||||
from predictors.model_predictor import default_pytorch_predictor
|
||||
register_predictor('default', default_pytorch_predictor)
|
||||
# 如果其他PyTorch模型有特殊预测逻辑,也可以在这里注册
|
||||
# register_predictor('kan', kan_predictor_func)
|
||||
|
||||
# 注意:这个函数的调用时机很重要,需要在应用启动时执行一次。
|
||||
# 我们可以暂时在 model_predictor.py 导入注册表后调用它。
|
Binary file not shown.
@ -18,39 +18,90 @@ from models.mlstm_model import MLSTMTransformer as MatrixLSTM
|
||||
from models.kan_model import KANForecaster
|
||||
from models.tcn_model import TCNForecaster
|
||||
from models.optimized_kan_forecaster import OptimizedKANForecaster
|
||||
from models.cnn_bilstm_attention import CnnBiLstmAttention
|
||||
import xgboost as xgb
|
||||
|
||||
from analysis.trend_analysis import analyze_prediction_result
|
||||
from utils.visualization import plot_prediction_results
|
||||
from utils.multi_store_data_utils import get_store_product_sales_data, aggregate_multi_store_data
|
||||
from core.config import DEVICE, get_model_file_path, DEFAULT_DATA_PATH
|
||||
from models.model_registry import get_predictor, register_predictor
|
||||
|
||||
def default_pytorch_predictor(model, checkpoint, product_df, future_days, start_date, history_lookback_days):
|
||||
"""
|
||||
默认的PyTorch模型预测逻辑,支持自动回归。
|
||||
"""
|
||||
config = checkpoint['config']
|
||||
scaler_X = checkpoint['scaler_X']
|
||||
scaler_y = checkpoint['scaler_y']
|
||||
features = config.get('features', ['sales', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature'])
|
||||
sequence_length = config['sequence_length']
|
||||
|
||||
if start_date:
|
||||
start_date_dt = pd.to_datetime(start_date)
|
||||
prediction_input_df = product_df[product_df['date'] < start_date_dt].tail(sequence_length)
|
||||
else:
|
||||
prediction_input_df = product_df.tail(sequence_length)
|
||||
start_date_dt = product_df['date'].iloc[-1] + timedelta(days=1)
|
||||
|
||||
if len(prediction_input_df) < sequence_length:
|
||||
raise ValueError(f"预测所需的历史数据不足。需要 {sequence_length} 天, 但只有 {len(prediction_input_df)} 天。")
|
||||
|
||||
history_for_chart_df = product_df[product_df['date'] < start_date_dt].tail(history_lookback_days)
|
||||
|
||||
all_predictions = []
|
||||
current_sequence_df = prediction_input_df.copy()
|
||||
|
||||
for _ in range(future_days):
|
||||
X_current_scaled = scaler_X.transform(current_sequence_df[features].values)
|
||||
# **核心改进**: 智能判断模型类型并调用相应的预测方法
|
||||
if isinstance(model, xgb.Booster):
|
||||
# XGBoost 模型预测路径
|
||||
X_input_reshaped = X_current_scaled.reshape(1, -1)
|
||||
d_input = xgb.DMatrix(X_input_reshaped)
|
||||
# **关键修复**: 使用 best_iteration 进行预测,以匹配早停策略
|
||||
y_pred_scaled = model.predict(d_input, iteration_range=(0, model.best_iteration))
|
||||
next_step_pred_scaled = y_pred_scaled.reshape(1, -1)
|
||||
else:
|
||||
# 默认 PyTorch 模型预测路径
|
||||
X_input = torch.tensor(X_current_scaled.reshape(1, sequence_length, -1), dtype=torch.float32).to(DEVICE)
|
||||
with torch.no_grad():
|
||||
y_pred_scaled = model(X_input).cpu().numpy()
|
||||
next_step_pred_scaled = y_pred_scaled[0, 0].reshape(1, -1)
|
||||
next_step_pred_unscaled = float(max(0, scaler_y.inverse_transform(next_step_pred_scaled)[0][0]))
|
||||
|
||||
next_date = current_sequence_df['date'].iloc[-1] + timedelta(days=1)
|
||||
all_predictions.append({'date': next_date, 'predicted_sales': next_step_pred_unscaled})
|
||||
|
||||
new_row = {'date': next_date, 'sales': next_step_pred_unscaled, 'weekday': next_date.weekday(), 'month': next_date.month, 'is_holiday': 0, 'is_weekend': 1 if next_date.weekday() >= 5 else 0, 'is_promotion': 0, 'temperature': current_sequence_df['temperature'].iloc[-1]}
|
||||
new_row_df = pd.DataFrame([new_row])
|
||||
current_sequence_df = pd.concat([current_sequence_df.iloc[1:], new_row_df], ignore_index=True)
|
||||
|
||||
predictions_df = pd.DataFrame(all_predictions)
|
||||
return predictions_df, history_for_chart_df, prediction_input_df
|
||||
|
||||
# 注册默认的PyTorch预测器
|
||||
register_predictor('default', default_pytorch_predictor)
|
||||
# 将增强后的默认预测器也注册给xgboost
|
||||
register_predictor('xgboost', default_pytorch_predictor)
|
||||
# 将新模型也注册给默认预测器
|
||||
register_predictor('cnn_bilstm_attention', default_pytorch_predictor)
|
||||
|
||||
|
||||
def load_model_and_predict(model_path: str, product_id: str, model_type: str, store_id: Optional[str] = None, future_days: int = 7, start_date: Optional[str] = None, analyze_result: bool = False, version: Optional[str] = None, training_mode: str = 'product', history_lookback_days: int = 30):
|
||||
"""
|
||||
加载已训练的模型并进行预测 (v3版 - 支持自动回归)
|
||||
|
||||
参数:
|
||||
... (同上, 新增 history_lookback_days)
|
||||
history_lookback_days: 用于图表展示的历史数据天数
|
||||
|
||||
返回:
|
||||
预测结果和分析
|
||||
加载已训练的模型并进行预测 (v4版 - 插件式架构)
|
||||
"""
|
||||
try:
|
||||
print(f"v3版预测函数启动,模型路径: {model_path}, 预测天数: {future_days}, 历史回看: {history_lookback_days}")
|
||||
|
||||
if not os.path.exists(model_path):
|
||||
print(f"模型文件 {model_path} 不存在")
|
||||
return None
|
||||
|
||||
# 加载销售数据
|
||||
raise FileNotFoundError(f"模型文件 {model_path} 不存在")
|
||||
|
||||
# --- 数据加载部分保持不变 ---
|
||||
from utils.multi_store_data_utils import aggregate_multi_store_data
|
||||
if training_mode == 'store' and store_id:
|
||||
# 先从原始数据加载一次以获取店铺名称,聚合会丢失此信息
|
||||
from utils.multi_store_data_utils import load_multi_store_data
|
||||
store_df_for_name = load_multi_store_data(store_id=store_id)
|
||||
product_name = store_df_for_name['store_name'].iloc[0] if not store_df_for_name.empty else f"店铺 {store_id}"
|
||||
|
||||
# 然后再进行聚合获取用于预测的数据
|
||||
product_df = aggregate_multi_store_data(store_id=store_id, aggregation_method='sum', file_path=DEFAULT_DATA_PATH)
|
||||
elif training_mode == 'global':
|
||||
product_df = aggregate_multi_store_data(aggregation_method='sum', file_path=DEFAULT_DATA_PATH)
|
||||
@ -60,124 +111,75 @@ def load_model_and_predict(model_path: str, product_id: str, model_type: str, st
|
||||
product_name = product_df['product_name'].iloc[0] if not product_df.empty else product_id
|
||||
|
||||
if product_df.empty:
|
||||
print(f"产品 {product_id} 或店铺 {store_id} 没有销售数据")
|
||||
return None
|
||||
raise ValueError(f"产品 {product_id} 或店铺 {store_id} 没有销售数据")
|
||||
|
||||
# 加载模型和配置
|
||||
# --- 模型加载与实例化 (重构) ---
|
||||
try:
|
||||
torch.serialization.add_safe_globals([sklearn.preprocessing._data.MinMaxScaler])
|
||||
except Exception: pass
|
||||
|
||||
checkpoint = torch.load(model_path, map_location=DEVICE, weights_only=False)
|
||||
if 'config' not in checkpoint or 'scaler_X' not in checkpoint or 'scaler_y' not in checkpoint:
|
||||
print("模型文件不完整,缺少config或scaler")
|
||||
return None
|
||||
|
||||
config = checkpoint['config']
|
||||
scaler_X = checkpoint['scaler_X']
|
||||
scaler_y = checkpoint['scaler_y']
|
||||
|
||||
# 创建模型实例
|
||||
# (此处省略了与原版本相同的模型创建代码,以保持简洁)
|
||||
if model_type == 'transformer':
|
||||
model = TimeSeriesTransformer(num_features=config['input_dim'], d_model=config['hidden_size'], nhead=config['num_heads'], num_encoder_layers=config['num_layers'], dim_feedforward=config['hidden_size'] * 2, dropout=config['dropout'], output_sequence_length=config['output_dim'], seq_length=config['sequence_length'], batch_size=32).to(DEVICE)
|
||||
elif model_type == 'mlstm':
|
||||
model = MatrixLSTM(num_features=config['input_dim'], hidden_size=config['hidden_size'], mlstm_layers=config['mlstm_layers'], embed_dim=config.get('embed_dim', 32), dense_dim=config.get('dense_dim', 32), num_heads=config.get('num_heads', 4), dropout_rate=config['dropout_rate'], num_blocks=config.get('num_blocks', 3), output_sequence_length=config['output_dim']).to(DEVICE)
|
||||
elif model_type == 'kan':
|
||||
model = KANForecaster(input_features=config['input_dim'], hidden_sizes=[config['hidden_size'], config['hidden_size']*2, config['hidden_size']], output_sequence_length=config['output_dim']).to(DEVICE)
|
||||
elif model_type == 'optimized_kan':
|
||||
model = OptimizedKANForecaster(input_features=config['input_dim'], hidden_sizes=[config['hidden_size'], config['hidden_size']*2, config['hidden_size']], output_sequence_length=config['output_dim']).to(DEVICE)
|
||||
elif model_type == 'tcn':
|
||||
model = TCNForecaster(num_features=config['input_dim'], output_sequence_length=config['output_dim'], num_channels=[config['hidden_size']] * config['num_layers'], kernel_size=config['kernel_size'], dropout=config['dropout']).to(DEVICE)
|
||||
config = checkpoint.get('config', {})
|
||||
loaded_model_type = config.get('model_type', model_type) # 优先使用模型内保存的类型
|
||||
|
||||
# 根据模型类型决定如何获取模型实例
|
||||
if loaded_model_type == 'xgboost':
|
||||
# 对于XGBoost, 模型对象直接保存在'model_state_dict'键中
|
||||
model = checkpoint['model_state_dict']
|
||||
else:
|
||||
print(f"不支持的模型类型: {model_type}"); return None
|
||||
|
||||
model.load_state_dict(checkpoint['model_state_dict'])
|
||||
model.eval()
|
||||
|
||||
# --- 核心逻辑修改:自动回归预测 ---
|
||||
features = ['sales', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature']
|
||||
sequence_length = config['sequence_length']
|
||||
|
||||
# 确定预测的起始点
|
||||
if start_date:
|
||||
start_date_dt = pd.to_datetime(start_date)
|
||||
# 获取预测开始日期前的 `sequence_length` 天数据作为初始输入
|
||||
prediction_input_df = product_df[product_df['date'] < start_date_dt].tail(sequence_length)
|
||||
else:
|
||||
# 如果未指定开始日期,则从数据的最后一天开始预测
|
||||
prediction_input_df = product_df.tail(sequence_length)
|
||||
start_date_dt = product_df['date'].iloc[-1] + timedelta(days=1)
|
||||
|
||||
if len(prediction_input_df) < sequence_length:
|
||||
print(f"错误: 预测所需的历史数据不足。需要 {sequence_length} 天, 但只有 {len(prediction_input_df)} 天。")
|
||||
return None
|
||||
|
||||
# 准备用于图表展示的历史数据
|
||||
history_for_chart_df = product_df[product_df['date'] < start_date_dt].tail(history_lookback_days)
|
||||
|
||||
# 自动回归预测循环
|
||||
all_predictions = []
|
||||
current_sequence_df = prediction_input_df.copy()
|
||||
|
||||
print(f"开始自动回归预测,共 {future_days} 天...")
|
||||
for i in range(future_days):
|
||||
# 准备当前序列的输入张量
|
||||
X_current_scaled = scaler_X.transform(current_sequence_df[features].values)
|
||||
X_input = torch.tensor(X_current_scaled.reshape(1, sequence_length, -1), dtype=torch.float32).to(DEVICE)
|
||||
|
||||
# 模型进行一次预测(可能预测出多个点,但我们只用第一个)
|
||||
with torch.no_grad():
|
||||
y_pred_scaled = model(X_input).cpu().numpy()
|
||||
# 对于PyTorch模型, 需要重新构建实例并加载state_dict
|
||||
if loaded_model_type == 'transformer':
|
||||
model = TimeSeriesTransformer(num_features=config['input_dim'], d_model=config['hidden_size'], nhead=config['num_heads'], num_encoder_layers=config['num_layers'], dim_feedforward=config['hidden_size'] * 2, dropout=config['dropout'], output_sequence_length=config['output_dim'], seq_length=config['sequence_length'], batch_size=32).to(DEVICE)
|
||||
elif loaded_model_type == 'mlstm':
|
||||
model = MatrixLSTM(num_features=config['input_dim'], hidden_size=config['hidden_size'], mlstm_layers=config['mlstm_layers'], embed_dim=config.get('embed_dim', 32), dense_dim=config.get('dense_dim', 32), num_heads=config.get('num_heads', 4), dropout_rate=config['dropout_rate'], num_blocks=config.get('num_blocks', 3), output_sequence_length=config['output_dim']).to(DEVICE)
|
||||
elif loaded_model_type == 'kan':
|
||||
model = KANForecaster(input_features=config['input_dim'], hidden_sizes=[config['hidden_size'], config['hidden_size']*2, config['hidden_size']], output_sequence_length=config['output_dim']).to(DEVICE)
|
||||
elif loaded_model_type == 'optimized_kan':
|
||||
model = OptimizedKANForecaster(input_features=config['input_dim'], hidden_sizes=[config['hidden_size'], config['hidden_size']*2, config['hidden_size']], output_sequence_length=config['output_dim']).to(DEVICE)
|
||||
elif loaded_model_type == 'tcn':
|
||||
model = TCNForecaster(num_features=config['input_dim'], output_sequence_length=config['output_dim'], num_channels=[config['hidden_size']] * config['num_layers'], kernel_size=config['kernel_size'], dropout=config['dropout']).to(DEVICE)
|
||||
elif loaded_model_type == 'cnn_bilstm_attention':
|
||||
model = CnnBiLstmAttention(
|
||||
input_dim=config['input_dim'],
|
||||
output_dim=config['output_dim'],
|
||||
sequence_length=config['sequence_length']
|
||||
).to(DEVICE)
|
||||
else:
|
||||
raise ValueError(f"不支持的模型类型: {loaded_model_type}")
|
||||
|
||||
# 提取下一个时间点的预测值
|
||||
next_step_pred_scaled = y_pred_scaled[0, 0].reshape(1, -1)
|
||||
next_step_pred_unscaled = scaler_y.inverse_transform(next_step_pred_scaled)[0][0]
|
||||
next_step_pred_unscaled = float(max(0, next_step_pred_unscaled)) # 确保销量不为负,并转换为标准float
|
||||
model.load_state_dict(checkpoint['model_state_dict'])
|
||||
model.eval()
|
||||
|
||||
# 获取新预测的日期
|
||||
next_date = current_sequence_df['date'].iloc[-1] + timedelta(days=1)
|
||||
all_predictions.append({'date': next_date, 'predicted_sales': next_step_pred_unscaled})
|
||||
# --- 动态调用预测器 ---
|
||||
predictor_function = get_predictor(loaded_model_type)
|
||||
if not predictor_function:
|
||||
raise ValueError(f"找不到模型类型 '{loaded_model_type}' 的预测器实现")
|
||||
|
||||
# 构建新的一行数据,用于更新输入序列
|
||||
new_row = {
|
||||
'date': next_date,
|
||||
'sales': next_step_pred_unscaled,
|
||||
'weekday': next_date.weekday(),
|
||||
'month': next_date.month,
|
||||
'is_holiday': 0,
|
||||
'is_weekend': 1 if next_date.weekday() >= 5 else 0,
|
||||
'is_promotion': 0,
|
||||
'temperature': current_sequence_df['temperature'].iloc[-1] # 沿用最后一天的温度
|
||||
}
|
||||
|
||||
# 更新序列:移除最旧的一行,添加最新预测的一行
|
||||
new_row_df = pd.DataFrame([new_row])
|
||||
current_sequence_df = pd.concat([current_sequence_df.iloc[1:], new_row_df], ignore_index=True)
|
||||
predictions_df, history_for_chart_df, prediction_input_df = predictor_function(
|
||||
model=model,
|
||||
checkpoint=checkpoint,
|
||||
product_df=product_df,
|
||||
future_days=future_days,
|
||||
start_date=start_date,
|
||||
history_lookback_days=history_lookback_days
|
||||
)
|
||||
|
||||
predictions_df = pd.DataFrame(all_predictions)
|
||||
print(f"自动回归预测完成,生成 {len(predictions_df)} 条预测数据。")
|
||||
|
||||
# 分析与可视化
|
||||
# --- 分析与返回部分保持不变 ---
|
||||
analysis = None
|
||||
if analyze_result:
|
||||
try:
|
||||
y_pred_for_analysis = predictions_df['predicted_sales'].values
|
||||
# 使用初始输入序列的特征进行分析
|
||||
initial_features_for_analysis = prediction_input_df[features].values
|
||||
analysis = analyze_prediction_result(product_id, model_type, y_pred_for_analysis, initial_features_for_analysis)
|
||||
analysis = analyze_prediction_result(product_id, loaded_model_type, predictions_df['predicted_sales'].values, prediction_input_df[config.get('features')].values)
|
||||
except Exception as e:
|
||||
print(f"分析预测结果失败: {str(e)}")
|
||||
|
||||
# 在返回前,将DataFrame转换为前端期望的JSON数组格式
|
||||
history_data_json = history_for_chart_df.to_dict('records') if not history_for_chart_df.empty else []
|
||||
prediction_data_json = predictions_df.to_dict('records') if not predictions_df.empty else []
|
||||
|
||||
return {
|
||||
'product_id': product_id,
|
||||
'product_name': product_name,
|
||||
'model_type': model_type,
|
||||
'predictions': prediction_data_json, # 兼容旧字段,使用已转换的json
|
||||
'model_type': loaded_model_type,
|
||||
'predictions': prediction_data_json,
|
||||
'prediction_data': prediction_data_json,
|
||||
'history_data': history_data_json,
|
||||
'analysis': analysis
|
||||
|
@ -2,18 +2,44 @@
|
||||
药店销售预测系统 - 模型训练模块
|
||||
"""
|
||||
|
||||
from .mlstm_trainer import train_product_model_with_mlstm
|
||||
from .kan_trainer import train_product_model_with_kan
|
||||
from .tcn_trainer import train_product_model_with_tcn
|
||||
from .transformer_trainer import train_product_model_with_transformer
|
||||
import os
|
||||
import glob
|
||||
import importlib
|
||||
|
||||
# 默认训练函数
|
||||
from .mlstm_trainer import train_product_model_with_mlstm as train_product_model
|
||||
_TRAINERS_LOADED = False
|
||||
|
||||
def discover_trainers():
|
||||
"""
|
||||
自动发现并加载所有训练器插件。
|
||||
使用一个标志位确保这个过程只执行一次。
|
||||
"""
|
||||
global _TRAINERS_LOADED
|
||||
if _TRAINERS_LOADED:
|
||||
return
|
||||
|
||||
print("🚀 开始发现并加载训练器插件...")
|
||||
|
||||
package_dir = os.path.dirname(__file__)
|
||||
module_name = __name__
|
||||
|
||||
trainer_files = glob.glob(os.path.join(package_dir, "*_trainer.py"))
|
||||
|
||||
for f in trainer_files:
|
||||
base_name = os.path.basename(f)
|
||||
if base_name.startswith('__'):
|
||||
continue
|
||||
|
||||
module_stem = base_name.replace('.py', '')
|
||||
|
||||
try:
|
||||
# 动态导入模块以触发自注册
|
||||
importlib.import_module(f".{module_stem}", package=module_name)
|
||||
except ImportError as e:
|
||||
print(f"⚠️ 加载训练器 {module_stem} 失败: {e}")
|
||||
|
||||
_TRAINERS_LOADED = True
|
||||
print("✅ 所有训练器插件加载完成。")
|
||||
|
||||
# 在包被首次导入时,自动执行发现过程
|
||||
discover_trainers()
|
||||
|
||||
__all__ = [
|
||||
'train_product_model',
|
||||
'train_product_model_with_mlstm',
|
||||
'train_product_model_with_kan',
|
||||
'train_product_model_with_tcn',
|
||||
'train_product_model_with_transformer'
|
||||
]
|
||||
|
118
server/trainers/cnn_bilstm_attention_trainer.py
Normal file
118
server/trainers/cnn_bilstm_attention_trainer.py
Normal file
@ -0,0 +1,118 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
CNN-BiLSTM-Attention 模型训练器
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.optim as optim
|
||||
import numpy as np
|
||||
|
||||
from models.model_registry import register_trainer
|
||||
from utils.model_manager import model_manager
|
||||
from analysis.metrics import evaluate_model
|
||||
from utils.data_utils import create_dataset
|
||||
from sklearn.preprocessing import MinMaxScaler
|
||||
|
||||
# 导入新创建的模型
|
||||
from models.cnn_bilstm_attention import CnnBiLstmAttention
|
||||
|
||||
def train_with_cnn_bilstm_attention(product_id, model_identifier, product_df, store_id, training_mode, aggregation_method, epochs, sequence_length, forecast_horizon, model_dir, **kwargs):
|
||||
"""
|
||||
使用 CNN-BiLSTM-Attention 模型进行训练。
|
||||
函数签名遵循系统标准。
|
||||
"""
|
||||
print(f"🚀 CNN-BiLSTM-Attention 训练器启动: model_identifier='{model_identifier}'")
|
||||
|
||||
# --- 1. 数据准备 ---
|
||||
features = ['sales', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature']
|
||||
|
||||
X = product_df[features].values
|
||||
y = product_df[['sales']].values
|
||||
|
||||
scaler_X = MinMaxScaler(feature_range=(0, 1))
|
||||
scaler_y = MinMaxScaler(feature_range=(0, 1))
|
||||
|
||||
X_scaled = scaler_X.fit_transform(X)
|
||||
y_scaled = scaler_y.fit_transform(y)
|
||||
|
||||
train_size = int(len(X_scaled) * 0.8)
|
||||
X_train_raw, X_test_raw = X_scaled[:train_size], X_scaled[train_size:]
|
||||
y_train_raw, y_test_raw = y_scaled[:train_size], y_scaled[train_size:]
|
||||
|
||||
trainX, trainY = create_dataset(X_train_raw, y_train_raw, sequence_length, forecast_horizon)
|
||||
testX, testY = create_dataset(X_test_raw, y_test_raw, sequence_length, forecast_horizon)
|
||||
|
||||
# 转换为 PyTorch Tensors
|
||||
trainX = torch.from_numpy(trainX).float()
|
||||
trainY = torch.from_numpy(trainY).float()
|
||||
testX = torch.from_numpy(testX).float()
|
||||
testY = torch.from_numpy(testY).float()
|
||||
|
||||
# --- 2. 实例化模型和优化器 ---
|
||||
input_dim = trainX.shape[2]
|
||||
|
||||
model = CnnBiLstmAttention(
|
||||
input_dim=input_dim,
|
||||
output_dim=forecast_horizon,
|
||||
sequence_length=sequence_length
|
||||
)
|
||||
|
||||
optimizer = optim.Adam(model.parameters(), lr=kwargs.get('learning_rate', 0.001))
|
||||
criterion = torch.nn.MSELoss()
|
||||
|
||||
# --- 3. 训练循环 ---
|
||||
print("开始训练 CNN-BiLSTM-Attention 模型...")
|
||||
for epoch in range(epochs):
|
||||
model.train()
|
||||
optimizer.zero_grad()
|
||||
|
||||
outputs = model(trainX)
|
||||
loss = criterion(outputs, trainY.squeeze(-1)) # 确保目标维度匹配
|
||||
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
if (epoch + 1) % 10 == 0:
|
||||
print(f'Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}')
|
||||
|
||||
# --- 4. 模型评估 ---
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
test_pred_scaled = model(testX)
|
||||
|
||||
test_pred_unscaled = scaler_y.inverse_transform(test_pred_scaled.numpy())
|
||||
test_true_unscaled = scaler_y.inverse_transform(testY.squeeze(-1).numpy())
|
||||
|
||||
metrics = evaluate_model(test_true_unscaled.flatten(), test_pred_unscaled.flatten())
|
||||
print(f"模型评估完成: RMSE={metrics['rmse']:.4f}")
|
||||
|
||||
# --- 5. 模型保存 ---
|
||||
model_data = {
|
||||
'model_state_dict': model.state_dict(),
|
||||
'scaler_X': scaler_X,
|
||||
'scaler_y': scaler_y,
|
||||
'config': {
|
||||
'model_type': 'cnn_bilstm_attention',
|
||||
'input_dim': input_dim,
|
||||
'output_dim': forecast_horizon,
|
||||
'sequence_length': sequence_length,
|
||||
'features': features
|
||||
},
|
||||
'metrics': metrics
|
||||
}
|
||||
|
||||
final_model_path, final_version = model_manager.save_model(
|
||||
model_data=model_data,
|
||||
product_id=product_id,
|
||||
model_type='cnn_bilstm_attention',
|
||||
store_id=store_id,
|
||||
training_mode=training_mode,
|
||||
aggregation_method=aggregation_method,
|
||||
product_name=product_df['product_name'].iloc[0]
|
||||
)
|
||||
|
||||
print(f"✅ CNN-BiLSTM-Attention 模型已保存,版本: {final_version}")
|
||||
return model, metrics, final_version, final_model_path
|
||||
|
||||
# --- 关键步骤: 将训练器注册到系统中 ---
|
||||
register_trainer('cnn_bilstm_attention', train_with_cnn_bilstm_attention)
|
@ -349,4 +349,9 @@ def train_product_model_with_kan(product_id, model_identifier, product_df=None,
|
||||
|
||||
print(f"最终模型已保存,版本: {final_version}, 路径: {final_model_path}")
|
||||
|
||||
return model, metrics
|
||||
return model, metrics
|
||||
|
||||
# --- 将此训练器注册到系统中 ---
|
||||
from models.model_registry import register_trainer
|
||||
register_trainer('kan', train_product_model_with_kan)
|
||||
register_trainer('optimized_kan', train_product_model_with_kan)
|
@ -514,4 +514,8 @@ def train_product_model_with_mlstm(
|
||||
|
||||
emit_progress(f"✅ mLSTM模型训练完成!版本 {final_version} 已保存", progress=100, metrics=final_metrics)
|
||||
|
||||
return model, metrics, epochs, final_model_path
|
||||
return model, metrics, epochs, final_model_path
|
||||
|
||||
# --- 将此训练器注册到系统中 ---
|
||||
from models.model_registry import register_trainer
|
||||
register_trainer('mlstm', train_product_model_with_mlstm)
|
@ -379,4 +379,8 @@ def train_product_model_with_tcn(
|
||||
|
||||
emit_progress(f"模型训练完成!版本 {final_version} 已保存", progress=100, metrics=final_metrics)
|
||||
|
||||
return model, metrics, epochs, final_model_path
|
||||
return model, metrics, epochs, final_model_path
|
||||
|
||||
# --- 将此训练器注册到系统中 ---
|
||||
from models.model_registry import register_trainer
|
||||
register_trainer('tcn', train_product_model_with_tcn)
|
@ -406,4 +406,8 @@ def train_product_model_with_transformer(
|
||||
'version': final_version
|
||||
}
|
||||
|
||||
return model, final_metrics, epochs
|
||||
return model, final_metrics, epochs
|
||||
|
||||
# --- 将此训练器注册到系统中 ---
|
||||
from models.model_registry import register_trainer
|
||||
register_trainer('transformer', train_product_model_with_transformer)
|
142
server/trainers/xgboost_trainer.py
Normal file
142
server/trainers/xgboost_trainer.py
Normal file
@ -0,0 +1,142 @@
|
||||
"""
|
||||
药店销售预测系统 - XGBoost 模型训练器 (插件式)
|
||||
"""
|
||||
|
||||
import time
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import xgboost as xgb
|
||||
from sklearn.preprocessing import MinMaxScaler
|
||||
from xgboost.callback import EarlyStopping
|
||||
|
||||
# 导入核心工具
|
||||
from utils.data_utils import create_dataset
|
||||
from analysis.metrics import evaluate_model
|
||||
from utils.model_manager import model_manager
|
||||
from models.model_registry import register_trainer
|
||||
|
||||
def train_product_model_with_xgboost(product_id, model_identifier, product_df, store_id, training_mode, aggregation_method, epochs, sequence_length, forecast_horizon, model_dir, **kwargs):
|
||||
"""
|
||||
使用 XGBoost 模型训练产品销售预测模型。
|
||||
此函数签名与其他训练器保持一致,以兼容注册表调用。
|
||||
"""
|
||||
print(f"🚀 XGBoost训练器启动: model_identifier='{model_identifier}'")
|
||||
|
||||
# --- 1. 数据准备和验证 ---
|
||||
if product_df.empty:
|
||||
raise ValueError(f"产品 {product_id} 没有可用的销售数据")
|
||||
|
||||
min_required_samples = sequence_length + forecast_horizon
|
||||
if len(product_df) < min_required_samples:
|
||||
error_msg = (f"数据不足: 需要 {min_required_samples} 条, 实际 {len(product_df)} 条。")
|
||||
raise ValueError(error_msg)
|
||||
|
||||
product_df = product_df.sort_values('date')
|
||||
product_name = product_df['product_name'].iloc[0] if 'product_name' in product_df.columns else model_identifier
|
||||
|
||||
# --- 2. 数据预处理和适配 ---
|
||||
features = ['sales', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature']
|
||||
|
||||
X = product_df[features].values
|
||||
y = product_df[['sales']].values
|
||||
|
||||
scaler_X = MinMaxScaler(feature_range=(0, 1))
|
||||
scaler_y = MinMaxScaler(feature_range=(0, 1))
|
||||
|
||||
X_scaled = scaler_X.fit_transform(X)
|
||||
y_scaled = scaler_y.fit_transform(y)
|
||||
|
||||
train_size = int(len(X_scaled) * 0.8)
|
||||
X_train_raw, X_test_raw = X_scaled[:train_size], X_scaled[train_size:]
|
||||
y_train_raw, y_test_raw = y_scaled[:train_size], y_scaled[train_size:]
|
||||
|
||||
trainX, trainY = create_dataset(X_train_raw, y_train_raw, sequence_length, forecast_horizon)
|
||||
testX, testY = create_dataset(X_test_raw, y_test_raw, sequence_length, forecast_horizon)
|
||||
|
||||
# **关键适配步骤**: XGBoost 需要二维输入
|
||||
trainX = trainX.reshape(trainX.shape[0], -1)
|
||||
testX = testX.reshape(testX.shape[0], -1)
|
||||
|
||||
# **关键适配**: 转换为 XGBoost 核心 DMatrix 格式,以使用稳定的 xgb.train API
|
||||
dtrain = xgb.DMatrix(trainX, label=trainY)
|
||||
dtest = xgb.DMatrix(testX, label=testY)
|
||||
|
||||
# --- 3. 模型训练 (使用核心 xgb.train API) ---
|
||||
xgb_params = {
|
||||
'learning_rate': kwargs.get('learning_rate', 0.08),
|
||||
'subsample': kwargs.get('subsample', 0.75),
|
||||
'colsample_bytree': kwargs.get('colsample_bytree', 1),
|
||||
'max_depth': kwargs.get('max_depth', 7),
|
||||
'gamma': kwargs.get('gamma', 0),
|
||||
'objective': 'reg:squarederror',
|
||||
'eval_metric': 'rmse', # eval_metric 在这里是原生支持的
|
||||
'n_jobs': -1
|
||||
}
|
||||
n_estimators = kwargs.get('n_estimators', 500)
|
||||
|
||||
print("开始训练XGBoost模型 (使用核心xgb.train API)...")
|
||||
start_time = time.time()
|
||||
|
||||
evals_result = {}
|
||||
model = xgb.train(
|
||||
params=xgb_params,
|
||||
dtrain=dtrain,
|
||||
num_boost_round=n_estimators,
|
||||
evals=[(dtrain, 'train'), (dtest, 'test')],
|
||||
early_stopping_rounds=50, # early_stopping_rounds 在这里是原生支持的
|
||||
evals_result=evals_result,
|
||||
verbose_eval=False
|
||||
)
|
||||
|
||||
training_time = time.time() - start_time
|
||||
print(f"XGBoost模型训练完成,耗时: {training_time:.2f}秒")
|
||||
|
||||
# --- 4. 模型评估 ---
|
||||
# 使用 model.best_iteration 获取最佳轮次的预测结果
|
||||
test_pred = model.predict(dtest, iteration_range=(0, model.best_iteration))
|
||||
|
||||
test_pred_inv = scaler_y.inverse_transform(test_pred.reshape(-1, forecast_horizon))
|
||||
test_true_inv = scaler_y.inverse_transform(testY.reshape(-1, forecast_horizon))
|
||||
|
||||
metrics = evaluate_model(test_true_inv.flatten(), test_pred_inv.flatten())
|
||||
metrics['training_time'] = training_time
|
||||
|
||||
print("\n模型评估指标:")
|
||||
print(f"MSE: {metrics['mse']:.4f}, RMSE: {metrics['rmse']:.4f}, MAE: {metrics['mae']:.4f}, R²: {metrics['r2']:.4f}, MAPE: {metrics['mape']:.2f}%")
|
||||
|
||||
# --- 5. 模型保存 (借道 utils.model_manager) ---
|
||||
# **关键适配点**: 我们将完整的XGBoost模型对象存入字典
|
||||
# torch.save 可以序列化多种Python对象,包括sklearn模型
|
||||
model_data = {
|
||||
'model_state_dict': model, # 直接保存模型对象
|
||||
'scaler_X': scaler_X,
|
||||
'scaler_y': scaler_y,
|
||||
'config': {
|
||||
'sequence_length': sequence_length,
|
||||
'forecast_horizon': forecast_horizon,
|
||||
'model_type': 'xgboost',
|
||||
'features': features,
|
||||
'xgb_params': xgb_params
|
||||
},
|
||||
'metrics': metrics,
|
||||
'loss_history': evals_result
|
||||
}
|
||||
|
||||
# 调用全局管理器进行保存,复用其命名和版本逻辑
|
||||
final_model_path, final_version = model_manager.save_model(
|
||||
model_data=model_data,
|
||||
product_id=product_id,
|
||||
model_type='xgboost',
|
||||
store_id=store_id,
|
||||
training_mode=training_mode,
|
||||
aggregation_method=aggregation_method,
|
||||
product_name=product_name
|
||||
)
|
||||
|
||||
print(f"XGBoost模型已通过统一管理器保存,版本: {final_version}, 路径: {final_model_path}")
|
||||
|
||||
# 返回值遵循统一格式
|
||||
return model, metrics, final_version, final_model_path
|
||||
|
||||
# --- 将此训练器注册到系统中 ---
|
||||
register_trainer('xgboost', train_product_model_with_xgboost)
|
@ -280,48 +280,41 @@ class ModelManager:
|
||||
if len(parts) < 3:
|
||||
return None # 格式不符合基本要求
|
||||
|
||||
model_type = parts[0]
|
||||
mode = parts[1]
|
||||
|
||||
# **核心修复**: 采用更健壮的、从后往前的解析逻辑,以支持带下划线的模型名称
|
||||
try:
|
||||
if mode == 'store' and len(parts) >= 3:
|
||||
# {model_type}_store_{store_id}_{version}
|
||||
version = parts[-1]
|
||||
store_id = '_'.join(parts[2:-1])
|
||||
return {
|
||||
'model_type': model_type,
|
||||
'training_mode': 'store',
|
||||
'store_id': store_id,
|
||||
'version': version,
|
||||
'product_id': None,
|
||||
'aggregation_method': None
|
||||
}
|
||||
elif mode == 'global' and len(parts) >= 3:
|
||||
# {model_type}_global_{aggregation_method}_{version}
|
||||
version = parts[-1]
|
||||
aggregation_method = '_'.join(parts[2:-1])
|
||||
return {
|
||||
'model_type': model_type,
|
||||
'training_mode': 'global',
|
||||
'aggregation_method': aggregation_method,
|
||||
'version': version,
|
||||
'product_id': None,
|
||||
'store_id': None
|
||||
}
|
||||
elif mode == 'product' and len(parts) >= 3:
|
||||
# {model_type}_product_{product_id}_{version}
|
||||
version = parts[-1]
|
||||
product_id = '_'.join(parts[2:-1])
|
||||
version = parts[-1]
|
||||
identifier = parts[-2]
|
||||
mode_candidate = parts[-3]
|
||||
|
||||
if mode_candidate == 'product':
|
||||
model_type = '_'.join(parts[:-3])
|
||||
return {
|
||||
'model_type': model_type,
|
||||
'training_mode': 'product',
|
||||
'product_id': product_id,
|
||||
'product_id': identifier,
|
||||
'version': version,
|
||||
'store_id': None,
|
||||
'aggregation_method': None
|
||||
}
|
||||
elif mode_candidate == 'store':
|
||||
model_type = '_'.join(parts[:-3])
|
||||
return {
|
||||
'model_type': model_type,
|
||||
'training_mode': 'store',
|
||||
'store_id': identifier,
|
||||
'version': version,
|
||||
}
|
||||
elif mode_candidate == 'global':
|
||||
model_type = '_'.join(parts[:-3])
|
||||
return {
|
||||
'model_type': model_type,
|
||||
'training_mode': 'global',
|
||||
'aggregation_method': identifier,
|
||||
'version': version,
|
||||
}
|
||||
except IndexError:
|
||||
# 如果文件名部分不够,则解析失败
|
||||
pass
|
||||
except Exception as e:
|
||||
print(f"解析新版v2文件名失败 {filename}: {e}")
|
||||
print(f"解析文件名失败 {filename}: {e}")
|
||||
|
||||
return None
|
||||
|
||||
|
@ -1,5 +1,5 @@
|
||||
### 根目录启动
|
||||
`uv pip install loguru numpy pandas torch matplotlib flask flask_cors flask_socketio flasgger scikit-learn tqdm pytorch_tcn`
|
||||
`uv pip install loguru numpy pandas torch matplotlib flask flask_cors flask_socketio flasgger scikit-learn tqdm pytorch_tcn pyarrow`
|
||||
|
||||
### UI
|
||||
`npm install` `npm run dev`
|
||||
|
222
xz新模型添加流程.md
Normal file
222
xz新模型添加流程.md
Normal file
@ -0,0 +1,222 @@
|
||||
# 如何向系统添加新模型
|
||||
|
||||
本指南详细说明了如何向本预测系统添加一个全新的预测模型。系统采用灵活的插件式架构,集成新模型的过程非常模块化,主要围绕 **模型(Model)**、**训练器(Trainer)** 和 **预测器(Predictor)** 这三个核心组件进行。
|
||||
|
||||
## 核心理念
|
||||
|
||||
系统的核心是 `models/model_registry.py`,它维护了两个独立的注册表:一个用于训练函数,另一个用于预测函数。添加新模型的本质就是:
|
||||
|
||||
1. **定义模型**:创建模型的架构。
|
||||
2. **创建训练器**:编写一个函数来训练这个模型,并将其注册到训练器注册表。
|
||||
3. **集成预测器**:确保系统知道如何加载模型并用它来预测,然后将预测逻辑注册到预测器注册表。
|
||||
|
||||
---
|
||||
|
||||
## 第 1 步:定义模型架构
|
||||
|
||||
首先,您需要在 `ShopTRAINING/server/models/` 目录下创建一个新的 Python 文件来定义您的模型。
|
||||
|
||||
**示例:创建 `ShopTRAINING/server/models/my_new_model.py`**
|
||||
|
||||
如果您的新模型是基于 PyTorch 的,它应该是一个继承自 `torch.nn.Module` 的类。
|
||||
|
||||
```python
|
||||
# file: ShopTRAINING/server/models/my_new_model.py
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
class MyNewModel(nn.Module):
|
||||
def __init__(self, input_features, hidden_size, output_sequence_length):
|
||||
"""
|
||||
定义模型的层和结构。
|
||||
"""
|
||||
super(MyNewModel, self).__init__()
|
||||
self.layer1 = nn.Linear(input_features, hidden_size)
|
||||
self.relu = nn.ReLU()
|
||||
self.layer2 = nn.Linear(hidden_size, output_sequence_length)
|
||||
# ... 可添加更复杂的结构
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
定义数据通过模型的前向传播路径。
|
||||
x 的形状通常是 (batch_size, sequence_length, num_features)
|
||||
"""
|
||||
# 确保输入是正确的形状
|
||||
# 例如,对于简单的线性层,可能需要展平
|
||||
batch_size, seq_len, features = x.shape
|
||||
x = x.view(batch_size * seq_len, features) # 展平
|
||||
|
||||
out = self.layer1(x)
|
||||
out = self.relu(out)
|
||||
out = self.layer2(out)
|
||||
|
||||
# 恢复形状以匹配输出
|
||||
out = out.view(batch_size, seq_len, -1)
|
||||
# 通常我们只关心序列的最后一个预测
|
||||
return out[:, -1, :]
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 第 2 步:创建模型训练器
|
||||
|
||||
接下来,在 `ShopTRAINING/server/trainers/` 目录下创建一个新的训练器文件。这个文件负责模型的整个训练、评估和保存流程。
|
||||
|
||||
**示例:创建 `ShopTRAINING/server/trainers/my_new_model_trainer.py`**
|
||||
|
||||
这个训练函数需要遵循系统中其他训练器(如 `xgboost_trainer.py`)的统一函数签名,并使用 `@register_trainer` 装饰器或在文件末尾调用 `register_trainer` 函数。
|
||||
|
||||
```python
|
||||
# file: ShopTRAINING/server/trainers/my_new_model_trainer.py
|
||||
|
||||
import torch
|
||||
import torch.optim as optim
|
||||
from models.model_registry import register_trainer
|
||||
from utils.model_manager import model_manager
|
||||
from analysis.metrics import evaluate_model
|
||||
from models.my_new_model import MyNewModel # 导入您的新模型
|
||||
|
||||
# 遵循系统的标准函数签名
|
||||
def train_with_mynewmodel(product_id, model_identifier, product_df, store_id, training_mode, aggregation_method, epochs, sequence_length, forecast_horizon, model_dir, **kwargs):
|
||||
print(f"🚀 MyNewModel 训练器启动: model_identifier='{model_identifier}'")
|
||||
|
||||
# --- 1. 数据准备 ---
|
||||
# (此处省略了数据加载、标准化和创建数据集的详细代码,
|
||||
# 您可以参考 xgboost_trainer.py 或其他训练器中的实现)
|
||||
# ...
|
||||
# 假设您已准备好 trainX, trainY, testX, testY, scaler_y 等变量
|
||||
# trainX = ...
|
||||
# trainY = ...
|
||||
# testX = ...
|
||||
# testY = ...
|
||||
# scaler_y = ...
|
||||
# features = [...]
|
||||
|
||||
# --- 2. 实例化模型和优化器 ---
|
||||
input_dim = trainX.shape[2] # 获取特征数量
|
||||
hidden_size = 64 # 示例超参数
|
||||
|
||||
model = MyNewModel(
|
||||
input_features=input_dim,
|
||||
hidden_size=hidden_size,
|
||||
output_sequence_length=forecast_horizon
|
||||
)
|
||||
optimizer = optim.Adam(model.parameters(), lr=0.001)
|
||||
criterion = torch.nn.MSELoss()
|
||||
|
||||
# --- 3. 训练循环 ---
|
||||
print("开始训练 MyNewModel...")
|
||||
for epoch in range(epochs):
|
||||
model.train()
|
||||
optimizer.zero_grad()
|
||||
outputs = model(trainX)
|
||||
loss = criterion(outputs, trainY)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
if (epoch + 1) % 10 == 0:
|
||||
print(f'Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}')
|
||||
|
||||
# --- 4. 模型评估 ---
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
test_pred_scaled = model(testX)
|
||||
|
||||
# 反标准化并计算指标
|
||||
# ... (参考其他训练器)
|
||||
metrics = {'rmse': 0.0, 'mae': 0.0, 'r2': 0.0, 'mape': 0.0} # 示例
|
||||
|
||||
# --- 5. 模型保存 ---
|
||||
model_data = {
|
||||
'model_state_dict': model.state_dict(),
|
||||
'scaler_X': None, # 替换为您的 scaler_X
|
||||
'scaler_y': scaler_y,
|
||||
'config': {
|
||||
'model_type': 'mynewmodel', # **关键**: 使用唯一的模型名称
|
||||
'input_dim': input_dim,
|
||||
'hidden_size': hidden_size,
|
||||
'sequence_length': sequence_length,
|
||||
'forecast_horizon': forecast_horizon,
|
||||
'features': features
|
||||
},
|
||||
'metrics': metrics
|
||||
}
|
||||
|
||||
model_manager.save_model(
|
||||
model_data=model_data,
|
||||
product_id=product_id,
|
||||
model_type='mynewmodel', # **关键**: 再次确认模型名称
|
||||
# ... 其他参数
|
||||
)
|
||||
|
||||
print("✅ MyNewModel 模型训练并保存完成!")
|
||||
return model, metrics, "v1", "path/to/model" # 返回值遵循统一格式
|
||||
|
||||
# --- 关键步骤: 将训练器注册到系统中 ---
|
||||
register_trainer('mynewmodel', train_with_mynewmodel)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 第 3 步:集成模型预测器
|
||||
|
||||
最后,您需要让系统知道如何加载和使用您的新模型进行预测。这需要在 `ShopTRAINING/server/predictors/model_predictor.py` 中进行两处修改。
|
||||
|
||||
**文件: `ShopTRAINING/server/predictors/model_predictor.py`**
|
||||
|
||||
1. **让系统知道如何构建您的模型实例**
|
||||
|
||||
在 `load_model_and_predict` 函数中,有一个 `if/elif` 结构用于根据模型类型实例化不同的模型。您需要为 `MyNewModel` 添加一个新的分支。
|
||||
|
||||
```python
|
||||
# 在 model_predictor.py 中
|
||||
|
||||
# 首先,导入您的新模型类
|
||||
from models.my_new_model import MyNewModel
|
||||
|
||||
# ... 在 load_model_and_predict 函数内部 ...
|
||||
|
||||
# ... 其他模型的 elif 分支 ...
|
||||
elif loaded_model_type == 'tcn':
|
||||
model = TCNForecaster(...)
|
||||
|
||||
# vvv 添加这个新的分支 vvv
|
||||
elif loaded_model_type == 'mynewmodel':
|
||||
model = MyNewModel(
|
||||
input_features=config['input_dim'],
|
||||
hidden_size=config['hidden_size'],
|
||||
output_sequence_length=config['forecast_horizon']
|
||||
).to(DEVICE)
|
||||
# ^^^ 添加结束 ^^^
|
||||
|
||||
else:
|
||||
raise ValueError(f"不支持的模型类型: {loaded_model_type}")
|
||||
|
||||
model.load_state_dict(checkpoint['model_state_dict'])
|
||||
model.eval()
|
||||
```
|
||||
|
||||
2. **注册预测逻辑**
|
||||
|
||||
如果您的模型是一个标准的 PyTorch 模型,并且其预测逻辑与现有的模型(如 Transformer, KAN)相同,您可以直接复用 `default_pytorch_predictor`。只需在文件末尾添加一行注册代码即可。
|
||||
|
||||
```python
|
||||
# 在 model_predictor.py 文件末尾
|
||||
|
||||
# ...
|
||||
# 将增强后的默认预测器也注册给xgboost
|
||||
register_predictor('xgboost', default_pytorch_predictor)
|
||||
|
||||
# vvv 添加这行代码 vvv
|
||||
# 让 'mynewmodel' 也使用通用的 PyTorch 预测器
|
||||
register_predictor('mynewmodel', default_pytorch_predictor)
|
||||
# ^^^ 添加结束 ^^^
|
||||
```
|
||||
|
||||
如果您的模型需要特殊的预测逻辑(例如,像 XGBoost 那样有不同的输入格式或调用方式),您可以复制 `default_pytorch_predictor` 创建一个新函数,修改其内部逻辑,然后将新函数注册给 `'mynewmodel'`。
|
||||
|
||||
---
|
||||
|
||||
## 总结
|
||||
|
||||
完成以上三个步骤后,您的新模型 `MyNewModel` 就已完全集成到系统中了。系统会自动在 `trainers` 目录中发现您的新训练器。当您通过 API 或界面选择 `mynewmodel` 进行训练和预测时,系统将自动调用您刚刚编写和注册的所有相应逻辑。
|
Loading…
x
Reference in New Issue
Block a user