import requests import json import matplotlib.pyplot as plt import base64 from io import BytesIO # API基础URL API_BASE_URL = 'http://localhost:5000/api' # 示例1: 获取产品列表 def get_products(): response = requests.get(f'{API_BASE_URL}/products') return response.json() # 示例2: 训练模型 def train_model(product_id, model_type='mlstm'): data = { 'product_id': product_id, 'model_type': model_type, 'parameters': { 'look_back': 14, 'future_days': 7, 'batch_size': 64, 'epochs': 100 } } response = requests.post(f'{API_BASE_URL}/training', json=data) return response.json() # 示例3: 查询训练状态 def check_training_status(task_id): response = requests.get(f'{API_BASE_URL}/training/{task_id}') return response.json() # 示例4: 使用模型预测 def predict_sales(product_id, model_type='mlstm'): data = { 'product_id': product_id, 'model_type': model_type, 'future_days': 7, 'include_visualization': True } response = requests.post(f'{API_BASE_URL}/prediction', json=data) result = response.json() # 如果包含可视化,显示图表 if 'data' in result and 'visualization' in result['data']: img_data = base64.b64decode(result['data']['visualization']) img = plt.imread(BytesIO(img_data)) plt.figure(figsize=(12, 6)) plt.imshow(img) plt.axis('off') plt.show() return result # 示例5: 比较不同模型 def compare_models(product_id): data = { 'product_id': product_id, 'model_types': ['mlstm', 'transformer', 'kan'], 'include_visualization': True } response = requests.post(f'{API_BASE_URL}/prediction/compare', json=data) return response.json() # 示例6: 获取模型列表 def list_models(product_id=None, model_type=None): params = {} if product_id: params['product_id'] = product_id if model_type: params['model_type'] = model_type response = requests.get(f'{API_BASE_URL}/models', params=params) return response.json() # 示例代码使用 if __name__ == '__main__': # 获取产品列表 products = get_products() print(json.dumps(products, indent=2)) # 选择第一个产品进行训练 if products['status'] == 'success' and products['data']: product_id = products['data'][0]['product_id'] # 训练模型 training_result = train_model(product_id) print(json.dumps(training_result, indent=2)) # 如果训练成功启动,检查状态 if training_result['status'] == 'success': task_id = training_result['data']['task_id'] status = check_training_status(task_id) print(json.dumps(status, indent=2)) # 使用模型预测 predictions = predict_sales(product_id) print(json.dumps(predictions, indent=2)) # 比较不同模型 comparison = compare_models(product_id) print(json.dumps(comparison, indent=2))