ShopTRAINING/run_pharmacy_prediction.py
gdtiti c0fe213b70 修复图表显示和数据处理问题
1. 修复前端图表日期排序问题:
   - 改进 PredictionView.vue 和 HistoryView.vue 中的图表渲染逻辑
   - 确保历史数据和预测数据按照正确的日期顺序显示

2. 修复后端API处理:
   - 解决 optimized_kan 模型类型的路径映射问题
   - 添加 JSON 序列化器处理 Pandas Timestamp 对象
   - 改进预测数据与历史数据的衔接处理

3. 优化图表样式和用户体验
2025-06-15 00:01:57 +08:00

279 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from datetime import datetime
import sys
# 导入模块
from pharmacy_predictor import train_product_model, train_product_model_with_mlstm, train_product_model_with_kan, train_product_model_with_transformer
from pharmacy_predictor import load_model_and_predict, load_kan_model_and_predict
from models import ModelManager
from pharmacy_predictor import PharmacyPredictor
from generate_pharmacy_data import generate_pharmacy_data
def clear_screen():
"""清空控制台屏幕"""
os.system('cls' if os.name == 'nt' else 'clear')
def print_header():
"""打印应用程序标题"""
print("=" * 80)
print(" " * 25 + "药店销售预测系统" + " " * 25)
print("=" * 80)
def main_menu():
"""显示主菜单并获取用户选择"""
print("\n主菜单:")
print("1. 训练所有药品的销售预测模型")
print("2. 训练单个药品的销售预测模型 (Transformer)")
print("3. 训练单个药品的销售预测模型 (mLSTM)")
print("4. 训练单个药品的销售预测模型 (KAN)")
print("5. 训练单个药品的销售预测模型 (优化版KAN)") # 新增选项
print("6. 比较原始KAN和优化版KAN模型性能") # 新增选项
print("7. 查看已有预测结果")
print("8. 使用已训练的模型进行预测")
print("9. 比较不同模型的预测结果")
print("10. 模型管理")
print("0. 退出")
choice = input("\n请输入您的选择 (0-10): ")
return choice
def train_all_products_menu(predictor):
"""训练所有药品的菜单"""
clear_screen()
print_header()
print("\n选择模型类型:")
print("1. Transformer")
print("2. mLSTM")
print("3. KAN")
print("4. 优化版KAN") # 新增选项
print("0. 返回主菜单")
choice = input("\n请输入您的选择 (0-4): ")
if choice == '0':
return
model_types = {
'1': 'transformer',
'2': 'mlstm',
'3': 'kan',
'4': 'kan' # 使用相同的模型类型但在训练时指定use_optimized=True
}
if choice in model_types:
model_type = model_types[choice]
use_optimized = (choice == '4') # 如果选择4则使用优化版KAN
# 获取所有产品ID
product_ids = predictor.data['product_id'].unique()
for product_id in product_ids:
print(f"\n正在训练产品 {product_id} 的模型...")
if model_type == 'kan' and use_optimized:
predictor.train_optimized_kan_model(product_id)
else:
predictor.train_model(product_id, model_type=model_type)
print("\n所有产品的模型训练完成!")
input("\n按Enter键继续...")
else:
print("\n无效的选择!")
input("\n按Enter键继续...")
def train_single_product_menu(predictor, model_type, use_optimized=False):
"""训练单个药品的菜单"""
clear_screen()
print_header()
model_name = model_type
if model_type == 'kan' and use_optimized:
model_name = "优化版KAN"
print(f"\n训练单个药品的{model_name}模型")
# 获取所有产品ID
if predictor.data is not None:
product_ids = predictor.data['product_id'].unique()
print("\n可用的产品ID:")
for i, product_id in enumerate(product_ids):
print(f"{i+1}. {product_id}")
choice = input("\n请选择产品ID的编号 (或输入0返回): ")
if choice == '0':
return
try:
idx = int(choice) - 1
if 0 <= idx < len(product_ids):
product_id = product_ids[idx]
# 获取训练参数
epochs = int(input("\n请输入训练轮数 (默认100): ") or "100")
batch_size = int(input("请输入批次大小 (默认32): ") or "32")
learning_rate = float(input("请输入学习率 (默认0.001): ") or "0.001")
sequence_length = int(input("请输入输入序列长度 (默认30): ") or "30")
forecast_horizon = int(input("请输入预测天数 (默认7): ") or "7")
hidden_size = int(input("请输入隐藏层大小 (默认64): ") or "64")
num_layers = int(input("请输入层数 (默认2): ") or "2")
dropout = float(input("请输入Dropout比例 (默认0.1): ") or "0.1")
print(f"\n开始训练产品 {product_id}{model_name}模型...")
if model_type == 'kan' and use_optimized:
# 使用优化版KAN模型
metrics = predictor.train_optimized_kan_model(
product_id=product_id,
epochs=epochs,
batch_size=batch_size,
learning_rate=learning_rate,
sequence_length=sequence_length,
forecast_horizon=forecast_horizon,
hidden_size=hidden_size,
num_layers=num_layers,
dropout=dropout
)
else:
# 使用普通模型
metrics = predictor.train_model(
product_id=product_id,
model_type=model_type,
epochs=epochs,
batch_size=batch_size,
learning_rate=learning_rate,
sequence_length=sequence_length,
forecast_horizon=forecast_horizon,
hidden_size=hidden_size,
num_layers=num_layers,
dropout=dropout
)
print("\n模型训练完成!")
print(f"MSE: {metrics['mse']:.4f}")
print(f"RMSE: {metrics['rmse']:.4f}")
print(f"MAE: {metrics['mae']:.4f}")
print(f"R²: {metrics['r2']:.4f}")
print(f"MAPE: {metrics['mape']:.2f}%")
print(f"训练时间: {metrics['training_time']:.2f}")
else:
print("\n无效的选择!")
except (ValueError, IndexError):
print("\n无效的输入!")
else:
print("\n没有可用的数据。请先生成或加载数据。")
input("\n按Enter键继续...")
def compare_kan_models_menu(predictor):
"""比较原始KAN和优化版KAN模型性能的菜单"""
clear_screen()
print_header()
print("\n比较原始KAN和优化版KAN模型性能")
# 获取所有产品ID
if predictor.data is not None:
product_ids = predictor.data['product_id'].unique()
print("\n可用的产品ID:")
for i, product_id in enumerate(product_ids):
print(f"{i+1}. {product_id}")
choice = input("\n请选择产品ID的编号 (或输入0返回): ")
if choice == '0':
return
try:
idx = int(choice) - 1
if 0 <= idx < len(product_ids):
product_id = product_ids[idx]
# 获取训练参数
epochs = int(input("\n请输入训练轮数 (默认100): ") or "100")
batch_size = int(input("请输入批次大小 (默认32): ") or "32")
learning_rate = float(input("请输入学习率 (默认0.001): ") or "0.001")
sequence_length = int(input("请输入输入序列长度 (默认30): ") or "30")
forecast_horizon = int(input("请输入预测天数 (默认7): ") or "7")
hidden_size = int(input("请输入隐藏层大小 (默认64): ") or "64")
num_layers = int(input("请输入层数 (默认2): ") or "2")
dropout = float(input("请输入Dropout比例 (默认0.1): ") or "0.1")
print(f"\n开始比较产品 {product_id} 的原始KAN和优化版KAN模型...")
# 使用比较方法
comparison = predictor.compare_kan_models(
product_id=product_id,
epochs=epochs,
batch_size=batch_size,
learning_rate=learning_rate,
sequence_length=sequence_length,
forecast_horizon=forecast_horizon,
hidden_size=hidden_size,
num_layers=num_layers,
dropout=dropout
)
print("\n比较完成!")
else:
print("\n无效的选择!")
except (ValueError, IndexError):
print("\n无效的输入!")
else:
print("\n没有可用的数据。请先生成或加载数据。")
input("\n按Enter键继续...")
def main():
"""主函数"""
# 检查数据文件是否存在,如果不存在则生成
data_path = 'pharmacy_sales.xlsx'
if not os.path.exists(data_path):
print("生成药店销售数据...")
generate_pharmacy_data(num_products=10, days=365, output_file=data_path)
print(f"数据已生成并保存到 {data_path}")
# 创建预测器实例
predictor = PharmacyPredictor(data_path=data_path)
while True:
clear_screen()
print_header()
choice = main_menu()
if choice == '0':
print("\n感谢使用药店销售预测系统,再见!")
break
elif choice == '1':
train_all_products_menu(predictor)
elif choice == '2':
train_single_product_menu(predictor, model_type='transformer')
elif choice == '3':
train_single_product_menu(predictor, model_type='mlstm')
elif choice == '4':
train_single_product_menu(predictor, model_type='kan')
elif choice == '5':
# 使用优化版KAN模型
train_single_product_menu(predictor, model_type='kan', use_optimized=True)
elif choice == '6':
# 比较原始KAN和优化版KAN模型性能
compare_kan_models_menu(predictor)
elif choice == '7':
# 查看已有预测结果
pass
elif choice == '8':
# 使用已训练的模型进行预测
pass
elif choice == '9':
# 比较不同模型的预测结果
pass
elif choice == '10':
# 模型管理
pass
else:
print("\n无效的选择,请重试!")
input("\n按Enter键继续...")
if __name__ == "__main__":
main()