修改预测界面不显示best版本和不能预测
This commit is contained in:
parent
d5ec662070
commit
290e402181
@ -116,7 +116,7 @@ const filters = reactive({
|
||||
|
||||
const pagination = reactive({
|
||||
currentPage: 1,
|
||||
pageSize: 8
|
||||
pageSize: 12
|
||||
})
|
||||
|
||||
const filteredModelList = computed(() => {
|
||||
|
@ -124,7 +124,7 @@ const filters = reactive({
|
||||
|
||||
const pagination = reactive({
|
||||
currentPage: 1,
|
||||
pageSize: 8
|
||||
pageSize: 12
|
||||
})
|
||||
|
||||
const filteredModelList = computed(() => {
|
||||
|
@ -126,7 +126,7 @@ const filters = reactive({
|
||||
|
||||
const pagination = reactive({
|
||||
currentPage: 1,
|
||||
pageSize: 8
|
||||
pageSize: 12
|
||||
})
|
||||
|
||||
const filteredModelList = computed(() => {
|
||||
|
Binary file not shown.
@ -1329,10 +1329,17 @@ def predict():
|
||||
if not model_uid:
|
||||
return jsonify({"status": "error", "message": "缺少 'model_uid' 参数"}), 400
|
||||
|
||||
# 智能处理带 '_best' 后缀的UID,以解决前端key冲突问题 (兼容旧版Python)
|
||||
is_best_request = model_uid.endswith('_best')
|
||||
if is_best_request:
|
||||
db_query_uid = model_uid[:-len('_best')]
|
||||
else:
|
||||
db_query_uid = model_uid
|
||||
|
||||
# 从数据库查找模型记录
|
||||
model_record = find_model_by_uid(model_uid)
|
||||
model_record = find_model_by_uid(db_query_uid)
|
||||
if not model_record:
|
||||
return jsonify({"status": "error", "message": f"模型UID '{model_uid}' 不存在"}), 404
|
||||
return jsonify({"status": "error", "message": f"找不到基础模型UID '{db_query_uid}'"}), 404
|
||||
|
||||
# 解析必要的模型元数据
|
||||
model_type = model_record.get('model_type')
|
||||
@ -1341,6 +1348,16 @@ def predict():
|
||||
|
||||
# 解析 artifacts 找到模型文件路径
|
||||
artifacts = json.loads(model_record.get('artifacts', '{}'))
|
||||
|
||||
# 根据请求类型选择正确的模型路径
|
||||
if is_best_request:
|
||||
model_file_path = artifacts.get('best_model')
|
||||
else:
|
||||
# 对于 v1, v2 等版本,使用 versioned_model
|
||||
model_file_path = artifacts.get('versioned_model')
|
||||
|
||||
# 如果特定版本路径不存在,提供一个后备方案,增加鲁棒性
|
||||
if not model_file_path:
|
||||
model_file_path = artifacts.get('best_model') or artifacts.get('versioned_model')
|
||||
|
||||
# 修正路径问题:将相对路径转换为绝对路径以进行可靠的文件检查
|
||||
@ -1412,10 +1429,13 @@ def predict():
|
||||
analysis_result = prediction_result.get('analysis', {}) if prediction_result else {}
|
||||
metrics_result = analysis_result.get('metrics', {}) if analysis_result else {}
|
||||
|
||||
# 准备要保存到历史记录的model_type,确保与前端显示一致
|
||||
history_model_type = f"{model_type}(best)" if is_best_request else model_type
|
||||
|
||||
db_payload = {
|
||||
"prediction_uid": prediction_uid,
|
||||
"model_id": model_uid,
|
||||
"model_type": model_type,
|
||||
"model_type": history_model_type, # 使用处理后的一致性名称
|
||||
"product_name": product_name_to_save, # 使用修正后的名称
|
||||
"model_version": model_record.get('display_name'), # 将模型信息保存到新字段
|
||||
"prediction_scope": {"product_id": product_id, "store_id": store_id},
|
||||
|
@ -86,6 +86,11 @@ register_predictor('default', default_pytorch_predictor)
|
||||
register_predictor('xgboost', default_pytorch_predictor)
|
||||
# 将新模型也注册给默认预测器
|
||||
register_predictor('cnn_bilstm_attention', default_pytorch_predictor)
|
||||
register_predictor('transformer', default_pytorch_predictor)
|
||||
register_predictor('mlstm', default_pytorch_predictor)
|
||||
register_predictor('kan', default_pytorch_predictor)
|
||||
register_predictor('optimized_kan', default_pytorch_predictor)
|
||||
register_predictor('tcn', default_pytorch_predictor)
|
||||
|
||||
|
||||
def load_model_and_predict(model_path: str, product_id: str, model_type: str, store_id: Optional[str] = None, future_days: int = 7, start_date: Optional[str] = None, analyze_result: bool = False, version: Optional[str] = None, training_mode: str = 'product', history_lookback_days: int = 30):
|
||||
|
@ -84,26 +84,48 @@ def query_models_from_db(filters: dict, page: int = 1, page_size: int = 10):
|
||||
if conditions:
|
||||
query += " WHERE " + " AND ".join(conditions)
|
||||
|
||||
# 获取总数
|
||||
count_query = query.replace("*", "COUNT(*)")
|
||||
cursor.execute(count_query, params)
|
||||
total_count = cursor.fetchone()[0]
|
||||
|
||||
# 添加分页
|
||||
offset = (page - 1) * page_size
|
||||
query += " ORDER BY created_at DESC LIMIT ? OFFSET ?"
|
||||
params.extend([page_size, offset])
|
||||
|
||||
# 修复分页逻辑:先查询所有符合条件的模型,再在内存中进行增强和分页
|
||||
query += " ORDER BY created_at DESC"
|
||||
cursor.execute(query, params)
|
||||
models = [dict(row) for row in cursor.fetchall()]
|
||||
all_models = [dict(row) for row in cursor.fetchall()]
|
||||
|
||||
# 增强逻辑:如果模型同时有best版本,则额外生成一条记录
|
||||
enhanced_models = []
|
||||
for model in all_models:
|
||||
# 首先添加原始的版本记录
|
||||
enhanced_models.append(model)
|
||||
|
||||
try:
|
||||
artifacts = json.loads(model.get('artifacts', '{}'))
|
||||
if artifacts and artifacts.get('best_model'):
|
||||
original_version = model.get('version', 'N/A')
|
||||
best_model_record = model.copy()
|
||||
best_model_record['version'] = f"best({original_version})"
|
||||
best_model_record['model_uid'] = f"{model['model_uid']}_best"
|
||||
best_model_record['model_type'] = f"{model['model_type']}(best)"
|
||||
|
||||
new_artifacts = {
|
||||
'best_model': artifacts.get('best_model'),
|
||||
'scaler_X': artifacts.get('scaler_X'),
|
||||
'scaler_y': artifacts.get('scaler_y')
|
||||
}
|
||||
best_model_record['artifacts'] = json.dumps(new_artifacts)
|
||||
enhanced_models.append(best_model_record)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
continue
|
||||
|
||||
# 在内存中进行分页
|
||||
total_count = len(enhanced_models)
|
||||
offset = (page - 1) * page_size
|
||||
paginated_models = enhanced_models[offset : offset + page_size]
|
||||
|
||||
return {
|
||||
"models": models,
|
||||
"models": paginated_models,
|
||||
"pagination": {
|
||||
"total": total_count,
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
"total_pages": (total_count + page_size - 1) // page_size
|
||||
"total_pages": (total_count + page_size - 1) // page_size if page_size > 0 else 0
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user