wip:改进

This commit is contained in:
朝夕 2026-04-09 23:06:01 +08:00
parent 2552017ea7
commit a7bd7f786e
7 changed files with 173 additions and 41 deletions

View File

@ -14,7 +14,8 @@ class Message(BaseModel):
async def create_message(message: Message, async def create_message(message: Message,
request: Request, request: Request,
token_date: TokenData = Depends(Accessor.get_current_user)): token_date: TokenData = Depends(Accessor.get_current_user)):
logger.info(f"收到消息,来源:客户端,消息内容:{message.message}") logger.info("收到消息,来源:客户端")
logger.debug(f"消息内容:{message.message}")
event = PretorEvent(platform="client", event = PretorEvent(platform="client",
user_id=str(token_date.user_id), user_id=str(token_date.user_id),
user_name=token_date.user_name, user_name=token_date.user_name,

View File

@ -1,7 +1,6 @@
import ray import ray
from pydantic_ai import Agent from pydantic_ai import Agent
from pretor.core.workflow.workflow import PretorWorkflow, WorkStep, WorkerGroup from pretor.core.workflow.workflow import PretorWorkflow, WorkStep, WorkerGroup
import uuid
@ray.remote @ray.remote
class ConsciousnessNode: class ConsciousnessNode:

View File

@ -3,34 +3,22 @@ import asyncio
from pretor.core.workflow.workflow import PretorWorkflow, WorkStep from pretor.core.workflow.workflow import PretorWorkflow, WorkStep
from loguru import logger from loguru import logger
from typing import Optional, Dict, Union, Any, List from typing import Optional, Dict, Union, Any, List
from pretor.utils.error import WorkflowError, WorkflowExit
class WorkflowEngine: class WorkflowEngine:
def __init__(self, workflow: PretorWorkflow): def __init__(self, workflow: PretorWorkflow):
self.workflow: PretorWorkflow = workflow self.workflow: PretorWorkflow = workflow
# 局部上下文记忆(黑板):用于存放上一个步骤的 output作为下一个步骤的 input
self.context_memory: Dict[str, Any] = {} self.context_memory: Dict[str, Any] = {}
self._steps_by_id: Dict[int, WorkStep] = {step.step: step for step in self.workflow.work_link}
def _get_step_by_id(self, step_id: int) -> Optional[WorkStep]:
"""根据序号获取当前步骤的定义"""
for step in self.workflow.work_link:
if step.step == step_id:
return step
return None
def _prepare_inputs(self, inputs: Optional[Union[str, List[str]]]) -> Any: def _prepare_inputs(self, inputs: Optional[Union[str, List[str]]]) -> Any:
"""从上下文中提取当前步骤所需的入参""" match inputs:
if not inputs: case None:
return None return None
case str(name):
if isinstance(inputs, str): return self.context_memory.get(name)
# 如果 input 是单一变量名,直接返回该变量的值 case list(names):
return self.context_memory.get(inputs) return {k: self.context_memory.get(k) for k in names}
if isinstance(inputs, list):
# 如果 input 是列表,返回包含这些变量名及其值的字典
return {k: self.context_memory.get(k) for k in inputs}
return None
async def run(self): async def run(self):
logger.info(f"🚀 工作流引擎启动: {self.workflow.title} [Trace ID: {self.workflow.trace_id}]") logger.info(f"🚀 工作流引擎启动: {self.workflow.title} [Trace ID: {self.workflow.trace_id}]")
@ -39,7 +27,7 @@ class WorkflowEngine:
# 核心调度循环:只要 step 在合法范围内,就一直执行 # 核心调度循环:只要 step 在合法范围内,就一直执行
while 1 <= self.workflow.status.step <= max_step: while 1 <= self.workflow.status.step <= max_step:
current_step_id = self.workflow.status.step current_step_id = self.workflow.status.step
current_step = self._get_step_by_id(current_step_id) current_step = self._steps_by_id.get(current_step_id)
if not current_step: if not current_step:
logger.error(f"严重错误:找不到步骤 {current_step_id},工作流强制终止。") logger.error(f"严重错误:找不到步骤 {current_step_id},工作流强制终止。")
@ -71,6 +59,13 @@ class WorkflowEngine:
# 4. 根据执行成功与否处理逻辑门跳转 # 4. 根据执行成功与否处理逻辑门跳转
self._handle_logic_gate(current_step, is_success) self._handle_logic_gate(current_step, is_success)
except WorkflowExit:
logger.info("命中 if_pass='exit',工作流被主动要求结束。")
break
except WorkflowError as e:
logger.error(f"{e},终止工作流。")
self.workflow.status.status = "failed"
break
except Exception as e: except Exception as e:
# 捕获系统级崩溃 (例如 Ray Actor 断联、网络异常) # 捕获系统级崩溃 (例如 Ray Actor 断联、网络异常)
logger.error(f"❌ Step {current_step_id} 发生系统级未捕获异常: {e}", exc_info=True) logger.error(f"❌ Step {current_step_id} 发生系统级未捕获异常: {e}", exc_info=True)
@ -98,6 +93,7 @@ class WorkflowEngine:
# worker = get_worker_actor(step.action) # worker = get_worker_actor(step.action)
# result = await worker.run.remote(step.desc, input_data) # result = await worker.run.remote(step.desc, input_data)
# return result, True # return result, True
await asyncio.sleep(1) await asyncio.sleep(1)
simulated_result = f"这是 {step.action} 动作产生的模拟结果" simulated_result = f"这是 {step.action} 动作产生的模拟结果"
is_success = True is_success = True
@ -109,22 +105,21 @@ class WorkflowEngine:
if is_success: if is_success:
if gate and gate.if_pass == "exit": if gate and gate.if_pass == "exit":
logger.info("命中 if_pass='exit',工作流被主动要求结束。") raise WorkflowExit()
self.workflow.status.step = 999999 # 设置一个越界值来终结 while 循环 self.workflow.status.step += 1 # 默认成功则步数 +1继续下一步
else:
self.workflow.status.step += 1 # 默认成功则步数 +1继续下一步
else: else:
if gate and gate.if_fail: if not gate or not gate.if_fail:
if gate.if_fail.startswith("jump_to_step_"): raise WorkflowError(f"步骤 {step.step} 失败且未配置 if_fail 兜底方案")
target_step = int(gate.if_fail.split("_")[-1])
match gate.if_fail.split("_"):
case ["jump", "to", "step", target] if target.isdigit():
target_step = int(target)
logger.warning(f"触发逻辑门分支!从 Step {step.step} 跳转至 Step {target_step}") logger.warning(f"触发逻辑门分支!从 Step {step.step} 跳转至 Step {target_step}")
self.workflow.status.step = target_step self.workflow.status.step = target_step
else: case _:
logger.error(f"未知的 if_fail 格式: {gate.if_fail},终止工作流。") raise WorkflowError(f"未知的 if_fail 格式: {gate.if_fail}")
self.workflow.status.step = 999999
else:
logger.error(f"步骤 {step.step} 失败且未配置 if_fail 兜底方案,工作流异常终止。")
self.workflow.status.step = 999999
@ray.remote @ray.remote
class WorkflowRunningEngine: class WorkflowRunningEngine:
@ -139,12 +134,16 @@ class WorkflowRunningEngine:
} }
self.workflow_queue = asyncio.Queue() self.workflow_queue = asyncio.Queue()
async def runner(self,i: int): async def runner(self, i: int):
while True: while True:
try: try:
workflow = await self.workflow_queue.get() workflow = await self.workflow_queue.get()
logger.info(f"WorkflowRunningEngine: runner_{i}接收工作流{workflow.trace_id}:{workflow.title}") logger.info(f"WorkflowRunningEngine: runner_{i}接收工作流{workflow.trace_id}:{workflow.title}")
workflow_engine = WorkflowEngine(workflow) workflow_engine = WorkflowEngine(workflow)
await workflow_engine.run() await workflow_engine.run()
except: except asyncio.CancelledError:
pass logger.info(f"WorkflowRunningEngine: runner_{i} 被取消。")
raise
except Exception as e:
logger.error(f"WorkflowRunningEngine: runner_{i} 遇到未捕获的异常: {e}", exc_info=True)

