ShopTRAINING/server/trainers/kan_trainer.py
xz2000 18f505a090 # 修改记录日志 (日期: 2025-07-16) ---未改全
## 1. 训练流程与模型保存逻辑修复 (重大)

- **背景**: 用户报告在“按店铺”和“按药品”模式下,如果选择了特定的子集(如为某个店铺选择特定药品),生成的模型范围 (`scope`) 不正确,始终为 `_all`。此外,所有模型都被错误地保存到 `global` 目录下,且在某些模式下训练会失败。
- **根本原因**:
    1.  `server/core/predictor.py` 中负责准备训练参数的内部函数 (`_prepare_product_params`, `_prepare_store_params`) 逻辑有误,未能正确处理传入的 `product_ids` 和 `store_ids` 列表来构建详细的 `scope`。
    2.  各个训练器 (`server/trainers/*.py`) 内部的日志记录和元数据生成逻辑不统一,且过于依赖 `product_id`,导致在全局或店铺模式下信息展示不清晰。

- **修复方案**:
    - **`server/core/predictor.py`**:
        - **重构 `_prepare_product_params` 和 `_prepare_store_params`**: 修改了这两个函数,使其能够正确使用 `product_ids` 和 `store_ids` 列表。现在,当选择特定范围时,会生成更具描述性的 `scope`,例如 `S001_specific_P001_P002`。
        - **结果**: 确保了传递给模型管理器的 `scope` 是准确且详细的,从而使模型能够根据训练范围被保存到正确的、独立的文件夹中。

    - **`server/trainers/*.py` (mlstm, kan, tcn, transformer)**:
        - **标准化日志与元数据**: 对所有四个训练器文件进行了统一修改。引入了一个通用的 `training_description` 变量,该变量整合了 `training_mode`、`scope` 和 `aggregation_method`。
        - **更新输出**: 修改了所有训练器中的日志消息、图表标题和 `metadata.json` 的生成逻辑,使其全部使用这个标准的 `training_description`。
        - **结果**: 确保了无论在哪种训练模式下,前端收到的日志、保存的图表和元数据都具有一致、清晰的格式,便于调试和结果追溯。

- **总体影响**: 此次修复从根本上解决了模型训练范围处理和模型保存路径的错误问题,使整个训练系统在所有模式下都能可靠、一致地运行。

---

## 2. 核心 Bug 修复

### 文件: `server/core/predictor.py`

- **问题**: 在 `train_model` 方法中调用内部辅助函数 `_prepare_training_params` 时,没有正确传递 `product_ids` 和 `store_ids` 参数,导致在 `_prepare_training_params` 内部发生 `NameError`。
- **修复**:
    - 修正了 `train_model` 方法内部对 `_prepare_training_params` 的调用,确保 `product_ids` 和 `store_ids` 被显式传递。
    - 此前已修复 `train_model` 的函数签名,使其能正确接收 `store_ids`。
- **结果**: 彻底解决了训练流程中的参数传递问题,根除了由此引发的 `NameError`。

## 3. 代码清理与重构

### 文件: `server/api.py`

- **内容**: 移除了在 `start_training` API 端点中遗留的旧版、基于线程(`threading.Thread`)的训练逻辑。
- **原因**: 该代码块已被新的、基于多进程(`multiprocessing`)的 `TrainingProcessManager` 完全取代。旧代码中包含了大量用于调试的 `thread_safe_print` 日志,已无用处。
- **结果**: `start_training` 端点的逻辑变得更加清晰,只负责参数校验和向 `TrainingProcessManager` 提交任务。

### 文件: `server/utils/training_process_manager.py`

- **内容**: 在 `TrainingWorker` 的 `run_training_task` 方法中,移除了一个用于模拟训练进度的 `for` 循环。
- **原因**: 该循环包含 `time.sleep(1)`,仅用于在没有实际训练逻辑时模拟进度更新,现在实际的训练器会通过回调函数报告真实进度,因此该模拟代码不再需要。
- **结果**: `TrainingWorker` 现在直接调用实际的训练器,不再有模拟延迟,代码更贴近生产环境。
2025-07-16 16:51:38 +08:00

