# Copyright 2026 zhaoxi826 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import jwt import os from datetime import datetime, timedelta, timezone from typing import Optional from fastapi import HTTPException, status, Request from pydantic import BaseModel, ValidationError from pretor.core.database.table.user import User from passlib.context import CryptContext class TokenData(BaseModel): user_id: str username: Optional[str] = None exp: Optional[int] = None SECRET_KEY = os.getenv("SECRET_KEY") ALGORITHM = "HS256" ACCESS_TOKEN_EXPIRE_MINUTES = 60 * 24 # 默认有效期 1 天 pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") class Accessor: @staticmethod def _decode_token(token: str) -> TokenData: try: payload = jwt.decode( token, SECRET_KEY, algorithms=[ALGORITHM] ) return TokenData(**payload) except jwt.ExpiredSignatureError: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Token 已过期", ) except (jwt.InvalidTokenError, ValidationError): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="无效的认证凭证", ) @staticmethod def _create_access_token(data: dict) -> str: to_encode = data.copy() expire = datetime.now(timezone.utc) + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) to_encode.update({"exp": int(expire.timestamp())}) encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) return encoded_jwt @staticmethod def _verify_password(plain_password: str, hashed_password: str) -> bool: return pwd_context.verify(plain_password, hashed_password) @staticmethod def get_current_user(request: Request) -> TokenData: auth_header = request.headers.get("Authorization") if not auth_header or not auth_header.startswith("Bearer "): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="未提供认证头部", ) token = auth_header.split(" ")[1] return Accessor._decode_token(token) @staticmethod def login_hashed_password(user: User, password: str) -> str: if not user: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="用户不存在", ) if not Accessor._verify_password(password, user.hashed_password): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="用户名或密码错误", ) token_payload = { "user_id": str(user.id), # 确保是字符串格式 "username": user.username } return Accessor._create_access_token(data=token_payload) @staticmethod def hash_password(password: str) -> str: if not password: raise ValueError("密码不能为空") if len(password) < 6: raise ValueError("密码长度不能小于 6 位") return pwd_context.hash(password)