ShopTRAINING/server/utils/feature_selection.py

133 lines
5.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
新数据管道 - 算法专属特征选择模块
本模块的核心是为系统中的每一种算法模型提供一个专属的特征选择函数。
这取代了过去硬编码或从单一配置文件读取特征列表的方式,
使得每个模型都能使用最适合其特性的特征组合进行训练。
"""
def select_features_for_mlstm(all_columns):
"""
为mLSTM模型挑选最佳特征子集 (已根据权威字典更新)。
策略: 核心时序信号 + 短期滚动特征 + 实际存在的静态上下文特征。
"""
selected = [
# 核心指标
'net_sales_quantity', 'gross_profit_total', 'transaction_count',
# 日期特征
'is_weekend', 'day_of_week', 'month', 'is_holiday',
# 生命周期特征
'lifecycle_days', 'sample_category',
# 短期滚动特征
'net_sales_quantity_rolling_mean_7d', 'return_quantity_rolling_mean_7d', 'net_sales_quantity_rolling_sum_7d',
# 店铺静态特征 (实际存在)
'district', 'poi_residential_count', 'poi_school_count', 'temperature_2m_mean',
# 商品静态特征 (实际存在)
'零售大类代码_encoded', '商品ABC分类_encoded', 'brand_encoded', 'approval_type_encoded'
]
return [col for col in selected if col in all_columns]
def select_features_for_transformer(all_columns):
"""
为Transformer模型挑选特征 (已根据权威字典更新)。
策略: 在mLSTM基础上增加中长期滚动特征和更精细的时间特征。
"""
base_features = select_features_for_mlstm(all_columns)
additional = [
# 中长期滚动特征
'net_sales_quantity_rolling_mean_30d', 'net_sales_quantity_rolling_sum_30d', 'net_sales_quantity_rolling_mean_90d',
# 精细化时间特征
'day_of_year', 'week_of_month'
]
# 合并并去重
return list(dict.fromkeys(base_features + [col for col in additional if col in all_columns]))
def select_features_for_tcn(all_columns):
"""
为TCN模型挑选特征 (已根据权威字典更新)。
策略: 与Transformer类似提供丰富的多尺度特征让卷积核学习。
"""
return select_features_for_transformer(all_columns)
def select_features_for_hybrid(all_columns):
"""
为CNN-BiLSTM-Attention混合模型挑选特征 (已根据权威字典更新)。
策略: 提供核心指标和中长期背景让CNN层自己学习局部模式。
"""
selected = [
'net_sales_quantity', 'gross_profit_total', 'transaction_count',
'is_weekend', 'day_of_week', 'month', 'is_holiday',
'lifecycle_days', 'sample_category',
# 中长期滚动特征作为宏观背景
'net_sales_quantity_rolling_mean_30d', 'net_sales_quantity_rolling_mean_90d',
# 所有实际存在的静态特征
'district', 'poi_residential_count', 'poi_school_count', 'temperature_2m_mean',
'零售大类代码_encoded', '商品ABC分类_encoded', 'brand_encoded', 'approval_type_encoded'
]
return [col for col in selected if col in all_columns]
def select_features_for_xgboost(all_columns):
"""
为XGBoost挑选特征 (已根据权威字典更新)。
策略: 提供除标识符和元数据外的所有特征,让树模型自己挖掘关系。
"""
# 排除标识符、原始日期、可能导致数据泄漏的目标相关变量,以及非数值的地理元数据
excluded = [
'subbh', 'hh', 'kdrq', 'date',
'sales_quantity', 'return_quantity',
'first_sale_date', 'last_sale_date',
'adcode', 'district_name', 'business_areas' # 排除非数值或高基数的分类特征
]
return [col for col in all_columns if col not in excluded]
def select_features_for_kan(all_columns):
"""
为KAN模型挑选特征 (已根据权威字典更新)。
策略: 与XGBoost完全相同提供最全的数值化特征集。
"""
return select_features_for_xgboost(all_columns)
def select_features_for_default(all_columns):
"""
一个后备的、基础的特征集。
"""
return select_features_for_mlstm(all_columns)
# --- 特征选择注册表 ---
FEATURE_SELECTION_REGISTRY = {
'mlstm': select_features_for_mlstm,
'transformer': select_features_for_transformer,
'tcn': select_features_for_tcn,
'xgboost': select_features_for_xgboost,
'kan': select_features_for_kan,
'cnn_bilstm_attention': select_features_for_hybrid,
# 'optimized_kan' 可以复用 'kan' 的选择逻辑
'optimized_kan': select_features_for_kan
}
def get_feature_list_for_model(model_type: str, all_columns: list) -> list:
"""
根据模型类型,从注册表中获取并调用对应的特征选择函数。
Args:
model_type (str): 模型的类型标识符 (e.g., 'mlstm', 'xgboost').
all_columns (list): 从数据文件中加载的全部列名。
Returns:
list: 为该模型挑选出的特征列表。
"""
if model_type not in FEATURE_SELECTION_REGISTRY:
print(f"警告: 未在注册表中找到模型 '{model_type}' 的专属特征选择函数,将使用默认的基础特征集。")
selection_function = select_features_for_default
else:
selection_function = FEATURE_SELECTION_REGISTRY[model_type]
return selection_function(all_columns)