feat: 用户配置隔离与食物智能识别

1. Config 表用户隔离
   - 添加 user_id 字段,复合主键 (user_id, key)
   - 现有数据归属 ID=1 用户
   - 所有 get_config/save_config 调用传入 user_id

2. 食物文字智能识别
   - 本地数据库优先匹配(快速)
   - 识别失败时自动调用通义千问 AI(准确)
   - 有配置 API Key 才调用,否则返回本地结果

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
2026-01-24 11:32:17 +08:00
parent afc6d2fb5e
commit 0f11e8ad56
7 changed files with 215 additions and 29 deletions

View File

@@ -408,11 +408,13 @@ def show_today():
today = date.today()
# 获取今日数据
active_user = db.get_active_user()
user_id = active_user.id if active_user else 1
exercises = db.get_exercises(today, today)
meals = db.get_meals(today, today)
sleep_records = db.get_sleep_records(today, today)
weight_records = db.get_weight_records(today, today)
config = db.get_config()
config = db.get_config(user_id)
console.print(Panel(f"[bold]📊 今日概览 - {today}[/bold]"))
@@ -555,7 +557,9 @@ def config_set(
goal: Optional[str] = typer.Option(None, "--goal", help="目标 (lose/maintain/gain)"),
):
"""设置用户配置"""
config = db.get_config()
active_user = db.get_active_user()
user_id = active_user.id if active_user else 1
config = db.get_config(user_id)
if age:
config.age = age
@@ -570,7 +574,7 @@ def config_set(
if goal:
config.goal = goal
db.save_config(config)
db.save_config(user_id, config)
console.print("[green]✓[/green] 配置已保存")
# 显示计算结果
@@ -583,7 +587,9 @@ def config_set(
@config_app.command("show")
def config_show():
"""显示当前配置"""
config = db.get_config()
active_user = db.get_active_user()
user_id = active_user.id if active_user else 1
config = db.get_config(user_id)
table = Table(title="用户配置")
table.add_column("项目", style="cyan")

View File

@@ -181,9 +181,88 @@ def _chinese_to_num(chinese: str) -> float:
return mapping.get(chinese, 1)
def estimate_meal_calories(description: str) -> dict:
def _local_estimate(description: str) -> dict:
"""本地估算卡路里(使用静态数据库)"""
items = parse_food_description(description)
return {
"total_calories": sum(item["calories"] for item in items),
"total_protein": round(sum(item["protein"] for item in items), 1),
"total_carbs": round(sum(item["carbs"] for item in items), 1),
"total_fat": round(sum(item["fat"] for item in items), 1),
"items": items,
}
def _ai_estimate(description: str, api_key: str) -> dict:
"""使用通义千问 AI 标准化食物描述后再估算"""
try:
import httpx
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
}
payload = {
"model": "qwen-vl-max-latest",
"messages": [
{
"role": "system",
"content": """你是专业的营养分析助手。用户会告诉你吃了什么,你需要:
1. 识别所有食物
2. 标准化食物名称(如"""面条""""猪肉"
3. 提取数量(如"两个""一碗""100g"
4. 按照 "食物1+食物2+食物3" 格式返回
示例:
用户输入:"今天吃了一碗米饭、两个鸡蛋还有一杯牛奶"
你返回:"一碗米饭+两个鸡蛋+一杯牛奶"
只返回标准化后的食物列表,不需要其他解释。"""
},
{
"role": "user",
"content": description
}
],
"temperature": 0.3,
}
with httpx.Client(timeout=30.0) as client:
response = client.post(
"https://dashscope.aliyuncs.com/compatible-mode/v1/chat/completions",
headers=headers,
json=payload,
)
response.raise_for_status()
result = response.json()
standardized = result["choices"][0]["message"]["content"].strip()
# 用标准化后的描述进行本地计算
local_result = _local_estimate(standardized)
local_result["description"] = standardized
local_result["original_input"] = description
local_result["provider"] = "qwen"
return local_result
except Exception as e:
# AI 调用失败,回退到本地结果
return _local_estimate(description)
def estimate_meal_calories(description: str, use_ai_fallback: bool = True) -> dict:
"""
估算一餐的总卡路里
估算一餐的总卡路里(智能识别)
优先使用本地数据库匹配,如果有未识别的食物且配置了 API Key
则调用通义千问 AI 进行标准化后重新计算。
Args:
description: 食物描述
use_ai_fallback: 是否在本地识别失败时调用 AI
返回:
{
@@ -194,17 +273,26 @@ def estimate_meal_calories(description: str) -> dict:
"items": [...]
}
"""
items = parse_food_description(description)
# 1. 先用本地数据库计算
result = _local_estimate(description)
total = {
"total_calories": sum(item["calories"] for item in items),
"total_protein": round(sum(item["protein"] for item in items), 1),
"total_carbs": round(sum(item["carbs"] for item in items), 1),
"total_fat": round(sum(item["fat"] for item in items), 1),
"items": items,
}
# 2. 检查是否有未识别的食物calories=0 且 estimated=False
has_unknown = any(
item["calories"] == 0 and not item.get("estimated", False)
for item in result["items"]
)
return total
# 3. 如果有未识别且启用 AI 回退,尝试调用大模型
if has_unknown and use_ai_fallback:
try:
from . import database as db
api_key = db.get_api_key("dashscope")
if api_key:
return _ai_estimate(description, api_key)
except Exception:
pass # 忽略导入或调用错误,返回本地结果
return result
def estimate_exercise_calories(exercise_type: str, duration_mins: int, weight_kg: float = 70) -> int:

View File

@@ -470,10 +470,31 @@ def delete_weight(weight_id: int):
# ===== 用户配置 =====
def get_config() -> UserConfig:
def migrate_config_add_user_id():
"""迁移:为 config 表添加 user_id 字段,实现用户隔离"""
with get_connection() as (conn, cursor):
# 检查 user_id 列是否已存在
cursor.execute("""
SELECT COLUMN_NAME FROM INFORMATION_SCHEMA.COLUMNS
WHERE TABLE_SCHEMA = DATABASE() AND TABLE_NAME = 'config' AND COLUMN_NAME = 'user_id'
""")
if cursor.fetchone():
return # 已迁移
# 1. 添加 user_id 列,默认值为 1
cursor.execute("ALTER TABLE config ADD COLUMN user_id INT NOT NULL DEFAULT 1")
# 2. 删除原主键
cursor.execute("ALTER TABLE config DROP PRIMARY KEY")
# 3. 创建新的复合主键
cursor.execute("ALTER TABLE config ADD PRIMARY KEY (user_id, `key`)")
def get_config(user_id: int = 1) -> UserConfig:
"""获取用户配置"""
with get_connection() as (conn, cursor):
cursor.execute("SELECT `key`, value FROM config")
cursor.execute("SELECT `key`, value FROM config WHERE user_id = %s", (user_id,))
rows = cursor.fetchall()
config_dict = {row["key"]: row["value"] for row in rows}
@@ -488,7 +509,7 @@ def get_config() -> UserConfig:
)
def save_config(config: UserConfig):
def save_config(user_id: int, config: UserConfig):
"""保存用户配置"""
with get_connection() as (conn, cursor):
config_dict = {
@@ -503,8 +524,9 @@ def save_config(config: UserConfig):
for key, value in config_dict.items():
if value is not None:
cursor.execute("""
REPLACE INTO config (`key`, value) VALUES (%s, %s)
""", (key, value))
INSERT INTO config (user_id, `key`, value) VALUES (%s, %s, %s)
ON DUPLICATE KEY UPDATE value = VALUES(value)
""", (user_id, key, value))
# ===== 用户管理 =====

View File

@@ -25,7 +25,7 @@ def _json_default(value):
raise TypeError(f"Object of type {type(value).__name__} is not JSON serializable")
def export_all_data_json(output_path: Path) -> None:
def export_all_data_json(output_path: Path, user_id: int = 1) -> None:
"""导出所有数据为 JSON"""
data = {
"version": "1.0",
@@ -34,7 +34,7 @@ def export_all_data_json(output_path: Path) -> None:
"meals": [m.to_dict() for m in db.get_meals()],
"sleep": [s.to_dict() for s in db.get_sleep_records()],
"weight": [w.to_dict() for w in db.get_weight_records()],
"config": db.get_config().to_dict(),
"config": db.get_config(user_id).to_dict(),
}
output_path.parent.mkdir(parents=True, exist_ok=True)
output_path.write_text(json.dumps(data, ensure_ascii=False, indent=2, default=_json_default), encoding="utf-8")
@@ -117,7 +117,7 @@ def export_to_csv(
writer.writerow(row)
def import_from_json(input_path: Path) -> dict[str, int]:
def import_from_json(input_path: Path, user_id: int = 1) -> dict[str, int]:
"""从 JSON 导入数据(最小实现:覆盖性导入,不做去重)"""
data = json.loads(input_path.read_text(encoding="utf-8"))
@@ -128,6 +128,7 @@ def import_from_json(input_path: Path) -> dict[str, int]:
if isinstance(config, dict):
# 仅持久化可写字段
db.save_config(
user_id,
UserConfig(
age=config.get("age"),
gender=config.get("gender"),

View File

@@ -112,7 +112,7 @@ class MonthlyReport:
WEEKDAY_NAMES = ["周一", "周二", "周三", "周四", "周五", "周六", "周日"]
def generate_weekly_report(target_date: Optional[date] = None) -> WeeklyReport:
def generate_weekly_report(target_date: Optional[date] = None, user_id: int = 1) -> WeeklyReport:
"""生成周报"""
if target_date is None:
target_date = date.today()
@@ -126,7 +126,7 @@ def generate_weekly_report(target_date: Optional[date] = None) -> WeeklyReport:
meals = db.get_meals(start_date, end_date)
sleep_records = db.get_sleep_records(start_date, end_date)
weight_records = db.get_weight_records(start_date, end_date)
config = db.get_config()
config = db.get_config(user_id)
# 运动统计
exercise_duration = sum(e.duration for e in exercises)

View File

@@ -20,6 +20,7 @@ from ..core.auth import hash_password, verify_password, create_token, decode_tok
# 初始化数据库
db.init_db()
db.migrate_auth_fields() # 迁移认证字段
db.migrate_config_add_user_id() # 迁移 config 表添加 user_id
app = FastAPI(
@@ -887,7 +888,9 @@ async def settings_page():
@app.get("/api/config", response_model=ConfigResponse)
async def get_config():
"""获取用户配置"""
config = db.get_config()
active_user = db.get_active_user()
user_id = active_user.id if active_user else 1
config = db.get_config(user_id)
return ConfigResponse(
age=config.age,
gender=config.gender,
@@ -909,7 +912,7 @@ async def get_today_summary():
raise HTTPException(status_code=400, detail="没有激活的用户")
today = date.today()
config = db.get_config()
config = db.get_config(active_user.id)
# 获取今日数据
exercises = db.get_exercises(start_date=today, end_date=today, user_id=active_user.id)
@@ -1187,7 +1190,7 @@ async def add_exercise_api(data: ExerciseInput):
raise HTTPException(status_code=400, detail="日期格式应为 YYYY-MM-DD") from exc
from ..core.calories import estimate_exercise_calories
config = db.get_config()
config = db.get_config(active_user.id)
weight_kg = config.weight or 70
calories = data.calories if data.calories is not None else estimate_exercise_calories(
data.type, data.duration, weight_kg
@@ -1579,7 +1582,9 @@ async def get_weight_records(
@app.get("/api/weight/goal")
async def get_weight_goal():
"""获取目标体重(基于用户配置推断)"""
config = db.get_config()
active_user = db.get_active_user()
user_id = active_user.id if active_user else 1
config = db.get_config(user_id)
if not config.weight:
return {"goal_weight": None}