154 lines
5.1 KiB
Python
154 lines
5.1 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
测试新的三种训练模式API
|
|
"""
|
|
|
|
import requests
|
|
import json
|
|
import time
|
|
|
|
BASE_URL = "http://localhost:5000"
|
|
|
|
def test_api_endpoint(method, url, data=None, params=None):
|
|
"""测试API端点"""
|
|
try:
|
|
if method.upper() == 'GET':
|
|
response = requests.get(url, params=params)
|
|
elif method.upper() == 'POST':
|
|
response = requests.post(url, json=data)
|
|
else:
|
|
raise ValueError(f"不支持的HTTP方法: {method}")
|
|
|
|
print(f"{method} {url}")
|
|
if data:
|
|
print(f"请求数据: {json.dumps(data, indent=2, ensure_ascii=False)}")
|
|
if params:
|
|
print(f"查询参数: {params}")
|
|
|
|
print(f"状态码: {response.status_code}")
|
|
print(f"响应: {json.dumps(response.json(), indent=2, ensure_ascii=False)}")
|
|
print("-" * 80)
|
|
|
|
return response.json()
|
|
except Exception as e:
|
|
print(f"请求失败: {str(e)}")
|
|
print("-" * 80)
|
|
return None
|
|
|
|
def main():
|
|
print("🧪 测试新的多模式训练API")
|
|
print("=" * 80)
|
|
|
|
# 1. 测试获取店铺列表
|
|
print("📝 1. 测试获取店铺列表")
|
|
stores_response = test_api_endpoint('GET', f"{BASE_URL}/api/stores")
|
|
|
|
# 2. 测试获取产品列表
|
|
print("📝 2. 测试获取产品列表")
|
|
products_response = test_api_endpoint('GET', f"{BASE_URL}/api/products")
|
|
|
|
# 3. 测试获取模型类型
|
|
print("📝 3. 测试获取模型类型")
|
|
model_types_response = test_api_endpoint('GET', f"{BASE_URL}/api/model_types")
|
|
|
|
# 获取第一个店铺和产品用于测试
|
|
if stores_response and stores_response.get('status') == 'success':
|
|
stores = stores_response['data']
|
|
if stores:
|
|
test_store_id = stores[0]['store_id']
|
|
print(f"✅ 使用测试店铺: {test_store_id}")
|
|
else:
|
|
print("❌ 没有可用的店铺数据")
|
|
return
|
|
else:
|
|
print("❌ 无法获取店铺列表")
|
|
return
|
|
|
|
if products_response and products_response.get('status') == 'success':
|
|
products = products_response['data']
|
|
if products:
|
|
test_product_id = products[0]['product_id']
|
|
print(f"✅ 使用测试产品: {test_product_id}")
|
|
else:
|
|
print("❌ 没有可用的产品数据")
|
|
return
|
|
else:
|
|
print("❌ 无法获取产品列表")
|
|
return
|
|
|
|
# 4. 测试获取店铺的产品列表
|
|
print("📝 4. 测试获取店铺产品列表")
|
|
test_api_endpoint('GET', f"{BASE_URL}/api/stores/{test_store_id}/products")
|
|
|
|
# 5. 测试全局训练统计API
|
|
print("📝 5. 测试全局训练统计API")
|
|
test_api_endpoint('GET', f"{BASE_URL}/api/training/global/stats", params={
|
|
'training_scope': 'all_stores_all_products',
|
|
'aggregation_method': 'sum'
|
|
})
|
|
|
|
# 6. 测试按产品训练(全局数据)
|
|
print("📝 6. 测试按产品训练(全局数据)")
|
|
product_training_data = {
|
|
"training_mode": "product",
|
|
"product_id": test_product_id,
|
|
"model_type": "mlstm",
|
|
"epochs": 5, # 使用较少的轮次进行测试
|
|
"store_id": None # 使用全局数据
|
|
}
|
|
product_training_response = test_api_endpoint('POST', f"{BASE_URL}/api/training", product_training_data)
|
|
|
|
# 7. 测试按产品训练(特定店铺)
|
|
print("📝 7. 测试按产品训练(特定店铺)")
|
|
product_store_training_data = {
|
|
"training_mode": "product",
|
|
"product_id": test_product_id,
|
|
"store_id": test_store_id,
|
|
"model_type": "mlstm",
|
|
"epochs": 5
|
|
}
|
|
test_api_endpoint('POST', f"{BASE_URL}/api/training", product_store_training_data)
|
|
|
|
# 8. 测试按店铺训练
|
|
print("📝 8. 测试按店铺训练")
|
|
store_training_data = {
|
|
"training_mode": "store",
|
|
"store_id": test_store_id,
|
|
"model_type": "mlstm",
|
|
"epochs": 5,
|
|
"product_scope": "all"
|
|
}
|
|
test_api_endpoint('POST', f"{BASE_URL}/api/training", store_training_data)
|
|
|
|
# 9. 测试全局训练
|
|
print("📝 9. 测试全局训练")
|
|
global_training_data = {
|
|
"training_mode": "global",
|
|
"model_type": "mlstm",
|
|
"epochs": 5,
|
|
"training_scope": "all_stores_all_products",
|
|
"aggregation_method": "sum"
|
|
}
|
|
test_api_endpoint('POST', f"{BASE_URL}/api/training", global_training_data)
|
|
|
|
# 10. 等待一下,然后查看训练任务列表
|
|
print("📝 10. 等待3秒后查看训练任务列表")
|
|
time.sleep(3)
|
|
training_tasks_response = test_api_endpoint('GET', f"{BASE_URL}/api/training")
|
|
|
|
# 11. 测试模型版本管理API
|
|
print("📝 11. 测试模型版本管理API")
|
|
|
|
# 产品模型版本
|
|
test_api_endpoint('GET', f"{BASE_URL}/api/models/{test_product_id}/mlstm/versions")
|
|
|
|
# 店铺模型版本
|
|
test_api_endpoint('GET', f"{BASE_URL}/api/models/store/{test_store_id}/mlstm/versions")
|
|
|
|
# 全局模型版本
|
|
test_api_endpoint('GET', f"{BASE_URL}/api/models/global/mlstm/versions")
|
|
|
|
print("🎉 测试完成!")
|
|
|
|
if __name__ == "__main__":
|
|
main() |