ShopTRAINING/docs/多店铺销售预测系统扩展方案.md

21 KiB
Raw Permalink Blame History

药店销售预测系统 - 多店铺扩展方案

背景

当前系统仅支持单店铺单商品的销售预测,需要扩展为支持多店铺销售预测的系统。

当前系统分析

当前系统存在以下限制:

  1. 数据结构:pharmacy_sales.xlsx 数据集中没有店铺相关字段
  2. 模型训练:训练函数如 train_product_model_with_mlstm 只接受单个 product_id 参数
  3. API设计预测和训练接口只支持单个产品ID参数
  4. 模型保存模型保存路径仅考虑产品ID不包含店铺信息

扩展方案

1. 数据库结构修改

1.1 新增数据表

-- 店铺表
CREATE TABLE stores (
    store_id VARCHAR(20) PRIMARY KEY,
    store_name VARCHAR(100) NOT NULL,
    location VARCHAR(200),
    size FLOAT,
    type VARCHAR(50),
    opening_date DATE,
    status VARCHAR(20) DEFAULT 'active'
);

-- 店铺-产品关联表
CREATE TABLE store_products (
    store_id VARCHAR(20),
    product_id VARCHAR(20),
    first_sale_date DATE,
    PRIMARY KEY (store_id, product_id),
    FOREIGN KEY (store_id) REFERENCES stores(store_id),
    FOREIGN KEY (product_id) REFERENCES products(product_id)
);

1.2 修改销售数据表结构

在销售数据表中添加 store_id 字段:

ALTER TABLE sales ADD COLUMN store_id VARCHAR(20) NOT NULL;

2. 数据文件格式修改

修改 pharmacy_sales.xlsx 文件格式,添加 store_idstore_name 列。

3. 后端代码修改

3.1 修改数据加载函数

# 修改前
def load_data(file_path='pharmacy_sales.xlsx'):
    df = pd.read_excel(file_path)
    return df

# 修改后
def load_data(file_path='pharmacy_sales.xlsx', store_id=None):
    df = pd.read_excel(file_path)
    if store_id:
        df = df[df['store_id'] == store_id]
    return df

3.2 修改模型训练函数

# 修改前
def train_product_model(product_id, epochs=50):
    return train_product_model_with_mlstm(product_id, epochs)

# 修改后
def train_product_model(product_id, store_id=None, epochs=50):
    return train_product_model_with_mlstm(product_id, store_id, epochs)

# 修改训练函数
def train_product_model_with_mlstm(product_id, store_id=None, epochs=50):
    # 读取销售数据
    df = pd.read_excel('pharmacy_sales.xlsx')
    
    # 筛选特定产品和店铺数据
    if store_id:
        product_df = df[(df['product_id'] == product_id) & (df['store_id'] == store_id)]
        store_name = product_df['store_name'].iloc[0]
        model_dir = f'models/mlstm/{store_id}'
        model_path = f'{model_dir}/{product_id}_model.pt'
        log_path = f'{model_dir}/{product_id}_log.json'
    else:
        product_df = df[df['product_id'] == product_id]
        store_name = "全部店铺"
        model_dir = 'models/mlstm'
        model_path = f'{model_dir}/{product_id}_model.pt'
        log_path = f'{model_dir}/{product_id}_log.json'
    
    product_name = product_df['product_name'].iloc[0]
    
    print(f"使用mLSTM模型训练店铺 '{store_name}' 产品 '{product_name}' (ID: {product_id}) 的销售预测模型")
    
    # ... 后续训练代码保持不变 ...
    
    # 保存模型时添加店铺信息
    os.makedirs(model_dir, exist_ok=True)
    
    # 保存模型
    torch.save({
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'train_loss': train_losses,
        'test_loss': test_losses,
        'scaler_X': scaler_X,
        'scaler_y': scaler_y,
        'features': features,
        'look_back': look_back,
        'T': T,
        'model_type': 'mlstm',
        'store_id': store_id
    }, model_path)
    
    # 保存日志
    log_data = {
        'product_id': product_id,
        'product_name': product_name,
        'store_id': store_id,
        'store_name': store_name,
        'model_type': 'mlstm',
        'training_completed_at': datetime.now().isoformat(),
        'epochs': epochs,
        'metrics': metrics,
        'file_path': model_path
    }
    
    with open(log_path, 'w', encoding='utf-8') as f:
        json.dump(log_data, f, indent=4, ensure_ascii=False)
    
    return model, metrics

