添加训练算法模拟xgboosT,训练可以完成,预测读取还有问题

数据文件保存机构改为### 1.2. 文件存储位置

-   **最终产物**: 所有最终模型、元数据文件、损失图等,统一存放在 `saved_models/` 根目录下。
-   **过程文件**: 所有训练过程中的检查点文件,统一存放在 `saved_models/checkpoints/` 目录下。

### 1.3. 文件名生成规则

1.  **构建逻辑路径**: 根据训练参数(模式、范围、类型、版本)确定逻辑路径。
    -   *示例*: `product/P001_all/mlstm/v2`

2.  **生成文件名前缀**: 将逻辑路径中的所有 `/` 替换为 `_`。
    -   *示例*: `product_P001_all_mlstm_v2`

3.  **拼接文件后缀**: 在前缀后加上描述文件类型的后缀。
    -   `_model.pth`
    -   `_loss_curve.png`
    -   `_checkpoint_best.pth`
    -   `_checkpoint_epoch_{N}.pth`

#### **完整示例:**

-   **最终模型**: `saved_models/product_P001_all_mlstm_v2_model.pth`
-   **最佳检查点**: `saved_models/checkpoints/product_P001_all_mlstm_v2_checkpoint_best.pth`
-   **Epoch 50 检查点**: `saved_models/checkpoints/product_P001_all_mlstm_v2_checkpoint_epoch_50.pth`
This commit is contained in:
xz2000 2025-07-21 18:47:02 +08:00
parent 28bae35783
commit 87df49f764
13 changed files with 809 additions and 184 deletions

Binary file not shown.

View File

@ -56,3 +56,4 @@ tzdata==2025.2
werkzeug==3.1.3 werkzeug==3.1.3
win32-setctime==1.2.0 win32-setctime==1.2.0
wsproto==1.2.0 wsproto==1.2.0
xgboost

View File

@ -46,6 +46,7 @@ from trainers.mlstm_trainer import train_product_model_with_mlstm
from trainers.kan_trainer import train_product_model_with_kan from trainers.kan_trainer import train_product_model_with_kan
from trainers.tcn_trainer import train_product_model_with_tcn from trainers.tcn_trainer import train_product_model_with_tcn
from trainers.transformer_trainer import train_product_model_with_transformer from trainers.transformer_trainer import train_product_model_with_transformer
from trainers.xgboost_trainer import train_product_model_with_xgboost
# 导入预测函数 # 导入预测函数
from predictors.model_predictor import load_model_and_predict from predictors.model_predictor import load_model_and_predict
@ -942,7 +943,7 @@ def get_all_training_tasks():
'type': 'object', 'type': 'object',
'properties': { 'properties': {
'product_id': {'type': 'string', 'description': '例如 P001'}, 'product_id': {'type': 'string', 'description': '例如 P001'},
'model_type': {'type': 'string', 'enum': ['mlstm', 'transformer', 'kan', 'optimized_kan', 'tcn'], 'description': '要训练的模型类型'}, 'model_type': {'type': 'string', 'enum': ['mlstm', 'transformer', 'kan', 'optimized_kan', 'tcn', 'xgboost'], 'description': '要训练的模型类型'},
'store_id': {'type': 'string', 'description': '店铺ID如 S001。为空时使用全局聚合数据'}, 'store_id': {'type': 'string', 'description': '店铺ID如 S001。为空时使用全局聚合数据'},
'epochs': {'type': 'integer', 'default': 50, 'description': '训练轮次'} 'epochs': {'type': 'integer', 'default': 50, 'description': '训练轮次'}
}, },
@ -1006,7 +1007,7 @@ def start_training():
pass pass
# 检查模型类型是否有效 # 检查模型类型是否有效
valid_model_types = ['mlstm', 'kan', 'optimized_kan', 'transformer', 'tcn'] valid_model_types = ['mlstm', 'kan', 'optimized_kan', 'transformer', 'tcn', 'xgboost']
if model_type not in valid_model_types: if model_type not in valid_model_types:
return jsonify({'error': '无效的模型类型'}), 400 return jsonify({'error': '无效的模型类型'}), 400
@ -1425,7 +1426,7 @@ def get_training_status(task_id):
'type': 'object', 'type': 'object',
'properties': { 'properties': {
'product_id': {'type': 'string'}, 'product_id': {'type': 'string'},
'model_type': {'type': 'string', 'enum': ['mlstm', 'transformer', 'kan', 'optimized_kan', 'tcn']}, 'model_type': {'type': 'string', 'enum': ['mlstm', 'transformer', 'kan', 'optimized_kan', 'tcn', 'xgboost']},
'store_id': {'type': 'string', 'description': '店铺ID如 S001。为空时使用全局模型'}, 'store_id': {'type': 'string', 'description': '店铺ID如 S001。为空时使用全局模型'},
'version': {'type': 'string'}, 'version': {'type': 'string'},
'training_mode': {'type': 'string', 'enum': ['product', 'store', 'global'], 'default': 'product'}, 'training_mode': {'type': 'string', 'enum': ['product', 'store', 'global'], 'default': 'product'},
@ -2110,7 +2111,7 @@ def delete_prediction(prediction_id):
'in': 'query', 'in': 'query',
'type': 'string', 'type': 'string',
'required': False, 'required': False,
'description': "按模型类型筛选 (mlstm, kan, transformer, tcn)" 'description': "按模型类型筛选 (mlstm, kan, transformer, tcn, xgboost)"
}, },
{ {
'name': 'page', 'name': 'page',
@ -2172,7 +2173,7 @@ def list_models():
in: query in: query
type: string type: string
required: false required: false
description: "按模型类型筛选 (mlstm, kan, transformer, tcn)" description: "按模型类型筛选 (mlstm, kan, transformer, tcn, xgboost)"
- name: store_id - name: store_id
in: query in: query
type: string type: string
@ -2775,7 +2776,7 @@ def run_prediction(model_type, product_id, model_id, future_days, start_date, ve
'in': 'path', 'in': 'path',
'type': 'string', 'type': 'string',
'required': True, 'required': True,
'description': '模型类型例如mlstm, kan, transformer, tcn, optimized_kan' 'description': '模型类型例如mlstm, kan, transformer, tcn, optimized_kan, xgboost'
}, },
{ {
'name': 'product_id', 'name': 'product_id',
@ -3687,6 +3688,12 @@ def get_model_types():
'name': 'TCN', 'name': 'TCN',
'description': '时间卷积网络,适合处理长序列和平行计算', 'description': '时间卷积网络,适合处理长序列和平行计算',
'tag_type': 'danger' 'tag_type': 'danger'
},
{
'id': 'xgboost',
'name': 'XGBoost',
'description': '一种高效的梯度提升决策树模型,广泛用于各种预测任务。',
'tag_type': 'success'
} }
] ]

