141 lines
4.3 KiB
Python
141 lines
4.3 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
测试多店铺训练功能的简单脚本
|
|
"""
|
|
|
|
def test_method_signature():
|
|
"""测试方法签名是否正确"""
|
|
print("测试 PharmacyPredictor.train_model 方法签名...")
|
|
|
|
try:
|
|
# 读取文件内容
|
|
with open('server/core/predictor.py', 'r', encoding='utf-8') as f:
|
|
content = f.read()
|
|
|
|
# 检查方法签名
|
|
if 'def train_model(self, product_id, model_type=' in content:
|
|
# 查找完整的方法定义
|
|
start = content.find('def train_model(')
|
|
if start != -1:
|
|
end = content.find('"""', start + 100) # 找到docstring开始
|
|
method_def = content[start:end].strip()
|
|
|
|
print("找到方法定义:")
|
|
print(method_def)
|
|
|
|
# 检查参数
|
|
params_to_check = ['store_id', 'training_mode', 'aggregation_method']
|
|
results = {}
|
|
|
|
for param in params_to_check:
|
|
if param in method_def:
|
|
results[param] = True
|
|
print(f"✓ 参数 {param} 存在")
|
|
else:
|
|
results[param] = False
|
|
print(f"✗ 参数 {param} 不存在")
|
|
|
|
return all(results.values())
|
|
|
|
return False
|
|
|
|
except Exception as e:
|
|
print(f"错误: {e}")
|
|
return False
|
|
|
|
def test_data_file():
|
|
"""测试多店铺数据文件是否存在"""
|
|
print("\n测试多店铺数据文件...")
|
|
|
|
import os
|
|
|
|
files_to_check = [
|
|
'pharmacy_sales_multi_store.csv',
|
|
'server/utils/multi_store_data_utils.py'
|
|
]
|
|
|
|
results = []
|
|
for file_path in files_to_check:
|
|
if os.path.exists(file_path):
|
|
print(f"✓ 文件存在: {file_path}")
|
|
results.append(True)
|
|
else:
|
|
print(f"✗ 文件不存在: {file_path}")
|
|
results.append(False)
|
|
|
|
return all(results)
|
|
|
|
def test_import_structure():
|
|
"""测试导入结构是否正确"""
|
|
print("\n测试导入结构...")
|
|
|
|
try:
|
|
# 读取predictor.py文件
|
|
with open('server/core/predictor.py', 'r', encoding='utf-8') as f:
|
|
content = f.read()
|
|
|
|
# 检查多店铺工具函数的导入
|
|
imports_to_check = [
|
|
'from utils.multi_store_data_utils import',
|
|
'load_multi_store_data',
|
|
'get_store_product_sales_data',
|
|
'aggregate_multi_store_data'
|
|
]
|
|
|
|
results = []
|
|
for import_check in imports_to_check:
|
|
if import_check in content:
|
|
print(f"✓ 导入存在: {import_check}")
|
|
results.append(True)
|
|
else:
|
|
print(f"✗ 导入不存在: {import_check}")
|
|
results.append(False)
|
|
|
|
return all(results)
|
|
|
|
except Exception as e:
|
|
print(f"错误: {e}")
|
|
return False
|
|
|
|
def main():
|
|
print("=" * 60)
|
|
print("多店铺训练功能验证")
|
|
print("=" * 60)
|
|
|
|
tests = [
|
|
("方法签名检查", test_method_signature),
|
|
("数据文件检查", test_data_file),
|
|
("导入结构检查", test_import_structure)
|
|
]
|
|
|
|
results = []
|
|
for test_name, test_func in tests:
|
|
print(f"\n{test_name}:")
|
|
print("-" * 40)
|
|
result = test_func()
|
|
results.append(result)
|
|
|
|
print("\n" + "=" * 60)
|
|
print("测试结果:")
|
|
print("=" * 60)
|
|
|
|
for i, (test_name, result) in enumerate(zip([t[0] for t in tests], results)):
|
|
status = "通过" if result else "失败"
|
|
print(f"{i+1}. {test_name}: {status}")
|
|
|
|
overall = all(results)
|
|
print(f"\n总体结果: {'所有测试通过' if overall else '存在失败的测试'}")
|
|
|
|
if overall:
|
|
print("\n✓ PharmacyPredictor 类已成功更新以支持多店铺训练!")
|
|
print("现在可以使用以下参数调用 train_model 方法:")
|
|
print("- store_id: 指定店铺ID")
|
|
print("- training_mode: 'product', 'store', 或 'global'")
|
|
print("- aggregation_method: 'sum', 'mean', 或 'median' (仅用于全局训练)")
|
|
else:
|
|
print("\n✗ 还有一些问题需要解决")
|
|
|
|
return overall
|
|
|
|
if __name__ == "__main__":
|
|
main() |