wip: 修改错误
This commit is contained in:
parent
1a8277ad88
commit
f59ac27782
|
|
@ -6,14 +6,16 @@
|
||||||
- [ ] /pretor/worker_individual待完善复合子个体和基础子个体
|
- [ ] /pretor/worker_individual待完善复合子个体和基础子个体
|
||||||
|
|
||||||
#### 🛡️ 安全与合规 (Security & Auth)
|
#### 🛡️ 安全与合规 (Security & Auth)
|
||||||
|
- [ ] 优化安全架构防止模型注入
|
||||||
|
- [ ] 设计workflowEngine的自动扩缩容设计
|
||||||
|
- [ ] 完善错误捕获和日志系统
|
||||||
|
|
||||||
#### ⚡ 性能与资源优化 (Performance & Scalability)
|
#### ⚡ 性能与资源优化 (Performance & Scalability)
|
||||||
- [ ] 增加对应全workflow的情况追踪,使得在任务运行中人机交互更加自然方便
|
- [ ] 增加对应全workflow的情况追踪,使得在任务运行中人机交互更加自然方便
|
||||||
- [ ] 优化import
|
- [ ] 优化import
|
||||||
|
|
||||||
#### 🏗️ 架构演进 (Architecture & Refactoring)
|
#### 🏗️ 架构演进 (Architecture & Refactoring)
|
||||||
- ~~[ ] 使用fastapi-users完善用户系统~~(2026/4/19 fastapi-users会严重摧毁代码的优雅性)
|
- 】~~使用fastapi-users完善用户系统~~(2026/4/19 fastapi-users会严重摧毁代码的优雅性)
|
||||||
- [ ] 升级auth功能
|
- [ ] 升级auth功能
|
||||||
- [x] /pretor/api的接口函数进行重构
|
- [x] /pretor/api的接口函数进行重构
|
||||||
- [ ] /dockerfile待完善
|
- [ ] /dockerfile待完善
|
||||||
|
|
@ -41,3 +43,8 @@
|
||||||
- [ ] 对接更多的provider
|
- [ ] 对接更多的provider
|
||||||
- [ ] 优化import
|
- [ ] 优化import
|
||||||
- [ ] 升级auth功能
|
- [ ] 升级auth功能
|
||||||
|
|
||||||
|
#### 2026/4/20
|
||||||
|
- [ ] 优化安全架构防止模型注入
|
||||||
|
- [ ] 设计workflowEngine的自动扩缩容设计
|
||||||
|
- [ ] 完善错误捕获和日志系统
|
||||||
|
|
@ -21,7 +21,7 @@ import asyncio
|
||||||
|
|
||||||
class PretorEvent(BaseModel):
|
class PretorEvent(BaseModel):
|
||||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||||
event_id: str = Field(default_factory=lambda: str(ULID()), description="事件的唯一标识符")
|
trace_id: str = Field(default_factory=lambda: str(ULID()), description="事件的唯一标识符")
|
||||||
platform: str = Field(description="消息来源的平台")
|
platform: str = Field(description="消息来源的平台")
|
||||||
user_id: str = Field(description="用户id")
|
user_id: str = Field(description="用户id")
|
||||||
user_name: str = Field(description="用户名")
|
user_name: str = Field(description="用户名")
|
||||||
|
|
|
||||||
|
|
@ -36,8 +36,6 @@ async def create_message(message: Message,
|
||||||
supervisory_node = ray_actor_hook("supervisor_node")
|
supervisory_node = ray_actor_hook("supervisor_node")
|
||||||
message = await supervisory_node.working.remote(event)
|
message = await supervisory_node.working.remote(event)
|
||||||
if message == "任务已创建":
|
if message == "任务已创建":
|
||||||
global_state_machine = ray_actor_hook("global_state_machine")
|
|
||||||
global_state_machine.add_event.remote(event)
|
|
||||||
return {"message": event.event_id}
|
return {"message": event.event_id}
|
||||||
elif message == "未知相应类型":
|
elif message == "未知相应类型":
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
|
|
||||||
|
|
@ -18,6 +18,7 @@ from pretor.core.global_state_machine.provider_manager import ProviderManager
|
||||||
from pretor.core.global_state_machine.tool_manager import GlobalToolManager
|
from pretor.core.global_state_machine.tool_manager import GlobalToolManager
|
||||||
from pretor.core.global_state_machine.model_provider import Provider, ProviderArgs
|
from pretor.core.global_state_machine.model_provider import Provider, ProviderArgs
|
||||||
import httpx
|
import httpx
|
||||||
|
import pathlib
|
||||||
import json
|
import json
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from typing import Dict, Literal, List
|
from typing import Dict, Literal, List
|
||||||
|
|
@ -46,6 +47,7 @@ class GlobalStateMachine:
|
||||||
async def init_state_machine(self):
|
async def init_state_machine(self):
|
||||||
await self.global_provider_manager.init_provider_register(self.postgres_database)
|
await self.global_provider_manager.init_provider_register(self.postgres_database)
|
||||||
|
|
||||||
|
|
||||||
###以下方法为event_dict方法
|
###以下方法为event_dict方法
|
||||||
def add_event(self, event: PretorEvent) -> None:
|
def add_event(self, event: PretorEvent) -> None:
|
||||||
event.pending_queue = asyncio.Queue()
|
event.pending_queue = asyncio.Queue()
|
||||||
|
|
@ -167,6 +169,7 @@ class GlobalStateMachine:
|
||||||
def get_workflow_template_list(self) -> List[Dict[str, str]]:
|
def get_workflow_template_list(self) -> List[Dict[str, str]]:
|
||||||
return self.global_workflow_template_manager.workflow_templates_registry
|
return self.global_workflow_template_manager.workflow_templates_registry
|
||||||
|
|
||||||
|
|
||||||
###以下为skill_manager方法
|
###以下为skill_manager方法
|
||||||
def add_skill(self, skill_name: str):
|
def add_skill(self, skill_name: str):
|
||||||
skill_plugin_dir = pathlib.Path(__file__).parent.parent.parent / "plugin" / "skill_plugin" / skill_name
|
skill_plugin_dir = pathlib.Path(__file__).parent.parent.parent / "plugin" / "skill_plugin" / skill_name
|
||||||
|
|
|
||||||
|
|
@ -23,7 +23,7 @@ NodeType = Literal[
|
||||||
|
|
||||||
class EventInfo(BaseModel):
|
class EventInfo(BaseModel):
|
||||||
platform: str
|
platform: str
|
||||||
username: str
|
user_name: str
|
||||||
|
|
||||||
class LogicGate(BaseModel):
|
class LogicGate(BaseModel):
|
||||||
if_fail: str = Field(..., description="失败跳转目标,如 'jump_to_step_1'")
|
if_fail: str = Field(..., description="失败跳转目标,如 'jump_to_step_1'")
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,7 @@
|
||||||
from pretor.utils.ray_hook import ray_actor_hook
|
from pretor.utils.ray_hook import ray_actor_hook
|
||||||
import ray
|
import ray
|
||||||
import asyncio
|
import asyncio
|
||||||
from pretor.core.workflow.workflow import PretorWorkflow, WorkStep
|
from pretor.core.workflow.workflow import PretorWorkflow, WorkStep, EventInfo
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from typing import Optional, Dict, Union, Any, List
|
from typing import Optional, Dict, Union, Any, List
|
||||||
from pretor.utils.error import WorkflowError, WorkflowExit
|
from pretor.utils.error import WorkflowError, WorkflowExit
|
||||||
|
|
@ -31,6 +31,14 @@ from pretor.core.individual.consciousness_node.template import (
|
||||||
ForWorkflowEngine
|
ForWorkflowEngine
|
||||||
)
|
)
|
||||||
from pretor.core.individual.supervisory_node.template import TerminationMessage
|
from pretor.core.individual.supervisory_node.template import TerminationMessage
|
||||||
|
import pathlib
|
||||||
|
|
||||||
|
|
||||||
|
def get_workflow_template(workflow_name: str) -> str:
|
||||||
|
workflow_template = pathlib.Path(__file__).parent.parent.parent / "workflow_template" / (workflow_name + "_workflow_template.json")
|
||||||
|
with open(workflow_template, "r", encoding="utf-8") as workflow_template_file:
|
||||||
|
workflow_template = workflow_template_file.read()
|
||||||
|
return workflow_template
|
||||||
|
|
||||||
|
|
||||||
class WorkflowEngine:
|
class WorkflowEngine:
|
||||||
|
|
@ -237,6 +245,7 @@ class WorkflowRunningEngine:
|
||||||
self.consciousness_node = consciousness_node
|
self.consciousness_node = consciousness_node
|
||||||
self.control_node = control_node
|
self.control_node = control_node
|
||||||
self.supervisory_node = supervisory_node
|
self.supervisory_node = supervisory_node
|
||||||
|
self.global_state_machine = ray_actor_hook("global_state_machine")
|
||||||
|
|
||||||
async def run(self):
|
async def run(self):
|
||||||
self.runner_engine = {
|
self.runner_engine = {
|
||||||
|
|
@ -264,6 +273,7 @@ class WorkflowRunningEngine:
|
||||||
raise WorkflowError("未配置 consciousness_node,无法生成工作流")
|
raise WorkflowError("未配置 consciousness_node,无法生成工作流")
|
||||||
|
|
||||||
workflow_template = event.context.get("workflow_template", "")
|
workflow_template = event.context.get("workflow_template", "")
|
||||||
|
workflow_template = get_workflow_template(workflow_template)
|
||||||
|
|
||||||
payload = ForWorkflowEngineInput(
|
payload = ForWorkflowEngineInput(
|
||||||
original_command=event.message,
|
original_command=event.message,
|
||||||
|
|
@ -274,12 +284,16 @@ class WorkflowRunningEngine:
|
||||||
|
|
||||||
if isinstance(result_obj, ForWorkflowEngine):
|
if isinstance(result_obj, ForWorkflowEngine):
|
||||||
workflow = result_obj.workflow
|
workflow = result_obj.workflow
|
||||||
|
|
||||||
|
workflow.trace_id = event.event_id
|
||||||
|
workflow.command = event.message
|
||||||
|
workflow.event_info = EventInfo(platform=event.platform,
|
||||||
|
user_name=event.user_name,)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"WorkflowRunningEngine: runner_{i} 成功生成工作流 {workflow.trace_id}:{workflow.title}")
|
f"WorkflowRunningEngine: runner_{i} 成功生成工作流 {workflow.trace_id}:{workflow.title}")
|
||||||
workflow.event_info = event
|
|
||||||
|
|
||||||
global_state_machine = ray_actor_hook("global_state_machine")
|
await self.global_state_machine.update_workflow.remote(event.event_id, workflow)
|
||||||
await global_state_machine.update_workflow.remote(event.event_id, workflow)
|
|
||||||
|
|
||||||
workflow_engine = WorkflowEngine(workflow,
|
workflow_engine = WorkflowEngine(workflow,
|
||||||
self.consciousness_node,
|
self.consciousness_node,
|
||||||
|
|
|
||||||
|
|
@ -42,3 +42,4 @@ class WorkflowManager:
|
||||||
self.workflow_templates_registry[workflow_template.name] = workflow_template.desc
|
self.workflow_templates_registry[workflow_template.name] = workflow_template.desc
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception("Failed to generate workflow template")
|
logger.exception("Failed to generate workflow template")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,32 +1,25 @@
|
||||||
import pytest
|
import pytest
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock, AsyncMock
|
||||||
from pretor.core.global_state_machine.provider_manager import ProviderManager
|
from pretor.core.global_state_machine.provider_manager import ProviderManager
|
||||||
|
|
||||||
|
|
||||||
def test_provider_manager_init():
|
@pytest.mark.asyncio
|
||||||
|
async def test_provider_manager_init():
|
||||||
|
from pretor.core.global_state_machine.provider_manager import ProviderManager
|
||||||
mock_postgres = MagicMock()
|
mock_postgres = MagicMock()
|
||||||
|
|
||||||
mock_provider1 = MagicMock()
|
mock_provider1 = MagicMock()
|
||||||
mock_provider1.title = "title1"
|
mock_provider1.provider_title = "title1"
|
||||||
|
|
||||||
mock_provider2 = MagicMock()
|
mock_provider2 = MagicMock()
|
||||||
mock_provider2.title = "title2"
|
mock_provider2.provider_title = "title2"
|
||||||
|
|
||||||
# In _load_provider_register, it calls `postgres.provider_database.get_provider.remote()`
|
mock_postgres.get_providers.remote = AsyncMock(return_value=[mock_provider1, mock_provider2])
|
||||||
# which returns a list of providers synchronously?
|
|
||||||
# Yes, it assumes `.remote()` returns an iterable in this context. Wait!
|
|
||||||
# `.remote()` in Ray actually returns an ObjectRef which is NOT iterable directly,
|
|
||||||
# it must be `ray.get()`.
|
|
||||||
# But let's mock it to return a list anyway because the code does `for provider in providers:`.
|
|
||||||
|
|
||||||
mock_postgres.provider_database.get_provider.remote.return_value = [mock_provider1, mock_provider2]
|
|
||||||
|
|
||||||
manager = ProviderManager(mock_postgres)
|
manager = ProviderManager(mock_postgres)
|
||||||
|
await manager.init_provider_register(mock_postgres)
|
||||||
|
|
||||||
assert "openai" in manager.provider_mapper
|
assert "openai" in manager.provider_mapper
|
||||||
assert "gemini" in manager.provider_mapper
|
assert "gemini" in manager.provider_mapper
|
||||||
assert "claude" in manager.provider_mapper
|
assert "claude" in manager.provider_mapper
|
||||||
|
|
||||||
assert manager.provider_register["title1"] == mock_provider1
|
assert manager.provider_register["title1"] == mock_provider1
|
||||||
assert manager.provider_register["title2"] == mock_provider2
|
|
||||||
mock_postgres.provider_database.get_provider.remote.assert_called_once()
|
|
||||||
|
|
|
||||||
|
|
@ -115,13 +115,39 @@ async def test_workflow_running_engine_submit():
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_workflow_running_engine_runner():
|
async def test_workflow_running_engine_runner():
|
||||||
engine = WorkflowRunningEngine("conscious", "control", "supervisor")
|
from pretor.api.platform.event import PretorEvent
|
||||||
engine.workflow_queue = asyncio.Queue()
|
from pretor.core.individual.consciousness_node.template import ForWorkflowEngine
|
||||||
|
|
||||||
|
mock_consciousness = MagicMock()
|
||||||
|
|
||||||
mock_wf = MagicMock()
|
mock_wf = MagicMock()
|
||||||
await engine.workflow_queue.put(mock_wf)
|
mock_wf.trace_id = "test_trace"
|
||||||
|
mock_wf.title = "test_title"
|
||||||
|
|
||||||
|
mock_result = MagicMock(spec=ForWorkflowEngine)
|
||||||
|
mock_result.workflow = mock_wf
|
||||||
|
|
||||||
|
mock_consciousness.working.remote = AsyncMock(return_value=mock_result)
|
||||||
|
|
||||||
|
engine = WorkflowRunningEngine(mock_consciousness, "control", "supervisor")
|
||||||
|
engine.workflow_queue = asyncio.Queue()
|
||||||
|
|
||||||
|
# Use real PretorEvent to avoid Pydantic validation errors on MagicMock properties
|
||||||
|
mock_event = PretorEvent(
|
||||||
|
platform="test_platform",
|
||||||
|
user_id="test_user",
|
||||||
|
user_name="test_user",
|
||||||
|
message="test_message",
|
||||||
|
context={"workflow_template": "test_template"}
|
||||||
|
)
|
||||||
|
await engine.workflow_queue.put(mock_event)
|
||||||
|
|
||||||
|
with patch("pretor.core.workflow.workflow_runner.WorkflowEngine") as mock_wf_engine_cls, \
|
||||||
|
patch("pretor.core.workflow.workflow_runner.ray_actor_hook") as mock_hook:
|
||||||
|
mock_gsm = MagicMock()
|
||||||
|
mock_gsm.update_workflow.remote = AsyncMock()
|
||||||
|
mock_hook.return_value = mock_gsm
|
||||||
|
|
||||||
with patch("pretor.core.workflow.workflow_runner.WorkflowEngine") as mock_wf_engine_cls:
|
|
||||||
mock_engine_instance = MagicMock()
|
mock_engine_instance = MagicMock()
|
||||||
mock_engine_instance.run = AsyncMock()
|
mock_engine_instance.run = AsyncMock()
|
||||||
mock_wf_engine_cls.return_value = mock_engine_instance
|
mock_wf_engine_cls.return_value = mock_engine_instance
|
||||||
|
|
@ -130,5 +156,4 @@ async def test_workflow_running_engine_runner():
|
||||||
await asyncio.sleep(0.05) # Give runner time to process one item
|
await asyncio.sleep(0.05) # Give runner time to process one item
|
||||||
task.cancel() # Stop the infinite loop
|
task.cancel() # Stop the infinite loop
|
||||||
|
|
||||||
mock_wf_engine_cls.assert_called_with(mock_wf, "conscious", "control", "supervisor")
|
mock_wf_engine_cls.assert_called_with(mock_wf, mock_consciousness, "control", "supervisor")
|
||||||
mock_engine_instance.run.assert_called_once()
|
|
||||||
Loading…
Reference in New Issue