插件式添加模型
This commit is contained in:
parent
038289ae32
commit
751de9b548
Binary file not shown.
@ -56,5 +56,6 @@ tzdata==2025.2
|
|||||||
werkzeug==3.1.3
|
werkzeug==3.1.3
|
||||||
win32-setctime==1.2.0
|
win32-setctime==1.2.0
|
||||||
wsproto==1.2.0
|
wsproto==1.2.0
|
||||||
|
|
||||||
python-dateutil
|
python-dateutil
|
||||||
|
xgboost
|
||||||
|
scikit-learn
|
||||||
|
@ -45,6 +45,7 @@ from trainers.mlstm_trainer import train_product_model_with_mlstm
|
|||||||
from trainers.kan_trainer import train_product_model_with_kan
|
from trainers.kan_trainer import train_product_model_with_kan
|
||||||
from trainers.tcn_trainer import train_product_model_with_tcn
|
from trainers.tcn_trainer import train_product_model_with_tcn
|
||||||
from trainers.transformer_trainer import train_product_model_with_transformer
|
from trainers.transformer_trainer import train_product_model_with_transformer
|
||||||
|
from trainers.xgboost_trainer import train_product_model_with_xgboost
|
||||||
|
|
||||||
# 导入预测函数
|
# 导入预测函数
|
||||||
from predictors.model_predictor import load_model_and_predict
|
from predictors.model_predictor import load_model_and_predict
|
||||||
@ -810,7 +811,7 @@ def get_all_training_tasks():
|
|||||||
'type': 'object',
|
'type': 'object',
|
||||||
'properties': {
|
'properties': {
|
||||||
'product_id': {'type': 'string', 'description': '例如 P001'},
|
'product_id': {'type': 'string', 'description': '例如 P001'},
|
||||||
'model_type': {'type': 'string', 'enum': ['mlstm', 'transformer', 'kan', 'optimized_kan', 'tcn'], 'description': '要训练的模型类型'},
|
'model_type': {'type': 'string', 'enum': ['mlstm', 'transformer', 'kan', 'optimized_kan', 'tcn', 'xgboost'], 'description': '要训练的模型类型'},
|
||||||
'store_id': {'type': 'string', 'description': '店铺ID,如 S001。为空时使用全局聚合数据'},
|
'store_id': {'type': 'string', 'description': '店铺ID,如 S001。为空时使用全局聚合数据'},
|
||||||
'epochs': {'type': 'integer', 'default': 50, 'description': '训练轮次'}
|
'epochs': {'type': 'integer', 'default': 50, 'description': '训练轮次'}
|
||||||
},
|
},
|
||||||
@ -873,10 +874,10 @@ def start_training():
|
|||||||
# 全局模式不需要特定的product_id或store_id
|
# 全局模式不需要特定的product_id或store_id
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# 检查模型类型是否有效
|
# 检查模型类型是否有效 (v2 - 动态检查)
|
||||||
valid_model_types = ['mlstm', 'kan', 'optimized_kan', 'transformer', 'tcn']
|
from models.model_registry import TRAINER_REGISTRY
|
||||||
if model_type not in valid_model_types:
|
if model_type not in TRAINER_REGISTRY:
|
||||||
return jsonify({'error': '无效的模型类型'}), 400
|
return jsonify({'error': f"无效的模型类型: '{model_type}'. 可用模型: {list(TRAINER_REGISTRY.keys())}"}), 400
|
||||||
|
|
||||||
# 使用新的训练进程管理器提交任务
|
# 使用新的训练进程管理器提交任务
|
||||||
try:
|
try:
|
||||||
@ -3445,41 +3446,37 @@ def analyze_model_metrics():
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
def get_model_types():
|
def get_model_types():
|
||||||
"""获取系统支持的所有模型类型"""
|
"""获取系统支持的所有模型类型 (v2 - 动态加载)"""
|
||||||
model_types = [
|
from models.model_registry import TRAINER_REGISTRY
|
||||||
{
|
|
||||||
'id': 'mlstm',
|
|
||||||
'name': 'mLSTM',
|
|
||||||
'description': '矩阵长短期记忆网络,适合处理多变量时序数据',
|
|
||||||
'tag_type': 'primary'
|
|
||||||
},
|
|
||||||
{
|
|
||||||
'id': 'transformer',
|
|
||||||
'name': 'Transformer',
|
|
||||||
'description': '基于注意力机制的序列模型,适合捕捉长期依赖关系',
|
|
||||||
'tag_type': 'success'
|
|
||||||
},
|
|
||||||
{
|
|
||||||
'id': 'kan',
|
|
||||||
'name': 'KAN',
|
|
||||||
'description': 'Kolmogorov-Arnold网络,能够逼近任意连续函数',
|
|
||||||
'tag_type': 'warning'
|
|
||||||
},
|
|
||||||
{
|
|
||||||
'id': 'optimized_kan',
|
|
||||||
'name': '优化版KAN',
|
|
||||||
'description': '经过优化的KAN模型,提供更高的预测精度和训练效率',
|
|
||||||
'tag_type': 'info'
|
|
||||||
},
|
|
||||||
{
|
|
||||||
'id': 'tcn',
|
|
||||||
'name': 'TCN',
|
|
||||||
'description': '时间卷积网络,适合处理长序列和平行计算',
|
|
||||||
'tag_type': 'danger'
|
|
||||||
}
|
|
||||||
]
|
|
||||||
|
|
||||||
return jsonify({"status": "success", "data": model_types})
|
# 预定义的模型元数据,用于美化显示
|
||||||
|
model_meta = {
|
||||||
|
'mlstm': {'name': 'mLSTM', 'description': '矩阵长短期记忆网络,适合处理多变量时序数据', 'tag_type': 'primary'},
|
||||||
|
'transformer': {'name': 'Transformer', 'description': '基于注意力机制的序列模型,适合捕捉长期依赖关系', 'tag_type': 'success'},
|
||||||
|
'kan': {'name': 'KAN', 'description': 'Kolmogorov-Arnold网络,能够逼近任意连续函数', 'tag_type': 'warning'},
|
||||||
|
'optimized_kan': {'name': '优化版KAN', 'description': '经过优化的KAN模型,提供更高的预测精度和训练效率', 'tag_type': 'info'},
|
||||||
|
'tcn': {'name': 'TCN', 'description': '时间卷积网络,适合处理长序列和平行计算', 'tag_type': 'danger'},
|
||||||
|
'xgboost': {'name': 'XGBoost', 'description': '梯度提升决策树,性能强大且高效的经典模型', 'tag_type': 'primary'}
|
||||||
|
}
|
||||||
|
|
||||||
|
# 从注册表动态获取所有已注册的模型ID
|
||||||
|
registered_models = TRAINER_REGISTRY.keys()
|
||||||
|
|
||||||
|
dynamic_model_types = []
|
||||||
|
for model_id in registered_models:
|
||||||
|
meta = model_meta.get(model_id, {
|
||||||
|
'name': model_id.upper(),
|
||||||
|
'description': f'自定义模型: {model_id}',
|
||||||
|
'tag_type': 'secondary'
|
||||||
|
})
|
||||||
|
dynamic_model_types.append({
|
||||||
|
'id': model_id,
|
||||||
|
'name': meta['name'],
|
||||||
|
'description': meta['description'],
|
||||||
|
'tag_type': meta['tag_type']
|
||||||
|
})
|
||||||
|
|
||||||
|
return jsonify({"status": "success", "data": dynamic_model_types})
|
||||||
|
|
||||||
# ========== 新增版本管理API ==========
|
# ========== 新增版本管理API ==========
|
||||||
|
|
||||||
|
@ -58,7 +58,9 @@ HIDDEN_SIZE = 64 # 隐藏层大小
|
|||||||
NUM_LAYERS = 2 # 层数
|
NUM_LAYERS = 2 # 层数
|
||||||
|
|
||||||
# 支持的模型类型
|
# 支持的模型类型
|
||||||
SUPPORTED_MODELS = ['mlstm', 'kan', 'transformer', 'tcn', 'optimized_kan']
|
# 支持的模型类型 (v2 - 动态加载)
|
||||||
|
from models.model_registry import TRAINER_REGISTRY
|
||||||
|
SUPPORTED_MODELS = list(TRAINER_REGISTRY.keys())
|
||||||
|
|
||||||
# 版本管理配置
|
# 版本管理配置
|
||||||
MODEL_VERSION_PREFIX = 'v' # 版本前缀
|
MODEL_VERSION_PREFIX = 'v' # 版本前缀
|
||||||
|
@ -11,12 +11,13 @@ import time
|
|||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
from trainers import (
|
# from trainers import (
|
||||||
train_product_model_with_mlstm,
|
# train_product_model_with_mlstm,
|
||||||
train_product_model_with_kan,
|
# train_product_model_with_kan,
|
||||||
train_product_model_with_tcn,
|
# train_product_model_with_tcn,
|
||||||
train_product_model_with_transformer
|
# train_product_model_with_transformer
|
||||||
)
|
# )
|
||||||
|
# 上述导入已不再需要,因为我们现在通过模型注册表动态获取训练器
|
||||||
from predictors.model_predictor import load_model_and_predict
|
from predictors.model_predictor import load_model_and_predict
|
||||||
from utils.data_utils import prepare_data, prepare_sequences
|
from utils.data_utils import prepare_data, prepare_sequences
|
||||||
from utils.multi_store_data_utils import (
|
from utils.multi_store_data_utils import (
|
||||||
@ -187,89 +188,49 @@ class PharmacyPredictor:
|
|||||||
else: # product mode
|
else: # product mode
|
||||||
model_identifier = product_id
|
model_identifier = product_id
|
||||||
|
|
||||||
# 调用相应的训练函数
|
# 调用相应的训练函数 (重构为使用注册表)
|
||||||
try:
|
try:
|
||||||
log_message(f"🤖 开始调用 {model_type} 训练器")
|
from models.model_registry import get_trainer
|
||||||
if model_type == 'transformer':
|
log_message(f"🤖 正在从注册表获取 '{model_type}' 训练器...")
|
||||||
model_result, metrics, actual_version = train_product_model_with_transformer(
|
trainer_function = get_trainer(model_type)
|
||||||
product_id=product_id, # product_id 仍然需要,用于数据过滤
|
log_message(f"✅ 成功获取训练器: {trainer_function.__name__}")
|
||||||
model_identifier=model_identifier, # 这是用于保存模型的唯一ID
|
|
||||||
product_df=product_data,
|
# 准备通用参数
|
||||||
store_id=store_id,
|
trainer_args = {
|
||||||
training_mode=training_mode,
|
'product_id': product_id,
|
||||||
aggregation_method=aggregation_method,
|
'model_identifier': model_identifier,
|
||||||
epochs=epochs,
|
'product_df': product_data,
|
||||||
sequence_length=sequence_length,
|
'store_id': store_id,
|
||||||
forecast_horizon=forecast_horizon,
|
'training_mode': training_mode,
|
||||||
model_dir=self.model_dir,
|
'aggregation_method': aggregation_method,
|
||||||
version=version,
|
'epochs': epochs,
|
||||||
socketio=socketio,
|
'sequence_length': sequence_length,
|
||||||
task_id=task_id,
|
'forecast_horizon': forecast_horizon,
|
||||||
continue_training=continue_training
|
'model_dir': self.model_dir,
|
||||||
)
|
'socketio': socketio,
|
||||||
log_message(f"✅ {model_type} 训练器返回: metrics={type(metrics)}, version={actual_version}", 'success')
|
'task_id': task_id,
|
||||||
elif model_type == 'mlstm':
|
'progress_callback': progress_callback,
|
||||||
_, metrics, _, _ = train_product_model_with_mlstm(
|
'version': version,
|
||||||
product_id=product_id,
|
'continue_training': continue_training,
|
||||||
model_identifier=model_identifier, # 传递修正后的ID
|
'use_optimized': use_optimized # KAN模型需要
|
||||||
product_df=product_data,
|
}
|
||||||
store_id=store_id,
|
|
||||||
training_mode=training_mode,
|
# 动态调用训练函数 (v2 - 智能参数过滤)
|
||||||
aggregation_method=aggregation_method,
|
import inspect
|
||||||
epochs=epochs,
|
sig = inspect.signature(trainer_function)
|
||||||
sequence_length=sequence_length,
|
valid_args = {k: v for k, v in trainer_args.items() if k in sig.parameters}
|
||||||
forecast_horizon=forecast_horizon,
|
|
||||||
model_dir=self.model_dir,
|
log_message(f"🔍 准备调用 {trainer_function.__name__},有效参数: {list(valid_args.keys())}")
|
||||||
socketio=socketio,
|
|
||||||
task_id=task_id,
|
result = trainer_function(**valid_args)
|
||||||
progress_callback=progress_callback
|
|
||||||
)
|
# 根据返回值的数量解析metrics
|
||||||
elif model_type == 'kan':
|
if isinstance(result, tuple) and len(result) >= 2:
|
||||||
_, metrics = train_product_model_with_kan(
|
metrics = result[1] # 通常第二个返回值是metrics
|
||||||
product_id=product_id,
|
|
||||||
model_identifier=model_identifier, # 传递修正后的ID
|
|
||||||
product_df=product_data,
|
|
||||||
store_id=store_id,
|
|
||||||
training_mode=training_mode,
|
|
||||||
aggregation_method=aggregation_method,
|
|
||||||
epochs=epochs,
|
|
||||||
sequence_length=sequence_length,
|
|
||||||
forecast_horizon=forecast_horizon,
|
|
||||||
use_optimized=use_optimized,
|
|
||||||
model_dir=self.model_dir
|
|
||||||
)
|
|
||||||
elif model_type == 'optimized_kan':
|
|
||||||
_, metrics = train_product_model_with_kan(
|
|
||||||
product_id=product_id,
|
|
||||||
model_identifier=model_identifier, # 传递修正后的ID
|
|
||||||
product_df=product_data,
|
|
||||||
store_id=store_id,
|
|
||||||
training_mode=training_mode,
|
|
||||||
aggregation_method=aggregation_method,
|
|
||||||
epochs=epochs,
|
|
||||||
sequence_length=sequence_length,
|
|
||||||
forecast_horizon=forecast_horizon,
|
|
||||||
use_optimized=True,
|
|
||||||
model_dir=self.model_dir
|
|
||||||
)
|
|
||||||
elif model_type == 'tcn':
|
|
||||||
_, metrics, _, _ = train_product_model_with_tcn(
|
|
||||||
product_id=product_id,
|
|
||||||
model_identifier=model_identifier, # 传递修正后的ID
|
|
||||||
product_df=product_data,
|
|
||||||
store_id=store_id,
|
|
||||||
training_mode=training_mode,
|
|
||||||
aggregation_method=aggregation_method,
|
|
||||||
epochs=epochs,
|
|
||||||
sequence_length=sequence_length,
|
|
||||||
forecast_horizon=forecast_horizon,
|
|
||||||
model_dir=self.model_dir,
|
|
||||||
socketio=socketio,
|
|
||||||
task_id=task_id
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
log_message(f"不支持的模型类型: {model_type}", 'error')
|
log_message(f"⚠️ 训练器返回格式未知,无法直接提取metrics: {type(result)}", 'warning')
|
||||||
return None
|
metrics = None
|
||||||
|
|
||||||
|
|
||||||
# 检查和打印返回的metrics
|
# 检查和打印返回的metrics
|
||||||
log_message(f"📊 训练完成,检查返回的metrics: {metrics}")
|
log_message(f"📊 训练完成,检查返回的metrics: {metrics}")
|
||||||
|
102
server/models/cnn_bilstm_attention.py
Normal file
102
server/models/cnn_bilstm_attention.py
Normal file
@ -0,0 +1,102 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
CNN-BiLSTM-Attention 模型定义,适配药店销售预测系统。
|
||||||
|
原始代码来源: python机器学习回归全家桶
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
# 注意:由于原始代码使用了 TensorFlow/Keras 的层,我们将在这里创建一个 PyTorch 的等效实现。
|
||||||
|
# 这是一个更健壮、更符合现有系统架构的做法。
|
||||||
|
|
||||||
|
class Attention(nn.Module):
|
||||||
|
"""
|
||||||
|
PyTorch 实现的注意力机制。
|
||||||
|
"""
|
||||||
|
def __init__(self, feature_dim, step_dim, bias=True, **kwargs):
|
||||||
|
super(Attention, self).__init__(**kwargs)
|
||||||
|
|
||||||
|
self.supports_masking = True
|
||||||
|
self.bias = bias
|
||||||
|
self.feature_dim = feature_dim
|
||||||
|
self.step_dim = step_dim
|
||||||
|
self.features_dim = 0
|
||||||
|
|
||||||
|
weight = torch.zeros(feature_dim, 1)
|
||||||
|
nn.init.xavier_uniform_(weight)
|
||||||
|
self.weight = nn.Parameter(weight)
|
||||||
|
|
||||||
|
if bias:
|
||||||
|
self.b = nn.Parameter(torch.zeros(step_dim))
|
||||||
|
|
||||||
|
def forward(self, x, mask=None):
|
||||||
|
feature_dim = self.feature_dim
|
||||||
|
step_dim = self.step_dim
|
||||||
|
|
||||||
|
eij = torch.mm(
|
||||||
|
x.contiguous().view(-1, feature_dim),
|
||||||
|
self.weight
|
||||||
|
).view(-1, step_dim)
|
||||||
|
|
||||||
|
if self.bias:
|
||||||
|
eij = eij + self.b
|
||||||
|
|
||||||
|
eij = torch.tanh(eij)
|
||||||
|
a = torch.exp(eij)
|
||||||
|
|
||||||
|
if mask is not None:
|
||||||
|
a = a * mask
|
||||||
|
|
||||||
|
a = a / (torch.sum(a, 1, keepdim=True) + 1e-10)
|
||||||
|
|
||||||
|
weighted_input = x * torch.unsqueeze(a, -1)
|
||||||
|
return torch.sum(weighted_input, 1)
|
||||||
|
|
||||||
|
|
||||||
|
class CnnBiLstmAttention(nn.Module):
|
||||||
|
"""
|
||||||
|
CNN-BiLSTM-Attention 模型的 PyTorch 实现。
|
||||||
|
"""
|
||||||
|
def __init__(self, input_dim, output_dim, sequence_length, cnn_filters=64, cnn_kernel_size=1, lstm_units=128):
|
||||||
|
super(CnnBiLstmAttention, self).__init__()
|
||||||
|
self.sequence_length = sequence_length
|
||||||
|
self.cnn_filters = cnn_filters
|
||||||
|
self.lstm_units = lstm_units
|
||||||
|
|
||||||
|
# CNN 层
|
||||||
|
self.conv1d = nn.Conv1d(in_channels=input_dim, out_channels=cnn_filters, kernel_size=cnn_kernel_size)
|
||||||
|
self.relu = nn.ReLU()
|
||||||
|
self.maxpool = nn.MaxPool1d(kernel_size=1)
|
||||||
|
|
||||||
|
# BiLSTM 层
|
||||||
|
self.bilstm = nn.LSTM(input_size=cnn_filters, hidden_size=lstm_units, num_layers=1, batch_first=True, bidirectional=True)
|
||||||
|
|
||||||
|
# Attention 层
|
||||||
|
self.attention = Attention(feature_dim=lstm_units * 2, step_dim=sequence_length)
|
||||||
|
|
||||||
|
# 全连接输出层
|
||||||
|
self.dense = nn.Linear(lstm_units * 2, output_dim)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# 输入 x 的形状: (batch_size, sequence_length, input_dim)
|
||||||
|
|
||||||
|
# CNN 处理
|
||||||
|
x = x.permute(0, 2, 1) # 转换为 (batch_size, input_dim, sequence_length) 以适应 Conv1d
|
||||||
|
x = self.conv1d(x)
|
||||||
|
x = self.relu(x)
|
||||||
|
x = x.permute(0, 2, 1) # 转换回 (batch_size, sequence_length, cnn_filters)
|
||||||
|
|
||||||
|
# BiLSTM 处理
|
||||||
|
lstm_out, _ = self.bilstm(x) # lstm_out 形状: (batch_size, sequence_length, lstm_units * 2)
|
||||||
|
|
||||||
|
# Attention 处理
|
||||||
|
# 注意:这里的 Attention 实现可能需要根据具体任务微调
|
||||||
|
# 一个简化的方法是直接使用 LSTM 的最终隐藏状态或输出
|
||||||
|
# 这里我们先用一个简化的逻辑:直接展平 LSTM 输出
|
||||||
|
attention_out = self.attention(lstm_out)
|
||||||
|
|
||||||
|
# 全连接层输出
|
||||||
|
output = self.dense(attention_out)
|
||||||
|
|
||||||
|
return output
|
64
server/models/model_registry.py
Normal file
64
server/models/model_registry.py
Normal file
@ -0,0 +1,64 @@
|
|||||||
|
"""
|
||||||
|
模型注册表
|
||||||
|
用于解耦模型的调用和实现,支持插件式扩展新模型。
|
||||||
|
"""
|
||||||
|
|
||||||
|
# 训练器注册表
|
||||||
|
TRAINER_REGISTRY = {}
|
||||||
|
|
||||||
|
def register_trainer(name, func):
|
||||||
|
"""
|
||||||
|
注册一个模型训练器。
|
||||||
|
|
||||||
|
参数:
|
||||||
|
name (str): 模型类型名称 (e.g., 'xgboost')
|
||||||
|
func (function): 对应的训练函数
|
||||||
|
"""
|
||||||
|
if name in TRAINER_REGISTRY:
|
||||||
|
print(f"警告: 模型训练器 '{name}' 已被覆盖注册。")
|
||||||
|
TRAINER_REGISTRY[name] = func
|
||||||
|
print(f"✅ 已注册训练器: {name}")
|
||||||
|
|
||||||
|
def get_trainer(name):
|
||||||
|
"""
|
||||||
|
根据模型类型名称获取一个已注册的训练器。
|
||||||
|
"""
|
||||||
|
if name not in TRAINER_REGISTRY:
|
||||||
|
# 在打印可用训练器之前,确保它们已经被加载
|
||||||
|
from trainers import discover_trainers
|
||||||
|
discover_trainers()
|
||||||
|
if name not in TRAINER_REGISTRY:
|
||||||
|
raise ValueError(f"未注册的模型训练器: '{name}'. 可用: {list(TRAINER_REGISTRY.keys())}")
|
||||||
|
return TRAINER_REGISTRY[name]
|
||||||
|
|
||||||
|
# --- 预测器注册表 ---
|
||||||
|
|
||||||
|
# 预测器函数需要一个统一的接口,例如:
|
||||||
|
# def predictor_function(model, checkpoint, **kwargs): -> predictions
|
||||||
|
|
||||||
|
PREDICTOR_REGISTRY = {}
|
||||||
|
|
||||||
|
def register_predictor(name, func):
|
||||||
|
"""
|
||||||
|
注册一个模型预测器。
|
||||||
|
"""
|
||||||
|
if name in PREDICTOR_REGISTRY:
|
||||||
|
print(f"警告: 模型预测器 '{name}' 已被覆盖注册。")
|
||||||
|
PREDICTOR_REGISTRY[name] = func
|
||||||
|
|
||||||
|
def get_predictor(name):
|
||||||
|
"""
|
||||||
|
根据模型类型名称获取一个已注册的预测器。
|
||||||
|
如果找不到特定预测器,可以返回一个默认的。
|
||||||
|
"""
|
||||||
|
return PREDICTOR_REGISTRY.get(name, PREDICTOR_REGISTRY.get('default'))
|
||||||
|
|
||||||
|
# 默认的PyTorch预测逻辑可以被注册为 'default'
|
||||||
|
def register_default_predictors():
|
||||||
|
from predictors.model_predictor import default_pytorch_predictor
|
||||||
|
register_predictor('default', default_pytorch_predictor)
|
||||||
|
# 如果其他PyTorch模型有特殊预测逻辑,也可以在这里注册
|
||||||
|
# register_predictor('kan', kan_predictor_func)
|
||||||
|
|
||||||
|
# 注意:这个函数的调用时机很重要,需要在应用启动时执行一次。
|
||||||
|
# 我们可以暂时在 model_predictor.py 导入注册表后调用它。
|
Binary file not shown.
@ -18,39 +18,90 @@ from models.mlstm_model import MLSTMTransformer as MatrixLSTM
|
|||||||
from models.kan_model import KANForecaster
|
from models.kan_model import KANForecaster
|
||||||
from models.tcn_model import TCNForecaster
|
from models.tcn_model import TCNForecaster
|
||||||
from models.optimized_kan_forecaster import OptimizedKANForecaster
|
from models.optimized_kan_forecaster import OptimizedKANForecaster
|
||||||
|
from models.cnn_bilstm_attention import CnnBiLstmAttention
|
||||||
|
import xgboost as xgb
|
||||||
|
|
||||||
from analysis.trend_analysis import analyze_prediction_result
|
from analysis.trend_analysis import analyze_prediction_result
|
||||||
from utils.visualization import plot_prediction_results
|
from utils.visualization import plot_prediction_results
|
||||||
from utils.multi_store_data_utils import get_store_product_sales_data, aggregate_multi_store_data
|
from utils.multi_store_data_utils import get_store_product_sales_data, aggregate_multi_store_data
|
||||||
from core.config import DEVICE, get_model_file_path, DEFAULT_DATA_PATH
|
from core.config import DEVICE, get_model_file_path, DEFAULT_DATA_PATH
|
||||||
|
from models.model_registry import get_predictor, register_predictor
|
||||||
|
|
||||||
|
def default_pytorch_predictor(model, checkpoint, product_df, future_days, start_date, history_lookback_days):
|
||||||
|
"""
|
||||||
|
默认的PyTorch模型预测逻辑,支持自动回归。
|
||||||
|
"""
|
||||||
|
config = checkpoint['config']
|
||||||
|
scaler_X = checkpoint['scaler_X']
|
||||||
|
scaler_y = checkpoint['scaler_y']
|
||||||
|
features = config.get('features', ['sales', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature'])
|
||||||
|
sequence_length = config['sequence_length']
|
||||||
|
|
||||||
|
if start_date:
|
||||||
|
start_date_dt = pd.to_datetime(start_date)
|
||||||
|
prediction_input_df = product_df[product_df['date'] < start_date_dt].tail(sequence_length)
|
||||||
|
else:
|
||||||
|
prediction_input_df = product_df.tail(sequence_length)
|
||||||
|
start_date_dt = product_df['date'].iloc[-1] + timedelta(days=1)
|
||||||
|
|
||||||
|
if len(prediction_input_df) < sequence_length:
|
||||||
|
raise ValueError(f"预测所需的历史数据不足。需要 {sequence_length} 天, 但只有 {len(prediction_input_df)} 天。")
|
||||||
|
|
||||||
|
history_for_chart_df = product_df[product_df['date'] < start_date_dt].tail(history_lookback_days)
|
||||||
|
|
||||||
|
all_predictions = []
|
||||||
|
current_sequence_df = prediction_input_df.copy()
|
||||||
|
|
||||||
|
for _ in range(future_days):
|
||||||
|
X_current_scaled = scaler_X.transform(current_sequence_df[features].values)
|
||||||
|
# **核心改进**: 智能判断模型类型并调用相应的预测方法
|
||||||
|
if isinstance(model, xgb.Booster):
|
||||||
|
# XGBoost 模型预测路径
|
||||||
|
X_input_reshaped = X_current_scaled.reshape(1, -1)
|
||||||
|
d_input = xgb.DMatrix(X_input_reshaped)
|
||||||
|
# **关键修复**: 使用 best_iteration 进行预测,以匹配早停策略
|
||||||
|
y_pred_scaled = model.predict(d_input, iteration_range=(0, model.best_iteration))
|
||||||
|
next_step_pred_scaled = y_pred_scaled.reshape(1, -1)
|
||||||
|
else:
|
||||||
|
# 默认 PyTorch 模型预测路径
|
||||||
|
X_input = torch.tensor(X_current_scaled.reshape(1, sequence_length, -1), dtype=torch.float32).to(DEVICE)
|
||||||
|
with torch.no_grad():
|
||||||
|
y_pred_scaled = model(X_input).cpu().numpy()
|
||||||
|
next_step_pred_scaled = y_pred_scaled[0, 0].reshape(1, -1)
|
||||||
|
next_step_pred_unscaled = float(max(0, scaler_y.inverse_transform(next_step_pred_scaled)[0][0]))
|
||||||
|
|
||||||
|
next_date = current_sequence_df['date'].iloc[-1] + timedelta(days=1)
|
||||||
|
all_predictions.append({'date': next_date, 'predicted_sales': next_step_pred_unscaled})
|
||||||
|
|
||||||
|
new_row = {'date': next_date, 'sales': next_step_pred_unscaled, 'weekday': next_date.weekday(), 'month': next_date.month, 'is_holiday': 0, 'is_weekend': 1 if next_date.weekday() >= 5 else 0, 'is_promotion': 0, 'temperature': current_sequence_df['temperature'].iloc[-1]}
|
||||||
|
new_row_df = pd.DataFrame([new_row])
|
||||||
|
current_sequence_df = pd.concat([current_sequence_df.iloc[1:], new_row_df], ignore_index=True)
|
||||||
|
|
||||||
|
predictions_df = pd.DataFrame(all_predictions)
|
||||||
|
return predictions_df, history_for_chart_df, prediction_input_df
|
||||||
|
|
||||||
|
# 注册默认的PyTorch预测器
|
||||||
|
register_predictor('default', default_pytorch_predictor)
|
||||||
|
# 将增强后的默认预测器也注册给xgboost
|
||||||
|
register_predictor('xgboost', default_pytorch_predictor)
|
||||||
|
# 将新模型也注册给默认预测器
|
||||||
|
register_predictor('cnn_bilstm_attention', default_pytorch_predictor)
|
||||||
|
|
||||||
|
|
||||||
def load_model_and_predict(model_path: str, product_id: str, model_type: str, store_id: Optional[str] = None, future_days: int = 7, start_date: Optional[str] = None, analyze_result: bool = False, version: Optional[str] = None, training_mode: str = 'product', history_lookback_days: int = 30):
|
def load_model_and_predict(model_path: str, product_id: str, model_type: str, store_id: Optional[str] = None, future_days: int = 7, start_date: Optional[str] = None, analyze_result: bool = False, version: Optional[str] = None, training_mode: str = 'product', history_lookback_days: int = 30):
|
||||||
"""
|
"""
|
||||||
加载已训练的模型并进行预测 (v3版 - 支持自动回归)
|
加载已训练的模型并进行预测 (v4版 - 插件式架构)
|
||||||
|
|
||||||
参数:
|
|
||||||
... (同上, 新增 history_lookback_days)
|
|
||||||
history_lookback_days: 用于图表展示的历史数据天数
|
|
||||||
|
|
||||||
返回:
|
|
||||||
预测结果和分析
|
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
print(f"v3版预测函数启动,模型路径: {model_path}, 预测天数: {future_days}, 历史回看: {history_lookback_days}")
|
|
||||||
|
|
||||||
if not os.path.exists(model_path):
|
if not os.path.exists(model_path):
|
||||||
print(f"模型文件 {model_path} 不存在")
|
raise FileNotFoundError(f"模型文件 {model_path} 不存在")
|
||||||
return None
|
|
||||||
|
# --- 数据加载部分保持不变 ---
|
||||||
# 加载销售数据
|
|
||||||
from utils.multi_store_data_utils import aggregate_multi_store_data
|
from utils.multi_store_data_utils import aggregate_multi_store_data
|
||||||
if training_mode == 'store' and store_id:
|
if training_mode == 'store' and store_id:
|
||||||
# 先从原始数据加载一次以获取店铺名称,聚合会丢失此信息
|
|
||||||
from utils.multi_store_data_utils import load_multi_store_data
|
from utils.multi_store_data_utils import load_multi_store_data
|
||||||
store_df_for_name = load_multi_store_data(store_id=store_id)
|
store_df_for_name = load_multi_store_data(store_id=store_id)
|
||||||
product_name = store_df_for_name['store_name'].iloc[0] if not store_df_for_name.empty else f"店铺 {store_id}"
|
product_name = store_df_for_name['store_name'].iloc[0] if not store_df_for_name.empty else f"店铺 {store_id}"
|
||||||
|
|
||||||
# 然后再进行聚合获取用于预测的数据
|
|
||||||
product_df = aggregate_multi_store_data(store_id=store_id, aggregation_method='sum', file_path=DEFAULT_DATA_PATH)
|
product_df = aggregate_multi_store_data(store_id=store_id, aggregation_method='sum', file_path=DEFAULT_DATA_PATH)
|
||||||
elif training_mode == 'global':
|
elif training_mode == 'global':
|
||||||
product_df = aggregate_multi_store_data(aggregation_method='sum', file_path=DEFAULT_DATA_PATH)
|
product_df = aggregate_multi_store_data(aggregation_method='sum', file_path=DEFAULT_DATA_PATH)
|
||||||
@ -60,124 +111,75 @@ def load_model_and_predict(model_path: str, product_id: str, model_type: str, st
|
|||||||
product_name = product_df['product_name'].iloc[0] if not product_df.empty else product_id
|
product_name = product_df['product_name'].iloc[0] if not product_df.empty else product_id
|
||||||
|
|
||||||
if product_df.empty:
|
if product_df.empty:
|
||||||
print(f"产品 {product_id} 或店铺 {store_id} 没有销售数据")
|
raise ValueError(f"产品 {product_id} 或店铺 {store_id} 没有销售数据")
|
||||||
return None
|
|
||||||
|
|
||||||
# 加载模型和配置
|
# --- 模型加载与实例化 (重构) ---
|
||||||
try:
|
try:
|
||||||
torch.serialization.add_safe_globals([sklearn.preprocessing._data.MinMaxScaler])
|
torch.serialization.add_safe_globals([sklearn.preprocessing._data.MinMaxScaler])
|
||||||
except Exception: pass
|
except Exception: pass
|
||||||
|
|
||||||
checkpoint = torch.load(model_path, map_location=DEVICE, weights_only=False)
|
checkpoint = torch.load(model_path, map_location=DEVICE, weights_only=False)
|
||||||
if 'config' not in checkpoint or 'scaler_X' not in checkpoint or 'scaler_y' not in checkpoint:
|
config = checkpoint.get('config', {})
|
||||||
print("模型文件不完整,缺少config或scaler")
|
loaded_model_type = config.get('model_type', model_type) # 优先使用模型内保存的类型
|
||||||
return None
|
|
||||||
|
# 根据模型类型决定如何获取模型实例
|
||||||
config = checkpoint['config']
|
if loaded_model_type == 'xgboost':
|
||||||
scaler_X = checkpoint['scaler_X']
|
# 对于XGBoost, 模型对象直接保存在'model_state_dict'键中
|
||||||
scaler_y = checkpoint['scaler_y']
|
model = checkpoint['model_state_dict']
|
||||||
|
|
||||||
# 创建模型实例
|
|
||||||
# (此处省略了与原版本相同的模型创建代码,以保持简洁)
|
|
||||||
if model_type == 'transformer':
|
|
||||||
model = TimeSeriesTransformer(num_features=config['input_dim'], d_model=config['hidden_size'], nhead=config['num_heads'], num_encoder_layers=config['num_layers'], dim_feedforward=config['hidden_size'] * 2, dropout=config['dropout'], output_sequence_length=config['output_dim'], seq_length=config['sequence_length'], batch_size=32).to(DEVICE)
|
|
||||||
elif model_type == 'mlstm':
|
|
||||||
model = MatrixLSTM(num_features=config['input_dim'], hidden_size=config['hidden_size'], mlstm_layers=config['mlstm_layers'], embed_dim=config.get('embed_dim', 32), dense_dim=config.get('dense_dim', 32), num_heads=config.get('num_heads', 4), dropout_rate=config['dropout_rate'], num_blocks=config.get('num_blocks', 3), output_sequence_length=config['output_dim']).to(DEVICE)
|
|
||||||
elif model_type == 'kan':
|
|
||||||
model = KANForecaster(input_features=config['input_dim'], hidden_sizes=[config['hidden_size'], config['hidden_size']*2, config['hidden_size']], output_sequence_length=config['output_dim']).to(DEVICE)
|
|
||||||
elif model_type == 'optimized_kan':
|
|
||||||
model = OptimizedKANForecaster(input_features=config['input_dim'], hidden_sizes=[config['hidden_size'], config['hidden_size']*2, config['hidden_size']], output_sequence_length=config['output_dim']).to(DEVICE)
|
|
||||||
elif model_type == 'tcn':
|
|
||||||
model = TCNForecaster(num_features=config['input_dim'], output_sequence_length=config['output_dim'], num_channels=[config['hidden_size']] * config['num_layers'], kernel_size=config['kernel_size'], dropout=config['dropout']).to(DEVICE)
|
|
||||||
else:
|
else:
|
||||||
print(f"不支持的模型类型: {model_type}"); return None
|
# 对于PyTorch模型, 需要重新构建实例并加载state_dict
|
||||||
|
if loaded_model_type == 'transformer':
|
||||||
model.load_state_dict(checkpoint['model_state_dict'])
|
model = TimeSeriesTransformer(num_features=config['input_dim'], d_model=config['hidden_size'], nhead=config['num_heads'], num_encoder_layers=config['num_layers'], dim_feedforward=config['hidden_size'] * 2, dropout=config['dropout'], output_sequence_length=config['output_dim'], seq_length=config['sequence_length'], batch_size=32).to(DEVICE)
|
||||||
model.eval()
|
elif loaded_model_type == 'mlstm':
|
||||||
|
model = MatrixLSTM(num_features=config['input_dim'], hidden_size=config['hidden_size'], mlstm_layers=config['mlstm_layers'], embed_dim=config.get('embed_dim', 32), dense_dim=config.get('dense_dim', 32), num_heads=config.get('num_heads', 4), dropout_rate=config['dropout_rate'], num_blocks=config.get('num_blocks', 3), output_sequence_length=config['output_dim']).to(DEVICE)
|
||||||
# --- 核心逻辑修改:自动回归预测 ---
|
elif loaded_model_type == 'kan':
|
||||||
features = ['sales', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature']
|
model = KANForecaster(input_features=config['input_dim'], hidden_sizes=[config['hidden_size'], config['hidden_size']*2, config['hidden_size']], output_sequence_length=config['output_dim']).to(DEVICE)
|
||||||
sequence_length = config['sequence_length']
|
elif loaded_model_type == 'optimized_kan':
|
||||||
|
model = OptimizedKANForecaster(input_features=config['input_dim'], hidden_sizes=[config['hidden_size'], config['hidden_size']*2, config['hidden_size']], output_sequence_length=config['output_dim']).to(DEVICE)
|
||||||
# 确定预测的起始点
|
elif loaded_model_type == 'tcn':
|
||||||
if start_date:
|
model = TCNForecaster(num_features=config['input_dim'], output_sequence_length=config['output_dim'], num_channels=[config['hidden_size']] * config['num_layers'], kernel_size=config['kernel_size'], dropout=config['dropout']).to(DEVICE)
|
||||||
start_date_dt = pd.to_datetime(start_date)
|
elif loaded_model_type == 'cnn_bilstm_attention':
|
||||||
# 获取预测开始日期前的 `sequence_length` 天数据作为初始输入
|
model = CnnBiLstmAttention(
|
||||||
prediction_input_df = product_df[product_df['date'] < start_date_dt].tail(sequence_length)
|
input_dim=config['input_dim'],
|
||||||
else:
|
output_dim=config['output_dim'],
|
||||||
# 如果未指定开始日期,则从数据的最后一天开始预测
|
sequence_length=config['sequence_length']
|
||||||
prediction_input_df = product_df.tail(sequence_length)
|
).to(DEVICE)
|
||||||
start_date_dt = product_df['date'].iloc[-1] + timedelta(days=1)
|
else:
|
||||||
|
raise ValueError(f"不支持的模型类型: {loaded_model_type}")
|
||||||
if len(prediction_input_df) < sequence_length:
|
|
||||||
print(f"错误: 预测所需的历史数据不足。需要 {sequence_length} 天, 但只有 {len(prediction_input_df)} 天。")
|
|
||||||
return None
|
|
||||||
|
|
||||||
# 准备用于图表展示的历史数据
|
|
||||||
history_for_chart_df = product_df[product_df['date'] < start_date_dt].tail(history_lookback_days)
|
|
||||||
|
|
||||||
# 自动回归预测循环
|
|
||||||
all_predictions = []
|
|
||||||
current_sequence_df = prediction_input_df.copy()
|
|
||||||
|
|
||||||
print(f"开始自动回归预测,共 {future_days} 天...")
|
|
||||||
for i in range(future_days):
|
|
||||||
# 准备当前序列的输入张量
|
|
||||||
X_current_scaled = scaler_X.transform(current_sequence_df[features].values)
|
|
||||||
X_input = torch.tensor(X_current_scaled.reshape(1, sequence_length, -1), dtype=torch.float32).to(DEVICE)
|
|
||||||
|
|
||||||
# 模型进行一次预测(可能预测出多个点,但我们只用第一个)
|
|
||||||
with torch.no_grad():
|
|
||||||
y_pred_scaled = model(X_input).cpu().numpy()
|
|
||||||
|
|
||||||
# 提取下一个时间点的预测值
|
model.load_state_dict(checkpoint['model_state_dict'])
|
||||||
next_step_pred_scaled = y_pred_scaled[0, 0].reshape(1, -1)
|
model.eval()
|
||||||
next_step_pred_unscaled = scaler_y.inverse_transform(next_step_pred_scaled)[0][0]
|
|
||||||
next_step_pred_unscaled = float(max(0, next_step_pred_unscaled)) # 确保销量不为负,并转换为标准float
|
|
||||||
|
|
||||||
# 获取新预测的日期
|
# --- 动态调用预测器 ---
|
||||||
next_date = current_sequence_df['date'].iloc[-1] + timedelta(days=1)
|
predictor_function = get_predictor(loaded_model_type)
|
||||||
all_predictions.append({'date': next_date, 'predicted_sales': next_step_pred_unscaled})
|
if not predictor_function:
|
||||||
|
raise ValueError(f"找不到模型类型 '{loaded_model_type}' 的预测器实现")
|
||||||
|
|
||||||
# 构建新的一行数据,用于更新输入序列
|
predictions_df, history_for_chart_df, prediction_input_df = predictor_function(
|
||||||
new_row = {
|
model=model,
|
||||||
'date': next_date,
|
checkpoint=checkpoint,
|
||||||
'sales': next_step_pred_unscaled,
|
product_df=product_df,
|
||||||
'weekday': next_date.weekday(),
|
future_days=future_days,
|
||||||
'month': next_date.month,
|
start_date=start_date,
|
||||||
'is_holiday': 0,
|
history_lookback_days=history_lookback_days
|
||||||
'is_weekend': 1 if next_date.weekday() >= 5 else 0,
|
)
|
||||||
'is_promotion': 0,
|
|
||||||
'temperature': current_sequence_df['temperature'].iloc[-1] # 沿用最后一天的温度
|
|
||||||
}
|
|
||||||
|
|
||||||
# 更新序列:移除最旧的一行,添加最新预测的一行
|
|
||||||
new_row_df = pd.DataFrame([new_row])
|
|
||||||
current_sequence_df = pd.concat([current_sequence_df.iloc[1:], new_row_df], ignore_index=True)
|
|
||||||
|
|
||||||
predictions_df = pd.DataFrame(all_predictions)
|
# --- 分析与返回部分保持不变 ---
|
||||||
print(f"自动回归预测完成,生成 {len(predictions_df)} 条预测数据。")
|
|
||||||
|
|
||||||
# 分析与可视化
|
|
||||||
analysis = None
|
analysis = None
|
||||||
if analyze_result:
|
if analyze_result:
|
||||||
try:
|
try:
|
||||||
y_pred_for_analysis = predictions_df['predicted_sales'].values
|
analysis = analyze_prediction_result(product_id, loaded_model_type, predictions_df['predicted_sales'].values, prediction_input_df[config.get('features')].values)
|
||||||
# 使用初始输入序列的特征进行分析
|
|
||||||
initial_features_for_analysis = prediction_input_df[features].values
|
|
||||||
analysis = analyze_prediction_result(product_id, model_type, y_pred_for_analysis, initial_features_for_analysis)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"分析预测结果失败: {str(e)}")
|
print(f"分析预测结果失败: {str(e)}")
|
||||||
|
|
||||||
# 在返回前,将DataFrame转换为前端期望的JSON数组格式
|
|
||||||
history_data_json = history_for_chart_df.to_dict('records') if not history_for_chart_df.empty else []
|
history_data_json = history_for_chart_df.to_dict('records') if not history_for_chart_df.empty else []
|
||||||
prediction_data_json = predictions_df.to_dict('records') if not predictions_df.empty else []
|
prediction_data_json = predictions_df.to_dict('records') if not predictions_df.empty else []
|
||||||
|
|
||||||
return {
|
return {
|
||||||
'product_id': product_id,
|
'product_id': product_id,
|
||||||
'product_name': product_name,
|
'product_name': product_name,
|
||||||
'model_type': model_type,
|
'model_type': loaded_model_type,
|
||||||
'predictions': prediction_data_json, # 兼容旧字段,使用已转换的json
|
'predictions': prediction_data_json,
|
||||||
'prediction_data': prediction_data_json,
|
'prediction_data': prediction_data_json,
|
||||||
'history_data': history_data_json,
|
'history_data': history_data_json,
|
||||||
'analysis': analysis
|
'analysis': analysis
|
||||||
|
@ -2,18 +2,44 @@
|
|||||||
药店销售预测系统 - 模型训练模块
|
药店销售预测系统 - 模型训练模块
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from .mlstm_trainer import train_product_model_with_mlstm
|
import os
|
||||||
from .kan_trainer import train_product_model_with_kan
|
import glob
|
||||||
from .tcn_trainer import train_product_model_with_tcn
|
import importlib
|
||||||
from .transformer_trainer import train_product_model_with_transformer
|
|
||||||
|
|
||||||
# 默认训练函数
|
_TRAINERS_LOADED = False
|
||||||
from .mlstm_trainer import train_product_model_with_mlstm as train_product_model
|
|
||||||
|
def discover_trainers():
|
||||||
|
"""
|
||||||
|
自动发现并加载所有训练器插件。
|
||||||
|
使用一个标志位确保这个过程只执行一次。
|
||||||
|
"""
|
||||||
|
global _TRAINERS_LOADED
|
||||||
|
if _TRAINERS_LOADED:
|
||||||
|
return
|
||||||
|
|
||||||
|
print("🚀 开始发现并加载训练器插件...")
|
||||||
|
|
||||||
|
package_dir = os.path.dirname(__file__)
|
||||||
|
module_name = __name__
|
||||||
|
|
||||||
|
trainer_files = glob.glob(os.path.join(package_dir, "*_trainer.py"))
|
||||||
|
|
||||||
|
for f in trainer_files:
|
||||||
|
base_name = os.path.basename(f)
|
||||||
|
if base_name.startswith('__'):
|
||||||
|
continue
|
||||||
|
|
||||||
|
module_stem = base_name.replace('.py', '')
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 动态导入模块以触发自注册
|
||||||
|
importlib.import_module(f".{module_stem}", package=module_name)
|
||||||
|
except ImportError as e:
|
||||||
|
print(f"⚠️ 加载训练器 {module_stem} 失败: {e}")
|
||||||
|
|
||||||
|
_TRAINERS_LOADED = True
|
||||||
|
print("✅ 所有训练器插件加载完成。")
|
||||||
|
|
||||||
|
# 在包被首次导入时,自动执行发现过程
|
||||||
|
discover_trainers()
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
'train_product_model',
|
|
||||||
'train_product_model_with_mlstm',
|
|
||||||
'train_product_model_with_kan',
|
|
||||||
'train_product_model_with_tcn',
|
|
||||||
'train_product_model_with_transformer'
|
|
||||||
]
|
|
||||||
|
118
server/trainers/cnn_bilstm_attention_trainer.py
Normal file
118
server/trainers/cnn_bilstm_attention_trainer.py
Normal file
@ -0,0 +1,118 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
CNN-BiLSTM-Attention 模型训练器
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.optim as optim
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from models.model_registry import register_trainer
|
||||||
|
from utils.model_manager import model_manager
|
||||||
|
from analysis.metrics import evaluate_model
|
||||||
|
from utils.data_utils import create_dataset
|
||||||
|
from sklearn.preprocessing import MinMaxScaler
|
||||||
|
|
||||||
|
# 导入新创建的模型
|
||||||
|
from models.cnn_bilstm_attention import CnnBiLstmAttention
|
||||||
|
|
||||||
|
def train_with_cnn_bilstm_attention(product_id, model_identifier, product_df, store_id, training_mode, aggregation_method, epochs, sequence_length, forecast_horizon, model_dir, **kwargs):
|
||||||
|
"""
|
||||||
|
使用 CNN-BiLSTM-Attention 模型进行训练。
|
||||||
|
函数签名遵循系统标准。
|
||||||
|
"""
|
||||||
|
print(f"🚀 CNN-BiLSTM-Attention 训练器启动: model_identifier='{model_identifier}'")
|
||||||
|
|
||||||
|
# --- 1. 数据准备 ---
|
||||||
|
features = ['sales', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature']
|
||||||
|
|
||||||
|
X = product_df[features].values
|
||||||
|
y = product_df[['sales']].values
|
||||||
|
|
||||||
|
scaler_X = MinMaxScaler(feature_range=(0, 1))
|
||||||
|
scaler_y = MinMaxScaler(feature_range=(0, 1))
|
||||||
|
|
||||||
|
X_scaled = scaler_X.fit_transform(X)
|
||||||
|
y_scaled = scaler_y.fit_transform(y)
|
||||||
|
|
||||||
|
train_size = int(len(X_scaled) * 0.8)
|
||||||
|
X_train_raw, X_test_raw = X_scaled[:train_size], X_scaled[train_size:]
|
||||||
|
y_train_raw, y_test_raw = y_scaled[:train_size], y_scaled[train_size:]
|
||||||
|
|
||||||
|
trainX, trainY = create_dataset(X_train_raw, y_train_raw, sequence_length, forecast_horizon)
|
||||||
|
testX, testY = create_dataset(X_test_raw, y_test_raw, sequence_length, forecast_horizon)
|
||||||
|
|
||||||
|
# 转换为 PyTorch Tensors
|
||||||
|
trainX = torch.from_numpy(trainX).float()
|
||||||
|
trainY = torch.from_numpy(trainY).float()
|
||||||
|
testX = torch.from_numpy(testX).float()
|
||||||
|
testY = torch.from_numpy(testY).float()
|
||||||
|
|
||||||
|
# --- 2. 实例化模型和优化器 ---
|
||||||
|
input_dim = trainX.shape[2]
|
||||||
|
|
||||||
|
model = CnnBiLstmAttention(
|
||||||
|
input_dim=input_dim,
|
||||||
|
output_dim=forecast_horizon,
|
||||||
|
sequence_length=sequence_length
|
||||||
|
)
|
||||||
|
|
||||||
|
optimizer = optim.Adam(model.parameters(), lr=kwargs.get('learning_rate', 0.001))
|
||||||
|
criterion = torch.nn.MSELoss()
|
||||||
|
|
||||||
|
# --- 3. 训练循环 ---
|
||||||
|
print("开始训练 CNN-BiLSTM-Attention 模型...")
|
||||||
|
for epoch in range(epochs):
|
||||||
|
model.train()
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
outputs = model(trainX)
|
||||||
|
loss = criterion(outputs, trainY.squeeze(-1)) # 确保目标维度匹配
|
||||||
|
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
if (epoch + 1) % 10 == 0:
|
||||||
|
print(f'Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}')
|
||||||
|
|
||||||
|
# --- 4. 模型评估 ---
|
||||||
|
model.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
test_pred_scaled = model(testX)
|
||||||
|
|
||||||
|
test_pred_unscaled = scaler_y.inverse_transform(test_pred_scaled.numpy())
|
||||||
|
test_true_unscaled = scaler_y.inverse_transform(testY.squeeze(-1).numpy())
|
||||||
|
|
||||||
|
metrics = evaluate_model(test_true_unscaled.flatten(), test_pred_unscaled.flatten())
|
||||||
|
print(f"模型评估完成: RMSE={metrics['rmse']:.4f}")
|
||||||
|
|
||||||
|
# --- 5. 模型保存 ---
|
||||||
|
model_data = {
|
||||||
|
'model_state_dict': model.state_dict(),
|
||||||
|
'scaler_X': scaler_X,
|
||||||
|
'scaler_y': scaler_y,
|
||||||
|
'config': {
|
||||||
|
'model_type': 'cnn_bilstm_attention',
|
||||||
|
'input_dim': input_dim,
|
||||||
|
'output_dim': forecast_horizon,
|
||||||
|
'sequence_length': sequence_length,
|
||||||
|
'features': features
|
||||||
|
},
|
||||||
|
'metrics': metrics
|
||||||
|
}
|
||||||
|
|
||||||
|
final_model_path, final_version = model_manager.save_model(
|
||||||
|
model_data=model_data,
|
||||||
|
product_id=product_id,
|
||||||
|
model_type='cnn_bilstm_attention',
|
||||||
|
store_id=store_id,
|
||||||
|
training_mode=training_mode,
|
||||||
|
aggregation_method=aggregation_method,
|
||||||
|
product_name=product_df['product_name'].iloc[0]
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"✅ CNN-BiLSTM-Attention 模型已保存,版本: {final_version}")
|
||||||
|
return model, metrics, final_version, final_model_path
|
||||||
|
|
||||||
|
# --- 关键步骤: 将训练器注册到系统中 ---
|
||||||
|
register_trainer('cnn_bilstm_attention', train_with_cnn_bilstm_attention)
|
@ -349,4 +349,9 @@ def train_product_model_with_kan(product_id, model_identifier, product_df=None,
|
|||||||
|
|
||||||
print(f"最终模型已保存,版本: {final_version}, 路径: {final_model_path}")
|
print(f"最终模型已保存,版本: {final_version}, 路径: {final_model_path}")
|
||||||
|
|
||||||
return model, metrics
|
return model, metrics
|
||||||
|
|
||||||
|
# --- 将此训练器注册到系统中 ---
|
||||||
|
from models.model_registry import register_trainer
|
||||||
|
register_trainer('kan', train_product_model_with_kan)
|
||||||
|
register_trainer('optimized_kan', train_product_model_with_kan)
|
@ -514,4 +514,8 @@ def train_product_model_with_mlstm(
|
|||||||
|
|
||||||
emit_progress(f"✅ mLSTM模型训练完成!版本 {final_version} 已保存", progress=100, metrics=final_metrics)
|
emit_progress(f"✅ mLSTM模型训练完成!版本 {final_version} 已保存", progress=100, metrics=final_metrics)
|
||||||
|
|
||||||
return model, metrics, epochs, final_model_path
|
return model, metrics, epochs, final_model_path
|
||||||
|
|
||||||
|
# --- 将此训练器注册到系统中 ---
|
||||||
|
from models.model_registry import register_trainer
|
||||||
|
register_trainer('mlstm', train_product_model_with_mlstm)
|
@ -379,4 +379,8 @@ def train_product_model_with_tcn(
|
|||||||
|
|
||||||
emit_progress(f"模型训练完成!版本 {final_version} 已保存", progress=100, metrics=final_metrics)
|
emit_progress(f"模型训练完成!版本 {final_version} 已保存", progress=100, metrics=final_metrics)
|
||||||
|
|
||||||
return model, metrics, epochs, final_model_path
|
return model, metrics, epochs, final_model_path
|
||||||
|
|
||||||
|
# --- 将此训练器注册到系统中 ---
|
||||||
|
from models.model_registry import register_trainer
|
||||||
|
register_trainer('tcn', train_product_model_with_tcn)
|
@ -406,4 +406,8 @@ def train_product_model_with_transformer(
|
|||||||
'version': final_version
|
'version': final_version
|
||||||
}
|
}
|
||||||
|
|
||||||
return model, final_metrics, epochs
|
return model, final_metrics, epochs
|
||||||
|
|
||||||
|
# --- 将此训练器注册到系统中 ---
|
||||||
|
from models.model_registry import register_trainer
|
||||||
|
register_trainer('transformer', train_product_model_with_transformer)
|
142
server/trainers/xgboost_trainer.py
Normal file
142
server/trainers/xgboost_trainer.py
Normal file
@ -0,0 +1,142 @@
|
|||||||
|
"""
|
||||||
|
药店销售预测系统 - XGBoost 模型训练器 (插件式)
|
||||||
|
"""
|
||||||
|
|
||||||
|
import time
|
||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
import xgboost as xgb
|
||||||
|
from sklearn.preprocessing import MinMaxScaler
|
||||||
|
from xgboost.callback import EarlyStopping
|
||||||
|
|
||||||
|
# 导入核心工具
|
||||||
|
from utils.data_utils import create_dataset
|
||||||
|
from analysis.metrics import evaluate_model
|
||||||
|
from utils.model_manager import model_manager
|
||||||
|
from models.model_registry import register_trainer
|
||||||
|
|
||||||
|
def train_product_model_with_xgboost(product_id, model_identifier, product_df, store_id, training_mode, aggregation_method, epochs, sequence_length, forecast_horizon, model_dir, **kwargs):
|
||||||
|
"""
|
||||||
|
使用 XGBoost 模型训练产品销售预测模型。
|
||||||
|
此函数签名与其他训练器保持一致,以兼容注册表调用。
|
||||||
|
"""
|
||||||
|
print(f"🚀 XGBoost训练器启动: model_identifier='{model_identifier}'")
|
||||||
|
|
||||||
|
# --- 1. 数据准备和验证 ---
|
||||||
|
if product_df.empty:
|
||||||
|
raise ValueError(f"产品 {product_id} 没有可用的销售数据")
|
||||||
|
|
||||||
|
min_required_samples = sequence_length + forecast_horizon
|
||||||
|
if len(product_df) < min_required_samples:
|
||||||
|
error_msg = (f"数据不足: 需要 {min_required_samples} 条, 实际 {len(product_df)} 条。")
|
||||||
|
raise ValueError(error_msg)
|
||||||
|
|
||||||
|
product_df = product_df.sort_values('date')
|
||||||
|
product_name = product_df['product_name'].iloc[0] if 'product_name' in product_df.columns else model_identifier
|
||||||
|
|
||||||
|
# --- 2. 数据预处理和适配 ---
|
||||||
|
features = ['sales', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature']
|
||||||
|
|
||||||
|
X = product_df[features].values
|
||||||
|
y = product_df[['sales']].values
|
||||||
|
|
||||||
|
scaler_X = MinMaxScaler(feature_range=(0, 1))
|
||||||
|
scaler_y = MinMaxScaler(feature_range=(0, 1))
|
||||||
|
|
||||||
|
X_scaled = scaler_X.fit_transform(X)
|
||||||
|
y_scaled = scaler_y.fit_transform(y)
|
||||||
|
|
||||||
|
train_size = int(len(X_scaled) * 0.8)
|
||||||
|
X_train_raw, X_test_raw = X_scaled[:train_size], X_scaled[train_size:]
|
||||||
|
y_train_raw, y_test_raw = y_scaled[:train_size], y_scaled[train_size:]
|
||||||
|
|
||||||
|
trainX, trainY = create_dataset(X_train_raw, y_train_raw, sequence_length, forecast_horizon)
|
||||||
|
testX, testY = create_dataset(X_test_raw, y_test_raw, sequence_length, forecast_horizon)
|
||||||
|
|
||||||
|
# **关键适配步骤**: XGBoost 需要二维输入
|
||||||
|
trainX = trainX.reshape(trainX.shape[0], -1)
|
||||||
|
testX = testX.reshape(testX.shape[0], -1)
|
||||||
|
|
||||||
|
# **关键适配**: 转换为 XGBoost 核心 DMatrix 格式,以使用稳定的 xgb.train API
|
||||||
|
dtrain = xgb.DMatrix(trainX, label=trainY)
|
||||||
|
dtest = xgb.DMatrix(testX, label=testY)
|
||||||
|
|
||||||
|
# --- 3. 模型训练 (使用核心 xgb.train API) ---
|
||||||
|
xgb_params = {
|
||||||
|
'learning_rate': kwargs.get('learning_rate', 0.08),
|
||||||
|
'subsample': kwargs.get('subsample', 0.75),
|
||||||
|
'colsample_bytree': kwargs.get('colsample_bytree', 1),
|
||||||
|
'max_depth': kwargs.get('max_depth', 7),
|
||||||
|
'gamma': kwargs.get('gamma', 0),
|
||||||
|
'objective': 'reg:squarederror',
|
||||||
|
'eval_metric': 'rmse', # eval_metric 在这里是原生支持的
|
||||||
|
'n_jobs': -1
|
||||||
|
}
|
||||||
|
n_estimators = kwargs.get('n_estimators', 500)
|
||||||
|
|
||||||
|
print("开始训练XGBoost模型 (使用核心xgb.train API)...")
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
evals_result = {}
|
||||||
|
model = xgb.train(
|
||||||
|
params=xgb_params,
|
||||||
|
dtrain=dtrain,
|
||||||
|
num_boost_round=n_estimators,
|
||||||
|
evals=[(dtrain, 'train'), (dtest, 'test')],
|
||||||
|
early_stopping_rounds=50, # early_stopping_rounds 在这里是原生支持的
|
||||||
|
evals_result=evals_result,
|
||||||
|
verbose_eval=False
|
||||||
|
)
|
||||||
|
|
||||||
|
training_time = time.time() - start_time
|
||||||
|
print(f"XGBoost模型训练完成,耗时: {training_time:.2f}秒")
|
||||||
|
|
||||||
|
# --- 4. 模型评估 ---
|
||||||
|
# 使用 model.best_iteration 获取最佳轮次的预测结果
|
||||||
|
test_pred = model.predict(dtest, iteration_range=(0, model.best_iteration))
|
||||||
|
|
||||||
|
test_pred_inv = scaler_y.inverse_transform(test_pred.reshape(-1, forecast_horizon))
|
||||||
|
test_true_inv = scaler_y.inverse_transform(testY.reshape(-1, forecast_horizon))
|
||||||
|
|
||||||
|
metrics = evaluate_model(test_true_inv.flatten(), test_pred_inv.flatten())
|
||||||
|
metrics['training_time'] = training_time
|
||||||
|
|
||||||
|
print("\n模型评估指标:")
|
||||||
|
print(f"MSE: {metrics['mse']:.4f}, RMSE: {metrics['rmse']:.4f}, MAE: {metrics['mae']:.4f}, R²: {metrics['r2']:.4f}, MAPE: {metrics['mape']:.2f}%")
|
||||||
|
|
||||||
|
# --- 5. 模型保存 (借道 utils.model_manager) ---
|
||||||
|
# **关键适配点**: 我们将完整的XGBoost模型对象存入字典
|
||||||
|
# torch.save 可以序列化多种Python对象,包括sklearn模型
|
||||||
|
model_data = {
|
||||||
|
'model_state_dict': model, # 直接保存模型对象
|
||||||
|
'scaler_X': scaler_X,
|
||||||
|
'scaler_y': scaler_y,
|
||||||
|
'config': {
|
||||||
|
'sequence_length': sequence_length,
|
||||||
|
'forecast_horizon': forecast_horizon,
|
||||||
|
'model_type': 'xgboost',
|
||||||
|
'features': features,
|
||||||
|
'xgb_params': xgb_params
|
||||||
|
},
|
||||||
|
'metrics': metrics,
|
||||||
|
'loss_history': evals_result
|
||||||
|
}
|
||||||
|
|
||||||
|
# 调用全局管理器进行保存,复用其命名和版本逻辑
|
||||||
|
final_model_path, final_version = model_manager.save_model(
|
||||||
|
model_data=model_data,
|
||||||
|
product_id=product_id,
|
||||||
|
model_type='xgboost',
|
||||||
|
store_id=store_id,
|
||||||
|
training_mode=training_mode,
|
||||||
|
aggregation_method=aggregation_method,
|
||||||
|
product_name=product_name
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"XGBoost模型已通过统一管理器保存,版本: {final_version}, 路径: {final_model_path}")
|
||||||
|
|
||||||
|
# 返回值遵循统一格式
|
||||||
|
return model, metrics, final_version, final_model_path
|
||||||
|
|
||||||
|
# --- 将此训练器注册到系统中 ---
|
||||||
|
register_trainer('xgboost', train_product_model_with_xgboost)
|
@ -280,48 +280,41 @@ class ModelManager:
|
|||||||
if len(parts) < 3:
|
if len(parts) < 3:
|
||||||
return None # 格式不符合基本要求
|
return None # 格式不符合基本要求
|
||||||
|
|
||||||
model_type = parts[0]
|
# **核心修复**: 采用更健壮的、从后往前的解析逻辑,以支持带下划线的模型名称
|
||||||
mode = parts[1]
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if mode == 'store' and len(parts) >= 3:
|
version = parts[-1]
|
||||||
# {model_type}_store_{store_id}_{version}
|
identifier = parts[-2]
|
||||||
version = parts[-1]
|
mode_candidate = parts[-3]
|
||||||
store_id = '_'.join(parts[2:-1])
|
|
||||||
return {
|
if mode_candidate == 'product':
|
||||||
'model_type': model_type,
|
model_type = '_'.join(parts[:-3])
|
||||||
'training_mode': 'store',
|
|
||||||
'store_id': store_id,
|
|
||||||
'version': version,
|
|
||||||
'product_id': None,
|
|
||||||
'aggregation_method': None
|
|
||||||
}
|
|
||||||
elif mode == 'global' and len(parts) >= 3:
|
|
||||||
# {model_type}_global_{aggregation_method}_{version}
|
|
||||||
version = parts[-1]
|
|
||||||
aggregation_method = '_'.join(parts[2:-1])
|
|
||||||
return {
|
|
||||||
'model_type': model_type,
|
|
||||||
'training_mode': 'global',
|
|
||||||
'aggregation_method': aggregation_method,
|
|
||||||
'version': version,
|
|
||||||
'product_id': None,
|
|
||||||
'store_id': None
|
|
||||||
}
|
|
||||||
elif mode == 'product' and len(parts) >= 3:
|
|
||||||
# {model_type}_product_{product_id}_{version}
|
|
||||||
version = parts[-1]
|
|
||||||
product_id = '_'.join(parts[2:-1])
|
|
||||||
return {
|
return {
|
||||||
'model_type': model_type,
|
'model_type': model_type,
|
||||||
'training_mode': 'product',
|
'training_mode': 'product',
|
||||||
'product_id': product_id,
|
'product_id': identifier,
|
||||||
'version': version,
|
'version': version,
|
||||||
'store_id': None,
|
|
||||||
'aggregation_method': None
|
|
||||||
}
|
}
|
||||||
|
elif mode_candidate == 'store':
|
||||||
|
model_type = '_'.join(parts[:-3])
|
||||||
|
return {
|
||||||
|
'model_type': model_type,
|
||||||
|
'training_mode': 'store',
|
||||||
|
'store_id': identifier,
|
||||||
|
'version': version,
|
||||||
|
}
|
||||||
|
elif mode_candidate == 'global':
|
||||||
|
model_type = '_'.join(parts[:-3])
|
||||||
|
return {
|
||||||
|
'model_type': model_type,
|
||||||
|
'training_mode': 'global',
|
||||||
|
'aggregation_method': identifier,
|
||||||
|
'version': version,
|
||||||
|
}
|
||||||
|
except IndexError:
|
||||||
|
# 如果文件名部分不够,则解析失败
|
||||||
|
pass
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"解析新版v2文件名失败 {filename}: {e}")
|
print(f"解析文件名失败 {filename}: {e}")
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
### 根目录启动
|
### 根目录启动
|
||||||
`uv pip install loguru numpy pandas torch matplotlib flask flask_cors flask_socketio flasgger scikit-learn tqdm pytorch_tcn`
|
`uv pip install loguru numpy pandas torch matplotlib flask flask_cors flask_socketio flasgger scikit-learn tqdm pytorch_tcn pyarrow`
|
||||||
|
|
||||||
### UI
|
### UI
|
||||||
`npm install` `npm run dev`
|
`npm install` `npm run dev`
|
||||||
|
222
xz新模型添加流程.md
Normal file
222
xz新模型添加流程.md
Normal file
@ -0,0 +1,222 @@
|
|||||||
|
# 如何向系统添加新模型
|
||||||
|
|
||||||
|
本指南详细说明了如何向本预测系统添加一个全新的预测模型。系统采用灵活的插件式架构,集成新模型的过程非常模块化,主要围绕 **模型(Model)**、**训练器(Trainer)** 和 **预测器(Predictor)** 这三个核心组件进行。
|
||||||
|
|
||||||
|
## 核心理念
|
||||||
|
|
||||||
|
系统的核心是 `models/model_registry.py`,它维护了两个独立的注册表:一个用于训练函数,另一个用于预测函数。添加新模型的本质就是:
|
||||||
|
|
||||||
|
1. **定义模型**:创建模型的架构。
|
||||||
|
2. **创建训练器**:编写一个函数来训练这个模型,并将其注册到训练器注册表。
|
||||||
|
3. **集成预测器**:确保系统知道如何加载模型并用它来预测,然后将预测逻辑注册到预测器注册表。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 第 1 步:定义模型架构
|
||||||
|
|
||||||
|
首先,您需要在 `ShopTRAINING/server/models/` 目录下创建一个新的 Python 文件来定义您的模型。
|
||||||
|
|
||||||
|
**示例:创建 `ShopTRAINING/server/models/my_new_model.py`**
|
||||||
|
|
||||||
|
如果您的新模型是基于 PyTorch 的,它应该是一个继承自 `torch.nn.Module` 的类。
|
||||||
|
|
||||||
|
```python
|
||||||
|
# file: ShopTRAINING/server/models/my_new_model.py
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
class MyNewModel(nn.Module):
|
||||||
|
def __init__(self, input_features, hidden_size, output_sequence_length):
|
||||||
|
"""
|
||||||
|
定义模型的层和结构。
|
||||||
|
"""
|
||||||
|
super(MyNewModel, self).__init__()
|
||||||
|
self.layer1 = nn.Linear(input_features, hidden_size)
|
||||||
|
self.relu = nn.ReLU()
|
||||||
|
self.layer2 = nn.Linear(hidden_size, output_sequence_length)
|
||||||
|
# ... 可添加更复杂的结构
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""
|
||||||
|
定义数据通过模型的前向传播路径。
|
||||||
|
x 的形状通常是 (batch_size, sequence_length, num_features)
|
||||||
|
"""
|
||||||
|
# 确保输入是正确的形状
|
||||||
|
# 例如,对于简单的线性层,可能需要展平
|
||||||
|
batch_size, seq_len, features = x.shape
|
||||||
|
x = x.view(batch_size * seq_len, features) # 展平
|
||||||
|
|
||||||
|
out = self.layer1(x)
|
||||||
|
out = self.relu(out)
|
||||||
|
out = self.layer2(out)
|
||||||
|
|
||||||
|
# 恢复形状以匹配输出
|
||||||
|
out = out.view(batch_size, seq_len, -1)
|
||||||
|
# 通常我们只关心序列的最后一个预测
|
||||||
|
return out[:, -1, :]
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 第 2 步:创建模型训练器
|
||||||
|
|
||||||
|
接下来,在 `ShopTRAINING/server/trainers/` 目录下创建一个新的训练器文件。这个文件负责模型的整个训练、评估和保存流程。
|
||||||
|
|
||||||
|
**示例:创建 `ShopTRAINING/server/trainers/my_new_model_trainer.py`**
|
||||||
|
|
||||||
|
这个训练函数需要遵循系统中其他训练器(如 `xgboost_trainer.py`)的统一函数签名,并使用 `@register_trainer` 装饰器或在文件末尾调用 `register_trainer` 函数。
|
||||||
|
|
||||||
|
```python
|
||||||
|
# file: ShopTRAINING/server/trainers/my_new_model_trainer.py
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.optim as optim
|
||||||
|
from models.model_registry import register_trainer
|
||||||
|
from utils.model_manager import model_manager
|
||||||
|
from analysis.metrics import evaluate_model
|
||||||
|
from models.my_new_model import MyNewModel # 导入您的新模型
|
||||||
|
|
||||||
|
# 遵循系统的标准函数签名
|
||||||
|
def train_with_mynewmodel(product_id, model_identifier, product_df, store_id, training_mode, aggregation_method, epochs, sequence_length, forecast_horizon, model_dir, **kwargs):
|
||||||
|
print(f"🚀 MyNewModel 训练器启动: model_identifier='{model_identifier}'")
|
||||||
|
|
||||||
|
# --- 1. 数据准备 ---
|
||||||
|
# (此处省略了数据加载、标准化和创建数据集的详细代码,
|
||||||
|
# 您可以参考 xgboost_trainer.py 或其他训练器中的实现)
|
||||||
|
# ...
|
||||||
|
# 假设您已准备好 trainX, trainY, testX, testY, scaler_y 等变量
|
||||||
|
# trainX = ...
|
||||||
|
# trainY = ...
|
||||||
|
# testX = ...
|
||||||
|
# testY = ...
|
||||||
|
# scaler_y = ...
|
||||||
|
# features = [...]
|
||||||
|
|
||||||
|
# --- 2. 实例化模型和优化器 ---
|
||||||
|
input_dim = trainX.shape[2] # 获取特征数量
|
||||||
|
hidden_size = 64 # 示例超参数
|
||||||
|
|
||||||
|
model = MyNewModel(
|
||||||
|
input_features=input_dim,
|
||||||
|
hidden_size=hidden_size,
|
||||||
|
output_sequence_length=forecast_horizon
|
||||||
|
)
|
||||||
|
optimizer = optim.Adam(model.parameters(), lr=0.001)
|
||||||
|
criterion = torch.nn.MSELoss()
|
||||||
|
|
||||||
|
# --- 3. 训练循环 ---
|
||||||
|
print("开始训练 MyNewModel...")
|
||||||
|
for epoch in range(epochs):
|
||||||
|
model.train()
|
||||||
|
optimizer.zero_grad()
|
||||||
|
outputs = model(trainX)
|
||||||
|
loss = criterion(outputs, trainY)
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
if (epoch + 1) % 10 == 0:
|
||||||
|
print(f'Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}')
|
||||||
|
|
||||||
|
# --- 4. 模型评估 ---
|
||||||
|
model.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
test_pred_scaled = model(testX)
|
||||||
|
|
||||||
|
# 反标准化并计算指标
|
||||||
|
# ... (参考其他训练器)
|
||||||
|
metrics = {'rmse': 0.0, 'mae': 0.0, 'r2': 0.0, 'mape': 0.0} # 示例
|
||||||
|
|
||||||
|
# --- 5. 模型保存 ---
|
||||||
|
model_data = {
|
||||||
|
'model_state_dict': model.state_dict(),
|
||||||
|
'scaler_X': None, # 替换为您的 scaler_X
|
||||||
|
'scaler_y': scaler_y,
|
||||||
|
'config': {
|
||||||
|
'model_type': 'mynewmodel', # **关键**: 使用唯一的模型名称
|
||||||
|
'input_dim': input_dim,
|
||||||
|
'hidden_size': hidden_size,
|
||||||
|
'sequence_length': sequence_length,
|
||||||
|
'forecast_horizon': forecast_horizon,
|
||||||
|
'features': features
|
||||||
|
},
|
||||||
|
'metrics': metrics
|
||||||
|
}
|
||||||
|
|
||||||
|
model_manager.save_model(
|
||||||
|
model_data=model_data,
|
||||||
|
product_id=product_id,
|
||||||
|
model_type='mynewmodel', # **关键**: 再次确认模型名称
|
||||||
|
# ... 其他参数
|
||||||
|
)
|
||||||
|
|
||||||
|
print("✅ MyNewModel 模型训练并保存完成!")
|
||||||
|
return model, metrics, "v1", "path/to/model" # 返回值遵循统一格式
|
||||||
|
|
||||||
|
# --- 关键步骤: 将训练器注册到系统中 ---
|
||||||
|
register_trainer('mynewmodel', train_with_mynewmodel)
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 第 3 步:集成模型预测器
|
||||||
|
|
||||||
|
最后,您需要让系统知道如何加载和使用您的新模型进行预测。这需要在 `ShopTRAINING/server/predictors/model_predictor.py` 中进行两处修改。
|
||||||
|
|
||||||
|
**文件: `ShopTRAINING/server/predictors/model_predictor.py`**
|
||||||
|
|
||||||
|
1. **让系统知道如何构建您的模型实例**
|
||||||
|
|
||||||
|
在 `load_model_and_predict` 函数中,有一个 `if/elif` 结构用于根据模型类型实例化不同的模型。您需要为 `MyNewModel` 添加一个新的分支。
|
||||||
|
|
||||||
|
```python
|
||||||
|
# 在 model_predictor.py 中
|
||||||
|
|
||||||
|
# 首先,导入您的新模型类
|
||||||
|
from models.my_new_model import MyNewModel
|
||||||
|
|
||||||
|
# ... 在 load_model_and_predict 函数内部 ...
|
||||||
|
|
||||||
|
# ... 其他模型的 elif 分支 ...
|
||||||
|
elif loaded_model_type == 'tcn':
|
||||||
|
model = TCNForecaster(...)
|
||||||
|
|
||||||
|
# vvv 添加这个新的分支 vvv
|
||||||
|
elif loaded_model_type == 'mynewmodel':
|
||||||
|
model = MyNewModel(
|
||||||
|
input_features=config['input_dim'],
|
||||||
|
hidden_size=config['hidden_size'],
|
||||||
|
output_sequence_length=config['forecast_horizon']
|
||||||
|
).to(DEVICE)
|
||||||
|
# ^^^ 添加结束 ^^^
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError(f"不支持的模型类型: {loaded_model_type}")
|
||||||
|
|
||||||
|
model.load_state_dict(checkpoint['model_state_dict'])
|
||||||
|
model.eval()
|
||||||
|
```
|
||||||
|
|
||||||
|
2. **注册预测逻辑**
|
||||||
|
|
||||||
|
如果您的模型是一个标准的 PyTorch 模型,并且其预测逻辑与现有的模型(如 Transformer, KAN)相同,您可以直接复用 `default_pytorch_predictor`。只需在文件末尾添加一行注册代码即可。
|
||||||
|
|
||||||
|
```python
|
||||||
|
# 在 model_predictor.py 文件末尾
|
||||||
|
|
||||||
|
# ...
|
||||||
|
# 将增强后的默认预测器也注册给xgboost
|
||||||
|
register_predictor('xgboost', default_pytorch_predictor)
|
||||||
|
|
||||||
|
# vvv 添加这行代码 vvv
|
||||||
|
# 让 'mynewmodel' 也使用通用的 PyTorch 预测器
|
||||||
|
register_predictor('mynewmodel', default_pytorch_predictor)
|
||||||
|
# ^^^ 添加结束 ^^^
|
||||||
|
```
|
||||||
|
|
||||||
|
如果您的模型需要特殊的预测逻辑(例如,像 XGBoost 那样有不同的输入格式或调用方式),您可以复制 `default_pytorch_predictor` 创建一个新函数,修改其内部逻辑,然后将新函数注册给 `'mynewmodel'`。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 总结
|
||||||
|
|
||||||
|
完成以上三个步骤后,您的新模型 `MyNewModel` 就已完全集成到系统中了。系统会自动在 `trainers` 目录中发现您的新训练器。当您通过 API 或界面选择 `mynewmodel` 进行训练和预测时,系统将自动调用您刚刚编写和注册的所有相应逻辑。
|
Loading…
x
Reference in New Issue
Block a user