wip:改进

This commit is contained in:
朝夕 2026-04-09 23:06:01 +08:00
parent 2552017ea7
commit a7bd7f786e
7 changed files with 173 additions and 41 deletions

View File

@ -14,7 +14,8 @@ class Message(BaseModel):
async def create_message(message: Message,
request: Request,
token_date: TokenData = Depends(Accessor.get_current_user)):
logger.info(f"收到消息,来源:客户端,消息内容:{message.message}")
logger.info("收到消息,来源:客户端")
logger.debug(f"消息内容:{message.message}")
event = PretorEvent(platform="client",
user_id=str(token_date.user_id),
user_name=token_date.user_name,

View File

@ -1,7 +1,6 @@
import ray
from pydantic_ai import Agent
from pretor.core.workflow.workflow import PretorWorkflow, WorkStep, WorkerGroup
import uuid
@ray.remote
class ConsciousnessNode:

View File

@ -3,34 +3,22 @@ import asyncio
from pretor.core.workflow.workflow import PretorWorkflow, WorkStep
from loguru import logger
from typing import Optional, Dict, Union, Any, List
from pretor.utils.error import WorkflowError, WorkflowExit
class WorkflowEngine:
def __init__(self, workflow: PretorWorkflow):
self.workflow: PretorWorkflow = workflow
# 局部上下文记忆(黑板):用于存放上一个步骤的 output作为下一个步骤的 input
self.context_memory: Dict[str, Any] = {}
def _get_step_by_id(self, step_id: int) -> Optional[WorkStep]:
"""根据序号获取当前步骤的定义"""
for step in self.workflow.work_link:
if step.step == step_id:
return step
return None
self._steps_by_id: Dict[int, WorkStep] = {step.step: step for step in self.workflow.work_link}
def _prepare_inputs(self, inputs: Optional[Union[str, List[str]]]) -> Any:
"""从上下文中提取当前步骤所需的入参"""
if not inputs:
return None
if isinstance(inputs, str):
# 如果 input 是单一变量名,直接返回该变量的值
return self.context_memory.get(inputs)
if isinstance(inputs, list):
# 如果 input 是列表,返回包含这些变量名及其值的字典
return {k: self.context_memory.get(k) for k in inputs}
return None
match inputs:
case None:
return None
case str(name):
return self.context_memory.get(name)
case list(names):
return {k: self.context_memory.get(k) for k in names}
async def run(self):
logger.info(f"🚀 工作流引擎启动: {self.workflow.title} [Trace ID: {self.workflow.trace_id}]")
@ -39,7 +27,7 @@ class WorkflowEngine:
# 核心调度循环:只要 step 在合法范围内,就一直执行
while 1 <= self.workflow.status.step <= max_step:
current_step_id = self.workflow.status.step
current_step = self._get_step_by_id(current_step_id)
current_step = self._steps_by_id.get(current_step_id)
if not current_step:
logger.error(f"严重错误:找不到步骤 {current_step_id},工作流强制终止。")
@ -71,6 +59,13 @@ class WorkflowEngine:
# 4. 根据执行成功与否处理逻辑门跳转
self._handle_logic_gate(current_step, is_success)
except WorkflowExit:
logger.info("命中 if_pass='exit',工作流被主动要求结束。")
break
except WorkflowError as e:
logger.error(f"{e},终止工作流。")
self.workflow.status.status = "failed"
break
except Exception as e:
# 捕获系统级崩溃 (例如 Ray Actor 断联、网络异常)
logger.error(f"❌ Step {current_step_id} 发生系统级未捕获异常: {e}", exc_info=True)
@ -98,6 +93,7 @@ class WorkflowEngine:
# worker = get_worker_actor(step.action)
# result = await worker.run.remote(step.desc, input_data)
# return result, True
await asyncio.sleep(1)
simulated_result = f"这是 {step.action} 动作产生的模拟结果"
is_success = True
@ -109,22 +105,21 @@ class WorkflowEngine:
if is_success:
if gate and gate.if_pass == "exit":
logger.info("命中 if_pass='exit',工作流被主动要求结束。")
self.workflow.status.step = 999999 # 设置一个越界值来终结 while 循环
else:
self.workflow.status.step += 1 # 默认成功则步数 +1继续下一步
raise WorkflowExit()
self.workflow.status.step += 1 # 默认成功则步数 +1继续下一步
else:
if gate and gate.if_fail:
if gate.if_fail.startswith("jump_to_step_"):
target_step = int(gate.if_fail.split("_")[-1])
if not gate or not gate.if_fail:
raise WorkflowError(f"步骤 {step.step} 失败且未配置 if_fail 兜底方案")
match gate.if_fail.split("_"):
case ["jump", "to", "step", target] if target.isdigit():
target_step = int(target)
logger.warning(f"触发逻辑门分支!从 Step {step.step} 跳转至 Step {target_step}")
self.workflow.status.step = target_step
else:
logger.error(f"未知的 if_fail 格式: {gate.if_fail},终止工作流。")
self.workflow.status.step = 999999
else:
logger.error(f"步骤 {step.step} 失败且未配置 if_fail 兜底方案,工作流异常终止。")
self.workflow.status.step = 999999
case _:
raise WorkflowError(f"未知的 if_fail 格式: {gate.if_fail}")
@ray.remote
class WorkflowRunningEngine:
@ -139,12 +134,16 @@ class WorkflowRunningEngine:
}
self.workflow_queue = asyncio.Queue()
async def runner(self,i: int):
async def runner(self, i: int):
while True:
try:
workflow = await self.workflow_queue.get()
logger.info(f"WorkflowRunningEngine: runner_{i}接收工作流{workflow.trace_id}:{workflow.title}")
workflow_engine = WorkflowEngine(workflow)
await workflow_engine.run()
except:
pass
except asyncio.CancelledError:
logger.info(f"WorkflowRunningEngine: runner_{i} 被取消。")
raise
except Exception as e:
logger.error(f"WorkflowRunningEngine: runner_{i} 遇到未捕获的异常: {e}", exc_info=True)

