# 药店销售预测系统 - 多店铺扩展方案 ## 背景 当前系统仅支持单店铺单商品的销售预测,需要扩展为支持多店铺销售预测的系统。 ## 当前系统分析 当前系统存在以下限制: 1. 数据结构:`pharmacy_sales.xlsx` 数据集中没有店铺相关字段 2. 模型训练:训练函数如 `train_product_model_with_mlstm` 只接受单个 `product_id` 参数 3. API设计:预测和训练接口只支持单个产品ID参数 4. 模型保存:模型保存路径仅考虑产品ID,不包含店铺信息 ## 扩展方案 ### 1. 数据库结构修改 #### 1.1 新增数据表 ```sql -- 店铺表 CREATE TABLE stores ( store_id VARCHAR(20) PRIMARY KEY, store_name VARCHAR(100) NOT NULL, location VARCHAR(200), size FLOAT, type VARCHAR(50), opening_date DATE, status VARCHAR(20) DEFAULT 'active' ); -- 店铺-产品关联表 CREATE TABLE store_products ( store_id VARCHAR(20), product_id VARCHAR(20), first_sale_date DATE, PRIMARY KEY (store_id, product_id), FOREIGN KEY (store_id) REFERENCES stores(store_id), FOREIGN KEY (product_id) REFERENCES products(product_id) ); ``` #### 1.2 修改销售数据表结构 在销售数据表中添加 `store_id` 字段: ```sql ALTER TABLE sales ADD COLUMN store_id VARCHAR(20) NOT NULL; ``` ### 2. 数据文件格式修改 修改 `pharmacy_sales.xlsx` 文件格式,添加 `store_id` 和 `store_name` 列。 ### 3. 后端代码修改 #### 3.1 修改数据加载函数 ```python # 修改前 def load_data(file_path='pharmacy_sales.xlsx'): df = pd.read_excel(file_path) return df # 修改后 def load_data(file_path='pharmacy_sales.xlsx', store_id=None): df = pd.read_excel(file_path) if store_id: df = df[df['store_id'] == store_id] return df ``` #### 3.2 修改模型训练函数 ```python # 修改前 def train_product_model(product_id, epochs=50): return train_product_model_with_mlstm(product_id, epochs) # 修改后 def train_product_model(product_id, store_id=None, epochs=50): return train_product_model_with_mlstm(product_id, store_id, epochs) # 修改训练函数 def train_product_model_with_mlstm(product_id, store_id=None, epochs=50): # 读取销售数据 df = pd.read_excel('pharmacy_sales.xlsx') # 筛选特定产品和店铺数据 if store_id: product_df = df[(df['product_id'] == product_id) & (df['store_id'] == store_id)] store_name = product_df['store_name'].iloc[0] model_dir = f'models/mlstm/{store_id}' model_path = f'{model_dir}/{product_id}_model.pt' log_path = f'{model_dir}/{product_id}_log.json' else: product_df = df[df['product_id'] == product_id] store_name = "全部店铺" model_dir = 'models/mlstm' model_path = f'{model_dir}/{product_id}_model.pt' log_path = f'{model_dir}/{product_id}_log.json' product_name = product_df['product_name'].iloc[0] print(f"使用mLSTM模型训练店铺 '{store_name}' 产品 '{product_name}' (ID: {product_id}) 的销售预测模型") # ... 后续训练代码保持不变 ... # 保存模型时添加店铺信息 os.makedirs(model_dir, exist_ok=True) # 保存模型 torch.save({ 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'train_loss': train_losses, 'test_loss': test_losses, 'scaler_X': scaler_X, 'scaler_y': scaler_y, 'features': features, 'look_back': look_back, 'T': T, 'model_type': 'mlstm', 'store_id': store_id }, model_path) # 保存日志 log_data = { 'product_id': product_id, 'product_name': product_name, 'store_id': store_id, 'store_name': store_name, 'model_type': 'mlstm', 'training_completed_at': datetime.now().isoformat(), 'epochs': epochs, 'metrics': metrics, 'file_path': model_path } with open(log_path, 'w', encoding='utf-8') as f: json.dump(log_data, f, indent=4, ensure_ascii=False) return model, metrics ``` #### 3.3 修改预测函数 ```python # 修改前 def load_model_and_predict(product_id, model_type, future_days=7, start_date=None): # ... # 修改后 def load_model_and_predict(product_id, model_type, store_id=None, future_days=7, start_date=None): # 确定模型路径 if store_id: model_dir = f'models/{model_type}/{store_id}' model_path = f'{model_dir}/{product_id}_model.pt' else: model_dir = f'models/{model_type}' model_path = f'{model_dir}/{product_id}_model.pt' # ... 后续代码保持不变 ... ``` ### 4. API接口修改 #### 4.1 修改训练API ```python @app.route('/api/training', methods=['POST']) def start_training(): data = request.json product_id = data.get('product_id') store_id = data.get('store_id') # 新增店铺ID参数 model_type = data.get('model_type', 'mlstm') epochs = data.get('epochs', 50) if not product_id: return jsonify({"status": "error", "message": "缺少product_id参数"}), 400 # 生成任务ID task_id = str(uuid.uuid4()) # 创建训练任务 with tasks_lock: training_tasks[task_id] = { "product_id": product_id, "store_id": store_id, # 新增店铺ID "model_type": model_type, "status": "pending", "start_time": datetime.now().isoformat() } # 启动训练线程 executor.submit(train_task, product_id, store_id, epochs, model_type, task_id) return jsonify({ "status": "success", "message": "训练任务已提交", "task_id": task_id }) def train_task(product_id, store_id, epochs, model_type, task_id): try: # 更新任务状态 with tasks_lock: training_tasks[task_id]["status"] = "running" # 根据模型类型选择训练函数 if model_type == "mlstm": model, metrics = train_product_model_with_mlstm(product_id, store_id, epochs) elif model_type == "transformer": model, metrics = train_product_model_with_transformer(product_id, store_id, epochs) elif model_type == "kan": model, metrics = train_product_model_with_kan(product_id, store_id, epochs) elif model_type == "optimized_kan": model, metrics = train_product_model_with_kan(product_id, store_id, epochs, use_optimized=True) else: raise ValueError(f"不支持的模型类型: {model_type}") # 确定模型路径 if store_id: model_path = f"models/{model_type}/{store_id}/{product_id}_model.pt" else: model_path = f"models/{model_type}/{product_id}_model.pt" # 更新任务状态 with tasks_lock: training_tasks[task_id]["status"] = "completed" training_tasks[task_id]["metrics"] = metrics training_tasks[task_id]["model_path"] = model_path except Exception as e: # 记录错误 with tasks_lock: training_tasks[task_id]["status"] = "failed" training_tasks[task_id]["error"] = str(e) print(f"训练任务 {task_id} 失败: {e}") traceback.print_exc() ``` #### 4.2 修改预测API ```python @app.route('/api/prediction', methods=['POST']) def predict(): data = request.json product_id = data.get('product_id') store_id = data.get('store_id') # 新增店铺ID参数 model_type = data.get('model_type', 'mlstm') future_days = data.get('future_days', 7) include_visualization = data.get('include_visualization', True) start_date = data.get('start_date') if not product_id: return jsonify({"status": "error", "message": "缺少product_id参数"}), 400 try: # 调用预测函数 result = load_model_and_predict( product_id=product_id, model_type=model_type, store_id=store_id, # 传递店铺ID future_days=future_days, start_date=start_date ) # ... 后续代码保持不变 ... ``` ### 5. 前端界面修改 #### 5.1 添加店铺选择器组件 ```vue ``` #### 5.2 修改预测视图 ```vue ``` #### 5.3 添加店铺管理界面 ```vue ``` ### 6. 新增API接口 #### 6.1 店铺管理API ```python # 获取所有店铺 @app.route('/api/stores', methods=['GET']) def get_stores(): try: conn = get_db_connection() cursor = conn.cursor() cursor.execute("SELECT * FROM stores ORDER BY store_name") stores = cursor.fetchall() conn.close() return jsonify({ "status": "success", "data": stores }) except Exception as e: return jsonify({ "status": "error", "message": str(e) }), 500 # 获取单个店铺 @app.route('/api/stores/', methods=['GET']) def get_store(store_id): try: conn = get_db_connection() cursor = conn.cursor() cursor.execute("SELECT * FROM stores WHERE store_id = ?", (store_id,)) store = cursor.fetchone() conn.close() if not store: return jsonify({ "status": "error", "message": f"店铺 {store_id} 不存在" }), 404 return jsonify({ "status": "success", "data": store }) except Exception as e: return jsonify({ "status": "error", "message": str(e) }), 500 # 创建店铺 @app.route('/api/stores', methods=['POST']) def create_store(): data = request.json try: conn = get_db_connection() cursor = conn.cursor() cursor.execute( "INSERT INTO stores (store_id, store_name, location, size, type, opening_date, status) VALUES (?, ?, ?, ?, ?, ?, ?)", ( data['store_id'], data['store_name'], data.get('location'), data.get('size'), data.get('type'), data.get('opening_date'), data.get('status', 'active') ) ) conn.commit() conn.close() return jsonify({ "status": "success", "message": "店铺创建成功", "data": { "store_id": data['store_id'] } }) except Exception as e: return jsonify({ "status": "error", "message": str(e) }), 500 # 更新店铺 @app.route('/api/stores/', methods=['PUT']) def update_store(store_id): data = request.json try: conn = get_db_connection() cursor = conn.cursor() cursor.execute( "UPDATE stores SET store_name = ?, location = ?, size = ?, type = ?, opening_date = ?, status = ? WHERE store_id = ?", ( data['store_name'], data.get('location'), data.get('size'), data.get('type'), data.get('opening_date'), data.get('status'), store_id ) ) conn.commit() conn.close() return jsonify({ "status": "success", "message": "店铺更新成功" }) except Exception as e: return jsonify({ "status": "error", "message": str(e) }), 500 # 删除店铺 @app.route('/api/stores/', methods=['DELETE']) def delete_store(store_id): try: conn = get_db_connection() cursor = conn.cursor() # 检查是否有关联的销售数据 cursor.execute("SELECT COUNT(*) FROM sales WHERE store_id = ?", (store_id,)) count = cursor.fetchone()[0] if count > 0: return jsonify({ "status": "error", "message": f"无法删除店铺 {store_id},存在 {count} 条关联的销售数据" }), 400 # 删除店铺-产品关联 cursor.execute("DELETE FROM store_products WHERE store_id = ?", (store_id,)) # 删除店铺 cursor.execute("DELETE FROM stores WHERE store_id = ?", (store_id,)) conn.commit() conn.close() return jsonify({ "status": "success", "message": "店铺删除成功" }) except Exception as e: return jsonify({ "status": "error", "message": str(e) }), 500 ``` ### 7. 模型目录结构修改 修改模型保存目录结构,以支持多店铺: ``` models/ ├── mlstm/ │ ├── S001/ # 店铺S001的模型 │ │ ├── P001_model.pt │ │ ├── P001_log.json │ │ ├── P002_model.pt │ │ └── P002_log.json │ ├── S002/ # 店铺S002的模型 │ │ ├── P001_model.pt │ │ └── P001_log.json │ └── global/ # 全局模型(所有店铺数据训练) │ ├── P001_model.pt │ └── P001_log.json ├── transformer/ │ └── ... └── kan/ └── ... ``` ## 实施计划 ### 阶段一:数据结构调整 1. 创建店铺表和店铺-产品关联表 2. 修改销售数据表,添加店铺ID字段 3. 更新数据导入/导出功能,支持店铺信息 ### 阶段二:后端功能实现 1. 修改模型训练函数,支持按店铺训练 2. 修改预测函数,支持按店铺预测 3. 实现店铺管理API接口 4. 调整模型保存目录结构 ### 阶段三:前端界面更新 1. 添加店铺选择器组件 2. 创建店铺管理界面 3. 更新预测和训练界面,支持店铺选择 4. 优化数据可视化,支持店铺间比较 ### 阶段四:测试和优化 1. 单元测试和集成测试 2. 性能优化 3. 用户界面优化 4. 文档更新 ## 扩展建议 1. 实现店铺间销售数据对比分析功能 2. 添加店铺聚类分析,识别相似店铺模式 3. 开发跨店铺产品销售趋势分析 4. 实现区域级别的销售预测