Compare commits

...

3 Commits

Author SHA1 Message Date
6f3240c723 按药品训练-预测跑通 2025-07-17 17:54:53 +08:00
e437658b9d 系统开发设计指南 2025-07-17 15:52:04 +08:00
ee9ba299fa 模型预测算法优化 2025-07-16 18:50:16 +08:00
14 changed files with 775 additions and 68 deletions

View File

@ -185,13 +185,14 @@ const startPrediction = async () => {
try {
predicting.value = true
const payload = {
training_mode: 'global', //
model_type: form.model_type,
version: form.version,
future_days: form.future_days,
start_date: form.start_date,
analyze_result: form.analyze_result
}
const response = await axios.post('/api/predict', payload)
const response = await axios.post('/api/prediction', payload)
if (response.data.status === 'success') {
predictionResult.value = response.data.data
ElMessage.success('预测完成!')

View File

@ -165,5 +165,36 @@
### 16:16 - 项目状态更新
- **状态**: **所有已知问题已修复**
- **确认**: 用户已确认“现在药品和店铺预测流程通了
- **确认**: 用户已确认“现在药品和店铺预测流程通了。
- **后续**: 将本次修复过程归档至本文档。
---
### 2025年7月16日 18:38 - 全模型预测功能通用性修复
**问题现象**:
在解决了 `Transformer` 模型的预测问题后,发现一个更深层次的系统性问题:在所有预测模式(按药品、按店铺、全局)中,只有 `Transformer` 算法可以成功预测并显示图表,而其他四种模型(`mLSTM`, `KAN`, `优化版KAN`, `TCN`)虽然能成功训练,但在预测时均会失败,并提示“没有可用于图表的数据”。
**根本原因深度分析**:
这个问题的核心在于**模型配置的持久化不完整且不统一**。
1. **Transformer 的“幸存”**: `Transformer` 模型的实现恰好不依赖于那些在保存时被遗漏的特定超参数,因此它能“幸存”下来。
2. **其他模型的“共性缺陷”**: 其他所有模型 (`mLSTM`, `TCN`, `KAN`) 在它们的构造函数中,都依赖于一些在训练时定义、但在保存到检查点文件 (`.pth`) 时**被遗漏的**关键结构性参数。
* **mLSTM**: 缺少 `mlstm_layers`, `embed_dim`, `dense_dim` 等参数。
* **TCN**: 缺少 `num_channels`, `kernel_size` 等参数。
* **KAN**: 缺少 `hidden_sizes` 列表。
3. **连锁失败**:
* 当 `server/predictors/model_predictor.py` 尝试加载这些模型的检查点文件时,它从 `checkpoint['config']` 中找不到实例化模型所必需的全部参数。
* 模型实例化失败,抛出 `KeyError``TypeError`
* 这个异常导致 `load_model_and_predict` 函数提前返回 `None`,最终导致返回给前端的响应中缺少 `history_data`,前端因此无法渲染图表。
**系统性、可扩展的解决方案**:
为了彻底解决这个问题,并为未来平稳地加入新算法,我们对所有非 Transformer 的训练器进行了标准化的、彻底的修复。
1. **修复 `mlstm_trainer.py`**: 在 `config` 字典中补全了 `mlstm_layers`, `embed_dim`, `dense_dim` 等所有缺失的参数。
2. **修复 `tcn_trainer.py`**: 在 `config` 字典中补全了 `num_channels`, `kernel_size` 等所有缺失的参数。
3. **修复 `kan_trainer.py`**: 在 `config` 字典中补全了 `hidden_sizes` 列表。
**结果**:
通过这次系统性的修复,我们确保了所有训练器在保存模型时,都会将完整的、可用于重新实例化模型的配置信息写入检查点文件。这从根本上解决了所有模型算法的预测失败问题,使得整个系统在处理不同算法时具有了通用性和健壮性。

Binary file not shown.

View File

@ -1508,35 +1508,40 @@ def predict():
"""
try:
data = request.json
product_id = data.get('product_id')
model_type = data.get('model_type')
store_id = data.get('store_id')
training_mode = 'store' if store_id else 'product'
version = data.get('version') # 新增版本参数
version = data.get('version')
future_days = int(data.get('future_days', 7))
start_date = data.get('start_date', '')
include_visualization = data.get('include_visualization', False)
scope_msg = f", store_id={store_id}" if store_id else ", 全局模型"
print(f"API接收到预测请求: product_id={product_id}, model_type={model_type}, version={version}{scope_msg}, future_days={future_days}, start_date={start_date}")
if not product_id or not model_type:
return jsonify({"status": "error", "error": "product_id 和 model_type 是必需的"}), 400
# 确定训练模式和标识符
training_mode = data.get('training_mode', 'product')
product_id = data.get('product_id')
store_id = data.get('store_id')
# 获取产品名称
product_name = get_product_name(product_id)
if not product_name:
product_name = product_id
# 根据训练模式构建模型标识符
if training_mode == 'store':
if training_mode == 'global':
# 全局模式:使用硬编码的标识符,并为预测函数设置占位符
model_identifier = "global_all_products_sum"
product_id = 'all_products'
product_name = "全局聚合数据"
elif training_mode == 'store':
# 店铺模式验证store_id并构建标识符
if not store_id:
return jsonify({"status": "error", "error": "店铺模式需要 store_id"}), 400
model_identifier = f"store_{store_id}"
# 对于店铺预测product_id实际上是store_id但我们需要一个药品ID来获取名称这里暂时用一个占位符
product_name = f"店铺 {store_id} 整体"
else:
else: # 默认为 'product' 模式
# 药品模式验证product_id并构建标识符
if not product_id:
return jsonify({"status": "error", "error": "药品模式需要 product_id"}), 400
model_identifier = product_id
product_name = get_product_name(product_id) or product_id
print(f"API接收到预测请求: mode={training_mode}, model_identifier='{model_identifier}', model_type='{model_type}', version='{version}'")
if not model_type:
return jsonify({"status": "error", "error": "model_type 是必需的"}), 400
# 获取模型版本
if not version:
version = get_latest_model_version(model_identifier, model_type)
@ -3818,7 +3823,9 @@ def get_store_model_versions_api(store_id, model_type):
def get_global_model_versions_api(model_type):
"""获取全局模型版本列表API"""
try:
model_identifier = "global"
# 全局模型的标识符是在训练时确定的,例如 'global_all_products_sum'
# 这里我们假设前端请求的是默认的全局模型
model_identifier = "global_all_products_sum"
versions = get_model_versions(model_identifier, model_type)
latest_version = get_latest_model_version(model_identifier, model_type)

View File

@ -133,7 +133,14 @@ def get_model_file_path(product_id: str, model_type: str, version: str) -> str:
# 修正直接使用唯一的product_id它可能包含store_前缀来构建文件名
# 文件名示例: transformer_17002608_epoch_best.pth 或 transformer_store_01010023_epoch_best.pth
filename = f"{model_type}_{product_id}_epoch_{version}.pth"
# 针对 KAN 和 optimized_kan使用 model_manager 的命名约定
if model_type in ['kan', 'optimized_kan']:
# 格式: {model_type}_product_{product_id}_{version}.pth
# 注意KAN trainer 保存时product_id 就是 model_identifier
filename = f"{model_type}_product_{product_id}_{version}.pth"
else:
# 其他模型使用 _epoch_ 约定
filename = f"{model_type}_{product_id}_epoch_{version}.pth"
# 修正直接在根模型目录查找不再使用checkpoints子目录
return os.path.join(DEFAULT_MODEL_DIR, filename)
@ -151,32 +158,46 @@ def get_model_versions(product_id: str, model_type: str) -> list:
# 直接使用传入的product_id构建搜索模式
# 搜索模式,匹配 "transformer_product_17002608_epoch_50.pth" 或 "transformer_product_17002608_epoch_best.pth"
# 修正直接使用唯一的product_id它可能包含store_前缀来构建搜索模式
pattern = f"{model_type}_{product_id}_epoch_*.pth"
# 修正直接在根模型目录查找不再使用checkpoints子目录
search_path = os.path.join(DEFAULT_MODEL_DIR, pattern)
existing_files = glob.glob(search_path)
# 扩展搜索模式以兼容多种命名约定
patterns = [
f"{model_type}_{product_id}_epoch_*.pth", # 原始格式 (e.g., transformer_123_epoch_best.pth)
f"{model_type}_product_{product_id}_*.pth" # KAN/ModelManager格式 (e.g., kan_product_123_v1.pth)
]
existing_files = []
for pattern in patterns:
search_path = os.path.join(DEFAULT_MODEL_DIR, pattern)
existing_files.extend(glob.glob(search_path))
# 旧格式(兼容性支持)
pattern_old = f"{model_type}_model_product_{product_id}.pth"
old_file_path = os.path.join(DEFAULT_MODEL_DIR, pattern_old)
has_old_format = os.path.exists(old_file_path)
old_file_path = os.path.join(DEFAULT_MODEL_DIR, pattern_old)
has_old_format = os.path.exists(old_file_path)
if os.path.exists(old_file_path):
existing_files.append(old_file_path)
versions = set() # 使用集合避免重复
# 从找到的文件中提取版本信息
for file_path in existing_files:
filename = os.path.basename(file_path)
# 匹配 _epoch_ 后面的内容作为版本
version_match = re.search(r"_epoch_(.+)\.pth$", filename)
if version_match:
versions.add(version_match.group(1))
# 如果存在旧格式文件将其视为v1
if has_old_format:
versions.add("v1_legacy") # 添加一个特殊标识
print(f"检测到旧格式模型文件: {old_file_path},将其视为版本 v1_legacy")
# 尝试匹配 _epoch_ 格式
version_match_epoch = re.search(r"_epoch_(.+)\.pth$", filename)
if version_match_epoch:
versions.add(version_match_epoch.group(1))
continue
# 尝试匹配 _product_..._v 格式 (KAN)
version_match_kan = re.search(r"_product_.+_v(\d+)\.pth$", filename)
if version_match_kan:
versions.add(f"v{version_match_kan.group(1)}")
continue
# 尝试匹配旧的 _model_product_ 格式
if pattern_old in filename:
versions.add("v1_legacy") # 添加一个特殊标识
print(f"检测到旧格式模型文件: {old_file_path},将其视为版本 v1_legacy")
continue
# 转换为列表并排序
sorted_versions = sorted(list(versions))

View File

@ -191,6 +191,7 @@ class PharmacyPredictor:
if model_type == 'transformer':
model_result, metrics, actual_version = train_product_model_with_transformer(
product_id=product_id,
model_identifier=model_identifier,
product_df=product_data,
store_id=store_id,
training_mode=training_mode,
@ -208,6 +209,7 @@ class PharmacyPredictor:
elif model_type == 'mlstm':
_, metrics, _, _ = train_product_model_with_mlstm(
product_id=product_id,
model_identifier=model_identifier,
product_df=product_data,
store_id=store_id,
training_mode=training_mode,
@ -223,6 +225,7 @@ class PharmacyPredictor:
elif model_type == 'kan':
_, metrics = train_product_model_with_kan(
product_id=product_id,
model_identifier=model_identifier,
product_df=product_data,
store_id=store_id,
training_mode=training_mode,
@ -236,6 +239,7 @@ class PharmacyPredictor:
elif model_type == 'optimized_kan':
_, metrics = train_product_model_with_kan(
product_id=product_id,
model_identifier=model_identifier,
product_df=product_data,
store_id=store_id,
training_mode=training_mode,
@ -249,6 +253,7 @@ class PharmacyPredictor:
elif model_type == 'tcn':
_, metrics, _, _ = train_product_model_with_tcn(
product_id=product_id,
model_identifier=model_identifier,
product_df=product_data,
store_id=store_id,
training_mode=training_mode,

View File

@ -113,11 +113,17 @@ def load_model_and_predict(product_id, model_type, store_id=None, future_days=7,
)
store_name = product_df['store_name'].iloc[0] if 'store_name' in product_df.columns and not product_df.empty else f"店铺{store_id}"
prediction_scope = f"店铺 '{store_name}' ({store_id})"
# 对于店铺模型,其“产品名称”就是店铺名称
product_name = store_name
elif training_mode == 'global':
# 全局模型:聚合所有数据
product_df = aggregate_multi_store_data(
aggregation_method='sum',
file_path=DEFAULT_DATA_PATH
)
prediction_scope = "全局聚合数据"
product_name = "全局销售数据"
else:
# 产品模型(默认):聚合该产品在所有店铺的数据
# 此时传入的product_id是真正的产品ID
product_df = aggregate_multi_store_data(
product_id=product_id,
aggregation_method='sum',
@ -210,11 +216,11 @@ def load_model_and_predict(product_id, model_type, store_id=None, future_days=7,
model = MatrixLSTM(
num_features=config['input_dim'],
hidden_size=config['hidden_size'],
mlstm_layers=config['num_layers'],
mlstm_layers=config['mlstm_layers'],
embed_dim=embed_dim,
dense_dim=dense_dim,
num_heads=num_heads,
dropout_rate=config['dropout'],
dropout_rate=config['dropout_rate'],
num_blocks=num_blocks,
output_sequence_length=config['output_dim']
).to(DEVICE)
@ -235,7 +241,7 @@ def load_model_and_predict(product_id, model_type, store_id=None, future_days=7,
num_features=config['input_dim'],
output_sequence_length=config['output_dim'],
num_channels=[config['hidden_size']] * config['num_layers'],
kernel_size=3,
kernel_size=config['kernel_size'],
dropout=config['dropout']
).to(DEVICE)
else:

View File

@ -21,7 +21,7 @@ from utils.visualization import plot_loss_curve
from analysis.metrics import evaluate_model
from core.config import DEVICE, DEFAULT_MODEL_DIR, LOOK_BACK, FORECAST_HORIZON
def train_product_model_with_kan(product_id, product_df=None, store_id=None, training_mode='product', aggregation_method='sum', epochs=50, sequence_length=LOOK_BACK, forecast_horizon=FORECAST_HORIZON, use_optimized=False, model_dir=DEFAULT_MODEL_DIR):
def train_product_model_with_kan(product_id, model_identifier, product_df=None, store_id=None, training_mode='product', aggregation_method='sum', epochs=50, sequence_length=LOOK_BACK, forecast_horizon=FORECAST_HORIZON, use_optimized=False, model_dir=DEFAULT_MODEL_DIR):
"""
使用KAN模型训练产品销售预测模型
@ -168,6 +168,7 @@ def train_product_model_with_kan(product_id, product_df=None, store_id=None, tra
train_losses = []
test_losses = []
start_time = time.time()
best_loss = float('inf')
for epoch in range(epochs):
model.train()
@ -225,6 +226,43 @@ def train_product_model_with_kan(product_id, product_df=None, store_id=None, tra
test_loss = test_loss / len(test_loader)
test_losses.append(test_loss)
# 检查是否为最佳模型
model_type_name = 'optimized_kan' if use_optimized else 'kan'
if test_loss < best_loss:
best_loss = test_loss
print(f"🎉 新的最佳模型发现在 epoch {epoch+1},测试损失: {test_loss:.4f}")
# 为保存最佳模型准备数据
best_model_data = {
'model_state_dict': model.state_dict(),
'scaler_X': scaler_X,
'scaler_y': scaler_y,
'config': {
'input_dim': input_dim,
'output_dim': output_dim,
'hidden_size': hidden_size,
'hidden_sizes': [hidden_size, hidden_size * 2, hidden_size],
'sequence_length': sequence_length,
'forecast_horizon': forecast_horizon,
'model_type': model_type_name,
'use_optimized': use_optimized
},
'epoch': epoch + 1
}
# 使用模型管理器保存 'best' 版本
from utils.model_manager import model_manager
model_manager.save_model(
model_data=best_model_data,
product_id=model_identifier,
model_type=model_type_name,
version='best',
store_id=store_id,
training_mode=training_mode,
aggregation_method=aggregation_method,
product_name=product_name
)
if (epoch + 1) % 10 == 0:
print(f"Epoch {epoch+1}/{epochs}, Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}")
@ -282,7 +320,7 @@ def train_product_model_with_kan(product_id, product_df=None, store_id=None, tra
'input_dim': input_dim,
'output_dim': output_dim,
'hidden_size': hidden_size,
'hidden_sizes': [hidden_size, hidden_size*2, hidden_size],
'hidden_sizes': [hidden_size, hidden_size * 2, hidden_size],
'sequence_length': sequence_length,
'forecast_horizon': forecast_horizon,
'model_type': model_type_name,
@ -299,9 +337,9 @@ def train_product_model_with_kan(product_id, product_df=None, store_id=None, tra
model_path = model_manager.save_model(
model_data=model_data,
product_id=product_id,
product_id=model_identifier,
model_type=model_type_name,
version='v1', # KAN训练器默认使用v1
version=f'final_epoch_{epochs}',
store_id=store_id,
training_mode=training_mode,
aggregation_method=aggregation_method,

View File

@ -25,8 +25,8 @@ from core.config import (
)
from utils.training_progress import progress_manager
def save_checkpoint(checkpoint_data: dict, epoch_or_label, product_id: str,
model_type: str, model_dir: str, store_id=None,
def save_checkpoint(checkpoint_data: dict, epoch_or_label, model_identifier: str,
model_type: str, model_dir: str, store_id=None,
training_mode: str = 'product', aggregation_method=None):
"""
保存训练检查点
@ -47,7 +47,7 @@ def save_checkpoint(checkpoint_data: dict, epoch_or_label, product_id: str,
os.makedirs(checkpoint_dir, exist_ok=True)
# 修正直接使用product_id作为唯一标识符因为它已经包含了store_前缀或药品ID
filename = f"{model_type}_{product_id}_epoch_{epoch_or_label}.pth"
filename = f"{model_type}_{model_identifier}_epoch_{epoch_or_label}.pth"
checkpoint_path = os.path.join(checkpoint_dir, filename)
@ -102,6 +102,7 @@ def load_checkpoint(product_id: str, model_type: str, epoch_or_label,
def train_product_model_with_mlstm(
product_id,
model_identifier,
product_df,
store_id=None,
training_mode='product',
@ -430,10 +431,11 @@ def train_product_model_with_mlstm(
'output_dim': output_dim,
'hidden_size': hidden_size,
'num_heads': num_heads,
'dropout': dropout_rate,
'dropout_rate': dropout_rate,
'num_blocks': num_blocks,
'embed_dim': embed_dim,
'dense_dim': dense_dim,
'mlstm_layers': 2, # 确保这个参数被保存
'sequence_length': sequence_length,
'forecast_horizon': forecast_horizon,
'model_type': 'mlstm'
@ -450,13 +452,13 @@ def train_product_model_with_mlstm(
}
# 保存检查点
save_checkpoint(checkpoint_data, epoch + 1, product_id, 'mlstm',
save_checkpoint(checkpoint_data, epoch + 1, model_identifier, 'mlstm',
model_dir, store_id, training_mode, aggregation_method)
# 如果是最佳模型,额外保存一份
if test_loss < best_loss:
best_loss = test_loss
save_checkpoint(checkpoint_data, 'best', product_id, 'mlstm',
save_checkpoint(checkpoint_data, 'best', model_identifier, 'mlstm',
model_dir, store_id, training_mode, aggregation_method)
emit_progress(f"💾 保存最佳模型检查点 (epoch {epoch+1}, test_loss: {test_loss:.4f})")
epochs_no_improve = 0
@ -551,10 +553,11 @@ def train_product_model_with_mlstm(
'output_dim': output_dim,
'hidden_size': hidden_size,
'num_heads': num_heads,
'dropout': dropout_rate,
'dropout_rate': dropout_rate,
'num_blocks': num_blocks,
'embed_dim': embed_dim,
'dense_dim': dense_dim,
'mlstm_layers': 2, # 确保这个参数被保存
'sequence_length': sequence_length,
'forecast_horizon': forecast_horizon,
'model_type': 'mlstm'
@ -575,7 +578,7 @@ def train_product_model_with_mlstm(
# 保存最终模型使用epoch标识
final_model_path = save_checkpoint(
final_model_data, f"final_epoch_{epochs}", product_id, 'mlstm',
final_model_data, f"final_epoch_{epochs}", model_identifier, 'mlstm',
model_dir, store_id, training_mode, aggregation_method
)

View File

@ -21,8 +21,8 @@ from analysis.metrics import evaluate_model
from core.config import DEVICE, DEFAULT_MODEL_DIR, LOOK_BACK, FORECAST_HORIZON
from utils.training_progress import progress_manager
def save_checkpoint(checkpoint_data: dict, epoch_or_label, product_id: str,
model_type: str, model_dir: str, store_id=None,
def save_checkpoint(checkpoint_data: dict, epoch_or_label, model_identifier: str,
model_type: str, model_dir: str, store_id=None,
training_mode: str = 'product', aggregation_method=None):
"""
保存训练检查点
@ -44,7 +44,7 @@ def save_checkpoint(checkpoint_data: dict, epoch_or_label, product_id: str,
# 生成检查点文件名
# 修正直接使用product_id作为唯一标识符因为它已经包含了store_前缀或药品ID
filename = f"{model_type}_{product_id}_epoch_{epoch_or_label}.pth"
filename = f"{model_type}_{model_identifier}_epoch_{epoch_or_label}.pth"
checkpoint_path = os.path.join(checkpoint_dir, filename)
@ -56,6 +56,7 @@ def save_checkpoint(checkpoint_data: dict, epoch_or_label, product_id: str,
def train_product_model_with_tcn(
product_id,
model_identifier,
product_df=None,
store_id=None,
training_mode='product',
@ -381,6 +382,7 @@ def train_product_model_with_tcn(
'output_dim': output_dim,
'hidden_size': hidden_size,
'num_layers': num_layers,
'num_channels': [hidden_size] * num_layers,
'dropout': dropout_rate,
'kernel_size': kernel_size,
'sequence_length': sequence_length,
@ -398,13 +400,13 @@ def train_product_model_with_tcn(
}
# 保存检查点
save_checkpoint(checkpoint_data, epoch + 1, product_id, 'tcn',
save_checkpoint(checkpoint_data, epoch + 1, model_identifier, 'tcn',
model_dir, store_id, training_mode, aggregation_method)
# 如果是最佳模型,额外保存一份
if test_loss < best_loss:
best_loss = test_loss
save_checkpoint(checkpoint_data, 'best', product_id, 'tcn',
save_checkpoint(checkpoint_data, 'best', model_identifier, 'tcn',
model_dir, store_id, training_mode, aggregation_method)
emit_progress(f"💾 保存最佳模型检查点 (epoch {epoch+1}, test_loss: {test_loss:.4f})")
@ -471,6 +473,7 @@ def train_product_model_with_tcn(
'output_dim': output_dim,
'hidden_size': hidden_size,
'num_layers': num_layers,
'num_channels': [hidden_size] * num_layers,
'dropout': dropout_rate,
'kernel_size': kernel_size,
'sequence_length': sequence_length,
@ -494,7 +497,7 @@ def train_product_model_with_tcn(
# 保存最终模型使用epoch标识
final_model_path = save_checkpoint(
final_model_data, f"final_epoch_{epochs}", product_id, 'tcn',
final_model_data, f"final_epoch_{epochs}", model_identifier, 'tcn',
model_dir, store_id, training_mode, aggregation_method
)

View File

@ -27,8 +27,8 @@ from core.config import (
from utils.training_progress import progress_manager
from utils.model_manager import model_manager
def save_checkpoint(checkpoint_data: dict, epoch_or_label, product_id: str,
model_type: str, model_dir: str, store_id=None,
def save_checkpoint(checkpoint_data: dict, epoch_or_label, model_identifier: str,
model_type: str, model_dir: str, store_id=None,
training_mode: str = 'product', aggregation_method=None):
"""
保存训练检查点
@ -48,7 +48,7 @@ def save_checkpoint(checkpoint_data: dict, epoch_or_label, product_id: str,
os.makedirs(checkpoint_dir, exist_ok=True)
# 修正直接使用product_id作为唯一标识符因为它已经包含了store_前缀或药品ID
filename = f"{model_type}_{product_id}_epoch_{epoch_or_label}.pth"
filename = f"{model_type}_{model_identifier}_epoch_{epoch_or_label}.pth"
checkpoint_path = os.path.join(checkpoint_dir, filename)
@ -60,6 +60,7 @@ def save_checkpoint(checkpoint_data: dict, epoch_or_label, product_id: str,
def train_product_model_with_transformer(
product_id,
model_identifier,
product_df=None,
store_id=None,
training_mode='product',
@ -399,13 +400,13 @@ def train_product_model_with_transformer(
}
# 保存检查点
save_checkpoint(checkpoint_data, epoch + 1, product_id, 'transformer',
save_checkpoint(checkpoint_data, epoch + 1, model_identifier, 'transformer',
model_dir, store_id, training_mode, aggregation_method)
# 如果是最佳模型,额外保存一份
if test_loss < best_loss:
best_loss = test_loss
save_checkpoint(checkpoint_data, 'best', product_id, 'transformer',
save_checkpoint(checkpoint_data, 'best', model_identifier, 'transformer',
model_dir, store_id, training_mode, aggregation_method)
emit_progress(f"💾 保存最佳模型检查点 (epoch {epoch+1}, test_loss: {test_loss:.4f})")
epochs_no_improve = 0
@ -501,7 +502,7 @@ def train_product_model_with_transformer(
# 保存最终模型使用epoch标识
final_model_path = save_checkpoint(
final_model_data, f"final_epoch_{epochs}", product_id, 'transformer',
final_model_data, f"final_epoch_{epochs}", model_identifier, 'transformer',
model_dir, store_id, training_mode, aggregation_method
)

View File

@ -0,0 +1,464 @@
# 系统调用逻辑与核心代码分析
本文档旨在详细阐述本销售预测系统的端到端调用链路,从系统启动、前端交互、后端处理,到最终的模型训练、预测和图表展示。
## 1. 系统启动
系统由两部分组成Vue.js前端和Flask后端。
### 1.1. 启动后端API服务
在项目根目录下,通过以下命令启动后端服务:
```bash
python server/api.py
```
该命令会启动一个Flask应用监听在 `http://localhost:5000`并提供所有API和WebSocket服务。
### 1.2. 启动前端开发服务器
进入 `UI` 目录,执行以下命令:
```bash
cd UI
npm install
npm run dev
```
这将启动Vite开发服务器通常在 `http://localhost:5173`,并自动打开浏览器访问前端页面。
## 2. 核心调用链路概览
以最核心的 **“按药品训练 -> 按药品预测”** 流程为例,其高层调用链路如下:
**训练流程:**
`前端UI` -> `POST /api/training` -> `api.py: start_training()` -> `TrainingManager` -> `后台进程` -> `predictor.py: train_model()` -> `[model]_trainer.py: train_product_model_with_*()` -> `保存模型.pth`
**预测流程:**
`前端UI` -> `POST /api/prediction` -> `api.py: predict()` -> `predictor.py: predict()` -> `model_predictor.py: load_model_and_predict()` -> `加载模型.pth` -> `返回预测JSON` -> `前端图表渲染`
## 3. 详细流程:按药品训练
此流程的目标是为特定药品训练一个专用的预测模型。
### 3.1. 前端交互与API请求
1. **用户操作**: 用户在 **“按药品训练”** 页面 ([`UI/src/views/training/ProductTrainingView.vue`](UI/src/views/training/ProductTrainingView.vue:1)) 选择一个药品、一个模型类型如Transformer、设置训练轮次Epochs然后点击 **“启动药品训练”** 按钮。
2. **触发函数**: 点击事件调用 [`startTraining`](UI/src/views/training/ProductTrainingView.vue:521) 方法。
3. **构建Payload**: `startTraining` 方法构建一个包含训练参数的 `payload` 对象。关键字段是 `training_mode: 'product'`,用于告知后端这是针对特定产品的训练。
*核心代码 ([`UI/src/views/training/ProductTrainingView.vue`](UI/src/views/training/ProductTrainingView.vue:521))*
```javascript
const startTraining = async () => {
// ... 表单验证 ...
trainingLoading.value = true;
try {
const endpoint = "/api/training";
const payload = {
product_id: form.product_id,
store_id: form.data_scope === 'global' ? null : form.store_id,
model_type: form.model_type,
epochs: form.epochs,
training_mode: 'product' // 标识这是药品训练模式
};
const response = await axios.post(endpoint, payload);
// ... 处理响应启动WebSocket监听 ...
}
// ... 错误处理 ...
};
```
4. **API请求**: 使用 `axios` 向后端 `POST /api/training` 发送请求。
### 3.2. 后端API接收与任务分发
1. **路由处理**: 后端 [`server/api.py`](server/api.py:1) 中的 [`@app.route('/api/training', methods=['POST'])`](server/api.py:933) 装饰器捕获该请求,并由 [`start_training()`](server/api.py:971) 函数处理。
2. **任务提交**: `start_training()` 函数解析请求中的JSON数据然后调用 `training_manager.submit_task()` 将训练任务提交到一个后台进程池中执行以避免阻塞API主线程。这使得API可以立即返回一个任务ID而训练在后台异步进行。
*核心代码 ([`server/api.py`](server/api.py:971))*
```python
@app.route('/api/training', methods=['POST'])
def start_training():
data = request.get_json()
training_mode = data.get('training_mode', 'product')
model_type = data.get('model_type')
epochs = data.get('epochs', 50)
product_id = data.get('product_id')
store_id = data.get('store_id')
if not model_type or (training_mode == 'product' and not product_id):
return jsonify({'error': '缺少必要参数'}), 400
try:
# 使用训练进程管理器提交任务
task_id = training_manager.submit_task(
product_id=product_id or "unknown",
model_type=model_type,
training_mode=training_mode,
store_id=store_id,
epochs=epochs
)
logger.info(f"🚀 训练任务已提交到进程管理器: {task_id[:8]}")
return jsonify({
'message': '模型训练已开始(使用独立进程)',
'task_id': task_id,
})
except Exception as e:
logger.error(f"❌ 提交训练任务失败: {str(e)}")
return jsonify({'error': f'启动训练任务失败: {str(e)}'}), 500
```
### 3.3. 核心训练逻辑
1. **调用核心预测器**: 后台进程最终会调用 [`server/core/predictor.py`](server/core/predictor.py:1) 中的 [`PharmacyPredictor.train_model()`](server/core/predictor.py:63) 方法。
2. **数据准备**: `train_model` 方法首先根据 `training_mode` (`'product'`) 和 `product_id` 从数据源加载并聚合所有店铺关于该药品的销售数据。
3. **分发到具体训练器**: 接着,它根据 `model_type` 调用相应的训练函数。例如,如果 `model_type``transformer`,它会调用 `train_product_model_with_transformer`
*核心代码 ([`server/core/predictor.py`](server/core/predictor.py:63))*
```python
class PharmacyPredictor:
def train_model(self, product_id, model_type='transformer', ..., training_mode='product', ...):
# ...
if training_mode == 'product':
product_data = self.data[self.data['product_id'] == product_id].copy()
# ...
# 根据训练模式构建模型标识符
model_identifier = product_id
try:
if model_type == 'transformer':
model_result, metrics, actual_version = train_product_model_with_transformer(
product_id=product_id,
model_identifier=model_identifier,
product_df=product_data,
# ... 其他参数 ...
)
# ... 其他模型的elif分支 ...
return metrics
except Exception as e:
# ... 错误处理 ...
return None
```
### 3.4. 模型训练与保存
1. **具体训练器**: 以 [`server/trainers/transformer_trainer.py`](server/trainers/transformer_trainer.py:1) 为例,`train_product_model_with_transformer` 函数执行以下步骤:
* **数据预处理**: 调用 `prepare_data``prepare_sequences` 将原始销售数据转换为模型可以理解的、带有时间序列特征的监督学习格式(输入序列和目标序列)。
* **模型实例化**: 创建 `TimeSeriesTransformer` 模型实例。
* **训练循环**: 执行指定的 `epochs` 次训练,计算损失并使用优化器更新模型权重。
* **进度更新**: 在训练过程中,通过 `socketio.emit` 向前端发送 `training_progress` 事件,实时更新进度条和日志。
* **模型保存**: 训练完成后,将模型权重 (`model.state_dict()`)、完整的模型配置 (`config`) 以及数据缩放器 (`scaler_X`, `scaler_y`) 打包成一个字典checkpoint并使用 `torch.save()` 保存到 `.pth` 文件中。文件名由 `get_model_file_path` 根据 `model_identifier``model_type``version` 统一生成。
*核心代码 ([`server/trainers/transformer_trainer.py`](server/trainers/transformer_trainer.py:33))*
```python
def train_product_model_with_transformer(...):
# ... 数据准备 ...
# 定义模型配置
config = {
'input_dim': input_dim,
'output_dim': forecast_horizon,
'hidden_size': hidden_size,
# ... 所有必要的超参数 ...
'model_type': 'transformer'
}
model = TimeSeriesTransformer(...)
# ... 训练循环 ...
# 保存模型
checkpoint = {
'model_state_dict': model.state_dict(),
'config': config,
'scaler_X': scaler_X,
'scaler_y': scaler_y,
'metrics': test_metrics
}
model_path = get_model_file_path(model_identifier, 'transformer', version)
torch.save(checkpoint, model_path)
return model, test_metrics, version
```
## 4. 详细流程:按药品预测
训练完成后,用户可以使用已保存的模型进行预测。
### 4.1. 前端交互与API请求
1. **用户操作**: 用户在 **“按药品预测”** 页面 ([`UI/src/views/prediction/ProductPredictionView.vue`](UI/src/views/prediction/ProductPredictionView.vue:1)) 选择同一个药品、对应的模型和版本,然后点击 **“开始预测”**。
2. **触发函数**: 点击事件调用 [`startPrediction`](UI/src/views/prediction/ProductPredictionView.vue:202) 方法。
3. **构建Payload**: 该方法构建一个包含预测参数的 `payload`
*核心代码 ([`UI/src/views/prediction/ProductPredictionView.vue`](UI/src/views/prediction/ProductPredictionView.vue:202))*
```javascript
const startPrediction = async () => {
try {
predicting.value = true
const payload = {
product_id: form.product_id,
model_type: form.model_type,
version: form.version,
future_days: form.future_days,
// training_mode is implicitly 'product' here
}
const response = await axios.post('/api/prediction', payload)
if (response.data.status === 'success') {
predictionResult.value = response.data
await nextTick()
renderChart()
}
// ... 错误处理 ...
}
// ...
}
```
4. **API请求**: 使用 `axios` 向后端 `POST /api/prediction` 发送请求。
### 4.2. 后端API接收与预测执行
1. **路由处理**: [`server/api.py`](server/api.py:1) 中的 [`@app.route('/api/prediction', methods=['POST'])`](server/api.py:1413) 捕获请求,由 [`predict()`](server/api.py:1469) 函数处理。
2. **调用核心预测器**: `predict()` 函数解析参数,然后调用 `run_prediction` 辅助函数,该函数内部再调用 [`server/core/predictor.py`](server/core/predictor.py:1) 中的 [`PharmacyPredictor.predict()`](server/core/predictor.py:295) 方法。
*核心代码 ([`server/api.py`](server/api.py:1469))*
```python
@app.route('/api/prediction', methods=['POST'])
def predict():
try:
data = request.json
# ... 解析参数 ...
training_mode = data.get('training_mode', 'product')
product_id = data.get('product_id')
# ...
# 根据模式确定模型标识符
if training_mode == 'product':
model_identifier = product_id
# ...
# 执行预测
prediction_result = run_prediction(model_type, product_id, model_id, ...)
# ... 格式化响应 ...
return jsonify(response_data)
except Exception as e:
# ... 错误处理 ...
```
3. **分发到模型加载器**: [`PharmacyPredictor.predict()`](server/core/predictor.py:295) 方法的主要作用是再次根据 `training_mode``product_id` 确定 `model_identifier`,然后将所有参数传递给 [`server/predictors/model_predictor.py`](server/predictors/model_predictor.py:1) 中的 [`load_model_and_predict()`](server/predictors/model_predictor.py:26) 函数。
*核心代码 ([`server/core/predictor.py`](server/core/predictor.py:295))*
```python
class PharmacyPredictor:
def predict(self, product_id, model_type, ..., training_mode='product', ...):
if training_mode == 'product':
model_identifier = product_id
# ...
return load_model_and_predict(
model_identifier,
model_type,
# ... 其他参数 ...
)
```
### 4.3. 模型加载与执行预测
[`load_model_and_predict()`](server/predictors/model_predictor.py:26) 是预测流程的核心,它执行以下步骤:
1. **定位模型文件**: 使用 `get_model_file_path` 根据 `product_id` (即 `model_identifier`), `model_type`, 和 `version` 找到之前保存的 `.pth` 模型文件。
2. **加载Checkpoint**: 使用 `torch.load()` 加载模型文件,得到包含 `model_state_dict`, `config`, 和 `scalers` 的字典。
3. **重建模型**: 根据加载的 `config` 中的超参数(如 `hidden_size`, `num_layers` 等),重新创建一个与训练时结构完全相同的模型实例。**这是我们之前修复的关键点,确保所有必要参数都被保存和加载。**
4. **加载权重**: 将加载的 `model_state_dict` 应用到新创建的模型实例上。
5. **准备输入数据**: 从数据源获取最新的 `sequence_length` 天的历史数据作为预测的输入。
6. **数据归一化**: 使用加载的 `scaler_X` 对输入数据进行归一化。
7. **执行预测**: 将归一化的数据输入模型 (`model(X_input)`),得到预测结果。
8. **反归一化**: 使用加载的 `scaler_y` 将模型的输出(预测值)反归一化,转换回原始的销售量尺度。
9. **构建结果**: 将预测值和对应的未来日期组合成一个DataFrame并连同历史数据一起返回。
*核心代码 ([`server/predictors/model_predictor.py`](server/predictors/model_predictor.py:26))*
```python
def load_model_and_predict(...):
# ... 找到模型文件路径 model_path ...
# 加载模型和配置
checkpoint = torch.load(model_path, map_location=DEVICE, weights_only=False)
config = checkpoint['config']
scaler_X = checkpoint['scaler_X']
scaler_y = checkpoint['scaler_y']
# 创建模型实例 (以Transformer为例)
model = TimeSeriesTransformer(
num_features=config['input_dim'],
d_model=config['hidden_size'],
# ... 使用config中的所有参数 ...
).to(DEVICE)
# 加载模型参数
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
# ... 准备输入数据 ...
# 归一化输入数据
X_scaled = scaler_X.transform(X)
X_input = torch.tensor(X_scaled.reshape(1, sequence_length, -1), ...).to(DEVICE)
# 预测
with torch.no_grad():
y_pred_scaled = model(X_input).cpu().numpy()
# 反归一化
y_pred = scaler_y.inverse_transform(y_pred_scaled.reshape(-1, 1)).flatten()
# ... 构建返回结果 ...
return {
'predictions': predictions_df,
'history_data': recent_history,
# ...
}
```
### 4.4. 响应格式化与前端图表渲染
1. **API层格式化**: 在 [`server/api.py`](server/api.py:1) 的 [`predict()`](server/api.py:1469) 函数中,从 `load_model_and_predict` 返回的结果被精心格式化成前端期望的JSON结构该结构在顶层同时包含 `history_data``prediction_data` 两个数组。
2. **前端接收数据**: 前端 [`ProductPredictionView.vue`](UI/src/views/prediction/ProductPredictionView.vue:1) 在 `startPrediction` 方法中接收到这个JSON响应并将其存入 `predictionResult` ref。
3. **图表渲染**: [`renderChart()`](UI/src/views/prediction/ProductPredictionView.vue:232) 方法被调用。它从 `predictionResult.value` 中提取 `history_data``prediction_data`然后使用Chart.js库将这两部分数据绘制在同一个 `<canvas>` 上,历史数据为实线,预测数据为虚线,从而形成一个连续的趋势图。
*核心代码 ([`UI/src/views/prediction/ProductPredictionView.vue`](UI/src/views/prediction/ProductPredictionView.vue:232))*
```javascript
const renderChart = () => {
if (!chartCanvas.value || !predictionResult.value) return
// ...
// 后端直接提供 history_data 和 prediction_data
const historyData = predictionResult.value.history_data || []
const predictionData = predictionResult.value.prediction_data || []
const historyLabels = historyData.map(p => p.date)
const historySales = historyData.map(p => p.sales)
const predictionLabels = predictionData.map(p => p.date)
const predictionSales = predictionData.map(p => p.predicted_sales)
// ... 组合标签和数据,对齐数据点 ...
chart = new Chart(chartCanvas.value, {
type: 'line',
data: {
labels: allLabels,
datasets: [
{
label: '历史销量',
data: alignedHistorySales,
// ... 样式 ...
},
{
label: '预测销量',
data: alignedPredictionSales,
// ... 样式 ...
}
]
},
// ... Chart.js 配置 ...
})
}
```
至此,一个完整的“训练->预测->展示”的调用链路就完成了。
## 5. 模型保存规则与路径
为了确保模型的唯一性、可追溯性和可复现性,系统采用了一套严格的文件保存和命名规则。所有相关的逻辑都集中在 [`server/core/config.py`](server/core/config.py:1) 中。
### 5.1. 统一保存目录
所有训练产物包括模型权重、配置和数据缩放器Scalers都保存在项目根目录下的 `saved_models/` 文件夹中。
- **路径**: `PROJECT_ROOT/saved_models/`
- **定义**: 该路径由 [`server/core/config.py`](server/core/config.py:1) 中的 `DEFAULT_MODEL_DIR` 变量指定。
### 5.2. 文件命名规范
模型文件的命名遵循一个标准化的格式,以便在预测时能够被精确地定位和加载。该命名逻辑由 [`get_model_file_path()`](server/core/config.py:136) 函数统一管理。
**命名格式**: `{model_type}_{model_identifier}_epoch_{version}.pth`
**各部分说明**:
- `{model_type}`: 模型的算法类型。例如:`transformer`, `mlstm`, `tcn`, `kan`
- `{model_identifier}`: 模型的唯一业务标识符,它根据训练模式(`training_mode`)动态生成:
- **按药品训练 (`product`)**: 标识符就是 `product_id`
- *示例*: `transformer_17002608_epoch_best.pth`
- **按店铺训练 (`store`)**: 标识符是 `store_{store_id}`
- *示例*: `tcn_store_01010023_epoch_best.pth`
- **全局训练 (`global`)**: 标识符是固定的字符串 `'global'`
- *示例*: `mlstm_global_epoch_best.pth`
- `{version}`: 模型的版本。在训练过程中,通常会保存两个版本:
- `best`: 在验证集上表现最佳的模型。
- `{epoch_number}`: 训练完成时的最终模型,例如 `50`
前端的“版本”下拉框中显示的就是这些版本字符串。
### 5.3. Checkpoint文件内容
每个 `.pth` 文件都是一个PyTorch Checkpoint它是一个Python字典包含了重建和使用模型所需的所有信息。这是确保预测与训练环境一致的关键。
**Checkpoint结构**:
```python
checkpoint = {
# 1. 模型权重
'model_state_dict': model.state_dict(),
# 2. 完整的模型配置
'config': {
'input_dim': ...,
'hidden_size': ...,
'num_layers': ...,
'model_type': 'transformer',
# ... 其他所有重建模型所需的超参数 ...
},
# 3. 数据归一化缩放器
'scaler_X': scaler_X, # 用于输入特征
'scaler_y': scaler_y, # 用于目标值(销量)
# 4. (可选) 模型性能指标
'metrics': {'mse': 0.01, 'mae': 0.05, ...}
}
```
**核心优势**:
- **可复现性**: 通过保存完整的 `config`我们可以在预测时精确地重建出与训练时结构完全相同的模型实例避免了因模型结构不匹配导致的加载失败这是之前修复的一个核心BUG
- **数据一致性**: 保存 `scaler_X``scaler_y` 确保了在预测时使用与训练时完全相同的归一化/反归一化逻辑,保证了预测结果的正确性。

127
项目快速上手指南.md Normal file
View File

@ -0,0 +1,127 @@
# 项目快速上手指南 (面向新开发者)
欢迎加入项目本指南旨在帮助你快速理解项目的核心功能、技术架构和开发流程特别是为你一位Java背景的开发者提供清晰的切入点。
## 1. 项目是做什么的?(实现了什么功能)
这是一个基于历史销售数据的 **智能销售预测系统**
核心功能有三个全部通过Web界面操作
1. **模型训练**: 用户可以选择某个**药品**、某个**店铺**或**全局**数据然后选择一种机器学习算法如Transformer、mLSTM等进行训练最终生成一个预测模型。
2. **销售预测**: 使用已经训练好的模型,对未来的销量进行预测。
3. **结果可视化**: 将历史销量和预测销量在同一个图表中展示出来,方便用户直观地看到趋势。
简单来说,它就是一个 **"数据 -> 训练 -> 模型 -> 预测 -> 可视化"** 的完整闭环应用。
## 2. 用了什么技术?(技术栈)
你可以将这个项目的技术栈与Java世界进行类比
| 层面 | 本项目技术 | Java世界类比 | 说明 |
| :--- | :--- | :--- | :--- |
| **后端框架** | **Flask** | Spring Boot | 一个轻量级的Web框架用于提供API接口。 |
| **前端框架** | **Vue.js** | React / Angular | 用于构建用户交互界面的现代化JavaScript框架。 |
| **核心算法库** | **PyTorch** | (无直接对应) | 类似于Java的Deeplearning4j是实现深度学习算法的核心。 |
| **数据处理** | **Pandas** | (无直接对应) | Python中用于数据分析和处理的“瑞士军刀”可以看作是内存中的强大数据表格。 |
| **构建/打包** | **Vite** (前端) | Maven / Gradle | 前端项目的构建和依赖管理工具。 |
| **数据库** | **SQLite** | H2 / MySQL | 一个轻量级的本地文件数据库,用于记录预测历史等。 |
| **实时通信** | **Socket.IO** | WebSocket / STOMP | 用于后端在训练时向前端实时推送进度。 |
## 3. 系统架构是怎样的?(架构层级和设计)
本项目是经典的前后端分离架构,可以分为四个主要层次:
```
+------------------------------------------------------+
| 用户 (Browser) |
+------------------------------------------------------+
|
+------------------------------------------------------+
| 1. 前端层 (Frontend - Vue.js) |
| - Views (页面组件, e.g., ProductPredictionView.vue) |
| - API Calls (使用axios与后端通信) |
| - Charting (使用Chart.js进行图表渲染) |
+------------------------------------------------------+
| (HTTP/S, WebSocket)
+------------------------------------------------------+
| 2. 后端API层 (Backend API - Flask) |
| - api.py (类似Controller, 定义RESTful接口) |
| - 接收请求, 验证参数, 调用业务逻辑层 |
+------------------------------------------------------+
|
+------------------------------------------------------+
| 3. 业务逻辑层 (Business Logic - Python) |
| - core/predictor.py (类似Service层) |
| - 封装核心业务, 如“根据参数选择合适的训练器” |
+------------------------------------------------------+
|
+------------------------------------------------------+
| 4. 数据与模型层 (Data & Model - PyTorch/Pandas) |
| - trainers/*.py (具体的算法实现和训练逻辑) |
| - predictors/model_predictor.py (模型加载与预测逻辑) |
| - saved_models/ (存放训练好的.pth模型文件) |
| - data/ (存放原始数据.parquet文件) |
+------------------------------------------------------+
```
## 4. 关键执行流程
以最常见的“按药品预测”为例:
1. **前端**: 用户在页面上选择药品和模型点击“预测”按钮。Vue组件通过`axios`向后端发送一个POST请求到 `/api/prediction`
2. **API层**: `api.py` 接收到请求像一个Controller一样解析出药品ID、模型类型等参数。
3. **业务逻辑层**: `api.py` 调用 `core/predictor.py` 中的 `predict` 方法,将参数传递下去。这一层是业务的“调度中心”。
4. **模型层**: `core/predictor.py` 最终调用 `predictors/model_predictor.py` 中的 `load_model_and_predict` 函数。
5. **模型加载与执行**:
* 根据参数在 `saved_models/` 目录下找到对应的模型文件(例如 `transformer_17002608_epoch_best.pth`)。
* 加载文件,从中恢复出 **模型结构**、**模型权重** 和 **数据缩放器**
* 准备最新的历史数据作为输入,执行预测。
* 将预测结果返回。
6. **返回与渲染**: 结果逐层返回到`api.py`在这里被格式化为JSON然后发送给前端。前端接收到JSON后使用`Chart.js`将历史和预测数据画在图表上。
## 5. 如何添加一个新的算法?(开发者指南)
这是你最可能接触到的新功能开发。假设你要添加一个名为 `NewNet` 的新算法,你需要按以下步骤操作:
**目标**: 让 `NewNet` 出现在前端的“模型类型”下拉框中,并能成功训练和预测。
1. **创建训练器文件**:
* 在 `server/trainers/` 目录下,复制一份现有的训练器文件(例如 `tcn_trainer.py`)并重命名为 `newnet_trainer.py`
* 在 `newnet_trainer.py` 中:
* 定义你的 `NewNet` 模型类(继承自 `torch.nn.Module`)。
* 修改 `train_..._with_tcn` 函数,将其重命名为 `train_..._with_newnet`
* 在这个新函数里,确保实例化的是你的 `NewNet` 模型。
* **最关键的一步**: 在保存checkpoint时确保 `config` 字典里包含了重建 `NewNet` 所需的所有超参数(比如层数、节点数等)。
* **重要开发规范:参数命名规则**
为了防止在模型加载时出现参数不匹配的错误(例如 `KeyError: 'num_layers'`),我们制定了以下命名规范:
> **规则:** 对于特定于某个算法的超参数,其在 `config` 字典中的键名key必须以该算法的名称作为前缀或唯一标识。
**示例:**
* 对于 `mLSTM` 模型的层数,键名应为 `mlstm_layers`
* 对于 `TCN` 模型的通道数,键名可以是 `tcn_channels`
* 对于 `Transformer` 模型的编码器层数,键名可以是 `num_encoder_layers` 因为这在Transformer语境下是明确的
**加载模型时** ([`server/predictors/model_predictor.py`](server/predictors/model_predictor.py:1)),必须使用与保存时完全一致的键名来读取这些参数。遵循此规则可以从根本上杜绝因参数名不一致导致的模型加载失败问题。
2. **注册新模型**:
* 打开 `server/core/config.py` 文件。
* 找到 `SUPPORTED_MODELS` 列表。
* 在列表中添加你的新模型标识符 `'newnet'`
3. **接入业务逻辑层 (训练)**:
* 打开 `server/core/predictor.py` 文件。
* 在 `train_model` 方法中,找到 `if/elif` 模型选择逻辑。
* 添加一个新的 `elif model_type == 'newnet':` 分支,让它调用你在第一步中创建的 `train_..._with_newnet` 函数。
4. **接入模型层 (预测)**:
* 打开 `server/predictors/model_predictor.py` 文件。
* 在 `load_model_and_predict` 函数中,找到 `if/elif` 模型实例化逻辑。
* 添加一个新的 `elif model_type == 'newnet':` 分支,确保它能根据 `config` 正确地创建 `NewNet` 模型实例。
5. **更新前端界面**:
* 打开 `UI/src/views/training/``UI/src/views/prediction/` 目录下的相关Vue文件`ProductTrainingView.vue`)。
* 找到定义模型选项的地方(通常是一个数组或对象)。
* 添加 `{ label: '新网络模型 (NewNet)', value: 'newnet' }` 这样的新选项。
完成以上步骤后,重启服务,你就可以在界面上选择并使用你的新算法了。