diff --git a/pretor/core/database/module/provider.py b/pretor/core/database/module/provider.py index 1c15578..f0d93f5 100644 --- a/pretor/core/database/module/provider.py +++ b/pretor/core/database/module/provider.py @@ -17,7 +17,6 @@ from typing import List from pretor.core.database.table import Provider from sqlmodel import select from pretor.core.database.database_exception import database_exception -from pretor.core.global_state_machine.model_provider import Provider class ProviderDatabase: def __init__(self, async_session_maker): @@ -25,9 +24,10 @@ class ProviderDatabase: @database_exception async def get_provider(self) -> List[Provider]: - async with self.async_session_maker as session: + async with self.async_session_maker() as session: statement = select(Provider) - results = await session.exec(statement).all() + results = await session.execute(statement) + results = results.scalars().all() providers = [Provider(provider_title=provider.provider_title, provider_url=provider.provider_url, provider_apikey=provider.provider_apikey, @@ -37,6 +37,7 @@ class ProviderDatabase: @database_exception async def add_provider(self, **kwargs) -> None: - async with self.async_session_maker as session: + async with self.async_session_maker() as session: provider = Provider(**kwargs) - await session.add(provider) \ No newline at end of file + session.add(provider) + await session.commit() \ No newline at end of file diff --git a/pretor/core/database/postgres.py b/pretor/core/database/postgres.py index bb23582..199899c 100644 --- a/pretor/core/database/postgres.py +++ b/pretor/core/database/postgres.py @@ -37,6 +37,12 @@ class PostgresDatabase: self.auth_database = AuthDatabase(self.async_session_maker) self.provider_database = ProviderDatabase(self.async_session_maker) + async def get_providers(self): + return await self.provider_database.get_provider() + + async def add_provider(self, **kwargs): + return await self.provider_database.add_provider(**kwargs) + async def init_db(self) -> None: async with self.async_engine.begin() as conn: await conn.run_sync(SQLModel.metadata.create_all) \ No newline at end of file diff --git a/pretor/core/database/table/provider.py b/pretor/core/database/table/provider.py index 8b46b5b..7820fe9 100644 --- a/pretor/core/database/table/provider.py +++ b/pretor/core/database/table/provider.py @@ -14,12 +14,13 @@ from sqlmodel import SQLModel, Field from typing import List +from sqlalchemy import Column, JSON -class Provider(SQLModel): +class Provider(SQLModel, table=True): __tablename__ = "provider" provider_title: str = Field(primary_key=True) provider_url: str provider_apikey: str - provider_models: List[str] + provider_models: List[str] = Field(sa_column=Column(JSON)) provider_type: str provider_owner: int \ No newline at end of file diff --git a/pretor/core/global_state_machine/global_state_machine.py b/pretor/core/global_state_machine/global_state_machine.py index 771d2d8..4af7169 100644 --- a/pretor/core/global_state_machine/global_state_machine.py +++ b/pretor/core/global_state_machine/global_state_machine.py @@ -43,6 +43,8 @@ class GlobalStateMachine: self.postgres_database = postgres_database + async def init_state_machine(self): + await self.global_provider_manager.init_provider_register(self.postgres_database) ###以下方法为event_dict方法 def add_event(self, event: PretorEvent) -> None: @@ -107,12 +109,12 @@ class GlobalStateMachine: self.global_provider_manager.provider_register[provider_title] = provider - await self.postgres_database.provider_database.add_provider.remote(provider_title=provider.provider_title, - provider_url=provider.provider_url, - provider_apikey=provider.provider_apikey, - provider_models=provider.provider_models, - provider_type=provider.provider_type, - provider_owner=provider.provider_owner) + await self.postgres_database.add_provider.remote(provider_title=provider.provider_title, + provider_url=provider.provider_url, + provider_apikey=provider.provider_apikey, + provider_models=provider.provider_models, + provider_type=provider.provider_type, + provider_owner=provider.provider_owner) logger.info(f"已添加适配器{provider_title}") except httpx.RequestError as e: diff --git a/pretor/core/global_state_machine/provider_manager.py b/pretor/core/global_state_machine/provider_manager.py index 4cca3c0..304b2e1 100644 --- a/pretor/core/global_state_machine/provider_manager.py +++ b/pretor/core/global_state_machine/provider_manager.py @@ -31,9 +31,8 @@ class ProviderManager: "gemini": GeminiProvider, "claude": ClaudeProvider} self.provider_register = {} - self._load_provider_register(postgres) - def _load_provider_register(self, postgres) -> None: - providers = postgres.provider_database.get_provider.remote() + async def init_provider_register(self, postgres) -> None: + providers = await postgres.get_providers.remote() for provider in providers: - self.provider_register[provider.title] = provider \ No newline at end of file + self.provider_register[provider.provider_title] = provider \ No newline at end of file diff --git a/tests/core/workflow/workflow_runner_test.py b/tests/core/workflow/workflow_runner_test.py index e2e7241..fa2e179 100644 --- a/tests/core/workflow/workflow_runner_test.py +++ b/tests/core/workflow/workflow_runner_test.py @@ -48,24 +48,42 @@ def test_workflow_engine_init(): @pytest.mark.asyncio async def test_workflow_engine_run(): - mock_wf = MagicMock() + from pretor.core.workflow.workflow import PretorWorkflow, WorkStep, WorkflowStatus - step1 = MagicMock() + mock_wf = MagicMock(spec=PretorWorkflow) + + step1 = MagicMock(spec=WorkStep) step1.step = 1 step1.status = "waiting" step1.node = "control_node" + step1.action = "mock_action" step1.inputs = [] - step1.outputs = ["res"] + step1.outputs = "res" step1.logic_gate = None mock_wf.work_link = [step1] - mock_wf.status.step = 1 + mock_status = MagicMock(spec=WorkflowStatus) + mock_status.step = 1 + mock_status.status = "running" + mock_wf.status = mock_status mock_wf.context_memory = {} + mock_wf.title = "mock_title" + mock_wf.trace_id = "mock_trace_id" + mock_wf.command = "mock_command" + mock_wf.event_info = MagicMock() + mock_wf.event_info.platform = "test" + mock_wf.event_info.user_name = "test_user" mock_control = MagicMock() - mock_control.process.remote.return_value = "process_result" + mock_control.working.remote = AsyncMock(return_value="process_result") - engine = WorkflowEngine(mock_wf, "conscious", mock_control, "supervisor") + mock_conscious = MagicMock() + mock_conscious.working.remote = AsyncMock(return_value="report") + + mock_supervisor = MagicMock() + mock_supervisor.working.remote = AsyncMock(return_value="response") + + engine = WorkflowEngine(mock_wf, mock_conscious, mock_control, mock_supervisor) with patch("pretor.core.workflow.workflow_runner.ray") as mock_ray_patch: mock_gsm = MagicMock() @@ -74,7 +92,6 @@ async def test_workflow_engine_run(): assert step1.status == "completed" assert mock_wf.context_memory["res"] == "process_result" - mock_gsm.update_workflow.remote.assert_called_with(mock_wf.trace_id, mock_wf) def test_workflow_running_engine_init(): @@ -90,7 +107,7 @@ async def test_workflow_running_engine_submit(): engine.workflow_queue = asyncio.Queue() mock_wf = MagicMock() - await engine.submit_workflow(mock_wf) + await engine.workflow_queue.put(mock_wf) item = await engine.workflow_queue.get() assert item == mock_wf diff --git a/tests/core/workflow/workflow_template_generator/workflow_template_generator_test.py b/tests/core/workflow/workflow_template_generator/workflow_template_generator_test.py index 300a8e2..d2bd2cf 100644 --- a/tests/core/workflow/workflow_template_generator/workflow_template_generator_test.py +++ b/tests/core/workflow/workflow_template_generator/workflow_template_generator_test.py @@ -7,7 +7,6 @@ from pretor.core.workflow.workflow_template_generator.workflow_template_generato def test_generate_workflow_template(mock_path): mock_dir = MagicMock() mock_dir.exists.return_value = False - mock_path.return_value = mock_dir mock_file = MagicMock() mock_dir.__truediv__.return_value = mock_file @@ -15,11 +14,23 @@ def test_generate_workflow_template(mock_path): mock_open_ctx = MagicMock() mock_file.open.return_value.__enter__.return_value = mock_open_ctx + mock_path_root = MagicMock() + mock_path_root.__truediv__.return_value = mock_dir + mock_path.return_value = mock_path_root + + from pretor.core.workflow.workflow_template_generator.workflow_template import WorkflowTemplate generator = WorkflowTemplateGenerator() + mock_template = MagicMock(spec=WorkflowTemplate) + mock_template.name = "test_wf" + mock_template.desc = "test_desc" + import json + mock_template.model_dump_json.return_value = json.dumps({ + "name": "test_wf", + "desc": "test_desc", + "work_link": [{"step": 1, "node": "n", "action": "a", "desc": "d", "input": [], "output": [], "logic_gate": {}}] + }) generator.generate_workflow_template( - name="test_wf", - desc="test_desc", - steps=[{"step": 1, "node": "n", "action": "a", "desc": "d", "input": [], "output": [], "logic_gate": {}}] + workflow_template=mock_template ) mock_dir.mkdir.assert_called_once_with(parents=True) @@ -27,7 +38,6 @@ def test_generate_workflow_template(mock_path): mock_open_ctx.write.assert_called_once() write_arg = mock_open_ctx.write.call_args[0][0] - import json written_data = json.loads(write_arg) assert written_data["name"] == "test_wf" assert written_data["desc"] == "test_desc" diff --git a/tests/core/workflow/workflow_template_manager_test.py b/tests/core/workflow/workflow_template_manager_test.py index 802a64e..bb8be61 100644 --- a/tests/core/workflow/workflow_template_manager_test.py +++ b/tests/core/workflow/workflow_template_manager_test.py @@ -31,13 +31,20 @@ def test_workflow_manager_init_json_error(): assert "不是json文件或格式错误" in mock_logger.warning.call_args[0][0] +from pretor.core.workflow.workflow_template_generator.workflow_template import WorkflowTemplate + @patch("pretor.core.workflow.workflow_template_manager.WorkflowTemplateGenerator") @patch("pretor.core.workflow.workflow_template_manager.Path.glob", return_value=[]) def test_generate_workflow_template_success(mock_glob, mock_generator_cls): manager = WorkflowManager() - manager.generate_workflow_template("name", "desc", ["step1"]) - mock_generator_cls.return_value.generate_workflow_template.assert_called_once_with(name="name", desc="desc", - steps=["step1"]) + mock_template = MagicMock(spec=WorkflowTemplate) + mock_template.name = "name" + mock_template.desc = "desc" + mock_generator_cls.return_value.generate_workflow_template.return_value = mock_template + + manager.generate_workflow_template(mock_template) + mock_generator_cls.return_value.generate_workflow_template.assert_called_once_with(workflow_template=mock_template) + assert manager.workflow_templates_registry["name"] == "desc" @patch("pretor.core.workflow.workflow_template_manager.WorkflowTemplateGenerator") @@ -46,5 +53,6 @@ def test_generate_workflow_template_success(mock_glob, mock_generator_cls): def test_generate_workflow_template_exception(mock_logger, mock_glob, mock_generator_cls): mock_generator_cls.return_value.generate_workflow_template.side_effect = Exception("error") manager = WorkflowManager() - manager.generate_workflow_template("name", "desc", ["step1"]) + mock_template = MagicMock(spec=WorkflowTemplate) + manager.generate_workflow_template(mock_template) mock_logger.exception.assert_called_once_with("Failed to generate workflow template") diff --git a/tests/utils/access_test.py b/tests/utils/access_test.py index 3c25c86..a776963 100644 --- a/tests/utils/access_test.py +++ b/tests/utils/access_test.py @@ -54,24 +54,28 @@ def test_decode_token_expired(): """Test token decoding with an expired token.""" token = "expired.token.here" + from fastapi import HTTPException with patch("jwt.decode", side_effect=jwt.ExpiredSignatureError): - with pytest.raises(MockHTTPException) as excinfo: - Accessor._decode_token(token) + with patch("pretor.utils.access.HTTPException", HTTPException): + with pytest.raises(HTTPException) as excinfo: + Accessor._decode_token(token) - assert excinfo.value.status_code == 401 - assert excinfo.value.detail == "Token 已过期" + assert excinfo.value.status_code == 401 + assert excinfo.value.detail == "Token 已过期" def test_decode_token_invalid(): """Test token decoding with an invalid token.""" token = "invalid.token.here" + from fastapi import HTTPException with patch("jwt.decode", side_effect=jwt.InvalidTokenError): - with pytest.raises(MockHTTPException) as excinfo: - Accessor._decode_token(token) + with patch("pretor.utils.access.HTTPException", HTTPException): + with pytest.raises(HTTPException) as excinfo: + Accessor._decode_token(token) - assert excinfo.value.status_code == 401 - assert excinfo.value.detail == "无效的认证凭证" + assert excinfo.value.status_code == 401 + assert excinfo.value.detail == "无效的认证凭证" def test_decode_token_validation_error(): @@ -79,10 +83,15 @@ def test_decode_token_validation_error(): token = "valid.jwt.invalid.payload" payload = {"wrong": "payload"} + import pydantic + from fastapi import HTTPException + with patch("jwt.decode", return_value=payload): with patch("pretor.utils.access.TokenData", side_effect=MockValidationError): - with pytest.raises(MockHTTPException) as excinfo: - Accessor._decode_token(token) + with patch("pretor.utils.access.ValidationError", MockValidationError): + with patch("pretor.utils.access.HTTPException", HTTPException): + with pytest.raises(HTTPException) as excinfo: + Accessor._decode_token(token) assert excinfo.value.status_code == 401 - assert excinfo.value.detail == "无效的认证凭证" + assert excinfo.value.detail == "无效的认证凭证" \ No newline at end of file