Files
DDUp/tests/test_database.py
2026-01-22 12:57:26 +08:00

418 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""数据库操作测试"""
from datetime import date, timedelta
import pytest
from src.vitals.core import database as db
from src.vitals.core.models import Exercise, Meal, Sleep, Weight, UserConfig, User
class TestExerciseDB:
"""运动记录数据库测试"""
def test_add_exercise(self, sample_exercise):
"""测试添加运动记录"""
record_id = db.add_exercise(sample_exercise)
assert record_id > 0
def test_get_exercises(self, sample_exercise):
"""测试查询运动记录"""
db.add_exercise(sample_exercise)
exercises = db.get_exercises()
assert len(exercises) == 1
assert exercises[0].type == "跑步"
def test_get_exercises_by_date_range(self, sample_exercise):
"""测试按日期范围查询"""
db.add_exercise(sample_exercise)
# 查询包含该日期的范围
exercises = db.get_exercises(
start_date=date(2026, 1, 1),
end_date=date(2026, 1, 31),
)
assert len(exercises) == 1
# 查询不包含该日期的范围
exercises = db.get_exercises(
start_date=date(2026, 2, 1),
end_date=date(2026, 2, 28),
)
assert len(exercises) == 0
def test_multiple_exercises(self):
"""测试多条记录"""
for i in range(3):
exercise = Exercise(
date=date(2026, 1, 15 + i),
type="跑步",
duration=30 + i * 10,
calories=200 + i * 50,
)
db.add_exercise(exercise)
exercises = db.get_exercises()
assert len(exercises) == 3
# 按日期降序
assert exercises[0].date > exercises[1].date
class TestMealDB:
"""饮食记录数据库测试"""
def test_add_meal(self, sample_meal):
"""测试添加饮食记录"""
record_id = db.add_meal(sample_meal)
assert record_id > 0
def test_get_meals(self, sample_meal):
"""测试查询饮食记录"""
db.add_meal(sample_meal)
meals = db.get_meals()
assert len(meals) == 1
assert meals[0].meal_type == "午餐"
def test_get_meals_by_date(self, sample_meal):
"""测试按日期查询"""
db.add_meal(sample_meal)
meals = db.get_meals(
start_date=date(2026, 1, 18),
end_date=date(2026, 1, 18),
)
assert len(meals) == 1
def test_meal_with_food_items(self):
"""测试带食物条目的记录"""
meal = Meal(
date=date(2026, 1, 18),
meal_type="午餐",
description="测试",
calories=500,
food_items=[
{"name": "米饭", "calories": 200},
{"name": "红烧肉", "calories": 300},
],
)
db.add_meal(meal)
meals = db.get_meals()
assert len(meals) == 1
assert meals[0].food_items is not None
assert len(meals[0].food_items) == 2
class TestSleepDB:
"""睡眠记录数据库测试"""
def test_add_sleep(self, sample_sleep):
"""测试添加睡眠记录"""
record_id = db.add_sleep(sample_sleep)
assert record_id > 0
def test_get_sleep_records(self, sample_sleep):
"""测试查询睡眠记录"""
db.add_sleep(sample_sleep)
records = db.get_sleep_records()
assert len(records) == 1
assert records[0].duration == 7.5
def test_sleep_quality_range(self):
"""测试不同质量评分"""
for quality in range(1, 6):
sleep = Sleep(
date=date(2026, 1, 10 + quality),
duration=7.0,
quality=quality,
)
db.add_sleep(sleep)
records = db.get_sleep_records()
assert len(records) == 5
class TestWeightDB:
"""体重记录数据库测试"""
def test_add_weight(self, sample_weight):
"""测试添加体重记录"""
record_id = db.add_weight(sample_weight)
assert record_id > 0
def test_get_weight_records(self, sample_weight):
"""测试查询体重记录"""
db.add_weight(sample_weight)
records = db.get_weight_records()
assert len(records) == 1
assert records[0].weight_kg == 72.5
def test_get_latest_weight(self, sample_weight):
"""测试获取最新体重"""
db.add_weight(sample_weight)
# 添加更早的记录
older = Weight(
date=date(2026, 1, 10),
weight_kg=73.0,
)
db.add_weight(older)
latest = db.get_latest_weight()
assert latest.weight_kg == 72.5 # 最新的是 1月18日
def test_get_latest_weight_empty(self):
"""测试无记录时获取最新体重"""
latest = db.get_latest_weight()
assert latest is None
class TestConfigDB:
"""用户配置数据库测试"""
def test_save_and_get_config(self, sample_config):
"""测试保存和获取配置"""
db.save_config(sample_config)
config = db.get_config()
assert config.age == 28
assert config.gender == "male"
assert config.height == 175.0
assert config.weight == 72.0
def test_update_config(self):
"""测试更新配置"""
config1 = UserConfig(age=28, gender="male")
db.save_config(config1)
config2 = UserConfig(age=29, gender="male", height=175.0)
db.save_config(config2)
config = db.get_config()
assert config.age == 29
assert config.height == 175.0
def test_get_config_default(self):
"""测试获取默认配置"""
config = db.get_config()
assert config.activity_level == "moderate"
assert config.goal == "maintain"
class TestDatabaseInit:
"""数据库初始化测试"""
def test_init_creates_tables(self, test_db):
"""测试初始化创建表"""
import sqlite3
conn = sqlite3.connect(test_db)
cursor = conn.cursor()
# 检查表是否存在
cursor.execute(
"SELECT name FROM sqlite_master WHERE type='table'"
)
tables = {row[0] for row in cursor.fetchall()}
assert "exercise" in tables
assert "meal" in tables
assert "sleep" in tables
assert "weight" in tables
assert "config" in tables
conn.close()
def test_init_creates_indexes(self, test_db):
"""测试初始化创建索引"""
import sqlite3
conn = sqlite3.connect(test_db)
cursor = conn.cursor()
cursor.execute(
"SELECT name FROM sqlite_master WHERE type='index'"
)
indexes = {row[0] for row in cursor.fetchall()}
assert "idx_exercise_date" in indexes
assert "idx_meal_date" in indexes
assert "idx_sleep_date" in indexes
assert "idx_weight_date" in indexes
conn.close()
class TestUserDB:
"""用户数据库测试"""
def test_add_user(self):
"""测试添加用户"""
user = User(name="测试用户")
user_id = db.add_user(user)
assert user_id > 0
def test_get_users(self):
"""测试获取用户列表"""
db.add_user(User(name="用户1"))
db.add_user(User(name="用户2"))
users = db.get_users()
assert len(users) == 2
def test_get_user_by_id(self):
"""测试按 ID 获取用户"""
user = User(name="小明")
user_id = db.add_user(user)
fetched = db.get_user(user_id)
assert fetched is not None
assert fetched.name == "小明"
def test_update_user(self):
"""测试更新用户"""
user = User(name="原名")
user_id = db.add_user(user)
user.id = user_id
user.name = "新名"
db.update_user(user)
fetched = db.get_user(user_id)
assert fetched.name == "新名"
def test_delete_user(self):
"""测试删除用户"""
user_id = db.add_user(User(name="待删除"))
db.delete_user(user_id)
fetched = db.get_user(user_id)
assert fetched is None
def test_set_active_user(self):
"""测试设置激活用户"""
id1 = db.add_user(User(name="用户1"))
id2 = db.add_user(User(name="用户2"))
db.set_active_user(id1)
user1 = db.get_user(id1)
user2 = db.get_user(id2)
assert user1.is_active == True
assert user2.is_active == False
# 切换激活用户
db.set_active_user(id2)
user1 = db.get_user(id1)
user2 = db.get_user(id2)
assert user1.is_active == False
assert user2.is_active == True
def test_get_active_user(self):
"""测试获取激活用户"""
id1 = db.add_user(User(name="用户1"))
db.set_active_user(id1)
active = db.get_active_user()
assert active is not None
assert active.id == id1
class TestUserIdMigration:
"""user_id 迁移测试"""
def test_ensure_default_user_creates_user(self):
"""测试 ensure_default_user 创建默认用户"""
db.ensure_default_user()
users = db.get_users()
assert len(users) >= 1
# 应有一个激活用户
active = db.get_active_user()
assert active is not None
def test_existing_data_gets_default_user_id(self):
"""测试现有数据关联到默认用户"""
# 先添加一条运动记录(无 user_id
exercise = Exercise(
date=date(2026, 1, 18),
type="跑步",
duration=30,
calories=200,
)
db.add_exercise(exercise)
# 运行迁移
db.ensure_default_user()
# 获取默认用户的数据
active = db.get_active_user()
exercises = db.get_exercises(user_id=active.id)
assert len(exercises) >= 1
class TestDataClear:
"""数据清除测试"""
def test_preview_delete_all(self):
"""测试预览删除全部"""
# 创建用户和数据
user_id = db.add_user(User(name="测试用户"))
db.add_exercise(Exercise(date=date(2026, 1, 10), type="跑步", duration=30, calories=200), user_id)
db.add_exercise(Exercise(date=date(2026, 1, 11), type="游泳", duration=45, calories=300), user_id)
db.add_meal(Meal(date=date(2026, 1, 10), meal_type="午餐", description="米饭", calories=500), user_id)
counts = db.preview_delete(user_id, mode="all")
assert counts["exercise"] == 2
assert counts["meal"] == 1
assert counts["sleep"] == 0
assert counts["weight"] == 0
assert counts["total"] == 3
def test_preview_delete_by_range(self):
"""测试预览按时间范围删除"""
user_id = db.add_user(User(name="测试用户"))
db.add_exercise(Exercise(date=date(2026, 1, 5), type="跑步", duration=30, calories=200), user_id)
db.add_exercise(Exercise(date=date(2026, 1, 10), type="跑步", duration=30, calories=200), user_id)
db.add_exercise(Exercise(date=date(2026, 1, 15), type="跑步", duration=30, calories=200), user_id)
counts = db.preview_delete(user_id, mode="range", date_from=date(2026, 1, 8), date_to=date(2026, 1, 12))
assert counts["exercise"] == 1
assert counts["total"] == 1
def test_preview_delete_by_type(self):
"""测试预览按类型删除"""
user_id = db.add_user(User(name="测试用户"))
db.add_exercise(Exercise(date=date(2026, 1, 10), type="跑步", duration=30, calories=200), user_id)
db.add_meal(Meal(date=date(2026, 1, 10), meal_type="午餐", description="米饭", calories=500), user_id)
counts = db.preview_delete(user_id, mode="type", data_types=["exercise"])
assert counts["exercise"] == 1
assert counts["meal"] == 0
assert counts["total"] == 1
def test_clear_data_all(self):
"""测试清除全部数据"""
user_id = db.add_user(User(name="测试用户"))
db.add_exercise(Exercise(date=date(2026, 1, 10), type="跑步", duration=30, calories=200), user_id)
db.add_meal(Meal(date=date(2026, 1, 10), meal_type="午餐", description="米饭", calories=500), user_id)
db.clear_data(user_id, mode="all")
assert len(db.get_exercises(user_id=user_id)) == 0
assert len(db.get_meals(user_id=user_id)) == 0
def test_clear_data_by_range(self):
"""测试按时间范围清除"""
user_id = db.add_user(User(name="测试用户"))
db.add_exercise(Exercise(date=date(2026, 1, 5), type="跑步", duration=30, calories=200), user_id)
db.add_exercise(Exercise(date=date(2026, 1, 10), type="跑步", duration=30, calories=200), user_id)
db.add_exercise(Exercise(date=date(2026, 1, 15), type="跑步", duration=30, calories=200), user_id)
db.clear_data(user_id, mode="range", date_from=date(2026, 1, 8), date_to=date(2026, 1, 12))
exercises = db.get_exercises(user_id=user_id)
assert len(exercises) == 2
dates = [e.date for e in exercises]
assert date(2026, 1, 10) not in dates
def test_clear_data_by_type(self):
"""测试按类型清除"""
user_id = db.add_user(User(name="测试用户"))
db.add_exercise(Exercise(date=date(2026, 1, 10), type="跑步", duration=30, calories=200), user_id)
db.add_meal(Meal(date=date(2026, 1, 10), meal_type="午餐", description="米饭", calories=500), user_id)
db.clear_data(user_id, mode="type", data_types=["exercise"])
assert len(db.get_exercises(user_id=user_id)) == 0
assert len(db.get_meals(user_id=user_id)) == 1