Merge pull request 'lyf-dev-req0001' (#2) from lyf-dev-req0001 into lyf-dev

Reviewed-on: #2
This commit is contained in:
yuanfeiliao 2025-07-18 13:28:04 +08:00
commit ada4e8e108
9 changed files with 281 additions and 618 deletions

Binary file not shown.

View File

@ -56,7 +56,7 @@ from analysis.metrics import evaluate_model, compare_models
# 导入配置和版本管理
from core.config import (
DEFAULT_MODEL_DIR, WEBSOCKET_NAMESPACE,
get_model_versions, get_latest_model_version, get_next_model_version,
get_model_versions,
get_model_file_path, save_model_version_info
)
@ -1560,7 +1560,7 @@ def predict():
prediction_result = run_prediction(model_type, product_id, model_id, future_days, start_date, version, store_id, training_mode)
if prediction_result is None:
return jsonify({"status": "error", "error": "预测失败预测器返回None"}), 500
return jsonify({"status": "error", "error": "模型文件未找到或加载失败"}), 404
# 添加版本信息到预测结果
prediction_result['version'] = version
@ -3782,9 +3782,14 @@ def get_model_types():
def get_model_versions_api(product_id, model_type):
"""获取模型版本列表API"""
try:
versions = get_model_versions(product_id, model_type)
latest_version = get_latest_model_version(product_id, model_type)
from utils.model_manager import model_manager
result = model_manager.list_models(product_id=product_id, model_type=model_type)
models = result.get('models', [])
versions = sorted(list(set(m['version'] for m in models)), key=lambda v: (v != 'best', v))
latest_version = versions[0] if versions else None
return jsonify({
"status": "success",
"data": {

View File

@ -71,48 +71,6 @@ TRAINING_UPDATE_INTERVAL = 1 # 训练进度更新间隔(秒)
# 创建模型保存目录
os.makedirs(DEFAULT_MODEL_DIR, exist_ok=True)
def get_next_model_version(product_id: str, model_type: str) -> str:
"""
获取指定产品和模型类型的下一个版本号
Args:
product_id: 产品ID
model_type: 模型类型
Returns:
下一个版本号格式如 'v2', 'v3'
"""
# 新格式:带版本号的文件
pattern_new = f"{model_type}_model_product_{product_id}_v*.pth"
existing_files_new = glob.glob(os.path.join(DEFAULT_MODEL_DIR, pattern_new))
# 旧格式:不带版本号的文件(兼容性支持)
pattern_old = f"{model_type}_model_product_{product_id}.pth"
old_file_path = os.path.join(DEFAULT_MODEL_DIR, pattern_old)
has_old_format = os.path.exists(old_file_path)
# 如果没有任何格式的文件,返回默认版本
if not existing_files_new and not has_old_format:
return DEFAULT_VERSION
# 提取新格式文件的版本号
versions = []
for file_path in existing_files_new:
filename = os.path.basename(file_path)
version_match = re.search(rf"_v(\d+)\.pth$", filename)
if version_match:
versions.append(int(version_match.group(1)))
# 如果存在旧格式文件将其视为v1
if has_old_format:
versions.append(1)
print(f"检测到旧格式模型文件: {old_file_path}将其视为版本v1")
if versions:
next_version_num = max(versions) + 1
return f"v{next_version_num}"
else:
return DEFAULT_VERSION
def get_model_file_path(product_id: str, model_type: str, version: str) -> str:
"""
@ -134,13 +92,8 @@ def get_model_file_path(product_id: str, model_type: str, version: str) -> str:
# 修正直接使用唯一的product_id它可能包含store_前缀来构建文件名
# 文件名示例: transformer_17002608_epoch_best.pth 或 transformer_store_01010023_epoch_best.pth
# 针对 KAN 和 optimized_kan使用 model_manager 的命名约定
if model_type in ['kan', 'optimized_kan']:
# 格式: {model_type}_product_{product_id}_{version}.pth
# 注意KAN trainer 保存时product_id 就是 model_identifier
filename = f"{model_type}_product_{product_id}_{version}.pth"
else:
# 其他模型使用 _epoch_ 约定
filename = f"{model_type}_{product_id}_epoch_{version}.pth"
# 统一所有模型的命名格式
filename = f"{model_type}_product_{product_id}_{version}.pth"
# 修正直接在根模型目录查找不再使用checkpoints子目录
return os.path.join(DEFAULT_MODEL_DIR, filename)
@ -155,67 +108,32 @@ def get_model_versions(product_id: str, model_type: str) -> list:
Returns:
版本列表按版本号排序
"""
# 直接使用传入的product_id构建搜索模式
# 搜索模式,匹配 "transformer_product_17002608_epoch_50.pth" 或 "transformer_product_17002608_epoch_best.pth"
# 修正直接使用唯一的product_id它可能包含store_前缀来构建搜索模式
# 扩展搜索模式以兼容多种命名约定
patterns = [
f"{model_type}_{product_id}_epoch_*.pth", # 原始格式 (e.g., transformer_123_epoch_best.pth)
f"{model_type}_product_{product_id}_*.pth" # KAN/ModelManager格式 (e.g., kan_product_123_v1.pth)
]
# 统一使用新的命名约定进行搜索
pattern = os.path.join(DEFAULT_MODEL_DIR, f"{model_type}_product_{product_id}_*.pth")
existing_files = glob.glob(pattern)
existing_files = []
for pattern in patterns:
search_path = os.path.join(DEFAULT_MODEL_DIR, pattern)
existing_files.extend(glob.glob(search_path))
# 旧格式(兼容性支持)
pattern_old = f"{model_type}_model_product_{product_id}.pth"
old_file_path = os.path.join(DEFAULT_MODEL_DIR, pattern_old)
if os.path.exists(old_file_path):
existing_files.append(old_file_path)
versions = set() # 使用集合避免重复
versions = set()
# 从找到的文件中提取版本信息
for file_path in existing_files:
filename = os.path.basename(file_path)
# 尝试匹配 _epoch_ 格式
version_match_epoch = re.search(r"_epoch_(.+)\.pth$", filename)
if version_match_epoch:
versions.add(version_match_epoch.group(1))
continue
# 尝试匹配 _product_..._v 格式 (KAN)
version_match_kan = re.search(r"_product_.+_v(\d+)\.pth$", filename)
if version_match_kan:
versions.add(f"v{version_match_kan.group(1)}")
continue
# 严格匹配 _v<number> 或 'best'
match = re.search(r'_(v\d+|best)\.pth$', filename)
if match:
versions.add(match.group(1))
# 尝试匹配旧的 _model_product_ 格式
if pattern_old in filename:
versions.add("v1_legacy") # 添加一个特殊标识
print(f"检测到旧格式模型文件: {old_file_path},将其视为版本 v1_legacy")
continue
# 按数字版本降序排序,'best'始终在最前
def sort_key(v):
if v == 'best':
return -1 # 'best' is always first
if v.startswith('v'):
return int(v[1:])
return float('inf') # Should not happen
# 转换为列表并排序
sorted_versions = sorted(list(versions))
sorted_versions = sorted(list(versions), key=sort_key, reverse=True)
return sorted_versions
def get_latest_model_version(product_id: str, model_type: str) -> str:
"""
获取指定产品和模型类型的最新版本
Args:
product_id: 产品ID
model_type: 模型类型
Returns:
最新版本号如果没有则返回None
"""
versions = get_model_versions(product_id, model_type)
return versions[-1] if versions else None
def save_model_version_info(product_id: str, model_type: str, version: str, file_path: str, metrics: dict = None):
"""

View File

@ -257,11 +257,11 @@ def train_product_model_with_kan(product_id, model_identifier, product_df=None,
model_data=best_model_data,
product_id=model_identifier,
model_type=model_type_name,
version='best',
store_id=store_id,
training_mode=training_mode,
aggregation_method=aggregation_method,
product_name=product_name
product_name=product_name,
version='best' # 显式覆盖版本为'best'
)
if (epoch + 1) % 10 == 0:
@ -335,15 +335,18 @@ def train_product_model_with_kan(product_id, model_identifier, product_df=None,
'loss_curve_path': loss_curve_path
}
model_path = model_manager.save_model(
# 保存最终模型,让 model_manager 自动处理版本号
final_model_path, final_version = model_manager.save_model(
model_data=model_data,
product_id=model_identifier,
model_type=model_type_name,
version=f'final_epoch_{epochs}',
store_id=store_id,
training_mode=training_mode,
aggregation_method=aggregation_method,
product_name=product_name
# 注意此处不传递version参数由管理器自动生成
)
return model, metrics
print(f"最终模型已保存,版本: {final_version}, 路径: {final_model_path}")
return model, metrics

View File

@ -20,85 +20,10 @@ from utils.multi_store_data_utils import get_store_product_sales_data, aggregate
from utils.visualization import plot_loss_curve
from analysis.metrics import evaluate_model
from core.config import (
DEVICE, DEFAULT_MODEL_DIR, LOOK_BACK, FORECAST_HORIZON,
get_next_model_version, get_model_file_path, get_latest_model_version
DEVICE, DEFAULT_MODEL_DIR, LOOK_BACK, FORECAST_HORIZON
)
from utils.training_progress import progress_manager
def save_checkpoint(checkpoint_data: dict, epoch_or_label, model_identifier: str,
model_type: str, model_dir: str, store_id=None,
training_mode: str = 'product', aggregation_method=None):
"""
保存训练检查点
Args:
checkpoint_data: 检查点数据
epoch_or_label: epoch编号或标签'best'
product_id: 产品ID
model_type: 模型类型
model_dir: 模型保存目录
store_id: 店铺ID
training_mode: 训练模式
aggregation_method: 聚合方法
"""
# 创建检查点目录
# 直接在模型根目录保存,不再创建子目录
checkpoint_dir = model_dir
os.makedirs(checkpoint_dir, exist_ok=True)
# 修正直接使用product_id作为唯一标识符因为它已经包含了store_前缀或药品ID
filename = f"{model_type}_{model_identifier}_epoch_{epoch_or_label}.pth"
checkpoint_path = os.path.join(checkpoint_dir, filename)
# 保存检查点
torch.save(checkpoint_data, checkpoint_path)
print(f"[mLSTM] 检查点已保存: {checkpoint_path}", flush=True)
return checkpoint_path
def load_checkpoint(product_id: str, model_type: str, epoch_or_label,
model_dir: str, store_id=None, training_mode: str = 'product',
aggregation_method=None):
"""
加载训练检查点
Args:
product_id: 产品ID
model_type: 模型类型
epoch_or_label: epoch编号或标签
model_dir: 模型保存目录
store_id: 店铺ID
training_mode: 训练模式
aggregation_method: 聚合方法
Returns:
checkpoint_data: 检查点数据如果未找到返回None
"""
checkpoint_dir = os.path.join(model_dir, 'checkpoints')
# 生成检查点文件名
if training_mode == 'store' and store_id:
filename = f"{model_type}_store_{store_id}_{product_id}_epoch_{epoch_or_label}.pth"
elif training_mode == 'global' and aggregation_method:
filename = f"{model_type}_global_{product_id}_{aggregation_method}_epoch_{epoch_or_label}.pth"
else:
filename = f"{model_type}_product_{product_id}_epoch_{epoch_or_label}.pth"
checkpoint_path = os.path.join(checkpoint_dir, filename)
if os.path.exists(checkpoint_path):
try:
checkpoint_data = torch.load(checkpoint_path, map_location=DEVICE)
print(f"[mLSTM] 检查点已加载: {checkpoint_path}", flush=True)
return checkpoint_data
except Exception as e:
print(f"[mLSTM] 加载检查点失败: {e}", flush=True)
return None
else:
print(f"[mLSTM] 检查点文件不存在: {checkpoint_path}", flush=True)
return None
from utils.model_manager import model_manager
def train_product_model_with_mlstm(
product_id,
@ -173,15 +98,9 @@ def train_product_model_with_mlstm(
emit_progress("开始mLSTM模型训练...")
# 确定版本号
if version is None:
if continue_training:
version = get_latest_model_version(product_id, 'mlstm')
if version is None:
version = get_next_model_version(product_id, 'mlstm')
else:
version = get_next_model_version(product_id, 'mlstm')
emit_progress(f"开始训练 mLSTM 模型版本 {version}")
emit_progress(f"开始训练 mLSTM 模型")
if version:
emit_progress(f"使用指定版本: {version}")
# 初始化训练进度管理器(如果还未初始化)
if socketio and task_id:
@ -234,7 +153,7 @@ def train_product_model_with_mlstm(
print(f"[mLSTM] 使用mLSTM模型训练产品 '{product_name}' (ID: {product_id}) 的销售预测模型", flush=True)
print(f"[mLSTM] 训练范围: {training_scope}", flush=True)
print(f"[mLSTM] 版本: {version}", flush=True)
# print(f"[mLSTM] 版本: {version}", flush=True) # Version is now handled by model_manager
print(f"[mLSTM] 使用设备: {DEVICE}", flush=True)
print(f"[mLSTM] 模型将保存到目录: {model_dir}", flush=True)
print(f"[mLSTM] 数据量: {len(product_df)} 条记录", flush=True)
@ -323,16 +242,8 @@ def train_product_model_with_mlstm(
# 如果是继续训练,加载现有模型
if continue_training and version != 'v1':
try:
existing_model_path = get_model_file_path(product_id, 'mlstm', version)
if os.path.exists(existing_model_path):
checkpoint = torch.load(existing_model_path, map_location=DEVICE)
model.load_state_dict(checkpoint['model_state_dict'])
print(f"加载现有模型: {existing_model_path}")
emit_progress(f"加载现有模型版本 {version} 进行继续训练")
except Exception as e:
print(f"无法加载现有模型,将重新开始训练: {e}")
emit_progress("无法加载现有模型,重新开始训练")
# TODO: Implement continue_training logic with the new model_manager
pass
# 将模型移动到设备上
model = model.to(DEVICE)
@ -451,21 +362,23 @@ def train_product_model_with_mlstm(
}
}
# 保存检查点
save_checkpoint(checkpoint_data, epoch + 1, model_identifier, 'mlstm',
model_dir, store_id, training_mode, aggregation_method)
# 如果是最佳模型,额外保存一份
if test_loss < best_loss:
best_loss = test_loss
save_checkpoint(checkpoint_data, 'best', model_identifier, 'mlstm',
model_dir, store_id, training_mode, aggregation_method)
model_manager.save_model(
model_data=checkpoint_data,
product_id=model_identifier,
model_type='mlstm',
store_id=store_id,
training_mode=training_mode,
aggregation_method=aggregation_method,
product_name=product_name,
version='best'
)
emit_progress(f"💾 保存最佳模型检查点 (epoch {epoch+1}, test_loss: {test_loss:.4f})")
epochs_no_improve = 0
else:
epochs_no_improve += 1
emit_progress(f"💾 保存训练检查点 epoch_{epoch+1}")
if (epoch + 1) % 10 == 0:
print(f"Epoch {epoch+1}/{epochs}, Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}", flush=True)
@ -524,7 +437,6 @@ def train_product_model_with_mlstm(
# 计算评估指标
metrics = evaluate_model(test_true_inv, test_pred_inv)
metrics['training_time'] = training_time
metrics['version'] = version
# 打印评估指标
print("\n模型评估指标:")
@ -576,10 +488,15 @@ def train_product_model_with_mlstm(
}
}
# 保存最终模型使用epoch标识
final_model_path = save_checkpoint(
final_model_data, f"final_epoch_{epochs}", model_identifier, 'mlstm',
model_dir, store_id, training_mode, aggregation_method
# 保存最终模型,让 model_manager 自动处理版本号
final_model_path, final_version = model_manager.save_model(
model_data=final_model_data,
product_id=model_identifier,
model_type='mlstm',
store_id=store_id,
training_mode=training_mode,
aggregation_method=aggregation_method,
product_name=product_name
)
# 发送训练完成消息
@ -591,9 +508,10 @@ def train_product_model_with_mlstm(
'mape': metrics['mape'],
'training_time': training_time,
'final_epoch': epochs,
'model_path': final_model_path
'model_path': final_model_path,
'version': final_version
}
emit_progress(f"✅ mLSTM模型训练完成最终epoch: {epochs} 已保存", progress=100, metrics=final_metrics)
emit_progress(f"✅ mLSTM模型训练完成版本 {final_version} 已保存", progress=100, metrics=final_metrics)
return model, metrics, epochs, final_model_path
return model, metrics, epochs, final_model_path

View File

@ -20,39 +20,7 @@ from utils.visualization import plot_loss_curve
from analysis.metrics import evaluate_model
from core.config import DEVICE, DEFAULT_MODEL_DIR, LOOK_BACK, FORECAST_HORIZON
from utils.training_progress import progress_manager
def save_checkpoint(checkpoint_data: dict, epoch_or_label, model_identifier: str,
model_type: str, model_dir: str, store_id=None,
training_mode: str = 'product', aggregation_method=None):
"""
保存训练检查点
Args:
checkpoint_data: 检查点数据
epoch_or_label: epoch编号或标签'best'
product_id: 产品ID
model_type: 模型类型
model_dir: 模型保存目录
store_id: 店铺ID
training_mode: 训练模式
aggregation_method: 聚合方法
"""
# 创建检查点目录
# 直接在模型根目录保存,不再创建子目录
checkpoint_dir = model_dir
os.makedirs(checkpoint_dir, exist_ok=True)
# 生成检查点文件名
# 修正直接使用product_id作为唯一标识符因为它已经包含了store_前缀或药品ID
filename = f"{model_type}_{model_identifier}_epoch_{epoch_or_label}.pth"
checkpoint_path = os.path.join(checkpoint_dir, filename)
# 保存检查点
torch.save(checkpoint_data, checkpoint_path)
print(f"[TCN] 检查点已保存: {checkpoint_path}", flush=True)
return checkpoint_path
from utils.model_manager import model_manager
def train_product_model_with_tcn(
product_id,
@ -72,21 +40,6 @@ def train_product_model_with_tcn(
):
"""
使用TCN模型训练产品销售预测模型
参数:
product_id: 产品ID
epochs: 训练轮次
model_dir: 模型保存目录默认使用配置中的DEFAULT_MODEL_DIR
version: 指定版本号如果为None则自动生成
socketio: WebSocket对象用于实时反馈
task_id: 训练任务ID
continue_training: 是否继续训练现有模型
返回:
model: 训练好的模型
metrics: 模型评估指标
version: 实际使用的版本号
model_path: 模型文件路径
"""
def emit_progress(message, progress=None, metrics=None):
@ -103,62 +56,21 @@ def train_product_model_with_tcn(
data['metrics'] = metrics
socketio.emit('training_progress', data, namespace='/training')
# 确定版本号
if version is None:
from core.config import get_latest_model_version, get_next_model_version
if continue_training:
version = get_latest_model_version(product_id, 'tcn')
if version is None:
version = get_next_model_version(product_id, 'tcn')
else:
version = get_next_model_version(product_id, 'tcn')
emit_progress(f"开始训练 TCN 模型")
emit_progress(f"开始训练 TCN 模型版本 {version}")
# 如果没有传入product_df则根据训练模式加载数据
if product_df is None:
from utils.multi_store_data_utils import load_multi_store_data, get_store_product_sales_data, aggregate_multi_store_data
try:
if training_mode == 'store' and store_id:
# 加载特定店铺的数据
product_df = get_store_product_sales_data(
store_id,
product_id,
'pharmacy_sales_multi_store.csv'
)
training_scope = f"店铺 {store_id}"
elif training_mode == 'global':
# 聚合所有店铺的数据
product_df = aggregate_multi_store_data(
product_id,
aggregation_method=aggregation_method,
file_path='pharmacy_sales_multi_store.csv'
)
training_scope = f"全局聚合({aggregation_method})"
else:
# 默认:加载所有店铺的产品数据
product_df = load_multi_store_data('pharmacy_sales_multi_store.csv', product_id=product_id)
training_scope = "所有店铺"
except Exception as e:
print(f"多店铺数据加载失败: {e}")
# 后备方案:尝试原始数据
df = pd.read_excel('pharmacy_sales.xlsx')
product_df = df[df['product_id'] == product_id].sort_values('date')
training_scope = "原始数据"
from utils.multi_store_data_utils import aggregate_multi_store_data
product_df = aggregate_multi_store_data(
product_id=product_id,
aggregation_method=aggregation_method
)
training_scope = f"全局聚合({aggregation_method})"
else:
# 如果传入了product_df直接使用
if training_mode == 'store' and store_id:
training_scope = f"店铺 {store_id}"
elif training_mode == 'global':
training_scope = f"全局聚合({aggregation_method})"
else:
training_scope = "所有店铺"
training_scope = "所有店铺"
if product_df.empty:
raise ValueError(f"产品 {product_id} 没有可用的销售数据")
# 数据量检查
min_required_samples = sequence_length + forecast_horizon
if len(product_df) < min_required_samples:
error_msg = (
@ -166,10 +78,6 @@ def train_product_model_with_tcn(
f"当前配置需要: {min_required_samples} 天数据 (LOOK_BACK={sequence_length} + FORECAST_HORIZON={forecast_horizon})\n"
f"实际数据量: {len(product_df)}\n"
f"产品ID: {product_id}, 训练模式: {training_mode}\n"
f"建议解决方案:\n"
f"1. 生成更多数据: uv run generate_multi_store_data.py\n"
f"2. 调整配置参数: 减小 LOOK_BACK 或 FORECAST_HORIZON\n"
f"3. 使用全局训练模式聚合更多数据"
)
print(error_msg)
emit_progress(f"训练失败:数据不足 ({len(product_df)}/{min_required_samples} 天)")
@ -180,48 +88,39 @@ def train_product_model_with_tcn(
print(f"使用TCN模型训练产品 '{product_name}' (ID: {product_id}) 的销售预测模型")
print(f"训练范围: {training_scope}")
print(f"版本: {version}")
print(f"使用设备: {DEVICE}")
print(f"模型将保存到目录: {model_dir}")
emit_progress(f"训练产品: {product_name} (ID: {product_id})")
# 创建特征和目标变量
features = ['sales', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature']
# 预处理数据
X = product_df[features].values
y = product_df[['sales']].values # 保持为二维数组
y = product_df[['sales']].values
# 设置数据预处理阶段
progress_manager.set_stage("data_preprocessing", 0)
emit_progress("数据预处理中...")
# 归一化数据
scaler_X = MinMaxScaler(feature_range=(0, 1))
scaler_y = MinMaxScaler(feature_range=(0, 1))
X_scaled = scaler_X.fit_transform(X)
y_scaled = scaler_y.fit_transform(y)
# 划分训练集和测试集80% 训练20% 测试)
train_size = int(len(X_scaled) * 0.8)
X_train, X_test = X_scaled[:train_size], X_scaled[train_size:]
y_train, y_test = y_scaled[:train_size], y_scaled[train_size:]
progress_manager.set_stage("data_preprocessing", 50)
# 创建时间序列数据
trainX, trainY = create_dataset(X_train, y_train, sequence_length, forecast_horizon)
testX, testY = create_dataset(X_test, y_test, sequence_length, forecast_horizon)
# 转换为PyTorch的Tensor
trainX_tensor = torch.Tensor(trainX)
trainY_tensor = torch.Tensor(trainY)
testX_tensor = torch.Tensor(testX)
testY_tensor = torch.Tensor(testY)
# 创建数据加载器
train_dataset = PharmacyDataset(trainX_tensor, trainY_tensor)
test_dataset = PharmacyDataset(testX_tensor, testY_tensor)
@ -229,7 +128,6 @@ def train_product_model_with_tcn(
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
# 更新进度管理器的批次信息
total_batches = len(train_loader)
total_samples = len(train_dataset)
progress_manager.total_batches_per_epoch = total_batches
@ -238,7 +136,6 @@ def train_product_model_with_tcn(
progress_manager.set_stage("data_preprocessing", 100)
# 初始化TCN模型
input_dim = X_train.shape[1]
output_dim = forecast_horizon
hidden_size = 64
@ -254,21 +151,8 @@ def train_product_model_with_tcn(
dropout=dropout_rate
)
# 如果是继续训练,加载现有模型
if continue_training and version != 'v1':
try:
from core.config import get_model_file_path
existing_model_path = get_model_file_path(product_id, 'tcn', version)
if os.path.exists(existing_model_path):
checkpoint = torch.load(existing_model_path, map_location=DEVICE)
model.load_state_dict(checkpoint['model_state_dict'])
print(f"加载现有模型: {existing_model_path}")
emit_progress(f"加载现有模型版本 {version} 进行继续训练")
except Exception as e:
print(f"无法加载现有模型,将重新开始训练: {e}")
emit_progress("无法加载现有模型,重新开始训练")
# TODO: Implement continue_training logic with the new model_manager
# 将模型移动到设备上
model = model.to(DEVICE)
criterion = nn.MSELoss()
@ -276,20 +160,17 @@ def train_product_model_with_tcn(
emit_progress("开始模型训练...")
# 训练模型
train_losses = []
test_losses = []
start_time = time.time()
# 配置检查点保存
checkpoint_interval = max(1, epochs // 10) # 每10%进度保存一次最少每1个epoch
checkpoint_interval = max(1, epochs // 10)
best_loss = float('inf')
progress_manager.set_stage("model_training", 0)
emit_progress(f"开始训练 - 总epoch: {epochs}, 检查点间隔: {checkpoint_interval}")
for epoch in range(epochs):
# 开始新的轮次
progress_manager.start_epoch(epoch)
model.train()
@ -298,43 +179,34 @@ def train_product_model_with_tcn(
for batch_idx, (X_batch, y_batch) in enumerate(train_loader):
X_batch, y_batch = X_batch.to(DEVICE), y_batch.to(DEVICE)
# 确保目标张量有正确的形状 (batch_size, forecast_horizon, 1)
if y_batch.dim() == 2:
y_batch = y_batch.unsqueeze(-1)
# 前向传播
outputs = model(X_batch)
# 确保输出和目标形状匹配
loss = criterion(outputs, y_batch)
# 反向传播和优化
optimizer.zero_grad()
loss.backward()
optimizer.step()
epoch_loss += loss.item()
# 更新批次进度每10个批次更新一次
if batch_idx % 10 == 0 or batch_idx == len(train_loader) - 1:
current_lr = optimizer.param_groups[0]['lr']
progress_manager.update_batch(batch_idx, loss.item(), current_lr)
# 计算训练损失
train_loss = epoch_loss / len(train_loader)
train_losses.append(train_loss)
# 设置验证阶段
progress_manager.set_stage("validation", 0)
# 在测试集上评估
model.eval()
test_loss = 0
with torch.no_grad():
for batch_idx, (X_batch, y_batch) in enumerate(test_loader):
X_batch, y_batch = X_batch.to(DEVICE), y_batch.to(DEVICE)
# 确保目标张量有正确的形状
if y_batch.dim() == 2:
y_batch = y_batch.unsqueeze(-1)
@ -342,7 +214,6 @@ def train_product_model_with_tcn(
loss = criterion(outputs, y_batch)
test_loss += loss.item()
# 更新验证进度
if batch_idx % 5 == 0 or batch_idx == len(test_loader) - 1:
val_progress = (batch_idx / len(test_loader)) * 100
progress_manager.set_stage("validation", val_progress)
@ -350,10 +221,8 @@ def train_product_model_with_tcn(
test_loss = test_loss / len(test_loader)
test_losses.append(test_loss)
# 完成当前轮次
progress_manager.finish_epoch(train_loss, test_loss)
# 发送训练进度(保持与旧系统的兼容性)
if (epoch + 1) % 5 == 0 or epoch == epochs - 1:
progress = ((epoch + 1) / epochs) * 100
current_metrics = {
@ -365,7 +234,6 @@ def train_product_model_with_tcn(
emit_progress(f"Epoch {epoch+1}/{epochs}, Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}",
progress=progress, metrics=current_metrics)
# 定期保存检查点
if (epoch + 1) % checkpoint_interval == 0 or epoch == epochs - 1:
checkpoint_data = {
'epoch': epoch + 1,
@ -399,30 +267,28 @@ def train_product_model_with_tcn(
}
}
# 保存检查点
save_checkpoint(checkpoint_data, epoch + 1, model_identifier, 'tcn',
model_dir, store_id, training_mode, aggregation_method)
# 如果是最佳模型,额外保存一份
if test_loss < best_loss:
best_loss = test_loss
save_checkpoint(checkpoint_data, 'best', model_identifier, 'tcn',
model_dir, store_id, training_mode, aggregation_method)
model_manager.save_model(
model_data=checkpoint_data,
product_id=model_identifier,
model_type='tcn',
store_id=store_id,
training_mode=training_mode,
aggregation_method=aggregation_method,
product_name=product_name,
version='best'
)
emit_progress(f"💾 保存最佳模型检查点 (epoch {epoch+1}, test_loss: {test_loss:.4f})")
emit_progress(f"💾 保存训练检查点 epoch_{epoch+1}")
if (epoch + 1) % 10 == 0:
print(f"Epoch {epoch+1}/{epochs}, Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}")
# 计算训练时间
training_time = time.time() - start_time
# 设置模型保存阶段
progress_manager.set_stage("model_saving", 0)
emit_progress("训练完成,正在保存模型...")
# 绘制损失曲线并保存到模型目录
loss_curve_path = plot_loss_curve(
train_losses,
test_losses,
@ -432,23 +298,17 @@ def train_product_model_with_tcn(
)
print(f"损失曲线已保存到: {loss_curve_path}")
# 评估模型
model.eval()
with torch.no_grad():
# 确保测试数据的形状正确
test_pred = model(testX_tensor.to(DEVICE))
# 将输出转换为二维数组 [samples, forecast_horizon]
test_pred = test_pred.squeeze(-1).cpu().numpy()
# 反归一化预测结果和真实值
test_pred_inv = scaler_y.inverse_transform(test_pred.reshape(-1, 1)).flatten()
test_true_inv = scaler_y.inverse_transform(testY.reshape(-1, 1)).flatten()
# 计算评估指标
metrics = evaluate_model(test_true_inv, test_pred_inv)
metrics['training_time'] = training_time
# 打印评估指标
print("\n模型评估指标:")
print(f"MSE: {metrics['mse']:.4f}")
print(f"RMSE: {metrics['rmse']:.4f}")
@ -457,9 +317,8 @@ def train_product_model_with_tcn(
print(f"MAPE: {metrics['mape']:.2f}%")
print(f"训练时间: {training_time:.2f}")
# 保存最终训练完成的模型基于最终epoch
final_model_data = {
'epoch': epochs, # 最终epoch
'epoch': epochs,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'train_loss': train_losses[-1],
@ -495,10 +354,14 @@ def train_product_model_with_tcn(
progress_manager.set_stage("model_saving", 50)
# 保存最终模型使用epoch标识
final_model_path = save_checkpoint(
final_model_data, f"final_epoch_{epochs}", model_identifier, 'tcn',
model_dir, store_id, training_mode, aggregation_method
final_model_path, final_version = model_manager.save_model(
model_data=final_model_data,
product_id=model_identifier,
model_type='tcn',
store_id=store_id,
training_mode=training_mode,
aggregation_method=aggregation_method,
product_name=product_name
)
progress_manager.set_stage("model_saving", 100)
@ -510,9 +373,10 @@ def train_product_model_with_tcn(
'r2': metrics['r2'],
'mape': metrics['mape'],
'training_time': training_time,
'final_epoch': epochs
'final_epoch': epochs,
'version': final_version
}
emit_progress(f"模型训练完成!最终epoch: {epochs}", progress=100, metrics=final_metrics)
emit_progress(f"模型训练完成!版本 {final_version} 已保存", progress=100, metrics=final_metrics)
return model, metrics, epochs, final_model_path
return model, metrics, epochs, final_model_path

View File

@ -21,43 +21,11 @@ from utils.multi_store_data_utils import get_store_product_sales_data, aggregate
from utils.visualization import plot_loss_curve
from analysis.metrics import evaluate_model
from core.config import (
DEVICE, DEFAULT_MODEL_DIR, LOOK_BACK, FORECAST_HORIZON,
get_next_model_version, get_model_file_path, get_latest_model_version
DEVICE, DEFAULT_MODEL_DIR, LOOK_BACK, FORECAST_HORIZON
)
from utils.training_progress import progress_manager
from utils.model_manager import model_manager
def save_checkpoint(checkpoint_data: dict, epoch_or_label, model_identifier: str,
model_type: str, model_dir: str, store_id=None,
training_mode: str = 'product', aggregation_method=None):
"""
保存训练检查点
Args:
checkpoint_data: 检查点数据
epoch_or_label: epoch编号或标签'best'
product_id: 产品ID
model_type: 模型类型
model_dir: 模型保存目录
store_id: 店铺ID
training_mode: 训练模式
aggregation_method: 聚合方法
"""
# 直接在模型根目录保存,不再创建子目录
checkpoint_dir = model_dir
os.makedirs(checkpoint_dir, exist_ok=True)
# 修正直接使用product_id作为唯一标识符因为它已经包含了store_前缀或药品ID
filename = f"{model_type}_{model_identifier}_epoch_{epoch_or_label}.pth"
checkpoint_path = os.path.join(checkpoint_dir, filename)
# 保存检查点
torch.save(checkpoint_data, checkpoint_path)
print(f"[Transformer] 检查点已保存: {checkpoint_path}", flush=True)
return checkpoint_path
def train_product_model_with_transformer(
product_id,
model_identifier,
@ -79,23 +47,8 @@ def train_product_model_with_transformer(
):
"""
使用Transformer模型训练产品销售预测模型
参数:
product_id: 产品ID
epochs: 训练轮次
model_dir: 模型保存目录默认使用配置中的DEFAULT_MODEL_DIR
version: 指定版本号如果为None则自动生成
socketio: WebSocket对象用于实时反馈
task_id: 训练任务ID
continue_training: 是否继续训练现有模型
返回:
model: 训练好的模型
metrics: 模型评估指标
version: 实际使用的版本号
"""
# WebSocket进度反馈函数
def emit_progress(message, progress=None, metrics=None):
"""发送训练进度到前端"""
if socketio and task_id:
@ -110,18 +63,15 @@ def train_product_model_with_transformer(
data['metrics'] = metrics
socketio.emit('training_progress', data, namespace='/training')
print(f"[{time.strftime('%H:%M:%S')}] {message}", flush=True)
# 强制刷新输出缓冲区
import sys
sys.stdout.flush()
sys.stderr.flush()
emit_progress("开始Transformer模型训练...")
# 获取训练进度管理器实例
try:
from utils.training_progress import progress_manager
except ImportError:
# 如果无法导入,创建一个空的管理器以避免错误
class DummyProgressManager:
def set_stage(self, *args, **kwargs): pass
def start_training(self, *args, **kwargs): pass
@ -131,50 +81,19 @@ def train_product_model_with_transformer(
def finish_training(self, *args, **kwargs): pass
progress_manager = DummyProgressManager()
# 如果没有传入product_df则根据训练模式加载数据
if product_df is None:
from utils.multi_store_data_utils import load_multi_store_data, get_store_product_sales_data, aggregate_multi_store_data
try:
if training_mode == 'store' and store_id:
# 加载特定店铺的数据
product_df = get_store_product_sales_data(
store_id,
product_id,
'pharmacy_sales_multi_store.csv'
)
training_scope = f"店铺 {store_id}"
elif training_mode == 'global':
# 聚合所有店铺的数据
product_df = aggregate_multi_store_data(
product_id,
aggregation_method=aggregation_method,
file_path='pharmacy_sales_multi_store.csv'
)
training_scope = f"全局聚合({aggregation_method})"
else:
# 默认:加载所有店铺的产品数据
product_df = load_multi_store_data('pharmacy_sales_multi_store.csv', product_id=product_id)
training_scope = "所有店铺"
except Exception as e:
print(f"多店铺数据加载失败: {e}")
# 后备方案:尝试原始数据
df = pd.read_excel('pharmacy_sales.xlsx')
product_df = df[df['product_id'] == product_id].sort_values('date')
training_scope = "原始数据"
from utils.multi_store_data_utils import aggregate_multi_store_data
product_df = aggregate_multi_store_data(
product_id=product_id,
aggregation_method=aggregation_method
)
training_scope = f"全局聚合({aggregation_method})"
else:
# 如果传入了product_df直接使用
if training_mode == 'store' and store_id:
training_scope = f"店铺 {store_id}"
elif training_mode == 'global':
training_scope = f"全局聚合({aggregation_method})"
else:
training_scope = "所有店铺"
training_scope = "所有店铺"
if product_df.empty:
raise ValueError(f"产品 {product_id} 没有可用的销售数据")
# 数据量检查
min_required_samples = sequence_length + forecast_horizon
if len(product_df) < min_required_samples:
error_msg = (
@ -182,10 +101,6 @@ def train_product_model_with_transformer(
f"当前配置需要: {min_required_samples} 天数据 (LOOK_BACK={sequence_length} + FORECAST_HORIZON={forecast_horizon})\n"
f"实际数据量: {len(product_df)}\n"
f"产品ID: {product_id}, 训练模式: {training_mode}\n"
f"建议解决方案:\n"
f"1. 生成更多数据: uv run generate_multi_store_data.py\n"
f"2. 调整配置参数: 减小 LOOK_BACK 或 FORECAST_HORIZON\n"
f"3. 使用全局训练模式聚合更多数据"
)
print(error_msg)
raise ValueError(error_msg)
@ -197,18 +112,14 @@ def train_product_model_with_transformer(
print(f"[Device] 使用设备: {DEVICE}", flush=True)
print(f"[Model] 模型将保存到目录: {model_dir}", flush=True)
# 创建特征和目标变量
features = ['sales', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature']
# 设置数据预处理阶段
progress_manager.set_stage("data_preprocessing", 0)
emit_progress("数据预处理中...")
# 预处理数据
X = product_df[features].values
y = product_df[['sales']].values # 保持为二维数组
y = product_df[['sales']].values
# 归一化数据
scaler_X = MinMaxScaler(feature_range=(0, 1))
scaler_y = MinMaxScaler(feature_range=(0, 1))
@ -217,24 +128,20 @@ def train_product_model_with_transformer(
progress_manager.set_stage("data_preprocessing", 40)
# 划分训练集和测试集80% 训练20% 测试)
train_size = int(len(X_scaled) * 0.8)
X_train, X_test = X_scaled[:train_size], X_scaled[train_size:]
y_train, y_test = y_scaled[:train_size], y_scaled[train_size:]
# 创建时间序列数据
trainX, trainY = create_dataset(X_train, y_train, sequence_length, forecast_horizon)
testX, testY = create_dataset(X_test, y_test, sequence_length, forecast_horizon)
progress_manager.set_stage("data_preprocessing", 70)
# 转换为PyTorch的Tensor
trainX_tensor = torch.Tensor(trainX)
trainY_tensor = torch.Tensor(trainY)
testX_tensor = torch.Tensor(testX)
testY_tensor = torch.Tensor(testY)
# 创建数据加载器
train_dataset = PharmacyDataset(trainX_tensor, trainY_tensor)
test_dataset = PharmacyDataset(testX_tensor, testY_tensor)
@ -242,7 +149,6 @@ def train_product_model_with_transformer(
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
# 更新进度管理器的批次信息
total_batches = len(train_loader)
total_samples = len(train_dataset)
progress_manager.total_batches_per_epoch = total_batches
@ -252,7 +158,6 @@ def train_product_model_with_transformer(
progress_manager.set_stage("data_preprocessing", 100)
emit_progress("数据预处理完成,开始模型训练...")
# 初始化Transformer模型
input_dim = X_train.shape[1]
output_dim = forecast_horizon
hidden_size = 64
@ -272,20 +177,17 @@ def train_product_model_with_transformer(
batch_size=batch_size
)
# 将模型移动到设备上
model = model.to(DEVICE)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=patience // 2, factor=0.5)
# 训练模型
train_losses = []
test_losses = []
start_time = time.time()
# 配置检查点保存
checkpoint_interval = max(1, epochs // 10) # 每10%进度保存一次最少每1个epoch
checkpoint_interval = max(1, epochs // 10)
best_loss = float('inf')
epochs_no_improve = 0
@ -293,7 +195,6 @@ def train_product_model_with_transformer(
emit_progress(f"开始训练 - 总epoch: {epochs}, 检查点间隔: {checkpoint_interval}, 耐心值: {patience}")
for epoch in range(epochs):
# 开始新的轮次
progress_manager.start_epoch(epoch)
model.train()
@ -302,12 +203,9 @@ def train_product_model_with_transformer(
for batch_idx, (X_batch, y_batch) in enumerate(train_loader):
X_batch, y_batch = X_batch.to(DEVICE), y_batch.to(DEVICE)
# 确保目标张量有正确的形状
# 前向传播
outputs = model(X_batch)
loss = criterion(outputs, y_batch)
# 反向传播和优化
optimizer.zero_grad()
loss.backward()
if clip_norm:
@ -316,31 +214,25 @@ def train_product_model_with_transformer(
epoch_loss += loss.item()
# 更新批次进度
if batch_idx % 5 == 0 or batch_idx == len(train_loader) - 1:
current_lr = optimizer.param_groups[0]['lr']
progress_manager.update_batch(batch_idx, loss.item(), current_lr)
# 计算训练损失
train_loss = epoch_loss / len(train_loader)
train_losses.append(train_loss)
# 设置验证阶段
progress_manager.set_stage("validation", 0)
# 在测试集上评估
model.eval()
test_loss = 0
with torch.no_grad():
for batch_idx, (X_batch, y_batch) in enumerate(test_loader):
X_batch, y_batch = X_batch.to(DEVICE), y_batch.to(DEVICE)
# 确保目标张量有正确的形状
outputs = model(X_batch)
loss = criterion(outputs, y_batch)
test_loss += loss.item()
# 更新验证进度
if batch_idx % 3 == 0 or batch_idx == len(test_loader) - 1:
val_progress = (batch_idx / len(test_loader)) * 100
progress_manager.set_stage("validation", val_progress)
@ -348,13 +240,10 @@ def train_product_model_with_transformer(
test_loss = test_loss / len(test_loader)
test_losses.append(test_loss)
# 更新学习率
scheduler.step(test_loss)
# 完成当前轮次
progress_manager.finish_epoch(train_loss, test_loss)
# 发送训练进度
if (epoch + 1) % 5 == 0 or epoch == epochs - 1:
progress = ((epoch + 1) / epochs) * 100
current_metrics = {
@ -366,7 +255,6 @@ def train_product_model_with_transformer(
emit_progress(f"Epoch {epoch+1}/{epochs}, Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}",
progress=progress, metrics=current_metrics)
# 定期保存检查点
if (epoch + 1) % checkpoint_interval == 0 or epoch == epochs - 1:
checkpoint_data = {
'epoch': epoch + 1,
@ -399,38 +287,35 @@ def train_product_model_with_transformer(
}
}
# 保存检查点
save_checkpoint(checkpoint_data, epoch + 1, model_identifier, 'transformer',
model_dir, store_id, training_mode, aggregation_method)
# 如果是最佳模型,额外保存一份
if test_loss < best_loss:
best_loss = test_loss
save_checkpoint(checkpoint_data, 'best', model_identifier, 'transformer',
model_dir, store_id, training_mode, aggregation_method)
model_manager.save_model(
model_data=checkpoint_data,
product_id=model_identifier,
model_type='transformer',
store_id=store_id,
training_mode=training_mode,
aggregation_method=aggregation_method,
product_name=product_name,
version='best'
)
emit_progress(f"💾 保存最佳模型检查点 (epoch {epoch+1}, test_loss: {test_loss:.4f})")
epochs_no_improve = 0
else:
epochs_no_improve += 1
emit_progress(f"💾 保存训练检查点 epoch_{epoch+1}")
if (epoch + 1) % 10 == 0:
print(f"📊 Epoch {epoch+1}/{epochs}, 训练损失: {train_loss:.4f}, 测试损失: {test_loss:.4f}", flush=True)
# 提前停止逻辑
if epochs_no_improve >= patience:
emit_progress(f"连续 {patience} 个epoch测试损失未改善提前停止训练。")
break
# 计算训练时间
training_time = time.time() - start_time
# 设置模型保存阶段
progress_manager.set_stage("model_saving", 0)
emit_progress("训练完成,正在保存模型...")
# 绘制损失曲线并保存到模型目录
loss_curve_path = plot_loss_curve(
train_losses,
test_losses,
@ -440,21 +325,17 @@ def train_product_model_with_transformer(
)
print(f"📈 损失曲线已保存到: {loss_curve_path}", flush=True)
# 评估模型
model.eval()
with torch.no_grad():
test_pred = model(testX_tensor.to(DEVICE)).cpu().numpy()
test_true = testY
# 反归一化预测结果和真实值
test_pred_inv = scaler_y.inverse_transform(test_pred)
test_true_inv = scaler_y.inverse_transform(test_true)
# 计算评估指标
metrics = evaluate_model(test_true_inv, test_pred_inv)
metrics['training_time'] = training_time
# 打印评估指标
print(f"\n📊 模型评估指标:", flush=True)
print(f" MSE: {metrics['mse']:.4f}", flush=True)
print(f" RMSE: {metrics['rmse']:.4f}", flush=True)
@ -463,9 +344,8 @@ def train_product_model_with_transformer(
print(f" MAPE: {metrics['mape']:.2f}%", flush=True)
print(f" ⏱️ 训练时间: {training_time:.2f}", flush=True)
# 保存最终训练完成的模型基于最终epoch
final_model_data = {
'epoch': epochs, # 最终epoch
'epoch': epochs,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'train_loss': train_losses[-1],
@ -500,10 +380,14 @@ def train_product_model_with_transformer(
progress_manager.set_stage("model_saving", 50)
# 保存最终模型使用epoch标识
final_model_path = save_checkpoint(
final_model_data, f"final_epoch_{epochs}", model_identifier, 'transformer',
model_dir, store_id, training_mode, aggregation_method
final_model_path, final_version = model_manager.save_model(
model_data=final_model_data,
product_id=model_identifier,
model_type='transformer',
store_id=store_id,
training_mode=training_mode,
aggregation_method=aggregation_method,
product_name=product_name
)
progress_manager.set_stage("model_saving", 100)
@ -511,7 +395,6 @@ def train_product_model_with_transformer(
print(f"💾 模型已保存到 {final_model_path}", flush=True)
# 准备最终返回的指标
final_metrics = {
'mse': metrics['mse'],
'rmse': metrics['rmse'],
@ -519,7 +402,8 @@ def train_product_model_with_transformer(
'r2': metrics['r2'],
'mape': metrics['mape'],
'training_time': training_time,
'final_epoch': epochs
'final_epoch': epochs,
'version': final_version
}
return model, final_metrics, epochs
return model, final_metrics, epochs

View File

@ -8,6 +8,7 @@ import json
import torch
import glob
from datetime import datetime
import re
from typing import List, Dict, Optional, Tuple
from core.config import DEFAULT_MODEL_DIR
@ -24,9 +25,30 @@ class ModelManager:
if not os.path.exists(self.model_dir):
os.makedirs(self.model_dir)
def generate_model_filename(self,
product_id: str,
model_type: str,
def _get_next_version(self, product_id: str, model_type: str, store_id: Optional[str] = None, training_mode: str = 'product') -> int:
"""获取下一个模型版本号 (纯数字)"""
search_pattern = self.generate_model_filename(
product_id=product_id,
model_type=model_type,
version='v*',
store_id=store_id,
training_mode=training_mode
)
full_search_path = os.path.join(self.model_dir, search_pattern)
existing_files = glob.glob(full_search_path)
max_version = 0
for f in existing_files:
match = re.search(r'_v(\d+)\.pth$', os.path.basename(f))
if match:
max_version = max(max_version, int(match.group(1)))
return max_version + 1
def generate_model_filename(self,
product_id: str,
model_type: str,
version: str,
store_id: Optional[str] = None,
training_mode: str = 'product',
@ -36,7 +58,7 @@ class ModelManager:
格式规范:
- 产品模式: {model_type}_product_{product_id}_{version}.pth
- 店铺模式: {model_type}_store_{store_id}_{product_id}_{version}.pth
- 店铺模式: {model_type}_store_{store_id}_{product_id}_{version}.pth
- 全局模式: {model_type}_global_{product_id}_{aggregation_method}_{version}.pth
"""
if training_mode == 'store' and store_id:
@ -47,33 +69,33 @@ class ModelManager:
# 默认产品模式
return f"{model_type}_product_{product_id}_{version}.pth"
def save_model(self,
def save_model(self,
model_data: dict,
product_id: str,
model_type: str,
version: str,
model_type: str,
store_id: Optional[str] = None,
training_mode: str = 'product',
aggregation_method: Optional[str] = None,
product_name: Optional[str] = None) -> str:
product_name: Optional[str] = None,
version: Optional[str] = None) -> Tuple[str, str]:
"""
保存模型到统一位置
保存模型到统一位置并自动管理版本
参数:
model_data: 包含模型状态和配置的字典
product_id: 产品ID
model_type: 模型类型
version: 版本号
store_id: 店铺ID (可选)
training_mode: 训练模式
aggregation_method: 聚合方法 (可选)
product_name: 产品名称 (可选)
...
version: (可选) 如果提供则覆盖自动版本控制 ( 'best')
返回:
模型文件路径
(模型文件路径, 使用的版本号)
"""
if version is None:
next_version_num = self._get_next_version(product_id, model_type, store_id, training_mode)
version_str = f"v{next_version_num}"
else:
version_str = version
filename = self.generate_model_filename(
product_id, model_type, version, store_id, training_mode, aggregation_method
product_id, model_type, version_str, store_id, training_mode, aggregation_method
)
# 统一保存到根目录,避免复杂的子目录结构
@ -86,7 +108,7 @@ class ModelManager:
'product_id': product_id,
'product_name': product_name or product_id,
'model_type': model_type,
'version': version,
'version': version_str,
'store_id': store_id,
'training_mode': training_mode,
'aggregation_method': aggregation_method,
@ -99,7 +121,7 @@ class ModelManager:
torch.save(enhanced_model_data, model_path)
print(f"模型已保存: {model_path}")
return model_path
return model_path, version_str
def list_models(self,
product_id: Optional[str] = None,

View File

@ -1,101 +1,150 @@
# 新需求开发流程
# 新需求开发标准流程
本文档旨在提供一个标准、安全的新功能开发工作流,涵盖从创建功能分支到最终合并回主开发分支的完整步骤。
本文档旨在提供一个标准、安全、高效的新功能开发工作流,涵盖从创建功能分支到最终合并回主开发分支的完整步骤,并融入日常开发的最佳实践
## 核心流程
## 核心开发理念
采用功能分支Feature Branch的工作模式主要步骤如下
1. **同步与切换**:将远程创建的新功能分支(如 `lyf-dev-req0001`)同步到本地并切换。
2. **开发与提交**:在本地功能分支上进行开发,并频繁提交改动。
3. **推送与备份**:将本地的改动推送到远程功能分支。
4. **合并回主线**:当功能开发测试完成后,将其合并回主开发分支(如 `lyf-dev`)。
5. **清理分支**:合并完成后,清理已完成使命的功能分支。
- **主分支保护**: `lyf-dev` 是团队的主开发分支,应始终保持稳定和可部署状态。所有新功能开发都必须在独立的功能分支中进行。
- **功能分支**: 每个新需求(如 `req0001`)都对应一个功能分支(如 `lyf-dev-req0001`)。分支命名应清晰、有意义。
- **小步快跑**: 频繁提交Commit、频繁推送Push、频繁与主线同步`rebase``merge`)。这能有效减少后期合并的难度和风险。
- **清晰的历史**: 保持 Git 提交历史的可读性方便代码审查Code Review和问题追溯。
---
## 详细命令使用步骤
## 每日工作第一步:同步最新代码
### 步骤一:同步远程分支到本地
**无论你昨天工作到哪里,每天开始新一天的工作时,请务必执行以下步骤。这是保证团队高效协作、避免合并冲突的基石。**
假设您的同事已经在远程仓库基于 `lyf-dev` 创建了 `lyf-dev-req0001` 分支。
1. **获取远程所有最新信息**
这个命令会把远程仓库的新分支信息下载到你的本地,但不会做任何修改。
1. **更新主开发分支 `lyf-dev`**
```bash
git fetch origin
# 切换到主开发分支
git checkout lyf-dev
# 从远程拉取最新代码,--prune 会清理远程已删除的分支引用
git pull origin lyf-dev --prune
```
2. **创建并切换到本地功能分支**
这个命令会在本地创建一个名为 `lyf-dev-req0001` 的新分支,并自动设置它跟踪远程的 `origin/lyf-dev-req0001` 分支。
2. **同步你的功能分支 (团队选择一种方案)**
将主分支的最新代码同步到你的功能分支,有两种主流方案,请团队根据偏好选择其一。
---
### 方案一 (推荐): 使用 `rebase` 保持历史清爽
此方案会让你的分支提交历史保持为一条直线,非常清晰。
```bash
# 切换回你正在开发的功能分支(例如 lyf-dev-req0001
git checkout lyf-dev-req0001
# 使用 rebase 将 lyf-dev 的最新更新同步到你的分支
git rebase lyf-dev
```
现在,您已经处于一个干净、独立的功能分支上,可以开始开发了。
- **优点**: 最终的提交历史非常干净、线性,便于代码审查和问题追溯。
- **缺点**: 重写了提交历史,需要使用 `git push --force-with-lease` 强制推送。
- **冲突解决**:
1. 手动修改冲突文件。
2. 执行 `git add <冲突文件>`
3. 执行 `git rebase --continue`
4. 若想中止,执行 `git rebase --abort`
### 步骤二:在功能分支上开发
---
### 方案二: 使用 `merge` 保留完整历史
在这个分支上进行的所有修改都与 `lyf-dev` 无关,可以放心操作。
此方案会忠实记录每一次合并操作,不修改历史提交
1. **修改代码**:根据需求添加、修改文件。
2. **提交改动**
建议小步快跑,完成一个小的功能点就提交一次。
```bash
# 将所有修改添加到暂存区
# 切换回你正在开发的功能分支(例如 lyf-dev-req0001
git checkout lyf-dev-req0001
# 将最新的 lyf-dev 合并到你当前的分支
git merge lyf-dev
```
- **优点**: 操作安全,不修改历史,推送时无需强制。
- **缺点**: 会在功能分支中产生额外的合并提交记录 (e.g., "Merge branch 'lyf-dev' into ..."),使历史记录变得复杂。
- **冲突解决**:
1. 手动修改冲突文件。
2. 执行 `git add <冲突文件>`
3. 执行 `git commit` 完成合并。
---
## 完整开发流程
### 1. 开始新需求:创建功能分支
**当你需要开启一个全新的功能开发时:**
1. **确保 `lyf-dev` 已是最新**
(此步骤已在“每日工作第一步”中完成,此处作为提醒)
2. **从 `lyf-dev` 创建并切换到新分支**
假设新需求编号是 `req0002`
```bash
# 这会从最新的 lyf-dev 创建 lyf-dev-req0002 分支并切换过去
git checkout -b lyf-dev-req0002
```
### 2. 日常开发:提交与推送
**在你的功能分支上(如 `lyf-dev-req0002`)进行开发:**
1. **编码与本地提交**
完成一个小的、完整的功能点后,就进行一次提交。
```bash
# 查看修改状态
git status
# 添加所有相关文件到暂存区
git add .
# 提交并撰写清晰的说明
git commit -m "feat: 完成用户登录接口"
# 提交并撰写清晰的说明feat: 功能, fix: 修复, docs: 文档等)
git commit -m "feat: 实现用户认证模块"
```
3. **推送到远程功能分支**
为了备份代码或与他人协作,需要将本地的提交推送到远程。
2. **推送改动到远程备份**
为了代码安全和方便团队协作,应频繁将本地提交推送到远程。
```bash
git push origin lyf-dev-req0001
# -u 参数会设置本地分支跟踪远程分支,后续只需 git push 即可
git push -u origin lyf-dev-req0002
```
### 步骤三:合并功能到主开发分支 (`lyf-dev`)
### 3. 功能完成:合并回主线
当新功能开发完成,并且经过充分测试后,执行以下步骤将其合并回 `lyf-dev`
**当功能开发完成并通过测试后,将其合并回 `lyf-dev`**
1. **切换回主开发分支**
1. **最后一次同步**
在正式合并前,做最后一次同步,确保分支包含了 `lyf-dev` 的所有最新内容。
(重复“每日工作第一步”中的同步流程)
2. **切换到主分支并拉取最新代码**
```bash
git checkout lyf-dev
```
2. **确保 `lyf-dev` 是最新的**
在合并前,务必先从远程拉取 `lyf-dev` 的最新代码,以防他人在此期间有更新。
```bash
git pull origin lyf-dev
```
3. **合并功能分支**
这是最关键的一步,将 `lyf-dev-req0001` 的所有新功能合并进来。
3. **合并功能分支**
我们使用 `--no-ff` (No Fast-forward) 参数来创建合并提交,这样可以清晰地记录“合并了一个功能”这个行为。
```bash
git merge lyf-dev-req0001
# --no-ff 会创建一个新的合并提交,保留分支历史
git merge --no-ff lyf-dev-req0002
```
- **无冲突**Git 会自动完成合并。
- **有冲突 (Conflict)**Git 会提示你哪些文件冲突了。你需要手动解决这些文件中的冲突,然后执行 `git add .``git commit` 来完成合并。
如果同步工作做得好,这一步通常不会有冲突。
4. **推送合并后的 `lyf-dev`**
将本地合并好的 `lyf-dev` 分支推送到远程仓库。
4. **推送合并后的主分支**
```bash
git push origin lyf-dev
```
### 步骤四:清理分支(可选)
### 4. 清理工作
合并完成后,功能分支的历史使命就完成了。为了保持仓库的整洁,可以删除它。
**合并完成后,功能分支的历史使命就完成了**
1. **删除远程分支**
1. **删除远程分支**
```bash
git push origin --delete lyf-dev-req0001
git push origin --delete lyf-dev-req0002
```
2. **删除本地分支**
2. **删除本地分支**
```bash
git branch -d lyf-dev-req0001
git branch -d lyf-dev-req0002
```
遵循以上流程,可以确保的开发工作流清晰、安全且高效。
遵循以上流程,可以确保团队的开发工作流清晰、安全且高效。