按药品训练-预测跑通

This commit is contained in:
LYFxiaoan 2025-07-17 17:54:53 +08:00
parent e437658b9d
commit 6f3240c723
6 changed files with 92 additions and 22 deletions

Binary file not shown.

View File

@ -133,7 +133,14 @@ def get_model_file_path(product_id: str, model_type: str, version: str) -> str:
# 修正直接使用唯一的product_id它可能包含store_前缀来构建文件名
# 文件名示例: transformer_17002608_epoch_best.pth 或 transformer_store_01010023_epoch_best.pth
filename = f"{model_type}_{product_id}_epoch_{version}.pth"
# 针对 KAN 和 optimized_kan使用 model_manager 的命名约定
if model_type in ['kan', 'optimized_kan']:
# 格式: {model_type}_product_{product_id}_{version}.pth
# 注意KAN trainer 保存时product_id 就是 model_identifier
filename = f"{model_type}_product_{product_id}_{version}.pth"
else:
# 其他模型使用 _epoch_ 约定
filename = f"{model_type}_{product_id}_epoch_{version}.pth"
# 修正直接在根模型目录查找不再使用checkpoints子目录
return os.path.join(DEFAULT_MODEL_DIR, filename)
@ -151,32 +158,46 @@ def get_model_versions(product_id: str, model_type: str) -> list:
# 直接使用传入的product_id构建搜索模式
# 搜索模式,匹配 "transformer_product_17002608_epoch_50.pth" 或 "transformer_product_17002608_epoch_best.pth"
# 修正直接使用唯一的product_id它可能包含store_前缀来构建搜索模式
pattern = f"{model_type}_{product_id}_epoch_*.pth"
# 修正直接在根模型目录查找不再使用checkpoints子目录
search_path = os.path.join(DEFAULT_MODEL_DIR, pattern)
existing_files = glob.glob(search_path)
# 扩展搜索模式以兼容多种命名约定
patterns = [
f"{model_type}_{product_id}_epoch_*.pth", # 原始格式 (e.g., transformer_123_epoch_best.pth)
f"{model_type}_product_{product_id}_*.pth" # KAN/ModelManager格式 (e.g., kan_product_123_v1.pth)
]
existing_files = []
for pattern in patterns:
search_path = os.path.join(DEFAULT_MODEL_DIR, pattern)
existing_files.extend(glob.glob(search_path))
# 旧格式(兼容性支持)
pattern_old = f"{model_type}_model_product_{product_id}.pth"
old_file_path = os.path.join(DEFAULT_MODEL_DIR, pattern_old)
has_old_format = os.path.exists(old_file_path)
old_file_path = os.path.join(DEFAULT_MODEL_DIR, pattern_old)
has_old_format = os.path.exists(old_file_path)
if os.path.exists(old_file_path):
existing_files.append(old_file_path)
versions = set() # 使用集合避免重复
# 从找到的文件中提取版本信息
for file_path in existing_files:
filename = os.path.basename(file_path)
# 匹配 _epoch_ 后面的内容作为版本
version_match = re.search(r"_epoch_(.+)\.pth$", filename)
if version_match:
versions.add(version_match.group(1))
# 如果存在旧格式文件将其视为v1
if has_old_format:
versions.add("v1_legacy") # 添加一个特殊标识
print(f"检测到旧格式模型文件: {old_file_path},将其视为版本 v1_legacy")
# 尝试匹配 _epoch_ 格式
version_match_epoch = re.search(r"_epoch_(.+)\.pth$", filename)
if version_match_epoch:
versions.add(version_match_epoch.group(1))
continue
# 尝试匹配 _product_..._v 格式 (KAN)
version_match_kan = re.search(r"_product_.+_v(\d+)\.pth$", filename)
if version_match_kan:
versions.add(f"v{version_match_kan.group(1)}")
continue
# 尝试匹配旧的 _model_product_ 格式
if pattern_old in filename:
versions.add("v1_legacy") # 添加一个特殊标识
print(f"检测到旧格式模型文件: {old_file_path},将其视为版本 v1_legacy")
continue
# 转换为列表并排序
sorted_versions = sorted(list(versions))

View File

