ShopTRAINING/server/api_test.py

104 lines
3.1 KiB
Python

import requests
import json
import matplotlib.pyplot as plt
import base64
from io import BytesIO
# API基础URL
API_BASE_URL = 'http://localhost:5000/api'
# 示例1: 获取产品列表
def get_products():
response = requests.get(f'{API_BASE_URL}/products')
return response.json()
# 示例2: 训练模型
def train_model(product_id, model_type='mlstm'):
data = {
'product_id': product_id,
'model_type': model_type,
'parameters': {
'look_back': 14,
'future_days': 7,
'batch_size': 64,
'epochs': 100
}
}
response = requests.post(f'{API_BASE_URL}/training', json=data)
return response.json()
# 示例3: 查询训练状态
def check_training_status(task_id):
response = requests.get(f'{API_BASE_URL}/training/{task_id}')
return response.json()
# 示例4: 使用模型预测
def predict_sales(product_id, model_type='mlstm'):
data = {
'product_id': product_id,
'model_type': model_type,
'future_days': 7,
'include_visualization': True
}
response = requests.post(f'{API_BASE_URL}/prediction', json=data)
result = response.json()
# 如果包含可视化,显示图表
if 'data' in result and 'visualization' in result['data']:
img_data = base64.b64decode(result['data']['visualization'])
img = plt.imread(BytesIO(img_data))
plt.figure(figsize=(12, 6))
plt.imshow(img)
plt.axis('off')
plt.show()
return result
# 示例5: 比较不同模型
def compare_models(product_id):
data = {
'product_id': product_id,
'model_types': ['mlstm', 'transformer', 'kan'],
'include_visualization': True
}
response = requests.post(f'{API_BASE_URL}/prediction/compare', json=data)
return response.json()
# 示例6: 获取模型列表
def list_models(product_id=None, model_type=None):
params = {}
if product_id:
params['product_id'] = product_id
if model_type:
params['model_type'] = model_type
response = requests.get(f'{API_BASE_URL}/models', params=params)
return response.json()
# 示例代码使用
if __name__ == '__main__':
# 获取产品列表
products = get_products()
print(json.dumps(products, indent=2))
# 选择第一个产品进行训练
if products['status'] == 'success' and products['data']:
product_id = products['data'][0]['product_id']
# 训练模型
training_result = train_model(product_id)
print(json.dumps(training_result, indent=2))
# 如果训练成功启动,检查状态
if training_result['status'] == 'success':
task_id = training_result['data']['task_id']
status = check_training_status(task_id)
print(json.dumps(status, indent=2))
# 使用模型预测
predictions = predict_sales(product_id)
print(json.dumps(predictions, indent=2))
# 比较不同模型
comparison = compare_models(product_id)
print(json.dumps(comparison, indent=2))