50 lines
1.5 KiB
Python
50 lines
1.5 KiB
Python
from fastapi import Depends, HTTPException, status, Request
|
|
from sqlalchemy.orm import Session
|
|
from jose import JWTError
|
|
|
|
from app.core.database import get_db
|
|
from app.core.security import decode_token
|
|
from app.models.user import User
|
|
from app.services import user_service
|
|
|
|
|
|
def get_token_from_request(request: Request) -> str:
|
|
auth = request.headers.get("Authorization", "")
|
|
if auth.startswith("Bearer "):
|
|
return auth[7:]
|
|
# Fallback to query param for some special cases
|
|
token = request.query_params.get("token")
|
|
if token:
|
|
return token
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="缺少认证令牌",
|
|
headers={"WWW-Authenticate": "Bearer"},
|
|
)
|
|
|
|
|
|
def get_current_user(
|
|
request: Request,
|
|
db: Session = Depends(get_db),
|
|
) -> User:
|
|
token = get_token_from_request(request)
|
|
payload = decode_token(token)
|
|
if not payload:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="无效的认证令牌",
|
|
headers={"WWW-Authenticate": "Bearer"},
|
|
)
|
|
user = user_service.get_user_by_id(db, int(payload.get("sub")))
|
|
if not user:
|
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="用户不存在")
|
|
if not user.is_active:
|
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="用户已被禁用")
|
|
return user
|
|
|
|
|
|
def get_current_active_user(
|
|
current_user: User = Depends(get_current_user),
|
|
) -> User:
|
|
return current_user
|