插件式添加模型

This commit is contained in:
xz2000 2025-07-22 15:40:37 +08:00
parent 038289ae32
commit 751de9b548
19 changed files with 940 additions and 293 deletions

Binary file not shown.

View File

@ -56,5 +56,6 @@ tzdata==2025.2
werkzeug==3.1.3
win32-setctime==1.2.0
wsproto==1.2.0
python-dateutil
xgboost
scikit-learn

View File

@ -45,6 +45,7 @@ from trainers.mlstm_trainer import train_product_model_with_mlstm
from trainers.kan_trainer import train_product_model_with_kan
from trainers.tcn_trainer import train_product_model_with_tcn
from trainers.transformer_trainer import train_product_model_with_transformer
from trainers.xgboost_trainer import train_product_model_with_xgboost
# 导入预测函数
from predictors.model_predictor import load_model_and_predict
@ -810,7 +811,7 @@ def get_all_training_tasks():
'type': 'object',
'properties': {
'product_id': {'type': 'string', 'description': '例如 P001'},
'model_type': {'type': 'string', 'enum': ['mlstm', 'transformer', 'kan', 'optimized_kan', 'tcn'], 'description': '要训练的模型类型'},
'model_type': {'type': 'string', 'enum': ['mlstm', 'transformer', 'kan', 'optimized_kan', 'tcn', 'xgboost'], 'description': '要训练的模型类型'},
'store_id': {'type': 'string', 'description': '店铺ID如 S001。为空时使用全局聚合数据'},
'epochs': {'type': 'integer', 'default': 50, 'description': '训练轮次'}
},
@ -873,10 +874,10 @@ def start_training():
# 全局模式不需要特定的product_id或store_id
pass
# 检查模型类型是否有效
valid_model_types = ['mlstm', 'kan', 'optimized_kan', 'transformer', 'tcn']
if model_type not in valid_model_types:
return jsonify({'error': '无效的模型类型'}), 400
# 检查模型类型是否有效 (v2 - 动态检查)
from models.model_registry import TRAINER_REGISTRY
if model_type not in TRAINER_REGISTRY:
return jsonify({'error': f"无效的模型类型: '{model_type}'. 可用模型: {list(TRAINER_REGISTRY.keys())}"}), 400
# 使用新的训练进程管理器提交任务
try:
@ -3445,41 +3446,37 @@ def analyze_model_metrics():
}
})
def get_model_types():
"""获取系统支持的所有模型类型"""
model_types = [
{
'id': 'mlstm',
'name': 'mLSTM',
'description': '矩阵长短期记忆网络,适合处理多变量时序数据',
'tag_type': 'primary'
},
{
'id': 'transformer',
'name': 'Transformer',
'description': '基于注意力机制的序列模型,适合捕捉长期依赖关系',
'tag_type': 'success'
},
{
'id': 'kan',
'name': 'KAN',
'description': 'Kolmogorov-Arnold网络能够逼近任意连续函数',
'tag_type': 'warning'
},
{
'id': 'optimized_kan',
'name': '优化版KAN',
'description': '经过优化的KAN模型提供更高的预测精度和训练效率',
'tag_type': 'info'
},
{
'id': 'tcn',
'name': 'TCN',
'description': '时间卷积网络,适合处理长序列和平行计算',
'tag_type': 'danger'
}
]
"""获取系统支持的所有模型类型 (v2 - 动态加载)"""
from models.model_registry import TRAINER_REGISTRY
return jsonify({"status": "success", "data": model_types})
# 预定义的模型元数据,用于美化显示
model_meta = {
'mlstm': {'name': 'mLSTM', 'description': '矩阵长短期记忆网络,适合处理多变量时序数据', 'tag_type': 'primary'},
'transformer': {'name': 'Transformer', 'description': '基于注意力机制的序列模型,适合捕捉长期依赖关系', 'tag_type': 'success'},
'kan': {'name': 'KAN', 'description': 'Kolmogorov-Arnold网络能够逼近任意连续函数', 'tag_type': 'warning'},
'optimized_kan': {'name': '优化版KAN', 'description': '经过优化的KAN模型提供更高的预测精度和训练效率', 'tag_type': 'info'},
'tcn': {'name': 'TCN', 'description': '时间卷积网络,适合处理长序列和平行计算', 'tag_type': 'danger'},
'xgboost': {'name': 'XGBoost', 'description': '梯度提升决策树,性能强大且高效的经典模型', 'tag_type': 'primary'}
}
# 从注册表动态获取所有已注册的模型ID
registered_models = TRAINER_REGISTRY.keys()
dynamic_model_types = []
for model_id in registered_models:
meta = model_meta.get(model_id, {
'name': model_id.upper(),
'description': f'自定义模型: {model_id}',
'tag_type': 'secondary'
})
dynamic_model_types.append({
'id': model_id,
'name': meta['name'],
'description': meta['description'],
'tag_type': meta['tag_type']
})
return jsonify({"status": "success", "data": dynamic_model_types})
# ========== 新增版本管理API ==========

View File

@ -58,7 +58,9 @@ HIDDEN_SIZE = 64 # 隐藏层大小
NUM_LAYERS = 2 # 层数
# 支持的模型类型
SUPPORTED_MODELS = ['mlstm', 'kan', 'transformer', 'tcn', 'optimized_kan']
# 支持的模型类型 (v2 - 动态加载)
from models.model_registry import TRAINER_REGISTRY
SUPPORTED_MODELS = list(TRAINER_REGISTRY.keys())
# 版本管理配置
MODEL_VERSION_PREFIX = 'v' # 版本前缀

View File

