ShopTRAINING/test/simple_model_debug.py

84 lines
2.3 KiB
Python
Raw Normal View History

2025-07-02 11:05:23 +08:00
#!/usr/bin/env python3
"""
简单的模型解析调试
"""
import os
import glob
def debug_model_files():
"""调试模型文件"""
# 获取模型目录
current_dir = os.path.dirname(os.path.abspath(__file__))
model_dir = os.path.join(current_dir, 'saved_models')
print(f"模型目录: {model_dir}")
print(f"目录存在: {os.path.exists(model_dir)}")
if not os.path.exists(model_dir):
print("模型目录不存在!")
return
# 列出所有.pth文件
pattern = os.path.join(model_dir, "*.pth")
model_files = glob.glob(pattern)
print(f"找到 {len(model_files)} 个模型文件:")
for model_file in model_files:
filename = os.path.basename(model_file)
print(f" {filename}")
# 测试解析逻辑
model_id = filename.replace('.pth', '')
print(f" 生成的model_id: {model_id}")
# 测试文件名解析复制ModelManager的逻辑
model_info = parse_model_filename(filename)
print(f" 解析结果: {model_info}")
print()
def parse_model_filename(filename):
"""
简化的文件名解析逻辑复制自ModelManager
"""
if not filename.endswith('.pth'):
return None
base_name = filename.replace('.pth', '')
try:
# 新格式解析
if '_product_' in base_name:
# 产品模式: model_type_product_product_id_version
parts = base_name.split('_product_')
model_type = parts[0]
rest = parts[1]
# 分离产品ID和版本
if '_v' in rest:
last_v_index = rest.rfind('_v')
product_id = rest[:last_v_index]
version = rest[last_v_index+1:]
else:
product_id = rest
version = 'v1'
return {
'model_type': model_type,
'product_id': product_id,
'version': version,
'training_mode': 'product',
'store_id': None,
'aggregation_method': None
}
# 如果不匹配任何格式
return None
except Exception as e:
print(f"解析失败: {e}")
return None
if __name__ == "__main__":
debug_model_files()