324 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
药店销售预测系统 - KAN模型训练函数 (已重构)
"""
import os
import time
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from sklearn.preprocessing import MinMaxScaler
import matplotlib.pyplot as plt
from tqdm import tqdm
from datetime import datetime
from models.kan_model import KANForecaster
from models.optimized_kan_forecaster import OptimizedKANForecaster
from utils.data_utils import create_dataset, PharmacyDataset
from analysis.metrics import evaluate_model
from core.config import DEVICE, LOOK_BACK, FORECAST_HORIZON
from utils.model_manager import model_manager
from typing import Any
def convert_numpy_types(obj: Any) -> Any:
"""
递归地将字典或列表中的Numpy数值类型转换为Python原生类型。
"""
if isinstance(obj, dict):
return {k: convert_numpy_types(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [convert_numpy_types(elem) for elem in obj]
elif isinstance(obj, np.floating):
return float(obj)
elif isinstance(obj, np.integer):
return int(obj)
elif isinstance(obj, np.ndarray):
return obj.tolist()
return obj
def train_product_model_with_kan(
product_df,
product_id=None,
store_id=None,
training_mode='product',
aggregation_method='sum',
scope=None,
epochs=50,
use_optimized=False,
socketio=None,
task_id=None,
progress_callback=None,
patience=10,
learning_rate=0.001
):
"""
使用KAN模型训练产品销售预测模型 (已适配新的ModelManager)
"""
def emit_progress(message, progress=None, metrics=None):
"""发送训练进度到前端"""
progress_data = {
'task_id': task_id,
'message': f"[KAN] {message}",
'timestamp': time.time()
}
if progress is not None:
progress_data['progress'] = progress
if metrics is not None:
progress_data['metrics'] = metrics
if progress_callback:
try:
progress_callback(progress_data)
except Exception as e:
print(f"[KAN] 进度回调失败: {e}")
if socketio and task_id:
try:
socketio.emit('training_progress', progress_data, namespace='/training')
except Exception as e:
print(f"[KAN] WebSocket发送失败: {e}")
print(f"[KAN] {message}", flush=True)
emit_progress("开始KAN模型训练...")
# 1. 确定模型标识符和版本
model_type = 'optimized_kan' if use_optimized else 'kan'
# 直接使用从 predictor 传递过来的、已经构建好的 scope
if scope is None:
# 作为后备如果scope未提供则根据旧逻辑构建不推荐
if training_mode == 'store':
current_product_id = product_id if product_id and product_id not in ['unknown', 'all'] else 'all'
scope = f"{store_id}_{current_product_id}"
elif training_mode == 'product':
scope = f"{product_id}_{store_id or 'all'}"
elif training_mode == 'global':
scope = product_id if product_id else "all"
emit_progress(f"警告: Scope未由调用方提供已自动构建为 '{scope}'", 'warning')
model_identifier = model_manager.get_model_identifier(model_type, training_mode, scope, aggregation_method)
version = model_manager.get_next_version_number(model_identifier)
emit_progress(f"开始训练 {model_type} 模型 v{version}")
# 2. 获取模型版本路径
model_version_path = model_manager.get_model_version_path(
model_type=model_type,
training_mode=training_mode,
scope=scope,
version=version,
aggregation_method=aggregation_method
)
emit_progress(f"模型将保存到: {model_version_path}")
# 3. 数据加载和预处理
if product_df is None:
# 此处保留了原有的数据加载逻辑作为后备
from utils.multi_store_data_utils import load_multi_store_data, get_store_product_sales_data, aggregate_multi_store_data
try:
if training_mode == 'store' and store_id:
product_df = get_store_product_sales_data(store_id, product_id, 'pharmacy_sales_multi_store.csv')
elif training_mode == 'global':
product_df = aggregate_multi_store_data(product_id, aggregation_method=aggregation_method, file_path='pharmacy_sales_multi_store.csv')
else:
product_df = load_multi_store_data('pharmacy_sales_multi_store.csv', product_id=product_id)
except Exception as e:
emit_progress(f"多店铺数据加载失败: {e}, 尝试后备方案...")
df = pd.read_excel('pharmacy_sales.xlsx')
product_df = df[df['product_id'] == product_id].sort_values('date')
# 根据训练模式和参数动态生成更详细的描述
if training_mode == 'store':
training_scope = f"店铺 {store_id}"
if scope and 'specific' in scope:
training_scope += " (指定药品)"
else:
training_scope += " (所有药品)"
elif training_mode == 'global':
training_scope = f"全局聚合({aggregation_method})"
else: # product 模式
training_scope = f"药品 {product_id}"
if scope and 'specific' in scope:
training_scope += " (指定店铺)"
elif store_id:
training_scope += f" (店铺 {store_id})"
else:
training_scope += " (所有店铺)"
min_required_samples = LOOK_BACK + FORECAST_HORIZON
if len(product_df) < min_required_samples:
error_msg = f"数据不足: 需要 {min_required_samples} 天, 实际 {len(product_df)} 天。"
emit_progress(f"训练失败:{error_msg}")
raise ValueError(error_msg)
product_df = product_df.sort_values('date')
if product_id:
product_name = product_df['product_name'].iloc[0]
else:
product_name = f"Aggregated Model ({training_mode}/{scope})"
print_product_id = product_id if product_id else "N/A"
emit_progress(f"训练产品: '{product_name}' (ID: {print_product_id}) - {training_scope}")
emit_progress(f"使用设备: {DEVICE}, 数据量: {len(product_df)}")
features = ['sales', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature']
X = product_df[features].values
y = product_df[['sales']].values
scaler_X = MinMaxScaler(feature_range=(0, 1))
scaler_y = MinMaxScaler(feature_range=(0, 1))
X_scaled = scaler_X.fit_transform(X)
y_scaled = scaler_y.fit_transform(y)
train_size = int(len(X_scaled) * 0.8)
X_train, X_test = X_scaled[:train_size], X_scaled[train_size:]
y_train, y_test = y_scaled[:train_size], y_scaled[train_size:]
trainX, trainY = create_dataset(X_train, y_train, LOOK_BACK, FORECAST_HORIZON)
testX, testY = create_dataset(X_test, y_test, LOOK_BACK, FORECAST_HORIZON)
train_loader = DataLoader(PharmacyDataset(torch.Tensor(trainX), torch.Tensor(trainY)), batch_size=32, shuffle=True)
test_loader = DataLoader(PharmacyDataset(torch.Tensor(testX), torch.Tensor(testY)), batch_size=32, shuffle=False)
# 4. 模型初始化
input_dim = X_train.shape[1]
output_dim = FORECAST_HORIZON
hidden_size = 64
if use_optimized:
model = OptimizedKANForecaster(input_features=input_dim, hidden_sizes=[hidden_size, hidden_size*2, hidden_size], output_sequence_length=output_dim)
else:
model = KANForecaster(input_features=input_dim, hidden_sizes=[hidden_size, hidden_size*2, hidden_size], output_sequence_length=output_dim)
model = model.to(DEVICE)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
emit_progress("数据预处理完成,开始模型训练...", progress=10)
# 5. 训练循环
train_losses, test_losses = [], []
start_time = time.time()
best_loss = float('inf')
epochs_no_improve = 0
for epoch in range(epochs):
model.train()
epoch_loss = 0
for X_batch, y_batch in train_loader:
X_batch, y_batch = X_batch.to(DEVICE), y_batch.to(DEVICE)
if y_batch.dim() == 2: y_batch = y_batch.unsqueeze(-1)
outputs = model(X_batch)
if outputs.dim() == 2: outputs = outputs.unsqueeze(-1)
loss = criterion(outputs, y_batch)
if hasattr(model, 'regularization_loss'):
loss = loss + model.regularization_loss() * 0.01
optimizer.zero_grad()
loss.backward()
optimizer.step()
epoch_loss += loss.item()
train_loss = epoch_loss / len(train_loader)
train_losses.append(train_loss)
model.eval()
test_loss = 0
with torch.no_grad():
for X_batch, y_batch in test_loader:
X_batch, y_batch = X_batch.to(DEVICE), y_batch.to(DEVICE)
if y_batch.dim() == 2: y_batch = y_batch.unsqueeze(-1)
outputs = model(X_batch)
if outputs.dim() == 2: outputs = outputs.unsqueeze(-1)
loss = criterion(outputs, y_batch)
test_loss += loss.item()
test_loss /= len(test_loader)
test_losses.append(test_loss)
progress_percentage = 10 + ((epoch + 1) / epochs) * 85
emit_progress(f"Epoch {epoch+1}/{epochs} - Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}", progress=progress_percentage)
if test_loss < best_loss:
best_loss = test_loss
epochs_no_improve = 0
checkpoint_data = {
'epoch': epoch + 1,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scaler_X': scaler_X,
'scaler_y': scaler_y,
}
model_manager.save_model_artifact(checkpoint_data, "checkpoint_best.pth", model_version_path)
emit_progress(f"💾 保存最佳模型检查点 (epoch {epoch+1}, test_loss: {test_loss:.4f})")
else:
epochs_no_improve += 1
if epochs_no_improve >= patience:
emit_progress(f"连续 {patience} 个epoch测试损失未改善提前停止训练。")
break
training_time = time.time() - start_time
# 6. 保存产物和评估
loss_fig = plt.figure(figsize=(10, 6))
plt.plot(train_losses, label='Training Loss')
plt.plot(test_losses, label='Test Loss')
plt.title(f'{model_type} 损失曲线 - {product_name} (v{version}) - {training_scope}')
plt.xlabel('Epoch'); plt.ylabel('Loss'); plt.legend(); plt.grid(True)
model_manager.save_model_artifact(loss_fig, "loss_curve.png", model_version_path)
plt.close(loss_fig)
emit_progress(f"损失曲线已保存到: {os.path.join(model_version_path, 'loss_curve.png')}")
model.eval()
with torch.no_grad():
testX_tensor = torch.Tensor(testX).to(DEVICE)
test_pred = model(testX_tensor).cpu().numpy()
if len(test_pred.shape) == 3: test_pred = test_pred.squeeze(-1)
test_pred_inv = scaler_y.inverse_transform(test_pred.reshape(-1, FORECAST_HORIZON))
test_true_inv = scaler_y.inverse_transform(testY.reshape(-1, FORECAST_HORIZON))
metrics = evaluate_model(test_true_inv.flatten(), test_pred_inv.flatten())
metrics['training_time'] = training_time
# 解决 'Object of type float32 is not JSON serializable' 错误
metrics = convert_numpy_types(metrics)
emit_progress(f"模型评估完成 - RMSE: {metrics['rmse']:.4f}, R²: {metrics['r2']:.4f}")
# 7. 保存最终模型和元数据
final_model_data = {
'model_state_dict': model.state_dict(),
'scaler_X': scaler_X,
'scaler_y': scaler_y,
}
model_manager.save_model_artifact(final_model_data, "model.pth", model_version_path)
metadata = {
'product_id': product_id if product_id else scope, 'product_name': product_name, 'model_type': model_type,
'version': f'v{version}', 'training_mode': training_mode, 'scope': scope,
'aggregation_method': aggregation_method, 'training_scope_description': training_scope,
'product_scope': '所有药品' if not product_id or product_id == 'all' else product_name,
'timestamp': datetime.now().isoformat(), 'metrics': metrics,
'config': {
'input_dim': input_dim, 'output_dim': output_dim,
'hidden_sizes': [hidden_size, hidden_size*2, hidden_size],
'sequence_length': LOOK_BACK, 'forecast_horizon': FORECAST_HORIZON,
'use_optimized': use_optimized
}
}
model_manager.save_model_artifact(metadata, "metadata.json", model_version_path)
# 8. 更新版本文件
model_manager.update_version(model_identifier, version)
emit_progress(f"{model_type}模型 v{version} 训练完成!", progress=100, metrics=metrics)
return model, metrics, version, model_version_path