修改预测界面不显示best版本和不能预测
This commit is contained in:
parent
d5ec662070
commit
290e402181
@ -116,7 +116,7 @@ const filters = reactive({
|
|||||||
|
|
||||||
const pagination = reactive({
|
const pagination = reactive({
|
||||||
currentPage: 1,
|
currentPage: 1,
|
||||||
pageSize: 8
|
pageSize: 12
|
||||||
})
|
})
|
||||||
|
|
||||||
const filteredModelList = computed(() => {
|
const filteredModelList = computed(() => {
|
||||||
|
@ -124,7 +124,7 @@ const filters = reactive({
|
|||||||
|
|
||||||
const pagination = reactive({
|
const pagination = reactive({
|
||||||
currentPage: 1,
|
currentPage: 1,
|
||||||
pageSize: 8
|
pageSize: 12
|
||||||
})
|
})
|
||||||
|
|
||||||
const filteredModelList = computed(() => {
|
const filteredModelList = computed(() => {
|
||||||
|
@ -126,7 +126,7 @@ const filters = reactive({
|
|||||||
|
|
||||||
const pagination = reactive({
|
const pagination = reactive({
|
||||||
currentPage: 1,
|
currentPage: 1,
|
||||||
pageSize: 8
|
pageSize: 12
|
||||||
})
|
})
|
||||||
|
|
||||||
const filteredModelList = computed(() => {
|
const filteredModelList = computed(() => {
|
||||||
|
Binary file not shown.
@ -1329,10 +1329,17 @@ def predict():
|
|||||||
if not model_uid:
|
if not model_uid:
|
||||||
return jsonify({"status": "error", "message": "缺少 'model_uid' 参数"}), 400
|
return jsonify({"status": "error", "message": "缺少 'model_uid' 参数"}), 400
|
||||||
|
|
||||||
|
# 智能处理带 '_best' 后缀的UID,以解决前端key冲突问题 (兼容旧版Python)
|
||||||
|
is_best_request = model_uid.endswith('_best')
|
||||||
|
if is_best_request:
|
||||||
|
db_query_uid = model_uid[:-len('_best')]
|
||||||
|
else:
|
||||||
|
db_query_uid = model_uid
|
||||||
|
|
||||||
# 从数据库查找模型记录
|
# 从数据库查找模型记录
|
||||||
model_record = find_model_by_uid(model_uid)
|
model_record = find_model_by_uid(db_query_uid)
|
||||||
if not model_record:
|
if not model_record:
|
||||||
return jsonify({"status": "error", "message": f"模型UID '{model_uid}' 不存在"}), 404
|
return jsonify({"status": "error", "message": f"找不到基础模型UID '{db_query_uid}'"}), 404
|
||||||
|
|
||||||
# 解析必要的模型元数据
|
# 解析必要的模型元数据
|
||||||
model_type = model_record.get('model_type')
|
model_type = model_record.get('model_type')
|
||||||
@ -1341,7 +1348,17 @@ def predict():
|
|||||||
|
|
||||||
# 解析 artifacts 找到模型文件路径
|
# 解析 artifacts 找到模型文件路径
|
||||||
artifacts = json.loads(model_record.get('artifacts', '{}'))
|
artifacts = json.loads(model_record.get('artifacts', '{}'))
|
||||||
model_file_path = artifacts.get('best_model') or artifacts.get('versioned_model')
|
|
||||||
|
# 根据请求类型选择正确的模型路径
|
||||||
|
if is_best_request:
|
||||||
|
model_file_path = artifacts.get('best_model')
|
||||||
|
else:
|
||||||
|
# 对于 v1, v2 等版本,使用 versioned_model
|
||||||
|
model_file_path = artifacts.get('versioned_model')
|
||||||
|
|
||||||
|
# 如果特定版本路径不存在,提供一个后备方案,增加鲁棒性
|
||||||
|
if not model_file_path:
|
||||||
|
model_file_path = artifacts.get('best_model') or artifacts.get('versioned_model')
|
||||||
|
|
||||||
# 修正路径问题:将相对路径转换为绝对路径以进行可靠的文件检查
|
# 修正路径问题:将相对路径转换为绝对路径以进行可靠的文件检查
|
||||||
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||||
@ -1412,10 +1429,13 @@ def predict():
|
|||||||
analysis_result = prediction_result.get('analysis', {}) if prediction_result else {}
|
analysis_result = prediction_result.get('analysis', {}) if prediction_result else {}
|
||||||
metrics_result = analysis_result.get('metrics', {}) if analysis_result else {}
|
metrics_result = analysis_result.get('metrics', {}) if analysis_result else {}
|
||||||
|
|
||||||
|
# 准备要保存到历史记录的model_type,确保与前端显示一致
|
||||||
|
history_model_type = f"{model_type}(best)" if is_best_request else model_type
|
||||||
|
|
||||||
db_payload = {
|
db_payload = {
|
||||||
"prediction_uid": prediction_uid,
|
"prediction_uid": prediction_uid,
|
||||||
"model_id": model_uid,
|
"model_id": model_uid,
|
||||||
"model_type": model_type,
|
"model_type": history_model_type, # 使用处理后的一致性名称
|
||||||
"product_name": product_name_to_save, # 使用修正后的名称
|
"product_name": product_name_to_save, # 使用修正后的名称
|
||||||
"model_version": model_record.get('display_name'), # 将模型信息保存到新字段
|
"model_version": model_record.get('display_name'), # 将模型信息保存到新字段
|
||||||
"prediction_scope": {"product_id": product_id, "store_id": store_id},
|
"prediction_scope": {"product_id": product_id, "store_id": store_id},
|
||||||
|
@ -86,6 +86,11 @@ register_predictor('default', default_pytorch_predictor)
|
|||||||
register_predictor('xgboost', default_pytorch_predictor)
|
register_predictor('xgboost', default_pytorch_predictor)
|
||||||
# 将新模型也注册给默认预测器
|
# 将新模型也注册给默认预测器
|
||||||
register_predictor('cnn_bilstm_attention', default_pytorch_predictor)
|
register_predictor('cnn_bilstm_attention', default_pytorch_predictor)
|
||||||
|
register_predictor('transformer', default_pytorch_predictor)
|
||||||
|
register_predictor('mlstm', default_pytorch_predictor)
|
||||||
|
register_predictor('kan', default_pytorch_predictor)
|
||||||
|
register_predictor('optimized_kan', default_pytorch_predictor)
|
||||||
|
register_predictor('tcn', default_pytorch_predictor)
|
||||||
|
|
||||||
|
|
||||||
def load_model_and_predict(model_path: str, product_id: str, model_type: str, store_id: Optional[str] = None, future_days: int = 7, start_date: Optional[str] = None, analyze_result: bool = False, version: Optional[str] = None, training_mode: str = 'product', history_lookback_days: int = 30):
|
def load_model_and_predict(model_path: str, product_id: str, model_type: str, store_id: Optional[str] = None, future_days: int = 7, start_date: Optional[str] = None, analyze_result: bool = False, version: Optional[str] = None, training_mode: str = 'product', history_lookback_days: int = 30):
|
||||||
|
@ -84,26 +84,48 @@ def query_models_from_db(filters: dict, page: int = 1, page_size: int = 10):
|
|||||||
if conditions:
|
if conditions:
|
||||||
query += " WHERE " + " AND ".join(conditions)
|
query += " WHERE " + " AND ".join(conditions)
|
||||||
|
|
||||||
# 获取总数
|
# 修复分页逻辑:先查询所有符合条件的模型,再在内存中进行增强和分页
|
||||||
count_query = query.replace("*", "COUNT(*)")
|
query += " ORDER BY created_at DESC"
|
||||||
cursor.execute(count_query, params)
|
|
||||||
total_count = cursor.fetchone()[0]
|
|
||||||
|
|
||||||
# 添加分页
|
|
||||||
offset = (page - 1) * page_size
|
|
||||||
query += " ORDER BY created_at DESC LIMIT ? OFFSET ?"
|
|
||||||
params.extend([page_size, offset])
|
|
||||||
|
|
||||||
cursor.execute(query, params)
|
cursor.execute(query, params)
|
||||||
models = [dict(row) for row in cursor.fetchall()]
|
all_models = [dict(row) for row in cursor.fetchall()]
|
||||||
|
|
||||||
|
# 增强逻辑:如果模型同时有best版本,则额外生成一条记录
|
||||||
|
enhanced_models = []
|
||||||
|
for model in all_models:
|
||||||
|
# 首先添加原始的版本记录
|
||||||
|
enhanced_models.append(model)
|
||||||
|
|
||||||
|
try:
|
||||||
|
artifacts = json.loads(model.get('artifacts', '{}'))
|
||||||
|
if artifacts and artifacts.get('best_model'):
|
||||||
|
original_version = model.get('version', 'N/A')
|
||||||
|
best_model_record = model.copy()
|
||||||
|
best_model_record['version'] = f"best({original_version})"
|
||||||
|
best_model_record['model_uid'] = f"{model['model_uid']}_best"
|
||||||
|
best_model_record['model_type'] = f"{model['model_type']}(best)"
|
||||||
|
|
||||||
|
new_artifacts = {
|
||||||
|
'best_model': artifacts.get('best_model'),
|
||||||
|
'scaler_X': artifacts.get('scaler_X'),
|
||||||
|
'scaler_y': artifacts.get('scaler_y')
|
||||||
|
}
|
||||||
|
best_model_record['artifacts'] = json.dumps(new_artifacts)
|
||||||
|
enhanced_models.append(best_model_record)
|
||||||
|
except (json.JSONDecodeError, TypeError):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 在内存中进行分页
|
||||||
|
total_count = len(enhanced_models)
|
||||||
|
offset = (page - 1) * page_size
|
||||||
|
paginated_models = enhanced_models[offset : offset + page_size]
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"models": models,
|
"models": paginated_models,
|
||||||
"pagination": {
|
"pagination": {
|
||||||
"total": total_count,
|
"total": total_count,
|
||||||
"page": page,
|
"page": page,
|
||||||
"page_size": page_size,
|
"page_size": page_size,
|
||||||
"total_pages": (total_count + page_size - 1) // page_size
|
"total_pages": (total_count + page_size - 1) // page_size if page_size > 0 else 0
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user