84 lines
2.3 KiB
Python
84 lines
2.3 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
简单的模型解析调试
|
||
"""
|
||
|
||
import os
|
||
import glob
|
||
|
||
def debug_model_files():
|
||
"""调试模型文件"""
|
||
|
||
# 获取模型目录
|
||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||
model_dir = os.path.join(current_dir, 'saved_models')
|
||
|
||
print(f"模型目录: {model_dir}")
|
||
print(f"目录存在: {os.path.exists(model_dir)}")
|
||
|
||
if not os.path.exists(model_dir):
|
||
print("模型目录不存在!")
|
||
return
|
||
|
||
# 列出所有.pth文件
|
||
pattern = os.path.join(model_dir, "*.pth")
|
||
model_files = glob.glob(pattern)
|
||
|
||
print(f"找到 {len(model_files)} 个模型文件:")
|
||
for model_file in model_files:
|
||
filename = os.path.basename(model_file)
|
||
print(f" {filename}")
|
||
|
||
# 测试解析逻辑
|
||
model_id = filename.replace('.pth', '')
|
||
print(f" 生成的model_id: {model_id}")
|
||
|
||
# 测试文件名解析(复制ModelManager的逻辑)
|
||
model_info = parse_model_filename(filename)
|
||
print(f" 解析结果: {model_info}")
|
||
print()
|
||
|
||
def parse_model_filename(filename):
|
||
"""
|
||
简化的文件名解析逻辑(复制自ModelManager)
|
||
"""
|
||
if not filename.endswith('.pth'):
|
||
return None
|
||
|
||
base_name = filename.replace('.pth', '')
|
||
|
||
try:
|
||
# 新格式解析
|
||
if '_product_' in base_name:
|
||
# 产品模式: model_type_product_product_id_version
|
||
parts = base_name.split('_product_')
|
||
model_type = parts[0]
|
||
rest = parts[1]
|
||
|
||
# 分离产品ID和版本
|
||
if '_v' in rest:
|
||
last_v_index = rest.rfind('_v')
|
||
product_id = rest[:last_v_index]
|
||
version = rest[last_v_index+1:]
|
||
else:
|
||
product_id = rest
|
||
version = 'v1'
|
||
|
||
return {
|
||
'model_type': model_type,
|
||
'product_id': product_id,
|
||
'version': version,
|
||
'training_mode': 'product',
|
||
'store_id': None,
|
||
'aggregation_method': None
|
||
}
|
||
|
||
# 如果不匹配任何格式
|
||
return None
|
||
|
||
except Exception as e:
|
||
print(f"解析失败: {e}")
|
||
return None
|
||
|
||
if __name__ == "__main__":
|
||
debug_model_files() |