ShopTRAINING/test/test_multi_store_training.py
2025-07-02 11:05:23 +08:00

141 lines
4.3 KiB
Python

#!/usr/bin/env python3
"""
测试多店铺训练功能的简单脚本
"""
def test_method_signature():
"""测试方法签名是否正确"""
print("测试 PharmacyPredictor.train_model 方法签名...")
try:
# 读取文件内容
with open('server/core/predictor.py', 'r', encoding='utf-8') as f:
content = f.read()
# 检查方法签名
if 'def train_model(self, product_id, model_type=' in content:
# 查找完整的方法定义
start = content.find('def train_model(')
if start != -1:
end = content.find('"""', start + 100) # 找到docstring开始
method_def = content[start:end].strip()
print("找到方法定义:")
print(method_def)
# 检查参数
params_to_check = ['store_id', 'training_mode', 'aggregation_method']
results = {}
for param in params_to_check:
if param in method_def:
results[param] = True
print(f"✓ 参数 {param} 存在")
else:
results[param] = False
print(f"✗ 参数 {param} 不存在")
return all(results.values())
return False
except Exception as e:
print(f"错误: {e}")
return False
def test_data_file():
"""测试多店铺数据文件是否存在"""
print("\n测试多店铺数据文件...")
import os
files_to_check = [
'pharmacy_sales_multi_store.csv',
'server/utils/multi_store_data_utils.py'
]
results = []
for file_path in files_to_check:
if os.path.exists(file_path):
print(f"✓ 文件存在: {file_path}")
results.append(True)
else:
print(f"✗ 文件不存在: {file_path}")
results.append(False)
return all(results)
def test_import_structure():
"""测试导入结构是否正确"""
print("\n测试导入结构...")
try:
# 读取predictor.py文件
with open('server/core/predictor.py', 'r', encoding='utf-8') as f:
content = f.read()
# 检查多店铺工具函数的导入
imports_to_check = [
'from utils.multi_store_data_utils import',
'load_multi_store_data',
'get_store_product_sales_data',
'aggregate_multi_store_data'
]
results = []
for import_check in imports_to_check:
if import_check in content:
print(f"✓ 导入存在: {import_check}")
results.append(True)
else:
print(f"✗ 导入不存在: {import_check}")
results.append(False)
return all(results)
except Exception as e:
print(f"错误: {e}")
return False
def main():
print("=" * 60)
print("多店铺训练功能验证")
print("=" * 60)
tests = [
("方法签名检查", test_method_signature),
("数据文件检查", test_data_file),
("导入结构检查", test_import_structure)
]
results = []
for test_name, test_func in tests:
print(f"\n{test_name}:")
print("-" * 40)
result = test_func()
results.append(result)
print("\n" + "=" * 60)
print("测试结果:")
print("=" * 60)
for i, (test_name, result) in enumerate(zip([t[0] for t in tests], results)):
status = "通过" if result else "失败"
print(f"{i+1}. {test_name}: {status}")
overall = all(results)
print(f"\n总体结果: {'所有测试通过' if overall else '存在失败的测试'}")
if overall:
print("\n✓ PharmacyPredictor 类已成功更新以支持多店铺训练!")
print("现在可以使用以下参数调用 train_model 方法:")
print("- store_id: 指定店铺ID")
print("- training_mode: 'product', 'store', 或 'global'")
print("- aggregation_method: 'sum', 'mean', 或 'median' (仅用于全局训练)")
else:
print("\n✗ 还有一些问题需要解决")
return overall
if __name__ == "__main__":
main()