lyf-dev-req0001 #2

Merged
yuanfeiliao merged 2 commits from lyf-dev-req0001 into lyf-dev 2025-07-18 13:28:05 +08:00
9 changed files with 281 additions and 618 deletions

Binary file not shown.

View File

@ -56,7 +56,7 @@ from analysis.metrics import evaluate_model, compare_models
# 导入配置和版本管理 # 导入配置和版本管理
from core.config import ( from core.config import (
DEFAULT_MODEL_DIR, WEBSOCKET_NAMESPACE, DEFAULT_MODEL_DIR, WEBSOCKET_NAMESPACE,
get_model_versions, get_latest_model_version, get_next_model_version, get_model_versions,
get_model_file_path, save_model_version_info get_model_file_path, save_model_version_info
) )
@ -1560,7 +1560,7 @@ def predict():
prediction_result = run_prediction(model_type, product_id, model_id, future_days, start_date, version, store_id, training_mode) prediction_result = run_prediction(model_type, product_id, model_id, future_days, start_date, version, store_id, training_mode)
if prediction_result is None: if prediction_result is None:
return jsonify({"status": "error", "error": "预测失败预测器返回None"}), 500 return jsonify({"status": "error", "error": "模型文件未找到或加载失败"}), 404
# 添加版本信息到预测结果 # 添加版本信息到预测结果
prediction_result['version'] = version prediction_result['version'] = version
@ -3782,8 +3782,13 @@ def get_model_types():
def get_model_versions_api(product_id, model_type): def get_model_versions_api(product_id, model_type):
"""获取模型版本列表API""" """获取模型版本列表API"""
try: try:
versions = get_model_versions(product_id, model_type) from utils.model_manager import model_manager
latest_version = get_latest_model_version(product_id, model_type)
result = model_manager.list_models(product_id=product_id, model_type=model_type)
models = result.get('models', [])
versions = sorted(list(set(m['version'] for m in models)), key=lambda v: (v != 'best', v))
latest_version = versions[0] if versions else None
return jsonify({ return jsonify({
"status": "success", "status": "success",

View File

@ -71,48 +71,6 @@ TRAINING_UPDATE_INTERVAL = 1 # 训练进度更新间隔(秒)
# 创建模型保存目录 # 创建模型保存目录
os.makedirs(DEFAULT_MODEL_DIR, exist_ok=True) os.makedirs(DEFAULT_MODEL_DIR, exist_ok=True)
def get_next_model_version(product_id: str, model_type: str) -> str:
"""
获取指定产品和模型类型的下一个版本号
Args:
product_id: 产品ID
model_type: 模型类型
Returns:
下一个版本号格式如 'v2', 'v3'
"""
# 新格式:带版本号的文件
pattern_new = f"{model_type}_model_product_{product_id}_v*.pth"
existing_files_new = glob.glob(os.path.join(DEFAULT_MODEL_DIR, pattern_new))
# 旧格式:不带版本号的文件(兼容性支持)
pattern_old = f"{model_type}_model_product_{product_id}.pth"
old_file_path = os.path.join(DEFAULT_MODEL_DIR, pattern_old)
has_old_format = os.path.exists(old_file_path)
# 如果没有任何格式的文件,返回默认版本
if not existing_files_new and not has_old_format:
return DEFAULT_VERSION
# 提取新格式文件的版本号
versions = []
for file_path in existing_files_new:
filename = os.path.basename(file_path)
version_match = re.search(rf"_v(\d+)\.pth$", filename)
if version_match:
versions.append(int(version_match.group(1)))
# 如果存在旧格式文件将其视为v1
if has_old_format:
versions.append(1)
print(f"检测到旧格式模型文件: {old_file_path}将其视为版本v1")
if versions:
next_version_num = max(versions) + 1
return f"v{next_version_num}"
else:
return DEFAULT_VERSION
def get_model_file_path(product_id: str, model_type: str, version: str) -> str: def get_model_file_path(product_id: str, model_type: str, version: str) -> str:
""" """
@ -134,13 +92,8 @@ def get_model_file_path(product_id: str, model_type: str, version: str) -> str:
# 修正直接使用唯一的product_id它可能包含store_前缀来构建文件名 # 修正直接使用唯一的product_id它可能包含store_前缀来构建文件名
# 文件名示例: transformer_17002608_epoch_best.pth 或 transformer_store_01010023_epoch_best.pth # 文件名示例: transformer_17002608_epoch_best.pth 或 transformer_store_01010023_epoch_best.pth
# 针对 KAN 和 optimized_kan使用 model_manager 的命名约定 # 针对 KAN 和 optimized_kan使用 model_manager 的命名约定
if model_type in ['kan', 'optimized_kan']: # 统一所有模型的命名格式
# 格式: {model_type}_product_{product_id}_{version}.pth filename = f"{model_type}_product_{product_id}_{version}.pth"
# 注意KAN trainer 保存时product_id 就是 model_identifier
filename = f"{model_type}_product_{product_id}_{version}.pth"
else:
# 其他模型使用 _epoch_ 约定
filename = f"{model_type}_{product_id}_epoch_{version}.pth"
# 修正直接在根模型目录查找不再使用checkpoints子目录 # 修正直接在根模型目录查找不再使用checkpoints子目录
return os.path.join(DEFAULT_MODEL_DIR, filename) return os.path.join(DEFAULT_MODEL_DIR, filename)
@ -155,67 +108,32 @@ def get_model_versions(product_id: str, model_type: str) -> list:
Returns: Returns:
版本列表按版本号排序 版本列表按版本号排序
""" """
# 直接使用传入的product_id构建搜索模式 # 统一使用新的命名约定进行搜索
# 搜索模式,匹配 "transformer_product_17002608_epoch_50.pth" 或 "transformer_product_17002608_epoch_best.pth" pattern = os.path.join(DEFAULT_MODEL_DIR, f"{model_type}_product_{product_id}_*.pth")
# 修正直接使用唯一的product_id它可能包含store_前缀来构建搜索模式 existing_files = glob.glob(pattern)
# 扩展搜索模式以兼容多种命名约定
patterns = [
f"{model_type}_{product_id}_epoch_*.pth", # 原始格式 (e.g., transformer_123_epoch_best.pth)
f"{model_type}_product_{product_id}_*.pth" # KAN/ModelManager格式 (e.g., kan_product_123_v1.pth)
]
existing_files = [] versions = set()
for pattern in patterns:
search_path = os.path.join(DEFAULT_MODEL_DIR, pattern)
existing_files.extend(glob.glob(search_path))
# 旧格式(兼容性支持)
pattern_old = f"{model_type}_model_product_{product_id}.pth"
old_file_path = os.path.join(DEFAULT_MODEL_DIR, pattern_old)
if os.path.exists(old_file_path):
existing_files.append(old_file_path)
versions = set() # 使用集合避免重复
# 从找到的文件中提取版本信息
for file_path in existing_files: for file_path in existing_files:
filename = os.path.basename(file_path) filename = os.path.basename(file_path)
# 尝试匹配 _epoch_ 格式 # 严格匹配 _v<number> 或 'best'
version_match_epoch = re.search(r"_epoch_(.+)\.pth$", filename) match = re.search(r'_(v\d+|best)\.pth$', filename)
if version_match_epoch: if match:
versions.add(version_match_epoch.group(1)) versions.add(match.group(1))
continue
# 尝试匹配 _product_..._v 格式 (KAN) # 按数字版本降序排序,'best'始终在最前
version_match_kan = re.search(r"_product_.+_v(\d+)\.pth$", filename) def sort_key(v):
if version_match_kan: if v == 'best':
versions.add(f"v{version_match_kan.group(1)}") return -1 # 'best' is always first
continue if v.startswith('v'):
return int(v[1:])
return float('inf') # Should not happen
# 尝试匹配旧的 _model_product_ 格式 sorted_versions = sorted(list(versions), key=sort_key, reverse=True)
if pattern_old in filename:
versions.add("v1_legacy") # 添加一个特殊标识
print(f"检测到旧格式模型文件: {old_file_path},将其视为版本 v1_legacy")
continue
# 转换为列表并排序
sorted_versions = sorted(list(versions))
return sorted_versions return sorted_versions
def get_latest_model_version(product_id: str, model_type: str) -> str:
"""
获取指定产品和模型类型的最新版本
Args:
product_id: 产品ID
model_type: 模型类型
Returns:
最新版本号如果没有则返回None
"""
versions = get_model_versions(product_id, model_type)
return versions[-1] if versions else None
def save_model_version_info(product_id: str, model_type: str, version: str, file_path: str, metrics: dict = None): def save_model_version_info(product_id: str, model_type: str, version: str, file_path: str, metrics: dict = None):
""" """

View File

@ -257,11 +257,11 @@ def train_product_model_with_kan(product_id, model_identifier, product_df=None,
model_data=best_model_data, model_data=best_model_data,
product_id=model_identifier, product_id=model_identifier,
model_type=model_type_name, model_type=model_type_name,
version='best',
store_id=store_id, store_id=store_id,
training_mode=training_mode, training_mode=training_mode,
aggregation_method=aggregation_method, aggregation_method=aggregation_method,
product_name=product_name product_name=product_name,
version='best' # 显式覆盖版本为'best'
) )
if (epoch + 1) % 10 == 0: if (epoch + 1) % 10 == 0:
@ -335,15 +335,18 @@ def train_product_model_with_kan(product_id, model_identifier, product_df=None,
'loss_curve_path': loss_curve_path 'loss_curve_path': loss_curve_path
} }
model_path = model_manager.save_model( # 保存最终模型,让 model_manager 自动处理版本号
final_model_path, final_version = model_manager.save_model(
model_data=model_data, model_data=model_data,
product_id=model_identifier, product_id=model_identifier,
model_type=model_type_name, model_type=model_type_name,
version=f'final_epoch_{epochs}',
store_id=store_id, store_id=store_id,
training_mode=training_mode, training_mode=training_mode,
aggregation_method=aggregation_method, aggregation_method=aggregation_method,
product_name=product_name product_name=product_name
# 注意此处不传递version参数由管理器自动生成
) )
print(f"最终模型已保存,版本: {final_version}, 路径: {final_model_path}")
return model, metrics return model, metrics

View File

@ -20,85 +20,10 @@ from utils.multi_store_data_utils import get_store_product_sales_data, aggregate
from utils.visualization import plot_loss_curve from utils.visualization import plot_loss_curve
from analysis.metrics import evaluate_model from analysis.metrics import evaluate_model
from core.config import ( from core.config import (
DEVICE, DEFAULT_MODEL_DIR, LOOK_BACK, FORECAST_HORIZON, DEVICE, DEFAULT_MODEL_DIR, LOOK_BACK, FORECAST_HORIZON
get_next_model_version, get_model_file_path, get_latest_model_version
) )
from utils.training_progress import progress_manager from utils.training_progress import progress_manager
from utils.model_manager import model_manager
def save_checkpoint(checkpoint_data: dict, epoch_or_label, model_identifier: str,
model_type: str, model_dir: str, store_id=None,
training_mode: str = 'product', aggregation_method=None):
"""
保存训练检查点
Args:
checkpoint_data: 检查点数据
epoch_or_label: epoch编号或标签'best'
product_id: 产品ID
model_type: 模型类型
model_dir: 模型保存目录
store_id: 店铺ID
training_mode: 训练模式
aggregation_method: 聚合方法
"""
# 创建检查点目录
# 直接在模型根目录保存,不再创建子目录
checkpoint_dir = model_dir
os.makedirs(checkpoint_dir, exist_ok=True)
# 修正直接使用product_id作为唯一标识符因为它已经包含了store_前缀或药品ID
filename = f"{model_type}_{model_identifier}_epoch_{epoch_or_label}.pth"
checkpoint_path = os.path.join(checkpoint_dir, filename)
# 保存检查点
torch.save(checkpoint_data, checkpoint_path)
print(f"[mLSTM] 检查点已保存: {checkpoint_path}", flush=True)
return checkpoint_path
def load_checkpoint(product_id: str, model_type: str, epoch_or_label,
model_dir: str, store_id=None, training_mode: str = 'product',
aggregation_method=None):
"""
加载训练检查点
Args:
product_id: 产品ID
model_type: 模型类型
epoch_or_label: epoch编号或标签
model_dir: 模型保存目录
store_id: 店铺ID
training_mode: 训练模式
aggregation_method: 聚合方法
Returns:
checkpoint_data: 检查点数据如果未找到返回None
"""
checkpoint_dir = os.path.join(model_dir, 'checkpoints')
# 生成检查点文件名
if training_mode == 'store' and store_id:
filename = f"{model_type}_store_{store_id}_{product_id}_epoch_{epoch_or_label}.pth"
elif training_mode == 'global' and aggregation_method:
filename = f"{model_type}_global_{product_id}_{aggregation_method}_epoch_{epoch_or_label}.pth"
else:
filename = f"{model_type}_product_{product_id}_epoch_{epoch_or_label}.pth"
checkpoint_path = os.path.join(checkpoint_dir, filename)
if os.path.exists(checkpoint_path):
try:
checkpoint_data = torch.load(checkpoint_path, map_location=DEVICE)
print(f"[mLSTM] 检查点已加载: {checkpoint_path}", flush=True)
return checkpoint_data
except Exception as e:
print(f"[mLSTM] 加载检查点失败: {e}", flush=True)
return None
else:
print(f"[mLSTM] 检查点文件不存在: {checkpoint_path}", flush=True)
return None
def train_product_model_with_mlstm( def train_product_model_with_mlstm(
product_id, product_id,
@ -173,15 +98,9 @@ def train_product_model_with_mlstm(
emit_progress("开始mLSTM模型训练...") emit_progress("开始mLSTM模型训练...")
# 确定版本号 # 确定版本号
if version is None: emit_progress(f"开始训练 mLSTM 模型")
if continue_training: if version:
version = get_latest_model_version(product_id, 'mlstm') emit_progress(f"使用指定版本: {version}")
if version is None:
version = get_next_model_version(product_id, 'mlstm')
else:
version = get_next_model_version(product_id, 'mlstm')
emit_progress(f"开始训练 mLSTM 模型版本 {version}")
# 初始化训练进度管理器(如果还未初始化) # 初始化训练进度管理器(如果还未初始化)
if socketio and task_id: if socketio and task_id:
@ -234,7 +153,7 @@ def train_product_model_with_mlstm(
print(f"[mLSTM] 使用mLSTM模型训练产品 '{product_name}' (ID: {product_id}) 的销售预测模型", flush=True) print(f"[mLSTM] 使用mLSTM模型训练产品 '{product_name}' (ID: {product_id}) 的销售预测模型", flush=True)
print(f"[mLSTM] 训练范围: {training_scope}", flush=True) print(f"[mLSTM] 训练范围: {training_scope}", flush=True)
print(f"[mLSTM] 版本: {version}", flush=True) # print(f"[mLSTM] 版本: {version}", flush=True) # Version is now handled by model_manager
print(f"[mLSTM] 使用设备: {DEVICE}", flush=True) print(f"[mLSTM] 使用设备: {DEVICE}", flush=True)
print(f"[mLSTM] 模型将保存到目录: {model_dir}", flush=True) print(f"[mLSTM] 模型将保存到目录: {model_dir}", flush=True)
print(f"[mLSTM] 数据量: {len(product_df)} 条记录", flush=True) print(f"[mLSTM] 数据量: {len(product_df)} 条记录", flush=True)
@ -323,16 +242,8 @@ def train_product_model_with_mlstm(
# 如果是继续训练,加载现有模型 # 如果是继续训练,加载现有模型
if continue_training and version != 'v1': if continue_training and version != 'v1':
try: # TODO: Implement continue_training logic with the new model_manager
existing_model_path = get_model_file_path(product_id, 'mlstm', version) pass
if os.path.exists(existing_model_path):
checkpoint = torch.load(existing_model_path, map_location=DEVICE)
model.load_state_dict(checkpoint['model_state_dict'])
print(f"加载现有模型: {existing_model_path}")
emit_progress(f"加载现有模型版本 {version} 进行继续训练")
except Exception as e:
print(f"无法加载现有模型,将重新开始训练: {e}")
emit_progress("无法加载现有模型,重新开始训练")
# 将模型移动到设备上 # 将模型移动到设备上
model = model.to(DEVICE) model = model.to(DEVICE)
@ -451,22 +362,24 @@ def train_product_model_with_mlstm(
} }
} }
# 保存检查点
save_checkpoint(checkpoint_data, epoch + 1, model_identifier, 'mlstm',
model_dir, store_id, training_mode, aggregation_method)
# 如果是最佳模型,额外保存一份 # 如果是最佳模型,额外保存一份
if test_loss < best_loss: if test_loss < best_loss:
best_loss = test_loss best_loss = test_loss
save_checkpoint(checkpoint_data, 'best', model_identifier, 'mlstm', model_manager.save_model(
model_dir, store_id, training_mode, aggregation_method) model_data=checkpoint_data,
product_id=model_identifier,
model_type='mlstm',
store_id=store_id,
training_mode=training_mode,
aggregation_method=aggregation_method,
product_name=product_name,
version='best'
)
emit_progress(f"💾 保存最佳模型检查点 (epoch {epoch+1}, test_loss: {test_loss:.4f})") emit_progress(f"💾 保存最佳模型检查点 (epoch {epoch+1}, test_loss: {test_loss:.4f})")
epochs_no_improve = 0 epochs_no_improve = 0
else: else:
epochs_no_improve += 1 epochs_no_improve += 1
emit_progress(f"💾 保存训练检查点 epoch_{epoch+1}")
if (epoch + 1) % 10 == 0: if (epoch + 1) % 10 == 0:
print(f"Epoch {epoch+1}/{epochs}, Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}", flush=True) print(f"Epoch {epoch+1}/{epochs}, Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}", flush=True)
@ -524,7 +437,6 @@ def train_product_model_with_mlstm(
# 计算评估指标 # 计算评估指标
metrics = evaluate_model(test_true_inv, test_pred_inv) metrics = evaluate_model(test_true_inv, test_pred_inv)
metrics['training_time'] = training_time metrics['training_time'] = training_time
metrics['version'] = version
# 打印评估指标 # 打印评估指标
print("\n模型评估指标:") print("\n模型评估指标:")
@ -576,10 +488,15 @@ def train_product_model_with_mlstm(
} }
} }
# 保存最终模型使用epoch标识 # 保存最终模型,让 model_manager 自动处理版本号
final_model_path = save_checkpoint( final_model_path, final_version = model_manager.save_model(
final_model_data, f"final_epoch_{epochs}", model_identifier, 'mlstm', model_data=final_model_data,
model_dir, store_id, training_mode, aggregation_method product_id=model_identifier,
model_type='mlstm',
store_id=store_id,
training_mode=training_mode,
aggregation_method=aggregation_method,
product_name=product_name
) )
# 发送训练完成消息 # 发送训练完成消息
@ -591,9 +508,10 @@ def train_product_model_with_mlstm(
'mape': metrics['mape'], 'mape': metrics['mape'],
'training_time': training_time, 'training_time': training_time,
'final_epoch': epochs, 'final_epoch': epochs,
'model_path': final_model_path 'model_path': final_model_path,
'version': final_version
} }
emit_progress(f"✅ mLSTM模型训练完成最终epoch: {epochs} 已保存", progress=100, metrics=final_metrics) emit_progress(f"✅ mLSTM模型训练完成版本 {final_version} 已保存", progress=100, metrics=final_metrics)
return model, metrics, epochs, final_model_path return model, metrics, epochs, final_model_path

View File

@ -20,39 +20,7 @@ from utils.visualization import plot_loss_curve
from analysis.metrics import evaluate_model from analysis.metrics import evaluate_model
from core.config import DEVICE, DEFAULT_MODEL_DIR, LOOK_BACK, FORECAST_HORIZON from core.config import DEVICE, DEFAULT_MODEL_DIR, LOOK_BACK, FORECAST_HORIZON
from utils.training_progress import progress_manager from utils.training_progress import progress_manager
from utils.model_manager import model_manager
def save_checkpoint(checkpoint_data: dict, epoch_or_label, model_identifier: str,
model_type: str, model_dir: str, store_id=None,
training_mode: str = 'product', aggregation_method=None):
"""
保存训练检查点
Args:
checkpoint_data: 检查点数据
epoch_or_label: epoch编号或标签'best'
product_id: 产品ID
model_type: 模型类型
model_dir: 模型保存目录
store_id: 店铺ID
training_mode: 训练模式
aggregation_method: 聚合方法
"""
# 创建检查点目录
# 直接在模型根目录保存,不再创建子目录
checkpoint_dir = model_dir
os.makedirs(checkpoint_dir, exist_ok=True)
# 生成检查点文件名
# 修正直接使用product_id作为唯一标识符因为它已经包含了store_前缀或药品ID
filename = f"{model_type}_{model_identifier}_epoch_{epoch_or_label}.pth"
checkpoint_path = os.path.join(checkpoint_dir, filename)
# 保存检查点
torch.save(checkpoint_data, checkpoint_path)
print(f"[TCN] 检查点已保存: {checkpoint_path}", flush=True)
return checkpoint_path
def train_product_model_with_tcn( def train_product_model_with_tcn(
product_id, product_id,
@ -72,21 +40,6 @@ def train_product_model_with_tcn(
): ):
""" """
使用TCN模型训练产品销售预测模型 使用TCN模型训练产品销售预测模型
参数:
product_id: 产品ID
epochs: 训练轮次
model_dir: 模型保存目录默认使用配置中的DEFAULT_MODEL_DIR
version: 指定版本号如果为None则自动生成
socketio: WebSocket对象用于实时反馈
task_id: 训练任务ID
continue_training: 是否继续训练现有模型
返回:
model: 训练好的模型
metrics: 模型评估指标
version: 实际使用的版本号
model_path: 模型文件路径
""" """
def emit_progress(message, progress=None, metrics=None): def emit_progress(message, progress=None, metrics=None):
@ -103,62 +56,21 @@ def train_product_model_with_tcn(
data['metrics'] = metrics data['metrics'] = metrics
socketio.emit('training_progress', data, namespace='/training') socketio.emit('training_progress', data, namespace='/training')
# 确定版本号 emit_progress(f"开始训练 TCN 模型")
if version is None:
from core.config import get_latest_model_version, get_next_model_version
if continue_training:
version = get_latest_model_version(product_id, 'tcn')
if version is None:
version = get_next_model_version(product_id, 'tcn')
else:
version = get_next_model_version(product_id, 'tcn')
emit_progress(f"开始训练 TCN 模型版本 {version}")
# 如果没有传入product_df则根据训练模式加载数据
if product_df is None: if product_df is None:
from utils.multi_store_data_utils import load_multi_store_data, get_store_product_sales_data, aggregate_multi_store_data from utils.multi_store_data_utils import aggregate_multi_store_data
product_df = aggregate_multi_store_data(
try: product_id=product_id,
if training_mode == 'store' and store_id: aggregation_method=aggregation_method
# 加载特定店铺的数据 )
product_df = get_store_product_sales_data( training_scope = f"全局聚合({aggregation_method})"
store_id,
product_id,
'pharmacy_sales_multi_store.csv'
)
training_scope = f"店铺 {store_id}"
elif training_mode == 'global':
# 聚合所有店铺的数据
product_df = aggregate_multi_store_data(
product_id,
aggregation_method=aggregation_method,
file_path='pharmacy_sales_multi_store.csv'
)
training_scope = f"全局聚合({aggregation_method})"
else:
# 默认:加载所有店铺的产品数据
product_df = load_multi_store_data('pharmacy_sales_multi_store.csv', product_id=product_id)
training_scope = "所有店铺"
except Exception as e:
print(f"多店铺数据加载失败: {e}")
# 后备方案:尝试原始数据
df = pd.read_excel('pharmacy_sales.xlsx')
product_df = df[df['product_id'] == product_id].sort_values('date')
training_scope = "原始数据"
else: else:
# 如果传入了product_df直接使用 training_scope = "所有店铺"
if training_mode == 'store' and store_id:
training_scope = f"店铺 {store_id}"
elif training_mode == 'global':
training_scope = f"全局聚合({aggregation_method})"
else:
training_scope = "所有店铺"
if product_df.empty: if product_df.empty:
raise ValueError(f"产品 {product_id} 没有可用的销售数据") raise ValueError(f"产品 {product_id} 没有可用的销售数据")
# 数据量检查
min_required_samples = sequence_length + forecast_horizon min_required_samples = sequence_length + forecast_horizon
if len(product_df) < min_required_samples: if len(product_df) < min_required_samples:
error_msg = ( error_msg = (
@ -166,10 +78,6 @@ def train_product_model_with_tcn(
f"当前配置需要: {min_required_samples} 天数据 (LOOK_BACK={sequence_length} + FORECAST_HORIZON={forecast_horizon})\n" f"当前配置需要: {min_required_samples} 天数据 (LOOK_BACK={sequence_length} + FORECAST_HORIZON={forecast_horizon})\n"
f"实际数据量: {len(product_df)}\n" f"实际数据量: {len(product_df)}\n"
f"产品ID: {product_id}, 训练模式: {training_mode}\n" f"产品ID: {product_id}, 训练模式: {training_mode}\n"
f"建议解决方案:\n"
f"1. 生成更多数据: uv run generate_multi_store_data.py\n"
f"2. 调整配置参数: 减小 LOOK_BACK 或 FORECAST_HORIZON\n"
f"3. 使用全局训练模式聚合更多数据"
) )
print(error_msg) print(error_msg)
emit_progress(f"训练失败:数据不足 ({len(product_df)}/{min_required_samples} 天)") emit_progress(f"训练失败:数据不足 ({len(product_df)}/{min_required_samples} 天)")
@ -180,48 +88,39 @@ def train_product_model_with_tcn(
print(f"使用TCN模型训练产品 '{product_name}' (ID: {product_id}) 的销售预测模型") print(f"使用TCN模型训练产品 '{product_name}' (ID: {product_id}) 的销售预测模型")
print(f"训练范围: {training_scope}") print(f"训练范围: {training_scope}")
print(f"版本: {version}")
print(f"使用设备: {DEVICE}") print(f"使用设备: {DEVICE}")
print(f"模型将保存到目录: {model_dir}") print(f"模型将保存到目录: {model_dir}")
emit_progress(f"训练产品: {product_name} (ID: {product_id})") emit_progress(f"训练产品: {product_name} (ID: {product_id})")
# 创建特征和目标变量
features = ['sales', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature'] features = ['sales', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature']
# 预处理数据
X = product_df[features].values X = product_df[features].values
y = product_df[['sales']].values # 保持为二维数组 y = product_df[['sales']].values
# 设置数据预处理阶段
progress_manager.set_stage("data_preprocessing", 0) progress_manager.set_stage("data_preprocessing", 0)
emit_progress("数据预处理中...") emit_progress("数据预处理中...")
# 归一化数据
scaler_X = MinMaxScaler(feature_range=(0, 1)) scaler_X = MinMaxScaler(feature_range=(0, 1))
scaler_y = MinMaxScaler(feature_range=(0, 1)) scaler_y = MinMaxScaler(feature_range=(0, 1))
X_scaled = scaler_X.fit_transform(X) X_scaled = scaler_X.fit_transform(X)
y_scaled = scaler_y.fit_transform(y) y_scaled = scaler_y.fit_transform(y)
# 划分训练集和测试集80% 训练20% 测试)
train_size = int(len(X_scaled) * 0.8) train_size = int(len(X_scaled) * 0.8)
X_train, X_test = X_scaled[:train_size], X_scaled[train_size:] X_train, X_test = X_scaled[:train_size], X_scaled[train_size:]
y_train, y_test = y_scaled[:train_size], y_scaled[train_size:] y_train, y_test = y_scaled[:train_size], y_scaled[train_size:]
progress_manager.set_stage("data_preprocessing", 50) progress_manager.set_stage("data_preprocessing", 50)
# 创建时间序列数据
trainX, trainY = create_dataset(X_train, y_train, sequence_length, forecast_horizon) trainX, trainY = create_dataset(X_train, y_train, sequence_length, forecast_horizon)
testX, testY = create_dataset(X_test, y_test, sequence_length, forecast_horizon) testX, testY = create_dataset(X_test, y_test, sequence_length, forecast_horizon)
# 转换为PyTorch的Tensor
trainX_tensor = torch.Tensor(trainX) trainX_tensor = torch.Tensor(trainX)
trainY_tensor = torch.Tensor(trainY) trainY_tensor = torch.Tensor(trainY)
testX_tensor = torch.Tensor(testX) testX_tensor = torch.Tensor(testX)
testY_tensor = torch.Tensor(testY) testY_tensor = torch.Tensor(testY)
# 创建数据加载器
train_dataset = PharmacyDataset(trainX_tensor, trainY_tensor) train_dataset = PharmacyDataset(trainX_tensor, trainY_tensor)
test_dataset = PharmacyDataset(testX_tensor, testY_tensor) test_dataset = PharmacyDataset(testX_tensor, testY_tensor)
@ -229,7 +128,6 @@ def train_product_model_with_tcn(
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False) test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
# 更新进度管理器的批次信息
total_batches = len(train_loader) total_batches = len(train_loader)
total_samples = len(train_dataset) total_samples = len(train_dataset)
progress_manager.total_batches_per_epoch = total_batches progress_manager.total_batches_per_epoch = total_batches
@ -238,7 +136,6 @@ def train_product_model_with_tcn(
progress_manager.set_stage("data_preprocessing", 100) progress_manager.set_stage("data_preprocessing", 100)
# 初始化TCN模型
input_dim = X_train.shape[1] input_dim = X_train.shape[1]
output_dim = forecast_horizon output_dim = forecast_horizon
hidden_size = 64 hidden_size = 64
@ -254,21 +151,8 @@ def train_product_model_with_tcn(
dropout=dropout_rate dropout=dropout_rate
) )
# 如果是继续训练,加载现有模型 # TODO: Implement continue_training logic with the new model_manager
if continue_training and version != 'v1':
try:
from core.config import get_model_file_path
existing_model_path = get_model_file_path(product_id, 'tcn', version)
if os.path.exists(existing_model_path):
checkpoint = torch.load(existing_model_path, map_location=DEVICE)
model.load_state_dict(checkpoint['model_state_dict'])
print(f"加载现有模型: {existing_model_path}")
emit_progress(f"加载现有模型版本 {version} 进行继续训练")
except Exception as e:
print(f"无法加载现有模型,将重新开始训练: {e}")
emit_progress("无法加载现有模型,重新开始训练")
# 将模型移动到设备上
model = model.to(DEVICE) model = model.to(DEVICE)
criterion = nn.MSELoss() criterion = nn.MSELoss()
@ -276,20 +160,17 @@ def train_product_model_with_tcn(
emit_progress("开始模型训练...") emit_progress("开始模型训练...")
# 训练模型
train_losses = [] train_losses = []
test_losses = [] test_losses = []
start_time = time.time() start_time = time.time()
# 配置检查点保存 checkpoint_interval = max(1, epochs // 10)
checkpoint_interval = max(1, epochs // 10) # 每10%进度保存一次最少每1个epoch
best_loss = float('inf') best_loss = float('inf')
progress_manager.set_stage("model_training", 0) progress_manager.set_stage("model_training", 0)
emit_progress(f"开始训练 - 总epoch: {epochs}, 检查点间隔: {checkpoint_interval}") emit_progress(f"开始训练 - 总epoch: {epochs}, 检查点间隔: {checkpoint_interval}")
for epoch in range(epochs): for epoch in range(epochs):
# 开始新的轮次
progress_manager.start_epoch(epoch) progress_manager.start_epoch(epoch)
model.train() model.train()
@ -298,43 +179,34 @@ def train_product_model_with_tcn(
for batch_idx, (X_batch, y_batch) in enumerate(train_loader): for batch_idx, (X_batch, y_batch) in enumerate(train_loader):
X_batch, y_batch = X_batch.to(DEVICE), y_batch.to(DEVICE) X_batch, y_batch = X_batch.to(DEVICE), y_batch.to(DEVICE)
# 确保目标张量有正确的形状 (batch_size, forecast_horizon, 1)
if y_batch.dim() == 2: if y_batch.dim() == 2:
y_batch = y_batch.unsqueeze(-1) y_batch = y_batch.unsqueeze(-1)
# 前向传播
outputs = model(X_batch) outputs = model(X_batch)
# 确保输出和目标形状匹配
loss = criterion(outputs, y_batch) loss = criterion(outputs, y_batch)
# 反向传播和优化
optimizer.zero_grad() optimizer.zero_grad()
loss.backward() loss.backward()
optimizer.step() optimizer.step()
epoch_loss += loss.item() epoch_loss += loss.item()
# 更新批次进度每10个批次更新一次
if batch_idx % 10 == 0 or batch_idx == len(train_loader) - 1: if batch_idx % 10 == 0 or batch_idx == len(train_loader) - 1:
current_lr = optimizer.param_groups[0]['lr'] current_lr = optimizer.param_groups[0]['lr']
progress_manager.update_batch(batch_idx, loss.item(), current_lr) progress_manager.update_batch(batch_idx, loss.item(), current_lr)
# 计算训练损失
train_loss = epoch_loss / len(train_loader) train_loss = epoch_loss / len(train_loader)
train_losses.append(train_loss) train_losses.append(train_loss)
# 设置验证阶段
progress_manager.set_stage("validation", 0) progress_manager.set_stage("validation", 0)
# 在测试集上评估
model.eval() model.eval()
test_loss = 0 test_loss = 0
with torch.no_grad(): with torch.no_grad():
for batch_idx, (X_batch, y_batch) in enumerate(test_loader): for batch_idx, (X_batch, y_batch) in enumerate(test_loader):
X_batch, y_batch = X_batch.to(DEVICE), y_batch.to(DEVICE) X_batch, y_batch = X_batch.to(DEVICE), y_batch.to(DEVICE)
# 确保目标张量有正确的形状
if y_batch.dim() == 2: if y_batch.dim() == 2:
y_batch = y_batch.unsqueeze(-1) y_batch = y_batch.unsqueeze(-1)
@ -342,7 +214,6 @@ def train_product_model_with_tcn(
loss = criterion(outputs, y_batch) loss = criterion(outputs, y_batch)
test_loss += loss.item() test_loss += loss.item()
# 更新验证进度
if batch_idx % 5 == 0 or batch_idx == len(test_loader) - 1: if batch_idx % 5 == 0 or batch_idx == len(test_loader) - 1:
val_progress = (batch_idx / len(test_loader)) * 100 val_progress = (batch_idx / len(test_loader)) * 100
progress_manager.set_stage("validation", val_progress) progress_manager.set_stage("validation", val_progress)
@ -350,10 +221,8 @@ def train_product_model_with_tcn(
test_loss = test_loss / len(test_loader) test_loss = test_loss / len(test_loader)
test_losses.append(test_loss) test_losses.append(test_loss)
# 完成当前轮次
progress_manager.finish_epoch(train_loss, test_loss) progress_manager.finish_epoch(train_loss, test_loss)
# 发送训练进度(保持与旧系统的兼容性)
if (epoch + 1) % 5 == 0 or epoch == epochs - 1: if (epoch + 1) % 5 == 0 or epoch == epochs - 1:
progress = ((epoch + 1) / epochs) * 100 progress = ((epoch + 1) / epochs) * 100
current_metrics = { current_metrics = {
@ -365,7 +234,6 @@ def train_product_model_with_tcn(
emit_progress(f"Epoch {epoch+1}/{epochs}, Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}", emit_progress(f"Epoch {epoch+1}/{epochs}, Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}",
progress=progress, metrics=current_metrics) progress=progress, metrics=current_metrics)
# 定期保存检查点
if (epoch + 1) % checkpoint_interval == 0 or epoch == epochs - 1: if (epoch + 1) % checkpoint_interval == 0 or epoch == epochs - 1:
checkpoint_data = { checkpoint_data = {
'epoch': epoch + 1, 'epoch': epoch + 1,
@ -399,30 +267,28 @@ def train_product_model_with_tcn(
} }
} }
# 保存检查点
save_checkpoint(checkpoint_data, epoch + 1, model_identifier, 'tcn',
model_dir, store_id, training_mode, aggregation_method)
# 如果是最佳模型,额外保存一份
if test_loss < best_loss: if test_loss < best_loss:
best_loss = test_loss best_loss = test_loss
save_checkpoint(checkpoint_data, 'best', model_identifier, 'tcn', model_manager.save_model(
model_dir, store_id, training_mode, aggregation_method) model_data=checkpoint_data,
product_id=model_identifier,
model_type='tcn',
store_id=store_id,
training_mode=training_mode,
aggregation_method=aggregation_method,
product_name=product_name,
version='best'
)
emit_progress(f"💾 保存最佳模型检查点 (epoch {epoch+1}, test_loss: {test_loss:.4f})") emit_progress(f"💾 保存最佳模型检查点 (epoch {epoch+1}, test_loss: {test_loss:.4f})")
emit_progress(f"💾 保存训练检查点 epoch_{epoch+1}")
if (epoch + 1) % 10 == 0: if (epoch + 1) % 10 == 0:
print(f"Epoch {epoch+1}/{epochs}, Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}") print(f"Epoch {epoch+1}/{epochs}, Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}")
# 计算训练时间
training_time = time.time() - start_time training_time = time.time() - start_time
# 设置模型保存阶段
progress_manager.set_stage("model_saving", 0) progress_manager.set_stage("model_saving", 0)
emit_progress("训练完成,正在保存模型...") emit_progress("训练完成,正在保存模型...")
# 绘制损失曲线并保存到模型目录
loss_curve_path = plot_loss_curve( loss_curve_path = plot_loss_curve(
train_losses, train_losses,
test_losses, test_losses,
@ -432,23 +298,17 @@ def train_product_model_with_tcn(
) )
print(f"损失曲线已保存到: {loss_curve_path}") print(f"损失曲线已保存到: {loss_curve_path}")
# 评估模型
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
# 确保测试数据的形状正确
test_pred = model(testX_tensor.to(DEVICE)) test_pred = model(testX_tensor.to(DEVICE))
# 将输出转换为二维数组 [samples, forecast_horizon]
test_pred = test_pred.squeeze(-1).cpu().numpy() test_pred = test_pred.squeeze(-1).cpu().numpy()
# 反归一化预测结果和真实值
test_pred_inv = scaler_y.inverse_transform(test_pred.reshape(-1, 1)).flatten() test_pred_inv = scaler_y.inverse_transform(test_pred.reshape(-1, 1)).flatten()
test_true_inv = scaler_y.inverse_transform(testY.reshape(-1, 1)).flatten() test_true_inv = scaler_y.inverse_transform(testY.reshape(-1, 1)).flatten()
# 计算评估指标
metrics = evaluate_model(test_true_inv, test_pred_inv) metrics = evaluate_model(test_true_inv, test_pred_inv)
metrics['training_time'] = training_time metrics['training_time'] = training_time
# 打印评估指标
print("\n模型评估指标:") print("\n模型评估指标:")
print(f"MSE: {metrics['mse']:.4f}") print(f"MSE: {metrics['mse']:.4f}")
print(f"RMSE: {metrics['rmse']:.4f}") print(f"RMSE: {metrics['rmse']:.4f}")
@ -457,9 +317,8 @@ def train_product_model_with_tcn(
print(f"MAPE: {metrics['mape']:.2f}%") print(f"MAPE: {metrics['mape']:.2f}%")
print(f"训练时间: {training_time:.2f}") print(f"训练时间: {training_time:.2f}")
# 保存最终训练完成的模型基于最终epoch
final_model_data = { final_model_data = {
'epoch': epochs, # 最终epoch 'epoch': epochs,
'model_state_dict': model.state_dict(), 'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(), 'optimizer_state_dict': optimizer.state_dict(),
'train_loss': train_losses[-1], 'train_loss': train_losses[-1],
@ -495,10 +354,14 @@ def train_product_model_with_tcn(
progress_manager.set_stage("model_saving", 50) progress_manager.set_stage("model_saving", 50)
# 保存最终模型使用epoch标识 final_model_path, final_version = model_manager.save_model(
final_model_path = save_checkpoint( model_data=final_model_data,
final_model_data, f"final_epoch_{epochs}", model_identifier, 'tcn', product_id=model_identifier,
model_dir, store_id, training_mode, aggregation_method model_type='tcn',
store_id=store_id,
training_mode=training_mode,
aggregation_method=aggregation_method,
product_name=product_name
) )
progress_manager.set_stage("model_saving", 100) progress_manager.set_stage("model_saving", 100)
@ -510,9 +373,10 @@ def train_product_model_with_tcn(
'r2': metrics['r2'], 'r2': metrics['r2'],
'mape': metrics['mape'], 'mape': metrics['mape'],
'training_time': training_time, 'training_time': training_time,
'final_epoch': epochs 'final_epoch': epochs,
'version': final_version
} }
emit_progress(f"模型训练完成!最终epoch: {epochs}", progress=100, metrics=final_metrics) emit_progress(f"模型训练完成!版本 {final_version} 已保存", progress=100, metrics=final_metrics)
return model, metrics, epochs, final_model_path return model, metrics, epochs, final_model_path

View File

@ -21,43 +21,11 @@ from utils.multi_store_data_utils import get_store_product_sales_data, aggregate
from utils.visualization import plot_loss_curve from utils.visualization import plot_loss_curve
from analysis.metrics import evaluate_model from analysis.metrics import evaluate_model
from core.config import ( from core.config import (
DEVICE, DEFAULT_MODEL_DIR, LOOK_BACK, FORECAST_HORIZON, DEVICE, DEFAULT_MODEL_DIR, LOOK_BACK, FORECAST_HORIZON
get_next_model_version, get_model_file_path, get_latest_model_version
) )
from utils.training_progress import progress_manager from utils.training_progress import progress_manager
from utils.model_manager import model_manager from utils.model_manager import model_manager
def save_checkpoint(checkpoint_data: dict, epoch_or_label, model_identifier: str,
model_type: str, model_dir: str, store_id=None,
training_mode: str = 'product', aggregation_method=None):
"""
保存训练检查点
Args:
checkpoint_data: 检查点数据
epoch_or_label: epoch编号或标签'best'
product_id: 产品ID
model_type: 模型类型
model_dir: 模型保存目录
store_id: 店铺ID
training_mode: 训练模式
aggregation_method: 聚合方法
"""
# 直接在模型根目录保存,不再创建子目录
checkpoint_dir = model_dir
os.makedirs(checkpoint_dir, exist_ok=True)
# 修正直接使用product_id作为唯一标识符因为它已经包含了store_前缀或药品ID
filename = f"{model_type}_{model_identifier}_epoch_{epoch_or_label}.pth"
checkpoint_path = os.path.join(checkpoint_dir, filename)
# 保存检查点
torch.save(checkpoint_data, checkpoint_path)
print(f"[Transformer] 检查点已保存: {checkpoint_path}", flush=True)
return checkpoint_path
def train_product_model_with_transformer( def train_product_model_with_transformer(
product_id, product_id,
model_identifier, model_identifier,
@ -79,23 +47,8 @@ def train_product_model_with_transformer(
): ):
""" """
使用Transformer模型训练产品销售预测模型 使用Transformer模型训练产品销售预测模型
参数:
product_id: 产品ID
epochs: 训练轮次
model_dir: 模型保存目录默认使用配置中的DEFAULT_MODEL_DIR
version: 指定版本号如果为None则自动生成
socketio: WebSocket对象用于实时反馈
task_id: 训练任务ID
continue_training: 是否继续训练现有模型
返回:
model: 训练好的模型
metrics: 模型评估指标
version: 实际使用的版本号
""" """
# WebSocket进度反馈函数
def emit_progress(message, progress=None, metrics=None): def emit_progress(message, progress=None, metrics=None):
"""发送训练进度到前端""" """发送训练进度到前端"""
if socketio and task_id: if socketio and task_id:
@ -110,18 +63,15 @@ def train_product_model_with_transformer(
data['metrics'] = metrics data['metrics'] = metrics
socketio.emit('training_progress', data, namespace='/training') socketio.emit('training_progress', data, namespace='/training')
print(f"[{time.strftime('%H:%M:%S')}] {message}", flush=True) print(f"[{time.strftime('%H:%M:%S')}] {message}", flush=True)
# 强制刷新输出缓冲区
import sys import sys
sys.stdout.flush() sys.stdout.flush()
sys.stderr.flush() sys.stderr.flush()
emit_progress("开始Transformer模型训练...") emit_progress("开始Transformer模型训练...")
# 获取训练进度管理器实例
try: try:
from utils.training_progress import progress_manager from utils.training_progress import progress_manager
except ImportError: except ImportError:
# 如果无法导入,创建一个空的管理器以避免错误
class DummyProgressManager: class DummyProgressManager:
def set_stage(self, *args, **kwargs): pass def set_stage(self, *args, **kwargs): pass
def start_training(self, *args, **kwargs): pass def start_training(self, *args, **kwargs): pass
@ -131,50 +81,19 @@ def train_product_model_with_transformer(
def finish_training(self, *args, **kwargs): pass def finish_training(self, *args, **kwargs): pass
progress_manager = DummyProgressManager() progress_manager = DummyProgressManager()
# 如果没有传入product_df则根据训练模式加载数据
if product_df is None: if product_df is None:
from utils.multi_store_data_utils import load_multi_store_data, get_store_product_sales_data, aggregate_multi_store_data from utils.multi_store_data_utils import aggregate_multi_store_data
product_df = aggregate_multi_store_data(
try: product_id=product_id,
if training_mode == 'store' and store_id: aggregation_method=aggregation_method
# 加载特定店铺的数据 )
product_df = get_store_product_sales_data( training_scope = f"全局聚合({aggregation_method})"
store_id,
product_id,
'pharmacy_sales_multi_store.csv'
)
training_scope = f"店铺 {store_id}"
elif training_mode == 'global':
# 聚合所有店铺的数据
product_df = aggregate_multi_store_data(
product_id,
aggregation_method=aggregation_method,
file_path='pharmacy_sales_multi_store.csv'
)
training_scope = f"全局聚合({aggregation_method})"
else:
# 默认:加载所有店铺的产品数据
product_df = load_multi_store_data('pharmacy_sales_multi_store.csv', product_id=product_id)
training_scope = "所有店铺"
except Exception as e:
print(f"多店铺数据加载失败: {e}")
# 后备方案:尝试原始数据
df = pd.read_excel('pharmacy_sales.xlsx')
product_df = df[df['product_id'] == product_id].sort_values('date')
training_scope = "原始数据"
else: else:
# 如果传入了product_df直接使用 training_scope = "所有店铺"
if training_mode == 'store' and store_id:
training_scope = f"店铺 {store_id}"
elif training_mode == 'global':
training_scope = f"全局聚合({aggregation_method})"
else:
training_scope = "所有店铺"
if product_df.empty: if product_df.empty:
raise ValueError(f"产品 {product_id} 没有可用的销售数据") raise ValueError(f"产品 {product_id} 没有可用的销售数据")
# 数据量检查
min_required_samples = sequence_length + forecast_horizon min_required_samples = sequence_length + forecast_horizon
if len(product_df) < min_required_samples: if len(product_df) < min_required_samples:
error_msg = ( error_msg = (
@ -182,10 +101,6 @@ def train_product_model_with_transformer(
f"当前配置需要: {min_required_samples} 天数据 (LOOK_BACK={sequence_length} + FORECAST_HORIZON={forecast_horizon})\n" f"当前配置需要: {min_required_samples} 天数据 (LOOK_BACK={sequence_length} + FORECAST_HORIZON={forecast_horizon})\n"
f"实际数据量: {len(product_df)}\n" f"实际数据量: {len(product_df)}\n"
f"产品ID: {product_id}, 训练模式: {training_mode}\n" f"产品ID: {product_id}, 训练模式: {training_mode}\n"
f"建议解决方案:\n"
f"1. 生成更多数据: uv run generate_multi_store_data.py\n"
f"2. 调整配置参数: 减小 LOOK_BACK 或 FORECAST_HORIZON\n"
f"3. 使用全局训练模式聚合更多数据"
) )
print(error_msg) print(error_msg)
raise ValueError(error_msg) raise ValueError(error_msg)
@ -197,18 +112,14 @@ def train_product_model_with_transformer(
print(f"[Device] 使用设备: {DEVICE}", flush=True) print(f"[Device] 使用设备: {DEVICE}", flush=True)
print(f"[Model] 模型将保存到目录: {model_dir}", flush=True) print(f"[Model] 模型将保存到目录: {model_dir}", flush=True)
# 创建特征和目标变量
features = ['sales', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature'] features = ['sales', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature']
# 设置数据预处理阶段
progress_manager.set_stage("data_preprocessing", 0) progress_manager.set_stage("data_preprocessing", 0)
emit_progress("数据预处理中...") emit_progress("数据预处理中...")
# 预处理数据
X = product_df[features].values X = product_df[features].values
y = product_df[['sales']].values # 保持为二维数组 y = product_df[['sales']].values
# 归一化数据
scaler_X = MinMaxScaler(feature_range=(0, 1)) scaler_X = MinMaxScaler(feature_range=(0, 1))
scaler_y = MinMaxScaler(feature_range=(0, 1)) scaler_y = MinMaxScaler(feature_range=(0, 1))
@ -217,24 +128,20 @@ def train_product_model_with_transformer(
progress_manager.set_stage("data_preprocessing", 40) progress_manager.set_stage("data_preprocessing", 40)
# 划分训练集和测试集80% 训练20% 测试)
train_size = int(len(X_scaled) * 0.8) train_size = int(len(X_scaled) * 0.8)
X_train, X_test = X_scaled[:train_size], X_scaled[train_size:] X_train, X_test = X_scaled[:train_size], X_scaled[train_size:]
y_train, y_test = y_scaled[:train_size], y_scaled[train_size:] y_train, y_test = y_scaled[:train_size], y_scaled[train_size:]
# 创建时间序列数据
trainX, trainY = create_dataset(X_train, y_train, sequence_length, forecast_horizon) trainX, trainY = create_dataset(X_train, y_train, sequence_length, forecast_horizon)
testX, testY = create_dataset(X_test, y_test, sequence_length, forecast_horizon) testX, testY = create_dataset(X_test, y_test, sequence_length, forecast_horizon)
progress_manager.set_stage("data_preprocessing", 70) progress_manager.set_stage("data_preprocessing", 70)
# 转换为PyTorch的Tensor
trainX_tensor = torch.Tensor(trainX) trainX_tensor = torch.Tensor(trainX)
trainY_tensor = torch.Tensor(trainY) trainY_tensor = torch.Tensor(trainY)
testX_tensor = torch.Tensor(testX) testX_tensor = torch.Tensor(testX)
testY_tensor = torch.Tensor(testY) testY_tensor = torch.Tensor(testY)
# 创建数据加载器
train_dataset = PharmacyDataset(trainX_tensor, trainY_tensor) train_dataset = PharmacyDataset(trainX_tensor, trainY_tensor)
test_dataset = PharmacyDataset(testX_tensor, testY_tensor) test_dataset = PharmacyDataset(testX_tensor, testY_tensor)
@ -242,7 +149,6 @@ def train_product_model_with_transformer(
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False) test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
# 更新进度管理器的批次信息
total_batches = len(train_loader) total_batches = len(train_loader)
total_samples = len(train_dataset) total_samples = len(train_dataset)
progress_manager.total_batches_per_epoch = total_batches progress_manager.total_batches_per_epoch = total_batches
@ -252,7 +158,6 @@ def train_product_model_with_transformer(
progress_manager.set_stage("data_preprocessing", 100) progress_manager.set_stage("data_preprocessing", 100)
emit_progress("数据预处理完成,开始模型训练...") emit_progress("数据预处理完成,开始模型训练...")
# 初始化Transformer模型
input_dim = X_train.shape[1] input_dim = X_train.shape[1]
output_dim = forecast_horizon output_dim = forecast_horizon
hidden_size = 64 hidden_size = 64
@ -272,20 +177,17 @@ def train_product_model_with_transformer(
batch_size=batch_size batch_size=batch_size
) )
# 将模型移动到设备上
model = model.to(DEVICE) model = model.to(DEVICE)
criterion = nn.MSELoss() criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate) optimizer = optim.Adam(model.parameters(), lr=learning_rate)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=patience // 2, factor=0.5) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=patience // 2, factor=0.5)
# 训练模型
train_losses = [] train_losses = []
test_losses = [] test_losses = []
start_time = time.time() start_time = time.time()
# 配置检查点保存 checkpoint_interval = max(1, epochs // 10)
checkpoint_interval = max(1, epochs // 10) # 每10%进度保存一次最少每1个epoch
best_loss = float('inf') best_loss = float('inf')
epochs_no_improve = 0 epochs_no_improve = 0
@ -293,7 +195,6 @@ def train_product_model_with_transformer(
emit_progress(f"开始训练 - 总epoch: {epochs}, 检查点间隔: {checkpoint_interval}, 耐心值: {patience}") emit_progress(f"开始训练 - 总epoch: {epochs}, 检查点间隔: {checkpoint_interval}, 耐心值: {patience}")
for epoch in range(epochs): for epoch in range(epochs):
# 开始新的轮次
progress_manager.start_epoch(epoch) progress_manager.start_epoch(epoch)
model.train() model.train()
@ -302,12 +203,9 @@ def train_product_model_with_transformer(
for batch_idx, (X_batch, y_batch) in enumerate(train_loader): for batch_idx, (X_batch, y_batch) in enumerate(train_loader):
X_batch, y_batch = X_batch.to(DEVICE), y_batch.to(DEVICE) X_batch, y_batch = X_batch.to(DEVICE), y_batch.to(DEVICE)
# 确保目标张量有正确的形状
# 前向传播
outputs = model(X_batch) outputs = model(X_batch)
loss = criterion(outputs, y_batch) loss = criterion(outputs, y_batch)
# 反向传播和优化
optimizer.zero_grad() optimizer.zero_grad()
loss.backward() loss.backward()
if clip_norm: if clip_norm:
@ -316,31 +214,25 @@ def train_product_model_with_transformer(
epoch_loss += loss.item() epoch_loss += loss.item()
# 更新批次进度
if batch_idx % 5 == 0 or batch_idx == len(train_loader) - 1: if batch_idx % 5 == 0 or batch_idx == len(train_loader) - 1:
current_lr = optimizer.param_groups[0]['lr'] current_lr = optimizer.param_groups[0]['lr']
progress_manager.update_batch(batch_idx, loss.item(), current_lr) progress_manager.update_batch(batch_idx, loss.item(), current_lr)
# 计算训练损失
train_loss = epoch_loss / len(train_loader) train_loss = epoch_loss / len(train_loader)
train_losses.append(train_loss) train_losses.append(train_loss)
# 设置验证阶段
progress_manager.set_stage("validation", 0) progress_manager.set_stage("validation", 0)
# 在测试集上评估
model.eval() model.eval()
test_loss = 0 test_loss = 0
with torch.no_grad(): with torch.no_grad():
for batch_idx, (X_batch, y_batch) in enumerate(test_loader): for batch_idx, (X_batch, y_batch) in enumerate(test_loader):
X_batch, y_batch = X_batch.to(DEVICE), y_batch.to(DEVICE) X_batch, y_batch = X_batch.to(DEVICE), y_batch.to(DEVICE)
# 确保目标张量有正确的形状
outputs = model(X_batch) outputs = model(X_batch)
loss = criterion(outputs, y_batch) loss = criterion(outputs, y_batch)
test_loss += loss.item() test_loss += loss.item()
# 更新验证进度
if batch_idx % 3 == 0 or batch_idx == len(test_loader) - 1: if batch_idx % 3 == 0 or batch_idx == len(test_loader) - 1:
val_progress = (batch_idx / len(test_loader)) * 100 val_progress = (batch_idx / len(test_loader)) * 100
progress_manager.set_stage("validation", val_progress) progress_manager.set_stage("validation", val_progress)
@ -348,13 +240,10 @@ def train_product_model_with_transformer(
test_loss = test_loss / len(test_loader) test_loss = test_loss / len(test_loader)
test_losses.append(test_loss) test_losses.append(test_loss)
# 更新学习率
scheduler.step(test_loss) scheduler.step(test_loss)
# 完成当前轮次
progress_manager.finish_epoch(train_loss, test_loss) progress_manager.finish_epoch(train_loss, test_loss)
# 发送训练进度
if (epoch + 1) % 5 == 0 or epoch == epochs - 1: if (epoch + 1) % 5 == 0 or epoch == epochs - 1:
progress = ((epoch + 1) / epochs) * 100 progress = ((epoch + 1) / epochs) * 100
current_metrics = { current_metrics = {
@ -366,7 +255,6 @@ def train_product_model_with_transformer(
emit_progress(f"Epoch {epoch+1}/{epochs}, Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}", emit_progress(f"Epoch {epoch+1}/{epochs}, Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}",
progress=progress, metrics=current_metrics) progress=progress, metrics=current_metrics)
# 定期保存检查点
if (epoch + 1) % checkpoint_interval == 0 or epoch == epochs - 1: if (epoch + 1) % checkpoint_interval == 0 or epoch == epochs - 1:
checkpoint_data = { checkpoint_data = {
'epoch': epoch + 1, 'epoch': epoch + 1,
@ -399,38 +287,35 @@ def train_product_model_with_transformer(
} }
} }
# 保存检查点
save_checkpoint(checkpoint_data, epoch + 1, model_identifier, 'transformer',
model_dir, store_id, training_mode, aggregation_method)
# 如果是最佳模型,额外保存一份
if test_loss < best_loss: if test_loss < best_loss:
best_loss = test_loss best_loss = test_loss
save_checkpoint(checkpoint_data, 'best', model_identifier, 'transformer', model_manager.save_model(
model_dir, store_id, training_mode, aggregation_method) model_data=checkpoint_data,
product_id=model_identifier,
model_type='transformer',
store_id=store_id,
training_mode=training_mode,
aggregation_method=aggregation_method,
product_name=product_name,
version='best'
)
emit_progress(f"💾 保存最佳模型检查点 (epoch {epoch+1}, test_loss: {test_loss:.4f})") emit_progress(f"💾 保存最佳模型检查点 (epoch {epoch+1}, test_loss: {test_loss:.4f})")
epochs_no_improve = 0 epochs_no_improve = 0
else: else:
epochs_no_improve += 1 epochs_no_improve += 1
emit_progress(f"💾 保存训练检查点 epoch_{epoch+1}")
if (epoch + 1) % 10 == 0: if (epoch + 1) % 10 == 0:
print(f"📊 Epoch {epoch+1}/{epochs}, 训练损失: {train_loss:.4f}, 测试损失: {test_loss:.4f}", flush=True) print(f"📊 Epoch {epoch+1}/{epochs}, 训练损失: {train_loss:.4f}, 测试损失: {test_loss:.4f}", flush=True)
# 提前停止逻辑
if epochs_no_improve >= patience: if epochs_no_improve >= patience:
emit_progress(f"连续 {patience} 个epoch测试损失未改善提前停止训练。") emit_progress(f"连续 {patience} 个epoch测试损失未改善提前停止训练。")
break break
# 计算训练时间
training_time = time.time() - start_time training_time = time.time() - start_time
# 设置模型保存阶段
progress_manager.set_stage("model_saving", 0) progress_manager.set_stage("model_saving", 0)
emit_progress("训练完成,正在保存模型...") emit_progress("训练完成,正在保存模型...")
# 绘制损失曲线并保存到模型目录
loss_curve_path = plot_loss_curve( loss_curve_path = plot_loss_curve(
train_losses, train_losses,
test_losses, test_losses,
@ -440,21 +325,17 @@ def train_product_model_with_transformer(
) )
print(f"📈 损失曲线已保存到: {loss_curve_path}", flush=True) print(f"📈 损失曲线已保存到: {loss_curve_path}", flush=True)
# 评估模型
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
test_pred = model(testX_tensor.to(DEVICE)).cpu().numpy() test_pred = model(testX_tensor.to(DEVICE)).cpu().numpy()
test_true = testY test_true = testY
# 反归一化预测结果和真实值
test_pred_inv = scaler_y.inverse_transform(test_pred) test_pred_inv = scaler_y.inverse_transform(test_pred)
test_true_inv = scaler_y.inverse_transform(test_true) test_true_inv = scaler_y.inverse_transform(test_true)
# 计算评估指标
metrics = evaluate_model(test_true_inv, test_pred_inv) metrics = evaluate_model(test_true_inv, test_pred_inv)
metrics['training_time'] = training_time metrics['training_time'] = training_time
# 打印评估指标
print(f"\n📊 模型评估指标:", flush=True) print(f"\n📊 模型评估指标:", flush=True)
print(f" MSE: {metrics['mse']:.4f}", flush=True) print(f" MSE: {metrics['mse']:.4f}", flush=True)
print(f" RMSE: {metrics['rmse']:.4f}", flush=True) print(f" RMSE: {metrics['rmse']:.4f}", flush=True)
@ -463,9 +344,8 @@ def train_product_model_with_transformer(
print(f" MAPE: {metrics['mape']:.2f}%", flush=True) print(f" MAPE: {metrics['mape']:.2f}%", flush=True)
print(f" ⏱️ 训练时间: {training_time:.2f}", flush=True) print(f" ⏱️ 训练时间: {training_time:.2f}", flush=True)
# 保存最终训练完成的模型基于最终epoch
final_model_data = { final_model_data = {
'epoch': epochs, # 最终epoch 'epoch': epochs,
'model_state_dict': model.state_dict(), 'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(), 'optimizer_state_dict': optimizer.state_dict(),
'train_loss': train_losses[-1], 'train_loss': train_losses[-1],
@ -500,10 +380,14 @@ def train_product_model_with_transformer(
progress_manager.set_stage("model_saving", 50) progress_manager.set_stage("model_saving", 50)
# 保存最终模型使用epoch标识 final_model_path, final_version = model_manager.save_model(
final_model_path = save_checkpoint( model_data=final_model_data,
final_model_data, f"final_epoch_{epochs}", model_identifier, 'transformer', product_id=model_identifier,
model_dir, store_id, training_mode, aggregation_method model_type='transformer',
store_id=store_id,
training_mode=training_mode,
aggregation_method=aggregation_method,
product_name=product_name
) )
progress_manager.set_stage("model_saving", 100) progress_manager.set_stage("model_saving", 100)
@ -511,7 +395,6 @@ def train_product_model_with_transformer(
print(f"💾 模型已保存到 {final_model_path}", flush=True) print(f"💾 模型已保存到 {final_model_path}", flush=True)
# 准备最终返回的指标
final_metrics = { final_metrics = {
'mse': metrics['mse'], 'mse': metrics['mse'],
'rmse': metrics['rmse'], 'rmse': metrics['rmse'],
@ -519,7 +402,8 @@ def train_product_model_with_transformer(
'r2': metrics['r2'], 'r2': metrics['r2'],
'mape': metrics['mape'], 'mape': metrics['mape'],
'training_time': training_time, 'training_time': training_time,
'final_epoch': epochs 'final_epoch': epochs,
'version': final_version
} }
return model, final_metrics, epochs return model, final_metrics, epochs

View File

@ -8,6 +8,7 @@ import json
import torch import torch
import glob import glob
from datetime import datetime from datetime import datetime
import re
from typing import List, Dict, Optional, Tuple from typing import List, Dict, Optional, Tuple
from core.config import DEFAULT_MODEL_DIR from core.config import DEFAULT_MODEL_DIR
@ -24,6 +25,27 @@ class ModelManager:
if not os.path.exists(self.model_dir): if not os.path.exists(self.model_dir):
os.makedirs(self.model_dir) os.makedirs(self.model_dir)
def _get_next_version(self, product_id: str, model_type: str, store_id: Optional[str] = None, training_mode: str = 'product') -> int:
"""获取下一个模型版本号 (纯数字)"""
search_pattern = self.generate_model_filename(
product_id=product_id,
model_type=model_type,
version='v*',
store_id=store_id,
training_mode=training_mode
)
full_search_path = os.path.join(self.model_dir, search_pattern)
existing_files = glob.glob(full_search_path)
max_version = 0
for f in existing_files:
match = re.search(r'_v(\d+)\.pth$', os.path.basename(f))
if match:
max_version = max(max_version, int(match.group(1)))
return max_version + 1
def generate_model_filename(self, def generate_model_filename(self,
product_id: str, product_id: str,
model_type: str, model_type: str,
@ -51,29 +73,29 @@ class ModelManager:
model_data: dict, model_data: dict,
product_id: str, product_id: str,
model_type: str, model_type: str,
version: str,
store_id: Optional[str] = None, store_id: Optional[str] = None,
training_mode: str = 'product', training_mode: str = 'product',
aggregation_method: Optional[str] = None, aggregation_method: Optional[str] = None,
product_name: Optional[str] = None) -> str: product_name: Optional[str] = None,
version: Optional[str] = None) -> Tuple[str, str]:
""" """
保存模型到统一位置 保存模型到统一位置并自动管理版本
参数: 参数:
model_data: 包含模型状态和配置的字典 ...
product_id: 产品ID version: (可选) 如果提供则覆盖自动版本控制 ( 'best')
model_type: 模型类型
version: 版本号
store_id: 店铺ID (可选)
training_mode: 训练模式
aggregation_method: 聚合方法 (可选)
product_name: 产品名称 (可选)
返回: 返回:
模型文件路径 (模型文件路径, 使用的版本号)
""" """
if version is None:
next_version_num = self._get_next_version(product_id, model_type, store_id, training_mode)
version_str = f"v{next_version_num}"
else:
version_str = version
filename = self.generate_model_filename( filename = self.generate_model_filename(
product_id, model_type, version, store_id, training_mode, aggregation_method product_id, model_type, version_str, store_id, training_mode, aggregation_method
) )
# 统一保存到根目录,避免复杂的子目录结构 # 统一保存到根目录,避免复杂的子目录结构
@ -86,7 +108,7 @@ class ModelManager:
'product_id': product_id, 'product_id': product_id,
'product_name': product_name or product_id, 'product_name': product_name or product_id,
'model_type': model_type, 'model_type': model_type,
'version': version, 'version': version_str,
'store_id': store_id, 'store_id': store_id,
'training_mode': training_mode, 'training_mode': training_mode,
'aggregation_method': aggregation_method, 'aggregation_method': aggregation_method,
@ -99,7 +121,7 @@ class ModelManager:
torch.save(enhanced_model_data, model_path) torch.save(enhanced_model_data, model_path)
print(f"模型已保存: {model_path}") print(f"模型已保存: {model_path}")
return model_path return model_path, version_str
def list_models(self, def list_models(self,
product_id: Optional[str] = None, product_id: Optional[str] = None,

View File

@ -1,101 +1,150 @@
# 新需求开发流程 # 新需求开发标准流程
本文档旨在提供一个标准、安全的新功能开发工作流,涵盖从创建功能分支到最终合并回主开发分支的完整步骤。 本文档旨在提供一个标准、安全、高效的新功能开发工作流,涵盖从创建功能分支到最终合并回主开发分支的完整步骤,并融入日常开发的最佳实践
## 核心流程 ## 核心开发理念
采用功能分支Feature Branch的工作模式主要步骤如下 - **主分支保护**: `lyf-dev` 是团队的主开发分支,应始终保持稳定和可部署状态。所有新功能开发都必须在独立的功能分支中进行。
- **功能分支**: 每个新需求(如 `req0001`)都对应一个功能分支(如 `lyf-dev-req0001`)。分支命名应清晰、有意义。
1. **同步与切换**:将远程创建的新功能分支(如 `lyf-dev-req0001`)同步到本地并切换。 - **小步快跑**: 频繁提交Commit、频繁推送Push、频繁与主线同步`rebase``merge`)。这能有效减少后期合并的难度和风险。
2. **开发与提交**:在本地功能分支上进行开发,并频繁提交改动。 - **清晰的历史**: 保持 Git 提交历史的可读性方便代码审查Code Review和问题追溯。
3. **推送与备份**:将本地的改动推送到远程功能分支。
4. **合并回主线**:当功能开发测试完成后,将其合并回主开发分支(如 `lyf-dev`)。
5. **清理分支**:合并完成后,清理已完成使命的功能分支。
--- ---
## 详细命令使用步骤 ## 每日工作第一步:同步最新代码
### 步骤一:同步远程分支到本地 **无论你昨天工作到哪里,每天开始新一天的工作时,请务必执行以下步骤。这是保证团队高效协作、避免合并冲突的基石。**
假设您的同事已经在远程仓库基于 `lyf-dev` 创建了 `lyf-dev-req0001` 分支。 1. **更新主开发分支 `lyf-dev`**
1. **获取远程所有最新信息**
这个命令会把远程仓库的新分支信息下载到你的本地,但不会做任何修改。
```bash ```bash
git fetch origin # 切换到主开发分支
git checkout lyf-dev
# 从远程拉取最新代码,--prune 会清理远程已删除的分支引用
git pull origin lyf-dev --prune
``` ```
2. **创建并切换到本地功能分支** 2. **同步你的功能分支 (团队选择一种方案)**
这个命令会在本地创建一个名为 `lyf-dev-req0001` 的新分支,并自动设置它跟踪远程的 `origin/lyf-dev-req0001` 分支。 将主分支的最新代码同步到你的功能分支,有两种主流方案,请团队根据偏好选择其一。
---
### 方案一 (推荐): 使用 `rebase` 保持历史清爽
此方案会让你的分支提交历史保持为一条直线,非常清晰。
```bash ```bash
# 切换回你正在开发的功能分支(例如 lyf-dev-req0001
git checkout lyf-dev-req0001 git checkout lyf-dev-req0001
# 使用 rebase 将 lyf-dev 的最新更新同步到你的分支
git rebase lyf-dev
``` ```
现在,您已经处于一个干净、独立的功能分支上,可以开始开发了。 - **优点**: 最终的提交历史非常干净、线性,便于代码审查和问题追溯。
- **缺点**: 重写了提交历史,需要使用 `git push --force-with-lease` 强制推送。
- **冲突解决**:
1. 手动修改冲突文件。
2. 执行 `git add <冲突文件>`
3. 执行 `git rebase --continue`
4. 若想中止,执行 `git rebase --abort`
### 步骤二:在功能分支上开发 ---
### 方案二: 使用 `merge` 保留完整历史
在这个分支上进行的所有修改都与 `lyf-dev` 无关,可以放心操作。 此方案会忠实记录每一次合并操作,不修改历史提交
1. **修改代码**:根据需求添加、修改文件。
2. **提交改动**
建议小步快跑,完成一个小的功能点就提交一次。
```bash ```bash
# 将所有修改添加到暂存区 # 切换回你正在开发的功能分支(例如 lyf-dev-req0001
git checkout lyf-dev-req0001
# 将最新的 lyf-dev 合并到你当前的分支
git merge lyf-dev
```
- **优点**: 操作安全,不修改历史,推送时无需强制。
- **缺点**: 会在功能分支中产生额外的合并提交记录 (e.g., "Merge branch 'lyf-dev' into ..."),使历史记录变得复杂。
- **冲突解决**:
1. 手动修改冲突文件。
2. 执行 `git add <冲突文件>`
3. 执行 `git commit` 完成合并。
---
## 完整开发流程
### 1. 开始新需求:创建功能分支
**当你需要开启一个全新的功能开发时:**
1. **确保 `lyf-dev` 已是最新**
(此步骤已在“每日工作第一步”中完成,此处作为提醒)
2. **从 `lyf-dev` 创建并切换到新分支**
假设新需求编号是 `req0002`
```bash
# 这会从最新的 lyf-dev 创建 lyf-dev-req0002 分支并切换过去
git checkout -b lyf-dev-req0002
```
### 2. 日常开发:提交与推送
**在你的功能分支上(如 `lyf-dev-req0002`)进行开发:**
1. **编码与本地提交**
完成一个小的、完整的功能点后,就进行一次提交。
```bash
# 查看修改状态
git status
# 添加所有相关文件到暂存区
git add . git add .
# 提交并撰写清晰的说明feat: 功能, fix: 修复, docs: 文档等)
# 提交并撰写清晰的说明 git commit -m "feat: 实现用户认证模块"
git commit -m "feat: 完成用户登录接口"
``` ```
3. **推送到远程功能分支** 2. **推送改动到远程备份**
为了备份代码或与他人协作,需要将本地的提交推送到远程。 为了代码安全和方便团队协作,应频繁将本地提交推送到远程。
```bash ```bash
git push origin lyf-dev-req0001 # -u 参数会设置本地分支跟踪远程分支,后续只需 git push 即可
git push -u origin lyf-dev-req0002
``` ```
### 步骤三:合并功能到主开发分支 (`lyf-dev`) ### 3. 功能完成:合并回主线
当新功能开发完成,并且经过充分测试后,执行以下步骤将其合并回 `lyf-dev` **当功能开发完成并通过测试后,将其合并回 `lyf-dev`**
1. **切换回主开发分支** 1. **最后一次同步**
在正式合并前,做最后一次同步,确保分支包含了 `lyf-dev` 的所有最新内容。
(重复“每日工作第一步”中的同步流程)
2. **切换到主分支并拉取最新代码**
```bash ```bash
git checkout lyf-dev git checkout lyf-dev
```
2. **确保 `lyf-dev` 是最新的**
在合并前,务必先从远程拉取 `lyf-dev` 的最新代码,以防他人在此期间有更新。
```bash
git pull origin lyf-dev git pull origin lyf-dev
``` ```
3. **合并功能分支** 3. **合并功能分支**
这是最关键的一步,将 `lyf-dev-req0001` 的所有新功能合并进来。 我们使用 `--no-ff` (No Fast-forward) 参数来创建合并提交,这样可以清晰地记录“合并了一个功能”这个行为。
```bash ```bash
git merge lyf-dev-req0001 # --no-ff 会创建一个新的合并提交,保留分支历史
git merge --no-ff lyf-dev-req0002
``` ```
- **无冲突**Git 会自动完成合并。 如果同步工作做得好,这一步通常不会有冲突。
- **有冲突 (Conflict)**Git 会提示你哪些文件冲突了。你需要手动解决这些文件中的冲突,然后执行 `git add .``git commit` 来完成合并。
4. **推送合并后的 `lyf-dev`** 4. **推送合并后的主分支**
将本地合并好的 `lyf-dev` 分支推送到远程仓库。
```bash ```bash
git push origin lyf-dev git push origin lyf-dev
``` ```
### 步骤四:清理分支(可选) ### 4. 清理工作
合并完成后,功能分支的历史使命就完成了。为了保持仓库的整洁,可以删除它。 **合并完成后,功能分支的历史使命就完成了**
1. **删除远程分支** 1. **删除远程分支**
```bash ```bash
git push origin --delete lyf-dev-req0001 git push origin --delete lyf-dev-req0002
``` ```
2. **删除本地分支** 2. **删除本地分支**
```bash ```bash
git branch -d lyf-dev-req0001 git branch -d lyf-dev-req0002
``` ```
遵循以上流程,可以确保的开发工作流清晰、安全且高效。 遵循以上流程,可以确保团队的开发工作流清晰、安全且高效。