3.3 修改预测函数

# 修改前
def load_model_and_predict(product_id, model_type, future_days=7, start_date=None):
    # ...

# 修改后
def load_model_and_predict(product_id, model_type, store_id=None, future_days=7, start_date=None):
    # 确定模型路径
    if store_id:
        model_dir = f'models/{model_type}/{store_id}'
        model_path = f'{model_dir}/{product_id}_model.pt'
    else:
        model_dir = f'models/{model_type}'
        model_path = f'{model_dir}/{product_id}_model.pt'
    
    # ... 后续代码保持不变 ...

4. API接口修改

4.1 修改训练API

@app.route('/api/training', methods=['POST'])
def start_training():
    data = request.json
    product_id = data.get('product_id')
    store_id = data.get('store_id')  # 新增店铺ID参数
    model_type = data.get('model_type', 'mlstm')
    epochs = data.get('epochs', 50)
    
    if not product_id:
        return jsonify({"status": "error", "message": "缺少product_id参数"}), 400
    
    # 生成任务ID
    task_id = str(uuid.uuid4())
    
    # 创建训练任务
    with tasks_lock:
        training_tasks[task_id] = {
            "product_id": product_id,
            "store_id": store_id,  # 新增店铺ID
            "model_type": model_type,
            "status": "pending",
            "start_time": datetime.now().isoformat()
        }
    
    # 启动训练线程
    executor.submit(train_task, product_id, store_id, epochs, model_type, task_id)
    
    return jsonify({
        "status": "success", 
        "message": "训练任务已提交", 
        "task_id": task_id
    })

def train_task(product_id, store_id, epochs, model_type, task_id):
    try:
        # 更新任务状态
        with tasks_lock:
            training_tasks[task_id]["status"] = "running"
        
        # 根据模型类型选择训练函数
        if model_type == "mlstm":
            model, metrics = train_product_model_with_mlstm(product_id, store_id, epochs)
        elif model_type == "transformer":
            model, metrics = train_product_model_with_transformer(product_id, store_id, epochs)
        elif model_type == "kan":
            model, metrics = train_product_model_with_kan(product_id, store_id, epochs)
        elif model_type == "optimized_kan":
            model, metrics = train_product_model_with_kan(product_id, store_id, epochs, use_optimized=True)
        else:
            raise ValueError(f"不支持的模型类型: {model_type}")
        
        # 确定模型路径
        if store_id:
            model_path = f"models/{model_type}/{store_id}/{product_id}_model.pt"
        else:
            model_path = f"models/{model_type}/{product_id}_model.pt"
        
        # 更新任务状态
        with tasks_lock:
            training_tasks[task_id]["status"] = "completed"
            training_tasks[task_id]["metrics"] = metrics
            training_tasks[task_id]["model_path"] = model_path
    
    except Exception as e:
        # 记录错误
        with tasks_lock:
            training_tasks[task_id]["status"] = "failed"
            training_tasks[task_id]["error"] = str(e)
        print(f"训练任务 {task_id} 失败: {e}")
        traceback.print_exc()

4.2 修改预测API

@app.route('/api/prediction', methods=['POST'])
def predict():
    data = request.json
    product_id = data.get('product_id')
    store_id = data.get('store_id')  # 新增店铺ID参数
    model_type = data.get('model_type', 'mlstm')
    future_days = data.get('future_days', 7)
    include_visualization = data.get('include_visualization', True)
    start_date = data.get('start_date')
    
    if not product_id:
        return jsonify({"status": "error", "message": "缺少product_id参数"}), 400
    
    try:
        # 调用预测函数
        result = load_model_and_predict(
            product_id=product_id,
            model_type=model_type,
            store_id=store_id,  # 传递店铺ID
            future_days=future_days,
            start_date=start_date
        )
        
        # ... 后续代码保持不变 ...

