From 81da2e9f812957efb8089e4de672223aeed46c4b Mon Sep 17 00:00:00 2001 From: zhaoxi Date: Tue, 21 Apr 2026 23:00:59 +0800 Subject: [PATCH] =?UTF-8?q?wip:=20=E5=AF=B9=E4=BA=8E=E7=94=A8=E6=88=B7?= =?UTF-8?q?=E6=A8=A1=E5=9E=8B=E8=BF=9B=E8=A1=8C=E4=BA=86=E4=BC=98=E5=8C=96?= =?UTF-8?q?=EF=BC=8C=E5=A2=9E=E5=8A=A0=E4=BA=86=E6=9D=83=E9=99=90=E6=A8=A1?= =?UTF-8?q?=E5=9E=8B=E5=92=8C=E6=9D=83=E9=99=90=E6=A0=A1=E9=AA=8C?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- docs/problem.md | 6 +- pretor/api/agent.py | 90 ++++++++++++++++++++++- pretor/core/database/module/individual.py | 62 +++++++++++++++- pretor/core/database/module/user.py | 10 ++- pretor/core/database/postgres.py | 23 +++++- pretor/core/database/table/user.py | 11 ++- pretor/utils/check_user/role_check.py | 40 ++++++++++ 7 files changed, 231 insertions(+), 11 deletions(-) create mode 100644 pretor/utils/check_user/role_check.py diff --git a/docs/problem.md b/docs/problem.md index 23f5e4b..81f8042 100644 --- a/docs/problem.md +++ b/docs/problem.md @@ -18,7 +18,7 @@ - 】~~使用fastapi-users完善用户系统~~(2026/4/19 fastapi-users会严重摧毁代码的优雅性) - [ ] 升级auth功能 - [x] /pretor/api的接口函数进行重构 -- [ ] /dockerfile待完善 +- [x] /dockerfile待完善 - [ ] 完善沙箱功能 - [ ] 完善爬虫功能 - [ ] 对接更多的provider @@ -30,7 +30,7 @@ - [x] /pretor/core/individual每个template进行优化 - [ ] /pretor/worker_individual待完善复合子个体和基础子个体 - [ ] /pretor/api待完善 -- [ ] /dockerfile待完善 +- [x] /dockerfile待完善 #### 2026/4/16 - [ ] 发布v0.1.0正式版 @@ -42,7 +42,7 @@ - [ ] 完善爬虫功能 - [ ] 对接更多的provider - [ ] 优化import -- [ ] 升级auth功能 +- [x] 升级auth功能 #### 2026/4/20 - [ ] 优化安全架构防止模型注入 diff --git a/pretor/api/agent.py b/pretor/api/agent.py index 37a0f31..ec84a7d 100644 --- a/pretor/api/agent.py +++ b/pretor/api/agent.py @@ -15,9 +15,12 @@ from typing import Union from pretor.utils.ray_hook import ray_actor_hook -from fastapi import APIRouter, Request, Depends +from fastapi import APIRouter, Depends from pydantic import BaseModel from pretor.utils.access import Accessor, TokenData +from pretor.core.database.table.individual import AgentType +from fastapi import HTTPException +from typing import Optional, List, Dict agent_router = APIRouter(prefix="/api/v1/agent", tags=["agent"]) @@ -50,4 +53,87 @@ async def load_agent(agent_register: Union[AgentRegister, AgentLocalRegister], node.create_agent.remote(global_state_machine,agent_register.provider_title,agent_register.model_id) case _: pass - return {"message": "创建成功"} \ No newline at end of file + return {"message": "创建成功"} + + +class WorkerIndividualCreate(BaseModel): + agent_name: str + agent_type: AgentType + description: str + provider_title: str + model_id: str + system_prompt: str + output_template: dict + bound_skill: Dict[str, List[str]] + workspace: List[str] + + +class WorkerIndividualUpdate(BaseModel): + agent_name: Optional[str] = None + agent_type: Optional[AgentType] = None + description: Optional[str] = None + provider_title: Optional[str] = None + model_id: Optional[str] = None + system_prompt: Optional[str] = None + output_template: Optional[dict] = None + bound_skill: Optional[Dict[str, List[str]]] = None + workspace: Optional[List[str]] = None + + +@agent_router.post("/worker") +async def create_worker_individual(worker_data: WorkerIndividualCreate, + token_data: TokenData = Depends(Accessor.get_current_user)): + postgres_database = ray_actor_hook("postgres_database") + data_dict = worker_data.model_dump() + data_dict["owner_id"] = token_data.user_id + worker = await postgres_database.add_worker_individual.remote(**data_dict) + return {"message": "success", "agent_id": worker.agent_id} + + +@agent_router.get("/worker") +async def get_worker_individual_list(token_data: TokenData = Depends(Accessor.get_current_user)): + postgres_database = ray_actor_hook("postgres_database") + workers = await postgres_database.get_worker_individual_list.remote(owner_id=token_data.user_id) + return {"workers": workers} + + +@agent_router.get("/worker/{agent_id}") +async def get_worker_individual(agent_id: str, + token_data: TokenData = Depends(Accessor.get_current_user)): + postgres_database = ray_actor_hook("postgres_database") + worker = await postgres_database.get_worker_individual.remote(agent_id=agent_id) + if not worker: + raise HTTPException(status_code=404, detail="Agent not found") + if worker.owner_id != token_data.user_id: + raise HTTPException(status_code=403, detail="Forbidden: You do not own this agent") + return worker + + +@agent_router.put("/worker/{agent_id}") +async def update_worker_individual(agent_id: str, + worker_data: WorkerIndividualUpdate, + token_data: TokenData = Depends(Accessor.get_current_user)): + postgres_database = ray_actor_hook("postgres_database") + worker = await postgres_database.get_worker_individual.remote(agent_id=agent_id) + if not worker: + raise HTTPException(status_code=404, detail="Agent not found") + if worker.owner_id != token_data.user_id: + raise HTTPException(status_code=403, detail="Forbidden: You do not own this agent") + + update_data = worker_data.model_dump(exclude_unset=True) + updated_worker = await postgres_database.update_worker_individual.remote(agent_id=agent_id, **update_data) + return {"message": "success", "worker": updated_worker} + + +@agent_router.delete("/worker/{agent_id}") +async def delete_worker_individual(agent_id: str, + token_data: TokenData = Depends(Accessor.get_current_user)): + postgres_database = ray_actor_hook("postgres_database") + worker = await postgres_database.get_worker_individual.remote(agent_id=agent_id) + if not worker: + raise HTTPException(status_code=404, detail="Agent not found") + if worker.owner_id != token_data.user_id: + raise HTTPException(status_code=403, detail="Forbidden: You do not own this agent") + + await postgres_database.delete_worker_individual.remote(agent_id=agent_id) + return {"message": "success"} \ No newline at end of file diff --git a/pretor/core/database/module/individual.py b/pretor/core/database/module/individual.py index 87f3d56..97cb80c 100644 --- a/pretor/core/database/module/individual.py +++ b/pretor/core/database/module/individual.py @@ -14,5 +14,63 @@ from pretor.core.database.table import WorkerIndividual from sqlmodel import select -from pretor.utils.error import UserNotExistError, UserPasswordError -from pretor.core.database.database_exception import database_exception \ No newline at end of file +from typing import List, Optional +from pretor.core.database.database_exception import database_exception + +from ulid import ULID + +class IndividualDatabase: + def __init__(self, async_session_maker): + self.async_session_maker = async_session_maker + + @database_exception + async def add_worker_individual(self, **kwargs) -> WorkerIndividual: + async with self.async_session_maker() as session: + agent_id = str(ULID()) + individual = WorkerIndividual(agent_id=agent_id, **kwargs) + session.add(individual) + await session.commit() + await session.refresh(individual) + return individual + + @database_exception + async def get_worker_individual(self, agent_id: str) -> Optional[WorkerIndividual]: + async with self.async_session_maker() as session: + statement = select(WorkerIndividual).where(WorkerIndividual.agent_id == agent_id) + results = await session.execute(statement) + return results.scalar_one_or_none() + + @database_exception + async def get_worker_individual_list(self, owner_id: str) -> List[WorkerIndividual]: + async with self.async_session_maker() as session: + statement = select(WorkerIndividual).where(WorkerIndividual.owner_id == owner_id) + results = await session.execute(statement) + return list(results.scalars().all()) + + @database_exception + async def update_worker_individual(self, agent_id: str, **kwargs) -> Optional[WorkerIndividual]: + async with self.async_session_maker() as session: + statement = select(WorkerIndividual).where(WorkerIndividual.agent_id == agent_id) + results = await session.execute(statement) + individual = results.scalar_one_or_none() + if not individual: + return None + for key, value in kwargs.items(): + if value is not None: + setattr(individual, key, value) + session.add(individual) + await session.commit() + await session.refresh(individual) + return individual + + @database_exception + async def delete_worker_individual(self, agent_id: str) -> bool: + async with self.async_session_maker() as session: + statement = select(WorkerIndividual).where(WorkerIndividual.agent_id == agent_id) + results = await session.execute(statement) + individual = results.scalar_one_or_none() + if not individual: + return False + session.delete(individual) + await session.commit() + return True diff --git a/pretor/core/database/module/user.py b/pretor/core/database/module/user.py index fa5d6d7..0fc9eea 100644 --- a/pretor/core/database/module/user.py +++ b/pretor/core/database/module/user.py @@ -16,7 +16,7 @@ from pretor.core.database.table import User from sqlmodel import select from pretor.utils.error import UserNotExistError, UserPasswordError from pretor.core.database.database_exception import database_exception - +from pretor.core.database.table.user import UserAuthority class AuthDatabase: def __init__(self, async_session_maker): self.async_session_maker = async_session_maker @@ -65,4 +65,10 @@ class AuthDatabase: user = results.scalar_one_or_none() if user is None: raise UserNotExistError() - return user \ No newline at end of file + return user + + @database_exception + async def get_user_authority(self, user_id: str) -> UserAuthority: + async with self.async_session_maker() as session: + user = await session.get(User, user_id) + return user.user_authority \ No newline at end of file diff --git a/pretor/core/database/postgres.py b/pretor/core/database/postgres.py index 5cd88a2..6b95f89 100644 --- a/pretor/core/database/postgres.py +++ b/pretor/core/database/postgres.py @@ -19,6 +19,7 @@ from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession from sqlalchemy.orm import sessionmaker from sqlmodel import SQLModel +from pretor.core.database.module.individual import IndividualDatabase from pretor.core.database.module.user import AuthDatabase from pretor.core.database.module.provider import ProviderDatabase @@ -36,6 +37,7 @@ class PostgresDatabase: self.auth_database = AuthDatabase(self.async_session_maker) self.provider_database = ProviderDatabase(self.async_session_maker) + self.individual_database = IndividualDatabase(self.async_session_maker) async def init_db(self) -> None: async with self.async_engine.begin() as conn: @@ -59,4 +61,23 @@ class PostgresDatabase: return await self.auth_database.delete_user(**kwargs) async def login_user(self, **kwargs): - return await self.auth_database.login_user(**kwargs) \ No newline at end of file + return await self.auth_database.login_user(**kwargs) + + async def get_user_authority(self, **kwargs): + return await self.auth_database.get_user_authority(**kwargs) + + ##individual_database 操作 + async def add_worker_individual(self, **kwargs): + return await self.individual_database.add_worker_individual(**kwargs) + + async def get_worker_individual(self, agent_id: str): + return await self.individual_database.get_worker_individual(agent_id) + + async def get_worker_individual_list(self, owner_id: str): + return await self.individual_database.get_worker_individual_list(owner_id) + + async def update_worker_individual(self, agent_id: str, **kwargs): + return await self.individual_database.update_worker_individual(agent_id, **kwargs) + + async def delete_worker_individual(self, agent_id: str): + return await self.individual_database.delete_worker_individual(agent_id) \ No newline at end of file diff --git a/pretor/core/database/table/user.py b/pretor/core/database/table/user.py index dff6562..28c2459 100644 --- a/pretor/core/database/table/user.py +++ b/pretor/core/database/table/user.py @@ -13,10 +13,19 @@ # limitations under the License. from sqlmodel import SQLModel, Field +from enum import IntEnum +class UserAuthority(IntEnum): + SUPER_ADMINISTRATOR = 100 + ADMINISTRATOR = 50 + USER = 20 + UNAUTHORIZED_USER = 10 + GUEST = 0 class User(SQLModel): __tablename__ = 'user' - user_id: int = Field(default=None, primary_key=True) + user_id: str = Field(primary_key=True) user_name: str = Field(index=True) hashed_password: str + user_authority: UserAuthority = Field(default=UserAuthority.USER) + diff --git a/pretor/utils/check_user/role_check.py b/pretor/utils/check_user/role_check.py new file mode 100644 index 0000000..60c555a --- /dev/null +++ b/pretor/utils/check_user/role_check.py @@ -0,0 +1,40 @@ +# Copyright 2026 zhaoxi826 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from functools import lru_cache +from typing import Annotated +from fastapi import Depends, HTTPException +from pretor.utils.access import Accessor, TokenData +from pretor.core.database.table.user import UserAuthority +from pretor.utils.ray_hook import ray_actor_hook + +@lru_cache +async def get_authority(user_id: str) -> UserAuthority: + postgres_database = ray_actor_hook("postgres_database") + user_authority = await postgres_database.get_user_authority.remote(user_id=user_id) + return user_authority + +class RoleChecker: + def __init__(self, **kwargs): + self.allowed_roles = kwargs.get("allowed_roles", ) + + async def __call__(self, + token_data: Annotated[TokenData, Depends(Accessor.get_current_user)]): + user_authority = await get_authority(token_data.user_id) + if user_authority < self.allowed_roles: + raise HTTPException( + status_code=403, + detail={"message": f"User {token_data.user_id} does not have allowed roles"}, + ) + return token_data +