ShopTRAINING/test/simple_model_debug.py
2025-07-02 11:05:23 +08:00

84 lines
2.3 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
简单的模型解析调试
"""
import os
import glob
def debug_model_files():
"""调试模型文件"""
# 获取模型目录
current_dir = os.path.dirname(os.path.abspath(__file__))
model_dir = os.path.join(current_dir, 'saved_models')
print(f"模型目录: {model_dir}")
print(f"目录存在: {os.path.exists(model_dir)}")
if not os.path.exists(model_dir):
print("模型目录不存在!")
return
# 列出所有.pth文件
pattern = os.path.join(model_dir, "*.pth")
model_files = glob.glob(pattern)
print(f"找到 {len(model_files)} 个模型文件:")
for model_file in model_files:
filename = os.path.basename(model_file)
print(f" {filename}")
# 测试解析逻辑
model_id = filename.replace('.pth', '')
print(f" 生成的model_id: {model_id}")
# 测试文件名解析复制ModelManager的逻辑
model_info = parse_model_filename(filename)
print(f" 解析结果: {model_info}")
print()
def parse_model_filename(filename):
"""
简化的文件名解析逻辑复制自ModelManager
"""
if not filename.endswith('.pth'):
return None
base_name = filename.replace('.pth', '')
try:
# 新格式解析
if '_product_' in base_name:
# 产品模式: model_type_product_product_id_version
parts = base_name.split('_product_')
model_type = parts[0]
rest = parts[1]
# 分离产品ID和版本
if '_v' in rest:
last_v_index = rest.rfind('_v')
product_id = rest[:last_v_index]
version = rest[last_v_index+1:]
else:
product_id = rest
version = 'v1'
return {
'model_type': model_type,
'product_id': product_id,
'version': version,
'training_mode': 'product',
'store_id': None,
'aggregation_method': None
}
# 如果不匹配任何格式
return None
except Exception as e:
print(f"解析失败: {e}")
return None
if __name__ == "__main__":
debug_model_files()