5. 前端界面修改

5.1 添加店铺选择器组件

<!-- StoreSelector.vue -->
<template>
  <div class="store-selector">
    <label>选择店铺:</label>
    <select v-model="selectedStore" @change="onChange">
      <option value="">全部店铺</option>
      <option v-for="store in stores" :key="store.store_id" :value="store.store_id">
        {{ store.store_name }}
      </option>
    </select>
  </div>
</template>

<script>
export default {
  name: 'StoreSelector',
  data() {
    return {
      selectedStore: '',
      stores: []
    }
  },
  created() {
    this.fetchStores();
  },
  methods: {
    async fetchStores() {
      try {
        const response = await fetch('/api/stores');
        const result = await response.json();
        if (result.status === 'success') {
          this.stores = result.data;
        }
      } catch (error) {
        console.error('获取店铺列表失败:', error);
      }
    },
    onChange() {
      this.$emit('store-selected', this.selectedStore);
    }
  }
}
</script>

5.2 修改预测视图

<!-- PredictionView.vue -->
<template>
  <div class="prediction-view">
    <!-- 添加店铺选择器 -->
    <StoreSelector @store-selected="onStoreSelected" />
    
    <!-- 其他组件保持不变 -->
    <ProductSelector @product-selected="onProductSelected" />
    <ModelSelector @model-selected="onModelSelected" />
    <!-- ... -->
  </div>
</template>

<script>
import StoreSelector from '@/components/StoreSelector.vue';

export default {
  components: {
    StoreSelector,
    // ... 其他组件
  },
  data() {
    return {
      selectedStore: '',
      selectedProduct: '',
      selectedModel: 'mlstm',
      // ... 其他数据
    }
  },
  methods: {
    onStoreSelected(storeId) {
      this.selectedStore = storeId;
    },
    async fetchPrediction() {
      // 修改API调用添加店铺ID
      const requestData = {
        product_id: this.selectedProduct,
        store_id: this.selectedStore,  // 添加店铺ID
        model_type: this.selectedModel,
        future_days: this.futureDays
      };
      
      // ... 后续代码保持不变 ...
    }
  }
}
</script>

5.3 添加店铺管理界面

<!-- StoreManagementView.vue -->
<template>
  <div class="store-management">
    <h1>店铺管理</h1>
    
    <!-- 店铺列表 -->
    <div class="store-list">
      <table>
        <thead>
          <tr>
            <th>ID</th>
            <th>名称</th>
            <th>位置</th>
            <th>类型</th>
            <th>状态</th>
            <th>操作</th>
          </tr>
        </thead>
        <tbody>
          <tr v-for="store in stores" :key="store.store_id">
            <td>{{ store.store_id }}</td>
            <td>{{ store.store_name }}</td>
            <td>{{ store.location }}</td>
            <td>{{ store.type }}</td>
            <td>{{ store.status }}</td>
            <td>
              <button @click="editStore(store)">编辑</button>
              <button @click="deleteStore(store.store_id)">删除</button>
            </td>
          </tr>
        </tbody>
      </table>
    </div>
    
    <!-- 添加/编辑店铺表单 -->
    <div class="store-form">
      <h2>{{ isEditing ? '编辑店铺' : '添加店铺' }}</h2>
      <form @submit.prevent="saveStore">
        <div class="form-group">
          <label>店铺ID:</label>
          <input v-model="currentStore.store_id" :disabled="isEditing" required />
        </div>
        <div class="form-group">
          <label>店铺名称:</label>
          <input v-model="currentStore.store_name" required />
        </div>
        <div class="form-group">
          <label>位置:</label>
          <input v-model="currentStore.location" />
        </div>
        <div class="form-group">
          <label>大小(平方米):</label>
          <input v-model="currentStore.size" type="number" />
        </div>
        <div class="form-group">
          <label>类型:</label>
          <select v-model="currentStore.type">
            <option value="flagship">旗舰店</option>
            <option value="standard">标准店</option>
            <option value="community">社区店</option>
          </select>
        </div>
        <div class="form-group">
          <label>开业日期:</label>
          <input v-model="currentStore.opening_date" type="date" />
        </div>
        <div class="form-group">
          <label>状态:</label>
          <select v-model="currentStore.status">
            <option value="active">营业中</option>
            <option value="closed">已关闭</option>
            <option value="renovating">装修中</option>
          </select>
        </div>
        <div class="form-actions">
          <button type="submit">保存</button>
          <button type="button" @click="resetForm">取消</button>
        </div>
      </form>
    </div>
  </div>
