132 lines
3.7 KiB
Python
132 lines
3.7 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
测试API修复的简单脚本
|
|
"""
|
|
|
|
def test_api_syntax():
|
|
"""测试API文件语法是否正确"""
|
|
print("检查API文件语法...")
|
|
|
|
try:
|
|
# 尝试编译API文件
|
|
with open('server/api.py', 'r', encoding='utf-8') as f:
|
|
content = f.read()
|
|
|
|
compile(content, 'server/api.py', 'exec')
|
|
print("✓ API文件语法检查通过")
|
|
return True
|
|
except SyntaxError as e:
|
|
print(f"✗ API文件语法错误: {e}")
|
|
print(f"行号: {e.lineno}, 位置: {e.offset}")
|
|
return False
|
|
except Exception as e:
|
|
print(f"✗ 检查失败: {e}")
|
|
return False
|
|
|
|
def test_imports():
|
|
"""测试关键导入是否存在"""
|
|
print("\n检查关键导入...")
|
|
|
|
try:
|
|
with open('server/api.py', 'r', encoding='utf-8') as f:
|
|
content = f.read()
|
|
|
|
imports_to_check = [
|
|
'import threading',
|
|
'from core.predictor import PharmacyPredictor',
|
|
'from utils.multi_store_data_utils import'
|
|
]
|
|
|
|
missing_imports = []
|
|
for imp in imports_to_check:
|
|
if imp not in content:
|
|
missing_imports.append(imp)
|
|
|
|
if missing_imports:
|
|
print("✗ 缺少导入:")
|
|
for imp in missing_imports:
|
|
print(f" - {imp}")
|
|
return False
|
|
else:
|
|
print("✓ 所有关键导入都存在")
|
|
return True
|
|
|
|
except Exception as e:
|
|
print(f"✗ 检查失败: {e}")
|
|
return False
|
|
|
|
def test_function_fixes():
|
|
"""测试函数修复是否正确"""
|
|
print("\n检查函数修复...")
|
|
|
|
try:
|
|
with open('server/api.py', 'r', encoding='utf-8') as f:
|
|
content = f.read()
|
|
|
|
# 检查是否包含修复的函数调用
|
|
fixes_to_check = [
|
|
'predictor.train_model(',
|
|
'training_mode=',
|
|
'store_id=',
|
|
'aggregation_method='
|
|
]
|
|
|
|
missing_fixes = []
|
|
for fix in fixes_to_check:
|
|
if fix not in content:
|
|
missing_fixes.append(fix)
|
|
|
|
if missing_fixes:
|
|
print("✗ 缺少修复:")
|
|
for fix in missing_fixes:
|
|
print(f" - {fix}")
|
|
return False
|
|
else:
|
|
print("✓ 所有函数修复都存在")
|
|
return True
|
|
|
|
except Exception as e:
|
|
print(f"✗ 检查失败: {e}")
|
|
return False
|
|
|
|
def main():
|
|
print("=" * 50)
|
|
print("API修复验证")
|
|
print("=" * 50)
|
|
|
|
tests = [
|
|
("语法检查", test_api_syntax),
|
|
("导入检查", test_imports),
|
|
("修复检查", test_function_fixes)
|
|
]
|
|
|
|
results = []
|
|
for test_name, test_func in tests:
|
|
result = test_func()
|
|
results.append(result)
|
|
|
|
print("\n" + "=" * 50)
|
|
print("测试结果:")
|
|
print("=" * 50)
|
|
|
|
for i, (test_name, result) in enumerate(zip([t[0] for t in tests], results)):
|
|
status = "通过" if result else "失败"
|
|
print(f"{i+1}. {test_name}: {status}")
|
|
|
|
overall = all(results)
|
|
print(f"\n总体结果: {'所有测试通过' if overall else '存在失败的测试'}")
|
|
|
|
if overall:
|
|
print("\n✓ API修复完成! 现在应该可以正常处理训练请求了")
|
|
print("\n修复内容:")
|
|
print("- 添加了缺少的 threading 导入")
|
|
print("- 修复了训练函数参数不匹配的问题")
|
|
print("- 统一使用 PharmacyPredictor.train_model() 方法")
|
|
print("- 添加了正确的 training_mode 参数")
|
|
else:
|
|
print("\n✗ 还有一些问题需要解决")
|
|
|
|
return overall
|
|
|
|
if __name__ == "__main__":
|
|
main() |