297 lines
11 KiB
Python
297 lines
11 KiB
Python
from sqlalchemy.orm import Session, joinedload, subqueryload
|
|
from sqlalchemy import func
|
|
import models, schemas
|
|
from security import get_password_hash, verify_password
|
|
from fastapi import HTTPException
|
|
from typing import List
|
|
|
|
# --- User CRUD ---
|
|
def get_user_by_phone_number(db: Session, phone_number: str):
|
|
return db.query(models.User).filter(models.User.phone_number == phone_number).first()
|
|
|
|
def create_user(db: Session, user: schemas.UserCreate):
|
|
hashed_password = None
|
|
if user.password:
|
|
hashed_password = get_password_hash(user.password)
|
|
db_user = models.User(phone_number=user.phone_number, hashed_password=hashed_password)
|
|
db.add(db_user)
|
|
db.commit()
|
|
db.refresh(db_user)
|
|
return db_user
|
|
|
|
def authenticate_user(db: Session, phone_number: str, password: str):
|
|
user = get_user_by_phone_number(db, phone_number)
|
|
if not user or not user.hashed_password:
|
|
return False
|
|
if not verify_password(password, user.hashed_password):
|
|
return False
|
|
return user
|
|
|
|
def set_user_password(db: Session, user: models.User, password_set: schemas.PasswordSet):
|
|
user.hashed_password = get_password_hash(password_set.new_password)
|
|
db.add(user)
|
|
db.commit()
|
|
db.refresh(user)
|
|
return user
|
|
|
|
# --- Family Member & Health Profile CRUD ---
|
|
def create_family_member(db: Session, member: schemas.FamilyMemberCreate, owner_id: int):
|
|
db_member = models.FamilyMember(**member.dict(), owner_id=owner_id)
|
|
db.add(db_member)
|
|
db.commit()
|
|
db.refresh(db_member)
|
|
return db_member
|
|
|
|
def get_family_members_by_user(db: Session, user_id: int):
|
|
return db.query(models.FamilyMember).filter(models.FamilyMember.owner_id == user_id).all()
|
|
|
|
def get_family_member(db: Session, member_id: int, user_id: int):
|
|
return db.query(models.FamilyMember).filter(models.FamilyMember.id == member_id, models.FamilyMember.owner_id == user_id).first()
|
|
|
|
def update_health_profile(db: Session, member_id: int, profiles: List[schemas.HealthProfileCreate]):
|
|
db.query(models.HealthProfile).filter(models.HealthProfile.member_id == member_id).delete()
|
|
|
|
for profile in profiles:
|
|
db_profile = models.HealthProfile(
|
|
member_id=member_id,
|
|
profile_type=profile.profile_type,
|
|
value=profile.value
|
|
)
|
|
db.add(db_profile)
|
|
db.commit()
|
|
return db.query(models.HealthProfile).filter(models.HealthProfile.member_id == member_id).all()
|
|
|
|
|
|
# --- User Preference CRUD ---
|
|
def create_user_preferences(db: Session, user_id: int, preferences: schemas.OnboardingPreferences):
|
|
db.query(models.UserPreference).filter(models.UserPreference.user_id == user_id).delete()
|
|
|
|
for category, values in preferences.dict().items():
|
|
for value in values:
|
|
db_preference = models.UserPreference(
|
|
user_id=user_id,
|
|
category=category,
|
|
value=value
|
|
)
|
|
db.add(db_preference)
|
|
db.commit()
|
|
|
|
# --- Food, Ingredient, Additive CRUD ---
|
|
def get_food_by_barcode(db: Session, barcode: str):
|
|
return db.query(models.Food).options(joinedload(models.Food.ingredients), joinedload(models.Food.additives)).filter(models.Food.barcode == barcode).first()
|
|
|
|
def create_food(db: Session, food: schemas.FoodCreate):
|
|
db_food_data = food.dict(exclude={'ingredient_ids', 'additive_ids'})
|
|
db_food = models.Food(**db_food_data)
|
|
|
|
if food.ingredient_ids:
|
|
ingredients = db.query(models.Ingredient).filter(models.Ingredient.id.in_(food.ingredient_ids)).all()
|
|
db_food.ingredients.extend(ingredients)
|
|
|
|
if food.additive_ids:
|
|
additives = db.query(models.Additive).filter(models.Additive.id.in_(food.additive_ids)).all()
|
|
db_food.additives.extend(additives)
|
|
|
|
db.add(db_food)
|
|
db.commit()
|
|
db.refresh(db_food)
|
|
return db_food
|
|
|
|
# --- Article CRUD ---
|
|
def create_article(db: Session, article: schemas.ArticleCreate):
|
|
db_article = models.Article(**article.dict())
|
|
db.add(db_article)
|
|
db.commit()
|
|
db.refresh(db_article)
|
|
return db_article
|
|
|
|
def get_articles_by_category(db: Session, category: str, skip: int = 0, limit: int = 10):
|
|
return db.query(models.Article).filter(models.Article.category == category).offset(skip).limit(limit).all()
|
|
|
|
def get_article(db: Session, article_id: int):
|
|
return db.query(models.Article).filter(models.Article.id == article_id).first()
|
|
|
|
# --- Recipe CRUD ---
|
|
def create_recipe(db: Session, recipe: schemas.RecipeCreate, author_id: int):
|
|
recipe_data = recipe.dict(exclude={'ingredients'})
|
|
db_recipe = models.Recipe(**recipe_data, author_id=author_id)
|
|
|
|
for ingredient_data in recipe.ingredients:
|
|
db_recipe_ingredient = models.RecipeIngredient(
|
|
ingredient_id=ingredient_data.ingredient_id,
|
|
quantity=ingredient_data.quantity,
|
|
unit=ingredient_data.unit
|
|
)
|
|
db_recipe.ingredients.append(db_recipe_ingredient)
|
|
|
|
db.add(db_recipe)
|
|
db.commit()
|
|
db.refresh(db_recipe)
|
|
return db_recipe
|
|
|
|
def get_recipe_details(db: Session, recipe_id: int):
|
|
return db.query(models.Recipe).options(joinedload(models.Recipe.ingredients).joinedload(models.RecipeIngredient.ingredient)).filter(models.Recipe.id == recipe_id).first()
|
|
|
|
# --- Community (Post, Comment, Favorite) CRUD ---
|
|
def create_post(db: Session, post: schemas.PostCreate, author_id: int):
|
|
db_post = models.Post(**post.dict(), author_id=author_id)
|
|
db.add(db_post)
|
|
db.commit()
|
|
db.refresh(db_post)
|
|
return db_post
|
|
|
|
def get_posts_by_food(db: Session, food_id: int, skip: int = 0, limit: int = 10):
|
|
posts = db.query(models.Post).filter(models.Post.food_id == food_id).order_by(models.Post.created_at.desc()).offset(skip).limit(limit).options(subqueryload(models.Post.comments), subqueryload(models.Post.favorites)).all()
|
|
for post in posts:
|
|
post.favorites_count = len(post.favorites)
|
|
return posts
|
|
|
|
def create_comment(db: Session, comment: schemas.CommentCreate, author_id: int):
|
|
db_comment = models.Comment(**comment.dict(), author_id=author_id)
|
|
db.add(db_comment)
|
|
db.commit()
|
|
db.refresh(db_comment)
|
|
return db_comment
|
|
|
|
def toggle_favorite(db: Session, favorite: schemas.FavoriteCreate, user_id: int):
|
|
db_favorite = db.query(models.Favorite).filter(
|
|
models.Favorite.entity_type == favorite.entity_type,
|
|
models.Favorite.entity_id == favorite.entity_id,
|
|
models.Favorite.user_id == user_id
|
|
).first()
|
|
|
|
if db_favorite:
|
|
db.delete(db_favorite)
|
|
db.commit()
|
|
return {"favorited": False}
|
|
else:
|
|
db_favorite = models.Favorite(**favorite.dict(), user_id=user_id)
|
|
db.add(db_favorite)
|
|
db.commit()
|
|
return {"favorited": True}
|
|
|
|
# --- Topic CRUD ---
|
|
def get_topics(db: Session, skip: int = 0, limit: int = 100):
|
|
return db.query(models.Topic).offset(skip).limit(limit).all()
|
|
|
|
def get_posts_by_topic(db: Session, topic_id: int, skip: int = 0, limit: int = 10):
|
|
posts = db.query(models.Post).filter(models.Post.topic_id == topic_id).order_by(models.Post.created_at.desc()).offset(skip).limit(limit).options(subqueryload(models.Post.comments), subqueryload(models.Post.favorites)).all()
|
|
for post in posts:
|
|
post.favorites_count = len(post.favorites)
|
|
return posts
|
|
|
|
# --- E-commerce CRUD ---
|
|
def get_products(db: Session, skip: int = 0, limit: int = 10):
|
|
return db.query(models.Product).offset(skip).limit(limit).all()
|
|
|
|
def get_cart_items(db: Session, user_id: int):
|
|
return db.query(models.CartItem).filter(models.CartItem.user_id == user_id).all()
|
|
|
|
def add_to_cart(db: Session, item: schemas.CartItemCreate, user_id: int):
|
|
# Check if the item is already in the cart
|
|
db_item = db.query(models.CartItem).filter(models.CartItem.product_id == item.product_id, models.CartItem.user_id == user_id).first()
|
|
if db_item:
|
|
db_item.quantity += item.quantity
|
|
else:
|
|
db_item = models.CartItem(**item.dict(), user_id=user_id)
|
|
db.add(db_item)
|
|
db.commit()
|
|
db.refresh(db_item)
|
|
return db_item
|
|
|
|
def create_order(db: Session, user_id: int):
|
|
cart_items = get_cart_items(db, user_id)
|
|
if not cart_items:
|
|
raise HTTPException(status_code=400, detail="Cart is empty")
|
|
|
|
total_price = sum(item.product.price * item.quantity for item in cart_items)
|
|
|
|
db_order = models.Order(user_id=user_id, total_price=total_price)
|
|
db.add(db_order)
|
|
db.commit()
|
|
db.refresh(db_order)
|
|
|
|
for item in cart_items:
|
|
db_order_item = models.OrderItem(
|
|
order_id=db_order.id,
|
|
product_id=item.product_id,
|
|
quantity=item.quantity,
|
|
price=item.product.price
|
|
)
|
|
db.add(db_order_item)
|
|
db.delete(item) # Clear the cart
|
|
|
|
db.commit()
|
|
db.refresh(db_order)
|
|
return db_order
|
|
|
|
# --- Search History CRUD ---
|
|
def create_search_history(db: Session, search: schemas.SearchHistoryCreate, user_id: int | None = None):
|
|
db_search = models.SearchHistory(
|
|
term=search.term,
|
|
user_id=user_id
|
|
)
|
|
db.add(db_search)
|
|
db.commit()
|
|
db.refresh(db_search)
|
|
return db_search
|
|
|
|
def get_search_history_by_user(db: Session, user_id: int, skip: int = 0, limit: int = 10):
|
|
return db.query(models.SearchHistory).filter(models.SearchHistory.user_id == user_id).order_by(models.SearchHistory.created_at.desc()).offset(skip).limit(limit).all()
|
|
|
|
def get_popular_searches(db: Session, limit: int = 10):
|
|
return db.query(models.SearchHistory.term, func.count(models.SearchHistory.term).label('count')).group_by(models.SearchHistory.term).order_by(func.count(models.SearchHistory.term).desc()).limit(limit).all()
|
|
|
|
def get_or_create_oauth_user(db: Session, provider: str, openid: str):
|
|
# Check if the OAuth account already exists
|
|
oauth_account = db.query(models.OAuthAccount).filter(
|
|
models.OAuthAccount.provider == provider,
|
|
models.OAuthAccount.openid == openid
|
|
).first()
|
|
|
|
if oauth_account:
|
|
return oauth_account.user, False
|
|
|
|
# If not, create a new user and the OAuth account
|
|
new_user = models.User()
|
|
db.add(new_user)
|
|
db.commit()
|
|
db.refresh(new_user)
|
|
|
|
new_oauth_account = models.OAuthAccount(
|
|
provider=provider,
|
|
openid=openid,
|
|
user_id=new_user.id
|
|
)
|
|
db.add(new_oauth_account)
|
|
db.commit()
|
|
|
|
return new_user, True
|
|
|
|
def suggest_recipes_by_ingredients(db: Session, ingredient_ids: List[int], limit: int = 10):
|
|
if not ingredient_ids:
|
|
return []
|
|
|
|
# Find recipes that contain the most of the given ingredients
|
|
# This is a simplified logic. A real-world scenario might involve more complex ranking.
|
|
recipes_with_match_count = db.query(
|
|
models.Recipe,
|
|
func.count(models.RecipeIngredient.ingredient_id).label('match_count')
|
|
).join(models.RecipeIngredient).filter(
|
|
models.RecipeIngredient.ingredient_id.in_(ingredient_ids)
|
|
).group_by(models.Recipe.id).order_by(
|
|
func.count(models.RecipeIngredient.ingredient_id).desc()
|
|
).limit(limit).all()
|
|
|
|
# The result is a list of tuples (Recipe, match_count). We only need the Recipe object.
|
|
return [recipe for recipe, match_count in recipes_with_match_count]
|
|
|
|
def create_submitted_food(db: Session, food: schemas.SubmittedFoodCreate, user_id: int):
|
|
db_food = models.SubmittedFood(
|
|
**food.dict(),
|
|
submitted_by_id=user_id
|
|
)
|
|
db.add(db_food)
|
|
db.commit()
|
|
db.refresh(db_food)
|
|
return db_food |