104 lines
3.1 KiB
Python
104 lines
3.1 KiB
Python
import requests
|
|
import json
|
|
import matplotlib.pyplot as plt
|
|
import base64
|
|
from io import BytesIO
|
|
|
|
# API基础URL
|
|
API_BASE_URL = 'http://localhost:5000/api'
|
|
|
|
# 示例1: 获取产品列表
|
|
def get_products():
|
|
response = requests.get(f'{API_BASE_URL}/products')
|
|
return response.json()
|
|
|
|
# 示例2: 训练模型
|
|
def train_model(product_id, model_type='mlstm'):
|
|
data = {
|
|
'product_id': product_id,
|
|
'model_type': model_type,
|
|
'parameters': {
|
|
'look_back': 14,
|
|
'future_days': 7,
|
|
'batch_size': 64,
|
|
'epochs': 100
|
|
}
|
|
}
|
|
response = requests.post(f'{API_BASE_URL}/training', json=data)
|
|
return response.json()
|
|
|
|
# 示例3: 查询训练状态
|
|
def check_training_status(task_id):
|
|
response = requests.get(f'{API_BASE_URL}/training/{task_id}')
|
|
return response.json()
|
|
|
|
# 示例4: 使用模型预测
|
|
def predict_sales(product_id, model_type='mlstm'):
|
|
data = {
|
|
'product_id': product_id,
|
|
'model_type': model_type,
|
|
'future_days': 7,
|
|
'include_visualization': True
|
|
}
|
|
response = requests.post(f'{API_BASE_URL}/prediction', json=data)
|
|
result = response.json()
|
|
|
|
# 如果包含可视化,显示图表
|
|
if 'data' in result and 'visualization' in result['data']:
|
|
img_data = base64.b64decode(result['data']['visualization'])
|
|
img = plt.imread(BytesIO(img_data))
|
|
plt.figure(figsize=(12, 6))
|
|
plt.imshow(img)
|
|
plt.axis('off')
|
|
plt.show()
|
|
|
|
return result
|
|
|
|
# 示例5: 比较不同模型
|
|
def compare_models(product_id):
|
|
data = {
|
|
'product_id': product_id,
|
|
'model_types': ['mlstm', 'transformer', 'kan'],
|
|
'include_visualization': True
|
|
}
|
|
response = requests.post(f'{API_BASE_URL}/prediction/compare', json=data)
|
|
return response.json()
|
|
|
|
# 示例6: 获取模型列表
|
|
def list_models(product_id=None, model_type=None):
|
|
params = {}
|
|
if product_id:
|
|
params['product_id'] = product_id
|
|
if model_type:
|
|
params['model_type'] = model_type
|
|
|
|
response = requests.get(f'{API_BASE_URL}/models', params=params)
|
|
return response.json()
|
|
|
|
# 示例代码使用
|
|
if __name__ == '__main__':
|
|
# 获取产品列表
|
|
products = get_products()
|
|
print(json.dumps(products, indent=2))
|
|
|
|
# 选择第一个产品进行训练
|
|
if products['status'] == 'success' and products['data']:
|
|
product_id = products['data'][0]['product_id']
|
|
|
|
# 训练模型
|
|
training_result = train_model(product_id)
|
|
print(json.dumps(training_result, indent=2))
|
|
|
|
# 如果训练成功启动,检查状态
|
|
if training_result['status'] == 'success':
|
|
task_id = training_result['data']['task_id']
|
|
status = check_training_status(task_id)
|
|
print(json.dumps(status, indent=2))
|
|
|
|
# 使用模型预测
|
|
predictions = predict_sales(product_id)
|
|
print(json.dumps(predictions, indent=2))
|
|
|
|
# 比较不同模型
|
|
comparison = compare_models(product_id)
|
|
print(json.dumps(comparison, indent=2)) |