备份,模型管理待完善
This commit is contained in:
parent
88f245b957
commit
af7638aeca
@ -17,7 +17,22 @@
|
||||
|
||||
<el-form :inline="true" @submit.prevent="fetchModels">
|
||||
<el-form-item label="产品ID" style="width:300px">
|
||||
<el-input v-model="filters.product_id" placeholder="按产品ID筛选" clearable></el-input>
|
||||
<el-select
|
||||
v-model="filters.product_id"
|
||||
placeholder="按产品/店铺/全局筛选"
|
||||
clearable
|
||||
filterable
|
||||
remote
|
||||
:remote-method="searchOptions"
|
||||
:loading="searchLoading"
|
||||
>
|
||||
<el-option
|
||||
v-for="item in searchResults"
|
||||
:key="item.value"
|
||||
:label="item.label"
|
||||
:value="item.value"
|
||||
/>
|
||||
</el-select>
|
||||
</el-form-item>
|
||||
<el-form-item label="模型类型" style="width:300px">
|
||||
<el-select v-model="filters.model_type" placeholder="按模型类型筛选" clearable>
|
||||
@ -127,7 +142,7 @@
|
||||
{{ selectedModelDetails.model_info.model_type }}
|
||||
</el-tag>
|
||||
</el-descriptions-item>
|
||||
<el-descriptions-item label="产品">{{ selectedModelDetails.model_info.product_name }}</el-descriptions-item>
|
||||
<el-descriptions-item :label="selectedModelDetails.model_info.scopeLabel">{{ selectedModelDetails.model_info.scopeName }}</el-descriptions-item>
|
||||
<el-descriptions-item label="创建时间">{{ formatDateTime(selectedModelDetails.model_info.created_at) }}</el-descriptions-item>
|
||||
</el-descriptions>
|
||||
|
||||
@ -207,6 +222,27 @@ const models = ref([])
|
||||
const modelTypes = ref([])
|
||||
const loading = ref(true)
|
||||
const filters = reactive({ product_id: '', model_type: '' })
|
||||
const searchLoading = ref(false)
|
||||
const searchResults = ref([])
|
||||
|
||||
const searchOptions = async (query) => {
|
||||
if (query) {
|
||||
searchLoading.value = true
|
||||
try {
|
||||
const response = await axios.get('/api/management/filter-options', { params: { query } })
|
||||
if (response.data.status === 'success') {
|
||||
searchResults.value = response.data.data
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('获取筛选选项失败:', error)
|
||||
searchResults.value = []
|
||||
} finally {
|
||||
searchLoading.value = false
|
||||
}
|
||||
} else {
|
||||
searchResults.value = []
|
||||
}
|
||||
}
|
||||
|
||||
// 分页相关
|
||||
const pagination = reactive({
|
||||
@ -290,29 +326,43 @@ const handlePageSizeChange = (pageSize) => {
|
||||
|
||||
const viewDetails = async (model) => {
|
||||
detailsDialogVisible.value = true;
|
||||
selectedModelDetails.value = null; // 重置
|
||||
selectedModelDetails.value = null;
|
||||
try {
|
||||
// 新逻辑: 直接使用行数据,因为列表API已返回足够信息
|
||||
const details = {
|
||||
model_info: {
|
||||
model_id: model.model_uid,
|
||||
model_type: model.model_type,
|
||||
product_name: model.display_name,
|
||||
created_at: model.created_at,
|
||||
},
|
||||
training_metrics: model.performance_metrics,
|
||||
chart_data: {
|
||||
loss_chart: model.artifacts?.loss_curve_data || { epochs: [], train_loss: [], test_loss: [] }
|
||||
}
|
||||
};
|
||||
selectedModelDetails.value = details;
|
||||
nextTick(() => {
|
||||
initLossChart();
|
||||
});
|
||||
const response = await axios.get(`/api/models/${model.model_uid}`);
|
||||
if (response.data.status === 'success') {
|
||||
const details = response.data.data;
|
||||
details.training_metrics = normalizeMetricsKeys(details.training_metrics);
|
||||
selectedModelDetails.value = details;
|
||||
|
||||
// 准备用于显示的数据结构
|
||||
// 准备用于显示的数据结构
|
||||
const model_info = {
|
||||
model_id: details.model_uid,
|
||||
model_type: details.model_type,
|
||||
created_at: details.created_at,
|
||||
scopeLabel: '范围', // 默认标签
|
||||
scopeName: details.display_name // 默认名称
|
||||
};
|
||||
|
||||
if (details.training_mode === 'product') {
|
||||
model_info.scopeLabel = '产品';
|
||||
model_info.scopeName = details.product_name || details.display_name;
|
||||
} else if (details.training_mode === 'store') {
|
||||
model_info.scopeLabel = '店铺';
|
||||
model_info.scopeName = details.store_name || details.display_name;
|
||||
} else if (details.training_mode === 'global') {
|
||||
model_info.scopeLabel = '全局模型';
|
||||
model_info.scopeName = details.display_name;
|
||||
}
|
||||
|
||||
const formattedDetails = {
|
||||
model_info,
|
||||
training_metrics: normalizeMetricsKeys(details.performance_metrics),
|
||||
chart_data: {
|
||||
loss_chart: details.artifacts?.loss_curve_data || { epochs: [], train_loss: [], test_loss: [] }
|
||||
}
|
||||
};
|
||||
|
||||
selectedModelDetails.value = formattedDetails;
|
||||
|
||||
nextTick(() => {
|
||||
initLossChart();
|
||||
});
|
||||
|
82
data/old_5shops_50skus-据结构字典.md
Normal file
82
data/old_5shops_50skus-据结构字典.md
Normal file
@ -0,0 +1,82 @@
|
||||
|
||||
| 分类 | 字段名 | 数据类型 | 描述 | 来源 |
|
||||
| --- | --- | --- | --- | --- |
|
||||
| **标识符** | `subbh` | String | 店铺唯一标识 | 骨架 |
|
||||
| | `hh` | String | 商品唯一标识 | 骨架 |
|
||||
| | `kdrq` | Date | 开单日期 (主键之一) | 骨架 |
|
||||
| **核心指标** | `sales_quantity` | Float | 当日销售量 (无销售则为0) | 当日销售 |
|
||||
| | `return_quantity` | Float | 当日退货量 (无销售则为0) | 当日销售 |
|
||||
| | `net_sales_quantity` | Float | **当日净销售量 (目标变量)** | 当日销售 |
|
||||
| | `gross_profit_total` | Float | 当日毛利 (无销售则为0) | 当日销售 |
|
||||
| | `transaction_count` | Integer | 当日交易次数 (无销售则为0) | 当日销售 |
|
||||
| **日期特征** | `date` | Date | 日期 (冗余字段) | 时序计算 |
|
||||
| | `is_weekend` | Boolean | 是否为周末 (True/False) | 时序计算 |
|
||||
| | `day_of_week` | Integer | 一周中的第几天 (0=周一, 6=周日) | 时序计算 |
|
||||
| | `day_of_month` | Integer | 一月中的第几天 (1-31) | 时序计算 |
|
||||
| | `day_of_year` | Integer | 一年中的第几天 (1-366) | 时序计算 |
|
||||
| | `week_of_month` | Integer | 当月第几周 (1-5) | 时序计算 |
|
||||
| | `month` | Integer | 月份 (1-12) | 时序计算 |
|
||||
| | `quarter` | Integer | 季度 (1-4) | 时序计算 |
|
||||
| | `is_holiday` | Boolean | 是否为节假日 (True/False) | 时序计算 |
|
||||
| **生命周期特征** | `first_sale_date` | Date | SKU在店首次销售日期 | 生命周期 |
|
||||
| | `last_sale_date` | Date | SKU在店末次销售日期 | 生命周期 |
|
||||
| | `lifecycle_days` | Integer | SKU在店生命周期总天数 | 生命周期 |
|
||||
| | `sample_category` | String | 生命周期分类 (new/medium/old) | 生命周期 |
|
||||
| | `rolling_7d_valid` | Boolean | 7日滚动窗口是否有效 (距离首次销售>=7天) | 生命周期 |
|
||||
| | `rolling_15d_valid` | Boolean | 15日滚动窗口是否有效 | 生命周期 |
|
||||
| | `rolling_30d_valid` | Boolean | 30日滚动窗口是否有效 | 生命周期 |
|
||||
| | `rolling_90d_valid` | Boolean | 90日滚动窗口是否有效 | 生命周期 |
|
||||
| **滚动特征 (7天)** | `sales_quantity_rolling_mean_7d` | Float | 过去7日平均销售量 | 历史滚动 |
|
||||
| | `return_quantity_rolling_mean_7d` | Float | 过去7日平均退货量 | 历史滚动 |
|
||||
| | `net_sales_quantity_rolling_mean_7d`| Float | 过去7日平均净销量 | 历史滚动 |
|
||||
| | `sales_quantity_rolling_sum_7d` | Float | 过去7日总销售量 | 历史滚动 |
|
||||
| | `return_quantity_rolling_sum_7d` | Float | 过去7日总退货量 | 历史滚动 |
|
||||
| | `net_sales_quantity_rolling_sum_7d` | Float | 过去7日总净销量 | 历史滚动 |
|
||||
| **滚动特征 (15天)** | `sales_quantity_rolling_mean_15d` | Float | 过去15日平均销售量 | 历史滚动 |
|
||||
| | `return_quantity_rolling_mean_15d` | Float | 过去15日平均退货量 | 历史滚动 |
|
||||
| | `net_sales_quantity_rolling_mean_15d`| Float | 过去15日平均净销量 | 历史滚动 |
|
||||
| | `sales_quantity_rolling_sum_15d` | Float | 过去15日总销售量 | 历史滚动 |
|
||||
| | `return_quantity_rolling_sum_15d` | Float | 过去15日总退货量 | 历史滚动 |
|
||||
| | `net_sales_quantity_rolling_sum_15d` | Float | 过去15日总净销量 | 历史滚动 |
|
||||
| **滚动特征 (30天)** | `sales_quantity_rolling_mean_30d` | Float | 过去30日平均销售量 | 历史滚动 |
|
||||
| | `return_quantity_rolling_mean_30d` | Float | 过去30日平均退货量 | 历史滚动 |
|
||||
| | `net_sales_quantity_rolling_mean_30d`| Float | 过去30日平均净销量 | 历史滚动 |
|
||||
| | `sales_quantity_rolling_sum_30d` | Float | 过去30日总销售量 | 历史滚动 |
|
||||
| | `return_quantity_rolling_sum_30d` | Float | 过去30日总退货量 | 历史滚动 |
|
||||
| | `net_sales_quantity_rolling_sum_30d` | Float | 过去30日总净销量 | 历史滚动 |
|
||||
| **滚动特征 (90天)** | `sales_quantity_rolling_mean_90d` | Float | 过去90日平均销售量 | 历史滚动 |
|
||||
| | `return_quantity_rolling_mean_90d` | Float | 过去90日平均退货量 | 历史滚动 |
|
||||
| | `net_sales_quantity_rolling_mean_90d`| Float | 过去90日平均净销量 | 历史滚动 |
|
||||
| | `sales_quantity_rolling_sum_90d` | Float | 过去90日总销售量 | 历史滚动 |
|
||||
| | `return_quantity_rolling_sum_90d` | Float | 过去90日总退货量 | 历史滚动 |
|
||||
| | `net_sales_quantity_rolling_sum_90d` | Float | 过去90日总净销量 | 历史滚动 |
|
||||
| **滚动特征 (180天)** | `sales_quantity_rolling_mean_180d` | Float | 过去180日平均销售量 | 历史滚动 |
|
||||
| | `return_quantity_rolling_mean_180d` | Float | 过去180日平均退货量 | 历史滚动 |
|
||||
| | `net_sales_quantity_rolling_mean_180d`| Float | 过去180日平均净销量 | 历史滚动 |
|
||||
| | `sales_quantity_rolling_sum_180d` | Float | 过去180日总销售量 | 历史滚动 |
|
||||
| | `return_quantity_rolling_sum_180d` | Float | 过去180日总退货量 | 历史滚动 |
|
||||
| | `net_sales_quantity_rolling_sum_180d` | Float | 过去180日总净销量 | 历史滚动 |
|
||||
| **滚动特征 (365天)** | `sales_quantity_rolling_mean_365d` | Float | 过去365日平均销售量 | 历史滚动 |
|
||||
| | `return_quantity_rolling_mean_365d` | Float | 过去365日平均退货量 | 历史滚动 |
|
||||
| | `net_sales_quantity_rolling_mean_365d`| Float | 过去365日平均净销量 | 历史滚动 |
|
||||
| | `sales_quantity_rolling_sum_365d` | Float | 过去365日总销售量 | 历史滚动 |
|
||||
| | `return_quantity_rolling_sum_365d` | Float | 过去365日总退货量 | 历史滚动 |
|
||||
| | `net_sales_quantity_rolling_sum_365d` | Float | 过去365日总净销量 | 历史滚动 |
|
||||
| **店铺特征** | `province` | String | 店铺所在省份 | 店铺特征 |
|
||||
| | `city` | String | 店铺所在城市 | 店铺特征 |
|
||||
| | `district` | String | 店铺所在行政区 | 店铺特征 |
|
||||
| | `poi_residential_count` | Integer | 周边住宅区POI数量 | 店铺特征 |
|
||||
| | `poi_school_count` | Integer | 周边学校POI数量 | 店铺特征 |
|
||||
| | `poi_mall_count` | Integer | 周边购物中心POI数量 | 店铺特征 |
|
||||
| | `temperature_2m_max` | Float | 当日最高气温 | 店铺特征 |
|
||||
| | `temperature_2m_min` | Float | 当日最低气温 | 店铺特征 |
|
||||
| | `temperature_2m_mean`| Float | 当日平均气温 | 店铺特征 |
|
||||
| **商品特征** | `零售大类代码_encoded` | Integer | 零售大类代码的数字编码 | 商品特征 |
|
||||
| | `零售中类代码_encoded` | Integer | 零售中类代码的数字编码 | 商品特征 |
|
||||
| | `零售小类代码_encoded` | Integer | 零售小类代码的数字编码 | 商品特征 |
|
||||
| | `商品ABC分类_encoded` | Integer | 商品ABC分类的数字编码 | 商品特征 |
|
||||
| | `商品手册代码_encoded` | Integer | 商品手册代码的数字编码 | 商品特征 |
|
||||
| | `产地_encoded` | Integer | 产地的数字编码 | 商品特征 |
|
||||
| | `brand_encoded` | Integer | 品牌的数字编码 | 商品特征 |
|
||||
| | `packaging_quantity` | Float | 包装数量 (从规格中提取) | 商品特征 |
|
||||
| | `approval_type_encoded` | Integer | 批准文号类型的数字编码 | 商品特征 |
|
BIN
data/old_5shops_50skus.parquet
Normal file
BIN
data/old_5shops_50skus.parquet
Normal file
Binary file not shown.
Binary file not shown.
258
server/api.py
258
server/api.py
@ -1993,34 +1993,54 @@ def delete_prediction(prediction_id):
|
||||
def get_history_filter_options():
|
||||
"""获取历史记录页面用于筛选的选项列表"""
|
||||
try:
|
||||
# 1. 获取所有标准产品
|
||||
products = get_available_products()
|
||||
conn = get_db_connection()
|
||||
cursor = conn.cursor()
|
||||
|
||||
# 2. 获取所有店铺
|
||||
stores = get_available_stores()
|
||||
# 从 prediction_history 表中查询所有唯一的 product_name
|
||||
cursor.execute("SELECT DISTINCT product_name FROM prediction_history WHERE product_name IS NOT NULL")
|
||||
records = cursor.fetchall()
|
||||
conn.close()
|
||||
|
||||
# 3. 从历史记录中获取特殊的预测名称(如全局预测)
|
||||
options_map = {}
|
||||
# 构建选项列表
|
||||
options = [{'value': row['product_name'], 'label': row['product_name']} for row in records]
|
||||
|
||||
# 添加产品
|
||||
for p in products:
|
||||
options_map[p['product_id']] = {'value': p['product_id'], 'label': p['product_name'], 'type': 'product'}
|
||||
|
||||
# 添加店铺
|
||||
for s in stores:
|
||||
options_map[s['store_id']] = {'value': s['store_id'], 'label': s['store_name'], 'type': 'store'}
|
||||
|
||||
# v14 修复: 硬编码添加“全局预测”选项
|
||||
global_prediction_key = "全局预测"
|
||||
if global_prediction_key not in options_map:
|
||||
options_map[global_prediction_key] = {'value': global_prediction_key, 'label': global_prediction_key, 'type': 'special'}
|
||||
|
||||
return jsonify({"status": "success", "data": list(options_map.values())})
|
||||
return jsonify({"status": "success", "data": options})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取历史筛选选项失败: {e}\n{traceback.format_exc()}")
|
||||
return jsonify({"status": "error", "message": str(e)}), 500
|
||||
|
||||
@app.route('/api/management/filter-options', methods=['GET'])
|
||||
def get_management_filter_options():
|
||||
"""获取模型管理页面筛选框的动态选项"""
|
||||
try:
|
||||
query = request.args.get('query', '').lower()
|
||||
|
||||
options = []
|
||||
|
||||
# 添加全局预测选项
|
||||
global_option = {'value': 'global', 'label': '全局预测'}
|
||||
if not query or '全局' in query or 'global' in query:
|
||||
options.append(global_option)
|
||||
|
||||
# 获取并筛选产品
|
||||
products = get_available_products()
|
||||
for p in products:
|
||||
label = f"{p.get('product_name', '未知产品')} ({p['product_id']})"
|
||||
if not query or query in label.lower():
|
||||
options.append({'value': p['product_id'], 'label': label})
|
||||
|
||||
# 获取并筛选店铺
|
||||
stores = get_available_stores()
|
||||
for s in stores:
|
||||
label = f"{s.get('store_name', '未知店铺')} ({s['store_id']})"
|
||||
if not query or query in label.lower():
|
||||
options.append({'value': s['store_id'], 'label': label})
|
||||
|
||||
return jsonify({"status": "success", "data": options})
|
||||
except Exception as e:
|
||||
logger.error(f"获取管理筛选选项失败: {e}\n{traceback.format_exc()}")
|
||||
return jsonify({"status": "error", "message": str(e)}), 500
|
||||
# 4. 模型管理API
|
||||
@app.route('/api/models', methods=['GET'])
|
||||
@swag_from({
|
||||
@ -2227,100 +2247,60 @@ def list_models():
|
||||
})
|
||||
def get_model_details(model_id):
|
||||
"""
|
||||
获取单个模型的详细信息
|
||||
---
|
||||
tags:
|
||||
- 模型管理
|
||||
parameters:
|
||||
- name: model_id
|
||||
in: path
|
||||
type: string
|
||||
required: true
|
||||
description: "模型的唯一标识符 (格式: model_type_product_id)"
|
||||
responses:
|
||||
200:
|
||||
description: 模型的详细信息
|
||||
404:
|
||||
description: 未找到模型
|
||||
获取单个模型的详细信息 (v3 - 统一数据结构)
|
||||
"""
|
||||
try:
|
||||
model_type, product_id = model_id.split('_', 1)
|
||||
# 智能处理带 '_best' 后缀的UID
|
||||
db_query_uid = model_id[:-len('_best')] if model_id.endswith('_best') else model_id
|
||||
|
||||
# 处理优化版KAN模型的文件名
|
||||
file_model_type = model_type
|
||||
if model_type == 'optimized_kan':
|
||||
file_model_type = 'kan_optimized'
|
||||
model_record = find_model_by_uid(db_query_uid)
|
||||
|
||||
# 首先尝试从app配置中获取模型目录
|
||||
models_dir = app.config.get('MODEL_DIR', DEFAULT_MODEL_DIR)
|
||||
if not model_record:
|
||||
return jsonify({"status": "error", "message": "模型未找到"}), 404
|
||||
|
||||
# 检查models_dir是否存在,如果不存在,使用DEFAULT_MODEL_DIR作为后备
|
||||
if not os.path.exists(models_dir) and os.path.exists(DEFAULT_MODEL_DIR):
|
||||
print(f"警告: 配置的模型目录 '{models_dir}' 不存在,使用默认目录 '{DEFAULT_MODEL_DIR}'")
|
||||
models_dir = DEFAULT_MODEL_DIR
|
||||
# 将 sqlite3.Row 转换为可修改的字典
|
||||
model_data = dict(model_record)
|
||||
|
||||
# 尝试多种可能的文件名格式
|
||||
possible_patterns = [
|
||||
f'{file_model_type}_product_{product_id}_v1.pth', # 新格式
|
||||
f'{file_model_type}_model_product_{product_id}.pth', # 旧格式
|
||||
f'{file_model_type}_{product_id}_v1.pth', # 备用格式
|
||||
]
|
||||
# 解析JSON字段
|
||||
model_data['training_scope'] = json.loads(model_data.get('training_scope', '{}'))
|
||||
model_data['performance_metrics'] = json.loads(model_data.get('performance_metrics', '{}'))
|
||||
model_data['artifacts'] = json.loads(model_data.get('artifacts', '{}'))
|
||||
|
||||
model_path = None
|
||||
for pattern in possible_patterns:
|
||||
test_path = os.path.join(models_dir, pattern)
|
||||
if os.path.exists(test_path):
|
||||
model_path = test_path
|
||||
print(f"找到模型文件: {pattern}")
|
||||
break
|
||||
# 统一化修复:完全复制 list_models 的名称处理逻辑,确保数据结构一致
|
||||
scope = model_data.get('training_scope', {})
|
||||
mode = model_data.get('training_mode')
|
||||
|
||||
if not model_path:
|
||||
print(f"未找到模型文件,尝试的路径:")
|
||||
for pattern in possible_patterns:
|
||||
test_path = os.path.join(models_dir, pattern)
|
||||
print(f" - {test_path}")
|
||||
return jsonify({"status": "error", "error": "模型未找到"}), 404
|
||||
# 初始化所有可能的名称字段
|
||||
model_data['product_name'] = None
|
||||
model_data['store_name'] = None
|
||||
# 优先使用数据库中的 display_name
|
||||
display_name = model_data.get('display_name')
|
||||
|
||||
# 加载模型文件
|
||||
try:
|
||||
# 添加weights_only=False参数,解决PyTorch 2.6序列化问题
|
||||
checkpoint = torch.load(model_path, map_location='cpu', weights_only=False)
|
||||
if isinstance(scope, dict):
|
||||
product_info = scope.get('product')
|
||||
store_info = scope.get('store')
|
||||
|
||||
# 提取模型信息
|
||||
model_info = {
|
||||
"model_id": model_id,
|
||||
"product_id": product_id,
|
||||
"model_type": model_type,
|
||||
"created_at": datetime.fromtimestamp(os.path.getctime(model_path)).isoformat(),
|
||||
"file_path": model_path,
|
||||
"file_size": f"{os.path.getsize(model_path) / (1024 * 1024):.2f} MB"
|
||||
}
|
||||
if mode == 'product' and isinstance(product_info, dict):
|
||||
model_data['product_name'] = product_info.get('name')
|
||||
if not display_name: display_name = model_data['product_name']
|
||||
|
||||
# 如果checkpoint是字典,提取其中的信息
|
||||
if isinstance(checkpoint, dict):
|
||||
# 提取配置信息
|
||||
if 'config' in checkpoint:
|
||||
config = checkpoint['config']
|
||||
for key, value in config.items():
|
||||
model_info[key] = value
|
||||
elif mode == 'store' and isinstance(store_info, dict):
|
||||
model_data['store_name'] = store_info.get('name')
|
||||
if not display_name: display_name = model_data['store_name']
|
||||
|
||||
# 提取评估指标
|
||||
if 'metrics' in checkpoint:
|
||||
model_info['metrics'] = checkpoint['metrics']
|
||||
elif mode == 'global':
|
||||
if not display_name: display_name = "全局模型"
|
||||
|
||||
# 获取产品名称
|
||||
product_name = get_product_name(product_id)
|
||||
if product_name:
|
||||
model_info['product_name'] = product_name
|
||||
# 提供最终的后备方案
|
||||
if not display_name:
|
||||
display_name = "信息不完整"
|
||||
|
||||
return jsonify({"status": "success", "data": model_info})
|
||||
except Exception as e:
|
||||
print(f"加载模型文件失败: {str(e)}")
|
||||
return jsonify({"status": "error", "error": f"加载模型文件失败: {str(e)}"}), 500
|
||||
except ValueError:
|
||||
return jsonify({"status": "error", "error": "无效的model_id格式"}), 400
|
||||
model_data['display_name'] = display_name
|
||||
|
||||
return jsonify({"status": "success", "data": model_data})
|
||||
except Exception as e:
|
||||
return jsonify({"status": "error", "error": f"获取模型详情失败: {e}"}), 500
|
||||
logger.error(f"获取模型详情失败: {e}\n{traceback.format_exc()}")
|
||||
return jsonify({"status": "error", "message": str(e)}), 500
|
||||
|
||||
@app.route('/api/models/<model_id>', methods=['DELETE'])
|
||||
@swag_from({
|
||||
@ -2357,68 +2337,46 @@ def get_model_details(model_id):
|
||||
})
|
||||
def delete_model(model_id):
|
||||
"""
|
||||
删除一个模型及其关联文件
|
||||
---
|
||||
tags:
|
||||
- 模型管理
|
||||
parameters:
|
||||
- name: model_id
|
||||
in: path
|
||||
type: string
|
||||
required: true
|
||||
description: "要删除的模型的ID (格式: model_type_product_id)"
|
||||
responses:
|
||||
200:
|
||||
description: 模型删除成功
|
||||
404:
|
||||
description: 模型未找到
|
||||
删除一个模型及其关联文件 (v2 - 基于数据库)
|
||||
"""
|
||||
try:
|
||||
model_type, product_id = model_id.split('_', 1)
|
||||
# 智能处理带 '_best' 后缀的UID
|
||||
db_query_uid = model_id[:-len('_best')] if model_id.endswith('_best') else model_id
|
||||
|
||||
# 处理优化版KAN模型的文件名
|
||||
file_model_type = model_type
|
||||
if model_type == 'optimized_kan':
|
||||
file_model_type = 'kan_optimized'
|
||||
conn = get_db_connection()
|
||||
cursor = conn.cursor()
|
||||
|
||||
# 首先尝试从app配置中获取模型目录
|
||||
models_dir = app.config.get('MODEL_DIR', DEFAULT_MODEL_DIR)
|
||||
# 查找模型记录
|
||||
cursor.execute("SELECT artifacts FROM models WHERE model_uid = ?", (db_query_uid,))
|
||||
record = cursor.fetchone()
|
||||
|
||||
# 检查models_dir是否存在,如果不存在,使用DEFAULT_MODEL_DIR作为后备
|
||||
if not os.path.exists(models_dir) and os.path.exists(DEFAULT_MODEL_DIR):
|
||||
print(f"警告: 配置的模型目录 '{models_dir}' 不存在,使用默认目录 '{DEFAULT_MODEL_DIR}'")
|
||||
models_dir = DEFAULT_MODEL_DIR
|
||||
if not record:
|
||||
conn.close()
|
||||
return jsonify({"status": "error", "message": "模型未找到"}), 404
|
||||
|
||||
# 尝试多种可能的文件名格式
|
||||
possible_patterns = [
|
||||
f'{file_model_type}_product_{product_id}_v1.pth', # 新格式
|
||||
f'{file_model_type}_model_product_{product_id}.pth', # 旧格式
|
||||
f'{file_model_type}_{product_id}_v1.pth', # 备用格式
|
||||
]
|
||||
# 删除数据库记录
|
||||
cursor.execute("DELETE FROM models WHERE model_uid = ?", (db_query_uid,))
|
||||
conn.commit()
|
||||
|
||||
model_path = None
|
||||
for pattern in possible_patterns:
|
||||
test_path = os.path.join(models_dir, pattern)
|
||||
if os.path.exists(test_path):
|
||||
model_path = test_path
|
||||
print(f"找到模型文件: {pattern}")
|
||||
break
|
||||
# 删除关联的模型文件
|
||||
try:
|
||||
artifacts = json.loads(record['artifacts'])
|
||||
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
if not model_path:
|
||||
print(f"未找到模型文件,尝试的路径:")
|
||||
for pattern in possible_patterns:
|
||||
test_path = os.path.join(models_dir, pattern)
|
||||
print(f" - {test_path}")
|
||||
return jsonify({"status": "error", "error": "模型未找到"}), 404
|
||||
|
||||
# 删除模型文件
|
||||
os.remove(model_path)
|
||||
for key, path in artifacts.items():
|
||||
if path and isinstance(path, str):
|
||||
full_path = os.path.join(project_root, path)
|
||||
if os.path.exists(full_path):
|
||||
os.remove(full_path)
|
||||
logger.info(f"已删除文件: {full_path}")
|
||||
except (json.JSONDecodeError, TypeError, OSError) as e:
|
||||
logger.error(f"删除模型文件失败: {e}")
|
||||
|
||||
conn.close()
|
||||
return jsonify({"status": "success", "message": f"模型 {model_id} 已删除"})
|
||||
except ValueError:
|
||||
return jsonify({"status": "error", "error": "无效的model_id格式"}), 400
|
||||
except Exception as e:
|
||||
return jsonify({"status": "error", "error": f"删除模型失败: {e}"}), 500
|
||||
logger.error(f"删除模型失败: {e}\n{traceback.format_exc()}")
|
||||
return jsonify({"status": "error", "message": str(e)}), 500
|
||||
|
||||
@app.route('/api/models/<model_id>/export', methods=['GET'])
|
||||
@swag_from({
|
||||
|
@ -65,13 +65,21 @@ def query_models_from_db(filters: dict, page: int = 1, page_size: int = 10):
|
||||
conditions = []
|
||||
params = []
|
||||
|
||||
if filters.get('product_id'):
|
||||
conditions.append("json_extract(training_scope, '$.product.id') = ?")
|
||||
params.append(filters['product_id'])
|
||||
product_id_filter = filters.get('product_id')
|
||||
if product_id_filter:
|
||||
if product_id_filter.lower() == 'global':
|
||||
conditions.append("training_mode = ?")
|
||||
params.append('global')
|
||||
elif product_id_filter.startswith('S'):
|
||||
conditions.append("json_extract(training_scope, '$.store.id') = ?")
|
||||
params.append(product_id_filter)
|
||||
else:
|
||||
conditions.append("(json_extract(training_scope, '$.product.id') = ? OR display_name LIKE ?)")
|
||||
params.extend([product_id_filter, f"%{product_id_filter}%"])
|
||||
|
||||
if filters.get('model_type'):
|
||||
conditions.append("model_type = ?")
|
||||
params.append(filters['model_type'])
|
||||
conditions.append("model_type LIKE ?")
|
||||
params.append(f"{filters['model_type']}%")
|
||||
|
||||
if filters.get('store_id'):
|
||||
conditions.append("json_extract(training_scope, '$.store.id') = ?")
|
||||
|
Loading…
x
Reference in New Issue
Block a user