数据更换调试
This commit is contained in:
parent
a2ce7659c9
commit
f72ea91cab
@ -12,7 +12,7 @@ import '@/assets/fonts.css'
|
||||
import '@/assets/element-theme.css'
|
||||
|
||||
// 配置axios基础URL
|
||||
axios.defaults.baseURL = 'http://127.0.0.1:5000'
|
||||
// axios.defaults.baseURL = 'http://127.0.0.1:5000'
|
||||
console.log('API基础URL已设置为:', axios.defaults.baseURL)
|
||||
|
||||
const app = createApp(App)
|
||||
|
180
Windows_快速启动.bat
180
Windows_快速启动.bat
@ -1,180 +0,0 @@
|
||||
@echo off
|
||||
chcp 65001 >nul
|
||||
echo ====================================
|
||||
echo 药店销售预测系统 - Windows 快速启动
|
||||
echo ====================================
|
||||
echo.
|
||||
|
||||
:: 检查Python
|
||||
echo [1/6] 检查Python环境...
|
||||
python --version >nul 2>&1
|
||||
if errorlevel 1 (
|
||||
echo ❌ 未找到Python,请先安装Python 3.8+
|
||||
pause
|
||||
exit /b 1
|
||||
)
|
||||
echo ✓ Python环境正常
|
||||
|
||||
:: 检查虚拟环境
|
||||
echo.
|
||||
echo [2/6] 检查虚拟环境...
|
||||
if not exist ".venv\Scripts\python.exe" (
|
||||
echo 🔄 创建虚拟环境...
|
||||
python -m venv .venv
|
||||
if errorlevel 1 (
|
||||
echo ❌ 虚拟环境创建失败
|
||||
pause
|
||||
exit /b 1
|
||||
)
|
||||
)
|
||||
echo ✓ 虚拟环境准备完成
|
||||
|
||||
:: 激活虚拟环境
|
||||
echo.
|
||||
echo [3/6] 激活虚拟环境...
|
||||
call .venv\Scripts\activate.bat
|
||||
if errorlevel 1 (
|
||||
echo ❌ 虚拟环境激活失败
|
||||
pause
|
||||
exit /b 1
|
||||
)
|
||||
echo ✓ 虚拟环境已激活
|
||||
|
||||
:: 安装依赖
|
||||
echo.
|
||||
echo [4/6] 检查Python依赖...
|
||||
pip show flask >nul 2>&1
|
||||
if errorlevel 1 (
|
||||
echo 🔄 安装Python依赖...
|
||||
pip install -r install\requirements.txt
|
||||
if errorlevel 1 (
|
||||
echo ❌ 依赖安装失败
|
||||
pause
|
||||
exit /b 1
|
||||
)
|
||||
)
|
||||
echo ✓ Python依赖已安装
|
||||
|
||||
:: 检查数据文件
|
||||
echo.
|
||||
echo [5/6] 检查数据文件...
|
||||
if not exist "pharmacy_sales_multi_store.csv" (
|
||||
echo 🔄 生成示例数据...
|
||||
python generate_multi_store_data.py
|
||||
if errorlevel 1 (
|
||||
echo ❌ 数据生成失败
|
||||
pause
|
||||
exit /b 1
|
||||
)
|
||||
)
|
||||
echo ✓ 数据文件准备完成
|
||||
|
||||
:: 初始化数据库
|
||||
echo.
|
||||
echo [6/6] 初始化数据库...
|
||||
if not exist "prediction_history.db" (
|
||||
echo 🔄 初始化数据库...
|
||||
python server\init_multi_store_db.py
|
||||
if errorlevel 1 (
|
||||
echo ❌ 数据库初始化失败
|
||||
pause
|
||||
exit /b 1
|
||||
)
|
||||
)
|
||||
echo ✓ 数据库准备完成
|
||||
|
||||
echo.
|
||||
echo ====================================
|
||||
echo ✅ 环境准备完成!
|
||||
echo ====================================
|
||||
echo.
|
||||
echo 接下来请选择启动方式:
|
||||
echo [1] 启动API服务器 (后端)
|
||||
echo [2] 启动前端开发服务器
|
||||
echo [3] 运行API测试
|
||||
echo [4] 查看项目状态
|
||||
echo [0] 退出
|
||||
echo.
|
||||
|
||||
:menu
|
||||
set /p choice="请选择 (0-4): "
|
||||
|
||||
if "%choice%"=="1" goto start_api
|
||||
if "%choice%"=="2" goto start_frontend
|
||||
if "%choice%"=="3" goto run_tests
|
||||
if "%choice%"=="4" goto show_status
|
||||
if "%choice%"=="0" goto end
|
||||
echo 无效选择,请重新输入
|
||||
goto menu
|
||||
|
||||
:start_api
|
||||
echo.
|
||||
echo 🚀 启动API服务器...
|
||||
echo 服务器将在 http://localhost:5000 启动
|
||||
echo API文档访问: http://localhost:5000/swagger
|
||||
echo.
|
||||
echo 按 Ctrl+C 停止服务器
|
||||
echo.
|
||||
cd server
|
||||
python api.py
|
||||
goto end
|
||||
|
||||
:start_frontend
|
||||
echo.
|
||||
echo 🚀 启动前端开发服务器...
|
||||
cd UI
|
||||
if not exist "node_modules" (
|
||||
echo 🔄 安装前端依赖...
|
||||
npm install
|
||||
if errorlevel 1 (
|
||||
echo ❌ 前端依赖安装失败
|
||||
pause
|
||||
goto menu
|
||||
)
|
||||
)
|
||||
echo 前端将在 http://localhost:5173 启动
|
||||
echo.
|
||||
npm run dev
|
||||
goto end
|
||||
|
||||
:run_tests
|
||||
echo.
|
||||
echo 🧪 运行API测试...
|
||||
python test_api_endpoints.py
|
||||
echo.
|
||||
pause
|
||||
goto menu
|
||||
|
||||
:show_status
|
||||
echo.
|
||||
echo 📊 项目状态检查...
|
||||
echo.
|
||||
echo === 文件检查 ===
|
||||
if exist "pharmacy_sales_multi_store.csv" (echo ✓ 多店铺数据文件) else (echo ❌ 多店铺数据文件缺失)
|
||||
if exist "prediction_history.db" (echo ✓ 预测历史数据库) else (echo ❌ 预测历史数据库缺失)
|
||||
if exist "server\api.py" (echo ✓ API服务器文件) else (echo ❌ API服务器文件缺失)
|
||||
if exist "UI\package.json" (echo ✓ 前端项目文件) else (echo ❌ 前端项目文件缺失)
|
||||
|
||||
echo.
|
||||
echo === 模型文件 ===
|
||||
if exist "saved_models" (
|
||||
echo 已保存的模型:
|
||||
dir saved_models\*.pth /b 2>nul || echo 暂无已训练的模型
|
||||
) else (
|
||||
echo ❌ 模型目录不存在
|
||||
)
|
||||
|
||||
echo.
|
||||
echo === 虚拟环境状态 ===
|
||||
python -c "import sys; print('Python版本:', sys.version)"
|
||||
python -c "import flask; print('Flask版本:', flask.__version__)" 2>nul || echo ❌ Flask未安装
|
||||
|
||||
echo.
|
||||
pause
|
||||
goto menu
|
||||
|
||||
:end
|
||||
echo.
|
||||
echo 感谢使用药店销售预测系统!
|
||||
echo.
|
||||
pause
|
@ -107,14 +107,18 @@ def generate_multi_store_sales_data():
|
||||
df = pd.DataFrame(sales_data)
|
||||
|
||||
# 保存到CSV文件
|
||||
df.to_csv('pharmacy_sales_multi_store.csv', index=False, encoding='utf-8')
|
||||
output_dir = 'data'
|
||||
if not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
output_path = os.path.join(output_dir, 'timeseries_training_data_sample_10s50p.parquet')
|
||||
df.to_parquet(output_path, index=False)
|
||||
|
||||
print(f"多店铺销售数据生成完成!")
|
||||
print(f"数据记录数: {len(df)}")
|
||||
print(f"日期范围: {df['date'].min()} 到 {df['date'].max()}")
|
||||
print(f"店铺数量: {df['store_id'].nunique()}")
|
||||
print(f"产品数量: {df['product_id'].nunique()}")
|
||||
print(f"文件保存为: pharmacy_sales_multi_store.csv")
|
||||
print(f"文件保存为: {output_path}")
|
||||
|
||||
# 显示数据样本
|
||||
print("\n数据样本:")
|
||||
|
File diff suppressed because it is too large
Load Diff
Binary file not shown.
Binary file not shown.
@ -47,12 +47,14 @@ simple-websocket==1.1.0
|
||||
six==1.17.0
|
||||
sympy==1.13.3
|
||||
threadpoolctl==3.6.0
|
||||
torch==2.7.1+cu128
|
||||
torchaudio==2.7.1+cu128
|
||||
torchvision==0.22.1+cu128
|
||||
torch==2.7.1
|
||||
torchaudio==2.7.1
|
||||
torchvision==0.22.1
|
||||
tqdm==4.67.1
|
||||
typing-extensions==4.12.2
|
||||
tzdata==2025.2
|
||||
werkzeug==3.1.3
|
||||
win32-setctime==1.2.0
|
||||
wsproto==1.2.0
|
||||
|
||||
pyarrow
|
||||
|
Binary file not shown.
@ -57,7 +57,7 @@ from analysis.metrics import evaluate_model, compare_models
|
||||
from core.config import (
|
||||
DEFAULT_MODEL_DIR, WEBSOCKET_NAMESPACE,
|
||||
get_model_versions, get_latest_model_version, get_next_model_version,
|
||||
get_model_file_path, save_model_version_info
|
||||
get_model_file_path, save_model_version_info, DEFAULT_DATA_PATH
|
||||
)
|
||||
|
||||
# 导入多店铺数据工具
|
||||
@ -160,7 +160,7 @@ def train_store_model(store_id, model_type, epochs=50, product_scope='all', prod
|
||||
|
||||
# 读取店铺所有数据,找到第一个有数据的药品
|
||||
try:
|
||||
df = pd.read_csv('pharmacy_sales_multi_store.csv')
|
||||
df = pd.read_parquet(DEFAULT_DATA_PATH)
|
||||
store_products = df[df['store_id'] == store_id]['product_id'].unique()
|
||||
|
||||
if len(store_products) == 0:
|
||||
@ -207,7 +207,7 @@ def train_global_model(model_type, epochs=50, training_scope='all_stores_all_pro
|
||||
import pandas as pd
|
||||
|
||||
# 读取数据
|
||||
df = pd.read_csv('pharmacy_sales_multi_store.csv')
|
||||
df = pd.read_parquet(DEFAULT_DATA_PATH)
|
||||
|
||||
# 根据训练范围过滤数据
|
||||
if training_scope == 'selected_stores' and store_ids:
|
||||
@ -631,7 +631,7 @@ def swagger_ui():
|
||||
def get_products():
|
||||
try:
|
||||
from utils.multi_store_data_utils import get_available_products
|
||||
products = get_available_products('pharmacy_sales_multi_store.csv')
|
||||
products = get_available_products(DEFAULT_DATA_PATH)
|
||||
return jsonify({"status": "success", "data": products})
|
||||
except Exception as e:
|
||||
return jsonify({"status": "error", "message": str(e)}), 500
|
||||
@ -686,7 +686,7 @@ def get_products():
|
||||
def get_product(product_id):
|
||||
try:
|
||||
from utils.multi_store_data_utils import load_multi_store_data
|
||||
df = load_multi_store_data('pharmacy_sales_multi_store.csv', product_id=product_id)
|
||||
df = load_multi_store_data(DEFAULT_DATA_PATH, product_id=product_id)
|
||||
|
||||
if df.empty:
|
||||
return jsonify({"status": "error", "message": "产品不存在"}), 404
|
||||
@ -764,7 +764,7 @@ def get_product_sales(product_id):
|
||||
|
||||
from utils.multi_store_data_utils import load_multi_store_data
|
||||
df = load_multi_store_data(
|
||||
'pharmacy_sales_multi_store.csv',
|
||||
DEFAULT_DATA_PATH,
|
||||
product_id=product_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date
|
||||
@ -1713,7 +1713,7 @@ def compare_predictions():
|
||||
predictor = PharmacyPredictor()
|
||||
|
||||
# 获取产品名称
|
||||
df = pd.read_excel('pharmacy_sales.xlsx')
|
||||
df = pd.read_parquet(DEFAULT_DATA_PATH)
|
||||
product_df = df[df['product_id'] == product_id]
|
||||
|
||||
if product_df.empty:
|
||||
@ -1868,7 +1868,7 @@ def analyze_prediction():
|
||||
predictions_array = np.array(predictions)
|
||||
|
||||
# 获取产品特征数据
|
||||
df = pd.read_excel('pharmacy_sales.xlsx')
|
||||
df = pd.read_parquet(DEFAULT_DATA_PATH)
|
||||
product_df = df[df['product_id'] == product_id].sort_values('date')
|
||||
|
||||
if product_df.empty:
|
||||
@ -2689,7 +2689,7 @@ def get_product_name(product_id):
|
||||
"""根据产品ID获取产品名称"""
|
||||
try:
|
||||
# 从Excel文件中查找产品名称
|
||||
df = pd.read_excel('pharmacy_sales.xlsx')
|
||||
df = pd.read_parquet(DEFAULT_DATA_PATH)
|
||||
product_df = df[df['product_id'] == product_id]
|
||||
if not product_df.empty:
|
||||
return product_df['product_name'].iloc[0]
|
||||
@ -2750,7 +2750,7 @@ def run_prediction(model_type, product_id, model_id, future_days, start_date, ve
|
||||
# 获取历史数据用于对比
|
||||
try:
|
||||
# 读取原始数据
|
||||
df = pd.read_excel('pharmacy_sales.xlsx')
|
||||
df = pd.read_parquet(DEFAULT_DATA_PATH)
|
||||
product_df = df[df['product_id'] == product_id].copy()
|
||||
|
||||
if not product_df.empty:
|
||||
@ -4026,7 +4026,7 @@ def get_stores():
|
||||
"""
|
||||
try:
|
||||
from utils.multi_store_data_utils import get_available_stores
|
||||
stores = get_available_stores('pharmacy_sales_multi_store.csv')
|
||||
stores = get_available_stores(DEFAULT_DATA_PATH)
|
||||
|
||||
return jsonify({
|
||||
"status": "success",
|
||||
@ -4046,7 +4046,7 @@ def get_store(store_id):
|
||||
"""
|
||||
try:
|
||||
from utils.multi_store_data_utils import get_available_stores
|
||||
stores = get_available_stores('pharmacy_sales_multi_store.csv')
|
||||
stores = get_available_stores(DEFAULT_DATA_PATH)
|
||||
|
||||
store = None
|
||||
for s in stores:
|
||||
@ -4282,7 +4282,7 @@ def get_global_training_stats():
|
||||
import pandas as pd
|
||||
|
||||
# 读取数据
|
||||
df = pd.read_csv('pharmacy_sales_multi_store.csv')
|
||||
df = pd.read_parquet(DEFAULT_DATA_PATH)
|
||||
|
||||
# 根据训练范围过滤数据
|
||||
if training_scope == 'selected_stores' and store_ids:
|
||||
@ -4340,6 +4340,7 @@ def get_sales_data():
|
||||
"""
|
||||
获取销售数据列表,支持分页和过滤
|
||||
"""
|
||||
logger.info("\n[DEBUG-API] ===> Entering get_sales_data endpoint.")
|
||||
try:
|
||||
# 获取查询参数
|
||||
store_id = request.args.get('store_id')
|
||||
@ -4349,6 +4350,8 @@ def get_sales_data():
|
||||
page = int(request.args.get('page', 1))
|
||||
page_size = int(request.args.get('page_size', 20))
|
||||
|
||||
logger.info(f"[DEBUG-API] Request Params: store_id={store_id}, product_id={product_id}, start_date={start_date}, end_date={end_date}, page={page}, page_size={page_size}")
|
||||
|
||||
# 验证参数
|
||||
if page < 1:
|
||||
page = 1
|
||||
@ -4358,16 +4361,19 @@ def get_sales_data():
|
||||
# 使用多店铺数据工具加载数据
|
||||
from utils.multi_store_data_utils import load_multi_store_data, get_sales_statistics
|
||||
|
||||
logger.info("[DEBUG-API] Calling load_multi_store_data...")
|
||||
# 加载过滤后的数据
|
||||
df = load_multi_store_data(
|
||||
'pharmacy_sales_multi_store.csv',
|
||||
DEFAULT_DATA_PATH,
|
||||
store_id=store_id,
|
||||
product_id=product_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date
|
||||
)
|
||||
logger.info(f"[DEBUG-API] load_multi_store_data returned DataFrame with shape: {df.shape}")
|
||||
|
||||
if df.empty:
|
||||
logger.info("[DEBUG-API] DataFrame is empty after loading/filtering. Returning empty success response.")
|
||||
return jsonify({
|
||||
"status": "success",
|
||||
"data": [],
|
||||
@ -4382,11 +4388,13 @@ def get_sales_data():
|
||||
|
||||
# 计算总数
|
||||
total_records = len(df)
|
||||
logger.info(f"[DEBUG-API] Total records found: {total_records}")
|
||||
|
||||
# 分页处理
|
||||
start_idx = (page - 1) * page_size
|
||||
end_idx = start_idx + page_size
|
||||
paginated_df = df.iloc[start_idx:end_idx]
|
||||
logger.info(f"[DEBUG-API] Paginated DataFrame shape: {paginated_df.shape}")
|
||||
|
||||
# 转换为字典列表
|
||||
data = []
|
||||
@ -4419,16 +4427,22 @@ def get_sales_data():
|
||||
}
|
||||
}
|
||||
|
||||
return jsonify({
|
||||
response_payload = {
|
||||
"status": "success",
|
||||
"data": data,
|
||||
"total": total_records,
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
"statistics": statistics
|
||||
})
|
||||
}
|
||||
logger.info(f"[DEBUG-API] Preparing to send response. Data length: {len(data)}, Total records: {total_records}")
|
||||
logger.info("[DEBUG-API] <=== Exiting get_sales_data endpoint successfully.")
|
||||
return jsonify(response_payload)
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
logger.error("[DEBUG-API] !!! An error occurred in get_sales_data endpoint.")
|
||||
logger.error(traceback.format_exc())
|
||||
return jsonify({
|
||||
"status": "error",
|
||||
"message": f"获取销售数据失败: {str(e)}"
|
||||
@ -4523,7 +4537,8 @@ if __name__ == '__main__':
|
||||
port=args.port,
|
||||
debug=args.debug,
|
||||
use_reloader=False, # 关闭重载器避免冲突
|
||||
log_output=True
|
||||
log_output=True,
|
||||
allow_unsafe_werkzeug=True
|
||||
)
|
||||
finally:
|
||||
# 确保在退出时停止训练进程管理器
|
||||
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -26,7 +26,7 @@ def get_device():
|
||||
DEVICE = get_device()
|
||||
|
||||
# 数据相关配置
|
||||
DEFAULT_DATA_PATH = 'pharmacy_sales.xlsx'
|
||||
DEFAULT_DATA_PATH = 'data/timeseries_training_data_sample_10s50p.parquet'
|
||||
DEFAULT_MODEL_DIR = 'saved_models'
|
||||
DEFAULT_FEATURES = ['sales', 'price', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature']
|
||||
|
||||
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -19,7 +19,7 @@ from models.optimized_kan_forecaster import OptimizedKANForecaster
|
||||
from utils.data_utils import create_dataset, PharmacyDataset
|
||||
from utils.visualization import plot_loss_curve
|
||||
from analysis.metrics import evaluate_model
|
||||
from core.config import DEVICE, DEFAULT_MODEL_DIR, LOOK_BACK, FORECAST_HORIZON
|
||||
from core.config import DEVICE, DEFAULT_MODEL_DIR, LOOK_BACK, FORECAST_HORIZON, DEFAULT_DATA_PATH
|
||||
|
||||
def train_product_model_with_kan(product_id, store_id=None, training_mode='product', aggregation_method='sum', epochs=50, use_optimized=False, model_dir=DEFAULT_MODEL_DIR):
|
||||
"""
|
||||
@ -44,7 +44,7 @@ def train_product_model_with_kan(product_id, store_id=None, training_mode='produ
|
||||
product_df = get_store_product_sales_data(
|
||||
store_id,
|
||||
product_id,
|
||||
'pharmacy_sales_multi_store.csv'
|
||||
DEFAULT_DATA_PATH
|
||||
)
|
||||
training_scope = f"店铺 {store_id}"
|
||||
elif training_mode == 'global':
|
||||
@ -52,17 +52,17 @@ def train_product_model_with_kan(product_id, store_id=None, training_mode='produ
|
||||
product_df = aggregate_multi_store_data(
|
||||
product_id,
|
||||
aggregation_method=aggregation_method,
|
||||
file_path='pharmacy_sales_multi_store.csv'
|
||||
file_path=DEFAULT_DATA_PATH
|
||||
)
|
||||
training_scope = f"全局聚合({aggregation_method})"
|
||||
else:
|
||||
# 默认:加载所有店铺的产品数据
|
||||
product_df = load_multi_store_data('pharmacy_sales_multi_store.csv', product_id=product_id)
|
||||
product_df = load_multi_store_data(DEFAULT_DATA_PATH, product_id=product_id)
|
||||
training_scope = "所有店铺"
|
||||
except Exception as e:
|
||||
print(f"多店铺数据加载失败: {e}")
|
||||
# 后备方案:尝试原始数据
|
||||
df = pd.read_excel('pharmacy_sales.xlsx')
|
||||
df = pd.read_parquet(DEFAULT_DATA_PATH)
|
||||
product_df = df[df['product_id'] == product_id].sort_values('date')
|
||||
training_scope = "原始数据"
|
||||
|
||||
|
@ -20,7 +20,7 @@ from utils.multi_store_data_utils import get_store_product_sales_data, aggregate
|
||||
from utils.visualization import plot_loss_curve
|
||||
from analysis.metrics import evaluate_model
|
||||
from core.config import (
|
||||
DEVICE, DEFAULT_MODEL_DIR, LOOK_BACK, FORECAST_HORIZON,
|
||||
DEVICE, DEFAULT_MODEL_DIR, LOOK_BACK, FORECAST_HORIZON, DEFAULT_DATA_PATH,
|
||||
get_next_model_version, get_model_file_path, get_latest_model_version
|
||||
)
|
||||
from utils.training_progress import progress_manager
|
||||
@ -212,7 +212,7 @@ def train_product_model_with_mlstm(
|
||||
product_df = get_store_product_sales_data(
|
||||
store_id,
|
||||
product_id,
|
||||
'pharmacy_sales_multi_store.csv'
|
||||
DEFAULT_DATA_PATH
|
||||
)
|
||||
training_scope = f"店铺 {store_id}"
|
||||
elif training_mode == 'global':
|
||||
@ -220,17 +220,17 @@ def train_product_model_with_mlstm(
|
||||
product_df = aggregate_multi_store_data(
|
||||
product_id,
|
||||
aggregation_method=aggregation_method,
|
||||
file_path='pharmacy_sales_multi_store.csv'
|
||||
file_path=DEFAULT_DATA_PATH
|
||||
)
|
||||
training_scope = f"全局聚合({aggregation_method})"
|
||||
else:
|
||||
# 默认:加载所有店铺的产品数据
|
||||
product_df = load_multi_store_data('pharmacy_sales_multi_store.csv', product_id=product_id)
|
||||
product_df = load_multi_store_data(DEFAULT_DATA_PATH, product_id=product_id)
|
||||
training_scope = "所有店铺"
|
||||
except Exception as e:
|
||||
print(f"多店铺数据加载失败: {e}")
|
||||
# 后备方案:尝试原始数据
|
||||
df = pd.read_excel('pharmacy_sales.xlsx')
|
||||
df = pd.read_parquet(DEFAULT_DATA_PATH)
|
||||
product_df = df[df['product_id'] == product_id].sort_values(by='date')
|
||||
training_scope = "原始数据"
|
||||
|
||||
|
@ -18,7 +18,7 @@ from models.tcn_model import TCNForecaster
|
||||
from utils.data_utils import create_dataset, PharmacyDataset
|
||||
from utils.visualization import plot_loss_curve
|
||||
from analysis.metrics import evaluate_model
|
||||
from core.config import DEVICE, DEFAULT_MODEL_DIR, LOOK_BACK, FORECAST_HORIZON
|
||||
from core.config import DEVICE, DEFAULT_MODEL_DIR, LOOK_BACK, FORECAST_HORIZON, DEFAULT_DATA_PATH
|
||||
from utils.training_progress import progress_manager
|
||||
|
||||
def save_checkpoint(checkpoint_data: dict, epoch_or_label, product_id: str,
|
||||
@ -123,7 +123,7 @@ def train_product_model_with_tcn(
|
||||
product_df = get_store_product_sales_data(
|
||||
store_id,
|
||||
product_id,
|
||||
'pharmacy_sales_multi_store.csv'
|
||||
DEFAULT_DATA_PATH
|
||||
)
|
||||
training_scope = f"店铺 {store_id}"
|
||||
elif training_mode == 'global':
|
||||
@ -131,17 +131,17 @@ def train_product_model_with_tcn(
|
||||
product_df = aggregate_multi_store_data(
|
||||
product_id,
|
||||
aggregation_method=aggregation_method,
|
||||
file_path='pharmacy_sales_multi_store.csv'
|
||||
file_path=DEFAULT_DATA_PATH
|
||||
)
|
||||
training_scope = f"全局聚合({aggregation_method})"
|
||||
else:
|
||||
# 默认:加载所有店铺的产品数据
|
||||
product_df = load_multi_store_data('pharmacy_sales_multi_store.csv', product_id=product_id)
|
||||
product_df = load_multi_store_data(DEFAULT_DATA_PATH, product_id=product_id)
|
||||
training_scope = "所有店铺"
|
||||
except Exception as e:
|
||||
print(f"多店铺数据加载失败: {e}")
|
||||
# 后备方案:尝试原始数据
|
||||
df = pd.read_excel('pharmacy_sales.xlsx')
|
||||
df = pd.read_parquet(DEFAULT_DATA_PATH)
|
||||
product_df = df[df['product_id'] == product_id].sort_values('date')
|
||||
training_scope = "原始数据"
|
||||
|
||||
|
@ -21,7 +21,7 @@ from utils.multi_store_data_utils import get_store_product_sales_data, aggregate
|
||||
from utils.visualization import plot_loss_curve
|
||||
from analysis.metrics import evaluate_model
|
||||
from core.config import (
|
||||
DEVICE, DEFAULT_MODEL_DIR, LOOK_BACK, FORECAST_HORIZON,
|
||||
DEVICE, DEFAULT_MODEL_DIR, LOOK_BACK, FORECAST_HORIZON, DEFAULT_DATA_PATH,
|
||||
get_next_model_version, get_model_file_path, get_latest_model_version
|
||||
)
|
||||
from utils.training_progress import progress_manager
|
||||
@ -138,7 +138,7 @@ def train_product_model_with_transformer(
|
||||
product_df = get_store_product_sales_data(
|
||||
store_id,
|
||||
product_id,
|
||||
'pharmacy_sales_multi_store.csv'
|
||||
DEFAULT_DATA_PATH
|
||||
)
|
||||
training_scope = f"店铺 {store_id}"
|
||||
elif training_mode == 'global':
|
||||
@ -146,17 +146,17 @@ def train_product_model_with_transformer(
|
||||
product_df = aggregate_multi_store_data(
|
||||
product_id,
|
||||
aggregation_method=aggregation_method,
|
||||
file_path='pharmacy_sales_multi_store.csv'
|
||||
file_path=DEFAULT_DATA_PATH
|
||||
)
|
||||
training_scope = f"全局聚合({aggregation_method})"
|
||||
else:
|
||||
# 默认:加载所有店铺的产品数据
|
||||
product_df = load_multi_store_data('pharmacy_sales_multi_store.csv', product_id=product_id)
|
||||
product_df = load_multi_store_data(DEFAULT_DATA_PATH, product_id=product_id)
|
||||
training_scope = "所有店铺"
|
||||
except Exception as e:
|
||||
print(f"多店铺数据加载失败: {e}")
|
||||
# 后备方案:尝试原始数据
|
||||
df = pd.read_excel('pharmacy_sales.xlsx')
|
||||
df = pd.read_parquet(DEFAULT_DATA_PATH)
|
||||
product_df = df[df['product_id'] == product_id].sort_values('date')
|
||||
training_scope = "原始数据"
|
||||
|
||||
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -9,7 +9,13 @@ import os
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, List, Tuple, Dict, Any
|
||||
|
||||
def load_multi_store_data(file_path: str = 'pharmacy_sales_multi_store.csv',
|
||||
import logging
|
||||
|
||||
# 配置日志
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def load_multi_store_data(file_path: str = 'data/timeseries_training_data_sample_10s50p.parquet',
|
||||
store_id: Optional[str] = None,
|
||||
product_id: Optional[str] = None,
|
||||
start_date: Optional[str] = None,
|
||||
@ -27,72 +33,77 @@ def load_multi_store_data(file_path: str = 'pharmacy_sales_multi_store.csv',
|
||||
返回:
|
||||
DataFrame: 过滤后的销售数据
|
||||
"""
|
||||
logger.info("\n[DEBUG-UTIL] ---> Entering load_multi_store_data function.")
|
||||
logger.info(f"[DEBUG-UTIL] Initial params: file_path={file_path}, store_id={store_id}, product_id={product_id}, start_date={start_date}, end_date={end_date}")
|
||||
|
||||
# 尝试多个可能的文件路径
|
||||
possible_paths = [
|
||||
file_path,
|
||||
f'../{file_path}',
|
||||
f'server/{file_path}',
|
||||
'pharmacy_sales_multi_store.csv',
|
||||
'../pharmacy_sales_multi_store.csv',
|
||||
'pharmacy_sales.xlsx', # 后向兼容原始文件
|
||||
'../pharmacy_sales.xlsx'
|
||||
]
|
||||
try:
|
||||
# 构建数据文件的绝对路径,使其独立于当前工作目录
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
project_root = os.path.dirname(os.path.dirname(script_dir))
|
||||
absolute_file_path = os.path.join(project_root, file_path)
|
||||
|
||||
logger.info(f"[DEBUG-UTIL] Constructed absolute path to data file: {absolute_file_path}")
|
||||
|
||||
if not os.path.exists(absolute_file_path):
|
||||
logger.error(f"[DEBUG-UTIL] Data file not found at absolute path: {absolute_file_path}")
|
||||
error_message = f"V3-FIX: 无法找到数据文件: {absolute_file_path}"
|
||||
raise FileNotFoundError(error_message)
|
||||
|
||||
logger.info(f"[DEBUG-UTIL] Attempting to load file from: {absolute_file_path}")
|
||||
if absolute_file_path.endswith('.parquet'):
|
||||
df = pd.read_parquet(absolute_file_path)
|
||||
elif absolute_file_path.endswith('.csv'):
|
||||
df = pd.read_csv(absolute_file_path)
|
||||
elif absolute_file_path.endswith('.xlsx'):
|
||||
df = pd.read_excel(absolute_file_path)
|
||||
else:
|
||||
raise ValueError(f"不支持的文件格式: {absolute_file_path}")
|
||||
|
||||
logger.info(f"[DEBUG-UTIL] Successfully loaded data file. Shape: {df.shape}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[DEBUG-UTIL] !!! An error occurred during file loading: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
raise e
|
||||
|
||||
df = None
|
||||
for path in possible_paths:
|
||||
try:
|
||||
if path.endswith('.csv'):
|
||||
df = pd.read_csv(path)
|
||||
elif path.endswith('.xlsx'):
|
||||
df = pd.read_excel(path)
|
||||
# 为原始Excel文件添加默认店铺信息
|
||||
if 'store_id' not in df.columns:
|
||||
df['store_id'] = 'S001'
|
||||
df['store_name'] = '默认店铺'
|
||||
df['store_location'] = '未知位置'
|
||||
df['store_type'] = 'standard'
|
||||
|
||||
if df is not None:
|
||||
print(f"成功加载数据文件: {path}")
|
||||
break
|
||||
except Exception as e:
|
||||
continue
|
||||
|
||||
if df is None:
|
||||
raise FileNotFoundError(f"无法找到数据文件,尝试的路径: {possible_paths}")
|
||||
logger.info(f"[DEBUG-UTIL] Initial DataFrame columns: {df.columns.tolist()}")
|
||||
|
||||
# 确保date列是datetime类型
|
||||
if 'date' in df.columns:
|
||||
df['date'] = pd.to_datetime(df['date'])
|
||||
logger.info("[DEBUG-UTIL] Converted 'date' column to datetime objects.")
|
||||
|
||||
# 按店铺过滤
|
||||
if store_id:
|
||||
df = df[df['store_id'] == store_id].copy()
|
||||
print(f"按店铺过滤: {store_id}, 剩余记录数: {len(df)}")
|
||||
logger.info(f"[DEBUG-UTIL] Filtered by store_id='{store_id}'. Records remaining: {len(df)}")
|
||||
|
||||
# 按产品过滤
|
||||
if product_id:
|
||||
df = df[df['product_id'] == product_id].copy()
|
||||
print(f"按产品过滤: {product_id}, 剩余记录数: {len(df)}")
|
||||
logger.info(f"[DEBUG-UTIL] Filtered by product_id='{product_id}'. Records remaining: {len(df)}")
|
||||
|
||||
# 按时间范围过滤
|
||||
if start_date:
|
||||
start_date = pd.to_datetime(start_date)
|
||||
df = df[df['date'] >= start_date].copy()
|
||||
print(f"开始日期过滤: {start_date}, 剩余记录数: {len(df)}")
|
||||
start_date_dt = pd.to_datetime(start_date)
|
||||
df = df[df['date'] >= start_date_dt].copy()
|
||||
logger.info(f"[DEBUG-UTIL] Filtered by start_date>='{start_date}'. Records remaining: {len(df)}")
|
||||
|
||||
if end_date:
|
||||
end_date = pd.to_datetime(end_date)
|
||||
df = df[df['date'] <= end_date].copy()
|
||||
print(f"结束日期过滤: {end_date}, 剩余记录数: {len(df)}")
|
||||
end_date_dt = pd.to_datetime(end_date)
|
||||
df = df[df['date'] <= end_date_dt].copy()
|
||||
logger.info(f"[DEBUG-UTIL] Filtered by end_date<='{end_date}'. Records remaining: {len(df)}")
|
||||
|
||||
if len(df) == 0:
|
||||
print("警告: 过滤后没有数据")
|
||||
logger.warning("[DEBUG-UTIL] Warning: DataFrame is empty after filtering.")
|
||||
|
||||
# 标准化列名以匹配训练代码期望的格式
|
||||
logger.info("[DEBUG-UTIL] Calling standardize_column_names...")
|
||||
df = standardize_column_names(df)
|
||||
logger.info(f"[DEBUG-UTIL] DataFrame columns after standardization: {df.columns.tolist()}")
|
||||
|
||||
logger.info("[DEBUG-UTIL] <--- Exiting load_multi_store_data function.")
|
||||
return df
|
||||
|
||||
def standardize_column_names(df: pd.DataFrame) -> pd.DataFrame:
|
||||
@ -109,7 +120,7 @@ def standardize_column_names(df: pd.DataFrame) -> pd.DataFrame:
|
||||
|
||||
# 列名映射:新列名 -> 原列名
|
||||
column_mapping = {
|
||||
'sales': 'quantity_sold', # 销售数量
|
||||
'quantity_sold': 'sales_quantity', # 销售数量
|
||||
'price': 'unit_price', # 单价
|
||||
'weekday': 'day_of_week' # 星期几
|
||||
}
|
||||
@ -170,7 +181,7 @@ def standardize_column_names(df: pd.DataFrame) -> pd.DataFrame:
|
||||
|
||||
return df
|
||||
|
||||
def get_available_stores(file_path: str = 'pharmacy_sales_multi_store.csv') -> List[Dict[str, Any]]:
|
||||
def get_available_stores(file_path: str = 'data/timeseries_training_data_sample_10s50p.parquet') -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取可用的店铺列表
|
||||
|
||||
@ -191,7 +202,7 @@ def get_available_stores(file_path: str = 'pharmacy_sales_multi_store.csv') -> L
|
||||
print(f"获取店铺列表失败: {e}")
|
||||
return []
|
||||
|
||||
def get_available_products(file_path: str = 'pharmacy_sales_multi_store.csv',
|
||||
def get_available_products(file_path: str = 'data/timeseries_training_data_sample_10s50p.parquet',
|
||||
store_id: Optional[str] = None) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取可用的产品列表
|
||||
@ -222,7 +233,7 @@ def get_available_products(file_path: str = 'pharmacy_sales_multi_store.csv',
|
||||
|
||||
def get_store_product_sales_data(store_id: str,
|
||||
product_id: str,
|
||||
file_path: str = 'pharmacy_sales_multi_store.csv') -> pd.DataFrame:
|
||||
file_path: str = 'data/timeseries_training_data_sample_10s50p.parquet') -> pd.DataFrame:
|
||||
"""
|
||||
获取特定店铺和产品的销售数据,用于模型训练
|
||||
|
||||
@ -256,7 +267,7 @@ def get_store_product_sales_data(store_id: str,
|
||||
|
||||
def aggregate_multi_store_data(product_id: str,
|
||||
aggregation_method: str = 'sum',
|
||||
file_path: str = 'pharmacy_sales_multi_store.csv') -> pd.DataFrame:
|
||||
file_path: str = 'data/timeseries_training_data_sample_10s50p.parquet') -> pd.DataFrame:
|
||||
"""
|
||||
聚合多个店铺的销售数据,用于全局模型训练
|
||||
|
||||
@ -319,7 +330,7 @@ def aggregate_multi_store_data(product_id: str,
|
||||
|
||||
return aggregated_df
|
||||
|
||||
def get_sales_statistics(file_path: str = 'pharmacy_sales_multi_store.csv',
|
||||
def get_sales_statistics(file_path: str = 'data/timeseries_training_data_sample_10s50p.parquet',
|
||||
store_id: Optional[str] = None,
|
||||
product_id: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
@ -358,7 +369,7 @@ def get_sales_statistics(file_path: str = 'pharmacy_sales_multi_store.csv',
|
||||
return {'error': str(e)}
|
||||
|
||||
# 向后兼容的函数
|
||||
def load_data(file_path='pharmacy_sales.xlsx', store_id=None):
|
||||
def load_data(file_path='data/timeseries_training_data_sample_10s50p.parquet', store_id=None):
|
||||
"""
|
||||
向后兼容的数据加载函数
|
||||
"""
|
||||
|
@ -16,8 +16,8 @@ def start_api_debug():
|
||||
|
||||
# 启动命令
|
||||
cmd = [
|
||||
sys.executable,
|
||||
"./server/api.py",
|
||||
"uv", "run",
|
||||
"./server/api.py",
|
||||
"--debug",
|
||||
"--host", "0.0.0.0",
|
||||
"--port", "5000"
|
||||
|
4
temp_check_parquet.py
Normal file
4
temp_check_parquet.py
Normal file
@ -0,0 +1,4 @@
|
||||
import pandas as pd
|
||||
|
||||
df = pd.read_parquet('data/timeseries_training_data_sample_10s50p.parquet')
|
||||
print(df.columns)
|
@ -1,41 +0,0 @@
|
||||
@echo off
|
||||
chcp 65001 >nul 2>&1
|
||||
echo 🚀 启动药店销售预测系统API服务器 (WebSocket修复版)
|
||||
echo.
|
||||
|
||||
:: 设置编码环境变量
|
||||
set PYTHONIOENCODING=utf-8
|
||||
set PYTHONLEGACYWINDOWSSTDIO=0
|
||||
|
||||
:: 显示当前配置
|
||||
echo 📋 当前环境配置:
|
||||
echo 编码: UTF-8
|
||||
echo 路径: %CD%
|
||||
echo Python: uv管理
|
||||
echo.
|
||||
|
||||
:: 检查依赖
|
||||
echo 🔍 检查Python依赖...
|
||||
uv list --quiet >nul 2>&1
|
||||
if errorlevel 1 (
|
||||
echo ⚠️ UV环境未配置,正在初始化...
|
||||
uv sync
|
||||
)
|
||||
|
||||
echo ✅ 依赖检查完成
|
||||
echo.
|
||||
|
||||
:: 启动API服务器
|
||||
echo 🌐 启动API服务器 (WebSocket支持)...
|
||||
echo 💡 访问地址: http://localhost:5000
|
||||
echo 🔗 WebSocket端点: ws://localhost:5000/socket.io
|
||||
echo.
|
||||
echo 📝 启动日志:
|
||||
echo ----------------------------------------
|
||||
|
||||
uv run server/api.py --host 0.0.0.0 --port 5000
|
||||
|
||||
echo.
|
||||
echo ----------------------------------------
|
||||
echo 🛑 API服务器已停止
|
||||
pause
|
11
启动API服务器.bat
11
启动API服务器.bat
@ -1,11 +0,0 @@
|
||||
@echo off
|
||||
chcp 65001 >nul 2>&1
|
||||
set PYTHONIOENCODING=utf-8
|
||||
set PYTHONLEGACYWINDOWSSTDIO=0
|
||||
cd /d %~dp0
|
||||
echo 🚀 启动药店销售预测系统API服务器...
|
||||
echo 📝 编码设置: UTF-8
|
||||
echo 🌐 服务地址: http://127.0.0.1:5000
|
||||
echo.
|
||||
uv run server/api.py
|
||||
pause
|
30
导出依赖配置.bat
30
导出依赖配置.bat
@ -1,30 +0,0 @@
|
||||
@echo off
|
||||
chcp 65001 >nul 2>&1
|
||||
echo 📦 导出UV依赖配置
|
||||
echo.
|
||||
|
||||
:: 设置编码
|
||||
set PYTHONIOENCODING=utf-8
|
||||
|
||||
echo 📋 导出requirements.txt格式...
|
||||
uv export --format requirements-txt > requirements-exported.txt
|
||||
|
||||
echo 📋 导出依赖树状图...
|
||||
uv tree > dependency-tree.txt
|
||||
|
||||
echo 📋 显示当前已安装的包...
|
||||
uv list > installed-packages.txt
|
||||
|
||||
echo 📋 显示uv配置...
|
||||
uv config list > uv-config.txt
|
||||
|
||||
echo.
|
||||
echo ✅ 依赖配置导出完成!
|
||||
echo.
|
||||
echo 📁 生成的文件:
|
||||
echo - requirements-exported.txt (标准requirements格式)
|
||||
echo - dependency-tree.txt (依赖关系树)
|
||||
echo - installed-packages.txt (已安装包列表)
|
||||
echo - uv-config.txt (UV配置信息)
|
||||
echo.
|
||||
pause
|
43
快速安装依赖.bat
43
快速安装依赖.bat
@ -1,43 +0,0 @@
|
||||
@echo off
|
||||
chcp 65001 >nul 2>&1
|
||||
echo 🚀 药店销售预测系统 - 快速安装依赖
|
||||
echo.
|
||||
|
||||
:: 设置编码环境变量
|
||||
set PYTHONIOENCODING=utf-8
|
||||
set PYTHONLEGACYWINDOWSSTDIO=0
|
||||
|
||||
echo 📁 配置UV缓存目录...
|
||||
uv config set cache-dir ".uv_cache"
|
||||
|
||||
echo 🌐 配置镜像源...
|
||||
uv config set global.index-url "https://pypi.tuna.tsinghua.edu.cn/simple"
|
||||
|
||||
echo.
|
||||
echo 📦 安装核心依赖包...
|
||||
echo.
|
||||
|
||||
:: 分批安装,避免超时
|
||||
echo 1/4 安装基础数据处理包...
|
||||
uv add numpy pandas openpyxl
|
||||
|
||||
echo 2/4 安装机器学习包...
|
||||
uv add scikit-learn matplotlib tqdm
|
||||
|
||||
echo 3/4 安装Web框架包...
|
||||
uv add flask flask-cors flask-socketio flasgger werkzeug
|
||||
|
||||
echo 4/4 安装深度学习框架...
|
||||
uv add torch torchvision --index-url https://download.pytorch.org/whl/cpu
|
||||
|
||||
echo.
|
||||
echo ✅ 核心依赖安装完成!
|
||||
echo.
|
||||
echo 🔍 检查安装状态...
|
||||
uv list
|
||||
|
||||
echo.
|
||||
echo 🎉 依赖安装完成!可以启动系统了
|
||||
echo 💡 启动命令: uv run server/api.py
|
||||
echo.
|
||||
pause
|
43
配置UV环境.bat
43
配置UV环境.bat
@ -1,43 +0,0 @@
|
||||
@echo off
|
||||
chcp 65001 >nul 2>&1
|
||||
echo 🔧 配置药店销售预测系统UV环境...
|
||||
echo.
|
||||
|
||||
:: 设置编码环境变量
|
||||
set PYTHONIOENCODING=utf-8
|
||||
set PYTHONLEGACYWINDOWSSTDIO=0
|
||||
|
||||
:: 设置缓存目录
|
||||
echo 📁 设置UV缓存目录...
|
||||
uv config set cache-dir "H:\_Workings\_OneTree\_ShopTRAINING\.uv_cache"
|
||||
|
||||
:: 设置镜像源
|
||||
echo 🌐 配置国内镜像源...
|
||||
uv config set global.index-url "https://pypi.tuna.tsinghua.edu.cn/simple"
|
||||
|
||||
:: 设置信任主机
|
||||
echo 🔒 配置信任主机...
|
||||
uv config set global.trusted-host "pypi.tuna.tsinghua.edu.cn"
|
||||
|
||||
echo.
|
||||
echo ✅ UV环境配置完成
|
||||
echo 📋 当前配置:
|
||||
uv config list
|
||||
|
||||
echo.
|
||||
echo 🚀 初始化项目并同步依赖...
|
||||
uv sync
|
||||
|
||||
echo.
|
||||
echo 📦 安装完成,检查依赖状态...
|
||||
uv tree
|
||||
|
||||
echo.
|
||||
echo 🎉 环境配置和依赖同步完成!
|
||||
echo.
|
||||
echo 💡 使用方法:
|
||||
echo 启动API服务器: uv run server/api.py
|
||||
echo 运行测试: uv run pytest
|
||||
echo 格式化代码: uv run black server/
|
||||
echo.
|
||||
pause
|
Loading…
x
Reference in New Issue
Block a user