318 lines
13 KiB
Python
318 lines
13 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
"""
|
||
综合测试脚本 - 验证现代化系统是否正确工作
|
||
"""
|
||
|
||
import os
|
||
import sys
|
||
import urllib.request
|
||
import json
|
||
import time
|
||
import threading
|
||
from datetime import datetime
|
||
|
||
# 设置UTF-8编码
|
||
os.environ['PYTHONIOENCODING'] = 'utf-8'
|
||
os.environ['PYTHONLEGACYWINDOWSSTDIO'] = '0'
|
||
|
||
if os.name == 'nt':
|
||
try:
|
||
os.system('chcp 65001 >nul 2>&1')
|
||
if hasattr(sys.stdout, 'reconfigure'):
|
||
sys.stdout.reconfigure(encoding='utf-8', errors='replace')
|
||
sys.stderr.reconfigure(encoding='utf-8', errors='replace')
|
||
except Exception:
|
||
pass
|
||
|
||
class ModernSystemTester:
|
||
"""现代化系统测试器"""
|
||
|
||
def __init__(self):
|
||
self.base_url = "http://127.0.0.1:5000"
|
||
self.test_results = []
|
||
|
||
def log_test(self, test_name, status, message="", details=None):
|
||
"""记录测试结果"""
|
||
result = {
|
||
"test": test_name,
|
||
"status": status,
|
||
"message": message,
|
||
"details": details,
|
||
"timestamp": datetime.now().strftime('%H:%M:%S')
|
||
}
|
||
self.test_results.append(result)
|
||
|
||
status_emoji = "✅" if status == "PASS" else "❌" if status == "FAIL" else "⏳"
|
||
print(f"{status_emoji} [{result['timestamp']}] {test_name}: {message}")
|
||
if details:
|
||
print(f" 详情: {details}")
|
||
|
||
def check_api_connectivity(self):
|
||
"""测试1: API连接性"""
|
||
try:
|
||
with urllib.request.urlopen(f"{self.base_url}/api/version", timeout=5) as response:
|
||
if response.getcode() == 200:
|
||
data = json.loads(response.read().decode('utf-8'))
|
||
version = data.get('data', {}).get('version', 'unknown')
|
||
features = data.get('data', {}).get('features', [])
|
||
self.log_test("API连接性", "PASS", f"版本 {version}", f"特性: {len(features)}个")
|
||
return True
|
||
except Exception as e:
|
||
self.log_test("API连接性", "FAIL", f"连接失败: {str(e)}")
|
||
return False
|
||
|
||
def test_modern_logging(self):
|
||
"""测试2: 现代化日志系统"""
|
||
try:
|
||
# 检查日志文件是否存在
|
||
log_files = []
|
||
today = datetime.now().strftime('%Y-%m-%d')
|
||
|
||
expected_files = [
|
||
f"api_{today}.log",
|
||
f"api_error_{today}.log"
|
||
]
|
||
|
||
for file in expected_files:
|
||
if os.path.exists(file):
|
||
log_files.append(file)
|
||
|
||
if log_files:
|
||
self.log_test("日志文件生成", "PASS", f"发现 {len(log_files)} 个日志文件", log_files)
|
||
return True
|
||
else:
|
||
self.log_test("日志文件生成", "FAIL", "未发现loguru日志文件")
|
||
return False
|
||
except Exception as e:
|
||
self.log_test("日志文件生成", "FAIL", f"检查失败: {str(e)}")
|
||
return False
|
||
|
||
def test_training_process_manager(self):
|
||
"""测试3: 训练进程管理器"""
|
||
try:
|
||
# 提交训练任务
|
||
training_data = {
|
||
"product_id": "P001",
|
||
"model_type": "transformer",
|
||
"epochs": 2,
|
||
"training_mode": "product"
|
||
}
|
||
|
||
json_data = json.dumps(training_data).encode('utf-8')
|
||
req = urllib.request.Request(
|
||
f"{self.base_url}/api/training",
|
||
data=json_data,
|
||
headers={'Content-Type': 'application/json'}
|
||
)
|
||
|
||
with urllib.request.urlopen(req, timeout=30) as response:
|
||
if response.getcode() == 200:
|
||
result = json.loads(response.read().decode('utf-8'))
|
||
task_id = result.get('task_id')
|
||
message = result.get('message', '')
|
||
|
||
# 检查是否使用了现代化进程管理器
|
||
uses_modern_system = '独立进程' in message
|
||
|
||
if uses_modern_system:
|
||
self.log_test("进程管理器", "PASS", "使用独立进程管理器", f"任务ID: {task_id[:8]}")
|
||
return task_id
|
||
else:
|
||
self.log_test("进程管理器", "FAIL", "未使用现代化进程管理器", message)
|
||
return None
|
||
|
||
except Exception as e:
|
||
self.log_test("进程管理器", "FAIL", f"提交任务失败: {str(e)}")
|
||
return None
|
||
|
||
def test_task_status_tracking(self, task_id):
|
||
"""测试4: 任务状态跟踪"""
|
||
if not task_id:
|
||
self.log_test("任务状态跟踪", "SKIP", "无有效任务ID")
|
||
return False
|
||
|
||
try:
|
||
max_checks = 10
|
||
for i in range(max_checks):
|
||
with urllib.request.urlopen(f"{self.base_url}/api/training/{task_id}", timeout=5) as response:
|
||
if response.getcode() == 200:
|
||
result = json.loads(response.read().decode('utf-8'))
|
||
if result.get('status') == 'success':
|
||
data = result.get('data', {})
|
||
status = data.get('status', 'unknown')
|
||
progress = data.get('progress', 0)
|
||
process_id = data.get('process_id', 'N/A')
|
||
|
||
if i == 0: # 第一次检查
|
||
self.log_test("任务状态跟踪", "PASS",
|
||
f"状态={status}, 进度={progress}%, 进程ID={process_id}")
|
||
|
||
if status in ['completed', 'failed']:
|
||
return status == 'completed'
|
||
|
||
time.sleep(2)
|
||
|
||
self.log_test("任务状态跟踪", "FAIL", f"超时: {max_checks * 2}秒后任务仍未完成")
|
||
return False
|
||
|
||
except Exception as e:
|
||
self.log_test("任务状态跟踪", "FAIL", f"状态查询失败: {str(e)}")
|
||
return False
|
||
|
||
def test_unicode_support(self):
|
||
"""测试5: 中文和emoji支持"""
|
||
try:
|
||
# 检查产品列表API是否正确返回中文
|
||
with urllib.request.urlopen(f"{self.base_url}/api/products", timeout=5) as response:
|
||
if response.getcode() == 200:
|
||
result = json.loads(response.read().decode('utf-8'))
|
||
products = result.get('data', [])
|
||
|
||
chinese_found = False
|
||
for product in products:
|
||
name = product.get('name', '')
|
||
if any('\u4e00' <= char <= '\u9fff' for char in name): # 检查中文字符
|
||
chinese_found = True
|
||
break
|
||
|
||
if chinese_found:
|
||
self.log_test("中文支持", "PASS", f"发现 {len(products)} 个产品,包含中文")
|
||
|
||
# 测试emoji是否在控制台正常显示
|
||
test_emoji = "🚀📊🔧💬✅❌⏳🎯"
|
||
print(f" Emoji测试: {test_emoji}")
|
||
self.log_test("Emoji支持", "PASS", "控制台emoji显示正常")
|
||
return True
|
||
else:
|
||
self.log_test("中文支持", "FAIL", "未发现中文产品名称")
|
||
return False
|
||
|
||
except Exception as e:
|
||
self.log_test("中文支持", "FAIL", f"测试失败: {str(e)}")
|
||
return False
|
||
|
||
def test_concurrent_training(self):
|
||
"""测试6: 并发训练能力"""
|
||
try:
|
||
# 同时提交3个训练任务
|
||
models = ['mlstm', 'tcn', 'kan']
|
||
task_ids = []
|
||
|
||
for model in models:
|
||
training_data = {
|
||
"product_id": "P001",
|
||
"model_type": model,
|
||
"epochs": 1,
|
||
"training_mode": "product"
|
||
}
|
||
|
||
json_data = json.dumps(training_data).encode('utf-8')
|
||
req = urllib.request.Request(
|
||
f"{self.base_url}/api/training",
|
||
data=json_data,
|
||
headers={'Content-Type': 'application/json'}
|
||
)
|
||
|
||
with urllib.request.urlopen(req, timeout=30) as response:
|
||
if response.getcode() == 200:
|
||
result = json.loads(response.read().decode('utf-8'))
|
||
task_ids.append(result.get('task_id'))
|
||
|
||
time.sleep(0.5) # 稍微错开提交时间
|
||
|
||
if len(task_ids) == 3:
|
||
self.log_test("并发训练", "PASS", f"成功提交 {len(task_ids)} 个并发任务")
|
||
return True
|
||
else:
|
||
self.log_test("并发训练", "FAIL", f"只成功提交 {len(task_ids)}/3 个任务")
|
||
return False
|
||
|
||
except Exception as e:
|
||
self.log_test("并发训练", "FAIL", f"并发测试失败: {str(e)}")
|
||
return False
|
||
|
||
def test_all_tasks_api(self):
|
||
"""测试7: 任务列表API"""
|
||
try:
|
||
with urllib.request.urlopen(f"{self.base_url}/api/training", timeout=5) as response:
|
||
if response.getcode() == 200:
|
||
result = json.loads(response.read().decode('utf-8'))
|
||
if result.get('status') == 'success':
|
||
tasks = result.get('data', [])
|
||
self.log_test("任务列表API", "PASS", f"发现 {len(tasks)} 个历史任务")
|
||
return True
|
||
else:
|
||
self.log_test("任务列表API", "FAIL", "API返回错误状态")
|
||
return False
|
||
except Exception as e:
|
||
self.log_test("任务列表API", "FAIL", f"API调用失败: {str(e)}")
|
||
return False
|
||
|
||
def generate_report(self):
|
||
"""生成测试报告"""
|
||
print("\n" + "=" * 80)
|
||
print("🎯 现代化系统测试报告")
|
||
print("=" * 80)
|
||
|
||
total_tests = len(self.test_results)
|
||
passed_tests = len([r for r in self.test_results if r['status'] == 'PASS'])
|
||
failed_tests = len([r for r in self.test_results if r['status'] == 'FAIL'])
|
||
skipped_tests = len([r for r in self.test_results if r['status'] == 'SKIP'])
|
||
|
||
print(f"📊 测试统计:")
|
||
print(f" 总计: {total_tests} 个测试")
|
||
print(f" 通过: {passed_tests} 个 ✅")
|
||
print(f" 失败: {failed_tests} 个 ❌")
|
||
print(f" 跳过: {skipped_tests} 个 ⏭️")
|
||
print(f" 成功率: {passed_tests/total_tests*100:.1f}%")
|
||
|
||
print(f"\n📋 详细结果:")
|
||
for result in self.test_results:
|
||
status_emoji = "✅" if result['status'] == "PASS" else "❌" if result['status'] == "FAIL" else "⏭️"
|
||
print(f" {status_emoji} {result['test']}: {result['message']}")
|
||
|
||
print(f"\n🎯 总体评估:")
|
||
if failed_tests == 0:
|
||
print(" 🏆 所有测试通过!现代化系统工作完全正常")
|
||
return True
|
||
elif failed_tests <= 2:
|
||
print(" ⚠️ 大部分测试通过,系统基本正常,有少量问题需要修复")
|
||
return False
|
||
else:
|
||
print(" ❌ 多个测试失败,现代化系统存在问题需要进一步修复")
|
||
return False
|
||
|
||
def run_all_tests(self):
|
||
"""运行所有测试"""
|
||
print("🚀 开始现代化系统综合测试")
|
||
print("=" * 80)
|
||
|
||
# 测试序列
|
||
if not self.check_api_connectivity():
|
||
print("❌ API服务器未运行,请先启动: uv run server/modern_api.py")
|
||
return False
|
||
|
||
self.test_modern_logging()
|
||
task_id = self.test_training_process_manager()
|
||
self.test_task_status_tracking(task_id)
|
||
self.test_unicode_support()
|
||
self.test_concurrent_training()
|
||
self.test_all_tasks_api()
|
||
|
||
return self.generate_report()
|
||
|
||
def main():
|
||
"""主函数"""
|
||
tester = ModernSystemTester()
|
||
success = tester.run_all_tests()
|
||
|
||
if success:
|
||
print("\n🎉 现代化系统修改完全正确!")
|
||
exit(0)
|
||
else:
|
||
print("\n⚠️ 现代化系统存在问题,需要进一步修复")
|
||
exit(1)
|
||
|
||
if __name__ == "__main__":
|
||
main() |