@ -216,11 +216,11 @@ def load_model_and_predict(product_id, model_type, store_id=None, future_days=7,
model = MatrixLSTM(
num_features=config['input_dim'],
hidden_size=config['hidden_size'],
mlstm_layers=config['num_layers'],
mlstm_layers=config['mlstm_layers'],
embed_dim=embed_dim,
dense_dim=dense_dim,
num_heads=num_heads,
dropout_rate=config['dropout'],
dropout_rate=config['dropout_rate'],
num_blocks=num_blocks,
output_sequence_length=config['output_dim']
).to(DEVICE)
@ -241,7 +241,7 @@ def load_model_and_predict(product_id, model_type, store_id=None, future_days=7,
num_features=config['input_dim'],
output_sequence_length=config['output_dim'],
num_channels=[config['hidden_size']] * config['num_layers'],
kernel_size=3,
kernel_size=config['kernel_size'],
dropout=config['dropout']
).to(DEVICE)
else:

View File

@ -168,6 +168,7 @@ def train_product_model_with_kan(product_id, model_identifier, product_df=None,
train_losses = []
test_losses = []
start_time = time.time()
best_loss = float('inf')
for epoch in range(epochs):
model.train()
@ -225,6 +226,43 @@ def train_product_model_with_kan(product_id, model_identifier, product_df=None,
test_loss = test_loss / len(test_loader)
test_losses.append(test_loss)
# 检查是否为最佳模型
model_type_name = 'optimized_kan' if use_optimized else 'kan'
if test_loss < best_loss:
best_loss = test_loss
print(f"🎉 新的最佳模型发现在 epoch {epoch+1},测试损失: {test_loss:.4f}")
# 为保存最佳模型准备数据
best_model_data = {
'model_state_dict': model.state_dict(),
'scaler_X': scaler_X,
'scaler_y': scaler_y,
'config': {
'input_dim': input_dim,
'output_dim': output_dim,
'hidden_size': hidden_size,
'hidden_sizes': [hidden_size, hidden_size * 2, hidden_size],
'sequence_length': sequence_length,
'forecast_horizon': forecast_horizon,
'model_type': model_type_name,
'use_optimized': use_optimized
},
'epoch': epoch + 1
}
# 使用模型管理器保存 'best' 版本
from utils.model_manager import model_manager
model_manager.save_model(
model_data=best_model_data,
product_id=model_identifier,
model_type=model_type_name,
version='best',
store_id=store_id,
training_mode=training_mode,
aggregation_method=aggregation_method,
product_name=product_name
)
if (epoch + 1) % 10 == 0:
print(f"Epoch {epoch+1}/{epochs}, Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}")
@ -301,7 +339,7 @@ def train_product_model_with_kan(product_id, model_identifier, product_df=None,
model_data=model_data,
product_id=model_identifier,
model_type=model_type_name,
version='v1', # KAN训练器默认使用v1
version=f'final_epoch_{epochs}',
store_id=store_id,
training_mode=training_mode,
aggregation_method=aggregation_method,

View File

@ -93,6 +93,17 @@
* 在这个新函数里,确保实例化的是你的 `NewNet` 模型。
* **最关键的一步**: 在保存checkpoint时确保 `config` 字典里包含了重建 `NewNet` 所需的所有超参数(比如层数、节点数等)。
* **重要开发规范:参数命名规则**
为了防止在模型加载时出现参数不匹配的错误(例如 `KeyError: 'num_layers'`),我们制定了以下命名规范:
> **规则:** 对于特定于某个算法的超参数,其在 `config` 字典中的键名key必须以该算法的名称作为前缀或唯一标识。
**示例:**
* 对于 `mLSTM` 模型的层数,键名应为 `mlstm_layers`
* 对于 `TCN` 模型的通道数,键名可以是 `tcn_channels`
* 对于 `Transformer` 模型的编码器层数,键名可以是 `num_encoder_layers` 因为这在Transformer语境下是明确的
**加载模型时** ([`server/predictors/model_predictor.py`](server/predictors/model_predictor.py:1)),必须使用与保存时完全一致的键名来读取这些参数。遵循此规则可以从根本上杜绝因参数名不一致导致的模型加载失败问题。
2. **注册新模型**:
* 打开 `server/core/config.py` 文件。
* 找到 `SUPPORTED_MODELS` 列表。