363 lines
16 KiB
Python
363 lines
16 KiB
Python
import os
|
|
import pandas as pd
|
|
import numpy as np
|
|
import matplotlib.pyplot as plt
|
|
from datetime import datetime
|
|
import sys
|
|
|
|
# 导入模块
|
|
from pharmacy_predictor import train_product_model, train_product_model_with_mlstm, train_product_model_with_kan, train_product_model_with_transformer
|
|
from pharmacy_predictor import load_model_and_predict, load_kan_model_and_predict
|
|
from models import ModelManager
|
|
|
|
def print_header():
|
|
print("\n" + "="*60)
|
|
print("📊 药店单品销售预测系统 📊")
|
|
print("="*60)
|
|
|
|
def main():
|
|
# 初始化模型管理器
|
|
model_manager = ModelManager()
|
|
|
|
# 首先检查数据文件是否存在
|
|
try:
|
|
df = pd.read_excel('pharmacy_sales.xlsx')
|
|
print("✅ 检测到现有数据文件 'pharmacy_sales.xlsx'")
|
|
|
|
# 获取所有产品
|
|
products = df[['product_id', 'product_name']].drop_duplicates().sort_values('product_id')
|
|
|
|
print(f"\n📋 发现 {len(products)} 种药品:")
|
|
for i, (_, row) in enumerate(products.iterrows(), 1):
|
|
print(f" {i}. {row['product_name']} (ID: {row['product_id']})")
|
|
except:
|
|
print("❌ 未找到数据文件。正在生成模拟数据...")
|
|
import generate_pharmacy_data
|
|
df = pd.read_excel('pharmacy_sales.xlsx')
|
|
print("✅ 数据生成完成!")
|
|
|
|
# 获取所有产品
|
|
products = df[['product_id', 'product_name']].drop_duplicates().sort_values('product_id')
|
|
|
|
while True:
|
|
print("\n" + "="*60)
|
|
print("📋 请选择操作:")
|
|
print(" 1. 训练所有药品的销售预测模型")
|
|
print(" 2. 训练单个药品的销售预测模型")
|
|
print(" 3. 使用mLSTM模型训练单个药品的销售预测模型")
|
|
print(" 4. 使用KAN模型训练单个药品的销售预测模型")
|
|
print(" 5. 使用Transformer模型训练单个药品的销售预测模型")
|
|
print(" 6. 查看已有预测结果")
|
|
print(" 7. 使用已训练的模型进行预测")
|
|
print(" 8. 比较不同模型的预测结果")
|
|
print(" 9. 模型管理")
|
|
print(" 0. 退出")
|
|
print("="*60)
|
|
|
|
choice = input("\n请输入选项 (0-9): ")
|
|
|
|
if choice == '0':
|
|
print("感谢使用药店销售预测系统!再见!")
|
|
break
|
|
|
|
elif choice == '1':
|
|
# 训练所有药品的预测模型
|
|
print("\n开始训练所有药品的销售预测模型...")
|
|
|
|
all_metrics = {}
|
|
for _, row in products.iterrows():
|
|
product_id = row['product_id']
|
|
print(f"\n{'='*50}")
|
|
print(f"开始训练产品 {row['product_name']} (ID: {product_id}) 的模型")
|
|
print(f"{'='*50}")
|
|
|
|
_, metrics = train_product_model(product_id)
|
|
all_metrics[product_id] = metrics
|
|
|
|
# 输出所有产品的评估指标
|
|
print("\n所有产品模型评估结果汇总:")
|
|
for product_id, metrics in all_metrics.items():
|
|
product_name = df[df['product_id'] == product_id]['product_name'].iloc[0]
|
|
print(f"\n{product_name} (ID: {product_id}):")
|
|
for metric, value in metrics.items():
|
|
print(f" {metric}: {value:.4f}")
|
|
|
|
print("\n模型训练和评估完成!")
|
|
|
|
elif choice == '2':
|
|
# 训练单个药品的预测模型
|
|
print("\n请选择要训练的药品:")
|
|
for i, (_, row) in enumerate(products.iterrows(), 1):
|
|
print(f" {i}. {row['product_name']} (ID: {row['product_id']})")
|
|
|
|
product_idx = int(input("\n请输入药品编号 (1-{}): ".format(len(products))))
|
|
|
|
if 1 <= product_idx <= len(products):
|
|
product_id = products.iloc[product_idx-1]['product_id']
|
|
product_name = products.iloc[product_idx-1]['product_name']
|
|
|
|
print(f"\n开始训练 {product_name} (ID: {product_id}) 的销售预测模型...")
|
|
|
|
_, metrics = train_product_model(product_id)
|
|
|
|
print(f"\n{product_name} 模型评估指标:")
|
|
for metric, value in metrics.items():
|
|
print(f" {metric}: {value:.4f}")
|
|
|
|
print(f"\n模型训练和评估完成!")
|
|
else:
|
|
print("\n❌ 无效的选择!")
|
|
|
|
elif choice == '3':
|
|
# 使用mLSTM模型训练单个药品
|
|
print("\n请选择要训练的药品:")
|
|
for i, (_, row) in enumerate(products.iterrows(), 1):
|
|
print(f" {i}. {row['product_name']} (ID: {row['product_id']})")
|
|
|
|
product_idx = int(input("\n请输入药品编号 (1-{}): ".format(len(products))))
|
|
|
|
if 1 <= product_idx <= len(products):
|
|
product_id = products.iloc[product_idx-1]['product_id']
|
|
product_name = products.iloc[product_idx-1]['product_name']
|
|
|
|
print(f"\n开始使用mLSTM模型训练 {product_name} (ID: {product_id}) 的销售预测模型...")
|
|
|
|
model, metrics = train_product_model_with_mlstm(product_id)
|
|
|
|
# 保存模型到模型管理器
|
|
try:
|
|
model_manager.save_model(
|
|
model=model,
|
|
model_type='mlstm',
|
|
product_id=product_id,
|
|
metrics=metrics,
|
|
features=['sales', 'price', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature'],
|
|
look_back=14,
|
|
T=7
|
|
)
|
|
print("✅ 模型已保存到模型管理器")
|
|
except Exception as e:
|
|
print(f"❌ 保存模型到管理器时出错: {str(e)}")
|
|
|
|
print(f"\n{product_name} mLSTM模型评估指标:")
|
|
for metric, value in metrics.items():
|
|
print(f" {metric}: {value:.4f}")
|
|
|
|
print(f"\n模型训练和评估完成!")
|
|
else:
|
|
print("\n❌ 无效的选择!")
|
|
|
|
elif choice == '4':
|
|
# 使用KAN模型训练单个药品
|
|
print("\n请选择要训练的药品:")
|
|
for i, (_, row) in enumerate(products.iterrows(), 1):
|
|
print(f" {i}. {row['product_name']} (ID: {row['product_id']})")
|
|
|
|
product_idx = int(input("\n请输入药品编号 (1-{}): ".format(len(products))))
|
|
|
|
if 1 <= product_idx <= len(products):
|
|
product_id = products.iloc[product_idx-1]['product_id']
|
|
product_name = products.iloc[product_idx-1]['product_name']
|
|
|
|
print(f"\n开始使用KAN模型训练 {product_name} (ID: {product_id}) 的销售预测模型...")
|
|
|
|
model, metrics = train_product_model_with_kan(product_id)
|
|
|
|
# 保存模型到模型管理器
|
|
try:
|
|
model_manager.save_model(
|
|
model=model,
|
|
model_type='kan',
|
|
product_id=product_id,
|
|
metrics=metrics,
|
|
features=['sales', 'price', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature'],
|
|
look_back=14,
|
|
T=7
|
|
)
|
|
print("✅ 模型已保存到模型管理器")
|
|
except Exception as e:
|
|
print(f"❌ 保存模型到管理器时出错: {str(e)}")
|
|
|
|
print(f"\n{product_name} KAN模型评估指标:")
|
|
for metric, value in metrics.items():
|
|
print(f" {metric}: {value:.4f}")
|
|
|
|
print(f"\n模型训练和评估完成!")
|
|
else:
|
|
print("\n❌ 无效的选择!")
|
|
|
|
elif choice == '5':
|
|
# 使用Transformer模型训练单个药品
|
|
print("\n请选择要训练的药品:")
|
|
for i, (_, row) in enumerate(products.iterrows(), 1):
|
|
print(f" {i}. {row['product_name']} (ID: {row['product_id']})")
|
|
|
|
product_idx = int(input("\n请输入药品编号 (1-{}): ".format(len(products))))
|
|
|
|
if 1 <= product_idx <= len(products):
|
|
product_id = products.iloc[product_idx-1]['product_id']
|
|
product_name = products.iloc[product_idx-1]['product_name']
|
|
|
|
print(f"\n开始使用Transformer模型训练 {product_name} (ID: {product_id}) 的销售预测模型...")
|
|
|
|
model, metrics = train_product_model_with_transformer(product_id)
|
|
|
|
# 保存模型到模型管理器
|
|
try:
|
|
model_manager.save_model(
|
|
model=model,
|
|
model_type='transformer',
|
|
product_id=product_id,
|
|
metrics=metrics,
|
|
features=['sales', 'price', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature'],
|
|
look_back=14,
|
|
T=7
|
|
)
|
|
print("✅ 模型已保存到模型管理器")
|
|
except Exception as e:
|
|
print(f"❌ 保存模型到管理器时出错: {str(e)}")
|
|
|
|
print(f"\n{product_name} Transformer模型评估指标:")
|
|
for metric, value in metrics.items():
|
|
print(f" {metric}: {value:.4f}")
|
|
|
|
print(f"\n模型训练和评估完成!")
|
|
else:
|
|
print("\n❌ 无效的选择!")
|
|
|
|
elif choice == '6':
|
|
# 查看已有预测结果
|
|
print("\n正在搜索预测结果文件...")
|
|
prediction_files = [f for f in os.listdir() if f.endswith('_prediction_results.csv')]
|
|
|
|
if not prediction_files:
|
|
print("❌ 未找到任何预测结果文件。请先训练模型生成预测结果。")
|
|
continue
|
|
|
|
print(f"\n找到 {len(prediction_files)} 个预测结果文件:")
|
|
for i, file in enumerate(prediction_files, 1):
|
|
print(f" {i}. {file}")
|
|
|
|
file_idx = int(input("\n请选择要查看的文件 (1-{}): ".format(len(prediction_files))))
|
|
|
|
if 1 <= file_idx <= len(prediction_files):
|
|
file_path = prediction_files[file_idx-1]
|
|
|
|
try:
|
|
results_df = pd.read_csv(file_path)
|
|
print(f"\n{file_path} 内容:")
|
|
print(results_df)
|
|
|
|
# 可视化结果
|
|
plt.figure(figsize=(12, 6))
|
|
plt.plot(results_df['date'], results_df['actual_sales'], 'b-', label='实际销量')
|
|
plt.plot(results_df['date'], results_df['predicted_sales'], 'r--', label='预测销量')
|
|
plt.title('销量预测结果')
|
|
plt.xlabel('日期')
|
|
plt.ylabel('销量')
|
|
plt.legend()
|
|
plt.grid(True)
|
|
plt.xticks(rotation=45)
|
|
plt.tight_layout()
|
|
plt.show()
|
|
except Exception as e:
|
|
print(f"❌ 读取文件时出错: {str(e)}")
|
|
else:
|
|
print("\n❌ 无效的选择!")
|
|
|
|
elif choice == '7':
|
|
# 使用已训练的模型进行预测
|
|
print("\n请选择要使用的模型类型:")
|
|
print(" 1. mLSTM模型")
|
|
print(" 2. KAN模型")
|
|
print(" 3. Transformer模型")
|
|
|
|
model_choice = input("\n请输入选项 (1-3): ")
|
|
|
|
if model_choice not in ['1', '2', '3']:
|
|
print("\n❌ 无效的选择!")
|
|
continue
|
|
|
|
# 选择产品
|
|
print("\n请选择要预测的药品:")
|
|
for i, (_, row) in enumerate(products.iterrows(), 1):
|
|
print(f" {i}. {row['product_name']} (ID: {row['product_id']})")
|
|
|
|
product_idx = int(input("\n请输入药品编号 (1-{}): ".format(len(products))))
|
|
|
|
if 1 <= product_idx <= len(products):
|
|
product_id = products.iloc[product_idx-1]['product_id']
|
|
product_name = products.iloc[product_idx-1]['product_name']
|
|
|
|
# 使用模型管理器进行预测
|
|
try:
|
|
if model_choice == '1':
|
|
# mLSTM模型
|
|
model_type = 'mlstm'
|
|
elif model_choice == '2':
|
|
# KAN模型
|
|
model_type = 'kan'
|
|
else:
|
|
# Transformer模型
|
|
model_type = 'transformer'
|
|
|
|
print(f"\n使用{model_type.upper()}模型预测 {product_name} (ID: {product_id}) 的未来销量...")
|
|
|
|
# 使用模型管理器进行预测
|
|
product_df = df[df['product_id'] == product_id].sort_values('date')
|
|
predictions = model_manager.predict_with_model(
|
|
product_id=product_id,
|
|
model_type=model_type,
|
|
product_df=product_df
|
|
)
|
|
|
|
if predictions is not None:
|
|
print("\n✅ 预测完成!")
|
|
except Exception as e:
|
|
print(f"\n❌ 预测时出错: {str(e)}")
|
|
print("请确保已经训练并保存了对应的模型。")
|
|
else:
|
|
print("\n❌ 无效的选择!")
|
|
|
|
elif choice == '8':
|
|
# 比较不同模型的预测结果
|
|
print("\n请选择要比较的药品:")
|
|
for i, (_, row) in enumerate(products.iterrows(), 1):
|
|
print(f" {i}. {row['product_name']} (ID: {row['product_id']})")
|
|
|
|
product_idx = int(input("\n请输入药品编号 (1-{}): ".format(len(products))))
|
|
|
|
if 1 <= product_idx <= len(products):
|
|
product_id = products.iloc[product_idx-1]['product_id']
|
|
product_name = products.iloc[product_idx-1]['product_name']
|
|
|
|
print(f"\n比较 {product_name} (ID: {product_id}) 的不同模型预测结果...")
|
|
|
|
# 使用模型管理器进行比较
|
|
try:
|
|
product_df = df[df['product_id'] == product_id].sort_values('date')
|
|
comparison = model_manager.compare_models(
|
|
product_id=product_id,
|
|
product_df=product_df
|
|
)
|
|
|
|
if comparison is not None:
|
|
print("\n✅ 比较完成!")
|
|
except Exception as e:
|
|
print(f"\n❌ 比较时出错: {str(e)}")
|
|
print("请确保已经训练并保存了对应的模型。")
|
|
else:
|
|
print("\n❌ 无效的选择!")
|
|
|
|
elif choice == '9':
|
|
# 模型管理
|
|
print("\n启动模型管理工具...")
|
|
import model_management
|
|
model_management.interactive_mode()
|
|
|
|
else:
|
|
print("\n❌ 无效的选项!请重新输入。")
|
|
|
|
if __name__ == "__main__":
|
|
print_header()
|
|
main() |