ShopTRAINING/docs/全局模型训练实施手册.md

18 KiB
Raw Permalink Blame History

全局模型训练实施手册

1. 引言

本手册旨在为大规模销售预测系统提供三种核心的"全局模型"Global Models训练策略的详细技术实施指南。全局模型的核心思想是利用多个时间序列之间的共性和关联性训练出比独立模型更强大、更具泛化能力的模型。这对于处理成千上万个店铺和品类的预测任务至关重要。

我们将详细探讨以下三种方案:

  1. 店铺专属全局模型 (Store-Level):一个模型预测一个店铺的所有品类。
  2. 品类专属全局模型 (Category-Level):一个模型预测一个品类的所有店铺销量。(推荐优先实施)
  3. 终极全局模型 (Ultimate Global):一个模型预测所有店铺的所有品类。

2. 方案一:店铺专属全局模型 (Store-Level)

2.1 核心思想

为每一家店铺训练一个专属模型。该模型学习这家店铺特有的销售模式(如客流、促销敏感度),并能同时预测该店铺内所有品类的销售情况。

  • 模型数量2000个 (每个店铺一个)
  • 关键:模型通过学习**品类嵌入Product Embedding**来区分不同品类。

2.2 数据准备与处理

在此方案中,我们将数据按 store_id 进行分组,为每个分组(即每个店铺)准备训练数据。

使用 pyspark 的处理逻辑示例:

from pyspark.sql import SparkSession
from pyspark.sql.functions import col
from pyspark.ml.feature import StringIndexer

# 初始化Spark会话
spark = SparkSession.builder.appName("StoreLevelDataPrep").getOrCreate()

# 1. 加载已经过ETL处理的销售数据宽表
sales_df = spark.read.format("delta").load("/path/to/your/feature_store/sales_wide_table")

# 2. 为品类ID创建全局索引
# 这一步很关键,确保所有模型的品类嵌入层是一致的
product_indexer = StringIndexer(inputCol="product_id", outputCol="product_idx").fit(sales_df)
indexed_df = product_indexer.transform(sales_df)

# 3. 按店铺ID进行分组处理
all_stores = [row.store_id for row in indexed_df.select("store_id").distinct().collect()]

for store in all_stores:
    print(f"Preparing data for store: {store}")
    
    # 筛选出当前店铺的数据
    store_df = indexed_df.filter(col("store_id") == store)
    
    # store_df 现在包含了该店铺所有品类的销售数据
    # product_idx 列将作为模型输入,用于品类嵌入
    
    # 接下来,可以将 store_df 转换为 Pandas DataFrame 或直接在 Spark 中处理
    # 以便送入后续的PyTorch DataLoader
    # store_df.write.format("parquet").save(f"/path/to/training_data/store_level/{store}")

2.3 模型架构PyTorch conceptual

模型需要一个嵌入层来处理 product_idx

import torch
import torch.nn as nn

class StoreLevelGlobalModel(nn.Module):
    def __init__(self, num_products, product_embedding_dim=32, num_other_features=8, hidden_size=128):
        super().__init__()
        
        # 关键:品类嵌入层
        self.product_embedding = nn.Embedding(num_embeddings=num_products, embedding_dim=product_embedding_dim)
        
        # 其他数值型特征的维度
        self.num_other_features = num_other_features
        
        # LSTM 或 Transformer 作为核心
        # 输入维度 = 品类嵌入维度 + 其他特征维度
        input_dim = product_embedding_dim + num_other_features
        self.lstm = nn.LSTM(input_size=input_dim, hidden_size=hidden_size, batch_first=True)
        
        self.fc = nn.Linear(hidden_size, 1) # 预测未来1天的销量

    def forward(self, product_idx_input, other_features_input):
        # product_idx_input shape: (batch, sequence_length)
        # other_features_input shape: (batch, sequence_length, num_other_features)
        
        # 1. 获取品类嵌入向量
        product_embeds = self.product_embedding(product_idx_input)
        
        # 2. 拼接嵌入向量和其他特征
        model_input = torch.cat([product_embeds, other_features_input], dim=2)
        
        # 3. 送入LSTM
        lstm_out, _ = self.lstm(model_input)
        
        # 4. 全连接层预测
        # 只取序列最后一个时间点的输出进行预测
        predictions = self.fc(lstm_out[:, -1, :])
        
        return predictions

