ShopTRAINING/test/test_new_training_api.py
2025-07-02 11:05:23 +08:00

154 lines
5.1 KiB
Python

#!/usr/bin/env python3
"""
测试新的三种训练模式API
"""
import requests
import json
import time
BASE_URL = "http://localhost:5000"
def test_api_endpoint(method, url, data=None, params=None):
"""测试API端点"""
try:
if method.upper() == 'GET':
response = requests.get(url, params=params)
elif method.upper() == 'POST':
response = requests.post(url, json=data)
else:
raise ValueError(f"不支持的HTTP方法: {method}")
print(f"{method} {url}")
if data:
print(f"请求数据: {json.dumps(data, indent=2, ensure_ascii=False)}")
if params:
print(f"查询参数: {params}")
print(f"状态码: {response.status_code}")
print(f"响应: {json.dumps(response.json(), indent=2, ensure_ascii=False)}")
print("-" * 80)
return response.json()
except Exception as e:
print(f"请求失败: {str(e)}")
print("-" * 80)
return None
def main():
print("🧪 测试新的多模式训练API")
print("=" * 80)
# 1. 测试获取店铺列表
print("📝 1. 测试获取店铺列表")
stores_response = test_api_endpoint('GET', f"{BASE_URL}/api/stores")
# 2. 测试获取产品列表
print("📝 2. 测试获取产品列表")
products_response = test_api_endpoint('GET', f"{BASE_URL}/api/products")
# 3. 测试获取模型类型
print("📝 3. 测试获取模型类型")
model_types_response = test_api_endpoint('GET', f"{BASE_URL}/api/model_types")
# 获取第一个店铺和产品用于测试
if stores_response and stores_response.get('status') == 'success':
stores = stores_response['data']
if stores:
test_store_id = stores[0]['store_id']
print(f"✅ 使用测试店铺: {test_store_id}")
else:
print("❌ 没有可用的店铺数据")
return
else:
print("❌ 无法获取店铺列表")
return
if products_response and products_response.get('status') == 'success':
products = products_response['data']
if products:
test_product_id = products[0]['product_id']
print(f"✅ 使用测试产品: {test_product_id}")
else:
print("❌ 没有可用的产品数据")
return
else:
print("❌ 无法获取产品列表")
return
# 4. 测试获取店铺的产品列表
print("📝 4. 测试获取店铺产品列表")
test_api_endpoint('GET', f"{BASE_URL}/api/stores/{test_store_id}/products")
# 5. 测试全局训练统计API
print("📝 5. 测试全局训练统计API")
test_api_endpoint('GET', f"{BASE_URL}/api/training/global/stats", params={
'training_scope': 'all_stores_all_products',
'aggregation_method': 'sum'
})
# 6. 测试按产品训练(全局数据)
print("📝 6. 测试按产品训练(全局数据)")
product_training_data = {
"training_mode": "product",
"product_id": test_product_id,
"model_type": "mlstm",
"epochs": 5, # 使用较少的轮次进行测试
"store_id": None # 使用全局数据
}
product_training_response = test_api_endpoint('POST', f"{BASE_URL}/api/training", product_training_data)
# 7. 测试按产品训练(特定店铺)
print("📝 7. 测试按产品训练(特定店铺)")
product_store_training_data = {
"training_mode": "product",
"product_id": test_product_id,
"store_id": test_store_id,
"model_type": "mlstm",
"epochs": 5
}
test_api_endpoint('POST', f"{BASE_URL}/api/training", product_store_training_data)
# 8. 测试按店铺训练
print("📝 8. 测试按店铺训练")
store_training_data = {
"training_mode": "store",
"store_id": test_store_id,
"model_type": "mlstm",
"epochs": 5,
"product_scope": "all"
}
test_api_endpoint('POST', f"{BASE_URL}/api/training", store_training_data)
# 9. 测试全局训练
print("📝 9. 测试全局训练")
global_training_data = {
"training_mode": "global",
"model_type": "mlstm",
"epochs": 5,
"training_scope": "all_stores_all_products",
"aggregation_method": "sum"
}
test_api_endpoint('POST', f"{BASE_URL}/api/training", global_training_data)
# 10. 等待一下,然后查看训练任务列表
print("📝 10. 等待3秒后查看训练任务列表")
time.sleep(3)
training_tasks_response = test_api_endpoint('GET', f"{BASE_URL}/api/training")
# 11. 测试模型版本管理API
print("📝 11. 测试模型版本管理API")
# 产品模型版本
test_api_endpoint('GET', f"{BASE_URL}/api/models/{test_product_id}/mlstm/versions")
# 店铺模型版本
test_api_endpoint('GET', f"{BASE_URL}/api/models/store/{test_store_id}/mlstm/versions")
# 全局模型版本
test_api_endpoint('GET', f"{BASE_URL}/api/models/global/mlstm/versions")
print("🎉 测试完成!")
if __name__ == "__main__":
main()