#!/usr/bin/env python3 """ 测试多店铺训练功能的简单脚本 """ def test_method_signature(): """测试方法签名是否正确""" print("测试 PharmacyPredictor.train_model 方法签名...") try: # 读取文件内容 with open('server/core/predictor.py', 'r', encoding='utf-8') as f: content = f.read() # 检查方法签名 if 'def train_model(self, product_id, model_type=' in content: # 查找完整的方法定义 start = content.find('def train_model(') if start != -1: end = content.find('"""', start + 100) # 找到docstring开始 method_def = content[start:end].strip() print("找到方法定义:") print(method_def) # 检查参数 params_to_check = ['store_id', 'training_mode', 'aggregation_method'] results = {} for param in params_to_check: if param in method_def: results[param] = True print(f"✓ 参数 {param} 存在") else: results[param] = False print(f"✗ 参数 {param} 不存在") return all(results.values()) return False except Exception as e: print(f"错误: {e}") return False def test_data_file(): """测试多店铺数据文件是否存在""" print("\n测试多店铺数据文件...") import os files_to_check = [ 'pharmacy_sales_multi_store.csv', 'server/utils/multi_store_data_utils.py' ] results = [] for file_path in files_to_check: if os.path.exists(file_path): print(f"✓ 文件存在: {file_path}") results.append(True) else: print(f"✗ 文件不存在: {file_path}") results.append(False) return all(results) def test_import_structure(): """测试导入结构是否正确""" print("\n测试导入结构...") try: # 读取predictor.py文件 with open('server/core/predictor.py', 'r', encoding='utf-8') as f: content = f.read() # 检查多店铺工具函数的导入 imports_to_check = [ 'from utils.multi_store_data_utils import', 'load_multi_store_data', 'get_store_product_sales_data', 'aggregate_multi_store_data' ] results = [] for import_check in imports_to_check: if import_check in content: print(f"✓ 导入存在: {import_check}") results.append(True) else: print(f"✗ 导入不存在: {import_check}") results.append(False) return all(results) except Exception as e: print(f"错误: {e}") return False def main(): print("=" * 60) print("多店铺训练功能验证") print("=" * 60) tests = [ ("方法签名检查", test_method_signature), ("数据文件检查", test_data_file), ("导入结构检查", test_import_structure) ] results = [] for test_name, test_func in tests: print(f"\n{test_name}:") print("-" * 40) result = test_func() results.append(result) print("\n" + "=" * 60) print("测试结果:") print("=" * 60) for i, (test_name, result) in enumerate(zip([t[0] for t in tests], results)): status = "通过" if result else "失败" print(f"{i+1}. {test_name}: {status}") overall = all(results) print(f"\n总体结果: {'所有测试通过' if overall else '存在失败的测试'}") if overall: print("\n✓ PharmacyPredictor 类已成功更新以支持多店铺训练!") print("现在可以使用以下参数调用 train_model 方法:") print("- store_id: 指定店铺ID") print("- training_mode: 'product', 'store', 或 'global'") print("- aggregation_method: 'sum', 'mean', 或 'median' (仅用于全局训练)") else: print("\n✗ 还有一些问题需要解决") return overall if __name__ == "__main__": main()