View File

@ -17,4 +17,12 @@ class ProviderError(Exception):
pass
class ProviderNotExistError(ProviderError):
pass
pass
class WorkflowError(Exception):
pass
class WorkflowExit(WorkflowError):
pass

View File

@ -13,6 +13,7 @@ dependencies = [
"loguru>=0.7.3",
"passlib[bcrypt]>=1.7.4",
"pydantic-ai>=1.73.0",
"pytest>=9.0.3",
"python-ulid>=3.1.0",
"ray[default,serve]>=2.54.0",
"sqlmodel>=0.0.37",

View File

@ -0,0 +1,88 @@
import sys
from unittest.mock import MagicMock, patch
# Mock dependencies before importing the module under test
class MockHTTPException(Exception):
def __init__(self, status_code, detail=None, headers=None):
self.status_code = status_code
self.detail = detail
self.headers = headers
class MockValidationError(Exception):
pass
mock_fastapi = MagicMock()
mock_fastapi.HTTPException = MockHTTPException
mock_fastapi.status.HTTP_401_UNAUTHORIZED = 401
mock_pydantic = MagicMock()
mock_pydantic.ValidationError = MockValidationError
sys.modules["fastapi"] = mock_fastapi
sys.modules["pydantic"] = mock_pydantic
sys.modules["sqlmodel"] = MagicMock()
sys.modules["passlib"] = MagicMock()
sys.modules["passlib.context"] = MagicMock()
sys.modules["pretor.core.database.table.user"] = MagicMock()
import pytest
import jwt
from pretor.utils.access import Accessor
def test_decode_token_success():
"""Test successful token decoding."""
token = "valid.token.here"
payload = {"user_id": "123", "username": "testuser", "exp": 1234567890}
with patch("jwt.decode", return_value=payload) as mock_decode:
with patch("pretor.utils.access.TokenData") as mock_token_data_cls:
mock_token_data_instance = MagicMock()
mock_token_data_cls.return_value = mock_token_data_instance
result = Accessor._decode_token(token)
mock_decode.assert_called_once()
mock_token_data_cls.assert_called_once_with(**payload)
assert result == mock_token_data_instance
def test_decode_token_expired():
"""Test token decoding with an expired token."""
token = "expired.token.here"
with patch("jwt.decode", side_effect=jwt.ExpiredSignatureError):
with pytest.raises(MockHTTPException) as excinfo:
Accessor._decode_token(token)
assert excinfo.value.status_code == 401
assert excinfo.value.detail == "Token 已过期"
def test_decode_token_invalid():
"""Test token decoding with an invalid token."""
token = "invalid.token.here"
with patch("jwt.decode", side_effect=jwt.InvalidTokenError):
with pytest.raises(MockHTTPException) as excinfo:
Accessor._decode_token(token)
assert excinfo.value.status_code == 401
assert excinfo.value.detail == "无效的认证凭证"
def test_decode_token_validation_error():
"""Test token decoding with a payload that fails validation."""
token = "valid.jwt.invalid.payload"
payload = {"wrong": "payload"}
with patch("jwt.decode", return_value=payload):
with patch("pretor.utils.access.TokenData", side_effect=MockValidationError):
with pytest.raises(MockHTTPException) as excinfo:
Accessor._decode_token(token)
assert excinfo.value.status_code == 401
assert excinfo.value.detail == "无效的认证凭证"

36
uv.lock
View File

@ -189,6 +189,7 @@ dependencies = [
{ name = "loguru" },
{ name = "passlib", extra = ["bcrypt"] },
{ name = "pydantic-ai" },
{ name = "pytest" },
{ name = "python-ulid" },
{ name = "ray", extra = ["default", "serve"] },
{ name = "sqlmodel" },
@ -205,6 +206,7 @@ requires-dist = [
{ name = "loguru", specifier = ">=0.7.3" },
{ name = "passlib", extras = ["bcrypt"], specifier = ">=1.7.4" },
{ name = "pydantic-ai", specifier = ">=1.73.0" },
{ name = "pytest", specifier = ">=9.0.3" },
{ name = "python-ulid", specifier = ">=3.1.0" },
{ name = "ray", extras = ["default", "serve"], specifier = ">=2.54.0" },
{ name = "sqlmodel", specifier = ">=0.0.37" },
@ -1234,6 +1236,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/fa/5e/f8e9a1d23b9c20a551a8a02ea3637b4642e22c2626e3a13a9a29cdea99eb/importlib_metadata-8.7.1-py3-none-any.whl", hash = "sha256:5a1f80bf1daa489495071efbb095d75a634cf28a8bc299581244063b53176151", size = 27865, upload-time = "2025-12-21T10:00:18.329Z" },
]
[[package]]
name = "iniconfig"
version = "2.3.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/72/34/14ca021ce8e5dfedc35312d08ba8bf51fdd999c576889fc2c24cb97f4f10/iniconfig-2.3.0.tar.gz", hash = "sha256:c76315c77db068650d49c5b56314774a7804df16fee4402c1f19d6d15d8c4730", size = 20503, upload-time = "2025-10-18T21:55:43.219Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/cb/b1/3846dd7f199d53cb17f49cba7e651e9ce294d8497c8c150530ed11865bb8/iniconfig-2.3.0-py3-none-any.whl", hash = "sha256:f631c04d2c48c52b84d0d0549c99ff3859c98df65b3101406327ecc7d53fbf12", size = 7484, upload-time = "2025-10-18T21:55:41.639Z" },
]
[[package]]
name = "jaraco-classes"
version = "3.4.0"
@ -1954,6 +1965,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/63/d7/97f7e3a6abb67d8080dd406fd4df842c2be0efaf712d1c899c32a075027c/platformdirs-4.9.4-py3-none-any.whl", hash = "sha256:68a9a4619a666ea6439f2ff250c12a853cd1cbd5158d258bd824a7df6be2f868", size = 21216, upload-time = "2026-03-05T18:34:12.172Z" },
]
[[package]]
name = "pluggy"
version = "1.6.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/f9/e2/3e91f31a7d2b083fe6ef3fa267035b518369d9511ffab804f839851d2779/pluggy-1.6.0.tar.gz", hash = "sha256:7dcc130b76258d33b90f61b658791dede3486c3e6bfb003ee5c9bfb396dd22f3", size = 69412, upload-time = "2025-05-15T12:30:07.975Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746", size = 20538, upload-time = "2025-05-15T12:30:06.134Z" },
]
[[package]]
name = "prometheus-client"
version = "0.24.1"
@ -2404,6 +2424,22 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/df/80/fc9d01d5ed37ba4c42ca2b55b4339ae6e200b456be3a1aaddf4a9fa99b8c/pyperclip-1.11.0-py3-none-any.whl", hash = "sha256:299403e9ff44581cb9ba2ffeed69c7aa96a008622ad0c46cb575ca75b5b84273", size = 11063, upload-time = "2025-09-26T14:40:36.069Z" },
]
[[package]]
name = "pytest"
version = "9.0.3"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "colorama", marker = "sys_platform == 'win32'" },
{ name = "iniconfig" },
{ name = "packaging" },
{ name = "pluggy" },
{ name = "pygments" },
]
sdist = { url = "https://files.pythonhosted.org/packages/7d/0d/549bd94f1a0a402dc8cf64563a117c0f3765662e2e668477624baeec44d5/pytest-9.0.3.tar.gz", hash = "sha256:b86ada508af81d19edeb213c681b1d48246c1a91d304c6c81a427674c17eb91c", size = 1572165, upload-time = "2026-04-07T17:16:18.027Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/d4/24/a372aaf5c9b7208e7112038812994107bc65a84cd00e0354a88c2c77a617/pytest-9.0.3-py3-none-any.whl", hash = "sha256:2c5efc453d45394fdd706ade797c0a81091eccd1d6e4bccfcd476e2b8e0ab5d9", size = 375249, upload-time = "2026-04-07T17:16:16.13Z" },
]
[[package]]
name = "python-dateutil"
version = "2.9.0.post0"