Pretor/pretor/utils/access.py

107 lines
3.7 KiB
Python

# Copyright 2026 zhaoxi826
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import jwt
import os
from datetime import datetime, timedelta, timezone
from typing import Optional
from fastapi import HTTPException, status, Request
from pydantic import BaseModel, ValidationError
from pretor.core.database.table.user import User
from pwdlib import PasswordHash
class TokenData(BaseModel):
user_id: str
username: Optional[str] = None
exp: Optional[int] = None
SECRET_KEY = os.getenv("SECRET_KEY")
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 60 * 24
if not SECRET_KEY or SECRET_KEY in {"secret", "114514"}:
raise RuntimeError("未提供有效的 SECRET_KEY 或使用了不安全的默认值")
password_hasher = PasswordHash.recommended()
class Accessor:
@staticmethod
def _decode_token(token: str) -> TokenData:
try:
payload = jwt.decode(
token,
SECRET_KEY,
algorithms=[ALGORITHM]
)
return TokenData(**payload)
except jwt.ExpiredSignatureError:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Token 已过期",
)
except (jwt.InvalidTokenError, ValidationError):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="无效的认证凭证",
)
@staticmethod
def _create_access_token(data: dict) -> str:
to_encode = data.copy()
expire = datetime.now(timezone.utc) + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
to_encode.update({"exp": int(expire.timestamp())})
return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
@staticmethod
def verify_password(plain_password: str, hashed_password: str) -> bool:
return password_hasher.verify(plain_password, hashed_password)
@staticmethod
def get_current_user(request: Request) -> TokenData:
auth_header = request.headers.get("Authorization")
if not auth_header or not auth_header.startswith("Bearer "):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="未提供认证头部",
)
token = auth_header.split(" ")[1]
return Accessor._decode_token(token)
@staticmethod
def login_hashed_password(user: User, password: str) -> str:
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="用户不存在",
)
if not Accessor.verify_password(password, user.hashed_password):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="用户名或密码错误",
)
token_payload = {
"user_id": str(user.user_id),
"username": user.user_name
}
return Accessor._create_access_token(data=token_payload)
@staticmethod
def hash_password(password: str) -> str:
if not password:
raise ValueError("密码不能为空")
if len(password) < 6:
raise ValueError("密码长度不能小于 6 位")
return password_hasher.hash(password)