diff --git a/pretor/api/platform/frontend.py b/pretor/api/platform/frontend.py index 667c406..0028bb6 100644 --- a/pretor/api/platform/frontend.py +++ b/pretor/api/platform/frontend.py @@ -14,7 +14,8 @@ class Message(BaseModel): async def create_message(message: Message, request: Request, token_date: TokenData = Depends(Accessor.get_current_user)): - logger.info(f"收到消息,来源:客户端,消息内容:{message.message}") + logger.info("收到消息,来源:客户端") + logger.debug(f"消息内容:{message.message}") event = PretorEvent(platform="client", user_id=str(token_date.user_id), user_name=token_date.user_name, diff --git a/pretor/core/individual/consciousness_node/consciousness_node.py b/pretor/core/individual/consciousness_node/consciousness_node.py index 8c9e069..29fd5b8 100644 --- a/pretor/core/individual/consciousness_node/consciousness_node.py +++ b/pretor/core/individual/consciousness_node/consciousness_node.py @@ -1,7 +1,6 @@ import ray from pydantic_ai import Agent from pretor.core.workflow.workflow import PretorWorkflow, WorkStep, WorkerGroup -import uuid @ray.remote class ConsciousnessNode: diff --git a/pretor/core/workflow/workflow_runner.py b/pretor/core/workflow/workflow_runner.py index 7b34023..512ed50 100644 --- a/pretor/core/workflow/workflow_runner.py +++ b/pretor/core/workflow/workflow_runner.py @@ -3,34 +3,22 @@ import asyncio from pretor.core.workflow.workflow import PretorWorkflow, WorkStep from loguru import logger from typing import Optional, Dict, Union, Any, List +from pretor.utils.error import WorkflowError, WorkflowExit class WorkflowEngine: def __init__(self, workflow: PretorWorkflow): self.workflow: PretorWorkflow = workflow - # 局部上下文记忆(黑板):用于存放上一个步骤的 output,作为下一个步骤的 input self.context_memory: Dict[str, Any] = {} - - def _get_step_by_id(self, step_id: int) -> Optional[WorkStep]: - """根据序号获取当前步骤的定义""" - for step in self.workflow.work_link: - if step.step == step_id: - return step - return None + self._steps_by_id: Dict[int, WorkStep] = {step.step: step for step in self.workflow.work_link} def _prepare_inputs(self, inputs: Optional[Union[str, List[str]]]) -> Any: - """从上下文中提取当前步骤所需的入参""" - if not inputs: - return None - - if isinstance(inputs, str): - # 如果 input 是单一变量名,直接返回该变量的值 - return self.context_memory.get(inputs) - - if isinstance(inputs, list): - # 如果 input 是列表,返回包含这些变量名及其值的字典 - return {k: self.context_memory.get(k) for k in inputs} - - return None + match inputs: + case None: + return None + case str(name): + return self.context_memory.get(name) + case list(names): + return {k: self.context_memory.get(k) for k in names} async def run(self): logger.info(f"🚀 工作流引擎启动: {self.workflow.title} [Trace ID: {self.workflow.trace_id}]") @@ -39,7 +27,7 @@ class WorkflowEngine: # 核心调度循环:只要 step 在合法范围内,就一直执行 while 1 <= self.workflow.status.step <= max_step: current_step_id = self.workflow.status.step - current_step = self._get_step_by_id(current_step_id) + current_step = self._steps_by_id.get(current_step_id) if not current_step: logger.error(f"严重错误:找不到步骤 {current_step_id},工作流强制终止。") @@ -71,6 +59,13 @@ class WorkflowEngine: # 4. 根据执行成功与否处理逻辑门跳转 self._handle_logic_gate(current_step, is_success) + except WorkflowExit: + logger.info("命中 if_pass='exit',工作流被主动要求结束。") + break + except WorkflowError as e: + logger.error(f"{e},终止工作流。") + self.workflow.status.status = "failed" + break except Exception as e: # 捕获系统级崩溃 (例如 Ray Actor 断联、网络异常) logger.error(f"❌ Step {current_step_id} 发生系统级未捕获异常: {e}", exc_info=True) @@ -98,6 +93,7 @@ class WorkflowEngine: # worker = get_worker_actor(step.action) # result = await worker.run.remote(step.desc, input_data) # return result, True + await asyncio.sleep(1) simulated_result = f"这是 {step.action} 动作产生的模拟结果" is_success = True @@ -109,22 +105,21 @@ class WorkflowEngine: if is_success: if gate and gate.if_pass == "exit": - logger.info("命中 if_pass='exit',工作流被主动要求结束。") - self.workflow.status.step = 999999 # 设置一个越界值来终结 while 循环 - else: - self.workflow.status.step += 1 # 默认成功则步数 +1,继续下一步 + raise WorkflowExit() + self.workflow.status.step += 1 # 默认成功则步数 +1,继续下一步 else: - if gate and gate.if_fail: - if gate.if_fail.startswith("jump_to_step_"): - target_step = int(gate.if_fail.split("_")[-1]) + if not gate or not gate.if_fail: + raise WorkflowError(f"步骤 {step.step} 失败且未配置 if_fail 兜底方案") + + match gate.if_fail.split("_"): + case ["jump", "to", "step", target] if target.isdigit(): + target_step = int(target) logger.warning(f"触发逻辑门分支!从 Step {step.step} 跳转至 Step {target_step}") self.workflow.status.step = target_step - else: - logger.error(f"未知的 if_fail 格式: {gate.if_fail},终止工作流。") - self.workflow.status.step = 999999 - else: - logger.error(f"步骤 {step.step} 失败且未配置 if_fail 兜底方案,工作流异常终止。") - self.workflow.status.step = 999999 + case _: + raise WorkflowError(f"未知的 if_fail 格式: {gate.if_fail}") + + @ray.remote class WorkflowRunningEngine: @@ -139,12 +134,16 @@ class WorkflowRunningEngine: } self.workflow_queue = asyncio.Queue() - async def runner(self,i: int): + async def runner(self, i: int): while True: try: workflow = await self.workflow_queue.get() logger.info(f"WorkflowRunningEngine: runner_{i}接收工作流{workflow.trace_id}:{workflow.title}") workflow_engine = WorkflowEngine(workflow) await workflow_engine.run() - except: - pass \ No newline at end of file + except asyncio.CancelledError: + logger.info(f"WorkflowRunningEngine: runner_{i} 被取消。") + raise + except Exception as e: + logger.error(f"WorkflowRunningEngine: runner_{i} 遇到未捕获的异常: {e}", exc_info=True) + diff --git a/pretor/utils/error.py b/pretor/utils/error.py index c10b05f..9acd9b7 100644 --- a/pretor/utils/error.py +++ b/pretor/utils/error.py @@ -17,4 +17,12 @@ class ProviderError(Exception): pass class ProviderNotExistError(ProviderError): - pass \ No newline at end of file + pass + +class WorkflowError(Exception): + + pass + +class WorkflowExit(WorkflowError): + + pass diff --git a/pyproject.toml b/pyproject.toml index 5cf0096..d0a3801 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,6 +13,7 @@ dependencies = [ "loguru>=0.7.3", "passlib[bcrypt]>=1.7.4", "pydantic-ai>=1.73.0", + "pytest>=9.0.3", "python-ulid>=3.1.0", "ray[default,serve]>=2.54.0", "sqlmodel>=0.0.37", diff --git a/tests/utils/access_test.py b/tests/utils/access_test.py new file mode 100644 index 0000000..3c25c86 --- /dev/null +++ b/tests/utils/access_test.py @@ -0,0 +1,88 @@ +import sys +from unittest.mock import MagicMock, patch + + +# Mock dependencies before importing the module under test +class MockHTTPException(Exception): + def __init__(self, status_code, detail=None, headers=None): + self.status_code = status_code + self.detail = detail + self.headers = headers + + +class MockValidationError(Exception): + pass + + +mock_fastapi = MagicMock() +mock_fastapi.HTTPException = MockHTTPException +mock_fastapi.status.HTTP_401_UNAUTHORIZED = 401 + +mock_pydantic = MagicMock() +mock_pydantic.ValidationError = MockValidationError + +sys.modules["fastapi"] = mock_fastapi +sys.modules["pydantic"] = mock_pydantic +sys.modules["sqlmodel"] = MagicMock() +sys.modules["passlib"] = MagicMock() +sys.modules["passlib.context"] = MagicMock() +sys.modules["pretor.core.database.table.user"] = MagicMock() + +import pytest +import jwt +from pretor.utils.access import Accessor + + +def test_decode_token_success(): + """Test successful token decoding.""" + token = "valid.token.here" + payload = {"user_id": "123", "username": "testuser", "exp": 1234567890} + + with patch("jwt.decode", return_value=payload) as mock_decode: + with patch("pretor.utils.access.TokenData") as mock_token_data_cls: + mock_token_data_instance = MagicMock() + mock_token_data_cls.return_value = mock_token_data_instance + + result = Accessor._decode_token(token) + + mock_decode.assert_called_once() + mock_token_data_cls.assert_called_once_with(**payload) + assert result == mock_token_data_instance + + +def test_decode_token_expired(): + """Test token decoding with an expired token.""" + token = "expired.token.here" + + with patch("jwt.decode", side_effect=jwt.ExpiredSignatureError): + with pytest.raises(MockHTTPException) as excinfo: + Accessor._decode_token(token) + + assert excinfo.value.status_code == 401 + assert excinfo.value.detail == "Token 已过期" + + +def test_decode_token_invalid(): + """Test token decoding with an invalid token.""" + token = "invalid.token.here" + + with patch("jwt.decode", side_effect=jwt.InvalidTokenError): + with pytest.raises(MockHTTPException) as excinfo: + Accessor._decode_token(token) + + assert excinfo.value.status_code == 401 + assert excinfo.value.detail == "无效的认证凭证" + + +def test_decode_token_validation_error(): + """Test token decoding with a payload that fails validation.""" + token = "valid.jwt.invalid.payload" + payload = {"wrong": "payload"} + + with patch("jwt.decode", return_value=payload): + with patch("pretor.utils.access.TokenData", side_effect=MockValidationError): + with pytest.raises(MockHTTPException) as excinfo: + Accessor._decode_token(token) + + assert excinfo.value.status_code == 401 + assert excinfo.value.detail == "无效的认证凭证" diff --git a/uv.lock b/uv.lock index ed670df..780a8e2 100644 --- a/uv.lock +++ b/uv.lock @@ -189,6 +189,7 @@ dependencies = [ { name = "loguru" }, { name = "passlib", extra = ["bcrypt"] }, { name = "pydantic-ai" }, + { name = "pytest" }, { name = "python-ulid" }, { name = "ray", extra = ["default", "serve"] }, { name = "sqlmodel" }, @@ -205,6 +206,7 @@ requires-dist = [ { name = "loguru", specifier = ">=0.7.3" }, { name = "passlib", extras = ["bcrypt"], specifier = ">=1.7.4" }, { name = "pydantic-ai", specifier = ">=1.73.0" }, + { name = "pytest", specifier = ">=9.0.3" }, { name = "python-ulid", specifier = ">=3.1.0" }, { name = "ray", extras = ["default", "serve"], specifier = ">=2.54.0" }, { name = "sqlmodel", specifier = ">=0.0.37" }, @@ -1234,6 +1236,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/fa/5e/f8e9a1d23b9c20a551a8a02ea3637b4642e22c2626e3a13a9a29cdea99eb/importlib_metadata-8.7.1-py3-none-any.whl", hash = "sha256:5a1f80bf1daa489495071efbb095d75a634cf28a8bc299581244063b53176151", size = 27865, upload-time = "2025-12-21T10:00:18.329Z" }, ] +[[package]] +name = "iniconfig" +version = "2.3.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/72/34/14ca021ce8e5dfedc35312d08ba8bf51fdd999c576889fc2c24cb97f4f10/iniconfig-2.3.0.tar.gz", hash = "sha256:c76315c77db068650d49c5b56314774a7804df16fee4402c1f19d6d15d8c4730", size = 20503, upload-time = "2025-10-18T21:55:43.219Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/cb/b1/3846dd7f199d53cb17f49cba7e651e9ce294d8497c8c150530ed11865bb8/iniconfig-2.3.0-py3-none-any.whl", hash = "sha256:f631c04d2c48c52b84d0d0549c99ff3859c98df65b3101406327ecc7d53fbf12", size = 7484, upload-time = "2025-10-18T21:55:41.639Z" }, +] + [[package]] name = "jaraco-classes" version = "3.4.0" @@ -1954,6 +1965,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/63/d7/97f7e3a6abb67d8080dd406fd4df842c2be0efaf712d1c899c32a075027c/platformdirs-4.9.4-py3-none-any.whl", hash = "sha256:68a9a4619a666ea6439f2ff250c12a853cd1cbd5158d258bd824a7df6be2f868", size = 21216, upload-time = "2026-03-05T18:34:12.172Z" }, ] +[[package]] +name = "pluggy" +version = "1.6.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f9/e2/3e91f31a7d2b083fe6ef3fa267035b518369d9511ffab804f839851d2779/pluggy-1.6.0.tar.gz", hash = "sha256:7dcc130b76258d33b90f61b658791dede3486c3e6bfb003ee5c9bfb396dd22f3", size = 69412, upload-time = "2025-05-15T12:30:07.975Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746", size = 20538, upload-time = "2025-05-15T12:30:06.134Z" }, +] + [[package]] name = "prometheus-client" version = "0.24.1" @@ -2404,6 +2424,22 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/df/80/fc9d01d5ed37ba4c42ca2b55b4339ae6e200b456be3a1aaddf4a9fa99b8c/pyperclip-1.11.0-py3-none-any.whl", hash = "sha256:299403e9ff44581cb9ba2ffeed69c7aa96a008622ad0c46cb575ca75b5b84273", size = 11063, upload-time = "2025-09-26T14:40:36.069Z" }, ] +[[package]] +name = "pytest" +version = "9.0.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "colorama", marker = "sys_platform == 'win32'" }, + { name = "iniconfig" }, + { name = "packaging" }, + { name = "pluggy" }, + { name = "pygments" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/7d/0d/549bd94f1a0a402dc8cf64563a117c0f3765662e2e668477624baeec44d5/pytest-9.0.3.tar.gz", hash = "sha256:b86ada508af81d19edeb213c681b1d48246c1a91d304c6c81a427674c17eb91c", size = 1572165, upload-time = "2026-04-07T17:16:18.027Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d4/24/a372aaf5c9b7208e7112038812994107bc65a84cd00e0354a88c2c77a617/pytest-9.0.3-py3-none-any.whl", hash = "sha256:2c5efc453d45394fdd706ade797c0a81091eccd1d6e4bccfcd476e2b8e0ab5d9", size = 375249, upload-time = "2026-04-07T17:16:16.13Z" }, +] + [[package]] name = "python-dateutil" version = "2.9.0.post0"