</template>

<script>
export default {
  name: 'StoreManagementView',
  data() {
    return {
      stores: [],
      currentStore: this.getEmptyStore(),
      isEditing: false
    }
  },
  created() {
    this.fetchStores();
  },
  methods: {
    getEmptyStore() {
      return {
        store_id: '',
        store_name: '',
        location: '',
        size: null,
        type: 'standard',
        opening_date: '',
        status: 'active'
      };
    },
    async fetchStores() {
      try {
        const response = await fetch('/api/stores');
        const result = await response.json();
        if (result.status === 'success') {
          this.stores = result.data;
        }
      } catch (error) {
        console.error('获取店铺列表失败:', error);
      }
    },
    editStore(store) {
      this.currentStore = { ...store };
      this.isEditing = true;
    },
    async deleteStore(storeId) {
      if (!confirm(`确定要删除店铺 ${storeId} 吗?`)) return;
      
      try {
        const response = await fetch(`/api/stores/${storeId}`, {
          method: 'DELETE'
        });
        const result = await response.json();
        if (result.status === 'success') {
          this.fetchStores();
        }
      } catch (error) {
        console.error('删除店铺失败:', error);
      }
    },
    async saveStore() {
      try {
        const url = this.isEditing ? `/api/stores/${this.currentStore.store_id}` : '/api/stores';
        const method = this.isEditing ? 'PUT' : 'POST';
        
        const response = await fetch(url, {
          method,
          headers: {
            'Content-Type': 'application/json'
          },
          body: JSON.stringify(this.currentStore)
        });
        
        const result = await response.json();
        if (result.status === 'success') {
          this.fetchStores();
          this.resetForm();
        }
      } catch (error) {
        console.error('保存店铺失败:', error);
      }
    },
    resetForm() {
      this.currentStore = this.getEmptyStore();
      this.isEditing = false;
    }
  }
}
</script>

6. 新增API接口

6.1 店铺管理API

# 获取所有店铺
@app.route('/api/stores', methods=['GET'])
def get_stores():
    try:
        conn = get_db_connection()
        cursor = conn.cursor()
        cursor.execute("SELECT * FROM stores ORDER BY store_name")
        stores = cursor.fetchall()
        conn.close()
        
        return jsonify({
            "status": "success",
            "data": stores
        })
    except Exception as e:
        return jsonify({
            "status": "error",
            "message": str(e)
        }), 500

# 获取单个店铺
@app.route('/api/stores/<store_id>', methods=['GET'])
def get_store(store_id):
    try:
        conn = get_db_connection()
        cursor = conn.cursor()
        cursor.execute("SELECT * FROM stores WHERE store_id = ?", (store_id,))
        store = cursor.fetchone()
        conn.close()
        
        if not store:
            return jsonify({
                "status": "error",
                "message": f"店铺 {store_id} 不存在"
            }), 404
        
        return jsonify({
            "status": "success",
            "data": store
        })
    except Exception as e:
        return jsonify({
            "status": "error",
            "message": str(e)
        }), 500

