feat: 用户配置隔离与食物智能识别
1. Config 表用户隔离 - 添加 user_id 字段,复合主键 (user_id, key) - 现有数据归属 ID=1 用户 - 所有 get_config/save_config 调用传入 user_id 2. 食物文字智能识别 - 本地数据库优先匹配(快速) - 识别失败时自动调用通义千问 AI(准确) - 有配置 API Key 才调用,否则返回本地结果 Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -408,11 +408,13 @@ def show_today():
|
||||
today = date.today()
|
||||
|
||||
# 获取今日数据
|
||||
active_user = db.get_active_user()
|
||||
user_id = active_user.id if active_user else 1
|
||||
exercises = db.get_exercises(today, today)
|
||||
meals = db.get_meals(today, today)
|
||||
sleep_records = db.get_sleep_records(today, today)
|
||||
weight_records = db.get_weight_records(today, today)
|
||||
config = db.get_config()
|
||||
config = db.get_config(user_id)
|
||||
|
||||
console.print(Panel(f"[bold]📊 今日概览 - {today}[/bold]"))
|
||||
|
||||
@@ -555,7 +557,9 @@ def config_set(
|
||||
goal: Optional[str] = typer.Option(None, "--goal", help="目标 (lose/maintain/gain)"),
|
||||
):
|
||||
"""设置用户配置"""
|
||||
config = db.get_config()
|
||||
active_user = db.get_active_user()
|
||||
user_id = active_user.id if active_user else 1
|
||||
config = db.get_config(user_id)
|
||||
|
||||
if age:
|
||||
config.age = age
|
||||
@@ -570,7 +574,7 @@ def config_set(
|
||||
if goal:
|
||||
config.goal = goal
|
||||
|
||||
db.save_config(config)
|
||||
db.save_config(user_id, config)
|
||||
console.print("[green]✓[/green] 配置已保存")
|
||||
|
||||
# 显示计算结果
|
||||
@@ -583,7 +587,9 @@ def config_set(
|
||||
@config_app.command("show")
|
||||
def config_show():
|
||||
"""显示当前配置"""
|
||||
config = db.get_config()
|
||||
active_user = db.get_active_user()
|
||||
user_id = active_user.id if active_user else 1
|
||||
config = db.get_config(user_id)
|
||||
|
||||
table = Table(title="用户配置")
|
||||
table.add_column("项目", style="cyan")
|
||||
|
||||
@@ -181,9 +181,88 @@ def _chinese_to_num(chinese: str) -> float:
|
||||
return mapping.get(chinese, 1)
|
||||
|
||||
|
||||
def estimate_meal_calories(description: str) -> dict:
|
||||
def _local_estimate(description: str) -> dict:
|
||||
"""本地估算卡路里(使用静态数据库)"""
|
||||
items = parse_food_description(description)
|
||||
|
||||
return {
|
||||
"total_calories": sum(item["calories"] for item in items),
|
||||
"total_protein": round(sum(item["protein"] for item in items), 1),
|
||||
"total_carbs": round(sum(item["carbs"] for item in items), 1),
|
||||
"total_fat": round(sum(item["fat"] for item in items), 1),
|
||||
"items": items,
|
||||
}
|
||||
|
||||
|
||||
def _ai_estimate(description: str, api_key: str) -> dict:
|
||||
"""使用通义千问 AI 标准化食物描述后再估算"""
|
||||
try:
|
||||
import httpx
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
payload = {
|
||||
"model": "qwen-vl-max-latest",
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": """你是专业的营养分析助手。用户会告诉你吃了什么,你需要:
|
||||
1. 识别所有食物
|
||||
2. 标准化食物名称(如"面"→"面条","肉"→"猪肉")
|
||||
3. 提取数量(如"两个"、"一碗"、"100g")
|
||||
4. 按照 "食物1+食物2+食物3" 格式返回
|
||||
|
||||
示例:
|
||||
用户输入:"今天吃了一碗米饭、两个鸡蛋还有一杯牛奶"
|
||||
你返回:"一碗米饭+两个鸡蛋+一杯牛奶"
|
||||
|
||||
只返回标准化后的食物列表,不需要其他解释。"""
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": description
|
||||
}
|
||||
],
|
||||
"temperature": 0.3,
|
||||
}
|
||||
|
||||
with httpx.Client(timeout=30.0) as client:
|
||||
response = client.post(
|
||||
"https://dashscope.aliyuncs.com/compatible-mode/v1/chat/completions",
|
||||
headers=headers,
|
||||
json=payload,
|
||||
)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
|
||||
standardized = result["choices"][0]["message"]["content"].strip()
|
||||
|
||||
# 用标准化后的描述进行本地计算
|
||||
local_result = _local_estimate(standardized)
|
||||
local_result["description"] = standardized
|
||||
local_result["original_input"] = description
|
||||
local_result["provider"] = "qwen"
|
||||
|
||||
return local_result
|
||||
|
||||
except Exception as e:
|
||||
# AI 调用失败,回退到本地结果
|
||||
return _local_estimate(description)
|
||||
|
||||
|
||||
def estimate_meal_calories(description: str, use_ai_fallback: bool = True) -> dict:
|
||||
"""
|
||||
估算一餐的总卡路里
|
||||
估算一餐的总卡路里(智能识别)
|
||||
|
||||
优先使用本地数据库匹配,如果有未识别的食物且配置了 API Key,
|
||||
则调用通义千问 AI 进行标准化后重新计算。
|
||||
|
||||
Args:
|
||||
description: 食物描述
|
||||
use_ai_fallback: 是否在本地识别失败时调用 AI
|
||||
|
||||
返回:
|
||||
{
|
||||
@@ -194,17 +273,26 @@ def estimate_meal_calories(description: str) -> dict:
|
||||
"items": [...]
|
||||
}
|
||||
"""
|
||||
items = parse_food_description(description)
|
||||
# 1. 先用本地数据库计算
|
||||
result = _local_estimate(description)
|
||||
|
||||
total = {
|
||||
"total_calories": sum(item["calories"] for item in items),
|
||||
"total_protein": round(sum(item["protein"] for item in items), 1),
|
||||
"total_carbs": round(sum(item["carbs"] for item in items), 1),
|
||||
"total_fat": round(sum(item["fat"] for item in items), 1),
|
||||
"items": items,
|
||||
}
|
||||
# 2. 检查是否有未识别的食物(calories=0 且 estimated=False)
|
||||
has_unknown = any(
|
||||
item["calories"] == 0 and not item.get("estimated", False)
|
||||
for item in result["items"]
|
||||
)
|
||||
|
||||
return total
|
||||
# 3. 如果有未识别且启用 AI 回退,尝试调用大模型
|
||||
if has_unknown and use_ai_fallback:
|
||||
try:
|
||||
from . import database as db
|
||||
api_key = db.get_api_key("dashscope")
|
||||
if api_key:
|
||||
return _ai_estimate(description, api_key)
|
||||
except Exception:
|
||||
pass # 忽略导入或调用错误,返回本地结果
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def estimate_exercise_calories(exercise_type: str, duration_mins: int, weight_kg: float = 70) -> int:
|
||||
|
||||
@@ -470,10 +470,31 @@ def delete_weight(weight_id: int):
|
||||
|
||||
# ===== 用户配置 =====
|
||||
|
||||
def get_config() -> UserConfig:
|
||||
def migrate_config_add_user_id():
|
||||
"""迁移:为 config 表添加 user_id 字段,实现用户隔离"""
|
||||
with get_connection() as (conn, cursor):
|
||||
# 检查 user_id 列是否已存在
|
||||
cursor.execute("""
|
||||
SELECT COLUMN_NAME FROM INFORMATION_SCHEMA.COLUMNS
|
||||
WHERE TABLE_SCHEMA = DATABASE() AND TABLE_NAME = 'config' AND COLUMN_NAME = 'user_id'
|
||||
""")
|
||||
if cursor.fetchone():
|
||||
return # 已迁移
|
||||
|
||||
# 1. 添加 user_id 列,默认值为 1
|
||||
cursor.execute("ALTER TABLE config ADD COLUMN user_id INT NOT NULL DEFAULT 1")
|
||||
|
||||
# 2. 删除原主键
|
||||
cursor.execute("ALTER TABLE config DROP PRIMARY KEY")
|
||||
|
||||
# 3. 创建新的复合主键
|
||||
cursor.execute("ALTER TABLE config ADD PRIMARY KEY (user_id, `key`)")
|
||||
|
||||
|
||||
def get_config(user_id: int = 1) -> UserConfig:
|
||||
"""获取用户配置"""
|
||||
with get_connection() as (conn, cursor):
|
||||
cursor.execute("SELECT `key`, value FROM config")
|
||||
cursor.execute("SELECT `key`, value FROM config WHERE user_id = %s", (user_id,))
|
||||
rows = cursor.fetchall()
|
||||
|
||||
config_dict = {row["key"]: row["value"] for row in rows}
|
||||
@@ -488,7 +509,7 @@ def get_config() -> UserConfig:
|
||||
)
|
||||
|
||||
|
||||
def save_config(config: UserConfig):
|
||||
def save_config(user_id: int, config: UserConfig):
|
||||
"""保存用户配置"""
|
||||
with get_connection() as (conn, cursor):
|
||||
config_dict = {
|
||||
@@ -503,8 +524,9 @@ def save_config(config: UserConfig):
|
||||
for key, value in config_dict.items():
|
||||
if value is not None:
|
||||
cursor.execute("""
|
||||
REPLACE INTO config (`key`, value) VALUES (%s, %s)
|
||||
""", (key, value))
|
||||
INSERT INTO config (user_id, `key`, value) VALUES (%s, %s, %s)
|
||||
ON DUPLICATE KEY UPDATE value = VALUES(value)
|
||||
""", (user_id, key, value))
|
||||
|
||||
|
||||
# ===== 用户管理 =====
|
||||
|
||||
@@ -25,7 +25,7 @@ def _json_default(value):
|
||||
raise TypeError(f"Object of type {type(value).__name__} is not JSON serializable")
|
||||
|
||||
|
||||
def export_all_data_json(output_path: Path) -> None:
|
||||
def export_all_data_json(output_path: Path, user_id: int = 1) -> None:
|
||||
"""导出所有数据为 JSON"""
|
||||
data = {
|
||||
"version": "1.0",
|
||||
@@ -34,7 +34,7 @@ def export_all_data_json(output_path: Path) -> None:
|
||||
"meals": [m.to_dict() for m in db.get_meals()],
|
||||
"sleep": [s.to_dict() for s in db.get_sleep_records()],
|
||||
"weight": [w.to_dict() for w in db.get_weight_records()],
|
||||
"config": db.get_config().to_dict(),
|
||||
"config": db.get_config(user_id).to_dict(),
|
||||
}
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
output_path.write_text(json.dumps(data, ensure_ascii=False, indent=2, default=_json_default), encoding="utf-8")
|
||||
@@ -117,7 +117,7 @@ def export_to_csv(
|
||||
writer.writerow(row)
|
||||
|
||||
|
||||
def import_from_json(input_path: Path) -> dict[str, int]:
|
||||
def import_from_json(input_path: Path, user_id: int = 1) -> dict[str, int]:
|
||||
"""从 JSON 导入数据(最小实现:覆盖性导入,不做去重)"""
|
||||
data = json.loads(input_path.read_text(encoding="utf-8"))
|
||||
|
||||
@@ -128,6 +128,7 @@ def import_from_json(input_path: Path) -> dict[str, int]:
|
||||
if isinstance(config, dict):
|
||||
# 仅持久化可写字段
|
||||
db.save_config(
|
||||
user_id,
|
||||
UserConfig(
|
||||
age=config.get("age"),
|
||||
gender=config.get("gender"),
|
||||
|
||||
@@ -112,7 +112,7 @@ class MonthlyReport:
|
||||
WEEKDAY_NAMES = ["周一", "周二", "周三", "周四", "周五", "周六", "周日"]
|
||||
|
||||
|
||||
def generate_weekly_report(target_date: Optional[date] = None) -> WeeklyReport:
|
||||
def generate_weekly_report(target_date: Optional[date] = None, user_id: int = 1) -> WeeklyReport:
|
||||
"""生成周报"""
|
||||
if target_date is None:
|
||||
target_date = date.today()
|
||||
@@ -126,7 +126,7 @@ def generate_weekly_report(target_date: Optional[date] = None) -> WeeklyReport:
|
||||
meals = db.get_meals(start_date, end_date)
|
||||
sleep_records = db.get_sleep_records(start_date, end_date)
|
||||
weight_records = db.get_weight_records(start_date, end_date)
|
||||
config = db.get_config()
|
||||
config = db.get_config(user_id)
|
||||
|
||||
# 运动统计
|
||||
exercise_duration = sum(e.duration for e in exercises)
|
||||
|
||||
@@ -20,6 +20,7 @@ from ..core.auth import hash_password, verify_password, create_token, decode_tok
|
||||
# 初始化数据库
|
||||
db.init_db()
|
||||
db.migrate_auth_fields() # 迁移认证字段
|
||||
db.migrate_config_add_user_id() # 迁移 config 表添加 user_id
|
||||
|
||||
|
||||
app = FastAPI(
|
||||
@@ -887,7 +888,9 @@ async def settings_page():
|
||||
@app.get("/api/config", response_model=ConfigResponse)
|
||||
async def get_config():
|
||||
"""获取用户配置"""
|
||||
config = db.get_config()
|
||||
active_user = db.get_active_user()
|
||||
user_id = active_user.id if active_user else 1
|
||||
config = db.get_config(user_id)
|
||||
return ConfigResponse(
|
||||
age=config.age,
|
||||
gender=config.gender,
|
||||
@@ -909,7 +912,7 @@ async def get_today_summary():
|
||||
raise HTTPException(status_code=400, detail="没有激活的用户")
|
||||
|
||||
today = date.today()
|
||||
config = db.get_config()
|
||||
config = db.get_config(active_user.id)
|
||||
|
||||
# 获取今日数据
|
||||
exercises = db.get_exercises(start_date=today, end_date=today, user_id=active_user.id)
|
||||
@@ -1187,7 +1190,7 @@ async def add_exercise_api(data: ExerciseInput):
|
||||
raise HTTPException(status_code=400, detail="日期格式应为 YYYY-MM-DD") from exc
|
||||
|
||||
from ..core.calories import estimate_exercise_calories
|
||||
config = db.get_config()
|
||||
config = db.get_config(active_user.id)
|
||||
weight_kg = config.weight or 70
|
||||
calories = data.calories if data.calories is not None else estimate_exercise_calories(
|
||||
data.type, data.duration, weight_kg
|
||||
@@ -1579,7 +1582,9 @@ async def get_weight_records(
|
||||
@app.get("/api/weight/goal")
|
||||
async def get_weight_goal():
|
||||
"""获取目标体重(基于用户配置推断)"""
|
||||
config = db.get_config()
|
||||
active_user = db.get_active_user()
|
||||
user_id = active_user.id if active_user else 1
|
||||
config = db.get_config(user_id)
|
||||
if not config.weight:
|
||||
return {"goal_weight": None}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user