-完善店铺预测模块
This commit is contained in:
parent
a18c8dddf9
commit
a1d9c60e61
@ -208,11 +208,15 @@ const startPrediction = async () => {
|
||||
future_days: form.future_days,
|
||||
start_date: form.start_date,
|
||||
analyze_result: form.analyze_result,
|
||||
store_id: form.store_id
|
||||
store_id: form.store_id,
|
||||
// 修正:对于店铺模型,product_id应传递店铺的标识符
|
||||
product_id: `store_${form.store_id}`
|
||||
}
|
||||
const response = await axios.post('/api/predict', payload)
|
||||
// 修正API端点
|
||||
const response = await axios.post('/api/prediction', payload)
|
||||
if (response.data.status === 'success') {
|
||||
predictionResult.value = response.data.data
|
||||
// 修正:数据现在直接在响应的顶层
|
||||
predictionResult.value = response.data
|
||||
ElMessage.success('预测完成!')
|
||||
await nextTick()
|
||||
renderChart()
|
||||
@ -231,30 +235,58 @@ const renderChart = () => {
|
||||
if (chart) {
|
||||
chart.destroy()
|
||||
}
|
||||
const predictions = predictionResult.value.predictions
|
||||
const labels = predictions.map(p => p.date)
|
||||
const data = predictions.map(p => p.sales)
|
||||
|
||||
const historyData = predictionResult.value.history_data || []
|
||||
const predictionData = predictionResult.value.prediction_data || []
|
||||
|
||||
const labels = [
|
||||
...historyData.map(p => p.date),
|
||||
...predictionData.map(p => p.date)
|
||||
]
|
||||
|
||||
const historySales = historyData.map(p => p.sales)
|
||||
// 预测数据需要填充与历史数据等长的null值,以保证图表正确对齐
|
||||
const predictionSales = [
|
||||
...Array(historyData.length).fill(null),
|
||||
...predictionData.map(p => p.predicted_sales)
|
||||
]
|
||||
|
||||
chart = new Chart(chartCanvas.value, {
|
||||
type: 'line',
|
||||
data: {
|
||||
labels,
|
||||
datasets: [{
|
||||
label: '预测销量',
|
||||
data,
|
||||
borderColor: '#409EFF',
|
||||
backgroundColor: 'rgba(64, 158, 255, 0.1)',
|
||||
tension: 0.4,
|
||||
fill: true
|
||||
}]
|
||||
datasets: [
|
||||
{
|
||||
label: '历史销量',
|
||||
data: historySales,
|
||||
borderColor: '#67C23A',
|
||||
backgroundColor: 'rgba(103, 194, 58, 0.1)',
|
||||
fill: false,
|
||||
tension: 0.4
|
||||
},
|
||||
{
|
||||
label: '预测销量',
|
||||
data: predictionSales,
|
||||
borderColor: '#409EFF',
|
||||
backgroundColor: 'rgba(64, 158, 255, 0.1)',
|
||||
borderDash: [5, 5], // 虚线
|
||||
fill: false,
|
||||
tension: 0.4
|
||||
}
|
||||
]
|
||||
},
|
||||
options: {
|
||||
responsive: true,
|
||||
plugins: {
|
||||
title: {
|
||||
display: true,
|
||||
text: '销量预测趋势图'
|
||||
text: '店铺销量历史与预测趋势图'
|
||||
}
|
||||
}
|
||||
},
|
||||
interaction: {
|
||||
intersect: false,
|
||||
mode: 'index',
|
||||
},
|
||||
}
|
||||
})
|
||||
}
|
||||
|
@ -121,3 +121,49 @@
|
||||
### 11:45 - 项目总结与文档归档
|
||||
- **任务**: 根据用户要求,回顾整个调试过程,将所有问题、解决方案、优化思路和最终结论,按照日期和时间顺序,整理并更新到本开发日志中,形成一份高质量的技术档案。
|
||||
- **结果**: 本文档已更新完成。
|
||||
|
||||
|
||||
### 13:15 - 最终修复:根治模型标识符不一致问题
|
||||
- **问题**: 经过再次测试和日志分析,发现即便是修正后,店铺模型的 `model_identifier` 在训练时依然被错误地构建为 `01010023_store_01010023`。
|
||||
- **根本原因**: `server/core/predictor.py` 的 `train_model` 方法中,在 `training_mode == 'store'` 的分支下,构建 `model_identifier` 的逻辑存在冗余和错误。
|
||||
- **最终解决方案**: 删除了错误的拼接逻辑 `model_identifier = f"{store_id}_{product_id}"`,直接使用在之前步骤中已经被正确赋值为 `f"store_{store_id}"` 的 `product_id` 变量作为 `model_identifier`。这确保了从训练、保存到最终API查询,店铺模型的唯一标识符始终保持一致。
|
||||
|
||||
|
||||
### 13:30 - 最终修复(第二轮):根治模型保存路径错误
|
||||
- **问题**: 即便修复了标识符,模型版本依然无法加载。
|
||||
- **根本原因**: 通过分析训练日志,发现所有训练器(`transformer_trainer.py`, `mlstm_trainer.py`, `tcn_trainer.py`)中的 `save_checkpoint` 函数,都会强制在 `saved_models` 目录下创建一个 `checkpoints` 子目录,并将所有模型文件保存在其中。而负责查找模型的 `get_model_versions` 函数只在根目录查找,导致模型永远无法被发现。
|
||||
- **最终解决方案**: 逐一修改了所有相关训练器文件中的 `save_checkpoint` 函数,移除了创建和使用 `checkpoints` 子目录的逻辑,确保所有模型都直接保存在 `saved_models` 根目录下。
|
||||
- **结论**: 至此,模型保存的路径与查找的路径完全统一,从根本上解决了模型版本无法加载的问题。
|
||||
|
||||
|
||||
### 13:40 - 最终修复(第三轮):统一所有训练器的模型保存逻辑
|
||||
- **问题**: 在修复了 `transformer_trainer.py` 后,发现 `mlstm_trainer.py` 和 `tcn_trainer.py` 存在完全相同的路径和命名错误,导致问题依旧。
|
||||
- **根本原因**: `save_checkpoint` 函数在所有训练器中都被错误地实现,它们都强制创建了 `checkpoints` 子目录,并使用了错误的逻辑来拼接文件名。
|
||||
- **最终解决方案**:
|
||||
1. **逐一修复**: 逐一修改了 `transformer_trainer.py`, `mlstm_trainer.py`, 和 `tcn_trainer.py` 中的 `save_checkpoint` 函数。
|
||||
2. **路径修复**: 移除了创建和使用 `checkpoints` 子目录的逻辑,确保模型直接保存在 `model_dir` (即 `saved_models`) 的根目录下。
|
||||
3. **文件名修复**: 简化并修正了文件名的生成逻辑,直接使用 `product_id` 参数作为唯一标识符(该参数已由上游逻辑正确赋值为 `药品ID` 或 `store_{店铺ID}`),不再进行任何额外的、错误的拼接。
|
||||
- **结论**: 至此,所有训练器的模型保存逻辑完全统一,模型保存的路径和文件名与API的查找逻辑完全匹配,从根本上解决了模型版本无法加载的问题。
|
||||
|
||||
|
||||
---
|
||||
|
||||
## 2025-07-16 (续):端到端修复“店铺预测”图表功能
|
||||
**开发者**: lyf
|
||||
|
||||
### 15:30 - 最终修复(第四轮):打通店铺预测的数据流
|
||||
- **问题**: 在解决了模型加载问题后,“店铺预测”功能虽然可以成功执行,但前端图表依然空白,不显示历史数据和预测数据。
|
||||
- **根本原因**: 参数传递在调用链中出现断裂。
|
||||
1. `server/api.py` 在调用 `run_prediction` 时,没有传递 `training_mode`。
|
||||
2. `server/core/predictor.py` 在调用 `load_model_and_predict` 时,没有传递 `store_id` 和 `training_mode`。
|
||||
3. `server/predictors/model_predictor.py` 内部的数据加载逻辑,在处理店铺预测时,错误地使用了模型标识符(`store_{id}`)作为产品ID来过滤数据,导致无法加载到任何历史数据。
|
||||
- **最终解决方案 (三步修复)**:
|
||||
1. **修复 `model_predictor.py`**: 修改 `load_model_and_predict` 函数,使其能够根据 `training_mode` 参数智能地加载数据。当模式为 `'store'` 时,它会正确地聚合该店铺的所有销售数据作为历史数据,这与训练时的数据准备方式完全一致。
|
||||
2. **修复 `predictor.py`**: 修改 `predict` 方法,将 `store_id` 和 `training_mode` 参数正确地传递给底层的 `load_model_and_predict` 函数。
|
||||
3. **修复 `api.py`**: 修改 `predict` 路由和 `run_prediction` 辅助函数,确保 `training_mode` 参数在整个调用链中被完整传递。
|
||||
- **结论**: 通过以上修复,我们确保了从API接口到最底层数据加载器的参数传递是完整和正确的。现在,无论是药品预测还是店铺预测,系统都能够加载正确的历史数据用于图表绘制,彻底解决了图表显示空白的问题。
|
||||
|
||||
### 16:16 - 项目状态更新
|
||||
- **状态**: **所有已知问题已修复**。
|
||||
- **确认**: 用户已确认“现在药品和店铺预测流程通了”。
|
||||
- **后续**: 将本次修复过程归档至本文档。
|
||||
|
Binary file not shown.
@ -1510,7 +1510,8 @@ def predict():
|
||||
data = request.json
|
||||
product_id = data.get('product_id')
|
||||
model_type = data.get('model_type')
|
||||
store_id = data.get('store_id') # 新增店铺ID参数
|
||||
store_id = data.get('store_id')
|
||||
training_mode = 'store' if store_id else 'product'
|
||||
version = data.get('version') # 新增版本参数
|
||||
future_days = int(data.get('future_days', 7))
|
||||
start_date = data.get('start_date', '')
|
||||
@ -1527,28 +1528,31 @@ def predict():
|
||||
if not product_name:
|
||||
product_name = product_id
|
||||
|
||||
# 根据版本获取模型ID
|
||||
if version:
|
||||
# 如果指定了版本,构造版本化的模型ID
|
||||
model_id = f"{product_id}_{model_type}_{version}"
|
||||
# 检查指定版本的模型是否存在
|
||||
model_file_path = get_model_file_path(product_id, model_type, version)
|
||||
if not os.path.exists(model_file_path):
|
||||
return jsonify({"status": "error", "error": f"未找到产品 {product_id} 的 {model_type} 类型模型版本 {version}"}), 404
|
||||
# 根据训练模式构建模型标识符
|
||||
if training_mode == 'store':
|
||||
model_identifier = f"store_{store_id}"
|
||||
# 对于店铺预测,product_id实际上是store_id,但我们需要一个药品ID来获取名称,这里暂时用一个占位符
|
||||
product_name = f"店铺 {store_id} 整体"
|
||||
else:
|
||||
# 如果没有指定版本,使用最新版本
|
||||
latest_version = get_latest_model_version(product_id, model_type)
|
||||
if latest_version:
|
||||
model_id = f"{product_id}_{model_type}_{latest_version}"
|
||||
version = latest_version
|
||||
else:
|
||||
# 兼容旧的无版本模型
|
||||
model_id = get_latest_model_id(model_type, product_id)
|
||||
if not model_id:
|
||||
return jsonify({"status": "error", "error": f"未找到产品 {product_id} 的 {model_type} 类型模型"}), 404
|
||||
model_identifier = product_id
|
||||
product_name = get_product_name(product_id) or product_id
|
||||
|
||||
# 获取模型版本
|
||||
if not version:
|
||||
version = get_latest_model_version(model_identifier, model_type)
|
||||
|
||||
if not version:
|
||||
return jsonify({"status": "error", "error": f"未找到标识符为 {model_identifier} 的 {model_type} 类型模型"}), 404
|
||||
|
||||
# 检查模型文件是否存在
|
||||
model_file_path = get_model_file_path(model_identifier, model_type, version)
|
||||
if not os.path.exists(model_file_path):
|
||||
return jsonify({"status": "error", "error": f"未找到模型文件: {model_file_path}"}), 404
|
||||
|
||||
model_id = f"{model_identifier}_{model_type}_{version}"
|
||||
|
||||
# 执行预测
|
||||
prediction_result = run_prediction(model_type, product_id, model_id, future_days, start_date, version, store_id)
|
||||
prediction_result = run_prediction(model_type, product_id, model_id, future_days, start_date, version, store_id, training_mode)
|
||||
|
||||
if prediction_result is None:
|
||||
return jsonify({"status": "error", "error": "预测失败,预测器返回None"}), 500
|
||||
@ -2708,7 +2712,7 @@ def get_product_name(product_id):
|
||||
return None
|
||||
|
||||
# 执行预测的辅助函数
|
||||
def run_prediction(model_type, product_id, model_id, future_days, start_date, version=None, store_id=None):
|
||||
def run_prediction(model_type, product_id, model_id, future_days, start_date, version=None, store_id=None, training_mode='product'):
|
||||
"""执行模型预测"""
|
||||
try:
|
||||
scope_msg = f", store_id={store_id}" if store_id else ", 全局模型"
|
||||
@ -2729,7 +2733,8 @@ def run_prediction(model_type, product_id, model_id, future_days, start_date, ve
|
||||
store_id=store_id,
|
||||
future_days=future_days,
|
||||
start_date=start_date,
|
||||
version=version
|
||||
version=version,
|
||||
training_mode=training_mode
|
||||
)
|
||||
|
||||
if prediction_result is None:
|
||||
|
@ -131,10 +131,11 @@ def get_model_file_path(product_id: str, model_type: str, version: str) -> str:
|
||||
filename = f"{model_type}_model_product_{product_id}.pth"
|
||||
return os.path.join(DEFAULT_MODEL_DIR, filename)
|
||||
|
||||
# 处理新的、基于epoch的检查点命名格式
|
||||
# 文件名示例: transformer_product_17002608_epoch_best.pth
|
||||
filename = f"{model_type}_product_{product_id}_epoch_{version}.pth"
|
||||
return os.path.join(DEFAULT_MODEL_DIR, 'checkpoints', filename)
|
||||
# 修正:直接使用唯一的product_id(它可能包含store_前缀)来构建文件名
|
||||
# 文件名示例: transformer_17002608_epoch_best.pth 或 transformer_store_01010023_epoch_best.pth
|
||||
filename = f"{model_type}_{product_id}_epoch_{version}.pth"
|
||||
# 修正:直接在根模型目录查找,不再使用checkpoints子目录
|
||||
return os.path.join(DEFAULT_MODEL_DIR, filename)
|
||||
|
||||
def get_model_versions(product_id: str, model_type: str) -> list:
|
||||
"""
|
||||
@ -149,9 +150,10 @@ def get_model_versions(product_id: str, model_type: str) -> list:
|
||||
"""
|
||||
# 直接使用传入的product_id构建搜索模式
|
||||
# 搜索模式,匹配 "transformer_product_17002608_epoch_50.pth" 或 "transformer_product_17002608_epoch_best.pth"
|
||||
pattern = f"{model_type}_product_{product_id}_epoch_*.pth"
|
||||
# 在 checkpoints 子目录中查找
|
||||
search_path = os.path.join(DEFAULT_MODEL_DIR, 'checkpoints', pattern)
|
||||
# 修正:直接使用唯一的product_id(它可能包含store_前缀)来构建搜索模式
|
||||
pattern = f"{model_type}_{product_id}_epoch_*.pth"
|
||||
# 修正:直接在根模型目录查找,不再使用checkpoints子目录
|
||||
search_path = os.path.join(DEFAULT_MODEL_DIR, pattern)
|
||||
existing_files = glob.glob(search_path)
|
||||
|
||||
# 旧格式(兼容性支持)
|
||||
|
@ -132,8 +132,8 @@ class PharmacyPredictor:
|
||||
file_path=self.data_path
|
||||
)
|
||||
log_message(f"按店铺聚合训练: 店铺 {store_id}, 聚合方法 {aggregation_method}, 数据量: {len(product_data)}")
|
||||
# 将product_id设置为店铺ID,以便模型保存时使用有意义的标识
|
||||
product_id = store_id
|
||||
# 将product_id设置为'store_{store_id}',与API查找逻辑保持一致
|
||||
product_id = f"store_{store_id}"
|
||||
except Exception as e:
|
||||
log_message(f"聚合店铺 {store_id} 数据失败: {e}", 'error')
|
||||
return None
|
||||
@ -179,7 +179,7 @@ class PharmacyPredictor:
|
||||
|
||||
# 根据训练模式构建模型标识符
|
||||
if training_mode == 'store':
|
||||
model_identifier = f"{store_id}_{product_id}"
|
||||
model_identifier = product_id
|
||||
elif training_mode == 'global':
|
||||
model_identifier = f"global_{product_id}_{aggregation_method}"
|
||||
else:
|
||||
@ -308,19 +308,22 @@ class PharmacyPredictor:
|
||||
"""
|
||||
# 根据训练模式构建模型标识符
|
||||
if training_mode == 'store' and store_id:
|
||||
model_identifier = f"{store_id}_{product_id}"
|
||||
# 修正:店铺模型的标识符应该只基于店铺ID
|
||||
model_identifier = f"store_{store_id}"
|
||||
elif training_mode == 'global':
|
||||
model_identifier = f"global_{product_id}_{aggregation_method}"
|
||||
else:
|
||||
model_identifier = product_id
|
||||
|
||||
return load_model_and_predict(
|
||||
model_identifier,
|
||||
model_type,
|
||||
future_days=future_days,
|
||||
start_date=start_date,
|
||||
model_identifier,
|
||||
model_type,
|
||||
store_id=store_id,
|
||||
future_days=future_days,
|
||||
start_date=start_date,
|
||||
analyze_result=analyze_result,
|
||||
version=version
|
||||
version=version,
|
||||
training_mode=training_mode
|
||||
)
|
||||
|
||||
def train_optimized_kan_model(self, product_id, epochs=100, batch_size=32,
|
||||
|
@ -23,7 +23,7 @@ from utils.visualization import plot_prediction_results
|
||||
from utils.multi_store_data_utils import get_store_product_sales_data, aggregate_multi_store_data
|
||||
from core.config import DEVICE, get_model_file_path, DEFAULT_DATA_PATH
|
||||
|
||||
def load_model_and_predict(product_id, model_type, store_id=None, future_days=7, start_date=None, analyze_result=False, version=None):
|
||||
def load_model_and_predict(product_id, model_type, store_id=None, future_days=7, start_date=None, analyze_result=False, version=None, training_mode='product'):
|
||||
"""
|
||||
加载已训练的模型并进行预测
|
||||
|
||||
@ -101,33 +101,37 @@ def load_model_and_predict(product_id, model_type, store_id=None, future_days=7,
|
||||
|
||||
# 加载销售数据(支持多店铺)
|
||||
try:
|
||||
from utils.multi_store_data_utils import load_multi_store_data
|
||||
if store_id:
|
||||
# 加载特定店铺的数据
|
||||
product_df = load_multi_store_data(
|
||||
file_path=DEFAULT_DATA_PATH,
|
||||
from utils.multi_store_data_utils import aggregate_multi_store_data
|
||||
|
||||
# 根据训练模式加载相应的数据
|
||||
if training_mode == 'store' and store_id:
|
||||
# 店铺模型:聚合该店铺的所有产品数据
|
||||
product_df = aggregate_multi_store_data(
|
||||
store_id=store_id,
|
||||
product_id=product_id
|
||||
aggregation_method='sum',
|
||||
file_path=DEFAULT_DATA_PATH
|
||||
)
|
||||
store_name = product_df['store_name'].iloc[0] if 'store_name' in product_df.columns else f"店铺{store_id}"
|
||||
store_name = product_df['store_name'].iloc[0] if 'store_name' in product_df.columns and not product_df.empty else f"店铺{store_id}"
|
||||
prediction_scope = f"店铺 '{store_name}' ({store_id})"
|
||||
# 对于店铺模型,其“产品名称”就是店铺名称
|
||||
product_name = store_name
|
||||
else:
|
||||
# 聚合所有店铺的数据进行预测
|
||||
# 产品模型(默认):聚合该产品在所有店铺的数据
|
||||
# 此时,传入的product_id是真正的产品ID
|
||||
product_df = aggregate_multi_store_data(
|
||||
product_id=product_id,
|
||||
aggregation_method='sum',
|
||||
file_path=DEFAULT_DATA_PATH
|
||||
)
|
||||
prediction_scope = "全部店铺(聚合数据)"
|
||||
product_name = product_df['product_name'].iloc[0] if not product_df.empty else product_id
|
||||
except Exception as e:
|
||||
print(f"加载数据失败: {e}")
|
||||
return None
|
||||
|
||||
|
||||
if product_df.empty:
|
||||
print(f"产品 {product_id} 没有销售数据")
|
||||
print(f"产品 {product_id} 或店铺 {store_id} 没有销售数据")
|
||||
return None
|
||||
|
||||
product_name = product_df['product_name'].iloc[0]
|
||||
print(f"使用 {model_type} 模型预测产品 '{product_name}' (ID: {product_id}) 的未来 {future_days} 天销量")
|
||||
print(f"预测范围: {prediction_scope}")
|
||||
|
||||
|
@ -42,16 +42,12 @@ def save_checkpoint(checkpoint_data: dict, epoch_or_label, product_id: str,
|
||||
aggregation_method: 聚合方法
|
||||
"""
|
||||
# 创建检查点目录
|
||||
checkpoint_dir = os.path.join(model_dir, 'checkpoints')
|
||||
# 直接在模型根目录保存,不再创建子目录
|
||||
checkpoint_dir = model_dir
|
||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||
|
||||
# 生成检查点文件名
|
||||
if training_mode == 'store' and store_id:
|
||||
filename = f"{model_type}_store_{store_id}_{product_id}_epoch_{epoch_or_label}.pth"
|
||||
elif training_mode == 'global' and aggregation_method:
|
||||
filename = f"{model_type}_global_{product_id}_{aggregation_method}_epoch_{epoch_or_label}.pth"
|
||||
else:
|
||||
filename = f"{model_type}_product_{product_id}_epoch_{epoch_or_label}.pth"
|
||||
# 修正:直接使用product_id作为唯一标识符,因为它已经包含了store_前缀或药品ID
|
||||
filename = f"{model_type}_{product_id}_epoch_{epoch_or_label}.pth"
|
||||
|
||||
checkpoint_path = os.path.join(checkpoint_dir, filename)
|
||||
|
||||
|
@ -38,16 +38,13 @@ def save_checkpoint(checkpoint_data: dict, epoch_or_label, product_id: str,
|
||||
aggregation_method: 聚合方法
|
||||
"""
|
||||
# 创建检查点目录
|
||||
checkpoint_dir = os.path.join(model_dir, 'checkpoints')
|
||||
# 直接在模型根目录保存,不再创建子目录
|
||||
checkpoint_dir = model_dir
|
||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||
|
||||
# 生成检查点文件名
|
||||
if training_mode == 'store' and store_id:
|
||||
filename = f"{model_type}_store_{store_id}_{product_id}_epoch_{epoch_or_label}.pth"
|
||||
elif training_mode == 'global' and aggregation_method:
|
||||
filename = f"{model_type}_global_{product_id}_{aggregation_method}_epoch_{epoch_or_label}.pth"
|
||||
else:
|
||||
filename = f"{model_type}_product_{product_id}_epoch_{epoch_or_label}.pth"
|
||||
# 修正:直接使用product_id作为唯一标识符,因为它已经包含了store_前缀或药品ID
|
||||
filename = f"{model_type}_{product_id}_epoch_{epoch_or_label}.pth"
|
||||
|
||||
checkpoint_path = os.path.join(checkpoint_dir, filename)
|
||||
|
||||
|
@ -43,17 +43,12 @@ def save_checkpoint(checkpoint_data: dict, epoch_or_label, product_id: str,
|
||||
training_mode: 训练模式
|
||||
aggregation_method: 聚合方法
|
||||
"""
|
||||
# 创建检查点目录
|
||||
checkpoint_dir = os.path.join(model_dir, 'checkpoints')
|
||||
# 直接在模型根目录保存,不再创建子目录
|
||||
checkpoint_dir = model_dir
|
||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||
|
||||
# 生成检查点文件名
|
||||
if training_mode == 'store' and store_id:
|
||||
filename = f"{model_type}_store_{store_id}_{product_id}_epoch_{epoch_or_label}.pth"
|
||||
elif training_mode == 'global' and aggregation_method:
|
||||
filename = f"{model_type}_global_{product_id}_{aggregation_method}_epoch_{epoch_or_label}.pth"
|
||||
else:
|
||||
filename = f"{model_type}_product_{product_id}_epoch_{epoch_or_label}.pth"
|
||||
# 修正:直接使用product_id作为唯一标识符,因为它已经包含了store_前缀或药品ID
|
||||
filename = f"{model_type}_{product_id}_epoch_{epoch_or_label}.pth"
|
||||
|
||||
checkpoint_path = os.path.join(checkpoint_dir, filename)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user