2.4 训练流程(概念)

训练流程由一个外部循环驱动,遍历所有店铺。

# 获取所有店铺ID列表
stores_to_train = get_all_store_ids()
# 获取总品类数,用于初始化模型
total_num_products = get_total_product_count() 

for store_id in stores_to_train:
    print(f"--- Training model for Store: {store_id} ---")
    
    # 1. 加载该店铺的训练数据
    train_loader = create_dataloader_for_store(store_id)
    
    # 2. 初始化模型
    model = StoreLevelGlobalModel(num_products=total_num_products)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    criterion = nn.MSELoss()
    
    # 3. 执行标准训练循环
    for epoch in range(num_epochs):
        for product_indices, other_features, targets in train_loader:
            # ... training steps ...
            pass
            
    # 4. 保存该店铺专属的模型
    # 使用MLflow记录模型
    # mlflow.pytorch.log_model(model, f"store_model_{store_id}")
    torch.save(model.state_dict(), f"models/store_level/{store_id}.pt")

2.5 优缺点

  • 优点: 能有效捕捉店铺内部的商品交叉销售和替代效应;对新商品有较好的冷启动能力。
  • 缺点: 无法利用不同店铺间的共性模型数量依然较多2000个管理和更新成本相对较高。

3. 方案二:品类专属全局模型 (Category-Level)

3.1 核心思想

为每一个品类或SKU训练一个专属模型。该模型学习这个品类固有的销售模式如季节性、生命周期并能同时预测它在所有店铺的销售情况。这是业界最常用且性价比最高的方案

  • 模型数量300+个 (每个品类一个)
  • 关键:模型通过学习**店铺嵌入Store Embedding**来区分不同店铺。

3.2 数据准备与处理

与方案一类似,但数据按 product_id 分组。

# (Spark环境已初始化)
sales_df = spark.read.format("delta").load("/path/to/your/feature_store/sales_wide_table")

# 为店铺ID创建全局索引
store_indexer = StringIndexer(inputCol="store_id", outputCol="store_idx").fit(sales_df)
indexed_df = store_indexer.transform(sales_df)

# 按品类ID进行分组处理
all_products = [row.product_id for row in indexed_df.select("product_id").distinct().collect()]

for product in all_products:
    print(f"Preparing data for product: {product}")
    
    # 筛选出当前品类的数据
    product_df = indexed_df.filter(col("product_id") == product)
    
    # product_df 现在包含了该品类在所有店铺的销售数据
    # store_idx 列将作为模型输入,用于店铺嵌入
    # product_df.write.format("parquet").save(f"/path/to/training_data/category_level/{product}")

3.3 模型架构PyTorch conceptual

模型将包含一个 store_id 的嵌入层。

import torch.nn as nn

class CategoryLevelGlobalModel(nn.Module):
    def __init__(self, num_stores, store_embedding_dim=16, num_other_features=8, hidden_size=128):
        super().__init__()
        
        # 关键:店铺嵌入层
        self.store_embedding = nn.Embedding(num_embeddings=num_stores, embedding_dim=store_embedding_dim)
        
        # LSTM输入维度 = 店铺嵌入维度 + 其他特征维度
        input_dim = store_embedding_dim + num_other_features
        self.lstm = nn.LSTM(input_size=input_dim, hidden_size=hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, 1)

    def forward(self, store_idx_input, other_features_input):
        # 获取店铺嵌入向量
        store_embeds = self.store_embedding(store_idx_input)
        
        # 拼接嵌入向量和其他特征
        model_input = torch.cat([store_embeds, other_features_input], dim=2)
        
        # ... 后续与方案一类似 ...
        lstm_out, _ = self.lstm(model_input)
        predictions = self.fc(lstm_out[:, -1, :])
        return predictions

2.4 训练流程(概念)

循环遍历所有品类。

products_to_train = get_all_product_ids()
total_num_stores = get_total_store_count()

for product_id in products_to_train:
    print(f"--- Training model for Product: {product_id} ---")
    
    train_loader = create_dataloader_for_product(product_id)
    model = CategoryLevelGlobalModel(num_stores=total_num_stores)
    # ... training loop ...
    
    # 保存该品类专属的模型
    torch.save(model.state_dict(), f"models/category_level/{product_id}.pt")

