wip: 优化
This commit is contained in:
parent
2796c20f5e
commit
1715b64d17
5
.env
5
.env
|
|
@ -0,0 +1,5 @@
|
||||||
|
POSTGRES_USER=postgres
|
||||||
|
POSTGRES_PASSWORD=postgres
|
||||||
|
POSTGRES_HOST=127.0.0.1
|
||||||
|
POSTGRES_PORT=5432
|
||||||
|
POSTGRES_DB=pretor
|
||||||
2
LICENSE
2
LICENSE
|
|
@ -187,7 +187,7 @@
|
||||||
same "printed page" as the copyright notice for easier
|
same "printed page" as the copyright notice for easier
|
||||||
identification within third-party archives.
|
identification within third-party archives.
|
||||||
|
|
||||||
Copyright [2026] [zhaoxi826]
|
Copyright 2026 zhaoxi826
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
## ArchonBot项目开发
|
## Pretor项目开发
|
||||||
#项目规划
|
#项目规划
|
||||||
---
|
---
|
||||||
#### 全局规划:
|
#### 全局规划:
|
||||||
|
|
|
||||||
49
main.py
49
main.py
|
|
@ -1,7 +1,56 @@
|
||||||
|
import asyncio
|
||||||
|
import ray
|
||||||
|
|
||||||
from pretor.utils.banner import print_banner
|
from pretor.utils.banner import print_banner
|
||||||
|
from pretor.core.database.postgres import PostgresDatabase
|
||||||
|
from pretor.core.global_state_machine.global_state_machine import GlobalStateMachine
|
||||||
|
from pretor.core.individual.supervisory_node.supervisory_node import SupervisoryNode
|
||||||
|
from pretor.core.individual.consciousness_node.consciousness_node import ConsciousnessNode
|
||||||
|
from pretor.core.individual.control_node.control_node import ControlNode
|
||||||
|
from pretor.core.workflow.workflow_runner import WorkflowRunningEngine
|
||||||
|
from pretor.core.api import PretorGateway
|
||||||
|
|
||||||
|
|
||||||
|
async def start_system():
|
||||||
|
# 1. 初始化 Ray
|
||||||
|
ray.init(ignore_reinit_error=True)
|
||||||
|
|
||||||
|
# 2. 启动数据库组件
|
||||||
|
postgres_database = PostgresDatabase.remote()
|
||||||
|
await postgres_database.init_db.remote()
|
||||||
|
|
||||||
|
# 3. 启动全局状态机
|
||||||
|
global_state_machine = GlobalStateMachine.remote(postgres_database)
|
||||||
|
|
||||||
|
# 4. 启动核心节点
|
||||||
|
supervisory_node = SupervisoryNode.remote()
|
||||||
|
consciousness_node = ConsciousnessNode.remote()
|
||||||
|
control_node = ControlNode.remote()
|
||||||
|
|
||||||
|
# 5. 启动工作流运行引擎
|
||||||
|
workflow_engine = WorkflowRunningEngine.remote(
|
||||||
|
consciousness_node=consciousness_node,
|
||||||
|
control_node=control_node,
|
||||||
|
supervisory_node=supervisory_node
|
||||||
|
)
|
||||||
|
# 异步拉起 runner 协程群
|
||||||
|
workflow_engine.run.remote()
|
||||||
|
|
||||||
|
# 6. 启动 FastAPI 网关
|
||||||
|
pretor_gateway = PretorGateway.remote(
|
||||||
|
postgres_database=postgres_database,
|
||||||
|
global_state_machine=global_state_machine,
|
||||||
|
supervisory_node=supervisory_node
|
||||||
|
)
|
||||||
|
|
||||||
|
# 挂起在网关服务上,暴露 8000 端口
|
||||||
|
await pretor_gateway.server_run.remote(host="0.0.0.0", port=8000)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
print_banner()
|
print_banner()
|
||||||
|
asyncio.run(start_system())
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
main()
|
main()
|
||||||
|
|
@ -1,14 +0,0 @@
|
||||||
# Copyright 2026 zhaoxi826
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
|
|
@ -12,4 +12,3 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from pretor.adapter.model_adapter.provider_manager import ProviderManager
|
|
||||||
|
|
@ -13,13 +13,14 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import datetime
|
import datetime
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field, ConfigDict
|
||||||
from ulid import ULID
|
from ulid import ULID
|
||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
from pretor.core.workflow.workflow import PretorWorkflow
|
from pretor.core.workflow.workflow import PretorWorkflow
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
class PretorEvent(BaseModel):
|
class PretorEvent(BaseModel):
|
||||||
|
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||||
event_id: str = Field(default_factory=lambda: str(ULID()), description="事件的唯一标识符")
|
event_id: str = Field(default_factory=lambda: str(ULID()), description="事件的唯一标识符")
|
||||||
platform: str = Field(description="消息来源的平台")
|
platform: str = Field(description="消息来源的平台")
|
||||||
user_id: str = Field(description="用户id")
|
user_id: str = Field(description="用户id")
|
||||||
|
|
|
||||||
|
|
@ -13,12 +13,16 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
|
import uvicorn
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
from fastapi import FastAPI,WebSocket
|
from fastapi import FastAPI,WebSocket
|
||||||
from pretor.core.database.postgres import PostgresDatabase
|
from pretor.core.database.postgres import PostgresDatabase
|
||||||
from pretor.core.global_state_machine.global_state_machine import GlobalStateMachine
|
from pretor.core.global_state_machine.global_state_machine import GlobalStateMachine
|
||||||
from pretor.core.individual.supervisory_node.supervisory_node import SupervisoryNode
|
from pretor.core.individual.supervisory_node.supervisory_node import SupervisoryNode
|
||||||
|
|
||||||
|
from pretor.api.platform.frontend import client_router
|
||||||
|
from pretor.api.auth import auth_router
|
||||||
|
from pretor.api.provider import provider_router
|
||||||
|
|
||||||
@ray.remote
|
@ray.remote
|
||||||
class PretorGateway:
|
class PretorGateway:
|
||||||
|
|
@ -29,12 +33,19 @@ class PretorGateway:
|
||||||
supervisory_node: SupervisoryNode,):
|
supervisory_node: SupervisoryNode,):
|
||||||
self.app = FastAPI()
|
self.app = FastAPI()
|
||||||
self.gateway = {}
|
self.gateway = {}
|
||||||
self.app = FastAPI()
|
|
||||||
|
|
||||||
self.app.state.postgres_database = postgres_database
|
self.app.state.postgres_database = postgres_database
|
||||||
self.app.state.global_state_machine = global_state_machine
|
self.app.state.global_state_machine = global_state_machine
|
||||||
self.app.state.supervisory = supervisory_node
|
self.app.state.supervisory = supervisory_node
|
||||||
|
self.app.state.supervisory_node = supervisory_node
|
||||||
|
|
||||||
|
self.app.include_router(client_router)
|
||||||
|
self.app.include_router(auth_router)
|
||||||
|
self.app.include_router(provider_router)
|
||||||
|
|
||||||
|
async def server_run(self, host="0.0.0.0", port=8000):
|
||||||
|
config = uvicorn.Config(app=self.app, host=host, port=port, loop="asyncio")
|
||||||
|
server = uvicorn.Server(config)
|
||||||
|
await server.serve()
|
||||||
|
|
||||||
async def server_run(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,7 @@ from pydantic import BaseModel
|
||||||
from typing import List
|
from typing import List
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
class ProviderStatus(Enum, str):
|
class ProviderStatus(str, Enum):
|
||||||
UP = "up"
|
UP = "up"
|
||||||
DOWN = "down"
|
DOWN = "down"
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -12,3 +12,4 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
from .consciousness_node import ConsciousnessNode
|
||||||
|
|
@ -12,3 +12,4 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
from .control_node import ControlNode
|
||||||
|
|
@ -12,3 +12,4 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
from .supervisory_node import SupervisoryNode
|
||||||
|
|
@ -67,6 +67,7 @@ class PretorWorkflow(BaseModel):
|
||||||
output: Dict[str, Any] = Field(default_factory=dict, description="工作流最终产出结果")
|
output: Dict[str, Any] = Field(default_factory=dict, description="工作流最终产出结果")
|
||||||
status: WorkflowStatus = Field(default_factory=WorkflowStatus, description="运行时状态对象")
|
status: WorkflowStatus = Field(default_factory=WorkflowStatus, description="运行时状态对象")
|
||||||
event_info: EventInfo | None = Field(default_factory=None)
|
event_info: EventInfo | None = Field(default_factory=None)
|
||||||
|
context_memory: Dict[str, Any] = Field(default=Dict())
|
||||||
|
|
||||||
@model_validator(mode='after')
|
@model_validator(mode='after')
|
||||||
def validate_workflow_integrity(self) -> 'PretorWorkflow':
|
def validate_workflow_integrity(self) -> 'PretorWorkflow':
|
||||||
|
|
|
||||||
|
|
@ -33,8 +33,6 @@ class WorkflowEngine:
|
||||||
supervisory_node=None):
|
supervisory_node=None):
|
||||||
self.workflow: PretorWorkflow = workflow
|
self.workflow: PretorWorkflow = workflow
|
||||||
"""工作流:当前WorkflowEngine待执行的workflow"""
|
"""工作流:当前WorkflowEngine待执行的workflow"""
|
||||||
self.context_memory: Dict[str, Any] = {}
|
|
||||||
"""上下文管理器:当前workflow执行过程中的缓存"""
|
|
||||||
self._steps_by_id: Dict[int, WorkStep] = {step.step: step for step in self.workflow.work_link}
|
self._steps_by_id: Dict[int, WorkStep] = {step.step: step for step in self.workflow.work_link}
|
||||||
"""步骤表:将当前workflow的步骤序号和步骤内容存放"""
|
"""步骤表:将当前workflow的步骤序号和步骤内容存放"""
|
||||||
|
|
||||||
|
|
@ -58,9 +56,9 @@ class WorkflowEngine:
|
||||||
case None:
|
case None:
|
||||||
return None
|
return None
|
||||||
case str(name):
|
case str(name):
|
||||||
return self.context_memory.get(name)
|
return self.workflow.context_memory.get(name)
|
||||||
case list(names):
|
case list(names):
|
||||||
return {k: self.context_memory.get(k) for k in names}
|
return {k: self.workflow.context_memory.get(k) for k in names}
|
||||||
|
|
||||||
async def run(self):
|
async def run(self):
|
||||||
"""
|
"""
|
||||||
|
|
@ -84,7 +82,7 @@ class WorkflowEngine:
|
||||||
step_result, is_success = await self._dispatch_to_node(current_step, step_input_data)
|
step_result, is_success = await self._dispatch_to_node(current_step, step_input_data)
|
||||||
if is_success:
|
if is_success:
|
||||||
if current_step.outputs:
|
if current_step.outputs:
|
||||||
self.context_memory[current_step.outputs] = step_result
|
self.workflow.context_memory[current_step.outputs] = step_result
|
||||||
logger.debug(f"Step {current_step_id} 产出已保存至变量: '{current_step.outputs}'")
|
logger.debug(f"Step {current_step_id} 产出已保存至变量: '{current_step.outputs}'")
|
||||||
current_step.status = "completed"
|
current_step.status = "completed"
|
||||||
else:
|
else:
|
||||||
|
|
@ -104,7 +102,7 @@ class WorkflowEngine:
|
||||||
self.workflow.status.status = "failed"
|
self.workflow.status.status = "failed"
|
||||||
break
|
break
|
||||||
logger.info(f"✅ 工作流 {self.workflow.title} 执行步骤结束。")
|
logger.info(f"✅ 工作流 {self.workflow.title} 执行步骤结束。")
|
||||||
self.workflow.output = self.context_memory
|
self.workflow.output = self.workflow.context_memory
|
||||||
await self._report_results()
|
await self._report_results()
|
||||||
|
|
||||||
async def _report_results(self):
|
async def _report_results(self):
|
||||||
|
|
@ -141,7 +139,7 @@ class WorkflowEngine:
|
||||||
message=f"工作流执行完毕。系统报告:{report}"
|
message=f"工作流执行完毕。系统报告:{report}"
|
||||||
)
|
)
|
||||||
user_response = await self.supervisory_node.working.remote(term_msg)
|
user_response = await self.supervisory_node.working.remote(term_msg)
|
||||||
self.context_memory["_final_user_response"] = user_response
|
self.workflow.context_memory["_final_user_response"] = user_response
|
||||||
logger.info(f"Supervisory 最终回复:{user_response}")
|
logger.info(f"Supervisory 最终回复:{user_response}")
|
||||||
else:
|
else:
|
||||||
logger.warning("未提供 supervisory_node 句柄,跳过用户反馈生成。")
|
logger.warning("未提供 supervisory_node 句柄,跳过用户反馈生成。")
|
||||||
|
|
|
||||||
|
|
@ -14,7 +14,6 @@
|
||||||
|
|
||||||
import json
|
import json
|
||||||
from pretor.core.workflow.workflow_template_generator.workflow_template_generator import WorkflowTemplateGenerator
|
from pretor.core.workflow.workflow_template_generator.workflow_template_generator import WorkflowTemplateGenerator
|
||||||
from pretor.core.workflow.workflow import PretorWorkflow
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
|
|
@ -41,9 +40,4 @@ class WorkflowManager:
|
||||||
try:
|
try:
|
||||||
self.workflow_template_generator.generate_workflow_template(name=name, desc=desc, steps=steps)
|
self.workflow_template_generator.generate_workflow_template(name=name, desc=desc, steps=steps)
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def create_workflow(workflow_json: str) -> PretorWorkflow:
|
|
||||||
workflow = PretorWorkflow.model_validate_json(workflow_json)
|
|
||||||
return workflow
|
|
||||||
Loading…
Reference in New Issue