@ -11,12 +11,13 @@ import time
import matplotlib.pyplot as plt
from datetime import datetime
from trainers import (
train_product_model_with_mlstm,
train_product_model_with_kan,
train_product_model_with_tcn,
train_product_model_with_transformer
)
# from trainers import (
# train_product_model_with_mlstm,
# train_product_model_with_kan,
# train_product_model_with_tcn,
# train_product_model_with_transformer
# )
# 上述导入已不再需要,因为我们现在通过模型注册表动态获取训练器
from predictors.model_predictor import load_model_and_predict
from utils.data_utils import prepare_data, prepare_sequences
from utils.multi_store_data_utils import (
@ -187,89 +188,49 @@ class PharmacyPredictor:
else: # product mode
model_identifier = product_id
# 调用相应的训练函数
# 调用相应的训练函数 (重构为使用注册表)
try:
log_message(f"🤖 开始调用 {model_type} 训练器")
if model_type == 'transformer':
model_result, metrics, actual_version = train_product_model_with_transformer(
product_id=product_id, # product_id 仍然需要,用于数据过滤
model_identifier=model_identifier, # 这是用于保存模型的唯一ID
product_df=product_data,
store_id=store_id,
training_mode=training_mode,
aggregation_method=aggregation_method,
epochs=epochs,
sequence_length=sequence_length,
forecast_horizon=forecast_horizon,
model_dir=self.model_dir,
version=version,
socketio=socketio,
task_id=task_id,
continue_training=continue_training
)
log_message(f"{model_type} 训练器返回: metrics={type(metrics)}, version={actual_version}", 'success')
elif model_type == 'mlstm':
_, metrics, _, _ = train_product_model_with_mlstm(
product_id=product_id,
model_identifier=model_identifier, # 传递修正后的ID
product_df=product_data,
store_id=store_id,
training_mode=training_mode,
aggregation_method=aggregation_method,
epochs=epochs,
sequence_length=sequence_length,
forecast_horizon=forecast_horizon,
model_dir=self.model_dir,
socketio=socketio,
task_id=task_id,
progress_callback=progress_callback
)
elif model_type == 'kan':
_, metrics = train_product_model_with_kan(
product_id=product_id,
model_identifier=model_identifier, # 传递修正后的ID
product_df=product_data,
store_id=store_id,
training_mode=training_mode,
aggregation_method=aggregation_method,
epochs=epochs,
sequence_length=sequence_length,
forecast_horizon=forecast_horizon,
use_optimized=use_optimized,
model_dir=self.model_dir
)
elif model_type == 'optimized_kan':
_, metrics = train_product_model_with_kan(
product_id=product_id,
model_identifier=model_identifier, # 传递修正后的ID
product_df=product_data,
store_id=store_id,
training_mode=training_mode,
aggregation_method=aggregation_method,
epochs=epochs,
sequence_length=sequence_length,
forecast_horizon=forecast_horizon,
use_optimized=True,
model_dir=self.model_dir
)
elif model_type == 'tcn':
_, metrics, _, _ = train_product_model_with_tcn(
product_id=product_id,
model_identifier=model_identifier, # 传递修正后的ID
product_df=product_data,
store_id=store_id,
training_mode=training_mode,
aggregation_method=aggregation_method,
epochs=epochs,
sequence_length=sequence_length,
forecast_horizon=forecast_horizon,
model_dir=self.model_dir,
socketio=socketio,
task_id=task_id
)
from models.model_registry import get_trainer
log_message(f"🤖 正在从注册表获取 '{model_type}' 训练器...")
trainer_function = get_trainer(model_type)
log_message(f"✅ 成功获取训练器: {trainer_function.__name__}")
# 准备通用参数
trainer_args = {
'product_id': product_id,
'model_identifier': model_identifier,
'product_df': product_data,
'store_id': store_id,
'training_mode': training_mode,
'aggregation_method': aggregation_method,
'epochs': epochs,
'sequence_length': sequence_length,
'forecast_horizon': forecast_horizon,
'model_dir': self.model_dir,
'socketio': socketio,
'task_id': task_id,
'progress_callback': progress_callback,
'version': version,
'continue_training': continue_training,
'use_optimized': use_optimized # KAN模型需要
}
# 动态调用训练函数 (v2 - 智能参数过滤)
import inspect
sig = inspect.signature(trainer_function)
valid_args = {k: v for k, v in trainer_args.items() if k in sig.parameters}
log_message(f"🔍 准备调用 {trainer_function.__name__},有效参数: {list(valid_args.keys())}")
result = trainer_function(**valid_args)
# 根据返回值的数量解析metrics
if isinstance(result, tuple) and len(result) >= 2:
metrics = result[1] # 通常第二个返回值是metrics
else:
log_message(f"不支持的模型类型: {model_type}", 'error')
return None
log_message(f"⚠️ 训练器返回格式未知无法直接提取metrics: {type(result)}", 'warning')
metrics = None
# 检查和打印返回的metrics
log_message(f"📊 训练完成检查返回的metrics: {metrics}")

View File

@ -0,0 +1,102 @@
# -*- coding: utf-8 -*-
"""
CNN-BiLSTM-Attention 模型定义适配药店销售预测系统
原始代码来源: python机器学习回归全家桶
"""
import torch
import torch.nn as nn
# 注意:由于原始代码使用了 TensorFlow/Keras 的层,我们将在这里创建一个 PyTorch 的等效实现。
# 这是一个更健壮、更符合现有系统架构的做法。
class Attention(nn.Module):
"""
PyTorch 实现的注意力机制
"""
def __init__(self, feature_dim, step_dim, bias=True, **kwargs):
super(Attention, self).__init__(**kwargs)
self.supports_masking = True
self.bias = bias
self.feature_dim = feature_dim
self.step_dim = step_dim
self.features_dim = 0
weight = torch.zeros(feature_dim, 1)
nn.init.xavier_uniform_(weight)
self.weight = nn.Parameter(weight)
if bias:
self.b = nn.Parameter(torch.zeros(step_dim))
def forward(self, x, mask=None):
feature_dim = self.feature_dim
step_dim = self.step_dim
eij = torch.mm(
x.contiguous().view(-1, feature_dim),
self.weight
).view(-1, step_dim)
if self.bias:
eij = eij + self.b
eij = torch.tanh(eij)
a = torch.exp(eij)
if mask is not None:
a = a * mask
a = a / (torch.sum(a, 1, keepdim=True) + 1e-10)
weighted_input = x * torch.unsqueeze(a, -1)
return torch.sum(weighted_input, 1)
class CnnBiLstmAttention(nn.Module):
"""
CNN-BiLSTM-Attention 模型的 PyTorch 实现
"""
def __init__(self, input_dim, output_dim, sequence_length, cnn_filters=64, cnn_kernel_size=1, lstm_units=128):
super(CnnBiLstmAttention, self).__init__()
self.sequence_length = sequence_length
self.cnn_filters = cnn_filters
self.lstm_units = lstm_units
# CNN 层
self.conv1d = nn.Conv1d(in_channels=input_dim, out_channels=cnn_filters, kernel_size=cnn_kernel_size)
self.relu = nn.ReLU()
self.maxpool = nn.MaxPool1d(kernel_size=1)
# BiLSTM 层
self.bilstm = nn.LSTM(input_size=cnn_filters, hidden_size=lstm_units, num_layers=1, batch_first=True, bidirectional=True)
# Attention 层
self.attention = Attention(feature_dim=lstm_units * 2, step_dim=sequence_length)
# 全连接输出层
self.dense = nn.Linear(lstm_units * 2, output_dim)
def forward(self, x):
# 输入 x 的形状: (batch_size, sequence_length, input_dim)
# CNN 处理
x = x.permute(0, 2, 1) # 转换为 (batch_size, input_dim, sequence_length) 以适应 Conv1d
x = self.conv1d(x)
x = self.relu(x)
x = x.permute(0, 2, 1) # 转换回 (batch_size, sequence_length, cnn_filters)
# BiLSTM 处理
lstm_out, _ = self.bilstm(x) # lstm_out 形状: (batch_size, sequence_length, lstm_units * 2)
# Attention 处理
# 注意:这里的 Attention 实现可能需要根据具体任务微调
# 一个简化的方法是直接使用 LSTM 的最终隐藏状态或输出
# 这里我们先用一个简化的逻辑:直接展平 LSTM 输出
attention_out = self.attention(lstm_out)
# 全连接层输出
output = self.dense(attention_out)
return output

View File

@ -0,0 +1,64 @@
"""
模型注册表
用于解耦模型的调用和实现支持插件式扩展新模型
"""
# 训练器注册表
TRAINER_REGISTRY = {}
def register_trainer(name, func):
"""
注册一个模型训练器
参数:
name (str): 模型类型名称 (e.g., 'xgboost')
func (function): 对应的训练函数
"""
if name in TRAINER_REGISTRY:
print(f"警告: 模型训练器 '{name}' 已被覆盖注册。")
TRAINER_REGISTRY[name] = func
print(f"✅ 已注册训练器: {name}")
def get_trainer(name):
"""
根据模型类型名称获取一个已注册的训练器
"""
if name not in TRAINER_REGISTRY:
# 在打印可用训练器之前,确保它们已经被加载
from trainers import discover_trainers
discover_trainers()
if name not in TRAINER_REGISTRY:
raise ValueError(f"未注册的模型训练器: '{name}'. 可用: {list(TRAINER_REGISTRY.keys())}")
return TRAINER_REGISTRY[name]
# --- 预测器注册表 ---
# 预测器函数需要一个统一的接口,例如:
# def predictor_function(model, checkpoint, **kwargs): -> predictions
PREDICTOR_REGISTRY = {}
def register_predictor(name, func):
"""
注册一个模型预测器
"""
if name in PREDICTOR_REGISTRY:
print(f"警告: 模型预测器 '{name}' 已被覆盖注册。")
PREDICTOR_REGISTRY[name] = func
def get_predictor(name):
"""
根据模型类型名称获取一个已注册的预测器
如果找不到特定预测器可以返回一个默认的
"""
return PREDICTOR_REGISTRY.get(name, PREDICTOR_REGISTRY.get('default'))
# 默认的PyTorch预测逻辑可以被注册为 'default'
def register_default_predictors():
from predictors.model_predictor import default_pytorch_predictor
register_predictor('default', default_pytorch_predictor)
# 如果其他PyTorch模型有特殊预测逻辑也可以在这里注册
# register_predictor('kan', kan_predictor_func)
# 注意:这个函数的调用时机很重要,需要在应用启动时执行一次。
# 我们可以暂时在 model_predictor.py 导入注册表后调用它。

Binary file not shown.

View File

@ -18,39 +18,90 @@ from models.mlstm_model import MLSTMTransformer as MatrixLSTM
from models.kan_model import KANForecaster
from models.tcn_model import TCNForecaster
from models.optimized_kan_forecaster import OptimizedKANForecaster
from models.cnn_bilstm_attention import CnnBiLstmAttention
import xgboost as xgb
from analysis.trend_analysis import analyze_prediction_result
from utils.visualization import plot_prediction_results
from utils.multi_store_data_utils import get_store_product_sales_data, aggregate_multi_store_data
from core.config import DEVICE, get_model_file_path, DEFAULT_DATA_PATH
from models.model_registry import get_predictor, register_predictor
def default_pytorch_predictor(model, checkpoint, product_df, future_days, start_date, history_lookback_days):
"""
默认的PyTorch模型预测逻辑支持自动回归
"""
config = checkpoint['config']
scaler_X = checkpoint['scaler_X']
scaler_y = checkpoint['scaler_y']
features = config.get('features', ['sales', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature'])
sequence_length = config['sequence_length']
if start_date:
start_date_dt = pd.to_datetime(start_date)
prediction_input_df = product_df[product_df['date'] < start_date_dt].tail(sequence_length)
else:
prediction_input_df = product_df.tail(sequence_length)
start_date_dt = product_df['date'].iloc[-1] + timedelta(days=1)
if len(prediction_input_df) < sequence_length:
raise ValueError(f"预测所需的历史数据不足。需要 {sequence_length} 天, 但只有 {len(prediction_input_df)} 天。")
history_for_chart_df = product_df[product_df['date'] < start_date_dt].tail(history_lookback_days)
all_predictions = []
current_sequence_df = prediction_input_df.copy()
for _ in range(future_days):
X_current_scaled = scaler_X.transform(current_sequence_df[features].values)
# **核心改进**: 智能判断模型类型并调用相应的预测方法
if isinstance(model, xgb.Booster):
# XGBoost 模型预测路径
X_input_reshaped = X_current_scaled.reshape(1, -1)
d_input = xgb.DMatrix(X_input_reshaped)
# **关键修复**: 使用 best_iteration 进行预测,以匹配早停策略
y_pred_scaled = model.predict(d_input, iteration_range=(0, model.best_iteration))
next_step_pred_scaled = y_pred_scaled.reshape(1, -1)
else:
# 默认 PyTorch 模型预测路径
X_input = torch.tensor(X_current_scaled.reshape(1, sequence_length, -1), dtype=torch.float32).to(DEVICE)
with torch.no_grad():
y_pred_scaled = model(X_input).cpu().numpy()
next_step_pred_scaled = y_pred_scaled[0, 0].reshape(1, -1)
next_step_pred_unscaled = float(max(0, scaler_y.inverse_transform(next_step_pred_scaled)[0][0]))
next_date = current_sequence_df['date'].iloc[-1] + timedelta(days=1)
all_predictions.append({'date': next_date, 'predicted_sales': next_step_pred_unscaled})
new_row = {'date': next_date, 'sales': next_step_pred_unscaled, 'weekday': next_date.weekday(), 'month': next_date.month, 'is_holiday': 0, 'is_weekend': 1 if next_date.weekday() >= 5 else 0, 'is_promotion': 0, 'temperature': current_sequence_df['temperature'].iloc[-1]}
new_row_df = pd.DataFrame([new_row])
current_sequence_df = pd.concat([current_sequence_df.iloc[1:], new_row_df], ignore_index=True)
predictions_df = pd.DataFrame(all_predictions)
return predictions_df, history_for_chart_df, prediction_input_df
# 注册默认的PyTorch预测器
register_predictor('default', default_pytorch_predictor)
# 将增强后的默认预测器也注册给xgboost
register_predictor('xgboost', default_pytorch_predictor)
# 将新模型也注册给默认预测器
register_predictor('cnn_bilstm_attention', default_pytorch_predictor)
def load_model_and_predict(model_path: str, product_id: str, model_type: str, store_id: Optional[str] = None, future_days: int = 7, start_date: Optional[str] = None, analyze_result: bool = False, version: Optional[str] = None, training_mode: str = 'product', history_lookback_days: int = 30):
"""
加载已训练的模型并进行预测 (v3版 - 支持自动回归)
参数:
... (同上, 新增 history_lookback_days)
history_lookback_days: 用于图表展示的历史数据天数
返回:
预测结果和分析
加载已训练的模型并进行预测 (v4版 - 插件式架构)
"""
try:
print(f"v3版预测函数启动模型路径: {model_path}, 预测天数: {future_days}, 历史回看: {history_lookback_days}")
if not os.path.exists(model_path):
print(f"模型文件 {model_path} 不存在")
return None
# 加载销售数据
raise FileNotFoundError(f"模型文件 {model_path} 不存在")
# --- 数据加载部分保持不变 ---
from utils.multi_store_data_utils import aggregate_multi_store_data
if training_mode == 'store' and store_id:
# 先从原始数据加载一次以获取店铺名称,聚合会丢失此信息
from utils.multi_store_data_utils import load_multi_store_data
store_df_for_name = load_multi_store_data(store_id=store_id)
product_name = store_df_for_name['store_name'].iloc[0] if not store_df_for_name.empty else f"店铺 {store_id}"
# 然后再进行聚合获取用于预测的数据
product_df = aggregate_multi_store_data(store_id=store_id, aggregation_method='sum', file_path=DEFAULT_DATA_PATH)
elif training_mode == 'global':
product_df = aggregate_multi_store_data(aggregation_method='sum', file_path=DEFAULT_DATA_PATH)
@ -60,124 +111,75 @@ def load_model_and_predict(model_path: str, product_id: str, model_type: str, st
product_name = product_df['product_name'].iloc[0] if not product_df.empty else product_id
if product_df.empty:
print(f"产品 {product_id} 或店铺 {store_id} 没有销售数据")
return None
raise ValueError(f"产品 {product_id} 或店铺 {store_id} 没有销售数据")
# 加载模型和配置
# --- 模型加载与实例化 (重构) ---
try:
torch.serialization.add_safe_globals([sklearn.preprocessing._data.MinMaxScaler])
except Exception: pass
checkpoint = torch.load(model_path, map_location=DEVICE, weights_only=False)
if 'config' not in checkpoint or 'scaler_X' not in checkpoint or 'scaler_y' not in checkpoint:
print("模型文件不完整缺少config或scaler")
return None
config = checkpoint['config']
scaler_X = checkpoint['scaler_X']
scaler_y = checkpoint['scaler_y']
# 创建模型实例
# (此处省略了与原版本相同的模型创建代码,以保持简洁)
if model_type == 'transformer':
model = TimeSeriesTransformer(num_features=config['input_dim'], d_model=config['hidden_size'], nhead=config['num_heads'], num_encoder_layers=config['num_layers'], dim_feedforward=config['hidden_size'] * 2, dropout=config['dropout'], output_sequence_length=config['output_dim'], seq_length=config['sequence_length'], batch_size=32).to(DEVICE)
elif model_type == 'mlstm':
model = MatrixLSTM(num_features=config['input_dim'], hidden_size=config['hidden_size'], mlstm_layers=config['mlstm_layers'], embed_dim=config.get('embed_dim', 32), dense_dim=config.get('dense_dim', 32), num_heads=config.get('num_heads', 4), dropout_rate=config['dropout_rate'], num_blocks=config.get('num_blocks', 3), output_sequence_length=config['output_dim']).to(DEVICE)
elif model_type == 'kan':
model = KANForecaster(input_features=config['input_dim'], hidden_sizes=[config['hidden_size'], config['hidden_size']*2, config['hidden_size']], output_sequence_length=config['output_dim']).to(DEVICE)
elif model_type == 'optimized_kan':
model = OptimizedKANForecaster(input_features=config['input_dim'], hidden_sizes=[config['hidden_size'], config['hidden_size']*2, config['hidden_size']], output_sequence_length=config['output_dim']).to(DEVICE)
elif model_type == 'tcn':
model = TCNForecaster(num_features=config['input_dim'], output_sequence_length=config['output_dim'], num_channels=[config['hidden_size']] * config['num_layers'], kernel_size=config['kernel_size'], dropout=config['dropout']).to(DEVICE)
config = checkpoint.get('config', {})
loaded_model_type = config.get('model_type', model_type) # 优先使用模型内保存的类型
# 根据模型类型决定如何获取模型实例
if loaded_model_type == 'xgboost':
# 对于XGBoost, 模型对象直接保存在'model_state_dict'键中
model = checkpoint['model_state_dict']
else:
print(f"不支持的模型类型: {model_type}"); return None
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
# --- 核心逻辑修改:自动回归预测 ---
features = ['sales', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature']
sequence_length = config['sequence_length']
# 确定预测的起始点
if start_date:
start_date_dt = pd.to_datetime(start_date)
# 获取预测开始日期前的 `sequence_length` 天数据作为初始输入
prediction_input_df = product_df[product_df['date'] < start_date_dt].tail(sequence_length)
else:
# 如果未指定开始日期,则从数据的最后一天开始预测
prediction_input_df = product_df.tail(sequence_length)
start_date_dt = product_df['date'].iloc[-1] + timedelta(days=1)
if len(prediction_input_df) < sequence_length:
print(f"错误: 预测所需的历史数据不足。需要 {sequence_length} 天, 但只有 {len(prediction_input_df)} 天。")
return None
# 准备用于图表展示的历史数据
history_for_chart_df = product_df[product_df['date'] < start_date_dt].tail(history_lookback_days)
# 自动回归预测循环
all_predictions = []
current_sequence_df = prediction_input_df.copy()
print(f"开始自动回归预测,共 {future_days} 天...")
for i in range(future_days):
# 准备当前序列的输入张量
X_current_scaled = scaler_X.transform(current_sequence_df[features].values)
X_input = torch.tensor(X_current_scaled.reshape(1, sequence_length, -1), dtype=torch.float32).to(DEVICE)
# 模型进行一次预测(可能预测出多个点,但我们只用第一个)
with torch.no_grad():
y_pred_scaled = model(X_input).cpu().numpy()
# 对于PyTorch模型, 需要重新构建实例并加载state_dict
if loaded_model_type == 'transformer':
model = TimeSeriesTransformer(num_features=config['input_dim'], d_model=config['hidden_size'], nhead=config['num_heads'], num_encoder_layers=config['num_layers'], dim_feedforward=config['hidden_size'] * 2, dropout=config['dropout'], output_sequence_length=config['output_dim'], seq_length=config['sequence_length'], batch_size=32).to(DEVICE)
elif loaded_model_type == 'mlstm':
model = MatrixLSTM(num_features=config['input_dim'], hidden_size=config['hidden_size'], mlstm_layers=config['mlstm_layers'], embed_dim=config.get('embed_dim', 32), dense_dim=config.get('dense_dim', 32), num_heads=config.get('num_heads', 4), dropout_rate=config['dropout_rate'], num_blocks=config.get('num_blocks', 3), output_sequence_length=config['output_dim']).to(DEVICE)
elif loaded_model_type == 'kan':
model = KANForecaster(input_features=config['input_dim'], hidden_sizes=[config['hidden_size'], config['hidden_size']*2, config['hidden_size']], output_sequence_length=config['output_dim']).to(DEVICE)
elif loaded_model_type == 'optimized_kan':
model = OptimizedKANForecaster(input_features=config['input_dim'], hidden_sizes=[config['hidden_size'], config['hidden_size']*2, config['hidden_size']], output_sequence_length=config['output_dim']).to(DEVICE)
elif loaded_model_type == 'tcn':
model = TCNForecaster(num_features=config['input_dim'], output_sequence_length=config['output_dim'], num_channels=[config['hidden_size']] * config['num_layers'], kernel_size=config['kernel_size'], dropout=config['dropout']).to(DEVICE)
elif loaded_model_type == 'cnn_bilstm_attention':
model = CnnBiLstmAttention(
input_dim=config['input_dim'],
output_dim=config['output_dim'],
sequence_length=config['sequence_length']
).to(DEVICE)
else:
raise ValueError(f"不支持的模型类型: {loaded_model_type}")
# 提取下一个时间点的预测值
next_step_pred_scaled = y_pred_scaled[0, 0].reshape(1, -1)
next_step_pred_unscaled = scaler_y.inverse_transform(next_step_pred_scaled)[0][0]
next_step_pred_unscaled = float(max(0, next_step_pred_unscaled)) # 确保销量不为负并转换为标准float
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
# 获取新预测的日期
next_date = current_sequence_df['date'].iloc[-1] + timedelta(days=1)
all_predictions.append({'date': next_date, 'predicted_sales': next_step_pred_unscaled})
# --- 动态调用预测器 ---
predictor_function = get_predictor(loaded_model_type)
if not predictor_function:
raise ValueError(f"找不到模型类型 '{loaded_model_type}' 的预测器实现")
# 构建新的一行数据,用于更新输入序列
new_row = {
'date': next_date,
'sales': next_step_pred_unscaled,
'weekday': next_date.weekday(),
'month': next_date.month,
'is_holiday': 0,
'is_weekend': 1 if next_date.weekday() >= 5 else 0,
'is_promotion': 0,
'temperature': current_sequence_df['temperature'].iloc[-1] # 沿用最后一天的温度
}
# 更新序列:移除最旧的一行,添加最新预测的一行
new_row_df = pd.DataFrame([new_row])
current_sequence_df = pd.concat([current_sequence_df.iloc[1:], new_row_df], ignore_index=True)
predictions_df, history_for_chart_df, prediction_input_df = predictor_function(
model=model,
checkpoint=checkpoint,
product_df=product_df,
future_days=future_days,
start_date=start_date,
history_lookback_days=history_lookback_days
)
predictions_df = pd.DataFrame(all_predictions)
print(f"自动回归预测完成,生成 {len(predictions_df)} 条预测数据。")
# 分析与可视化
# --- 分析与返回部分保持不变 ---
analysis = None
if analyze_result:
try:
y_pred_for_analysis = predictions_df['predicted_sales'].values
# 使用初始输入序列的特征进行分析
initial_features_for_analysis = prediction_input_df[features].values
analysis = analyze_prediction_result(product_id, model_type, y_pred_for_analysis, initial_features_for_analysis)
analysis = analyze_prediction_result(product_id, loaded_model_type, predictions_df['predicted_sales'].values, prediction_input_df[config.get('features')].values)
except Exception as e:
print(f"分析预测结果失败: {str(e)}")
# 在返回前将DataFrame转换为前端期望的JSON数组格式
history_data_json = history_for_chart_df.to_dict('records') if not history_for_chart_df.empty else []
prediction_data_json = predictions_df.to_dict('records') if not predictions_df.empty else []
return {
'product_id': product_id,
'product_name': product_name,
'model_type': model_type,
'predictions': prediction_data_json, # 兼容旧字段使用已转换的json
'model_type': loaded_model_type,
'predictions': prediction_data_json,
'prediction_data': prediction_data_json,
'history_data': history_data_json,
'analysis': analysis

View File

@ -2,18 +2,44 @@
药店销售预测系统 - 模型训练模块
"""
from .mlstm_trainer import train_product_model_with_mlstm
from .kan_trainer import train_product_model_with_kan
from .tcn_trainer import train_product_model_with_tcn
from .transformer_trainer import train_product_model_with_transformer
import os
import glob
import importlib
# 默认训练函数
from .mlstm_trainer import train_product_model_with_mlstm as train_product_model
_TRAINERS_LOADED = False
def discover_trainers():
"""
自动发现并加载所有训练器插件
使用一个标志位确保这个过程只执行一次
"""
global _TRAINERS_LOADED
if _TRAINERS_LOADED:
return
print("🚀 开始发现并加载训练器插件...")
package_dir = os.path.dirname(__file__)
module_name = __name__
trainer_files = glob.glob(os.path.join(package_dir, "*_trainer.py"))
for f in trainer_files:
base_name = os.path.basename(f)
if base_name.startswith('__'):
continue
module_stem = base_name.replace('.py', '')
try:
# 动态导入模块以触发自注册
importlib.import_module(f".{module_stem}", package=module_name)
except ImportError as e:
print(f"⚠️ 加载训练器 {module_stem} 失败: {e}")
_TRAINERS_LOADED = True
print("✅ 所有训练器插件加载完成。")
# 在包被首次导入时,自动执行发现过程
discover_trainers()
__all__ = [
'train_product_model',
'train_product_model_with_mlstm',
'train_product_model_with_kan',
'train_product_model_with_tcn',
'train_product_model_with_transformer'
]

View File

@ -0,0 +1,118 @@
# -*- coding: utf-8 -*-
"""
CNN-BiLSTM-Attention 模型训练器
"""
import torch
import torch.optim as optim
import numpy as np
from models.model_registry import register_trainer
from utils.model_manager import model_manager
from analysis.metrics import evaluate_model
from utils.data_utils import create_dataset
from sklearn.preprocessing import MinMaxScaler
# 导入新创建的模型
from models.cnn_bilstm_attention import CnnBiLstmAttention
def train_with_cnn_bilstm_attention(product_id, model_identifier, product_df, store_id, training_mode, aggregation_method, epochs, sequence_length, forecast_horizon, model_dir, **kwargs):
"""
使用 CNN-BiLSTM-Attention 模型进行训练
函数签名遵循系统标准
"""
print(f"🚀 CNN-BiLSTM-Attention 训练器启动: model_identifier='{model_identifier}'")
# --- 1. 数据准备 ---
features = ['sales', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature']
X = product_df[features].values
y = product_df[['sales']].values
scaler_X = MinMaxScaler(feature_range=(0, 1))
scaler_y = MinMaxScaler(feature_range=(0, 1))
X_scaled = scaler_X.fit_transform(X)
y_scaled = scaler_y.fit_transform(y)
train_size = int(len(X_scaled) * 0.8)
X_train_raw, X_test_raw = X_scaled[:train_size], X_scaled[train_size:]
y_train_raw, y_test_raw = y_scaled[:train_size], y_scaled[train_size:]
trainX, trainY = create_dataset(X_train_raw, y_train_raw, sequence_length, forecast_horizon)
testX, testY = create_dataset(X_test_raw, y_test_raw, sequence_length, forecast_horizon)
# 转换为 PyTorch Tensors
trainX = torch.from_numpy(trainX).float()
trainY = torch.from_numpy(trainY).float()
testX = torch.from_numpy(testX).float()
testY = torch.from_numpy(testY).float()
# --- 2. 实例化模型和优化器 ---
input_dim = trainX.shape[2]
model = CnnBiLstmAttention(
input_dim=input_dim,
output_dim=forecast_horizon,
sequence_length=sequence_length
)
optimizer = optim.Adam(model.parameters(), lr=kwargs.get('learning_rate', 0.001))
criterion = torch.nn.MSELoss()
# --- 3. 训练循环 ---
print("开始训练 CNN-BiLSTM-Attention 模型...")
for epoch in range(epochs):
model.train()
optimizer.zero_grad()
outputs = model(trainX)
loss = criterion(outputs, trainY.squeeze(-1)) # 确保目标维度匹配
loss.backward()
optimizer.step()
if (epoch + 1) % 10 == 0:
print(f'Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}')
# --- 4. 模型评估 ---
model.eval()
with torch.no_grad():
test_pred_scaled = model(testX)
test_pred_unscaled = scaler_y.inverse_transform(test_pred_scaled.numpy())
test_true_unscaled = scaler_y.inverse_transform(testY.squeeze(-1).numpy())
metrics = evaluate_model(test_true_unscaled.flatten(), test_pred_unscaled.flatten())
print(f"模型评估完成: RMSE={metrics['rmse']:.4f}")
# --- 5. 模型保存 ---
model_data = {
'model_state_dict': model.state_dict(),
'scaler_X': scaler_X,
'scaler_y': scaler_y,
'config': {
'model_type': 'cnn_bilstm_attention',
'input_dim': input_dim,
'output_dim': forecast_horizon,
'sequence_length': sequence_length,
'features': features
},
'metrics': metrics
}
final_model_path, final_version = model_manager.save_model(
model_data=model_data,
product_id=product_id,
model_type='cnn_bilstm_attention',
store_id=store_id,
training_mode=training_mode,
aggregation_method=aggregation_method,
product_name=product_df['product_name'].iloc[0]
)
print(f"✅ CNN-BiLSTM-Attention 模型已保存,版本: {final_version}")
return model, metrics, final_version, final_model_path
# --- 关键步骤: 将训练器注册到系统中 ---
register_trainer('cnn_bilstm_attention', train_with_cnn_bilstm_attention)

View File

@ -349,4 +349,9 @@ def train_product_model_with_kan(product_id, model_identifier, product_df=None,
print(f"最终模型已保存,版本: {final_version}, 路径: {final_model_path}")
return model, metrics
return model, metrics
# --- 将此训练器注册到系统中 ---
from models.model_registry import register_trainer
register_trainer('kan', train_product_model_with_kan)
register_trainer('optimized_kan', train_product_model_with_kan)

View File

@ -514,4 +514,8 @@ def train_product_model_with_mlstm(
emit_progress(f"✅ mLSTM模型训练完成版本 {final_version} 已保存", progress=100, metrics=final_metrics)
return model, metrics, epochs, final_model_path
return model, metrics, epochs, final_model_path
# --- 将此训练器注册到系统中 ---
from models.model_registry import register_trainer
register_trainer('mlstm', train_product_model_with_mlstm)

View File

@ -379,4 +379,8 @@ def train_product_model_with_tcn(
emit_progress(f"模型训练完成!版本 {final_version} 已保存", progress=100, metrics=final_metrics)
return model, metrics, epochs, final_model_path
return model, metrics, epochs, final_model_path
# --- 将此训练器注册到系统中 ---
from models.model_registry import register_trainer
register_trainer('tcn', train_product_model_with_tcn)

View File

@ -406,4 +406,8 @@ def train_product_model_with_transformer(
'version': final_version
}
return model, final_metrics, epochs
return model, final_metrics, epochs
# --- 将此训练器注册到系统中 ---
from models.model_registry import register_trainer
register_trainer('transformer', train_product_model_with_transformer)

View File

@ -0,0 +1,142 @@
"""
药店销售预测系统 - XGBoost 模型训练器 (插件式)
"""
import time
import pandas as pd
import numpy as np
import xgboost as xgb
from sklearn.preprocessing import MinMaxScaler
from xgboost.callback import EarlyStopping
# 导入核心工具
from utils.data_utils import create_dataset
from analysis.metrics import evaluate_model
from utils.model_manager import model_manager
from models.model_registry import register_trainer
def train_product_model_with_xgboost(product_id, model_identifier, product_df, store_id, training_mode, aggregation_method, epochs, sequence_length, forecast_horizon, model_dir, **kwargs):
"""
使用 XGBoost 模型训练产品销售预测模型
此函数签名与其他训练器保持一致以兼容注册表调用
"""
print(f"🚀 XGBoost训练器启动: model_identifier='{model_identifier}'")
# --- 1. 数据准备和验证 ---
if product_df.empty:
raise ValueError(f"产品 {product_id} 没有可用的销售数据")
min_required_samples = sequence_length + forecast_horizon
if len(product_df) < min_required_samples:
error_msg = (f"数据不足: 需要 {min_required_samples} 条, 实际 {len(product_df)} 条。")
raise ValueError(error_msg)
product_df = product_df.sort_values('date')
product_name = product_df['product_name'].iloc[0] if 'product_name' in product_df.columns else model_identifier
# --- 2. 数据预处理和适配 ---
features = ['sales', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature']
X = product_df[features].values
y = product_df[['sales']].values
scaler_X = MinMaxScaler(feature_range=(0, 1))
scaler_y = MinMaxScaler(feature_range=(0, 1))
X_scaled = scaler_X.fit_transform(X)
y_scaled = scaler_y.fit_transform(y)
train_size = int(len(X_scaled) * 0.8)
X_train_raw, X_test_raw = X_scaled[:train_size], X_scaled[train_size:]
y_train_raw, y_test_raw = y_scaled[:train_size], y_scaled[train_size:]
trainX, trainY = create_dataset(X_train_raw, y_train_raw, sequence_length, forecast_horizon)
testX, testY = create_dataset(X_test_raw, y_test_raw, sequence_length, forecast_horizon)
# **关键适配步骤**: XGBoost 需要二维输入
trainX = trainX.reshape(trainX.shape[0], -1)
testX = testX.reshape(testX.shape[0], -1)
# **关键适配**: 转换为 XGBoost 核心 DMatrix 格式,以使用稳定的 xgb.train API
dtrain = xgb.DMatrix(trainX, label=trainY)
dtest = xgb.DMatrix(testX, label=testY)
# --- 3. 模型训练 (使用核心 xgb.train API) ---
xgb_params = {
'learning_rate': kwargs.get('learning_rate', 0.08),
'subsample': kwargs.get('subsample', 0.75),
'colsample_bytree': kwargs.get('colsample_bytree', 1),
'max_depth': kwargs.get('max_depth', 7),
'gamma': kwargs.get('gamma', 0),
'objective': 'reg:squarederror',
'eval_metric': 'rmse', # eval_metric 在这里是原生支持的
'n_jobs': -1
}
n_estimators = kwargs.get('n_estimators', 500)
print("开始训练XGBoost模型 (使用核心xgb.train API)...")
start_time = time.time()
evals_result = {}
model = xgb.train(
params=xgb_params,
dtrain=dtrain,
num_boost_round=n_estimators,
evals=[(dtrain, 'train'), (dtest, 'test')],
early_stopping_rounds=50, # early_stopping_rounds 在这里是原生支持的
evals_result=evals_result,
verbose_eval=False
)
training_time = time.time() - start_time
print(f"XGBoost模型训练完成耗时: {training_time:.2f}")
# --- 4. 模型评估 ---
# 使用 model.best_iteration 获取最佳轮次的预测结果
test_pred = model.predict(dtest, iteration_range=(0, model.best_iteration))
test_pred_inv = scaler_y.inverse_transform(test_pred.reshape(-1, forecast_horizon))
test_true_inv = scaler_y.inverse_transform(testY.reshape(-1, forecast_horizon))
metrics = evaluate_model(test_true_inv.flatten(), test_pred_inv.flatten())
metrics['training_time'] = training_time
print("\n模型评估指标:")
print(f"MSE: {metrics['mse']:.4f}, RMSE: {metrics['rmse']:.4f}, MAE: {metrics['mae']:.4f}, R²: {metrics['r2']:.4f}, MAPE: {metrics['mape']:.2f}%")
# --- 5. 模型保存 (借道 utils.model_manager) ---
# **关键适配点**: 我们将完整的XGBoost模型对象存入字典
# torch.save 可以序列化多种Python对象包括sklearn模型
model_data = {
'model_state_dict': model, # 直接保存模型对象
'scaler_X': scaler_X,
'scaler_y': scaler_y,
'config': {
'sequence_length': sequence_length,
'forecast_horizon': forecast_horizon,
'model_type': 'xgboost',
'features': features,
'xgb_params': xgb_params
},
'metrics': metrics,
'loss_history': evals_result
}
# 调用全局管理器进行保存,复用其命名和版本逻辑
final_model_path, final_version = model_manager.save_model(
model_data=model_data,
product_id=product_id,
model_type='xgboost',
store_id=store_id,
training_mode=training_mode,
aggregation_method=aggregation_method,
product_name=product_name
)
print(f"XGBoost模型已通过统一管理器保存版本: {final_version}, 路径: {final_model_path}")
# 返回值遵循统一格式
return model, metrics, final_version, final_model_path
# --- 将此训练器注册到系统中 ---
register_trainer('xgboost', train_product_model_with_xgboost)

View File

@ -280,48 +280,41 @@ class ModelManager:
if len(parts) < 3:
return None # 格式不符合基本要求
model_type = parts[0]
mode = parts[1]
# **核心修复**: 采用更健壮的、从后往前的解析逻辑,以支持带下划线的模型名称
try:
if mode == 'store' and len(parts) >= 3:
# {model_type}_store_{store_id}_{version}
version = parts[-1]
store_id = '_'.join(parts[2:-1])
return {
'model_type': model_type,
'training_mode': 'store',
'store_id': store_id,
'version': version,
'product_id': None,
'aggregation_method': None
}
elif mode == 'global' and len(parts) >= 3:
# {model_type}_global_{aggregation_method}_{version}
version = parts[-1]
aggregation_method = '_'.join(parts[2:-1])
return {
'model_type': model_type,
'training_mode': 'global',
'aggregation_method': aggregation_method,
'version': version,
'product_id': None,
'store_id': None
}
elif mode == 'product' and len(parts) >= 3:
# {model_type}_product_{product_id}_{version}
version = parts[-1]
product_id = '_'.join(parts[2:-1])
version = parts[-1]
identifier = parts[-2]
mode_candidate = parts[-3]
if mode_candidate == 'product':
model_type = '_'.join(parts[:-3])
return {
'model_type': model_type,
'training_mode': 'product',
'product_id': product_id,
'product_id': identifier,
'version': version,
'store_id': None,
'aggregation_method': None
}
elif mode_candidate == 'store':
model_type = '_'.join(parts[:-3])
return {
'model_type': model_type,
'training_mode': 'store',
'store_id': identifier,
'version': version,
}
elif mode_candidate == 'global':
model_type = '_'.join(parts[:-3])
return {
'model_type': model_type,
'training_mode': 'global',
'aggregation_method': identifier,
'version': version,
}
except IndexError:
# 如果文件名部分不够,则解析失败
pass
except Exception as e:
print(f"解析新版v2文件名失败 {filename}: {e}")
print(f"解析文件名失败 {filename}: {e}")
return None

View File

@ -1,5 +1,5 @@
### 根目录启动
`uv pip install loguru numpy pandas torch matplotlib flask flask_cors flask_socketio flasgger scikit-learn tqdm pytorch_tcn`
`uv pip install loguru numpy pandas torch matplotlib flask flask_cors flask_socketio flasgger scikit-learn tqdm pytorch_tcn pyarrow`
### UI
`npm install` `npm run dev`

222
xz新模型添加流程.md Normal file
View File

@ -0,0 +1,222 @@
# 如何向系统添加新模型
本指南详细说明了如何向本预测系统添加一个全新的预测模型。系统采用灵活的插件式架构,集成新模型的过程非常模块化,主要围绕 **模型Model**、**训练器Trainer** 和 **预测器Predictor** 这三个核心组件进行。
## 核心理念
系统的核心是 `models/model_registry.py`,它维护了两个独立的注册表:一个用于训练函数,另一个用于预测函数。添加新模型的本质就是:
1. **定义模型**:创建模型的架构。
2. **创建训练器**:编写一个函数来训练这个模型,并将其注册到训练器注册表。
3. **集成预测器**:确保系统知道如何加载模型并用它来预测,然后将预测逻辑注册到预测器注册表。
---
## 第 1 步:定义模型架构
首先,您需要在 `ShopTRAINING/server/models/` 目录下创建一个新的 Python 文件来定义您的模型。
**示例:创建 `ShopTRAINING/server/models/my_new_model.py`**
如果您的新模型是基于 PyTorch 的,它应该是一个继承自 `torch.nn.Module` 的类。
```python
# file: ShopTRAINING/server/models/my_new_model.py
import torch
import torch.nn as nn
class MyNewModel(nn.Module):
def __init__(self, input_features, hidden_size, output_sequence_length):
"""
定义模型的层和结构。
"""
super(MyNewModel, self).__init__()
self.layer1 = nn.Linear(input_features, hidden_size)
self.relu = nn.ReLU()
self.layer2 = nn.Linear(hidden_size, output_sequence_length)
# ... 可添加更复杂的结构
def forward(self, x):
"""
定义数据通过模型的前向传播路径。
x 的形状通常是 (batch_size, sequence_length, num_features)
"""
# 确保输入是正确的形状
# 例如,对于简单的线性层,可能需要展平
batch_size, seq_len, features = x.shape
x = x.view(batch_size * seq_len, features) # 展平
out = self.layer1(x)
out = self.relu(out)
out = self.layer2(out)
# 恢复形状以匹配输出
out = out.view(batch_size, seq_len, -1)
# 通常我们只关心序列的最后一个预测
return out[:, -1, :]
```
---
## 第 2 步:创建模型训练器
接下来,在 `ShopTRAINING/server/trainers/` 目录下创建一个新的训练器文件。这个文件负责模型的整个训练、评估和保存流程。
**示例:创建 `ShopTRAINING/server/trainers/my_new_model_trainer.py`**
这个训练函数需要遵循系统中其他训练器(如 `xgboost_trainer.py`)的统一函数签名,并使用 `@register_trainer` 装饰器或在文件末尾调用 `register_trainer` 函数。
```python
# file: ShopTRAINING/server/trainers/my_new_model_trainer.py
import torch
import torch.optim as optim
from models.model_registry import register_trainer
from utils.model_manager import model_manager
from analysis.metrics import evaluate_model
from models.my_new_model import MyNewModel # 导入您的新模型
# 遵循系统的标准函数签名
def train_with_mynewmodel(product_id, model_identifier, product_df, store_id, training_mode, aggregation_method, epochs, sequence_length, forecast_horizon, model_dir, **kwargs):
print(f"🚀 MyNewModel 训练器启动: model_identifier='{model_identifier}'")
# --- 1. 数据准备 ---
# (此处省略了数据加载、标准化和创建数据集的详细代码,
# 您可以参考 xgboost_trainer.py 或其他训练器中的实现)
# ...
# 假设您已准备好 trainX, trainY, testX, testY, scaler_y 等变量
# trainX = ...
# trainY = ...
# testX = ...
# testY = ...
# scaler_y = ...
# features = [...]
# --- 2. 实例化模型和优化器 ---
input_dim = trainX.shape[2] # 获取特征数量
hidden_size = 64 # 示例超参数
model = MyNewModel(
input_features=input_dim,
hidden_size=hidden_size,
output_sequence_length=forecast_horizon
)
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = torch.nn.MSELoss()
# --- 3. 训练循环 ---
print("开始训练 MyNewModel...")
for epoch in range(epochs):
model.train()
optimizer.zero_grad()
outputs = model(trainX)
loss = criterion(outputs, trainY)
loss.backward()
optimizer.step()
if (epoch + 1) % 10 == 0:
print(f'Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}')
# --- 4. 模型评估 ---
model.eval()
with torch.no_grad():
test_pred_scaled = model(testX)
# 反标准化并计算指标
# ... (参考其他训练器)
metrics = {'rmse': 0.0, 'mae': 0.0, 'r2': 0.0, 'mape': 0.0} # 示例
# --- 5. 模型保存 ---
model_data = {
'model_state_dict': model.state_dict(),
'scaler_X': None, # 替换为您的 scaler_X
'scaler_y': scaler_y,
'config': {
'model_type': 'mynewmodel', # **关键**: 使用唯一的模型名称
'input_dim': input_dim,
'hidden_size': hidden_size,
'sequence_length': sequence_length,
'forecast_horizon': forecast_horizon,
'features': features
},
'metrics': metrics
}
model_manager.save_model(
model_data=model_data,
product_id=product_id,
model_type='mynewmodel', # **关键**: 再次确认模型名称
# ... 其他参数
)
print("✅ MyNewModel 模型训练并保存完成!")
return model, metrics, "v1", "path/to/model" # 返回值遵循统一格式
# --- 关键步骤: 将训练器注册到系统中 ---
register_trainer('mynewmodel', train_with_mynewmodel)
```
---
## 第 3 步:集成模型预测器
最后,您需要让系统知道如何加载和使用您的新模型进行预测。这需要在 `ShopTRAINING/server/predictors/model_predictor.py` 中进行两处修改。
**文件: `ShopTRAINING/server/predictors/model_predictor.py`**
1. **让系统知道如何构建您的模型实例**
`load_model_and_predict` 函数中,有一个 `if/elif` 结构用于根据模型类型实例化不同的模型。您需要为 `MyNewModel` 添加一个新的分支。
```python
# 在 model_predictor.py 中
# 首先,导入您的新模型类
from models.my_new_model import MyNewModel
# ... 在 load_model_and_predict 函数内部 ...
# ... 其他模型的 elif 分支 ...
elif loaded_model_type == 'tcn':
model = TCNForecaster(...)
# vvv 添加这个新的分支 vvv
elif loaded_model_type == 'mynewmodel':
model = MyNewModel(
input_features=config['input_dim'],
hidden_size=config['hidden_size'],
output_sequence_length=config['forecast_horizon']
).to(DEVICE)
# ^^^ 添加结束 ^^^
else:
raise ValueError(f"不支持的模型类型: {loaded_model_type}")
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
```
2. **注册预测逻辑**
如果您的模型是一个标准的 PyTorch 模型,并且其预测逻辑与现有的模型(如 Transformer, KAN相同您可以直接复用 `default_pytorch_predictor`。只需在文件末尾添加一行注册代码即可。
```python
# 在 model_predictor.py 文件末尾
# ...
# 将增强后的默认预测器也注册给xgboost
register_predictor('xgboost', default_pytorch_predictor)
# vvv 添加这行代码 vvv
# 让 'mynewmodel' 也使用通用的 PyTorch 预测器
register_predictor('mynewmodel', default_pytorch_predictor)
# ^^^ 添加结束 ^^^
```
如果您的模型需要特殊的预测逻辑(例如,像 XGBoost 那样有不同的输入格式或调用方式),您可以复制 `default_pytorch_predictor` 创建一个新函数,修改其内部逻辑,然后将新函数注册给 `'mynewmodel'`
---
## 总结
完成以上三个步骤后,您的新模型 `MyNewModel` 就已完全集成到系统中了。系统会自动在 `trainers` 目录中发现您的新训练器。当您通过 API 或界面选择 `mynewmodel` 进行训练和预测时,系统将自动调用您刚刚编写和注册的所有相应逻辑。