View File

@ -15,7 +15,8 @@ from trainers import (
train_product_model_with_mlstm, train_product_model_with_mlstm,
train_product_model_with_kan, train_product_model_with_kan,
train_product_model_with_tcn, train_product_model_with_tcn,
train_product_model_with_transformer train_product_model_with_transformer,
train_product_model_with_xgboost
) )
from predictors.model_predictor import load_model_and_predict from predictors.model_predictor import load_model_and_predict
from utils.data_utils import prepare_data, prepare_sequences from utils.data_utils import prepare_data, prepare_sequences
@ -263,6 +264,16 @@ class PharmacyPredictor:
task_id=task_id, task_id=task_id,
path_info=path_info path_info=path_info
) )
elif model_type == 'xgboost':
metrics, _ = train_product_model_with_xgboost(
product_id=product_id,
store_id=store_id,
epochs=epochs,
socketio=socketio,
task_id=task_id,
version=version,
path_info=path_info
)
else: else:
log_message(f"不支持的模型类型: {model_type}", 'error') log_message(f"不支持的模型类型: {model_type}", 'error')
return None return None

View File

@ -8,8 +8,10 @@ import pandas as pd
import numpy as np import numpy as np
from datetime import datetime, timedelta from datetime import datetime, timedelta
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import xgboost as xgb
from sklearn.preprocessing import MinMaxScaler from sklearn.preprocessing import MinMaxScaler
import sklearn.preprocessing._data # 添加这一行以支持MinMaxScaler的反序列化 import sklearn.preprocessing._data # 添加这一行以支持MinMaxScaler的反序列化
import joblib
from models.transformer_model import TimeSeriesTransformer from models.transformer_model import TimeSeriesTransformer
from models.slstm_model import sLSTM as ScalarLSTM from models.slstm_model import sLSTM as ScalarLSTM
@ -30,7 +32,7 @@ def load_model_and_predict(product_id, model_type, model_path=None, store_id=Non
参数: 参数:
product_id: 产品ID product_id: 产品ID
model_type: 模型类型 ('transformer', 'mlstm', 'kan', 'tcn', 'optimized_kan') model_type: 模型类型 ('transformer', 'mlstm', 'kan', 'tcn', 'optimized_kan', 'xgboost')
model_path: 模型的完整文件路径 model_path: 模型的完整文件路径
store_id: 店铺ID为None时使用全局模型 store_id: 店铺ID为None时使用全局模型
future_days: 预测未来天数 future_days: 预测未来天数
@ -118,177 +120,241 @@ def load_model_and_predict(product_id, model_type, model_path=None, store_id=Non
# 加载模型和配置 # 加载模型和配置
try: try:
# 首先尝试使用weights_only=False加载 # 首先尝试使用weights_only=False加载
try: if model_type == 'xgboost':
print("尝试使用 weights_only=False 加载模型") if not os.path.exists(model_path):
checkpoint = torch.load(model_path, map_location=DEVICE, weights_only=False) print(f"XGBoost模型文件不存在: {model_path}")
except Exception as e: return None
print(f"使用weights_only=False加载失败: {str(e)}") # 加载元数据
print("尝试使用默认参数加载模型") metadata = joblib.load(model_path)
checkpoint = torch.load(model_path, map_location=DEVICE) model_file_path = metadata['model_file']
print(f"模型加载成功检查checkpoint类型: {type(checkpoint)}") if not os.path.exists(model_file_path):
if isinstance(checkpoint, dict): print(f"引用的XGBoost模型文件不存在: {model_file_path}")
print(f"checkpoint包含的键: {list(checkpoint.keys())}") return None
# 加载原生Booster模型
model = xgb.Booster()
model.load_model(model_file_path)
config = metadata['config']
metrics = metadata['metrics']
scaler_X = metadata['scaler_X']
scaler_y = metadata['scaler_y']
print("XGBoost原生模型及元数据加载成功")
else: else:
print(f"checkpoint不是字典类型而是: {type(checkpoint)}") try:
return None print("尝试使用 weights_only=False 加载模型")
checkpoint = torch.load(model_path, map_location=DEVICE, weights_only=False)
except Exception as e:
print(f"使用weights_only=False加载失败: {str(e)}")
print("尝试使用默认参数加载模型")
checkpoint = torch.load(model_path, map_location=DEVICE)
print(f"模型加载成功检查checkpoint类型: {type(checkpoint)}")
if isinstance(checkpoint, dict):
print(f"checkpoint包含的键: {list(checkpoint.keys())}")
else:
print(f"checkpoint不是字典类型而是: {type(checkpoint)}")
return None
except Exception as e: except Exception as e:
print(f"加载模型失败: {str(e)}") print(f"加载模型失败: {str(e)}")
return None return None
# 检查并获取配置 # XGBoost有不同的处理逻辑
if 'config' not in checkpoint: if model_type == 'xgboost':
print("模型文件中没有配置信息") look_back = config['look_back']
return None features = config['features']
config = checkpoint['config'] # 准备输入数据
print(f"模型配置: {config}") recent_data = product_df.iloc[-look_back:].copy()
# 检查并获取缩放器
if 'scaler_X' not in checkpoint or 'scaler_y' not in checkpoint:
print("模型文件中没有缩放器信息")
return None
scaler_X = checkpoint['scaler_X'] predictions = []
scaler_y = checkpoint['scaler_y'] current_input_df = recent_data[features].copy()
# 创建模型实例 for _ in range(future_days):
try: # 归一化输入数据并展平
if model_type == 'transformer': input_scaled = scaler_X.transform(current_input_df.values)
model = TimeSeriesTransformer( input_vector = input_scaled.flatten().reshape(1, -1)
num_features=config['input_dim'],
d_model=config['hidden_size'],
nhead=config['num_heads'],
num_encoder_layers=config['num_layers'],
dim_feedforward=config['hidden_size'] * 2,
dropout=config['dropout'],
output_sequence_length=config['output_dim'],
seq_length=config['sequence_length'],
batch_size=32
).to(DEVICE)
elif model_type == 'slstm':
model = ScalarLSTM(
input_dim=config['input_dim'],
hidden_dim=config['hidden_size'],
output_dim=config['output_dim'],
num_layers=config['num_layers'],
dropout=config['dropout']
).to(DEVICE)
elif model_type == 'mlstm':
# 获取配置参数,如果不存在则使用默认值
embed_dim = config.get('embed_dim', 32)
dense_dim = config.get('dense_dim', 32)
num_heads = config.get('num_heads', 4)
num_blocks = config.get('num_blocks', 3)
model = MatrixLSTM( # 预测缩放后的值
num_features=config['input_dim'], dpredict = xgb.DMatrix(input_vector)
hidden_size=config['hidden_size'], prediction_scaled = model.predict(dpredict)
mlstm_layers=config['num_layers'],
embed_dim=embed_dim,
dense_dim=dense_dim,
num_heads=num_heads,
dropout_rate=config['dropout'],
num_blocks=num_blocks,
output_sequence_length=config['output_dim']
).to(DEVICE)
elif model_type == 'kan':
model = KANForecaster(
input_features=config['input_dim'],
hidden_sizes=[config['hidden_size'], config['hidden_size']*2, config['hidden_size']],
output_sequence_length=config['output_dim']
).to(DEVICE)
elif model_type == 'optimized_kan':
model = OptimizedKANForecaster(
input_features=config['input_dim'],
hidden_sizes=[config['hidden_size'], config['hidden_size']*2, config['hidden_size']],
output_sequence_length=config['output_dim']
).to(DEVICE)
elif model_type == 'tcn':
model = TCNForecaster(
num_features=config['input_dim'],
output_sequence_length=config['output_dim'],
num_channels=[config['hidden_size']] * config['num_layers'],
kernel_size=3,
dropout=config['dropout']
).to(DEVICE)
else:
print(f"不支持的模型类型: {model_type}")
return None
print(f"模型实例创建成功: {type(model)}") # 反归一化得到真实预测值
except Exception as e: prediction = scaler_y.inverse_transform(prediction_scaled.reshape(-1, 1)).flatten()[0]
print(f"创建模型实例失败: {str(e)}") predictions.append(prediction)
return None
# 更新输入窗口以进行下一次预测
# 加载模型参数 # 创建新的一行,包含真实的预测值
try: new_row_values = current_input_df.iloc[-1].copy()
model.load_state_dict(checkpoint['model_state_dict']) new_row_values['sales'] = prediction
model.eval() # 可以在这里添加更复杂的未来特征生成逻辑例如根据新日期更新weekday, month等
print("模型参数加载成功")
except Exception as e: new_row_df = pd.DataFrame([new_row_values], columns=features)
print(f"加载模型参数失败: {str(e)}")
return None # 滚动窗口
current_input_df = pd.concat([current_input_df.iloc[1:], new_row_df], ignore_index=True)
# 准备输入数据
try:
features = ['sales', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature']
sequence_length = config['sequence_length']
# 获取最近的sequence_length天数据作为输入
recent_data = product_df.iloc[-sequence_length:].copy()
# 如果指定了起始日期,则使用该日期之后的数据
if start_date:
if isinstance(start_date, str):
start_date = datetime.strptime(start_date, '%Y-%m-%d')
recent_data = product_df[product_df['date'] >= start_date].iloc[:sequence_length].copy()
if len(recent_data) < sequence_length:
print(f"警告: 从指定日期 {start_date} 开始的数据少于所需的 {sequence_length}")
# 补充数据
missing_days = sequence_length - len(recent_data)
additional_data = product_df[product_df['date'] < start_date].iloc[-missing_days:].copy()
recent_data = pd.concat([additional_data, recent_data]).reset_index(drop=True)
print(f"输入数据准备完成,形状: {recent_data.shape}")
except Exception as e:
print(f"准备输入数据失败: {str(e)}")
return None
# 归一化输入数据
try:
X = recent_data[features].values
X_scaled = scaler_X.transform(X)
# 转换为模型输入格式
X_input = torch.tensor(X_scaled.reshape(1, sequence_length, -1), dtype=torch.float32).to(DEVICE)
print(f"输入张量准备完成,形状: {X_input.shape}")
except Exception as e:
print(f"归一化输入数据失败: {str(e)}")
return None
# 预测
try:
with torch.no_grad():
y_pred_scaled = model(X_input).cpu().numpy()
print(f"原始预测输出形状: {y_pred_scaled.shape}")
# 处理TCN、Transformer、mLSTM和KAN模型的输出确保形状正确
if model_type in ['tcn', 'transformer', 'mlstm', 'kan', 'optimized_kan'] and len(y_pred_scaled.shape) == 3:
y_pred_scaled = y_pred_scaled.squeeze(-1)
print(f"处理后的预测输出形状: {y_pred_scaled.shape}")
# 反归一化预测结果
y_pred = scaler_y.inverse_transform(y_pred_scaled.reshape(-1, 1)).flatten()
print(f"反归一化后的预测结果: {y_pred}")
# 生成预测日期 # 生成预测日期
last_date = recent_data['date'].iloc[-1] last_date = recent_data['date'].iloc[-1]
pred_dates = [(last_date + timedelta(days=i+1)) for i in range(len(y_pred))] pred_dates = [last_date + timedelta(days=i+1) for i in range(future_days)]
print(f"预测日期: {pred_dates}")
except Exception as e: y_pred = np.array(predictions)
print(f"执行预测失败: {str(e)}")
return None else: # 原有的PyTorch模型逻辑
# 检查并获取配置
if 'config' not in checkpoint:
print("模型文件中没有配置信息")
return None
config = checkpoint['config']
print(f"模型配置: {config}")
# 检查并获取缩放器
if 'scaler_X' not in checkpoint or 'scaler_y' not in checkpoint:
print("模型文件中没有缩放器信息")
return None
scaler_X = checkpoint['scaler_X']
scaler_y = checkpoint['scaler_y']
# 创建模型实例
try:
if model_type == 'transformer':
model = TimeSeriesTransformer(
num_features=config['input_dim'],
d_model=config['hidden_size'],
nhead=config['num_heads'],
num_encoder_layers=config['num_layers'],
dim_feedforward=config['hidden_size'] * 2,
dropout=config['dropout'],
output_sequence_length=config['output_dim'],
seq_length=config['sequence_length'],
batch_size=32
).to(DEVICE)
elif model_type == 'slstm':
model = ScalarLSTM(
input_dim=config['input_dim'],
hidden_dim=config['hidden_size'],
output_dim=config['output_dim'],
num_layers=config['num_layers'],
dropout=config['dropout']
).to(DEVICE)
elif model_type == 'mlstm':
# 获取配置参数,如果不存在则使用默认值
embed_dim = config.get('embed_dim', 32)
dense_dim = config.get('dense_dim', 32)
num_heads = config.get('num_heads', 4)
num_blocks = config.get('num_blocks', 3)
model = MatrixLSTM(
num_features=config['input_dim'],
hidden_size=config['hidden_size'],
mlstm_layers=config['num_layers'],
embed_dim=embed_dim,
dense_dim=dense_dim,
num_heads=num_heads,
dropout_rate=config['dropout'],
num_blocks=num_blocks,
output_sequence_length=config['output_dim']
).to(DEVICE)
elif model_type == 'kan':
model = KANForecaster(
input_features=config['input_dim'],
hidden_sizes=[config['hidden_size'], config['hidden_size']*2, config['hidden_size']],
output_sequence_length=config['output_dim']
).to(DEVICE)
elif model_type == 'optimized_kan':
model = OptimizedKANForecaster(
input_features=config['input_dim'],
hidden_sizes=[config['hidden_size'], config['hidden_size']*2, config['hidden_size']],
output_sequence_length=config['output_dim']
).to(DEVICE)
elif model_type == 'tcn':
model = TCNForecaster(
num_features=config['input_dim'],
output_sequence_length=config['output_dim'],
num_channels=[config['hidden_size']] * config['num_layers'],
kernel_size=3,
dropout=config['dropout']
).to(DEVICE)
else:
print(f"不支持的模型类型: {model_type}")
return None
print(f"模型实例创建成功: {type(model)}")
except Exception as e:
print(f"创建模型实例失败: {str(e)}")
return None
# 加载模型参数
try:
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
print("模型参数加载成功")
except Exception as e:
print(f"加载模型参数失败: {str(e)}")
return None
# 准备输入数据
try:
features = ['sales', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature']
sequence_length = config['sequence_length']
# 获取最近的sequence_length天数据作为输入
recent_data = product_df.iloc[-sequence_length:].copy()
# 如果指定了起始日期,则使用该日期之后的数据
if start_date:
if isinstance(start_date, str):
start_date = datetime.strptime(start_date, '%Y-%m-%d')
recent_data = product_df[product_df['date'] >= start_date].iloc[:sequence_length].copy()
if len(recent_data) < sequence_length:
print(f"警告: 从指定日期 {start_date} 开始的数据少于所需的 {sequence_length}")
# 补充数据
missing_days = sequence_length - len(recent_data)
additional_data = product_df[product_df['date'] < start_date].iloc[-missing_days:].copy()
recent_data = pd.concat([additional_data, recent_data]).reset_index(drop=True)
print(f"输入数据准备完成,形状: {recent_data.shape}")
except Exception as e:
print(f"准备输入数据失败: {str(e)}")
return None
# 归一化输入数据
try:
X = recent_data[features].values
X_scaled = scaler_X.transform(X)
# 转换为模型输入格式
X_input = torch.tensor(X_scaled.reshape(1, sequence_length, -1), dtype=torch.float32).to(DEVICE)
print(f"输入张量准备完成,形状: {X_input.shape}")
except Exception as e:
print(f"归一化输入数据失败: {str(e)}")
return None
# 预测
try:
with torch.no_grad():
y_pred_scaled = model(X_input).cpu().numpy()
print(f"原始预测输出形状: {y_pred_scaled.shape}")
# 处理TCN、Transformer、mLSTM和KAN模型的输出确保形状正确
if model_type in ['tcn', 'transformer', 'mlstm', 'kan', 'optimized_kan'] and len(y_pred_scaled.shape) == 3:
y_pred_scaled = y_pred_scaled.squeeze(-1)
print(f"处理后的预测输出形状: {y_pred_scaled.shape}")
# 反归一化预测结果
y_pred = scaler_y.inverse_transform(y_pred_scaled.reshape(-1, 1)).flatten()
print(f"反归一化后的预测结果: {y_pred}")
# 生成预测日期
last_date = recent_data['date'].iloc[-1]
pred_dates = [(last_date + timedelta(days=i+1)) for i in range(len(y_pred))]
print(f"预测日期: {pred_dates}")
except Exception as e:
print(f"执行预测失败: {str(e)}")
return None
# 创建预测结果DataFrame # 创建预测结果DataFrame
try: try:
@ -348,4 +414,4 @@ def load_model_and_predict(product_id, model_type, model_path=None, store_id=Non
print(f"预测过程中出现未捕获的异常: {str(e)}") print(f"预测过程中出现未捕获的异常: {str(e)}")
import traceback import traceback
traceback.print_exc() traceback.print_exc()
return None return None

View File

@ -6,6 +6,7 @@ from .mlstm_trainer import train_product_model_with_mlstm
from .kan_trainer import train_product_model_with_kan from .kan_trainer import train_product_model_with_kan
from .tcn_trainer import train_product_model_with_tcn from .tcn_trainer import train_product_model_with_tcn
from .transformer_trainer import train_product_model_with_transformer from .transformer_trainer import train_product_model_with_transformer
from .xgboost_trainer import train_product_model_with_xgboost
# 默认训练函数 # 默认训练函数
from .mlstm_trainer import train_product_model_with_mlstm as train_product_model from .mlstm_trainer import train_product_model_with_mlstm as train_product_model
@ -15,5 +16,6 @@ __all__ = [
'train_product_model_with_mlstm', 'train_product_model_with_mlstm',
'train_product_model_with_kan', 'train_product_model_with_kan',
'train_product_model_with_tcn', 'train_product_model_with_tcn',
'train_product_model_with_transformer' 'train_product_model_with_transformer',
'train_product_model_with_xgboost'
] ]

View File

@ -0,0 +1,296 @@
import xgboost as xgb
import numpy as np
import pandas as pd
import os
import joblib
import xgboost as xgb
from xgboost.callback import EarlyStopping
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
from sklearn.preprocessing import MinMaxScaler
# 从项目中导入正确的工具函数和配置
from utils.multi_store_data_utils import load_multi_store_data
from core.config import DEFAULT_DATA_PATH
from utils.file_save import ModelPathManager
from analysis.metrics import evaluate_model
# 重构后的原生API兼容回调
class EpochCheckpointCallback(xgb.callback.TrainingCallback):
def __init__(self, save_period, payload, base_path):
super().__init__()
self.save_period = save_period
self.payload = payload
self.base_path = base_path
self.best_score = float('inf')
def _save_checkpoint(self, model, path_suffix):
"""辅助函数,用于保存模型和元数据检查点"""
metadata_path = self.base_path.replace('_model.pth', f'_{path_suffix}.pth')
model_file_path = metadata_path.replace('.pth', '.xgb')
# 确保目录存在
os.makedirs(os.path.dirname(metadata_path), exist_ok=True)
# 保存原生Booster模型
model.save_model(model_file_path)
# 更新payload中的模型文件引用
self.payload['model_file'] = model_file_path
joblib.dump(self.payload, metadata_path)
print(f"[Checkpoint] 已保存检查点到: {metadata_path}")
def after_iteration(self, model, epoch, evals_log):
# 获取当前验证集的分数 (假设'test'是验证集)
current_score = evals_log['test']['rmse'][-1]
# 保存最佳模型
if current_score < self.best_score:
self.best_score = current_score
self._save_checkpoint(model, 'checkpoint_best')
# 保存周期性检查点
if (epoch + 1) % self.save_period == 0:
self._save_checkpoint(model, f'checkpoint_epoch_{epoch + 1}')
return False # 继续训练
def create_dataset(data, look_back=7):
"""
将时间序列数据转换为监督学习格式
:param data: 输入的DataFrame包含特征和目标
:param look_back: 用于预测的时间窗口大小
:return: X (特征), y (目标)
"""
X, y = [], []
feature_columns = [col for col in data.columns if col != 'date']
for i in range(len(data) - look_back):
# 展平look_back窗口内的所有特征
features = data[feature_columns].iloc[i:(i + look_back)].values.flatten()
X.append(features)
# 目标是窗口后的第一个销售值
y.append(data['sales'].iloc[i + look_back])
return np.array(X), np.array(y)
def train_product_model_with_xgboost(
product_id,
store_id=None,
epochs=100, # XGBoost中n_estimators更常用
look_back=7,
socketio=None,
task_id=None,
version='v1',
path_info=None,
**kwargs):
"""
使用XGBoost训练产品销售预测模型
"""
def emit_progress(message, progress=None):
if socketio and task_id:
payload = {'task_id': task_id, 'message': message}
if progress is not None:
payload['progress'] = progress
socketio.emit('training_update', payload, namespace='/api/training', room=task_id)
print(f"[{task_id}] {message}")
try:
model_path = None
emit_progress("开始XGBoost模型训练...", 0)
# 1. 加载数据
# 使用正确的函数并从config导入路径
full_df = load_multi_store_data(DEFAULT_DATA_PATH)
# 根据 store_id 和 product_id 筛选数据
if store_id:
df = full_df[(full_df['product_id'] == product_id) & (full_df['store_id'] == store_id)].copy()
else:
# 如果没有store_id则聚合该产品在所有店铺的数据
df = full_df[full_df['product_id'] == product_id].groupby('date').agg({
'sales': 'sum',
'weekday': 'first',
'month': 'first',
'is_holiday': 'max',
'is_weekend': 'max',
'is_promotion': 'max',
'temperature': 'mean'
}).reset_index()
if df.empty:
raise ValueError(f"加载的数据为空 (product: {product_id}, store: {store_id}),无法进行训练。")
# 确保数据按日期排序
df = df.sort_values('date').reset_index(drop=True)
emit_progress("数据加载完成。", 10)
# 2. 创建数据集
features_to_use = ['sales', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature']
# 确保所有需要的特征都存在
for col in features_to_use:
if col not in df.columns:
# 如果特征不存在用0填充
df[col] = 0
df_features = df[['date'] + features_to_use]
X, y = create_dataset(df_features, look_back)
if X.shape[0] == 0:
raise ValueError("创建数据集后样本数量为0请检查数据量和look_back参数。")
emit_progress(f"数据集创建完成,样本数: {X.shape[0]}", 20)
# 3. 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# 数据缩放
scaler_X = MinMaxScaler()
X_train_scaled = scaler_X.fit_transform(X_train)
X_test_scaled = scaler_X.transform(X_test)
scaler_y = MinMaxScaler()
y_train_scaled = scaler_y.fit_transform(y_train.reshape(-1, 1))
# y_test is not scaled, used for metric calculation against inverse_transformed predictions
emit_progress("数据划分和缩放完成。", 30)
# 4. 切换到XGBoost原生API
params = {
'learning_rate': kwargs.get('learning_rate', 0.1),
'max_depth': kwargs.get('max_depth', 5),
'subsample': kwargs.get('subsample', 0.8),
'colsample_bytree': kwargs.get('colsample_bytree', 0.8),
'objective': 'reg:squarederror',
'eval_metric': 'rmse',
'random_state': 42
}
dtrain = xgb.DMatrix(X_train_scaled, label=y_train_scaled.ravel())
dtest = xgb.DMatrix(X_test_scaled, label=scaler_y.transform(y_test.reshape(-1, 1)).ravel())
emit_progress("开始模型训练...", 40)
# 定义验证集
evals = [(dtrain, 'train'), (dtest, 'test')]
# 准备回调
callbacks = []
checkpoint_interval = kwargs.get('checkpoint_interval', 10) # 默认每10轮保存一次
if path_info and path_info.get('model_path') and checkpoint_interval > 0:
# 准备用于保存的payload模型对象将在回调中动态更新
checkpoint_payload = {
'metrics': {}, # 检查点不保存最终指标
'scaler_X': scaler_X,
'scaler_y': scaler_y,
'config': {
'look_back': look_back,
'features': features_to_use,
'product_id': product_id,
'store_id': store_id,
'version': version
}
}
checkpoint_callback = EpochCheckpointCallback(
save_period=checkpoint_interval,
payload=checkpoint_payload,
base_path=path_info['model_path']
)
callbacks.append(checkpoint_callback)
# 添加早停回调 (移除save_best)
callbacks.append(EarlyStopping(rounds=10))
# 用于存储评估结果
evals_result = {}
model = xgb.train(
params=params,
dtrain=dtrain,
num_boost_round=epochs,
evals=evals,
callbacks=callbacks,
evals_result=evals_result,
verbose_eval=False
)
emit_progress("模型训练完成。", 80)
# 绘制并保存损失曲线
if path_info and path_info.get('model_path'):
try:
loss_curve_path = path_info['model_path'].replace('_model.pth', '_loss_curve.png')
results = evals_result
train_rmse = results['train']['rmse']
test_rmse = results['test']['rmse']
num_epochs = len(train_rmse)
x_axis = range(0, num_epochs)
fig, ax = plt.subplots(figsize=(10, 6))
ax.plot(x_axis, train_rmse, label='Train')
ax.plot(x_axis, test_rmse, label='Test')
ax.legend()
plt.ylabel('RMSE')
plt.xlabel('Epoch')
plt.title('XGBoost RMSE Loss Curve')
plt.savefig(loss_curve_path)
plt.close(fig)
emit_progress(f"损失曲线图已保存到: {loss_curve_path}")
except Exception as e:
emit_progress(f"警告: 绘制损失曲线失败: {str(e)}")
# 5. 评估模型
dtest_pred = xgb.DMatrix(X_test_scaled)
y_pred_scaled = model.predict(dtest_pred)
y_pred = scaler_y.inverse_transform(y_pred_scaled.reshape(-1, 1)).flatten()
metrics = {
'RMSE': np.sqrt(mean_squared_error(y_test, y_pred)),
'MAE': mean_absolute_error(y_test, y_pred),
'R2': r2_score(y_test, y_pred)
}
emit_progress(f"模型评估完成: {metrics}", 90)
# 6. 保存模型 (原生API方式)
if path_info and path_info.get('model_path'):
metadata_path = path_info['model_path']
# 使用 .xgb 扩展名保存原生Booster模型
model_file_path = metadata_path.replace('.pth', '.xgb')
# 确保目录存在
os.makedirs(os.path.dirname(metadata_path), exist_ok=True)
# 使用原生方法保存Booster模型
model.save_model(model_file_path)
emit_progress(f"原生XGBoost模型已保存到: {model_file_path}")
# 保存元数据(包括模型文件路径)
metadata_payload = {
'model_file': model_file_path, # 保存模型文件的引用
'metrics': metrics,
'scaler_X': scaler_X,
'scaler_y': scaler_y,
'config': {
'look_back': look_back,
'features': features_to_use,
'product_id': product_id,
'store_id': store_id,
'version': version
}
}
joblib.dump(metadata_payload, metadata_path)
model_path = metadata_path # 确保model_path被赋值
emit_progress(f"模型元数据已保存到: {metadata_path}", 100)
else:
emit_progress("警告: 未提供path_info模型未保存。", 100)
return metrics, model_path
except Exception as e:
emit_progress(f"XGBoost训练失败: {str(e)}", 100)
import traceback
traceback.print_exc()
return {'error': str(e)}, None

View File

@ -167,11 +167,25 @@ class TrainingWorker:
}) })
if metrics: if metrics:
self.progress_queue.put({ if 'error' in metrics:
'task_id': task.task_id, self.progress_queue.put({
'log_type': 'info', 'task_id': task.task_id,
'message': f"📊 训练指标: MSE={metrics.get('mse', 'N/A'):.4f}, RMSE={metrics.get('rmse', 'N/A'):.4f}" 'log_type': 'error',
}) 'message': f"❌ 训练返回错误: {metrics['error']}"
})
else:
# 只有在没有错误时才格式化指标
mse_val = metrics.get('mse', 'N/A')
rmse_val = metrics.get('rmse', 'N/A')
mse_str = f"{mse_val:.4f}" if isinstance(mse_val, (int, float)) else mse_val
rmse_str = f"{rmse_val:.4f}" if isinstance(rmse_val, (int, float)) else rmse_val
self.progress_queue.put({
'task_id': task.task_id,
'log_type': 'info',
'message': f"📊 训练指标: MSE={mse_str}, RMSE={rmse_str}"
})
except ImportError as e: except ImportError as e:
training_logger.error(f"❌ 导入训练器失败: {e}") training_logger.error(f"❌ 导入训练器失败: {e}")
# 返回模拟的训练结果用于测试 # 返回模拟的训练结果用于测试
@ -382,12 +396,13 @@ class TrainingProcessManager:
# 只有在训练成功metrics有效时才保存版本信息 # 只有在训练成功metrics有效时才保存版本信息
if task.metrics and task.metrics.get('r2', -1) >= 0: if task.metrics and task.metrics.get('r2', -1) >= 0:
if task.path_info: if task.path_info:
identifier = task.path_info.get('identifier') # 确保使用正确的、经过规范化处理的标识符
version_control_identifier = task.path_info.get('identifier')
version = task.path_info.get('version') version = task.path_info.get('version')
if identifier and version: if version_control_identifier and version:
try: try:
self.path_manager.save_version_info(identifier, version) self.path_manager.save_version_info(version_control_identifier, version)
self.logger.info(f"✅ 版本信息已更新: identifier={identifier}, version={version}") self.logger.info(f"✅ 版本信息已更新: identifier={version_control_identifier}, version={version}")
task.version = version # 关键修复:将版本号保存到任务对象中 task.version = version # 关键修复:将版本号保存到任务对象中
except Exception as e: except Exception as e:
self.logger.error(f"❌ 更新版本文件失败: {e}") self.logger.error(f"❌ 更新版本文件失败: {e}")
@ -398,12 +413,12 @@ class TrainingProcessManager:
if self.websocket_callback: if self.websocket_callback:
try: try:
if action == 'complete': if action == 'complete':
# 从任务信息中获取版本号 # 从任务对象中获取权威的版本号
version = None version = None
with self.lock: with self.lock:
task = self.tasks.get(task_id) task = self.tasks.get(task_id)
if task and task.path_info: if task:
version = task.path_info.get('version') version = task.version
# 训练完成 - 发送完成状态 # 训练完成 - 发送完成状态
self.websocket_callback('training_update', { self.websocket_callback('training_update', {

118
test/verify_save_logic.py Normal file
View File

@ -0,0 +1,118 @@
import unittest
import os
import shutil
import sys
# 将项目根目录添加到 sys.path以解决模块导入问题
# 这使得测试脚本可以直接运行,而无需复杂的路径配置
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
if project_root not in sys.path:
sys.path.insert(0, project_root)
from server.utils.file_save import ModelPathManager
class TestModelPathManager(unittest.TestCase):
"""
测试 ModelPathManager 是否严格遵循扁平化文件存储规范
"""
def setUp(self):
"""在每个测试用例开始前,设置测试环境。"""
self.test_base_dir = 'test_saved_models'
# 清理之前的测试目录和文件
if os.path.exists(self.test_base_dir):
shutil.rmtree(self.test_base_dir)
self.path_manager = ModelPathManager(base_dir=self.test_base_dir)
def tearDown(self):
"""在每个测试用例结束后,清理测试环境。"""
if os.path.exists(self.test_base_dir):
shutil.rmtree(self.test_base_dir)
def test_product_mode_path_generation(self):
"""测试 'product' 模式下的路径生成是否符合规范。"""
print("\n--- 测试 'product' 模式 ---")
params = {
'training_mode': 'product',
'model_type': 'mlstm',
'product_id': 'P001',
'store_id': 'all'
}
# 第一次调用,版本应为 1
paths_v1 = self.path_manager.get_model_paths(**params)
# 验证版本号
self.assertEqual(paths_v1['version'], 1)
# 验证文件名前缀
expected_prefix_v1 = 'product_P001_all_mlstm_v1'
self.assertEqual(paths_v1['filename_prefix'], expected_prefix_v1)
# 验证各个文件的完整路径
self.assertEqual(paths_v1['model_path'], os.path.join(self.test_base_dir, f'{expected_prefix_v1}_model.pth'))
self.assertEqual(paths_v1['metadata_path'], os.path.join(self.test_base_dir, f'{expected_prefix_v1}_metadata.json'))
self.assertEqual(paths_v1['loss_curve_path'], os.path.join(self.test_base_dir, f'{expected_prefix_v1}_loss_curve.png'))
# 验证检查点路径
checkpoint_dir = os.path.join(self.test_base_dir, 'checkpoints')
self.assertEqual(paths_v1['checkpoint_dir'], checkpoint_dir)
self.assertEqual(paths_v1['best_checkpoint_path'], os.path.join(checkpoint_dir, f'{expected_prefix_v1}_checkpoint_best.pth'))
self.assertEqual(paths_v1['epoch_checkpoint_template'], os.path.join(checkpoint_dir, f'{expected_prefix_v1}_checkpoint_epoch_{{N}}.pth'))
print(f"生成的文件名前缀: {paths_v1['filename_prefix']}")
print(f"生成的模型路径: {paths_v1['model_path']}")
print("验证通过!")
# 模拟一次成功的训练,以触发版本递增
self.path_manager.save_version_info(paths_v1['identifier'], paths_v1['version'])
# 第二次调用,版本应为 2
paths_v2 = self.path_manager.get_model_paths(**params)
self.assertEqual(paths_v2['version'], 2)
expected_prefix_v2 = 'product_P001_all_mlstm_v2'
self.assertEqual(paths_v2['filename_prefix'], expected_prefix_v2)
print(f"\n版本递增后,生成的文件名前缀: {paths_v2['filename_prefix']}")
print("版本递增验证通过!")
def test_store_mode_path_generation_with_hash(self):
"""测试 'store' 模式下使用哈希的路径生成。"""
print("\n--- 测试 'store' 模式 (多药品ID哈希) ---")
params = {
'training_mode': 'store',
'model_type': 'kan',
'store_id': 'S008',
'product_scope': 'specific',
'product_ids': ['P002', 'P005', 'P003'] # 顺序故意打乱
}
paths = self.path_manager.get_model_paths(**params)
# 哈希值应该是固定的因为ID列表会先排序再哈希
expected_hash = self.path_manager._hash_ids(sorted(['P002', 'P005', 'P003']))
expected_prefix = f'store_S008_{expected_hash}_kan_v1'
self.assertEqual(paths['filename_prefix'], expected_prefix)
self.assertEqual(paths['model_path'], os.path.join(self.test_base_dir, f'{expected_prefix}_model.pth'))
print(f"生成的文件名前缀: {paths['filename_prefix']}")
print("验证通过!")
def test_global_mode_path_generation(self):
"""测试 'global' 模式下的路径生成。"""
print("\n--- 测试 'global' 模式 ---")
params = {
'training_mode': 'global',
'model_type': 'transformer',
'training_scope': 'all',
'aggregation_method': 'mean'
}
paths = self.path_manager.get_model_paths(**params)
expected_prefix = 'global_all_agg_mean_transformer_v1'
self.assertEqual(paths['filename_prefix'], expected_prefix)
self.assertEqual(paths['model_path'], os.path.join(self.test_base_dir, f'{expected_prefix}_model.pth'))
print(f"生成的文件名前缀: {paths['filename_prefix']}")
print("验证通过!")
if __name__ == '__main__':
unittest.main()

View File

@ -1,6 +1,6 @@
### 根目录启动 ### 根目录启动
**1**:`uv venv` **1**:`uv venv`
**2**:`uv pip install loguru numpy pandas torch matplotlib flask flask_cors flask_socketio flasgger scikit-learn tqdm pytorch_tcn pyarrow` **2**:`uv pip install loguru numpy pandas torch matplotlib flask flask_cors flask_socketio flasgger scikit-learn tqdm pytorch_tcn pyarrow xgboost -i https://pypi.tuna.tsinghua.edu.cn/simple`
**3**: `uv run .\server\api.py` **3**: `uv run .\server\api.py`
### UI ### UI
**1**:`npm install` `npm run dev` **1**:`npm install` `npm run dev`

61
xz模型添加流程.md Normal file
View File

@ -0,0 +1,61 @@
# 为系统添加新模型的标准流程
本文档总结了向本项目中添加一个新的预测模型以XGBoost为例的标准流程旨在为未来的开发工作提供清晰、可复用的路线图。
---
### 第1步创建模型训练器
这是最核心的一步,负责实现新模型的训练逻辑。
1. **创建新文件**:在 [`server/trainers/`](server/trainers/) 目录下创建一个新的Python文件例如 `new_model_trainer.py`
2. **定义训练函数**:在该文件中,定义一个核心的训练函数,遵循项目的标准签名,接收 `product_id`, `store_id`, `epochs`, `path_info` 等参数。
3. **实现函数内部逻辑**
* **数据加载**:使用 [`utils.multi_store_data_utils.load_multi_store_data`](server/utils/multi_store_data_utils.py) 加载数据,并根据 `product_id``store_id` 进行筛选。
* **数据预处理**将时间序列数据转换为监督学习格式。对于像XGBoost这样的模型这意味着创建一个“滑动窗口”如我们实现的 `create_dataset` 函数)。
* **数据缩放 (关键)****必须**使用 `sklearn.preprocessing.MinMaxScaler` 对特征 (`X`) 和目标 (`y`) 进行归一化。创建并训练 `scaler_X``scaler_y` 两个缩放器。
* **模型训练**:初始化您的新模型,并使用**归一化后**的数据进行训练。
* **生成损失曲线 (可选但推荐)**:如果模型支持,在训练过程中捕获训练集和验证集的损失,然后使用 `matplotlib` 绘制损失曲线图,并将其保存为 `..._loss_curve.png`
* **保存检查点 (可选但推荐)**如果模型支持回调Callbacks可以实现一个自定义回调函数用于按指定轮次间隔保存模型检查点 (`..._checkpoint_epoch_{N}.pth`)。
* **模型评估**:使用**反归一化后**的预测结果来计算评估指标RMSE, R2等
* **模型保存 (关键)**
* 创建一个字典payload**必须**包含以下内容:`'model'` (训练好的模型对象), `'config'` (训练配置), `'scaler_X'` (特征缩放器), 和 `'scaler_y'` (目标缩放器)。
* 使用正确的库PyTorch模型用 `torch.save`其他模型如XGBoost用 `joblib.dump`)将这个字典保存到 `path_info['model_path']` 指定的路径。**文件名统一使用 `.pth` 扩展名**。
---
### 第2步将训练器集成到系统中
1. **注册训练器**:打开 [`server/trainers/__init__.py`](server/trainers/__init__.py)。
* 在文件顶部,从您的新训练器文件中导入训练函数,例如 `from .new_model_trainer import train_product_model_with_new_model`
* 在文件底部的 `__all__` 列表中,添加您的新训练函数名。
2. **添加调度逻辑**:打开 [`server/core/predictor.py`](server/core/predictor.py)。
* 在 `train_model` 方法中,找到 `if/elif` 逻辑块,为您的新模型添加一个新的 `elif model_type == 'new_model':` 分支,并在此分支中调用您的新训练函数。
---
### 第3步实现预测逻辑
1. **修改预测器**:打开 [`server/predictors/model_predictor.py`](server/predictors/model_predictor.py)。
2. **添加预测分支**:在 `load_model_and_predict` 函数中,找到 `if/elif` 逻辑块,为您的新模型添加一个新的 `elif model_type == 'new_model':` 分支。
3. **实现分支内部逻辑**
* 使用与保存时相同的库(例如 `joblib.load`)加载 `.pth` 模型文件。
* 从加载的字典中,**必须**提取出 `model`, `config`, `scaler_X`, 和 `scaler_y`
* 准备用于预测的输入数据例如最近N天的数据
* 在进行预测时,**必须**先用 `scaler_X.transform` 对输入数据进行归一化。
* 得到模型的预测结果后,**必须**用 `scaler_y.inverse_transform` 将结果反归一化,以得到真实的预测值。
---
### 第4步更新API和依赖项
1. **更新API端点**:打开 [`server/api.py`](server/api.py)。
* 在 `/api/training` 路由(`start_training` 函数)的 `valid_model_types` 列表中添加您的新模型ID例如 `'new_model'`)。
* 在 `/api/model_types` 路由(`get_model_types` 函数)返回的列表中,添加您新模型的描述信息,以便它能显示在前端界面。
2. **更新依赖**:打开 [`requirements.txt`](requirements.txt) 文件添加您的新模型所需要的Python库例如 `xgboost`)。
遵循以上四个步骤,您就可以将任何新的预测模型一致、健壮地集成到现有系统中。

50
xz模型预测修改.md Normal file
View File

@ -0,0 +1,50 @@
# 模型预测路径修复记录
**修改时间**: 2025-07-18 18:43:50
## 1. 问题背景
系统在进行模型预测时出现“文件未找到”的错误。经分析,根本原因是模型加载逻辑(预测时)与模型保存逻辑(训练时)遵循了不一致的路径规则。
- **保存规则 (新)**: 遵循 `xz训练模型保存规则.md`,将模型保存在结构化的层级目录中,例如 `saved_models/product/{product_id}_all/mlstm/v1/model.pth`
- **加载逻辑 (旧)**: 代码中硬编码了扁平化的文件路径查找方式,例如在 `saved_models` 根目录下直接查找名为 `{product_id}_{model_type}_v1.pth` 的文件。
这种不匹配导致预测功能无法定位到已经训练好的模型。
## 2. 修复方案
为了解决此问题,我们采取了集中化路径管理的策略,确保整个应用程序都通过一个统一的管理器来生成和获取模型路径。
## 3. 代码修改详情
### 第一处修改:增强路径管理器
- **文件**: [`server/utils/file_save.py`](server/utils/file_save.py)
- **操作**: 在 `ModelPathManager` 类中新增了 `get_model_path_for_prediction` 方法。
- **目的**:
- 提供一个专门用于**预测时**获取模型路径的函数。
- 该函数严格按照 `xz训练模型保存规则.md` 中定义的层级结构来构建模型文件的完整路径。
- 这使得路径生成逻辑被集中管理,避免了代码各处的硬编码。
### 第二处修改修复API预测接口
- **文件**: [`server/api.py`](server/api.py)
- **操作**:
1. 修改了 `/api/prediction` 接口 (`predict` 函数) 的内部逻辑。
2. 修改了辅助函数 `run_prediction` 的定义和实现。
- **目的**:
- **`predict` 函数**: 移除了所有旧的、手動拼接模型文件名的错误逻辑。转而实例化 `ModelPathManager` 并调用其新的 `get_model_path_for_prediction` 方法来获取准确、唯一的模型路径。
- **`run_prediction` 函数**: 更新了其函数签名,增加了 `model_path` 参数,使其能够接收并向下传递由 `predict` 函数获取到的正确路径。同时,简化了其内部逻辑,直接调用 `load_model_and_predict`
### 第三处修改:修复模型加载器
- **文件**: [`server/predictors/model_predictor.py`](server/predictors/model_predictor.py)
- **操作**: 修改了 `load_model_and_predict` 函数。
- **目的**:
- 更新函数签名,添加了 `model_path` 参数。
- **彻底移除了**函数内部所有复杂的、用于猜测模型文件位置的旧逻辑。
- 函数现在完全依赖于从 `api.py` 传递过来的 `model_path` 参数来加载模型,确保了加载路径的准确性。
## 4. 结论
通过以上三处修改我们打通了从API请求到模型文件加载的整个链路确保了所有环节都遵循统一的、正确的结构化路径规则。这从根本上解决了因路径不匹配导致模型读取失败的问题。

View File

@ -26,7 +26,6 @@
3. **拼接文件后缀**: 在前缀后加上描述文件类型的后缀。 3. **拼接文件后缀**: 在前缀后加上描述文件类型的后缀。
- `_model.pth` - `_model.pth`
- `_metadata.json`
- `_loss_curve.png` - `_loss_curve.png`
- `_checkpoint_best.pth` - `_checkpoint_best.pth`
- `_checkpoint_epoch_{N}.pth` - `_checkpoint_epoch_{N}.pth`
@ -34,7 +33,6 @@
#### **完整示例:** #### **完整示例:**
- **最终模型**: `saved_models/product_P001_all_mlstm_v2_model.pth` - **最终模型**: `saved_models/product_P001_all_mlstm_v2_model.pth`
- **元数据**: `saved_models/product_P001_all_mlstm_v2_metadata.json`
- **最佳检查点**: `saved_models/checkpoints/product_P001_all_mlstm_v2_checkpoint_best.pth` - **最佳检查点**: `saved_models/checkpoints/product_P001_all_mlstm_v2_checkpoint_best.pth`
- **Epoch 50 检查点**: `saved_models/checkpoints/product_P001_all_mlstm_v2_checkpoint_epoch_50.pth` - **Epoch 50 检查点**: `saved_models/checkpoints/product_P001_all_mlstm_v2_checkpoint_epoch_50.pth`