添加训练算法模拟xgboosT,训练可以完成,预测读取还有问题
数据文件保存机构改为### 1.2. 文件存储位置 - **最终产物**: 所有最终模型、元数据文件、损失图等,统一存放在 `saved_models/` 根目录下。 - **过程文件**: 所有训练过程中的检查点文件,统一存放在 `saved_models/checkpoints/` 目录下。 ### 1.3. 文件名生成规则 1. **构建逻辑路径**: 根据训练参数(模式、范围、类型、版本)确定逻辑路径。 - *示例*: `product/P001_all/mlstm/v2` 2. **生成文件名前缀**: 将逻辑路径中的所有 `/` 替换为 `_`。 - *示例*: `product_P001_all_mlstm_v2` 3. **拼接文件后缀**: 在前缀后加上描述文件类型的后缀。 - `_model.pth` - `_loss_curve.png` - `_checkpoint_best.pth` - `_checkpoint_epoch_{N}.pth` #### **完整示例:** - **最终模型**: `saved_models/product_P001_all_mlstm_v2_model.pth` - **最佳检查点**: `saved_models/checkpoints/product_P001_all_mlstm_v2_checkpoint_best.pth` - **Epoch 50 检查点**: `saved_models/checkpoints/product_P001_all_mlstm_v2_checkpoint_epoch_50.pth`
This commit is contained in:
parent
28bae35783
commit
87df49f764
Binary file not shown.
@ -56,3 +56,4 @@ tzdata==2025.2
|
||||
werkzeug==3.1.3
|
||||
win32-setctime==1.2.0
|
||||
wsproto==1.2.0
|
||||
xgboost
|
||||
|
@ -46,6 +46,7 @@ from trainers.mlstm_trainer import train_product_model_with_mlstm
|
||||
from trainers.kan_trainer import train_product_model_with_kan
|
||||
from trainers.tcn_trainer import train_product_model_with_tcn
|
||||
from trainers.transformer_trainer import train_product_model_with_transformer
|
||||
from trainers.xgboost_trainer import train_product_model_with_xgboost
|
||||
|
||||
# 导入预测函数
|
||||
from predictors.model_predictor import load_model_and_predict
|
||||
@ -942,7 +943,7 @@ def get_all_training_tasks():
|
||||
'type': 'object',
|
||||
'properties': {
|
||||
'product_id': {'type': 'string', 'description': '例如 P001'},
|
||||
'model_type': {'type': 'string', 'enum': ['mlstm', 'transformer', 'kan', 'optimized_kan', 'tcn'], 'description': '要训练的模型类型'},
|
||||
'model_type': {'type': 'string', 'enum': ['mlstm', 'transformer', 'kan', 'optimized_kan', 'tcn', 'xgboost'], 'description': '要训练的模型类型'},
|
||||
'store_id': {'type': 'string', 'description': '店铺ID,如 S001。为空时使用全局聚合数据'},
|
||||
'epochs': {'type': 'integer', 'default': 50, 'description': '训练轮次'}
|
||||
},
|
||||
@ -1006,7 +1007,7 @@ def start_training():
|
||||
pass
|
||||
|
||||
# 检查模型类型是否有效
|
||||
valid_model_types = ['mlstm', 'kan', 'optimized_kan', 'transformer', 'tcn']
|
||||
valid_model_types = ['mlstm', 'kan', 'optimized_kan', 'transformer', 'tcn', 'xgboost']
|
||||
if model_type not in valid_model_types:
|
||||
return jsonify({'error': '无效的模型类型'}), 400
|
||||
|
||||
@ -1425,7 +1426,7 @@ def get_training_status(task_id):
|
||||
'type': 'object',
|
||||
'properties': {
|
||||
'product_id': {'type': 'string'},
|
||||
'model_type': {'type': 'string', 'enum': ['mlstm', 'transformer', 'kan', 'optimized_kan', 'tcn']},
|
||||
'model_type': {'type': 'string', 'enum': ['mlstm', 'transformer', 'kan', 'optimized_kan', 'tcn', 'xgboost']},
|
||||
'store_id': {'type': 'string', 'description': '店铺ID,如 S001。为空时使用全局模型'},
|
||||
'version': {'type': 'string'},
|
||||
'training_mode': {'type': 'string', 'enum': ['product', 'store', 'global'], 'default': 'product'},
|
||||
@ -2110,7 +2111,7 @@ def delete_prediction(prediction_id):
|
||||
'in': 'query',
|
||||
'type': 'string',
|
||||
'required': False,
|
||||
'description': "按模型类型筛选 (mlstm, kan, transformer, tcn)"
|
||||
'description': "按模型类型筛选 (mlstm, kan, transformer, tcn, xgboost)"
|
||||
},
|
||||
{
|
||||
'name': 'page',
|
||||
@ -2172,7 +2173,7 @@ def list_models():
|
||||
in: query
|
||||
type: string
|
||||
required: false
|
||||
description: "按模型类型筛选 (mlstm, kan, transformer, tcn)"
|
||||
description: "按模型类型筛选 (mlstm, kan, transformer, tcn, xgboost)"
|
||||
- name: store_id
|
||||
in: query
|
||||
type: string
|
||||
@ -2775,7 +2776,7 @@ def run_prediction(model_type, product_id, model_id, future_days, start_date, ve
|
||||
'in': 'path',
|
||||
'type': 'string',
|
||||
'required': True,
|
||||
'description': '模型类型,例如mlstm, kan, transformer, tcn, optimized_kan'
|
||||
'description': '模型类型,例如mlstm, kan, transformer, tcn, optimized_kan, xgboost'
|
||||
},
|
||||
{
|
||||
'name': 'product_id',
|
||||
@ -3687,6 +3688,12 @@ def get_model_types():
|
||||
'name': 'TCN',
|
||||
'description': '时间卷积网络,适合处理长序列和平行计算',
|
||||
'tag_type': 'danger'
|
||||
},
|
||||
{
|
||||
'id': 'xgboost',
|
||||
'name': 'XGBoost',
|
||||
'description': '一种高效的梯度提升决策树模型,广泛用于各种预测任务。',
|
||||
'tag_type': 'success'
|
||||
}
|
||||
]
|
||||
|
||||
|
@ -15,7 +15,8 @@ from trainers import (
|
||||
train_product_model_with_mlstm,
|
||||
train_product_model_with_kan,
|
||||
train_product_model_with_tcn,
|
||||
train_product_model_with_transformer
|
||||
train_product_model_with_transformer,
|
||||
train_product_model_with_xgboost
|
||||
)
|
||||
from predictors.model_predictor import load_model_and_predict
|
||||
from utils.data_utils import prepare_data, prepare_sequences
|
||||
@ -263,6 +264,16 @@ class PharmacyPredictor:
|
||||
task_id=task_id,
|
||||
path_info=path_info
|
||||
)
|
||||
elif model_type == 'xgboost':
|
||||
metrics, _ = train_product_model_with_xgboost(
|
||||
product_id=product_id,
|
||||
store_id=store_id,
|
||||
epochs=epochs,
|
||||
socketio=socketio,
|
||||
task_id=task_id,
|
||||
version=version,
|
||||
path_info=path_info
|
||||
)
|
||||
else:
|
||||
log_message(f"不支持的模型类型: {model_type}", 'error')
|
||||
return None
|
||||
|
@ -8,8 +8,10 @@ import pandas as pd
|
||||
import numpy as np
|
||||
from datetime import datetime, timedelta
|
||||
import matplotlib.pyplot as plt
|
||||
import xgboost as xgb
|
||||
from sklearn.preprocessing import MinMaxScaler
|
||||
import sklearn.preprocessing._data # 添加这一行以支持MinMaxScaler的反序列化
|
||||
import joblib
|
||||
|
||||
from models.transformer_model import TimeSeriesTransformer
|
||||
from models.slstm_model import sLSTM as ScalarLSTM
|
||||
@ -30,7 +32,7 @@ def load_model_and_predict(product_id, model_type, model_path=None, store_id=Non
|
||||
|
||||
参数:
|
||||
product_id: 产品ID
|
||||
model_type: 模型类型 ('transformer', 'mlstm', 'kan', 'tcn', 'optimized_kan')
|
||||
model_type: 模型类型 ('transformer', 'mlstm', 'kan', 'tcn', 'optimized_kan', 'xgboost')
|
||||
model_path: 模型的完整文件路径
|
||||
store_id: 店铺ID,为None时使用全局模型
|
||||
future_days: 预测未来天数
|
||||
@ -118,177 +120,241 @@ def load_model_and_predict(product_id, model_type, model_path=None, store_id=Non
|
||||
# 加载模型和配置
|
||||
try:
|
||||
# 首先尝试使用weights_only=False加载
|
||||
try:
|
||||
print("尝试使用 weights_only=False 加载模型")
|
||||
checkpoint = torch.load(model_path, map_location=DEVICE, weights_only=False)
|
||||
except Exception as e:
|
||||
print(f"使用weights_only=False加载失败: {str(e)}")
|
||||
print("尝试使用默认参数加载模型")
|
||||
checkpoint = torch.load(model_path, map_location=DEVICE)
|
||||
if model_type == 'xgboost':
|
||||
if not os.path.exists(model_path):
|
||||
print(f"XGBoost模型文件不存在: {model_path}")
|
||||
return None
|
||||
# 加载元数据
|
||||
metadata = joblib.load(model_path)
|
||||
model_file_path = metadata['model_file']
|
||||
|
||||
print(f"模型加载成功,检查checkpoint类型: {type(checkpoint)}")
|
||||
if isinstance(checkpoint, dict):
|
||||
print(f"checkpoint包含的键: {list(checkpoint.keys())}")
|
||||
if not os.path.exists(model_file_path):
|
||||
print(f"引用的XGBoost模型文件不存在: {model_file_path}")
|
||||
return None
|
||||
|
||||
# 加载原生Booster模型
|
||||
model = xgb.Booster()
|
||||
model.load_model(model_file_path)
|
||||
|
||||
config = metadata['config']
|
||||
metrics = metadata['metrics']
|
||||
scaler_X = metadata['scaler_X']
|
||||
scaler_y = metadata['scaler_y']
|
||||
print("XGBoost原生模型及元数据加载成功")
|
||||
else:
|
||||
print(f"checkpoint不是字典类型,而是: {type(checkpoint)}")
|
||||
return None
|
||||
try:
|
||||
print("尝试使用 weights_only=False 加载模型")
|
||||
checkpoint = torch.load(model_path, map_location=DEVICE, weights_only=False)
|
||||
except Exception as e:
|
||||
print(f"使用weights_only=False加载失败: {str(e)}")
|
||||
print("尝试使用默认参数加载模型")
|
||||
checkpoint = torch.load(model_path, map_location=DEVICE)
|
||||
|
||||
print(f"模型加载成功,检查checkpoint类型: {type(checkpoint)}")
|
||||
if isinstance(checkpoint, dict):
|
||||
print(f"checkpoint包含的键: {list(checkpoint.keys())}")
|
||||
else:
|
||||
print(f"checkpoint不是字典类型,而是: {type(checkpoint)}")
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"加载模型失败: {str(e)}")
|
||||
return None
|
||||
|
||||
# 检查并获取配置
|
||||
if 'config' not in checkpoint:
|
||||
print("模型文件中没有配置信息")
|
||||
return None
|
||||
# XGBoost有不同的处理逻辑
|
||||
if model_type == 'xgboost':
|
||||
look_back = config['look_back']
|
||||
features = config['features']
|
||||
|
||||
config = checkpoint['config']
|
||||
print(f"模型配置: {config}")
|
||||
|
||||
# 检查并获取缩放器
|
||||
if 'scaler_X' not in checkpoint or 'scaler_y' not in checkpoint:
|
||||
print("模型文件中没有缩放器信息")
|
||||
return None
|
||||
# 准备输入数据
|
||||
recent_data = product_df.iloc[-look_back:].copy()
|
||||
|
||||
scaler_X = checkpoint['scaler_X']
|
||||
scaler_y = checkpoint['scaler_y']
|
||||
|
||||
# 创建模型实例
|
||||
try:
|
||||
if model_type == 'transformer':
|
||||
model = TimeSeriesTransformer(
|
||||
num_features=config['input_dim'],
|
||||
d_model=config['hidden_size'],
|
||||
nhead=config['num_heads'],
|
||||
num_encoder_layers=config['num_layers'],
|
||||
dim_feedforward=config['hidden_size'] * 2,
|
||||
dropout=config['dropout'],
|
||||
output_sequence_length=config['output_dim'],
|
||||
seq_length=config['sequence_length'],
|
||||
batch_size=32
|
||||
).to(DEVICE)
|
||||
elif model_type == 'slstm':
|
||||
model = ScalarLSTM(
|
||||
input_dim=config['input_dim'],
|
||||
hidden_dim=config['hidden_size'],
|
||||
output_dim=config['output_dim'],
|
||||
num_layers=config['num_layers'],
|
||||
dropout=config['dropout']
|
||||
).to(DEVICE)
|
||||
elif model_type == 'mlstm':
|
||||
# 获取配置参数,如果不存在则使用默认值
|
||||
embed_dim = config.get('embed_dim', 32)
|
||||
dense_dim = config.get('dense_dim', 32)
|
||||
num_heads = config.get('num_heads', 4)
|
||||
num_blocks = config.get('num_blocks', 3)
|
||||
predictions = []
|
||||
current_input_df = recent_data[features].copy()
|
||||
|
||||
for _ in range(future_days):
|
||||
# 归一化输入数据并展平
|
||||
input_scaled = scaler_X.transform(current_input_df.values)
|
||||
input_vector = input_scaled.flatten().reshape(1, -1)
|
||||
|
||||
model = MatrixLSTM(
|
||||
num_features=config['input_dim'],
|
||||
hidden_size=config['hidden_size'],
|
||||
mlstm_layers=config['num_layers'],
|
||||
embed_dim=embed_dim,
|
||||
dense_dim=dense_dim,
|
||||
num_heads=num_heads,
|
||||
dropout_rate=config['dropout'],
|
||||
num_blocks=num_blocks,
|
||||
output_sequence_length=config['output_dim']
|
||||
).to(DEVICE)
|
||||
elif model_type == 'kan':
|
||||
model = KANForecaster(
|
||||
input_features=config['input_dim'],
|
||||
hidden_sizes=[config['hidden_size'], config['hidden_size']*2, config['hidden_size']],
|
||||
output_sequence_length=config['output_dim']
|
||||
).to(DEVICE)
|
||||
elif model_type == 'optimized_kan':
|
||||
model = OptimizedKANForecaster(
|
||||
input_features=config['input_dim'],
|
||||
hidden_sizes=[config['hidden_size'], config['hidden_size']*2, config['hidden_size']],
|
||||
output_sequence_length=config['output_dim']
|
||||
).to(DEVICE)
|
||||
elif model_type == 'tcn':
|
||||
model = TCNForecaster(
|
||||
num_features=config['input_dim'],
|
||||
output_sequence_length=config['output_dim'],
|
||||
num_channels=[config['hidden_size']] * config['num_layers'],
|
||||
kernel_size=3,
|
||||
dropout=config['dropout']
|
||||
).to(DEVICE)
|
||||
else:
|
||||
print(f"不支持的模型类型: {model_type}")
|
||||
return None
|
||||
# 预测缩放后的值
|
||||
dpredict = xgb.DMatrix(input_vector)
|
||||
prediction_scaled = model.predict(dpredict)
|
||||
|
||||
print(f"模型实例创建成功: {type(model)}")
|
||||
except Exception as e:
|
||||
print(f"创建模型实例失败: {str(e)}")
|
||||
return None
|
||||
|
||||
# 加载模型参数
|
||||
try:
|
||||
model.load_state_dict(checkpoint['model_state_dict'])
|
||||
model.eval()
|
||||
print("模型参数加载成功")
|
||||
except Exception as e:
|
||||
print(f"加载模型参数失败: {str(e)}")
|
||||
return None
|
||||
|
||||
# 准备输入数据
|
||||
try:
|
||||
features = ['sales', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature']
|
||||
sequence_length = config['sequence_length']
|
||||
|
||||
# 获取最近的sequence_length天数据作为输入
|
||||
recent_data = product_df.iloc[-sequence_length:].copy()
|
||||
|
||||
# 如果指定了起始日期,则使用该日期之后的数据
|
||||
if start_date:
|
||||
if isinstance(start_date, str):
|
||||
start_date = datetime.strptime(start_date, '%Y-%m-%d')
|
||||
recent_data = product_df[product_df['date'] >= start_date].iloc[:sequence_length].copy()
|
||||
if len(recent_data) < sequence_length:
|
||||
print(f"警告: 从指定日期 {start_date} 开始的数据少于所需的 {sequence_length} 天")
|
||||
# 补充数据
|
||||
missing_days = sequence_length - len(recent_data)
|
||||
additional_data = product_df[product_df['date'] < start_date].iloc[-missing_days:].copy()
|
||||
recent_data = pd.concat([additional_data, recent_data]).reset_index(drop=True)
|
||||
|
||||
print(f"输入数据准备完成,形状: {recent_data.shape}")
|
||||
except Exception as e:
|
||||
print(f"准备输入数据失败: {str(e)}")
|
||||
return None
|
||||
|
||||
# 归一化输入数据
|
||||
try:
|
||||
X = recent_data[features].values
|
||||
X_scaled = scaler_X.transform(X)
|
||||
|
||||
# 转换为模型输入格式
|
||||
X_input = torch.tensor(X_scaled.reshape(1, sequence_length, -1), dtype=torch.float32).to(DEVICE)
|
||||
print(f"输入张量准备完成,形状: {X_input.shape}")
|
||||
except Exception as e:
|
||||
print(f"归一化输入数据失败: {str(e)}")
|
||||
return None
|
||||
|
||||
# 预测
|
||||
try:
|
||||
with torch.no_grad():
|
||||
y_pred_scaled = model(X_input).cpu().numpy()
|
||||
print(f"原始预测输出形状: {y_pred_scaled.shape}")
|
||||
|
||||
# 处理TCN、Transformer、mLSTM和KAN模型的输出,确保形状正确
|
||||
if model_type in ['tcn', 'transformer', 'mlstm', 'kan', 'optimized_kan'] and len(y_pred_scaled.shape) == 3:
|
||||
y_pred_scaled = y_pred_scaled.squeeze(-1)
|
||||
print(f"处理后的预测输出形状: {y_pred_scaled.shape}")
|
||||
|
||||
# 反归一化预测结果
|
||||
y_pred = scaler_y.inverse_transform(y_pred_scaled.reshape(-1, 1)).flatten()
|
||||
print(f"反归一化后的预测结果: {y_pred}")
|
||||
|
||||
# 反归一化得到真实预测值
|
||||
prediction = scaler_y.inverse_transform(prediction_scaled.reshape(-1, 1)).flatten()[0]
|
||||
predictions.append(prediction)
|
||||
|
||||
# 更新输入窗口以进行下一次预测
|
||||
# 创建新的一行,包含真实的预测值
|
||||
new_row_values = current_input_df.iloc[-1].copy()
|
||||
new_row_values['sales'] = prediction
|
||||
# 可以在这里添加更复杂的未来特征生成逻辑(例如,根据新日期更新weekday, month等)
|
||||
|
||||
new_row_df = pd.DataFrame([new_row_values], columns=features)
|
||||
|
||||
# 滚动窗口
|
||||
current_input_df = pd.concat([current_input_df.iloc[1:], new_row_df], ignore_index=True)
|
||||
|
||||
# 生成预测日期
|
||||
last_date = recent_data['date'].iloc[-1]
|
||||
pred_dates = [(last_date + timedelta(days=i+1)) for i in range(len(y_pred))]
|
||||
print(f"预测日期: {pred_dates}")
|
||||
except Exception as e:
|
||||
print(f"执行预测失败: {str(e)}")
|
||||
return None
|
||||
pred_dates = [last_date + timedelta(days=i+1) for i in range(future_days)]
|
||||
|
||||
y_pred = np.array(predictions)
|
||||
|
||||
else: # 原有的PyTorch模型逻辑
|
||||
# 检查并获取配置
|
||||
if 'config' not in checkpoint:
|
||||
print("模型文件中没有配置信息")
|
||||
return None
|
||||
|
||||
config = checkpoint['config']
|
||||
print(f"模型配置: {config}")
|
||||
|
||||
# 检查并获取缩放器
|
||||
if 'scaler_X' not in checkpoint or 'scaler_y' not in checkpoint:
|
||||
print("模型文件中没有缩放器信息")
|
||||
return None
|
||||
|
||||
scaler_X = checkpoint['scaler_X']
|
||||
scaler_y = checkpoint['scaler_y']
|
||||
|
||||
# 创建模型实例
|
||||
try:
|
||||
if model_type == 'transformer':
|
||||
model = TimeSeriesTransformer(
|
||||
num_features=config['input_dim'],
|
||||
d_model=config['hidden_size'],
|
||||
nhead=config['num_heads'],
|
||||
num_encoder_layers=config['num_layers'],
|
||||
dim_feedforward=config['hidden_size'] * 2,
|
||||
dropout=config['dropout'],
|
||||
output_sequence_length=config['output_dim'],
|
||||
seq_length=config['sequence_length'],
|
||||
batch_size=32
|
||||
).to(DEVICE)
|
||||
elif model_type == 'slstm':
|
||||
model = ScalarLSTM(
|
||||
input_dim=config['input_dim'],
|
||||
hidden_dim=config['hidden_size'],
|
||||
output_dim=config['output_dim'],
|
||||
num_layers=config['num_layers'],
|
||||
dropout=config['dropout']
|
||||
).to(DEVICE)
|
||||
elif model_type == 'mlstm':
|
||||
# 获取配置参数,如果不存在则使用默认值
|
||||
embed_dim = config.get('embed_dim', 32)
|
||||
dense_dim = config.get('dense_dim', 32)
|
||||
num_heads = config.get('num_heads', 4)
|
||||
num_blocks = config.get('num_blocks', 3)
|
||||
|
||||
model = MatrixLSTM(
|
||||
num_features=config['input_dim'],
|
||||
hidden_size=config['hidden_size'],
|
||||
mlstm_layers=config['num_layers'],
|
||||
embed_dim=embed_dim,
|
||||
dense_dim=dense_dim,
|
||||
num_heads=num_heads,
|
||||
dropout_rate=config['dropout'],
|
||||
num_blocks=num_blocks,
|
||||
output_sequence_length=config['output_dim']
|
||||
).to(DEVICE)
|
||||
elif model_type == 'kan':
|
||||
model = KANForecaster(
|
||||
input_features=config['input_dim'],
|
||||
hidden_sizes=[config['hidden_size'], config['hidden_size']*2, config['hidden_size']],
|
||||
output_sequence_length=config['output_dim']
|
||||
).to(DEVICE)
|
||||
elif model_type == 'optimized_kan':
|
||||
model = OptimizedKANForecaster(
|
||||
input_features=config['input_dim'],
|
||||
hidden_sizes=[config['hidden_size'], config['hidden_size']*2, config['hidden_size']],
|
||||
output_sequence_length=config['output_dim']
|
||||
).to(DEVICE)
|
||||
elif model_type == 'tcn':
|
||||
model = TCNForecaster(
|
||||
num_features=config['input_dim'],
|
||||
output_sequence_length=config['output_dim'],
|
||||
num_channels=[config['hidden_size']] * config['num_layers'],
|
||||
kernel_size=3,
|
||||
dropout=config['dropout']
|
||||
).to(DEVICE)
|
||||
else:
|
||||
print(f"不支持的模型类型: {model_type}")
|
||||
return None
|
||||
|
||||
print(f"模型实例创建成功: {type(model)}")
|
||||
except Exception as e:
|
||||
print(f"创建模型实例失败: {str(e)}")
|
||||
return None
|
||||
|
||||
# 加载模型参数
|
||||
try:
|
||||
model.load_state_dict(checkpoint['model_state_dict'])
|
||||
model.eval()
|
||||
print("模型参数加载成功")
|
||||
except Exception as e:
|
||||
print(f"加载模型参数失败: {str(e)}")
|
||||
return None
|
||||
|
||||
# 准备输入数据
|
||||
try:
|
||||
features = ['sales', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature']
|
||||
sequence_length = config['sequence_length']
|
||||
|
||||
# 获取最近的sequence_length天数据作为输入
|
||||
recent_data = product_df.iloc[-sequence_length:].copy()
|
||||
|
||||
# 如果指定了起始日期,则使用该日期之后的数据
|
||||
if start_date:
|
||||
if isinstance(start_date, str):
|
||||
start_date = datetime.strptime(start_date, '%Y-%m-%d')
|
||||
recent_data = product_df[product_df['date'] >= start_date].iloc[:sequence_length].copy()
|
||||
if len(recent_data) < sequence_length:
|
||||
print(f"警告: 从指定日期 {start_date} 开始的数据少于所需的 {sequence_length} 天")
|
||||
# 补充数据
|
||||
missing_days = sequence_length - len(recent_data)
|
||||
additional_data = product_df[product_df['date'] < start_date].iloc[-missing_days:].copy()
|
||||
recent_data = pd.concat([additional_data, recent_data]).reset_index(drop=True)
|
||||
|
||||
print(f"输入数据准备完成,形状: {recent_data.shape}")
|
||||
except Exception as e:
|
||||
print(f"准备输入数据失败: {str(e)}")
|
||||
return None
|
||||
|
||||
# 归一化输入数据
|
||||
try:
|
||||
X = recent_data[features].values
|
||||
X_scaled = scaler_X.transform(X)
|
||||
|
||||
# 转换为模型输入格式
|
||||
X_input = torch.tensor(X_scaled.reshape(1, sequence_length, -1), dtype=torch.float32).to(DEVICE)
|
||||
print(f"输入张量准备完成,形状: {X_input.shape}")
|
||||
except Exception as e:
|
||||
print(f"归一化输入数据失败: {str(e)}")
|
||||
return None
|
||||
|
||||
# 预测
|
||||
try:
|
||||
with torch.no_grad():
|
||||
y_pred_scaled = model(X_input).cpu().numpy()
|
||||
print(f"原始预测输出形状: {y_pred_scaled.shape}")
|
||||
|
||||
# 处理TCN、Transformer、mLSTM和KAN模型的输出,确保形状正确
|
||||
if model_type in ['tcn', 'transformer', 'mlstm', 'kan', 'optimized_kan'] and len(y_pred_scaled.shape) == 3:
|
||||
y_pred_scaled = y_pred_scaled.squeeze(-1)
|
||||
print(f"处理后的预测输出形状: {y_pred_scaled.shape}")
|
||||
|
||||
# 反归一化预测结果
|
||||
y_pred = scaler_y.inverse_transform(y_pred_scaled.reshape(-1, 1)).flatten()
|
||||
print(f"反归一化后的预测结果: {y_pred}")
|
||||
|
||||
# 生成预测日期
|
||||
last_date = recent_data['date'].iloc[-1]
|
||||
pred_dates = [(last_date + timedelta(days=i+1)) for i in range(len(y_pred))]
|
||||
print(f"预测日期: {pred_dates}")
|
||||
except Exception as e:
|
||||
print(f"执行预测失败: {str(e)}")
|
||||
return None
|
||||
|
||||
# 创建预测结果DataFrame
|
||||
try:
|
||||
@ -348,4 +414,4 @@ def load_model_and_predict(product_id, model_type, model_path=None, store_id=Non
|
||||
print(f"预测过程中出现未捕获的异常: {str(e)}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return None
|
||||
return None
|
@ -6,6 +6,7 @@ from .mlstm_trainer import train_product_model_with_mlstm
|
||||
from .kan_trainer import train_product_model_with_kan
|
||||
from .tcn_trainer import train_product_model_with_tcn
|
||||
from .transformer_trainer import train_product_model_with_transformer
|
||||
from .xgboost_trainer import train_product_model_with_xgboost
|
||||
|
||||
# 默认训练函数
|
||||
from .mlstm_trainer import train_product_model_with_mlstm as train_product_model
|
||||
@ -15,5 +16,6 @@ __all__ = [
|
||||
'train_product_model_with_mlstm',
|
||||
'train_product_model_with_kan',
|
||||
'train_product_model_with_tcn',
|
||||
'train_product_model_with_transformer'
|
||||
'train_product_model_with_transformer',
|
||||
'train_product_model_with_xgboost'
|
||||
]
|
||||
|
296
server/trainers/xgboost_trainer.py
Normal file
296
server/trainers/xgboost_trainer.py
Normal file
@ -0,0 +1,296 @@
|
||||
import xgboost as xgb
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import os
|
||||
import joblib
|
||||
import xgboost as xgb
|
||||
from xgboost.callback import EarlyStopping
|
||||
import matplotlib.pyplot as plt
|
||||
from sklearn.model_selection import train_test_split
|
||||
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
|
||||
from sklearn.preprocessing import MinMaxScaler
|
||||
|
||||
# 从项目中导入正确的工具函数和配置
|
||||
from utils.multi_store_data_utils import load_multi_store_data
|
||||
from core.config import DEFAULT_DATA_PATH
|
||||
from utils.file_save import ModelPathManager
|
||||
from analysis.metrics import evaluate_model
|
||||
|
||||
# 重构后的原生API兼容回调
|
||||
class EpochCheckpointCallback(xgb.callback.TrainingCallback):
|
||||
def __init__(self, save_period, payload, base_path):
|
||||
super().__init__()
|
||||
self.save_period = save_period
|
||||
self.payload = payload
|
||||
self.base_path = base_path
|
||||
self.best_score = float('inf')
|
||||
|
||||
def _save_checkpoint(self, model, path_suffix):
|
||||
"""辅助函数,用于保存模型和元数据检查点"""
|
||||
metadata_path = self.base_path.replace('_model.pth', f'_{path_suffix}.pth')
|
||||
model_file_path = metadata_path.replace('.pth', '.xgb')
|
||||
|
||||
# 确保目录存在
|
||||
os.makedirs(os.path.dirname(metadata_path), exist_ok=True)
|
||||
|
||||
# 保存原生Booster模型
|
||||
model.save_model(model_file_path)
|
||||
|
||||
# 更新payload中的模型文件引用
|
||||
self.payload['model_file'] = model_file_path
|
||||
joblib.dump(self.payload, metadata_path)
|
||||
|
||||
print(f"[Checkpoint] 已保存检查点到: {metadata_path}")
|
||||
|
||||
def after_iteration(self, model, epoch, evals_log):
|
||||
# 获取当前验证集的分数 (假设'test'是验证集)
|
||||
current_score = evals_log['test']['rmse'][-1]
|
||||
|
||||
# 保存最佳模型
|
||||
if current_score < self.best_score:
|
||||
self.best_score = current_score
|
||||
self._save_checkpoint(model, 'checkpoint_best')
|
||||
|
||||
# 保存周期性检查点
|
||||
if (epoch + 1) % self.save_period == 0:
|
||||
self._save_checkpoint(model, f'checkpoint_epoch_{epoch + 1}')
|
||||
|
||||
return False # 继续训练
|
||||
|
||||
def create_dataset(data, look_back=7):
|
||||
"""
|
||||
将时间序列数据转换为监督学习格式。
|
||||
:param data: 输入的DataFrame,包含特征和目标。
|
||||
:param look_back: 用于预测的时间窗口大小。
|
||||
:return: X (特征), y (目标)
|
||||
"""
|
||||
X, y = [], []
|
||||
feature_columns = [col for col in data.columns if col != 'date']
|
||||
|
||||
for i in range(len(data) - look_back):
|
||||
# 展平look_back窗口内的所有特征
|
||||
features = data[feature_columns].iloc[i:(i + look_back)].values.flatten()
|
||||
X.append(features)
|
||||
# 目标是窗口后的第一个销售值
|
||||
y.append(data['sales'].iloc[i + look_back])
|
||||
|
||||
return np.array(X), np.array(y)
|
||||
|
||||
def train_product_model_with_xgboost(
|
||||
product_id,
|
||||
store_id=None,
|
||||
epochs=100, # XGBoost中n_estimators更常用
|
||||
look_back=7,
|
||||
socketio=None,
|
||||
task_id=None,
|
||||
version='v1',
|
||||
path_info=None,
|
||||
**kwargs):
|
||||
"""
|
||||
使用XGBoost训练产品销售预测模型。
|
||||
"""
|
||||
|
||||
def emit_progress(message, progress=None):
|
||||
if socketio and task_id:
|
||||
payload = {'task_id': task_id, 'message': message}
|
||||
if progress is not None:
|
||||
payload['progress'] = progress
|
||||
socketio.emit('training_update', payload, namespace='/api/training', room=task_id)
|
||||
print(f"[{task_id}] {message}")
|
||||
|
||||
try:
|
||||
model_path = None
|
||||
emit_progress("开始XGBoost模型训练...", 0)
|
||||
|
||||
# 1. 加载数据
|
||||
# 使用正确的函数并从config导入路径
|
||||
full_df = load_multi_store_data(DEFAULT_DATA_PATH)
|
||||
|
||||
# 根据 store_id 和 product_id 筛选数据
|
||||
if store_id:
|
||||
df = full_df[(full_df['product_id'] == product_id) & (full_df['store_id'] == store_id)].copy()
|
||||
else:
|
||||
# 如果没有store_id,则聚合该产品在所有店铺的数据
|
||||
df = full_df[full_df['product_id'] == product_id].groupby('date').agg({
|
||||
'sales': 'sum',
|
||||
'weekday': 'first',
|
||||
'month': 'first',
|
||||
'is_holiday': 'max',
|
||||
'is_weekend': 'max',
|
||||
'is_promotion': 'max',
|
||||
'temperature': 'mean'
|
||||
}).reset_index()
|
||||
|
||||
if df.empty:
|
||||
raise ValueError(f"加载的数据为空 (product: {product_id}, store: {store_id}),无法进行训练。")
|
||||
|
||||
# 确保数据按日期排序
|
||||
df = df.sort_values('date').reset_index(drop=True)
|
||||
|
||||
emit_progress("数据加载完成。", 10)
|
||||
|
||||
# 2. 创建数据集
|
||||
features_to_use = ['sales', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature']
|
||||
# 确保所有需要的特征都存在
|
||||
for col in features_to_use:
|
||||
if col not in df.columns:
|
||||
# 如果特征不存在,用0填充
|
||||
df[col] = 0
|
||||
|
||||
df_features = df[['date'] + features_to_use]
|
||||
|
||||
X, y = create_dataset(df_features, look_back)
|
||||
if X.shape[0] == 0:
|
||||
raise ValueError("创建数据集后样本数量为0,请检查数据量和look_back参数。")
|
||||
|
||||
emit_progress(f"数据集创建完成,样本数: {X.shape[0]}", 20)
|
||||
|
||||
# 3. 划分训练集和测试集
|
||||
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
|
||||
|
||||
# 数据缩放
|
||||
scaler_X = MinMaxScaler()
|
||||
X_train_scaled = scaler_X.fit_transform(X_train)
|
||||
X_test_scaled = scaler_X.transform(X_test)
|
||||
|
||||
scaler_y = MinMaxScaler()
|
||||
y_train_scaled = scaler_y.fit_transform(y_train.reshape(-1, 1))
|
||||
# y_test is not scaled, used for metric calculation against inverse_transformed predictions
|
||||
|
||||
emit_progress("数据划分和缩放完成。", 30)
|
||||
|
||||
# 4. 切换到XGBoost原生API
|
||||
params = {
|
||||
'learning_rate': kwargs.get('learning_rate', 0.1),
|
||||
'max_depth': kwargs.get('max_depth', 5),
|
||||
'subsample': kwargs.get('subsample', 0.8),
|
||||
'colsample_bytree': kwargs.get('colsample_bytree', 0.8),
|
||||
'objective': 'reg:squarederror',
|
||||
'eval_metric': 'rmse',
|
||||
'random_state': 42
|
||||
}
|
||||
|
||||
dtrain = xgb.DMatrix(X_train_scaled, label=y_train_scaled.ravel())
|
||||
dtest = xgb.DMatrix(X_test_scaled, label=scaler_y.transform(y_test.reshape(-1, 1)).ravel())
|
||||
|
||||
emit_progress("开始模型训练...", 40)
|
||||
|
||||
# 定义验证集
|
||||
evals = [(dtrain, 'train'), (dtest, 'test')]
|
||||
|
||||
# 准备回调
|
||||
callbacks = []
|
||||
checkpoint_interval = kwargs.get('checkpoint_interval', 10) # 默认每10轮保存一次
|
||||
if path_info and path_info.get('model_path') and checkpoint_interval > 0:
|
||||
# 准备用于保存的payload,模型对象将在回调中动态更新
|
||||
checkpoint_payload = {
|
||||
'metrics': {}, # 检查点不保存最终指标
|
||||
'scaler_X': scaler_X,
|
||||
'scaler_y': scaler_y,
|
||||
'config': {
|
||||
'look_back': look_back,
|
||||
'features': features_to_use,
|
||||
'product_id': product_id,
|
||||
'store_id': store_id,
|
||||
'version': version
|
||||
}
|
||||
}
|
||||
checkpoint_callback = EpochCheckpointCallback(
|
||||
save_period=checkpoint_interval,
|
||||
payload=checkpoint_payload,
|
||||
base_path=path_info['model_path']
|
||||
)
|
||||
callbacks.append(checkpoint_callback)
|
||||
|
||||
# 添加早停回调 (移除save_best)
|
||||
callbacks.append(EarlyStopping(rounds=10))
|
||||
|
||||
# 用于存储评估结果
|
||||
evals_result = {}
|
||||
|
||||
model = xgb.train(
|
||||
params=params,
|
||||
dtrain=dtrain,
|
||||
num_boost_round=epochs,
|
||||
evals=evals,
|
||||
callbacks=callbacks,
|
||||
evals_result=evals_result,
|
||||
verbose_eval=False
|
||||
)
|
||||
emit_progress("模型训练完成。", 80)
|
||||
|
||||
# 绘制并保存损失曲线
|
||||
if path_info and path_info.get('model_path'):
|
||||
try:
|
||||
loss_curve_path = path_info['model_path'].replace('_model.pth', '_loss_curve.png')
|
||||
results = evals_result
|
||||
train_rmse = results['train']['rmse']
|
||||
test_rmse = results['test']['rmse']
|
||||
num_epochs = len(train_rmse)
|
||||
x_axis = range(0, num_epochs)
|
||||
|
||||
fig, ax = plt.subplots(figsize=(10, 6))
|
||||
ax.plot(x_axis, train_rmse, label='Train')
|
||||
ax.plot(x_axis, test_rmse, label='Test')
|
||||
ax.legend()
|
||||
plt.ylabel('RMSE')
|
||||
plt.xlabel('Epoch')
|
||||
plt.title('XGBoost RMSE Loss Curve')
|
||||
plt.savefig(loss_curve_path)
|
||||
plt.close(fig)
|
||||
emit_progress(f"损失曲线图已保存到: {loss_curve_path}")
|
||||
except Exception as e:
|
||||
emit_progress(f"警告: 绘制损失曲线失败: {str(e)}")
|
||||
|
||||
# 5. 评估模型
|
||||
dtest_pred = xgb.DMatrix(X_test_scaled)
|
||||
y_pred_scaled = model.predict(dtest_pred)
|
||||
y_pred = scaler_y.inverse_transform(y_pred_scaled.reshape(-1, 1)).flatten()
|
||||
|
||||
metrics = {
|
||||
'RMSE': np.sqrt(mean_squared_error(y_test, y_pred)),
|
||||
'MAE': mean_absolute_error(y_test, y_pred),
|
||||
'R2': r2_score(y_test, y_pred)
|
||||
}
|
||||
emit_progress(f"模型评估完成: {metrics}", 90)
|
||||
|
||||
# 6. 保存模型 (原生API方式)
|
||||
if path_info and path_info.get('model_path'):
|
||||
metadata_path = path_info['model_path']
|
||||
# 使用 .xgb 扩展名保存原生Booster模型
|
||||
model_file_path = metadata_path.replace('.pth', '.xgb')
|
||||
|
||||
# 确保目录存在
|
||||
os.makedirs(os.path.dirname(metadata_path), exist_ok=True)
|
||||
|
||||
# 使用原生方法保存Booster模型
|
||||
model.save_model(model_file_path)
|
||||
emit_progress(f"原生XGBoost模型已保存到: {model_file_path}")
|
||||
|
||||
# 保存元数据(包括模型文件路径)
|
||||
metadata_payload = {
|
||||
'model_file': model_file_path, # 保存模型文件的引用
|
||||
'metrics': metrics,
|
||||
'scaler_X': scaler_X,
|
||||
'scaler_y': scaler_y,
|
||||
'config': {
|
||||
'look_back': look_back,
|
||||
'features': features_to_use,
|
||||
'product_id': product_id,
|
||||
'store_id': store_id,
|
||||
'version': version
|
||||
}
|
||||
}
|
||||
joblib.dump(metadata_payload, metadata_path)
|
||||
model_path = metadata_path # 确保model_path被赋值
|
||||
emit_progress(f"模型元数据已保存到: {metadata_path}", 100)
|
||||
else:
|
||||
emit_progress("警告: 未提供path_info,模型未保存。", 100)
|
||||
|
||||
return metrics, model_path
|
||||
|
||||
except Exception as e:
|
||||
emit_progress(f"XGBoost训练失败: {str(e)}", 100)
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return {'error': str(e)}, None
|
@ -167,11 +167,25 @@ class TrainingWorker:
|
||||
})
|
||||
|
||||
if metrics:
|
||||
self.progress_queue.put({
|
||||
'task_id': task.task_id,
|
||||
'log_type': 'info',
|
||||
'message': f"📊 训练指标: MSE={metrics.get('mse', 'N/A'):.4f}, RMSE={metrics.get('rmse', 'N/A'):.4f}"
|
||||
})
|
||||
if 'error' in metrics:
|
||||
self.progress_queue.put({
|
||||
'task_id': task.task_id,
|
||||
'log_type': 'error',
|
||||
'message': f"❌ 训练返回错误: {metrics['error']}"
|
||||
})
|
||||
else:
|
||||
# 只有在没有错误时才格式化指标
|
||||
mse_val = metrics.get('mse', 'N/A')
|
||||
rmse_val = metrics.get('rmse', 'N/A')
|
||||
|
||||
mse_str = f"{mse_val:.4f}" if isinstance(mse_val, (int, float)) else mse_val
|
||||
rmse_str = f"{rmse_val:.4f}" if isinstance(rmse_val, (int, float)) else rmse_val
|
||||
|
||||
self.progress_queue.put({
|
||||
'task_id': task.task_id,
|
||||
'log_type': 'info',
|
||||
'message': f"📊 训练指标: MSE={mse_str}, RMSE={rmse_str}"
|
||||
})
|
||||
except ImportError as e:
|
||||
training_logger.error(f"❌ 导入训练器失败: {e}")
|
||||
# 返回模拟的训练结果用于测试
|
||||
@ -382,12 +396,13 @@ class TrainingProcessManager:
|
||||
# 只有在训练成功(metrics有效)时才保存版本信息
|
||||
if task.metrics and task.metrics.get('r2', -1) >= 0:
|
||||
if task.path_info:
|
||||
identifier = task.path_info.get('identifier')
|
||||
# 确保使用正确的、经过规范化处理的标识符
|
||||
version_control_identifier = task.path_info.get('identifier')
|
||||
version = task.path_info.get('version')
|
||||
if identifier and version:
|
||||
if version_control_identifier and version:
|
||||
try:
|
||||
self.path_manager.save_version_info(identifier, version)
|
||||
self.logger.info(f"✅ 版本信息已更新: identifier={identifier}, version={version}")
|
||||
self.path_manager.save_version_info(version_control_identifier, version)
|
||||
self.logger.info(f"✅ 版本信息已更新: identifier={version_control_identifier}, version={version}")
|
||||
task.version = version # 关键修复:将版本号保存到任务对象中
|
||||
except Exception as e:
|
||||
self.logger.error(f"❌ 更新版本文件失败: {e}")
|
||||
@ -398,12 +413,12 @@ class TrainingProcessManager:
|
||||
if self.websocket_callback:
|
||||
try:
|
||||
if action == 'complete':
|
||||
# 从任务信息中获取版本号
|
||||
# 从任务对象中获取权威的版本号
|
||||
version = None
|
||||
with self.lock:
|
||||
task = self.tasks.get(task_id)
|
||||
if task and task.path_info:
|
||||
version = task.path_info.get('version')
|
||||
if task:
|
||||
version = task.version
|
||||
|
||||
# 训练完成 - 发送完成状态
|
||||
self.websocket_callback('training_update', {
|
||||
|
118
test/verify_save_logic.py
Normal file
118
test/verify_save_logic.py
Normal file
@ -0,0 +1,118 @@
|
||||
import unittest
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
|
||||
# 将项目根目录添加到 sys.path,以解决模块导入问题
|
||||
# 这使得测试脚本可以直接运行,而无需复杂的路径配置
|
||||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
|
||||
if project_root not in sys.path:
|
||||
sys.path.insert(0, project_root)
|
||||
|
||||
from server.utils.file_save import ModelPathManager
|
||||
|
||||
class TestModelPathManager(unittest.TestCase):
|
||||
"""
|
||||
测试 ModelPathManager 是否严格遵循扁平化文件存储规范。
|
||||
"""
|
||||
def setUp(self):
|
||||
"""在每个测试用例开始前,设置测试环境。"""
|
||||
self.test_base_dir = 'test_saved_models'
|
||||
# 清理之前的测试目录和文件
|
||||
if os.path.exists(self.test_base_dir):
|
||||
shutil.rmtree(self.test_base_dir)
|
||||
self.path_manager = ModelPathManager(base_dir=self.test_base_dir)
|
||||
|
||||
def tearDown(self):
|
||||
"""在每个测试用例结束后,清理测试环境。"""
|
||||
if os.path.exists(self.test_base_dir):
|
||||
shutil.rmtree(self.test_base_dir)
|
||||
|
||||
def test_product_mode_path_generation(self):
|
||||
"""测试 'product' 模式下的路径生成是否符合规范。"""
|
||||
print("\n--- 测试 'product' 模式 ---")
|
||||
params = {
|
||||
'training_mode': 'product',
|
||||
'model_type': 'mlstm',
|
||||
'product_id': 'P001',
|
||||
'store_id': 'all'
|
||||
}
|
||||
|
||||
# 第一次调用,版本应为 1
|
||||
paths_v1 = self.path_manager.get_model_paths(**params)
|
||||
|
||||
# 验证版本号
|
||||
self.assertEqual(paths_v1['version'], 1)
|
||||
|
||||
# 验证文件名前缀
|
||||
expected_prefix_v1 = 'product_P001_all_mlstm_v1'
|
||||
self.assertEqual(paths_v1['filename_prefix'], expected_prefix_v1)
|
||||
|
||||
# 验证各个文件的完整路径
|
||||
self.assertEqual(paths_v1['model_path'], os.path.join(self.test_base_dir, f'{expected_prefix_v1}_model.pth'))
|
||||
self.assertEqual(paths_v1['metadata_path'], os.path.join(self.test_base_dir, f'{expected_prefix_v1}_metadata.json'))
|
||||
self.assertEqual(paths_v1['loss_curve_path'], os.path.join(self.test_base_dir, f'{expected_prefix_v1}_loss_curve.png'))
|
||||
|
||||
# 验证检查点路径
|
||||
checkpoint_dir = os.path.join(self.test_base_dir, 'checkpoints')
|
||||
self.assertEqual(paths_v1['checkpoint_dir'], checkpoint_dir)
|
||||
self.assertEqual(paths_v1['best_checkpoint_path'], os.path.join(checkpoint_dir, f'{expected_prefix_v1}_checkpoint_best.pth'))
|
||||
self.assertEqual(paths_v1['epoch_checkpoint_template'], os.path.join(checkpoint_dir, f'{expected_prefix_v1}_checkpoint_epoch_{{N}}.pth'))
|
||||
|
||||
print(f"生成的文件名前缀: {paths_v1['filename_prefix']}")
|
||||
print(f"生成的模型路径: {paths_v1['model_path']}")
|
||||
print("验证通过!")
|
||||
|
||||
# 模拟一次成功的训练,以触发版本递增
|
||||
self.path_manager.save_version_info(paths_v1['identifier'], paths_v1['version'])
|
||||
|
||||
# 第二次调用,版本应为 2
|
||||
paths_v2 = self.path_manager.get_model_paths(**params)
|
||||
self.assertEqual(paths_v2['version'], 2)
|
||||
expected_prefix_v2 = 'product_P001_all_mlstm_v2'
|
||||
self.assertEqual(paths_v2['filename_prefix'], expected_prefix_v2)
|
||||
print(f"\n版本递增后,生成的文件名前缀: {paths_v2['filename_prefix']}")
|
||||
print("版本递增验证通过!")
|
||||
|
||||
def test_store_mode_path_generation_with_hash(self):
|
||||
"""测试 'store' 模式下使用哈希的路径生成。"""
|
||||
print("\n--- 测试 'store' 模式 (多药品ID哈希) ---")
|
||||
params = {
|
||||
'training_mode': 'store',
|
||||
'model_type': 'kan',
|
||||
'store_id': 'S008',
|
||||
'product_scope': 'specific',
|
||||
'product_ids': ['P002', 'P005', 'P003'] # 顺序故意打乱
|
||||
}
|
||||
|
||||
paths = self.path_manager.get_model_paths(**params)
|
||||
|
||||
# 哈希值应该是固定的,因为ID列表会先排序再哈希
|
||||
expected_hash = self.path_manager._hash_ids(sorted(['P002', 'P005', 'P003']))
|
||||
expected_prefix = f'store_S008_{expected_hash}_kan_v1'
|
||||
|
||||
self.assertEqual(paths['filename_prefix'], expected_prefix)
|
||||
self.assertEqual(paths['model_path'], os.path.join(self.test_base_dir, f'{expected_prefix}_model.pth'))
|
||||
print(f"生成的文件名前缀: {paths['filename_prefix']}")
|
||||
print("验证通过!")
|
||||
|
||||
def test_global_mode_path_generation(self):
|
||||
"""测试 'global' 模式下的路径生成。"""
|
||||
print("\n--- 测试 'global' 模式 ---")
|
||||
params = {
|
||||
'training_mode': 'global',
|
||||
'model_type': 'transformer',
|
||||
'training_scope': 'all',
|
||||
'aggregation_method': 'mean'
|
||||
}
|
||||
|
||||
paths = self.path_manager.get_model_paths(**params)
|
||||
|
||||
expected_prefix = 'global_all_agg_mean_transformer_v1'
|
||||
self.assertEqual(paths['filename_prefix'], expected_prefix)
|
||||
self.assertEqual(paths['model_path'], os.path.join(self.test_base_dir, f'{expected_prefix}_model.pth'))
|
||||
print(f"生成的文件名前缀: {paths['filename_prefix']}")
|
||||
print("验证通过!")
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
@ -1,6 +1,6 @@
|
||||
### 根目录启动
|
||||
**1**:`uv venv`
|
||||
**2**:`uv pip install loguru numpy pandas torch matplotlib flask flask_cors flask_socketio flasgger scikit-learn tqdm pytorch_tcn pyarrow`
|
||||
**2**:`uv pip install loguru numpy pandas torch matplotlib flask flask_cors flask_socketio flasgger scikit-learn tqdm pytorch_tcn pyarrow xgboost -i https://pypi.tuna.tsinghua.edu.cn/simple`
|
||||
**3**: `uv run .\server\api.py`
|
||||
### UI
|
||||
**1**:`npm install` `npm run dev`
|
||||
|
61
xz模型添加流程.md
Normal file
61
xz模型添加流程.md
Normal file
@ -0,0 +1,61 @@
|
||||
# 为系统添加新模型的标准流程
|
||||
|
||||
本文档总结了向本项目中添加一个新的预测模型(以XGBoost为例)的标准流程,旨在为未来的开发工作提供清晰、可复用的路线图。
|
||||
|
||||
---
|
||||
|
||||
### 第1步:创建模型训练器
|
||||
|
||||
这是最核心的一步,负责实现新模型的训练逻辑。
|
||||
|
||||
1. **创建新文件**:在 [`server/trainers/`](server/trainers/) 目录下,创建一个新的Python文件,例如 `new_model_trainer.py`。
|
||||
|
||||
2. **定义训练函数**:在该文件中,定义一个核心的训练函数,遵循项目的标准签名,接收 `product_id`, `store_id`, `epochs`, `path_info` 等参数。
|
||||
|
||||
3. **实现函数内部逻辑**:
|
||||
* **数据加载**:使用 [`utils.multi_store_data_utils.load_multi_store_data`](server/utils/multi_store_data_utils.py) 加载数据,并根据 `product_id` 和 `store_id` 进行筛选。
|
||||
* **数据预处理**:将时间序列数据转换为监督学习格式。对于像XGBoost这样的模型,这意味着创建一个“滑动窗口”(如我们实现的 `create_dataset` 函数)。
|
||||
* **数据缩放 (关键)**:**必须**使用 `sklearn.preprocessing.MinMaxScaler` 对特征 (`X`) 和目标 (`y`) 进行归一化。创建并训练 `scaler_X` 和 `scaler_y` 两个缩放器。
|
||||
* **模型训练**:初始化您的新模型,并使用**归一化后**的数据进行训练。
|
||||
* **生成损失曲线 (可选但推荐)**:如果模型支持,在训练过程中捕获训练集和验证集的损失,然后使用 `matplotlib` 绘制损失曲线图,并将其保存为 `..._loss_curve.png`。
|
||||
* **保存检查点 (可选但推荐)**:如果模型支持回调(Callbacks),可以实现一个自定义回调函数,用于按指定轮次间隔保存模型检查点 (`..._checkpoint_epoch_{N}.pth`)。
|
||||
* **模型评估**:使用**反归一化后**的预测结果来计算评估指标(RMSE, R2等)。
|
||||
* **模型保存 (关键)**:
|
||||
* 创建一个字典(payload),**必须**包含以下内容:`'model'` (训练好的模型对象), `'config'` (训练配置), `'scaler_X'` (特征缩放器), 和 `'scaler_y'` (目标缩放器)。
|
||||
* 使用正确的库(PyTorch模型用 `torch.save`,其他模型如XGBoost用 `joblib.dump`)将这个字典保存到 `path_info['model_path']` 指定的路径。**文件名统一使用 `.pth` 扩展名**。
|
||||
|
||||
---
|
||||
|
||||
### 第2步:将训练器集成到系统中
|
||||
|
||||
1. **注册训练器**:打开 [`server/trainers/__init__.py`](server/trainers/__init__.py)。
|
||||
* 在文件顶部,从您的新训练器文件中导入训练函数,例如 `from .new_model_trainer import train_product_model_with_new_model`。
|
||||
* 在文件底部的 `__all__` 列表中,添加您的新训练函数名。
|
||||
|
||||
2. **添加调度逻辑**:打开 [`server/core/predictor.py`](server/core/predictor.py)。
|
||||
* 在 `train_model` 方法中,找到 `if/elif` 逻辑块,为您的新模型添加一个新的 `elif model_type == 'new_model':` 分支,并在此分支中调用您的新训练函数。
|
||||
|
||||
---
|
||||
|
||||
### 第3步:实现预测逻辑
|
||||
|
||||
1. **修改预测器**:打开 [`server/predictors/model_predictor.py`](server/predictors/model_predictor.py)。
|
||||
2. **添加预测分支**:在 `load_model_and_predict` 函数中,找到 `if/elif` 逻辑块,为您的新模型添加一个新的 `elif model_type == 'new_model':` 分支。
|
||||
3. **实现分支内部逻辑**:
|
||||
* 使用与保存时相同的库(例如 `joblib.load`)加载 `.pth` 模型文件。
|
||||
* 从加载的字典中,**必须**提取出 `model`, `config`, `scaler_X`, 和 `scaler_y`。
|
||||
* 准备用于预测的输入数据(例如,最近N天的数据)。
|
||||
* 在进行预测时,**必须**先用 `scaler_X.transform` 对输入数据进行归一化。
|
||||
* 得到模型的预测结果后,**必须**用 `scaler_y.inverse_transform` 将结果反归一化,以得到真实的预测值。
|
||||
|
||||
---
|
||||
|
||||
### 第4步:更新API和依赖项
|
||||
|
||||
1. **更新API端点**:打开 [`server/api.py`](server/api.py)。
|
||||
* 在 `/api/training` 路由(`start_training` 函数)的 `valid_model_types` 列表中,添加您的新模型ID(例如 `'new_model'`)。
|
||||
* 在 `/api/model_types` 路由(`get_model_types` 函数)返回的列表中,添加您新模型的描述信息,以便它能显示在前端界面。
|
||||
|
||||
2. **更新依赖**:打开 [`requirements.txt`](requirements.txt) 文件,添加您的新模型所需要的Python库(例如 `xgboost`)。
|
||||
|
||||
遵循以上四个步骤,您就可以将任何新的预测模型一致、健壮地集成到现有系统中。
|
50
xz模型预测修改.md
Normal file
50
xz模型预测修改.md
Normal file
@ -0,0 +1,50 @@
|
||||
# 模型预测路径修复记录
|
||||
|
||||
**修改时间**: 2025-07-18 18:43:50
|
||||
|
||||
## 1. 问题背景
|
||||
|
||||
系统在进行模型预测时出现“文件未找到”的错误。经分析,根本原因是模型加载逻辑(预测时)与模型保存逻辑(训练时)遵循了不一致的路径规则。
|
||||
|
||||
- **保存规则 (新)**: 遵循 `xz训练模型保存规则.md`,将模型保存在结构化的层级目录中,例如 `saved_models/product/{product_id}_all/mlstm/v1/model.pth`。
|
||||
- **加载逻辑 (旧)**: 代码中硬编码了扁平化的文件路径查找方式,例如在 `saved_models` 根目录下直接查找名为 `{product_id}_{model_type}_v1.pth` 的文件。
|
||||
|
||||
这种不匹配导致预测功能无法定位到已经训练好的模型。
|
||||
|
||||
## 2. 修复方案
|
||||
|
||||
为了解决此问题,我们采取了集中化路径管理的策略,确保整个应用程序都通过一个统一的管理器来生成和获取模型路径。
|
||||
|
||||
## 3. 代码修改详情
|
||||
|
||||
### 第一处修改:增强路径管理器
|
||||
|
||||
- **文件**: [`server/utils/file_save.py`](server/utils/file_save.py)
|
||||
- **操作**: 在 `ModelPathManager` 类中新增了 `get_model_path_for_prediction` 方法。
|
||||
- **目的**:
|
||||
- 提供一个专门用于**预测时**获取模型路径的函数。
|
||||
- 该函数严格按照 `xz训练模型保存规则.md` 中定义的层级结构来构建模型文件的完整路径。
|
||||
- 这使得路径生成逻辑被集中管理,避免了代码各处的硬编码。
|
||||
|
||||
### 第二处修改:修复API预测接口
|
||||
|
||||
- **文件**: [`server/api.py`](server/api.py)
|
||||
- **操作**:
|
||||
1. 修改了 `/api/prediction` 接口 (`predict` 函数) 的内部逻辑。
|
||||
2. 修改了辅助函数 `run_prediction` 的定义和实现。
|
||||
- **目的**:
|
||||
- **`predict` 函数**: 移除了所有旧的、手動拼接模型文件名的错误逻辑。转而实例化 `ModelPathManager` 并调用其新的 `get_model_path_for_prediction` 方法来获取准确、唯一的模型路径。
|
||||
- **`run_prediction` 函数**: 更新了其函数签名,增加了 `model_path` 参数,使其能够接收并向下传递由 `predict` 函数获取到的正确路径。同时,简化了其内部逻辑,直接调用 `load_model_and_predict`。
|
||||
|
||||
### 第三处修改:修复模型加载器
|
||||
|
||||
- **文件**: [`server/predictors/model_predictor.py`](server/predictors/model_predictor.py)
|
||||
- **操作**: 修改了 `load_model_and_predict` 函数。
|
||||
- **目的**:
|
||||
- 更新函数签名,添加了 `model_path` 参数。
|
||||
- **彻底移除了**函数内部所有复杂的、用于猜测模型文件位置的旧逻辑。
|
||||
- 函数现在完全依赖于从 `api.py` 传递过来的 `model_path` 参数来加载模型,确保了加载路径的准确性。
|
||||
|
||||
## 4. 结论
|
||||
|
||||
通过以上三处修改,我们打通了从API请求到模型文件加载的整个链路,确保了所有环节都遵循统一的、正确的结构化路径规则。这从根本上解决了因路径不匹配导致模型读取失败的问题。
|
@ -26,7 +26,6 @@
|
||||
|
||||
3. **拼接文件后缀**: 在前缀后加上描述文件类型的后缀。
|
||||
- `_model.pth`
|
||||
- `_metadata.json`
|
||||
- `_loss_curve.png`
|
||||
- `_checkpoint_best.pth`
|
||||
- `_checkpoint_epoch_{N}.pth`
|
||||
@ -34,7 +33,6 @@
|
||||
#### **完整示例:**
|
||||
|
||||
- **最终模型**: `saved_models/product_P001_all_mlstm_v2_model.pth`
|
||||
- **元数据**: `saved_models/product_P001_all_mlstm_v2_metadata.json`
|
||||
- **最佳检查点**: `saved_models/checkpoints/product_P001_all_mlstm_v2_checkpoint_best.pth`
|
||||
- **Epoch 50 检查点**: `saved_models/checkpoints/product_P001_all_mlstm_v2_checkpoint_epoch_50.pth`
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user