123 lines
4.8 KiB
Python
123 lines
4.8 KiB
Python
import sys
|
|
import os
|
|
import json
|
|
import sqlite3
|
|
import traceback
|
|
|
|
# 确保脚本可以找到项目模块
|
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
|
sys.path.append(os.path.join(current_dir, 'server'))
|
|
|
|
from predictors.model_predictor import load_model_and_predict
|
|
from init_multi_store_db import get_db_connection
|
|
from utils.model_manager import ModelManager
|
|
from api import CustomJSONEncoder
|
|
|
|
def fix_old_prediction_data():
|
|
"""
|
|
遍历数据库中的历史预测记录,重新生成并覆盖被截断的预测数据文件。
|
|
"""
|
|
print("开始修复历史预测数据...")
|
|
|
|
conn = get_db_connection()
|
|
conn.row_factory = sqlite3.Row
|
|
cursor = conn.cursor()
|
|
|
|
try:
|
|
cursor.execute("SELECT * FROM prediction_history ORDER BY id")
|
|
records = cursor.fetchall()
|
|
|
|
if not records:
|
|
print("✅ 数据库中没有历史预测记录,无需修复。")
|
|
return
|
|
|
|
print(f"发现 {len(records)} 条历史记录,开始逐一检查...")
|
|
|
|
model_manager = ModelManager(os.path.join(current_dir, 'saved_models'))
|
|
fixed_count = 0
|
|
skipped_count = 0
|
|
error_count = 0
|
|
|
|
for record in records:
|
|
record_id = record['id']
|
|
file_path = record['file_path']
|
|
future_days_db = record['future_days']
|
|
|
|
try:
|
|
if not file_path or not os.path.exists(file_path):
|
|
print(f"跳过记录 {record_id}: 文件不存在 at {file_path}")
|
|
skipped_count += 1
|
|
continue
|
|
|
|
with open(file_path, 'r', encoding='utf-8') as f:
|
|
data = json.load(f)
|
|
|
|
# 兼容旧数据格式
|
|
prediction_data = data.get('prediction_data', data.get('predictions', []))
|
|
|
|
if len(prediction_data) < future_days_db:
|
|
print(f"记录 {record_id} 需要修复 (文件: {len(prediction_data)}天, 数据库: {future_days_db}天). 开始重新生成...")
|
|
|
|
# 从 model_id 解析所需信息
|
|
model_id_parts = record['model_id'].split('_')
|
|
version = model_id_parts[-1]
|
|
|
|
# 使用 list_models 查找对应的模型文件
|
|
models_result = model_manager.list_models(
|
|
product_id=record['product_id'],
|
|
model_type=record['model_type']
|
|
)
|
|
|
|
models = models_result.get('models', [])
|
|
found_model = None
|
|
for model in models:
|
|
if model.get('version') == version:
|
|
found_model = model
|
|
break
|
|
|
|
if not found_model:
|
|
print(f"❌ 无法找到用于修复的模型文件: product={record['product_id']}, type={record['model_type']}, version={version}")
|
|
error_count += 1
|
|
continue
|
|
|
|
model_path = found_model['file_path']
|
|
|
|
# 重新生成预测
|
|
new_prediction_result = load_model_and_predict(
|
|
model_path=model_path,
|
|
product_id=record['product_id'],
|
|
model_type=record['model_type'],
|
|
version=version,
|
|
future_days=future_days_db,
|
|
start_date=record['start_date'],
|
|
history_lookback_days=30 # 使用旧的默认值
|
|
)
|
|
|
|
if new_prediction_result:
|
|
# 覆盖旧的JSON文件
|
|
with open(file_path, 'w', encoding='utf-8') as f:
|
|
json.dump(new_prediction_result, f, ensure_ascii=False, indent=4, cls=CustomJSONEncoder)
|
|
print(f"成功修复并覆盖文件: {file_path}")
|
|
fixed_count += 1
|
|
else:
|
|
print(f"修复记录 {record_id} 失败: 预测函数返回空结果。")
|
|
error_count += 1
|
|
else:
|
|
skipped_count += 1
|
|
|
|
except Exception as e:
|
|
print(f"处理记录 {record_id} 时发生错误: {e}")
|
|
traceback.print_exc()
|
|
error_count += 1
|
|
|
|
print("\n--- 修复完成 ---")
|
|
print(f"总记录数: {len(records)}")
|
|
print(f"已修复: {fixed_count}")
|
|
print(f"已跳过 (无需修复): {skipped_count}")
|
|
print(f"失败: {error_count}")
|
|
|
|
finally:
|
|
conn.close()
|
|
|
|
if __name__ == '__main__':
|
|
fix_old_prediction_data() |