#!/usr/bin/env python3 """ 测试新的三种训练模式API """ import requests import json import time BASE_URL = "http://localhost:5000" def test_api_endpoint(method, url, data=None, params=None): """测试API端点""" try: if method.upper() == 'GET': response = requests.get(url, params=params) elif method.upper() == 'POST': response = requests.post(url, json=data) else: raise ValueError(f"不支持的HTTP方法: {method}") print(f"{method} {url}") if data: print(f"请求数据: {json.dumps(data, indent=2, ensure_ascii=False)}") if params: print(f"查询参数: {params}") print(f"状态码: {response.status_code}") print(f"响应: {json.dumps(response.json(), indent=2, ensure_ascii=False)}") print("-" * 80) return response.json() except Exception as e: print(f"请求失败: {str(e)}") print("-" * 80) return None def main(): print("🧪 测试新的多模式训练API") print("=" * 80) # 1. 测试获取店铺列表 print("📝 1. 测试获取店铺列表") stores_response = test_api_endpoint('GET', f"{BASE_URL}/api/stores") # 2. 测试获取产品列表 print("📝 2. 测试获取产品列表") products_response = test_api_endpoint('GET', f"{BASE_URL}/api/products") # 3. 测试获取模型类型 print("📝 3. 测试获取模型类型") model_types_response = test_api_endpoint('GET', f"{BASE_URL}/api/model_types") # 获取第一个店铺和产品用于测试 if stores_response and stores_response.get('status') == 'success': stores = stores_response['data'] if stores: test_store_id = stores[0]['store_id'] print(f"✅ 使用测试店铺: {test_store_id}") else: print("❌ 没有可用的店铺数据") return else: print("❌ 无法获取店铺列表") return if products_response and products_response.get('status') == 'success': products = products_response['data'] if products: test_product_id = products[0]['product_id'] print(f"✅ 使用测试产品: {test_product_id}") else: print("❌ 没有可用的产品数据") return else: print("❌ 无法获取产品列表") return # 4. 测试获取店铺的产品列表 print("📝 4. 测试获取店铺产品列表") test_api_endpoint('GET', f"{BASE_URL}/api/stores/{test_store_id}/products") # 5. 测试全局训练统计API print("📝 5. 测试全局训练统计API") test_api_endpoint('GET', f"{BASE_URL}/api/training/global/stats", params={ 'training_scope': 'all_stores_all_products', 'aggregation_method': 'sum' }) # 6. 测试按产品训练(全局数据) print("📝 6. 测试按产品训练(全局数据)") product_training_data = { "training_mode": "product", "product_id": test_product_id, "model_type": "mlstm", "epochs": 5, # 使用较少的轮次进行测试 "store_id": None # 使用全局数据 } product_training_response = test_api_endpoint('POST', f"{BASE_URL}/api/training", product_training_data) # 7. 测试按产品训练(特定店铺) print("📝 7. 测试按产品训练(特定店铺)") product_store_training_data = { "training_mode": "product", "product_id": test_product_id, "store_id": test_store_id, "model_type": "mlstm", "epochs": 5 } test_api_endpoint('POST', f"{BASE_URL}/api/training", product_store_training_data) # 8. 测试按店铺训练 print("📝 8. 测试按店铺训练") store_training_data = { "training_mode": "store", "store_id": test_store_id, "model_type": "mlstm", "epochs": 5, "product_scope": "all" } test_api_endpoint('POST', f"{BASE_URL}/api/training", store_training_data) # 9. 测试全局训练 print("📝 9. 测试全局训练") global_training_data = { "training_mode": "global", "model_type": "mlstm", "epochs": 5, "training_scope": "all_stores_all_products", "aggregation_method": "sum" } test_api_endpoint('POST', f"{BASE_URL}/api/training", global_training_data) # 10. 等待一下,然后查看训练任务列表 print("📝 10. 等待3秒后查看训练任务列表") time.sleep(3) training_tasks_response = test_api_endpoint('GET', f"{BASE_URL}/api/training") # 11. 测试模型版本管理API print("📝 11. 测试模型版本管理API") # 产品模型版本 test_api_endpoint('GET', f"{BASE_URL}/api/models/{test_product_id}/mlstm/versions") # 店铺模型版本 test_api_endpoint('GET', f"{BASE_URL}/api/models/store/{test_store_id}/mlstm/versions") # 全局模型版本 test_api_endpoint('GET', f"{BASE_URL}/api/models/global/mlstm/versions") print("🎉 测试完成!") if __name__ == "__main__": main()