ShopTRAINING/fix_old_predictions.py

123 lines
4.8 KiB
Python

import sys
import os
import json
import sqlite3
import traceback
# 确保脚本可以找到项目模块
current_dir = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.join(current_dir, 'server'))
from predictors.model_predictor import load_model_and_predict
from init_multi_store_db import get_db_connection
from utils.model_manager import ModelManager
from api import CustomJSONEncoder
def fix_old_prediction_data():
"""
遍历数据库中的历史预测记录,重新生成并覆盖被截断的预测数据文件。
"""
print("开始修复历史预测数据...")
conn = get_db_connection()
conn.row_factory = sqlite3.Row
cursor = conn.cursor()
try:
cursor.execute("SELECT * FROM prediction_history ORDER BY id")
records = cursor.fetchall()
if not records:
print("✅ 数据库中没有历史预测记录,无需修复。")
return
print(f"发现 {len(records)} 条历史记录,开始逐一检查...")
model_manager = ModelManager(os.path.join(current_dir, 'saved_models'))
fixed_count = 0
skipped_count = 0
error_count = 0
for record in records:
record_id = record['id']
file_path = record['file_path']
future_days_db = record['future_days']
try:
if not file_path or not os.path.exists(file_path):
print(f"跳过记录 {record_id}: 文件不存在 at {file_path}")
skipped_count += 1
continue
with open(file_path, 'r', encoding='utf-8') as f:
data = json.load(f)
# 兼容旧数据格式
prediction_data = data.get('prediction_data', data.get('predictions', []))
if len(prediction_data) < future_days_db:
print(f"记录 {record_id} 需要修复 (文件: {len(prediction_data)}天, 数据库: {future_days_db}天). 开始重新生成...")
# 从 model_id 解析所需信息
model_id_parts = record['model_id'].split('_')
version = model_id_parts[-1]
# 使用 list_models 查找对应的模型文件
models_result = model_manager.list_models(
product_id=record['product_id'],
model_type=record['model_type']
)
models = models_result.get('models', [])
found_model = None
for model in models:
if model.get('version') == version:
found_model = model
break
if not found_model:
print(f"❌ 无法找到用于修复的模型文件: product={record['product_id']}, type={record['model_type']}, version={version}")
error_count += 1
continue
model_path = found_model['file_path']
# 重新生成预测
new_prediction_result = load_model_and_predict(
model_path=model_path,
product_id=record['product_id'],
model_type=record['model_type'],
version=version,
future_days=future_days_db,
start_date=record['start_date'],
history_lookback_days=30 # 使用旧的默认值
)
if new_prediction_result:
# 覆盖旧的JSON文件
with open(file_path, 'w', encoding='utf-8') as f:
json.dump(new_prediction_result, f, ensure_ascii=False, indent=4, cls=CustomJSONEncoder)
print(f"成功修复并覆盖文件: {file_path}")
fixed_count += 1
else:
print(f"修复记录 {record_id} 失败: 预测函数返回空结果。")
error_count += 1
else:
skipped_count += 1
except Exception as e:
print(f"处理记录 {record_id} 时发生错误: {e}")
traceback.print_exc()
error_count += 1
print("\n--- 修复完成 ---")
print(f"总记录数: {len(records)}")
print(f"已修复: {fixed_count}")
print(f"已跳过 (无需修复): {skipped_count}")
print(f"失败: {error_count}")
finally:
conn.close()
if __name__ == '__main__':
fix_old_prediction_data()