按药品训练-预测跑通
This commit is contained in:
parent
e437658b9d
commit
6f3240c723
Binary file not shown.
@ -133,6 +133,13 @@ def get_model_file_path(product_id: str, model_type: str, version: str) -> str:
|
|||||||
|
|
||||||
# 修正:直接使用唯一的product_id(它可能包含store_前缀)来构建文件名
|
# 修正:直接使用唯一的product_id(它可能包含store_前缀)来构建文件名
|
||||||
# 文件名示例: transformer_17002608_epoch_best.pth 或 transformer_store_01010023_epoch_best.pth
|
# 文件名示例: transformer_17002608_epoch_best.pth 或 transformer_store_01010023_epoch_best.pth
|
||||||
|
# 针对 KAN 和 optimized_kan,使用 model_manager 的命名约定
|
||||||
|
if model_type in ['kan', 'optimized_kan']:
|
||||||
|
# 格式: {model_type}_product_{product_id}_{version}.pth
|
||||||
|
# 注意:KAN trainer 保存时,product_id 就是 model_identifier
|
||||||
|
filename = f"{model_type}_product_{product_id}_{version}.pth"
|
||||||
|
else:
|
||||||
|
# 其他模型使用 _epoch_ 约定
|
||||||
filename = f"{model_type}_{product_id}_epoch_{version}.pth"
|
filename = f"{model_type}_{product_id}_epoch_{version}.pth"
|
||||||
# 修正:直接在根模型目录查找,不再使用checkpoints子目录
|
# 修正:直接在根模型目录查找,不再使用checkpoints子目录
|
||||||
return os.path.join(DEFAULT_MODEL_DIR, filename)
|
return os.path.join(DEFAULT_MODEL_DIR, filename)
|
||||||
@ -151,32 +158,46 @@ def get_model_versions(product_id: str, model_type: str) -> list:
|
|||||||
# 直接使用传入的product_id构建搜索模式
|
# 直接使用传入的product_id构建搜索模式
|
||||||
# 搜索模式,匹配 "transformer_product_17002608_epoch_50.pth" 或 "transformer_product_17002608_epoch_best.pth"
|
# 搜索模式,匹配 "transformer_product_17002608_epoch_50.pth" 或 "transformer_product_17002608_epoch_best.pth"
|
||||||
# 修正:直接使用唯一的product_id(它可能包含store_前缀)来构建搜索模式
|
# 修正:直接使用唯一的product_id(它可能包含store_前缀)来构建搜索模式
|
||||||
pattern = f"{model_type}_{product_id}_epoch_*.pth"
|
# 扩展搜索模式以兼容多种命名约定
|
||||||
# 修正:直接在根模型目录查找,不再使用checkpoints子目录
|
patterns = [
|
||||||
|
f"{model_type}_{product_id}_epoch_*.pth", # 原始格式 (e.g., transformer_123_epoch_best.pth)
|
||||||
|
f"{model_type}_product_{product_id}_*.pth" # KAN/ModelManager格式 (e.g., kan_product_123_v1.pth)
|
||||||
|
]
|
||||||
|
|
||||||
|
existing_files = []
|
||||||
|
for pattern in patterns:
|
||||||
search_path = os.path.join(DEFAULT_MODEL_DIR, pattern)
|
search_path = os.path.join(DEFAULT_MODEL_DIR, pattern)
|
||||||
existing_files = glob.glob(search_path)
|
existing_files.extend(glob.glob(search_path))
|
||||||
|
|
||||||
# 旧格式(兼容性支持)
|
# 旧格式(兼容性支持)
|
||||||
pattern_old = f"{model_type}_model_product_{product_id}.pth"
|
pattern_old = f"{model_type}_model_product_{product_id}.pth"
|
||||||
old_file_path = os.path.join(DEFAULT_MODEL_DIR, pattern_old)
|
old_file_path = os.path.join(DEFAULT_MODEL_DIR, pattern_old)
|
||||||
has_old_format = os.path.exists(old_file_path)
|
if os.path.exists(old_file_path):
|
||||||
old_file_path = os.path.join(DEFAULT_MODEL_DIR, pattern_old)
|
existing_files.append(old_file_path)
|
||||||
has_old_format = os.path.exists(old_file_path)
|
|
||||||
|
|
||||||
versions = set() # 使用集合避免重复
|
versions = set() # 使用集合避免重复
|
||||||
|
|
||||||
# 从找到的文件中提取版本信息
|
# 从找到的文件中提取版本信息
|
||||||
for file_path in existing_files:
|
for file_path in existing_files:
|
||||||
filename = os.path.basename(file_path)
|
filename = os.path.basename(file_path)
|
||||||
# 匹配 _epoch_ 后面的内容作为版本
|
|
||||||
version_match = re.search(r"_epoch_(.+)\.pth$", filename)
|
|
||||||
if version_match:
|
|
||||||
versions.add(version_match.group(1))
|
|
||||||
|
|
||||||
# 如果存在旧格式文件,将其视为v1
|
# 尝试匹配 _epoch_ 格式
|
||||||
if has_old_format:
|
version_match_epoch = re.search(r"_epoch_(.+)\.pth$", filename)
|
||||||
|
if version_match_epoch:
|
||||||
|
versions.add(version_match_epoch.group(1))
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 尝试匹配 _product_..._v 格式 (KAN)
|
||||||
|
version_match_kan = re.search(r"_product_.+_v(\d+)\.pth$", filename)
|
||||||
|
if version_match_kan:
|
||||||
|
versions.add(f"v{version_match_kan.group(1)}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 尝试匹配旧的 _model_product_ 格式
|
||||||
|
if pattern_old in filename:
|
||||||
versions.add("v1_legacy") # 添加一个特殊标识
|
versions.add("v1_legacy") # 添加一个特殊标识
|
||||||
print(f"检测到旧格式模型文件: {old_file_path},将其视为版本 v1_legacy")
|
print(f"检测到旧格式模型文件: {old_file_path},将其视为版本 v1_legacy")
|
||||||
|
continue
|
||||||
|
|
||||||
# 转换为列表并排序
|
# 转换为列表并排序
|
||||||
sorted_versions = sorted(list(versions))
|
sorted_versions = sorted(list(versions))
|
||||||
|
Binary file not shown.
@ -216,11 +216,11 @@ def load_model_and_predict(product_id, model_type, store_id=None, future_days=7,
|
|||||||
model = MatrixLSTM(
|
model = MatrixLSTM(
|
||||||
num_features=config['input_dim'],
|
num_features=config['input_dim'],
|
||||||
hidden_size=config['hidden_size'],
|
hidden_size=config['hidden_size'],
|
||||||
mlstm_layers=config['num_layers'],
|
mlstm_layers=config['mlstm_layers'],
|
||||||
embed_dim=embed_dim,
|
embed_dim=embed_dim,
|
||||||
dense_dim=dense_dim,
|
dense_dim=dense_dim,
|
||||||
num_heads=num_heads,
|
num_heads=num_heads,
|
||||||
dropout_rate=config['dropout'],
|
dropout_rate=config['dropout_rate'],
|
||||||
num_blocks=num_blocks,
|
num_blocks=num_blocks,
|
||||||
output_sequence_length=config['output_dim']
|
output_sequence_length=config['output_dim']
|
||||||
).to(DEVICE)
|
).to(DEVICE)
|
||||||
@ -241,7 +241,7 @@ def load_model_and_predict(product_id, model_type, store_id=None, future_days=7,
|
|||||||
num_features=config['input_dim'],
|
num_features=config['input_dim'],
|
||||||
output_sequence_length=config['output_dim'],
|
output_sequence_length=config['output_dim'],
|
||||||
num_channels=[config['hidden_size']] * config['num_layers'],
|
num_channels=[config['hidden_size']] * config['num_layers'],
|
||||||
kernel_size=3,
|
kernel_size=config['kernel_size'],
|
||||||
dropout=config['dropout']
|
dropout=config['dropout']
|
||||||
).to(DEVICE)
|
).to(DEVICE)
|
||||||
else:
|
else:
|
||||||
|
@ -168,6 +168,7 @@ def train_product_model_with_kan(product_id, model_identifier, product_df=None,
|
|||||||
train_losses = []
|
train_losses = []
|
||||||
test_losses = []
|
test_losses = []
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
best_loss = float('inf')
|
||||||
|
|
||||||
for epoch in range(epochs):
|
for epoch in range(epochs):
|
||||||
model.train()
|
model.train()
|
||||||
@ -226,6 +227,43 @@ def train_product_model_with_kan(product_id, model_identifier, product_df=None,
|
|||||||
test_loss = test_loss / len(test_loader)
|
test_loss = test_loss / len(test_loader)
|
||||||
test_losses.append(test_loss)
|
test_losses.append(test_loss)
|
||||||
|
|
||||||
|
# 检查是否为最佳模型
|
||||||
|
model_type_name = 'optimized_kan' if use_optimized else 'kan'
|
||||||
|
if test_loss < best_loss:
|
||||||
|
best_loss = test_loss
|
||||||
|
print(f"🎉 新的最佳模型发现在 epoch {epoch+1},测试损失: {test_loss:.4f}")
|
||||||
|
|
||||||
|
# 为保存最佳模型准备数据
|
||||||
|
best_model_data = {
|
||||||
|
'model_state_dict': model.state_dict(),
|
||||||
|
'scaler_X': scaler_X,
|
||||||
|
'scaler_y': scaler_y,
|
||||||
|
'config': {
|
||||||
|
'input_dim': input_dim,
|
||||||
|
'output_dim': output_dim,
|
||||||
|
'hidden_size': hidden_size,
|
||||||
|
'hidden_sizes': [hidden_size, hidden_size * 2, hidden_size],
|
||||||
|
'sequence_length': sequence_length,
|
||||||
|
'forecast_horizon': forecast_horizon,
|
||||||
|
'model_type': model_type_name,
|
||||||
|
'use_optimized': use_optimized
|
||||||
|
},
|
||||||
|
'epoch': epoch + 1
|
||||||
|
}
|
||||||
|
|
||||||
|
# 使用模型管理器保存 'best' 版本
|
||||||
|
from utils.model_manager import model_manager
|
||||||
|
model_manager.save_model(
|
||||||
|
model_data=best_model_data,
|
||||||
|
product_id=model_identifier,
|
||||||
|
model_type=model_type_name,
|
||||||
|
version='best',
|
||||||
|
store_id=store_id,
|
||||||
|
training_mode=training_mode,
|
||||||
|
aggregation_method=aggregation_method,
|
||||||
|
product_name=product_name
|
||||||
|
)
|
||||||
|
|
||||||
if (epoch + 1) % 10 == 0:
|
if (epoch + 1) % 10 == 0:
|
||||||
print(f"Epoch {epoch+1}/{epochs}, Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}")
|
print(f"Epoch {epoch+1}/{epochs}, Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}")
|
||||||
|
|
||||||
@ -301,7 +339,7 @@ def train_product_model_with_kan(product_id, model_identifier, product_df=None,
|
|||||||
model_data=model_data,
|
model_data=model_data,
|
||||||
product_id=model_identifier,
|
product_id=model_identifier,
|
||||||
model_type=model_type_name,
|
model_type=model_type_name,
|
||||||
version='v1', # KAN训练器默认使用v1
|
version=f'final_epoch_{epochs}',
|
||||||
store_id=store_id,
|
store_id=store_id,
|
||||||
training_mode=training_mode,
|
training_mode=training_mode,
|
||||||
aggregation_method=aggregation_method,
|
aggregation_method=aggregation_method,
|
||||||
|
11
项目快速上手指南.md
11
项目快速上手指南.md
@ -93,6 +93,17 @@
|
|||||||
* 在这个新函数里,确保实例化的是你的 `NewNet` 模型。
|
* 在这个新函数里,确保实例化的是你的 `NewNet` 模型。
|
||||||
* **最关键的一步**: 在保存checkpoint时,确保 `config` 字典里包含了重建 `NewNet` 所需的所有超参数(比如层数、节点数等)。
|
* **最关键的一步**: 在保存checkpoint时,确保 `config` 字典里包含了重建 `NewNet` 所需的所有超参数(比如层数、节点数等)。
|
||||||
|
|
||||||
|
* **重要开发规范:参数命名规则**
|
||||||
|
为了防止在模型加载时出现参数不匹配的错误(例如 `KeyError: 'num_layers'`),我们制定了以下命名规范:
|
||||||
|
> **规则:** 对于特定于某个算法的超参数,其在 `config` 字典中的键名(key)必须以该算法的名称作为前缀或唯一标识。
|
||||||
|
|
||||||
|
**示例:**
|
||||||
|
* 对于 `mLSTM` 模型的层数,键名应为 `mlstm_layers`。
|
||||||
|
* 对于 `TCN` 模型的通道数,键名可以是 `tcn_channels`。
|
||||||
|
* 对于 `Transformer` 模型的编码器层数,键名可以是 `num_encoder_layers` (因为这在Transformer语境下是明确的)。
|
||||||
|
|
||||||
|
在 **加载模型时** ([`server/predictors/model_predictor.py`](server/predictors/model_predictor.py:1)),必须使用与保存时完全一致的键名来读取这些参数。遵循此规则可以从根本上杜绝因参数名不一致导致的模型加载失败问题。
|
||||||
|
|
||||||
2. **注册新模型**:
|
2. **注册新模型**:
|
||||||
* 打开 `server/core/config.py` 文件。
|
* 打开 `server/core/config.py` 文件。
|
||||||
* 找到 `SUPPORTED_MODELS` 列表。
|
* 找到 `SUPPORTED_MODELS` 列表。
|
||||||
|
Loading…
x
Reference in New Issue
Block a user