ShopTRAINING/test/test_api_fix.py
2025-07-02 11:05:23 +08:00

132 lines
3.7 KiB
Python

#!/usr/bin/env python3
"""
测试API修复的简单脚本
"""
def test_api_syntax():
"""测试API文件语法是否正确"""
print("检查API文件语法...")
try:
# 尝试编译API文件
with open('server/api.py', 'r', encoding='utf-8') as f:
content = f.read()
compile(content, 'server/api.py', 'exec')
print("✓ API文件语法检查通过")
return True
except SyntaxError as e:
print(f"✗ API文件语法错误: {e}")
print(f"行号: {e.lineno}, 位置: {e.offset}")
return False
except Exception as e:
print(f"✗ 检查失败: {e}")
return False
def test_imports():
"""测试关键导入是否存在"""
print("\n检查关键导入...")
try:
with open('server/api.py', 'r', encoding='utf-8') as f:
content = f.read()
imports_to_check = [
'import threading',
'from core.predictor import PharmacyPredictor',
'from utils.multi_store_data_utils import'
]
missing_imports = []
for imp in imports_to_check:
if imp not in content:
missing_imports.append(imp)
if missing_imports:
print("✗ 缺少导入:")
for imp in missing_imports:
print(f" - {imp}")
return False
else:
print("✓ 所有关键导入都存在")
return True
except Exception as e:
print(f"✗ 检查失败: {e}")
return False
def test_function_fixes():
"""测试函数修复是否正确"""
print("\n检查函数修复...")
try:
with open('server/api.py', 'r', encoding='utf-8') as f:
content = f.read()
# 检查是否包含修复的函数调用
fixes_to_check = [
'predictor.train_model(',
'training_mode=',
'store_id=',
'aggregation_method='
]
missing_fixes = []
for fix in fixes_to_check:
if fix not in content:
missing_fixes.append(fix)
if missing_fixes:
print("✗ 缺少修复:")
for fix in missing_fixes:
print(f" - {fix}")
return False
else:
print("✓ 所有函数修复都存在")
return True
except Exception as e:
print(f"✗ 检查失败: {e}")
return False
def main():
print("=" * 50)
print("API修复验证")
print("=" * 50)
tests = [
("语法检查", test_api_syntax),
("导入检查", test_imports),
("修复检查", test_function_fixes)
]
results = []
for test_name, test_func in tests:
result = test_func()
results.append(result)
print("\n" + "=" * 50)
print("测试结果:")
print("=" * 50)
for i, (test_name, result) in enumerate(zip([t[0] for t in tests], results)):
status = "通过" if result else "失败"
print(f"{i+1}. {test_name}: {status}")
overall = all(results)
print(f"\n总体结果: {'所有测试通过' if overall else '存在失败的测试'}")
if overall:
print("\n✓ API修复完成! 现在应该可以正常处理训练请求了")
print("\n修复内容:")
print("- 添加了缺少的 threading 导入")
print("- 修复了训练函数参数不匹配的问题")
print("- 统一使用 PharmacyPredictor.train_model() 方法")
print("- 添加了正确的 training_mode 参数")
else:
print("\n✗ 还有一些问题需要解决")
return overall
if __name__ == "__main__":
main()