diff --git a/pretor/core/api/__init__.py b/pretor/core/api/__init__.py index 84a5c30..25a0a35 100644 --- a/pretor/core/api/__init__.py +++ b/pretor/core/api/__init__.py @@ -46,74 +46,70 @@ app = FastAPI() class PretorGateway: gateway: Dict[str, WebSocket] def __init__(self): - self.app = app self.gateway = {} - self.app.include_router(client_router)#客户端路径 + app.include_router(client_router)#客户端路径 - self.app.include_router(auth_router)#用户路径 - self.app.include_router(provider_router)#供应商路径 - self.app.include_router(resource_router)#资源路径 - self.app.include_router(cluster_router)#集群信息路径 - self.app.include_router(agent_router)#agent路径 + app.include_router(auth_router)#用户路径 + app.include_router(provider_router)#供应商路径 + app.include_router(resource_router)#资源路径 + app.include_router(cluster_router)#集群信息路径 + app.include_router(agent_router)#agent路径 - @self.app.exception_handler(UserNotExistError) + @app.exception_handler(UserNotExistError) async def user_not_exist_handler(request: Request, exc: UserNotExistError): return JSONResponse(status_code=404, content={"message": "用户不存在"}) - @self.app.exception_handler(UserPasswordError) + @app.exception_handler(UserPasswordError) async def user_password_handler(request: Request, exc: UserPasswordError): return JSONResponse(status_code=401, content={"message": "密码错误"}) - @self.app.exception_handler(UserError) + @app.exception_handler(UserError) async def user_error_handler(request: Request, exc: UserError): return JSONResponse(status_code=400, content={"message": "用户相关错误"}) - @self.app.exception_handler(ProviderNotExistError) + @app.exception_handler(ProviderNotExistError) async def provider_not_exist_handler(request: Request, exc: ProviderNotExistError): return JSONResponse(status_code=404, content={"message": "服务提供商不存在"}) - @self.app.exception_handler(ProviderError) + @app.exception_handler(ProviderError) async def provider_error_handler(request: Request, exc: ProviderError): return JSONResponse(status_code=400, content={"message": "服务提供商错误"}) - @self.app.exception_handler(ModelNotExistError) + @app.exception_handler(ModelNotExistError) async def model_not_exist_handler(request: Request, exc: ModelNotExistError): return JSONResponse(status_code=404, content={"message": "模型不存在"}) - @self.app.exception_handler(DemandError) + @app.exception_handler(DemandError) async def demand_error_handler(request: Request, exc: DemandError): return JSONResponse(status_code=400, content={"message": "需求格式错误或不满足"}) - @self.app.exception_handler(WorkflowExit) + @app.exception_handler(WorkflowExit) async def workflow_exit_handler(request: Request, exc: WorkflowExit): return JSONResponse(status_code=400, content={"message": "工作流已退出"}) - @self.app.exception_handler(WorkflowError) + @app.exception_handler(WorkflowError) async def workflow_error_handler(request: Request, exc: WorkflowError): return JSONResponse(status_code=500, content={"message": "工作流执行错误"}) frontend_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))), "frontend", "dist") if os.path.exists(frontend_dir): - self.app.mount("/assets", StaticFiles(directory=os.path.join(frontend_dir, "assets")), name="assets") + app.mount("/assets", StaticFiles(directory=os.path.join(frontend_dir, "assets")), name="assets") # Serve favicon and other top-level static files if they exist - @self.app.get("/favicon.svg", include_in_schema=False) + @app.get("/favicon.svg", include_in_schema=False) async def serve_favicon(): return FileResponse(os.path.join(frontend_dir, "favicon.svg")) - @self.app.get("/icons.svg", include_in_schema=False) + @app.get("/icons.svg", include_in_schema=False) async def serve_icons(): return FileResponse(os.path.join(frontend_dir, "icons.svg")) - @self.app.get("/{full_path:path}", include_in_schema=False) + @app.get("/{full_path:path}", include_in_schema=False) async def serve_frontend(full_path: str): # If a path isn't API or assets, fallback to index.html for React Router / SPA handling # In this specific case, it also fixes any root path reloading issues return FileResponse(os.path.join(frontend_dir, "index.html")) - async def __call__(self, request: Request): - return await self.app(request.scope, request.receive, request._send) - diff --git a/pretor/core/database/module/user.py b/pretor/core/database/module/user.py index 7783fb0..0ac0ed9 100644 --- a/pretor/core/database/module/user.py +++ b/pretor/core/database/module/user.py @@ -24,8 +24,22 @@ class AuthDatabase: @database_exception async def add_user(self, user_name: str, hashed_password: str) -> User: from ulid import ULID - user = User(user_id=str(ULID()), user_name=user_name, hashed_password=hashed_password) async with self.async_session_maker() as session: + # Check if any users exist + statement = select(User).limit(1) + results = await session.execute(statement) + existing_user = results.first() + + authority = UserAuthority.USER + if existing_user is None: + authority = UserAuthority.SUPER_ADMINISTRATOR + + user = User( + user_id=str(ULID()), + user_name=user_name, + hashed_password=hashed_password, + user_authority=authority + ) session.add(user) await session.commit() await session.refresh(user) diff --git a/pretor/core/database/postgres.py b/pretor/core/database/postgres.py index d550950..ec10379 100644 --- a/pretor/core/database/postgres.py +++ b/pretor/core/database/postgres.py @@ -40,8 +40,13 @@ class PostgresDatabase: self._individual_database = IndividualDatabase(self.async_session_maker) async def init_db(self) -> None: - async with self.async_engine.begin() as conn: - await conn.run_sync(SQLModel.metadata.create_all) + try: + async with self.async_engine.begin() as conn: + await conn.run_sync(SQLModel.metadata.create_all) + except Exception as e: + # Provide a warning if the database is not accessible, allowing + # the app to start up for development/UI tests without crashing immediately. + print(f"Warning: Failed to initialize PostgreSQL database: {e}") async def auth_database(self, method_name: str, *args, **kwargs): method = getattr(self._auth_database, method_name) diff --git a/tests/core/database/module/user_test.py b/tests/core/database/module/user_test.py index f526ad8..6deba1d 100644 --- a/tests/core/database/module/user_test.py +++ b/tests/core/database/module/user_test.py @@ -34,6 +34,10 @@ async def test_add_user(mock_session_maker, mock_dependencies): mock_user.hashed_password = "hashedpw" mock_user_cls.return_value = mock_user + mock_exec_result = MagicMock() + mock_exec_result.first.return_value = None + session.execute = AsyncMock(return_value=mock_exec_result) + user = await db.add_user("testuser", "hashedpw") assert user.user_name == "testuser" @@ -58,11 +62,11 @@ async def test_change_password_success(mock_session_maker, mock_dependencies): mock_exec_result = MagicMock() mock_exec_result.scalar_one_or_none.return_value = mock_user - session.exec = AsyncMock(return_value=mock_exec_result) + session.execute = AsyncMock(return_value=mock_exec_result) user = await db.change_password("testuser", "old_password", "new_password") - session.exec.assert_called_once_with(mock_statement) + session.execute.assert_called_once_with(mock_statement) assert user.hashed_password == "new_password" session.add.assert_called_once_with(mock_user) session.commit.assert_called_once() @@ -78,7 +82,7 @@ async def test_change_password_user_not_exist(mock_session_maker, mock_dependenc mock_exec_result = MagicMock() mock_exec_result.scalar_one_or_none.return_value = None - session.exec = AsyncMock(return_value=mock_exec_result) + session.execute = AsyncMock(return_value=mock_exec_result) result = await db.change_password("testuser", "old_password", "new_password") assert result is None @@ -95,7 +99,7 @@ async def test_change_password_wrong_password(mock_session_maker, mock_dependenc mock_user.hashed_password = "actual_password" mock_exec_result = MagicMock() mock_exec_result.scalar_one_or_none.return_value = mock_user - session.exec = AsyncMock(return_value=mock_exec_result) + session.execute = AsyncMock(return_value=mock_exec_result) from pretor.utils.error import UserPasswordError with pytest.raises(UserPasswordError): @@ -115,10 +119,10 @@ async def test_delete_user_success(mock_session_maker, mock_dependencies): mock_user = MagicMock() mock_exec_result = MagicMock() mock_exec_result.scalar_one_or_none.return_value = mock_user - session.exec = AsyncMock(return_value=mock_exec_result) + session.execute = AsyncMock(return_value=mock_exec_result) await db.delete_user("testuser") - session.exec.assert_called_once_with(mock_statement) + session.execute.assert_called_once_with(mock_statement) session.delete.assert_called_once_with(mock_user) session.commit.assert_called_once() @@ -132,7 +136,7 @@ async def test_delete_user_not_exist(mock_session_maker, mock_dependencies): mock_exec_result = MagicMock() mock_exec_result.scalar_one_or_none.return_value = None - session.exec = AsyncMock(return_value=mock_exec_result) + session.execute = AsyncMock(return_value=mock_exec_result) result = await db.delete_user("testuser") assert result is None @@ -151,10 +155,10 @@ async def test_login_user_success(mock_session_maker, mock_dependencies): mock_user = MagicMock() mock_exec_result = MagicMock() mock_exec_result.scalar_one_or_none.return_value = mock_user - session.exec = AsyncMock(return_value=mock_exec_result) + session.execute = AsyncMock(return_value=mock_exec_result) user = await db.login_user("testuser") - session.exec.assert_called_once_with(mock_statement) + session.execute.assert_called_once_with(mock_statement) assert user == mock_user @@ -167,7 +171,7 @@ async def test_login_user_not_exist(mock_session_maker, mock_dependencies): mock_exec_result = MagicMock() mock_exec_result.scalar_one_or_none.return_value = None - session.exec = AsyncMock(return_value=mock_exec_result) + session.execute = AsyncMock(return_value=mock_exec_result) result = await db.login_user("testuser") assert result is None