View File

@ -18,3 +18,11 @@ class ProviderError(Exception):
class ProviderNotExistError(ProviderError): class ProviderNotExistError(ProviderError):
pass pass
class WorkflowError(Exception):
pass
class WorkflowExit(WorkflowError):
pass

View File

@ -13,6 +13,7 @@ dependencies = [
"loguru>=0.7.3", "loguru>=0.7.3",
"passlib[bcrypt]>=1.7.4", "passlib[bcrypt]>=1.7.4",
"pydantic-ai>=1.73.0", "pydantic-ai>=1.73.0",
"pytest>=9.0.3",
"python-ulid>=3.1.0", "python-ulid>=3.1.0",
"ray[default,serve]>=2.54.0", "ray[default,serve]>=2.54.0",
"sqlmodel>=0.0.37", "sqlmodel>=0.0.37",

View File

@ -0,0 +1,88 @@
import sys
from unittest.mock import MagicMock, patch
# Mock dependencies before importing the module under test
class MockHTTPException(Exception):
def __init__(self, status_code, detail=None, headers=None):
self.status_code = status_code
self.detail = detail
self.headers = headers
class MockValidationError(Exception):
pass
mock_fastapi = MagicMock()
mock_fastapi.HTTPException = MockHTTPException
mock_fastapi.status.HTTP_401_UNAUTHORIZED = 401
mock_pydantic = MagicMock()
mock_pydantic.ValidationError = MockValidationError
sys.modules["fastapi"] = mock_fastapi
sys.modules["pydantic"] = mock_pydantic
sys.modules["sqlmodel"] = MagicMock()
sys.modules["passlib"] = MagicMock()
sys.modules["passlib.context"] = MagicMock()
sys.modules["pretor.core.database.table.user"] = MagicMock()
import pytest
import jwt
from pretor.utils.access import Accessor
def test_decode_token_success():
"""Test successful token decoding."""
token = "valid.token.here"
payload = {"user_id": "123", "username": "testuser", "exp": 1234567890}
with patch("jwt.decode", return_value=payload) as mock_decode:
with patch("pretor.utils.access.TokenData") as mock_token_data_cls:
mock_token_data_instance = MagicMock()
mock_token_data_cls.return_value = mock_token_data_instance
result = Accessor._decode_token(token)
mock_decode.assert_called_once()
mock_token_data_cls.assert_called_once_with(**payload)
assert result == mock_token_data_instance
def test_decode_token_expired():
"""Test token decoding with an expired token."""
token = "expired.token.here"
with patch("jwt.decode", side_effect=jwt.ExpiredSignatureError):
with pytest.raises(MockHTTPException) as excinfo:
Accessor._decode_token(token)
assert excinfo.value.status_code == 401
assert excinfo.value.detail == "Token 已过期"
def test_decode_token_invalid():
"""Test token decoding with an invalid token."""
token = "invalid.token.here"
with patch("jwt.decode", side_effect=jwt.InvalidTokenError):
with pytest.raises(MockHTTPException) as excinfo:
Accessor._decode_token(token)
assert excinfo.value.status_code == 401
assert excinfo.value.detail == "无效的认证凭证"
def test_decode_token_validation_error():
"""Test token decoding with a payload that fails validation."""
token = "valid.jwt.invalid.payload"
payload = {"wrong": "payload"}
with patch("jwt.decode", return_value=payload):
with patch("pretor.utils.access.TokenData", side_effect=MockValidationError):
with pytest.raises(MockHTTPException) as excinfo:
Accessor._decode_token(token)
assert excinfo.value.status_code == 401
assert excinfo.value.detail == "无效的认证凭证"

36
uv.lock
View File

@ -189,6 +189,7 @@ dependencies = [
{ name = "loguru" }, { name = "loguru" },
{ name = "passlib", extra = ["bcrypt"] }, { name = "passlib", extra = ["bcrypt"] },
{ name = "pydantic-ai" }, { name = "pydantic-ai" },
{ name = "pytest" },
{ name = "python-ulid" }, { name = "python-ulid" },
{ name = "ray", extra = ["default", "serve"] }, { name = "ray", extra = ["default", "serve"] },
{ name = "sqlmodel" }, { name = "sqlmodel" },
@ -205,6 +206,7 @@ requires-dist = [
{ name = "loguru", specifier = ">=0.7.3" }, { name = "loguru", specifier = ">=0.7.3" },
{ name = "passlib", extras = ["bcrypt"], specifier = ">=1.7.4" }, { name = "passlib", extras = ["bcrypt"], specifier = ">=1.7.4" },
{ name = "pydantic-ai", specifier = ">=1.73.0" }, { name = "pydantic-ai", specifier = ">=1.73.0" },
{ name = "pytest", specifier = ">=9.0.3" },
{ name = "python-ulid", specifier = ">=3.1.0" }, { name = "python-ulid", specifier = ">=3.1.0" },
{ name = "ray", extras = ["default", "serve"], specifier = ">=2.54.0" }, { name = "ray", extras = ["default", "serve"], specifier = ">=2.54.0" },
{ name = "sqlmodel", specifier = ">=0.0.37" }, { name = "sqlmodel", specifier = ">=0.0.37" },
@ -1234,6 +1236,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/fa/5e/f8e9a1d23b9c20a551a8a02ea3637b4642e22c2626e3a13a9a29cdea99eb/importlib_metadata-8.7.1-py3-none-any.whl", hash = "sha256:5a1f80bf1daa489495071efbb095d75a634cf28a8bc299581244063b53176151", size = 27865, upload-time = "2025-12-21T10:00:18.329Z" }, { url = "https://files.pythonhosted.org/packages/fa/5e/f8e9a1d23b9c20a551a8a02ea3637b4642e22c2626e3a13a9a29cdea99eb/importlib_metadata-8.7.1-py3-none-any.whl", hash = "sha256:5a1f80bf1daa489495071efbb095d75a634cf28a8bc299581244063b53176151", size = 27865, upload-time = "2025-12-21T10:00:18.329Z" },
] ]
[[package]]
name = "iniconfig"
version = "2.3.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/72/34/14ca021ce8e5dfedc35312d08ba8bf51fdd999c576889fc2c24cb97f4f10/iniconfig-2.3.0.tar.gz", hash = "sha256:c76315c77db068650d49c5b56314774a7804df16fee4402c1f19d6d15d8c4730", size = 20503, upload-time = "2025-10-18T21:55:43.219Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/cb/b1/3846dd7f199d53cb17f49cba7e651e9ce294d8497c8c150530ed11865bb8/iniconfig-2.3.0-py3-none-any.whl", hash = "sha256:f631c04d2c48c52b84d0d0549c99ff3859c98df65b3101406327ecc7d53fbf12", size = 7484, upload-time = "2025-10-18T21:55:41.639Z" },
]
[[package]] [[package]]
name = "jaraco-classes" name = "jaraco-classes"
version = "3.4.0" version = "3.4.0"
@ -1954,6 +1965,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/63/d7/97f7e3a6abb67d8080dd406fd4df842c2be0efaf712d1c899c32a075027c/platformdirs-4.9.4-py3-none-any.whl", hash = "sha256:68a9a4619a666ea6439f2ff250c12a853cd1cbd5158d258bd824a7df6be2f868", size = 21216, upload-time = "2026-03-05T18:34:12.172Z" }, { url = "https://files.pythonhosted.org/packages/63/d7/97f7e3a6abb67d8080dd406fd4df842c2be0efaf712d1c899c32a075027c/platformdirs-4.9.4-py3-none-any.whl", hash = "sha256:68a9a4619a666ea6439f2ff250c12a853cd1cbd5158d258bd824a7df6be2f868", size = 21216, upload-time = "2026-03-05T18:34:12.172Z" },
] ]
[[package]]
name = "pluggy"
version = "1.6.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/f9/e2/3e91f31a7d2b083fe6ef3fa267035b518369d9511ffab804f839851d2779/pluggy-1.6.0.tar.gz", hash = "sha256:7dcc130b76258d33b90f61b658791dede3486c3e6bfb003ee5c9bfb396dd22f3", size = 69412, upload-time = "2025-05-15T12:30:07.975Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746", size = 20538, upload-time = "2025-05-15T12:30:06.134Z" },
]
[[package]] [[package]]
name = "prometheus-client" name = "prometheus-client"
version = "0.24.1" version = "0.24.1"
@ -2404,6 +2424,22 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/df/80/fc9d01d5ed37ba4c42ca2b55b4339ae6e200b456be3a1aaddf4a9fa99b8c/pyperclip-1.11.0-py3-none-any.whl", hash = "sha256:299403e9ff44581cb9ba2ffeed69c7aa96a008622ad0c46cb575ca75b5b84273", size = 11063, upload-time = "2025-09-26T14:40:36.069Z" }, { url = "https://files.pythonhosted.org/packages/df/80/fc9d01d5ed37ba4c42ca2b55b4339ae6e200b456be3a1aaddf4a9fa99b8c/pyperclip-1.11.0-py3-none-any.whl", hash = "sha256:299403e9ff44581cb9ba2ffeed69c7aa96a008622ad0c46cb575ca75b5b84273", size = 11063, upload-time = "2025-09-26T14:40:36.069Z" },
] ]
[[package]]
name = "pytest"
version = "9.0.3"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "colorama", marker = "sys_platform == 'win32'" },
{ name = "iniconfig" },
{ name = "packaging" },
{ name = "pluggy" },
{ name = "pygments" },
]
sdist = { url = "https://files.pythonhosted.org/packages/7d/0d/549bd94f1a0a402dc8cf64563a117c0f3765662e2e668477624baeec44d5/pytest-9.0.3.tar.gz", hash = "sha256:b86ada508af81d19edeb213c681b1d48246c1a91d304c6c81a427674c17eb91c", size = 1572165, upload-time = "2026-04-07T17:16:18.027Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/d4/24/a372aaf5c9b7208e7112038812994107bc65a84cd00e0354a88c2c77a617/pytest-9.0.3-py3-none-any.whl", hash = "sha256:2c5efc453d45394fdd706ade797c0a81091eccd1d6e4bccfcd476e2b8e0ab5d9", size = 375249, upload-time = "2026-04-07T17:16:16.13Z" },
]
[[package]] [[package]]
name = "python-dateutil" name = "python-dateutil"
version = "2.9.0.post0" version = "2.9.0.post0"