2.5 优缺点

  • 优点: 模型数量大幅减少,管理高效;能学习品类在不同市场环境的通用规律;对新开店铺的冷启动效果极佳。
  • 缺点: 可能忽略某些店铺和品类之间非常独特的强关联效应。

4. 方案三:终极全局模型 (Ultimate Global)

4.1 核心思想

只训练一个(或极少数几个)巨大的模型,用它来预测所有店铺的所有品类。该模型通过同时学习店铺和品类的嵌入,来捕捉全局的、品类的、店铺的以及它们之间的交叉模式。

  • 模型数量1
  • 关键:模型同时包含店铺嵌入品类嵌入

4.2 数据准备与处理

无需分组,直接使用全量数据。

# (Spark环境已初始化)
sales_df = spark.read.format("delta").load("/path/to/your/feature_store/sales_wide_table")

# 为店铺和品类ID都创建索引
store_indexer = StringIndexer(inputCol="store_id", outputCol="store_idx").fit(sales_df)
indexed_df = store_indexer.transform(sales_df)

product_indexer = StringIndexer(inputCol="product_id", outputCol="product_idx").fit(indexed_df)
final_df = product_indexer.transform(indexed_df)

# final_df 将作为最终的训练数据源
# final_df.write.format("parquet").save("/path/to/training_data/ultimate_global/")

4.3 模型架构PyTorch conceptual

模型包含两个嵌入层。

import torch.nn as nn

class UltimateGlobalModel(nn.Module):
    def __init__(self, num_stores, num_products, store_embedding_dim=16, product_embedding_dim=32, num_other_features=8, hidden_size=256):
        super().__init__()
        
        # 关键:两个嵌入层
        self.store_embedding = nn.Embedding(num_embeddings=num_stores, embedding_dim=store_embedding_dim)
        self.product_embedding = nn.Embedding(num_embeddings=num_products, embedding_dim=product_embedding_dim)
        
        # LSTM输入维度 = 店铺嵌入 + 品类嵌入 + 其他特征
        input_dim = store_embedding_dim + product_embedding_dim + num_other_features
        self.lstm = nn.LSTM(input_size=input_dim, hidden_size=hidden_size, batch_first=True, num_layers=2) # 模型可以更深
        self.fc = nn.Linear(hidden_size, 1)

    def forward(self, store_idx_input, product_idx_input, other_features_input):
        # 获取店铺和品类的嵌入向量
        store_embeds = self.store_embedding(store_idx_input)
        product_embeds = self.product_embedding(product_idx_input)
        
        # 拼接所有特征
        model_input = torch.cat([store_embeds, product_embeds, other_features_input], dim=2)
        
        # ... 后续 ...
        lstm_out, _ = self.lstm(model_input)
        predictions = self.fc(lstm_out[:, -1, :])
        return predictions

4.4 训练流程(概念)

流程最简单,直接在全量数据上训练一个模型。

print(f"--- Training the Ultimate Global Model ---")
    
# 加载全量训练数据
train_loader = create_dataloader_for_all_data()

# 初始化一个更大、更复杂的模型
model = UltimateGlobalModel(
    num_stores=total_num_stores, 
    num_products=total_num_products,
    hidden_size=512, # 可能需要更大的隐藏层
    num_layers=4
)
# ... training loop ...
    
# 保存唯一的全局模型
torch.save(model.state_dict(), "models/ultimate_global_model.pt")

4.5 优缺点

  • 优点: 最大化利用数据,泛化能力和冷启动能力最强;模型管理极简。
  • 缺点: 对计算资源要求最高;模型设计和调优更复杂,需要仔细处理不同特征的交互;训练时间最长。

5. 总结与建议

  • 起步阶段:强烈建议从 方案二(品类专属全局模型) 开始实施。它在性能、成本和管理复杂度之间取得了最佳平衡,也是业界验证最成熟的方案。
  • 进阶阶段对于A类高价值商品可以尝试 方案三(终极全局模型) 或一个结合了多个品类的分组模型,以追求更高的准确率。
  • 辅助阶段方案一(店铺专属全局模型) 可以作为一种补充,用于分析特定店铺的内部销售动态。

无论选择哪种方案一个强大的MLOps平台MLflow)都是必不可少的,它将帮助您有效地追踪实验、注册模型、管理版本和自动化部署。


6. 补充任务适配KAN模型进行全局训练

6.1 KAN模型简介与动机

