ShopTRAINING/run_pharmacy_prediction.py

363 lines
16 KiB
Python

import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from datetime import datetime
import sys
# 导入模块
from pharmacy_predictor import train_product_model, train_product_model_with_mlstm, train_product_model_with_kan, train_product_model_with_transformer
from pharmacy_predictor import load_model_and_predict, load_kan_model_and_predict
from models import ModelManager
def print_header():
print("\n" + "="*60)
print("📊 药店单品销售预测系统 📊")
print("="*60)
def main():
# 初始化模型管理器
model_manager = ModelManager()
# 首先检查数据文件是否存在
try:
df = pd.read_excel('pharmacy_sales.xlsx')
print("✅ 检测到现有数据文件 'pharmacy_sales.xlsx'")
# 获取所有产品
products = df[['product_id', 'product_name']].drop_duplicates().sort_values('product_id')
print(f"\n📋 发现 {len(products)} 种药品:")
for i, (_, row) in enumerate(products.iterrows(), 1):
print(f" {i}. {row['product_name']} (ID: {row['product_id']})")
except:
print("❌ 未找到数据文件。正在生成模拟数据...")
import generate_pharmacy_data
df = pd.read_excel('pharmacy_sales.xlsx')
print("✅ 数据生成完成!")
# 获取所有产品
products = df[['product_id', 'product_name']].drop_duplicates().sort_values('product_id')
while True:
print("\n" + "="*60)
print("📋 请选择操作:")
print(" 1. 训练所有药品的销售预测模型")
print(" 2. 训练单个药品的销售预测模型")
print(" 3. 使用mLSTM模型训练单个药品的销售预测模型")
print(" 4. 使用KAN模型训练单个药品的销售预测模型")
print(" 5. 使用Transformer模型训练单个药品的销售预测模型")
print(" 6. 查看已有预测结果")
print(" 7. 使用已训练的模型进行预测")
print(" 8. 比较不同模型的预测结果")
print(" 9. 模型管理")
print(" 0. 退出")
print("="*60)
choice = input("\n请输入选项 (0-9): ")
if choice == '0':
print("感谢使用药店销售预测系统!再见!")
break
elif choice == '1':
# 训练所有药品的预测模型
print("\n开始训练所有药品的销售预测模型...")
all_metrics = {}
for _, row in products.iterrows():
product_id = row['product_id']
print(f"\n{'='*50}")
print(f"开始训练产品 {row['product_name']} (ID: {product_id}) 的模型")
print(f"{'='*50}")
_, metrics = train_product_model(product_id)
all_metrics[product_id] = metrics
# 输出所有产品的评估指标
print("\n所有产品模型评估结果汇总:")
for product_id, metrics in all_metrics.items():
product_name = df[df['product_id'] == product_id]['product_name'].iloc[0]
print(f"\n{product_name} (ID: {product_id}):")
for metric, value in metrics.items():
print(f" {metric}: {value:.4f}")
print("\n模型训练和评估完成!")
elif choice == '2':
# 训练单个药品的预测模型
print("\n请选择要训练的药品:")
for i, (_, row) in enumerate(products.iterrows(), 1):
print(f" {i}. {row['product_name']} (ID: {row['product_id']})")
product_idx = int(input("\n请输入药品编号 (1-{}): ".format(len(products))))
if 1 <= product_idx <= len(products):
product_id = products.iloc[product_idx-1]['product_id']
product_name = products.iloc[product_idx-1]['product_name']
print(f"\n开始训练 {product_name} (ID: {product_id}) 的销售预测模型...")
_, metrics = train_product_model(product_id)
print(f"\n{product_name} 模型评估指标:")
for metric, value in metrics.items():
print(f" {metric}: {value:.4f}")
print(f"\n模型训练和评估完成!")
else:
print("\n❌ 无效的选择!")
elif choice == '3':
# 使用mLSTM模型训练单个药品
print("\n请选择要训练的药品:")
for i, (_, row) in enumerate(products.iterrows(), 1):
print(f" {i}. {row['product_name']} (ID: {row['product_id']})")
product_idx = int(input("\n请输入药品编号 (1-{}): ".format(len(products))))
if 1 <= product_idx <= len(products):
product_id = products.iloc[product_idx-1]['product_id']
product_name = products.iloc[product_idx-1]['product_name']
print(f"\n开始使用mLSTM模型训练 {product_name} (ID: {product_id}) 的销售预测模型...")
model, metrics = train_product_model_with_mlstm(product_id)
# 保存模型到模型管理器
try:
model_manager.save_model(
model=model,
model_type='mlstm',
product_id=product_id,
metrics=metrics,
features=['sales', 'price', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature'],
look_back=14,
T=7
)
print("✅ 模型已保存到模型管理器")
except Exception as e:
print(f"❌ 保存模型到管理器时出错: {str(e)}")
print(f"\n{product_name} mLSTM模型评估指标:")
for metric, value in metrics.items():
print(f" {metric}: {value:.4f}")
print(f"\n模型训练和评估完成!")
else:
print("\n❌ 无效的选择!")
elif choice == '4':
# 使用KAN模型训练单个药品
print("\n请选择要训练的药品:")
for i, (_, row) in enumerate(products.iterrows(), 1):
print(f" {i}. {row['product_name']} (ID: {row['product_id']})")
product_idx = int(input("\n请输入药品编号 (1-{}): ".format(len(products))))
if 1 <= product_idx <= len(products):
product_id = products.iloc[product_idx-1]['product_id']
product_name = products.iloc[product_idx-1]['product_name']
print(f"\n开始使用KAN模型训练 {product_name} (ID: {product_id}) 的销售预测模型...")
model, metrics = train_product_model_with_kan(product_id)
# 保存模型到模型管理器
try:
model_manager.save_model(
model=model,
model_type='kan',
product_id=product_id,
metrics=metrics,
features=['sales', 'price', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature'],
look_back=14,
T=7
)
print("✅ 模型已保存到模型管理器")
except Exception as e:
print(f"❌ 保存模型到管理器时出错: {str(e)}")
print(f"\n{product_name} KAN模型评估指标:")
for metric, value in metrics.items():
print(f" {metric}: {value:.4f}")
print(f"\n模型训练和评估完成!")
else:
print("\n❌ 无效的选择!")
elif choice == '5':
# 使用Transformer模型训练单个药品
print("\n请选择要训练的药品:")
for i, (_, row) in enumerate(products.iterrows(), 1):
print(f" {i}. {row['product_name']} (ID: {row['product_id']})")
product_idx = int(input("\n请输入药品编号 (1-{}): ".format(len(products))))
if 1 <= product_idx <= len(products):
product_id = products.iloc[product_idx-1]['product_id']
product_name = products.iloc[product_idx-1]['product_name']
print(f"\n开始使用Transformer模型训练 {product_name} (ID: {product_id}) 的销售预测模型...")
model, metrics = train_product_model_with_transformer(product_id)
# 保存模型到模型管理器
try:
model_manager.save_model(
model=model,
model_type='transformer',
product_id=product_id,
metrics=metrics,
features=['sales', 'price', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature'],
look_back=14,
T=7
)
print("✅ 模型已保存到模型管理器")
except Exception as e:
print(f"❌ 保存模型到管理器时出错: {str(e)}")
print(f"\n{product_name} Transformer模型评估指标:")
for metric, value in metrics.items():
print(f" {metric}: {value:.4f}")
print(f"\n模型训练和评估完成!")
else:
print("\n❌ 无效的选择!")
elif choice == '6':
# 查看已有预测结果
print("\n正在搜索预测结果文件...")
prediction_files = [f for f in os.listdir() if f.endswith('_prediction_results.csv')]
if not prediction_files:
print("❌ 未找到任何预测结果文件。请先训练模型生成预测结果。")
continue
print(f"\n找到 {len(prediction_files)} 个预测结果文件:")
for i, file in enumerate(prediction_files, 1):
print(f" {i}. {file}")
file_idx = int(input("\n请选择要查看的文件 (1-{}): ".format(len(prediction_files))))
if 1 <= file_idx <= len(prediction_files):
file_path = prediction_files[file_idx-1]
try:
results_df = pd.read_csv(file_path)
print(f"\n{file_path} 内容:")
print(results_df)
# 可视化结果
plt.figure(figsize=(12, 6))
plt.plot(results_df['date'], results_df['actual_sales'], 'b-', label='实际销量')
plt.plot(results_df['date'], results_df['predicted_sales'], 'r--', label='预测销量')
plt.title('销量预测结果')
plt.xlabel('日期')
plt.ylabel('销量')
plt.legend()
plt.grid(True)
plt.xticks(rotation=45)
plt.tight_layout()
plt.show()
except Exception as e:
print(f"❌ 读取文件时出错: {str(e)}")
else:
print("\n❌ 无效的选择!")
elif choice == '7':
# 使用已训练的模型进行预测
print("\n请选择要使用的模型类型:")
print(" 1. mLSTM模型")
print(" 2. KAN模型")
print(" 3. Transformer模型")
model_choice = input("\n请输入选项 (1-3): ")
if model_choice not in ['1', '2', '3']:
print("\n❌ 无效的选择!")
continue
# 选择产品
print("\n请选择要预测的药品:")
for i, (_, row) in enumerate(products.iterrows(), 1):
print(f" {i}. {row['product_name']} (ID: {row['product_id']})")
product_idx = int(input("\n请输入药品编号 (1-{}): ".format(len(products))))
if 1 <= product_idx <= len(products):
product_id = products.iloc[product_idx-1]['product_id']
product_name = products.iloc[product_idx-1]['product_name']
# 使用模型管理器进行预测
try:
if model_choice == '1':
# mLSTM模型
model_type = 'mlstm'
elif model_choice == '2':
# KAN模型
model_type = 'kan'
else:
# Transformer模型
model_type = 'transformer'
print(f"\n使用{model_type.upper()}模型预测 {product_name} (ID: {product_id}) 的未来销量...")
# 使用模型管理器进行预测
product_df = df[df['product_id'] == product_id].sort_values('date')
predictions = model_manager.predict_with_model(
product_id=product_id,
model_type=model_type,
product_df=product_df
)
if predictions is not None:
print("\n✅ 预测完成!")
except Exception as e:
print(f"\n❌ 预测时出错: {str(e)}")
print("请确保已经训练并保存了对应的模型。")
else:
print("\n❌ 无效的选择!")
elif choice == '8':
# 比较不同模型的预测结果
print("\n请选择要比较的药品:")
for i, (_, row) in enumerate(products.iterrows(), 1):
print(f" {i}. {row['product_name']} (ID: {row['product_id']})")
product_idx = int(input("\n请输入药品编号 (1-{}): ".format(len(products))))
if 1 <= product_idx <= len(products):
product_id = products.iloc[product_idx-1]['product_id']
product_name = products.iloc[product_idx-1]['product_name']
print(f"\n比较 {product_name} (ID: {product_id}) 的不同模型预测结果...")
# 使用模型管理器进行比较
try:
product_df = df[df['product_id'] == product_id].sort_values('date')
comparison = model_manager.compare_models(
product_id=product_id,
product_df=product_df
)
if comparison is not None:
print("\n✅ 比较完成!")
except Exception as e:
print(f"\n❌ 比较时出错: {str(e)}")
print("请确保已经训练并保存了对应的模型。")
else:
print("\n❌ 无效的选择!")
elif choice == '9':
# 模型管理
print("\n启动模型管理工具...")
import model_management
model_management.interactive_mode()
else:
print("\n❌ 无效的选项!请重新输入。")
if __name__ == "__main__":
print_header()
main()