diff --git a/src/vitals/web/app.py b/src/vitals/web/app.py index 3f61434..41da6ee 100644 --- a/src/vitals/web/app.py +++ b/src/vitals/web/app.py @@ -448,9 +448,26 @@ def get_token_from_header(authorization: Optional[str] = Header(None)) -> Option return None -def get_current_user(authorization: Optional[str] = Header(None)) -> Optional[User]: +def get_token_from_request( + request: Request, + authorization: Optional[str] = Header(None), + auth_token: Optional[str] = Cookie(None) +) -> Optional[str]: + """从 Cookie 或 Header 中获取 Token(优先 Cookie)""" + # 优先从 Cookie 获取 + if auth_token: + return auth_token + # 其次从 Header 获取(向后兼容) + return get_token_from_header(authorization) + + +def get_current_user( + request: Request, + authorization: Optional[str] = Header(None), + auth_token: Optional[str] = Cookie(None) +) -> Optional[User]: """获取当前登录用户(可选,返回 None 表示未登录)""" - token = get_token_from_header(authorization) + token = get_token_from_request(request, authorization, auth_token) if not token: return None payload = decode_token(token) @@ -462,25 +479,37 @@ def get_current_user(authorization: Optional[str] = Header(None)) -> Optional[Us return user -def require_user(authorization: Optional[str] = Header(None)) -> User: +def require_user( + request: Request, + authorization: Optional[str] = Header(None), + auth_token: Optional[str] = Cookie(None) +) -> User: """要求用户登录(必须认证)""" - user = get_current_user(authorization) + user = get_current_user(request, authorization, auth_token) if not user: raise HTTPException(status_code=401, detail="未登录或登录已过期") return user -def require_admin(authorization: Optional[str] = Header(None)) -> User: +def require_admin( + request: Request, + authorization: Optional[str] = Header(None), + auth_token: Optional[str] = Cookie(None) +) -> User: """要求管理员权限""" - user = require_user(authorization) + user = require_user(request, authorization, auth_token) if not user.is_admin: raise HTTPException(status_code=403, detail="需要管理员权限") return user -def get_current_user_id(authorization: Optional[str] = Header(None)) -> int: +def get_current_user_id( + request: Request, + authorization: Optional[str] = Header(None), + auth_token: Optional[str] = Cookie(None) +) -> int: """获取当前用户 ID(兼容模式:未登录时使用活跃用户)""" - user = get_current_user(authorization) + user = get_current_user(request, authorization, auth_token) if user: return user.id # 兼容模式:未登录时使用当前活跃用户