diff --git a/src/vitals/core/database.py b/src/vitals/core/database.py index 25cd0f8..b85b3de 100644 --- a/src/vitals/core/database.py +++ b/src/vitals/core/database.py @@ -984,3 +984,107 @@ def delete_invite(invite_id: int): """删除邀请码""" with get_connection() as (conn, cursor): cursor.execute("DELETE FROM invites WHERE id = %s", (invite_id,)) + + +# ===== API Key 管理 ===== + +# API Key 与环境变量的映射 +API_KEY_ENV_MAP = { + "dashscope": "DASHSCOPE_API_KEY", + "deepseek": "DEEPSEEK_API_KEY", + "anthropic": "ANTHROPIC_API_KEY", +} + +# API Key 显示名称 +API_KEY_NAMES = { + "dashscope": "通义千问 (DashScope)", + "deepseek": "DeepSeek", + "anthropic": "Anthropic (Claude)", +} + + +def get_api_key(provider: str) -> Optional[str]: + """获取 API Key,数据库优先,环境变量备用""" + with get_connection() as (conn, cursor): + cursor.execute( + "SELECT value FROM config WHERE `key` = %s", + (f"api_key.{provider}",) + ) + row = cursor.fetchone() + if row and row["value"]: + return row["value"] + + # 回退到环境变量 + env_var = API_KEY_ENV_MAP.get(provider) + return os.environ.get(env_var) if env_var else None + + +def set_api_key(provider: str, value: str): + """保存 API Key 到数据库""" + if provider not in API_KEY_ENV_MAP: + raise ValueError(f"Unknown provider: {provider}") + + with get_connection() as (conn, cursor): + cursor.execute( + "REPLACE INTO config (`key`, value) VALUES (%s, %s)", + (f"api_key.{provider}", value) + ) + + +def delete_api_key(provider: str): + """从数据库删除 API Key(将回退到环境变量)""" + with get_connection() as (conn, cursor): + cursor.execute( + "DELETE FROM config WHERE `key` = %s", + (f"api_key.{provider}",) + ) + + +def get_all_api_keys(masked: bool = True) -> dict: + """获取所有 API Keys 状态 + + Args: + masked: 是否掩码显示值 + + Returns: + dict: {provider: {"name": 显示名, "value": 值或掩码, "source": "database"|"env"|None}} + """ + result = {} + + with get_connection() as (conn, cursor): + for provider, env_var in API_KEY_ENV_MAP.items(): + # 查数据库 + cursor.execute( + "SELECT value FROM config WHERE `key` = %s", + (f"api_key.{provider}",) + ) + row = cursor.fetchone() + + db_value = row["value"] if row and row["value"] else None + env_value = os.environ.get(env_var) + + # 确定来源和值 + if db_value: + source = "database" + value = db_value + elif env_value: + source = "env" + value = env_value + else: + source = None + value = None + + # 掩码处理 + if masked and value: + if len(value) > 8: + value = value[:4] + "*" * (len(value) - 8) + value[-4:] + else: + value = "*" * len(value) + + result[provider] = { + "name": API_KEY_NAMES.get(provider, provider), + "value": value, + "source": source, + } + + return result diff --git a/src/vitals/vision/analyzer.py b/src/vitals/vision/analyzer.py index c1b679f..c477173 100644 --- a/src/vitals/vision/analyzer.py +++ b/src/vitals/vision/analyzer.py @@ -31,8 +31,8 @@ class ClaudeFoodAnalyzer(FoodAnalyzer): """使用 Claude Vision API 的食物分析器""" def __init__(self, api_key: Optional[str] = None): - import os - self.api_key = api_key or os.environ.get("ANTHROPIC_API_KEY") + from ..core import database as db + self.api_key = api_key or db.get_api_key("anthropic") def analyze(self, image_path: Path) -> dict: """使用 Claude Vision 分析食物图片""" diff --git a/src/vitals/vision/providers/deepseek.py b/src/vitals/vision/providers/deepseek.py index 92d902b..0e0eb4b 100644 --- a/src/vitals/vision/providers/deepseek.py +++ b/src/vitals/vision/providers/deepseek.py @@ -1,20 +1,20 @@ """DeepSeek Vision API 适配器""" import base64 -import os from pathlib import Path from typing import Optional import httpx from ...core.calories import estimate_meal_calories +from ...core import database as db class DeepSeekVisionAnalyzer: """DeepSeek Vision 食物识别分析器""" def __init__(self, api_key: Optional[str] = None): - self.api_key = api_key or os.environ.get("DEEPSEEK_API_KEY") + self.api_key = api_key or db.get_api_key("deepseek") self.base_url = "https://api.deepseek.com/v1" def analyze_image(self, image_path: Path) -> dict: diff --git a/src/vitals/vision/providers/qwen.py b/src/vitals/vision/providers/qwen.py index 0813771..c2837b0 100644 --- a/src/vitals/vision/providers/qwen.py +++ b/src/vitals/vision/providers/qwen.py @@ -1,20 +1,20 @@ """Qwen VL (通义千问视觉) API 适配器""" import base64 -import os from pathlib import Path from typing import Optional import httpx from ...core.calories import estimate_meal_calories +from ...core import database as db class QwenVisionAnalyzer: """Qwen VL 食物识别分析器""" def __init__(self, api_key: Optional[str] = None): - self.api_key = api_key or os.environ.get("DASHSCOPE_API_KEY") + self.api_key = api_key or db.get_api_key("dashscope") # 阿里云百炼 OpenAI 兼容接口 self.base_url = "https://dashscope.aliyuncs.com/compatible-mode/v1" diff --git a/src/vitals/web/app.py b/src/vitals/web/app.py index 9fd759f..3f61434 100644 --- a/src/vitals/web/app.py +++ b/src/vitals/web/app.py @@ -726,6 +726,50 @@ async def admin_delete_invite(invite_id: int, admin: User = Depends(require_admi return {"message": "邀请码已删除"} +# ===== API 密钥管理(管理员) ===== + + +@app.get("/api/admin/api-keys") +async def admin_get_api_keys(admin: User = Depends(require_admin)): + """获取所有 API Keys 状态(掩码显示)""" + return db.get_all_api_keys(masked=True) + + +@app.get("/api/admin/api-keys/{provider}") +async def admin_get_api_key(provider: str, admin: User = Depends(require_admin)): + """获取指定 API Key 完整值""" + if provider not in db.API_KEY_ENV_MAP: + raise HTTPException(status_code=400, detail=f"未知的 provider: {provider}") + value = db.get_api_key(provider) + return { + "provider": provider, + "name": db.API_KEY_NAMES.get(provider, provider), + "value": value, + } + + +class ApiKeyInput(BaseModel): + value: str + + +@app.put("/api/admin/api-keys/{provider}") +async def admin_set_api_key(provider: str, data: ApiKeyInput, admin: User = Depends(require_admin)): + """设置/更新 API Key""" + if provider not in db.API_KEY_ENV_MAP: + raise HTTPException(status_code=400, detail=f"未知的 provider: {provider}") + db.set_api_key(provider, data.value) + return {"message": "API Key 已保存", "provider": provider} + + +@app.delete("/api/admin/api-keys/{provider}") +async def admin_delete_api_key(provider: str, admin: User = Depends(require_admin)): + """删除 API Key(回退到环境变量)""" + if provider not in db.API_KEY_ENV_MAP: + raise HTTPException(status_code=400, detail=f"未知的 provider: {provider}") + db.delete_api_key(provider) + return {"message": "API Key 已删除,将使用环境变量配置", "provider": provider} + + # ===== 页面路由 ===== @@ -8030,6 +8074,18 @@ def get_settings_page_html() -> str: + +
+