KAN (Kolmogorov-Arnold Networks) 是一种新型的神经网络架构其设计灵感来源于Kolmogorov-Arnold表示定理。与传统的MLP多层感知机在节点上使用固定的激活函数不同KAN在网络的"边"连接上放置可学习的激活函数通常是样条函数B-splines

引入动机:

  1. 卓越的可解释性: KAN最大的亮点在于其模型内部的可视化能力。我们可以直接看到每个特征是如何通过一个一维样条函数影响最终输出的从而直观地理解模型学到的非线性关系为业务分析提供深刻洞察。
  2. 更高的参数效率: 理论上KAN可以用更少的参数来拟合复杂函数这可能意味着在同等精度下模型可以更小、更快。

6.2 核心适配机制:嵌入层 + KAN

要将KAN用于全局模型我们采用与前面方案相同的核心技巧

将离散的类别IDstore_id, product_id)通过nn.Embedding层转换为密集的向量然后将这些向量与其他数值特征拼接起来一同作为KAN模型的输入。

6.3 适配代码示例 (PyTorch)

下面的伪代码展示了如何将KAN集成到**方案二(品类专属)方案三(终极全局)**中。

示例1: 品类专属全局模型 + KAN

import torch
import torch.nn as nn
from efficient_kan import KAN

class CategoryLevelGlobalKAN(nn.Module):
    def __init__(self, num_stores, store_embedding_dim=16, num_other_features=8, kan_layers=[64, 32], k=3):
        super().__init__()
        # 1. 店铺嵌入层
        self.store_embedding = nn.Embedding(num_stores, store_embedding_dim)
        
        # 2. KAN模型
        # 注意:输入维度是所有特征维度的总和
        input_dim = store_embedding_dim + num_other_features
        self.kan = KAN([input_dim] + kan_layers + [1], k=k)

    def forward(self, store_idx_input, other_features_input):
        # store_idx_input shape: (batch_size, 1) or (batch_size,)
        # other_features_input shape: (batch_size, num_other_features)
        
        # 将店铺ID转换为嵌入向量
        store_embeds = self.store_embedding(store_idx_input).squeeze(1) # .squeeze() to remove extra dim
        
        # 拼接所有特征
        kan_input = torch.cat([store_embeds, other_features_input], dim=1)
        
        return self.kan(kan_input)

示例2: 终极全局模型 + KAN

import torch
import torch.nn as nn
from efficient_kan import KAN

class UltimateGlobalKAN(nn.Module):
    def __init__(self, num_stores, num_products, store_embedding_dim=16, product_embedding_dim=32, num_other_features=8, kan_layers=[128, 64], k=3):
        super().__init__()
        # 两个独立的嵌入层
        self.store_embedding = nn.Embedding(num_stores, store_embedding_dim)
        self.product_embedding = nn.Embedding(num_products, product_embedding_dim)
        
        # KAN模型的输入维度是所有嵌入和特征维度的总和
        input_dim = store_embedding_dim + product_embedding_dim + num_other_features
        self.kan = KAN([input_dim] + kan_layers + [1], k=k)

    def forward(self, store_idx_input, product_idx_input, other_features_input):
        store_embeds = self.store_embedding(store_idx_input).squeeze(1)
        product_embeds = self.product_embedding(product_idx_input).squeeze(1)
        
        kan_input = torch.cat([store_embeds, product_embeds, other_features_input], dim=1)
        
        return self.kan(kan_input)

重要说明: 上述伪代码展示了KAN用于处理表格化数据的场景。在实际时序预测中需要先通过滑动窗口将时间序列数据如过去14天的特征"铺平"成一个长向量再送入KAN。因此num_other_features 实际上会是 (窗口长度 * 每日特征数)

6.4 实施建议与挑战

  • 实施建议: 将KAN模型的训练作为一个并行的实验性任务。在实施方案二品类专属全局模型可以为一部分重要的B类商品同时训练一个基于Transformer/LSTM的模型和一个基于KAN的模型。对比它们的预测精度、训练开销以及KAN提供的可解释性所带来的业务价值。
  • 主要挑战:
    1. 时序信息处理: KAN是前馈网络不直接处理时间依赖性。必须依赖滑动窗口式的特征工程。
    2. 计算成本: KAN的训练成本可能较高需要合理配置分布式计算资源。
    3. 技术成熟度: 作为前沿模型,在生产环境中应用需要更充分的测试和验证。