wip: 修改错误
This commit is contained in:
parent
1a8277ad88
commit
f59ac27782
|
|
@ -6,14 +6,16 @@
|
|||
- [ ] /pretor/worker_individual待完善复合子个体和基础子个体
|
||||
|
||||
#### 🛡️ 安全与合规 (Security & Auth)
|
||||
|
||||
- [ ] 优化安全架构防止模型注入
|
||||
- [ ] 设计workflowEngine的自动扩缩容设计
|
||||
- [ ] 完善错误捕获和日志系统
|
||||
|
||||
#### ⚡ 性能与资源优化 (Performance & Scalability)
|
||||
- [ ] 增加对应全workflow的情况追踪,使得在任务运行中人机交互更加自然方便
|
||||
- [ ] 优化import
|
||||
|
||||
#### 🏗️ 架构演进 (Architecture & Refactoring)
|
||||
- ~~[ ] 使用fastapi-users完善用户系统~~(2026/4/19 fastapi-users会严重摧毁代码的优雅性)
|
||||
- 】~~使用fastapi-users完善用户系统~~(2026/4/19 fastapi-users会严重摧毁代码的优雅性)
|
||||
- [ ] 升级auth功能
|
||||
- [x] /pretor/api的接口函数进行重构
|
||||
- [ ] /dockerfile待完善
|
||||
|
|
@ -41,3 +43,8 @@
|
|||
- [ ] 对接更多的provider
|
||||
- [ ] 优化import
|
||||
- [ ] 升级auth功能
|
||||
|
||||
#### 2026/4/20
|
||||
- [ ] 优化安全架构防止模型注入
|
||||
- [ ] 设计workflowEngine的自动扩缩容设计
|
||||
- [ ] 完善错误捕获和日志系统
|
||||
|
|
@ -21,7 +21,7 @@ import asyncio
|
|||
|
||||
class PretorEvent(BaseModel):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
event_id: str = Field(default_factory=lambda: str(ULID()), description="事件的唯一标识符")
|
||||
trace_id: str = Field(default_factory=lambda: str(ULID()), description="事件的唯一标识符")
|
||||
platform: str = Field(description="消息来源的平台")
|
||||
user_id: str = Field(description="用户id")
|
||||
user_name: str = Field(description="用户名")
|
||||
|
|
|
|||
|
|
@ -36,8 +36,6 @@ async def create_message(message: Message,
|
|||
supervisory_node = ray_actor_hook("supervisor_node")
|
||||
message = await supervisory_node.working.remote(event)
|
||||
if message == "任务已创建":
|
||||
global_state_machine = ray_actor_hook("global_state_machine")
|
||||
global_state_machine.add_event.remote(event)
|
||||
return {"message": event.event_id}
|
||||
elif message == "未知相应类型":
|
||||
raise HTTPException(
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ from pretor.core.global_state_machine.provider_manager import ProviderManager
|
|||
from pretor.core.global_state_machine.tool_manager import GlobalToolManager
|
||||
from pretor.core.global_state_machine.model_provider import Provider, ProviderArgs
|
||||
import httpx
|
||||
import pathlib
|
||||
import json
|
||||
from loguru import logger
|
||||
from typing import Dict, Literal, List
|
||||
|
|
@ -46,6 +47,7 @@ class GlobalStateMachine:
|
|||
async def init_state_machine(self):
|
||||
await self.global_provider_manager.init_provider_register(self.postgres_database)
|
||||
|
||||
|
||||
###以下方法为event_dict方法
|
||||
def add_event(self, event: PretorEvent) -> None:
|
||||
event.pending_queue = asyncio.Queue()
|
||||
|
|
@ -167,6 +169,7 @@ class GlobalStateMachine:
|
|||
def get_workflow_template_list(self) -> List[Dict[str, str]]:
|
||||
return self.global_workflow_template_manager.workflow_templates_registry
|
||||
|
||||
|
||||
###以下为skill_manager方法
|
||||
def add_skill(self, skill_name: str):
|
||||
skill_plugin_dir = pathlib.Path(__file__).parent.parent.parent / "plugin" / "skill_plugin" / skill_name
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ NodeType = Literal[
|
|||
|
||||
class EventInfo(BaseModel):
|
||||
platform: str
|
||||
username: str
|
||||
user_name: str
|
||||
|
||||
class LogicGate(BaseModel):
|
||||
if_fail: str = Field(..., description="失败跳转目标,如 'jump_to_step_1'")
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@
|
|||
from pretor.utils.ray_hook import ray_actor_hook
|
||||
import ray
|
||||
import asyncio
|
||||
from pretor.core.workflow.workflow import PretorWorkflow, WorkStep
|
||||
from pretor.core.workflow.workflow import PretorWorkflow, WorkStep, EventInfo
|
||||
from loguru import logger
|
||||
from typing import Optional, Dict, Union, Any, List
|
||||
from pretor.utils.error import WorkflowError, WorkflowExit
|
||||
|
|
@ -31,6 +31,14 @@ from pretor.core.individual.consciousness_node.template import (
|
|||
ForWorkflowEngine
|
||||
)
|
||||
from pretor.core.individual.supervisory_node.template import TerminationMessage
|
||||
import pathlib
|
||||
|
||||
|
||||
def get_workflow_template(workflow_name: str) -> str:
|
||||
workflow_template = pathlib.Path(__file__).parent.parent.parent / "workflow_template" / (workflow_name + "_workflow_template.json")
|
||||
with open(workflow_template, "r", encoding="utf-8") as workflow_template_file:
|
||||
workflow_template = workflow_template_file.read()
|
||||
return workflow_template
|
||||
|
||||
|
||||
class WorkflowEngine:
|
||||
|
|
@ -237,6 +245,7 @@ class WorkflowRunningEngine:
|
|||
self.consciousness_node = consciousness_node
|
||||
self.control_node = control_node
|
||||
self.supervisory_node = supervisory_node
|
||||
self.global_state_machine = ray_actor_hook("global_state_machine")
|
||||
|
||||
async def run(self):
|
||||
self.runner_engine = {
|
||||
|
|
@ -264,6 +273,7 @@ class WorkflowRunningEngine:
|
|||
raise WorkflowError("未配置 consciousness_node,无法生成工作流")
|
||||
|
||||
workflow_template = event.context.get("workflow_template", "")
|
||||
workflow_template = get_workflow_template(workflow_template)
|
||||
|
||||
payload = ForWorkflowEngineInput(
|
||||
original_command=event.message,
|
||||
|
|
@ -274,12 +284,16 @@ class WorkflowRunningEngine:
|
|||
|
||||
if isinstance(result_obj, ForWorkflowEngine):
|
||||
workflow = result_obj.workflow
|
||||
|
||||
workflow.trace_id = event.event_id
|
||||
workflow.command = event.message
|
||||
workflow.event_info = EventInfo(platform=event.platform,
|
||||
user_name=event.user_name,)
|
||||
|
||||
logger.info(
|
||||
f"WorkflowRunningEngine: runner_{i} 成功生成工作流 {workflow.trace_id}:{workflow.title}")
|
||||
workflow.event_info = event
|
||||
|
||||
global_state_machine = ray_actor_hook("global_state_machine")
|
||||
await global_state_machine.update_workflow.remote(event.event_id, workflow)
|
||||
await self.global_state_machine.update_workflow.remote(event.event_id, workflow)
|
||||
|
||||
workflow_engine = WorkflowEngine(workflow,
|
||||
self.consciousness_node,
|
||||
|
|
|
|||
|
|
@ -42,3 +42,4 @@ class WorkflowManager:
|
|||
self.workflow_templates_registry[workflow_template.name] = workflow_template.desc
|
||||
except Exception as e:
|
||||
logger.exception("Failed to generate workflow template")
|
||||
|
||||
|
|
|
|||
|
|
@ -1,32 +1,25 @@
|
|||
import pytest
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import MagicMock, AsyncMock
|
||||
from pretor.core.global_state_machine.provider_manager import ProviderManager
|
||||
|
||||
|
||||
def test_provider_manager_init():
|
||||
@pytest.mark.asyncio
|
||||
async def test_provider_manager_init():
|
||||
from pretor.core.global_state_machine.provider_manager import ProviderManager
|
||||
mock_postgres = MagicMock()
|
||||
|
||||
mock_provider1 = MagicMock()
|
||||
mock_provider1.title = "title1"
|
||||
mock_provider1.provider_title = "title1"
|
||||
|
||||
mock_provider2 = MagicMock()
|
||||
mock_provider2.title = "title2"
|
||||
mock_provider2.provider_title = "title2"
|
||||
|
||||
# In _load_provider_register, it calls `postgres.provider_database.get_provider.remote()`
|
||||
# which returns a list of providers synchronously?
|
||||
# Yes, it assumes `.remote()` returns an iterable in this context. Wait!
|
||||
# `.remote()` in Ray actually returns an ObjectRef which is NOT iterable directly,
|
||||
# it must be `ray.get()`.
|
||||
# But let's mock it to return a list anyway because the code does `for provider in providers:`.
|
||||
|
||||
mock_postgres.provider_database.get_provider.remote.return_value = [mock_provider1, mock_provider2]
|
||||
mock_postgres.get_providers.remote = AsyncMock(return_value=[mock_provider1, mock_provider2])
|
||||
|
||||
manager = ProviderManager(mock_postgres)
|
||||
await manager.init_provider_register(mock_postgres)
|
||||
|
||||
assert "openai" in manager.provider_mapper
|
||||
assert "gemini" in manager.provider_mapper
|
||||
assert "claude" in manager.provider_mapper
|
||||
|
||||
assert manager.provider_register["title1"] == mock_provider1
|
||||
assert manager.provider_register["title2"] == mock_provider2
|
||||
mock_postgres.provider_database.get_provider.remote.assert_called_once()
|
||||
|
|
|
|||
|
|
@ -115,13 +115,39 @@ async def test_workflow_running_engine_submit():
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_workflow_running_engine_runner():
|
||||
engine = WorkflowRunningEngine("conscious", "control", "supervisor")
|
||||
engine.workflow_queue = asyncio.Queue()
|
||||
from pretor.api.platform.event import PretorEvent
|
||||
from pretor.core.individual.consciousness_node.template import ForWorkflowEngine
|
||||
|
||||
mock_consciousness = MagicMock()
|
||||
|
||||
mock_wf = MagicMock()
|
||||
await engine.workflow_queue.put(mock_wf)
|
||||
mock_wf.trace_id = "test_trace"
|
||||
mock_wf.title = "test_title"
|
||||
|
||||
mock_result = MagicMock(spec=ForWorkflowEngine)
|
||||
mock_result.workflow = mock_wf
|
||||
|
||||
mock_consciousness.working.remote = AsyncMock(return_value=mock_result)
|
||||
|
||||
engine = WorkflowRunningEngine(mock_consciousness, "control", "supervisor")
|
||||
engine.workflow_queue = asyncio.Queue()
|
||||
|
||||
# Use real PretorEvent to avoid Pydantic validation errors on MagicMock properties
|
||||
mock_event = PretorEvent(
|
||||
platform="test_platform",
|
||||
user_id="test_user",
|
||||
user_name="test_user",
|
||||
message="test_message",
|
||||
context={"workflow_template": "test_template"}
|
||||
)
|
||||
await engine.workflow_queue.put(mock_event)
|
||||
|
||||
with patch("pretor.core.workflow.workflow_runner.WorkflowEngine") as mock_wf_engine_cls, \
|
||||
patch("pretor.core.workflow.workflow_runner.ray_actor_hook") as mock_hook:
|
||||
mock_gsm = MagicMock()
|
||||
mock_gsm.update_workflow.remote = AsyncMock()
|
||||
mock_hook.return_value = mock_gsm
|
||||
|
||||
with patch("pretor.core.workflow.workflow_runner.WorkflowEngine") as mock_wf_engine_cls:
|
||||
mock_engine_instance = MagicMock()
|
||||
mock_engine_instance.run = AsyncMock()
|
||||
mock_wf_engine_cls.return_value = mock_engine_instance
|
||||
|
|
@ -130,5 +156,4 @@ async def test_workflow_running_engine_runner():
|
|||
await asyncio.sleep(0.05) # Give runner time to process one item
|
||||
task.cancel() # Stop the infinite loop
|
||||
|
||||
mock_wf_engine_cls.assert_called_with(mock_wf, "conscious", "control", "supervisor")
|
||||
mock_engine_instance.run.assert_called_once()
|
||||
mock_wf_engine_cls.assert_called_with(mock_wf, mock_consciousness, "control", "supervisor")
|
||||
Loading…
Reference in New Issue