418 lines
14 KiB
Python
418 lines
14 KiB
Python
"""数据库操作测试"""
|
||
|
||
from datetime import date, timedelta
|
||
|
||
import pytest
|
||
|
||
from src.vitals.core import database as db
|
||
from src.vitals.core.models import Exercise, Meal, Sleep, Weight, UserConfig, User
|
||
|
||
|
||
class TestExerciseDB:
|
||
"""运动记录数据库测试"""
|
||
|
||
def test_add_exercise(self, sample_exercise):
|
||
"""测试添加运动记录"""
|
||
record_id = db.add_exercise(sample_exercise)
|
||
assert record_id > 0
|
||
|
||
def test_get_exercises(self, sample_exercise):
|
||
"""测试查询运动记录"""
|
||
db.add_exercise(sample_exercise)
|
||
exercises = db.get_exercises()
|
||
assert len(exercises) == 1
|
||
assert exercises[0].type == "跑步"
|
||
|
||
def test_get_exercises_by_date_range(self, sample_exercise):
|
||
"""测试按日期范围查询"""
|
||
db.add_exercise(sample_exercise)
|
||
|
||
# 查询包含该日期的范围
|
||
exercises = db.get_exercises(
|
||
start_date=date(2026, 1, 1),
|
||
end_date=date(2026, 1, 31),
|
||
)
|
||
assert len(exercises) == 1
|
||
|
||
# 查询不包含该日期的范围
|
||
exercises = db.get_exercises(
|
||
start_date=date(2026, 2, 1),
|
||
end_date=date(2026, 2, 28),
|
||
)
|
||
assert len(exercises) == 0
|
||
|
||
def test_multiple_exercises(self):
|
||
"""测试多条记录"""
|
||
for i in range(3):
|
||
exercise = Exercise(
|
||
date=date(2026, 1, 15 + i),
|
||
type="跑步",
|
||
duration=30 + i * 10,
|
||
calories=200 + i * 50,
|
||
)
|
||
db.add_exercise(exercise)
|
||
|
||
exercises = db.get_exercises()
|
||
assert len(exercises) == 3
|
||
# 按日期降序
|
||
assert exercises[0].date > exercises[1].date
|
||
|
||
|
||
class TestMealDB:
|
||
"""饮食记录数据库测试"""
|
||
|
||
def test_add_meal(self, sample_meal):
|
||
"""测试添加饮食记录"""
|
||
record_id = db.add_meal(sample_meal)
|
||
assert record_id > 0
|
||
|
||
def test_get_meals(self, sample_meal):
|
||
"""测试查询饮食记录"""
|
||
db.add_meal(sample_meal)
|
||
meals = db.get_meals()
|
||
assert len(meals) == 1
|
||
assert meals[0].meal_type == "午餐"
|
||
|
||
def test_get_meals_by_date(self, sample_meal):
|
||
"""测试按日期查询"""
|
||
db.add_meal(sample_meal)
|
||
meals = db.get_meals(
|
||
start_date=date(2026, 1, 18),
|
||
end_date=date(2026, 1, 18),
|
||
)
|
||
assert len(meals) == 1
|
||
|
||
def test_meal_with_food_items(self):
|
||
"""测试带食物条目的记录"""
|
||
meal = Meal(
|
||
date=date(2026, 1, 18),
|
||
meal_type="午餐",
|
||
description="测试",
|
||
calories=500,
|
||
food_items=[
|
||
{"name": "米饭", "calories": 200},
|
||
{"name": "红烧肉", "calories": 300},
|
||
],
|
||
)
|
||
db.add_meal(meal)
|
||
|
||
meals = db.get_meals()
|
||
assert len(meals) == 1
|
||
assert meals[0].food_items is not None
|
||
assert len(meals[0].food_items) == 2
|
||
|
||
|
||
class TestSleepDB:
|
||
"""睡眠记录数据库测试"""
|
||
|
||
def test_add_sleep(self, sample_sleep):
|
||
"""测试添加睡眠记录"""
|
||
record_id = db.add_sleep(sample_sleep)
|
||
assert record_id > 0
|
||
|
||
def test_get_sleep_records(self, sample_sleep):
|
||
"""测试查询睡眠记录"""
|
||
db.add_sleep(sample_sleep)
|
||
records = db.get_sleep_records()
|
||
assert len(records) == 1
|
||
assert records[0].duration == 7.5
|
||
|
||
def test_sleep_quality_range(self):
|
||
"""测试不同质量评分"""
|
||
for quality in range(1, 6):
|
||
sleep = Sleep(
|
||
date=date(2026, 1, 10 + quality),
|
||
duration=7.0,
|
||
quality=quality,
|
||
)
|
||
db.add_sleep(sleep)
|
||
|
||
records = db.get_sleep_records()
|
||
assert len(records) == 5
|
||
|
||
|
||
class TestWeightDB:
|
||
"""体重记录数据库测试"""
|
||
|
||
def test_add_weight(self, sample_weight):
|
||
"""测试添加体重记录"""
|
||
record_id = db.add_weight(sample_weight)
|
||
assert record_id > 0
|
||
|
||
def test_get_weight_records(self, sample_weight):
|
||
"""测试查询体重记录"""
|
||
db.add_weight(sample_weight)
|
||
records = db.get_weight_records()
|
||
assert len(records) == 1
|
||
assert records[0].weight_kg == 72.5
|
||
|
||
def test_get_latest_weight(self, sample_weight):
|
||
"""测试获取最新体重"""
|
||
db.add_weight(sample_weight)
|
||
|
||
# 添加更早的记录
|
||
older = Weight(
|
||
date=date(2026, 1, 10),
|
||
weight_kg=73.0,
|
||
)
|
||
db.add_weight(older)
|
||
|
||
latest = db.get_latest_weight()
|
||
assert latest.weight_kg == 72.5 # 最新的是 1月18日
|
||
|
||
def test_get_latest_weight_empty(self):
|
||
"""测试无记录时获取最新体重"""
|
||
latest = db.get_latest_weight()
|
||
assert latest is None
|
||
|
||
|
||
class TestConfigDB:
|
||
"""用户配置数据库测试"""
|
||
|
||
def test_save_and_get_config(self, sample_config):
|
||
"""测试保存和获取配置"""
|
||
db.save_config(sample_config)
|
||
config = db.get_config()
|
||
|
||
assert config.age == 28
|
||
assert config.gender == "male"
|
||
assert config.height == 175.0
|
||
assert config.weight == 72.0
|
||
|
||
def test_update_config(self):
|
||
"""测试更新配置"""
|
||
config1 = UserConfig(age=28, gender="male")
|
||
db.save_config(config1)
|
||
|
||
config2 = UserConfig(age=29, gender="male", height=175.0)
|
||
db.save_config(config2)
|
||
|
||
config = db.get_config()
|
||
assert config.age == 29
|
||
assert config.height == 175.0
|
||
|
||
def test_get_config_default(self):
|
||
"""测试获取默认配置"""
|
||
config = db.get_config()
|
||
assert config.activity_level == "moderate"
|
||
assert config.goal == "maintain"
|
||
|
||
|
||
class TestDatabaseInit:
|
||
"""数据库初始化测试"""
|
||
|
||
def test_init_creates_tables(self, test_db):
|
||
"""测试初始化创建表"""
|
||
import sqlite3
|
||
|
||
conn = sqlite3.connect(test_db)
|
||
cursor = conn.cursor()
|
||
|
||
# 检查表是否存在
|
||
cursor.execute(
|
||
"SELECT name FROM sqlite_master WHERE type='table'"
|
||
)
|
||
tables = {row[0] for row in cursor.fetchall()}
|
||
|
||
assert "exercise" in tables
|
||
assert "meal" in tables
|
||
assert "sleep" in tables
|
||
assert "weight" in tables
|
||
assert "config" in tables
|
||
|
||
conn.close()
|
||
|
||
def test_init_creates_indexes(self, test_db):
|
||
"""测试初始化创建索引"""
|
||
import sqlite3
|
||
|
||
conn = sqlite3.connect(test_db)
|
||
cursor = conn.cursor()
|
||
|
||
cursor.execute(
|
||
"SELECT name FROM sqlite_master WHERE type='index'"
|
||
)
|
||
indexes = {row[0] for row in cursor.fetchall()}
|
||
|
||
assert "idx_exercise_date" in indexes
|
||
assert "idx_meal_date" in indexes
|
||
assert "idx_sleep_date" in indexes
|
||
assert "idx_weight_date" in indexes
|
||
|
||
conn.close()
|
||
|
||
|
||
class TestUserDB:
|
||
"""用户数据库测试"""
|
||
|
||
def test_add_user(self):
|
||
"""测试添加用户"""
|
||
user = User(name="测试用户")
|
||
user_id = db.add_user(user)
|
||
assert user_id > 0
|
||
|
||
def test_get_users(self):
|
||
"""测试获取用户列表"""
|
||
db.add_user(User(name="用户1"))
|
||
db.add_user(User(name="用户2"))
|
||
users = db.get_users()
|
||
assert len(users) == 2
|
||
|
||
def test_get_user_by_id(self):
|
||
"""测试按 ID 获取用户"""
|
||
user = User(name="小明")
|
||
user_id = db.add_user(user)
|
||
fetched = db.get_user(user_id)
|
||
assert fetched is not None
|
||
assert fetched.name == "小明"
|
||
|
||
def test_update_user(self):
|
||
"""测试更新用户"""
|
||
user = User(name="原名")
|
||
user_id = db.add_user(user)
|
||
user.id = user_id
|
||
user.name = "新名"
|
||
db.update_user(user)
|
||
fetched = db.get_user(user_id)
|
||
assert fetched.name == "新名"
|
||
|
||
def test_delete_user(self):
|
||
"""测试删除用户"""
|
||
user_id = db.add_user(User(name="待删除"))
|
||
db.delete_user(user_id)
|
||
fetched = db.get_user(user_id)
|
||
assert fetched is None
|
||
|
||
def test_set_active_user(self):
|
||
"""测试设置激活用户"""
|
||
id1 = db.add_user(User(name="用户1"))
|
||
id2 = db.add_user(User(name="用户2"))
|
||
db.set_active_user(id1)
|
||
user1 = db.get_user(id1)
|
||
user2 = db.get_user(id2)
|
||
assert user1.is_active == True
|
||
assert user2.is_active == False
|
||
# 切换激活用户
|
||
db.set_active_user(id2)
|
||
user1 = db.get_user(id1)
|
||
user2 = db.get_user(id2)
|
||
assert user1.is_active == False
|
||
assert user2.is_active == True
|
||
|
||
def test_get_active_user(self):
|
||
"""测试获取激活用户"""
|
||
id1 = db.add_user(User(name="用户1"))
|
||
db.set_active_user(id1)
|
||
active = db.get_active_user()
|
||
assert active is not None
|
||
assert active.id == id1
|
||
|
||
|
||
class TestUserIdMigration:
|
||
"""user_id 迁移测试"""
|
||
|
||
def test_ensure_default_user_creates_user(self):
|
||
"""测试 ensure_default_user 创建默认用户"""
|
||
db.ensure_default_user()
|
||
users = db.get_users()
|
||
assert len(users) >= 1
|
||
# 应有一个激活用户
|
||
active = db.get_active_user()
|
||
assert active is not None
|
||
|
||
def test_existing_data_gets_default_user_id(self):
|
||
"""测试现有数据关联到默认用户"""
|
||
# 先添加一条运动记录(无 user_id)
|
||
exercise = Exercise(
|
||
date=date(2026, 1, 18),
|
||
type="跑步",
|
||
duration=30,
|
||
calories=200,
|
||
)
|
||
db.add_exercise(exercise)
|
||
|
||
# 运行迁移
|
||
db.ensure_default_user()
|
||
|
||
# 获取默认用户的数据
|
||
active = db.get_active_user()
|
||
exercises = db.get_exercises(user_id=active.id)
|
||
assert len(exercises) >= 1
|
||
|
||
|
||
class TestDataClear:
|
||
"""数据清除测试"""
|
||
|
||
def test_preview_delete_all(self):
|
||
"""测试预览删除全部"""
|
||
# 创建用户和数据
|
||
user_id = db.add_user(User(name="测试用户"))
|
||
db.add_exercise(Exercise(date=date(2026, 1, 10), type="跑步", duration=30, calories=200), user_id)
|
||
db.add_exercise(Exercise(date=date(2026, 1, 11), type="游泳", duration=45, calories=300), user_id)
|
||
db.add_meal(Meal(date=date(2026, 1, 10), meal_type="午餐", description="米饭", calories=500), user_id)
|
||
|
||
counts = db.preview_delete(user_id, mode="all")
|
||
assert counts["exercise"] == 2
|
||
assert counts["meal"] == 1
|
||
assert counts["sleep"] == 0
|
||
assert counts["weight"] == 0
|
||
assert counts["total"] == 3
|
||
|
||
def test_preview_delete_by_range(self):
|
||
"""测试预览按时间范围删除"""
|
||
user_id = db.add_user(User(name="测试用户"))
|
||
db.add_exercise(Exercise(date=date(2026, 1, 5), type="跑步", duration=30, calories=200), user_id)
|
||
db.add_exercise(Exercise(date=date(2026, 1, 10), type="跑步", duration=30, calories=200), user_id)
|
||
db.add_exercise(Exercise(date=date(2026, 1, 15), type="跑步", duration=30, calories=200), user_id)
|
||
|
||
counts = db.preview_delete(user_id, mode="range", date_from=date(2026, 1, 8), date_to=date(2026, 1, 12))
|
||
assert counts["exercise"] == 1
|
||
assert counts["total"] == 1
|
||
|
||
def test_preview_delete_by_type(self):
|
||
"""测试预览按类型删除"""
|
||
user_id = db.add_user(User(name="测试用户"))
|
||
db.add_exercise(Exercise(date=date(2026, 1, 10), type="跑步", duration=30, calories=200), user_id)
|
||
db.add_meal(Meal(date=date(2026, 1, 10), meal_type="午餐", description="米饭", calories=500), user_id)
|
||
|
||
counts = db.preview_delete(user_id, mode="type", data_types=["exercise"])
|
||
assert counts["exercise"] == 1
|
||
assert counts["meal"] == 0
|
||
assert counts["total"] == 1
|
||
|
||
def test_clear_data_all(self):
|
||
"""测试清除全部数据"""
|
||
user_id = db.add_user(User(name="测试用户"))
|
||
db.add_exercise(Exercise(date=date(2026, 1, 10), type="跑步", duration=30, calories=200), user_id)
|
||
db.add_meal(Meal(date=date(2026, 1, 10), meal_type="午餐", description="米饭", calories=500), user_id)
|
||
|
||
db.clear_data(user_id, mode="all")
|
||
|
||
assert len(db.get_exercises(user_id=user_id)) == 0
|
||
assert len(db.get_meals(user_id=user_id)) == 0
|
||
|
||
def test_clear_data_by_range(self):
|
||
"""测试按时间范围清除"""
|
||
user_id = db.add_user(User(name="测试用户"))
|
||
db.add_exercise(Exercise(date=date(2026, 1, 5), type="跑步", duration=30, calories=200), user_id)
|
||
db.add_exercise(Exercise(date=date(2026, 1, 10), type="跑步", duration=30, calories=200), user_id)
|
||
db.add_exercise(Exercise(date=date(2026, 1, 15), type="跑步", duration=30, calories=200), user_id)
|
||
|
||
db.clear_data(user_id, mode="range", date_from=date(2026, 1, 8), date_to=date(2026, 1, 12))
|
||
|
||
exercises = db.get_exercises(user_id=user_id)
|
||
assert len(exercises) == 2
|
||
dates = [e.date for e in exercises]
|
||
assert date(2026, 1, 10) not in dates
|
||
|
||
def test_clear_data_by_type(self):
|
||
"""测试按类型清除"""
|
||
user_id = db.add_user(User(name="测试用户"))
|
||
db.add_exercise(Exercise(date=date(2026, 1, 10), type="跑步", duration=30, calories=200), user_id)
|
||
db.add_meal(Meal(date=date(2026, 1, 10), meal_type="午餐", description="米饭", calories=500), user_id)
|
||
|
||
db.clear_data(user_id, mode="type", data_types=["exercise"])
|
||
|
||
assert len(db.get_exercises(user_id=user_id)) == 0
|
||
assert len(db.get_meals(user_id=user_id)) == 1
|