# 创建店铺
@app.route('/api/stores', methods=['POST'])
def create_store():
    data = request.json
    
    try:
        conn = get_db_connection()
        cursor = conn.cursor()
        cursor.execute(
            "INSERT INTO stores (store_id, store_name, location, size, type, opening_date, status) VALUES (?, ?, ?, ?, ?, ?, ?)",
            (
                data['store_id'],
                data['store_name'],
                data.get('location'),
                data.get('size'),
                data.get('type'),
                data.get('opening_date'),
                data.get('status', 'active')
            )
        )
        conn.commit()
        conn.close()
        
        return jsonify({
            "status": "success",
            "message": "店铺创建成功",
            "data": {
                "store_id": data['store_id']
            }
        })
    except Exception as e:
        return jsonify({
            "status": "error",
            "message": str(e)
        }), 500

# 更新店铺
@app.route('/api/stores/<store_id>', methods=['PUT'])
def update_store(store_id):
    data = request.json
    
    try:
        conn = get_db_connection()
        cursor = conn.cursor()
        cursor.execute(
            "UPDATE stores SET store_name = ?, location = ?, size = ?, type = ?, opening_date = ?, status = ? WHERE store_id = ?",
            (
                data['store_name'],
                data.get('location'),
                data.get('size'),
                data.get('type'),
                data.get('opening_date'),
                data.get('status'),
                store_id
            )
        )
        conn.commit()
        conn.close()
        
        return jsonify({
            "status": "success",
            "message": "店铺更新成功"
        })
    except Exception as e:
        return jsonify({
            "status": "error",
            "message": str(e)
        }), 500

# 删除店铺
@app.route('/api/stores/<store_id>', methods=['DELETE'])
def delete_store(store_id):
    try:
        conn = get_db_connection()
        cursor = conn.cursor()
        
        # 检查是否有关联的销售数据
        cursor.execute("SELECT COUNT(*) FROM sales WHERE store_id = ?", (store_id,))
        count = cursor.fetchone()[0]
        
        if count > 0:
            return jsonify({
                "status": "error",
                "message": f"无法删除店铺 {store_id},存在 {count} 条关联的销售数据"
            }), 400
        
        # 删除店铺-产品关联
        cursor.execute("DELETE FROM store_products WHERE store_id = ?", (store_id,))
        
        # 删除店铺
        cursor.execute("DELETE FROM stores WHERE store_id = ?", (store_id,))
        
        conn.commit()
        conn.close()
        
        return jsonify({
            "status": "success",
            "message": "店铺删除成功"
        })
    except Exception as e:
        return jsonify({
            "status": "error",
            "message": str(e)
        }), 500

7. 模型目录结构修改

修改模型保存目录结构,以支持多店铺:

models/
├── mlstm/
│   ├── S001/                # 店铺S001的模型
│   │   ├── P001_model.pt
│   │   ├── P001_log.json
│   │   ├── P002_model.pt
│   │   └── P002_log.json
│   ├── S002/                # 店铺S002的模型
│   │   ├── P001_model.pt
│   │   └── P001_log.json
│   └── global/              # 全局模型(所有店铺数据训练)
│       ├── P001_model.pt
│       └── P001_log.json
├── transformer/
│   └── ...
└── kan/
    └── ...

实施计划

阶段一:数据结构调整

  1. 创建店铺表和店铺-产品关联表
  2. 修改销售数据表添加店铺ID字段
  3. 更新数据导入/导出功能,支持店铺信息

阶段二:后端功能实现

  1. 修改模型训练函数,支持按店铺训练
  2. 修改预测函数,支持按店铺预测
  3. 实现店铺管理API接口
  4. 调整模型保存目录结构

阶段三:前端界面更新

  1. 添加店铺选择器组件
  2. 创建店铺管理界面
  3. 更新预测和训练界面,支持店铺选择
  4. 优化数据可视化,支持店铺间比较

阶段四:测试和优化

  1. 单元测试和集成测试
  2. 性能优化
  3. 用户界面优化
  4. 文档更新

扩展建议

  1. 实现店铺间销售数据对比分析功能
  2. 添加店铺聚类分析,识别相似店铺模式
  3. 开发跨店铺产品销售趋势分析
  4. 实现区域级别的销售预测