refactor(core): decouple actors and remove workflow templates (#67)

Removes the deprecated `workflow_template` concept entirely across both backend API routers, internal logic handling within the `supervisory_node` and `consciousness_node`, and front-end components. Enables `consciousness_node` to work autonomously.

Also refactors core package structure to enforce the "one python package, one Ray Actor" architectural rule. `GlobalWorkflowManager`, `WorkflowRunningEngine`, `PostgresDatabase`, and `WorkerCluster` have been moved to their own top-level decoupled package directories with properly exported `__init__.py` modules. Test suites have been relocated and import paths updated across the system.

Co-authored-by: google-labs-jules[bot] <161369871+google-labs-jules[bot]@users.noreply.github.com>
Co-authored-by: zhaoxi826 <198742034+zhaoxi826@users.noreply.github.com>
This commit is contained in:
2026-05-06 15:05:47 +08:00
committed by GitHub
parent b3ea4cd8d9
commit 209ba45477
97 changed files with 1872 additions and 1498 deletions
@@ -1,6 +1,5 @@
import { SkillSettings } from './SkillSettings'; import { SkillSettings } from './SkillSettings';
import { ToolSettings } from './ToolSettings'; import { ToolSettings } from './ToolSettings';
import { WorkflowTemplateSettings } from './WorkflowTemplateSettings';
interface PluginLayoutProps { interface PluginLayoutProps {
resourceTab: string; resourceTab: string;
@@ -20,14 +19,6 @@ export function PluginLayout({ resourceTab, setResourceTab }: PluginLayoutProps)
> >
Skills Skills
</button> </button>
<button
onClick={() => setResourceTab('workflow_template')}
className={`py-4 text-sm font-medium border-b-2 transition-colors ${
resourceTab === 'workflow_template' ? 'border-blue-600 text-blue-600' : 'border-transparent text-slate-500 hover:text-slate-800'
}`}
>
Workflow Templates
</button>
<button <button
onClick={() => setResourceTab('tool')} onClick={() => setResourceTab('tool')}
className={`py-4 text-sm font-medium border-b-2 transition-colors ${ className={`py-4 text-sm font-medium border-b-2 transition-colors ${
@@ -41,7 +32,6 @@ export function PluginLayout({ resourceTab, setResourceTab }: PluginLayoutProps)
{/* Main Content */} {/* Main Content */}
<div className="flex-1 overflow-y-auto p-8"> <div className="flex-1 overflow-y-auto p-8">
{resourceTab === 'skill' && <SkillSettings />} {resourceTab === 'skill' && <SkillSettings />}
{resourceTab === 'workflow_template' && <WorkflowTemplateSettings />}
{resourceTab === 'tool' && <ToolSettings />} {resourceTab === 'tool' && <ToolSettings />}
</div> </div>
</div> </div>
@@ -1,163 +0,0 @@
import { useState, useEffect } from 'react';
import apiClient from '../../api/client';
import { FileCode, Trash2, Plus, LayoutTemplate } from 'lucide-react';
import type { WorkflowTemplate as ParsedWorkflowTemplate } from '../../types';
export function WorkflowTemplateSettings() {
const [templates, setTemplates] = useState<Record<string, ParsedWorkflowTemplate>>({});
const [loading, setLoading] = useState(true);
const [templateJson, setTemplateJson] = useState('{\n "name": "my_template",\n "steps": [\n {\n "name": "step1",\n "actor": "actor_name"\n }\n ]\n}');
const [creating, setCreating] = useState(false);
const [message, setMessage] = useState('');
const [error, setError] = useState('');
const validateTemplate = (data: any): data is ParsedWorkflowTemplate => {
if (!data || typeof data !== 'object') return false;
if (typeof data.name !== 'string') return false;
if (!Array.isArray(data.steps)) return false;
for (const step of data.steps) {
if (typeof step.name !== 'string') return false;
if (typeof step.actor !== 'string') return false;
}
return true;
};
const fetchTemplates = async () => {
setLoading(true);
try {
const response = await apiClient.get('/api/v1/resource/workflow_template');
setTemplates(response.data.templates || {});
} catch (err) {
console.error('Failed to fetch templates:', err);
} finally {
setLoading(false);
}
};
useEffect(() => {
fetchTemplates();
}, []);
const handleCreate = async (e: React.FormEvent) => {
e.preventDefault();
setCreating(true);
setMessage('');
setError('');
try {
const parsedJson = JSON.parse(templateJson);
if (!validateTemplate(parsedJson)) {
throw new Error('JSON structure does not match WorkflowTemplate schema (requires name and steps array with name and actor).');
}
await apiClient.post('/api/v1/resource/workflow_template', parsedJson);
setMessage('Workflow template created successfully');
setTemplateJson('{\n "name": "my_template",\n "steps": []\n}');
fetchTemplates();
} catch (err: any) {
console.error(err);
if (err instanceof SyntaxError) {
setError('Invalid JSON format');
} else {
setError(err.message || err.response?.data?.message || 'Failed to create workflow template');
}
} finally {
setCreating(false);
}
};
const handleDelete = async (templateName: string) => {
if (!confirm(`Are you sure you want to delete ${templateName}?`)) return;
try {
await apiClient.delete(`/api/v1/resource/workflow_template/${templateName}`);
fetchTemplates();
} catch (err: any) {
console.error('Failed to delete template:', err);
alert('Failed to delete template');
}
};
return (
<div className="max-w-4xl space-y-6">
<div className="mb-8">
<h1 className="text-2xl font-bold text-slate-800">Workflow Templates</h1>
<p className="text-slate-500 mt-1">Manage and create reusable workflow templates.</p>
</div>
<div className="bg-white rounded-xl shadow-sm border border-slate-200 overflow-hidden">
<div className="p-6 border-b border-slate-100 flex items-center space-x-3">
<div className="w-10 h-10 bg-blue-50 text-blue-600 rounded-lg flex items-center justify-center">
<FileCode size={20} />
</div>
<div>
<h2 className="text-lg font-semibold text-slate-800">Create Template</h2>
<p className="text-sm text-slate-500">Provide the JSON definition for a new workflow template.</p>
</div>
</div>
<div className="p-6">
<form onSubmit={handleCreate} className="space-y-4">
<div>
<label className="block text-sm font-medium text-slate-700 mb-1">Template JSON Definition</label>
<textarea
required
rows={8}
value={templateJson}
onChange={(e) => setTemplateJson(e.target.value)}
className="w-full px-4 py-2 border border-slate-200 rounded-lg focus:outline-none focus:ring-2 focus:ring-blue-500 font-mono text-sm"
/>
</div>
{message && <div className="text-green-600 text-sm">{message}</div>}
{error && <div className="text-red-600 text-sm">{error}</div>}
<div className="flex justify-end">
<button
type="submit"
disabled={creating}
className="flex items-center px-4 py-2 bg-blue-600 text-white rounded-lg hover:bg-blue-700 transition-colors disabled:opacity-50"
>
<Plus size={16} className="mr-2" />
{creating ? 'Creating...' : 'Create Template'}
</button>
</div>
</form>
</div>
</div>
<div className="bg-white rounded-xl shadow-sm border border-slate-200 overflow-hidden">
<div className="p-6 border-b border-slate-100">
<h2 className="text-lg font-semibold text-slate-800">Available Templates</h2>
</div>
<div className="p-6">
{loading ? (
<div className="text-slate-500 text-sm">Loading templates...</div>
) : Object.keys(templates).length === 0 ? (
<div className="text-slate-500 text-sm">No workflow templates created yet.</div>
) : (
<div className="grid grid-cols-1 md:grid-cols-2 gap-4">
{Object.keys(templates).map((name) => (
<div key={name} className="p-4 border border-slate-200 rounded-xl flex items-center justify-between hover:shadow-sm transition-shadow">
<div className="flex items-center space-x-3">
<div className="w-8 h-8 rounded-lg bg-slate-100 flex items-center justify-center text-slate-500">
<LayoutTemplate size={16} />
</div>
<span className="font-medium text-slate-800">{name}</span>
</div>
<button
onClick={() => handleDelete(name)}
className="p-2 text-slate-400 hover:text-red-500 hover:bg-red-50 rounded-lg transition-colors"
title="Delete Template"
>
<Trash2 size={16} />
</button>
</div>
))}
</div>
)}
</div>
</div>
</div>
);
}
+32 -30
View File
@@ -1,14 +1,14 @@
import asyncio import asyncio
import ray import ray
from pretor.worker_individual.worker_cluster import WorkerCluster from pretor.worker_cluster import WorkerCluster
from pretor.utils.banner import print_banner from pretor.utils.banner import print_banner
from pretor.core.database.postgres import PostgresDatabase from pretor.core.postgres_database import PostgresDatabase
from pretor.core.global_state_machine.global_state_machine import GlobalStateMachine from pretor.core.global_state_machine import GlobalStateMachine
from pretor.core.global_state_machine.global_workflow_manager import GlobalWorkflowManager from pretor.core.global_workflow_manager import GlobalWorkflowManager
from pretor.core.individual.supervisory_node.supervisory_node import SupervisoryNode from pretor.core.individual.supervisory_node import SupervisoryNode
from pretor.core.individual.consciousness_node.consciousness_node import ConsciousnessNode from pretor.core.individual.consciousness_node import ConsciousnessNode
from pretor.core.individual.control_node.control_node import ControlNode from pretor.core.individual.control_node import ControlNode
from pretor.core.workflow.workflow_runner import WorkflowRunningEngine from pretor.core.workflow_running_engine import WorkflowRunningEngine
from pretor.api import PretorGateway from pretor.api import PretorGateway
from ray import serve from ray import serve
import os import os
@@ -18,7 +18,10 @@ _secret_key = os.getenv("SECRET_KEY")
if not _secret_key or _secret_key in {"secret", "114514"}: if not _secret_key or _secret_key in {"secret", "114514"}:
_secret_key = secrets.token_urlsafe(32) _secret_key = secrets.token_urlsafe(32)
os.environ["SECRET_KEY"] = _secret_key os.environ["SECRET_KEY"] = _secret_key
print("⚠️ 警告: 未提供有效的 SECRET_KEY 或使用了不安全的默认值,已生成并设置随机密钥。") print(
"⚠️ 警告: 未提供有效的 SECRET_KEY 或使用了不安全的默认值,已生成并设置随机密钥。"
)
async def start_system(): async def start_system():
env_vars = { env_vars = {
@@ -30,21 +33,20 @@ async def start_system():
"SECRET_KEY": os.getenv("SECRET_KEY"), "SECRET_KEY": os.getenv("SECRET_KEY"),
} }
ray.init(ignore_reinit_error=True, ray.init(
namespace="pretor", ignore_reinit_error=True,
dashboard_host="0.0.0.0", namespace="pretor",
dashboard_port=8265, dashboard_host="0.0.0.0",
runtime_env={"env_vars": env_vars}) dashboard_port=8265,
runtime_env={"env_vars": env_vars},
)
# 2. 启动数据库组件 # 2. 启动数据库组件
postgres_database = PostgresDatabase.options(name='postgres_database').remote() postgres_database = PostgresDatabase.options(name="postgres_database").remote()
await postgres_database.init_db.remote() await postgres_database.init_db.remote()
global_state_machine = GlobalStateMachine.options( global_state_machine = GlobalStateMachine.options(
name='global_state_machine', name="global_state_machine", namespace="pretor", lifetime="detached"
namespace='pretor',
lifetime='detached'
).remote(postgres_database) ).remote(postgres_database)
print("正在等待 GlobalStateMachine 初始化并加载注册表...") print("正在等待 GlobalStateMachine 初始化并加载注册表...")
@@ -58,29 +60,29 @@ async def start_system():
return return
global_workflow_manager = GlobalWorkflowManager.options( global_workflow_manager = GlobalWorkflowManager.options(
name='global_workflow_manager', name="global_workflow_manager", namespace="pretor", lifetime="detached"
namespace='pretor',
lifetime='detached'
).remote() ).remote()
# 4. 启动核心节点 # 4. 启动核心节点
supervisory_node = SupervisoryNode.options(name='supervisory_node').remote() supervisory_node = SupervisoryNode.options(name="supervisory_node").remote()
consciousness_node = ConsciousnessNode.options(name='consciousness_node').remote() consciousness_node = ConsciousnessNode.options(name="consciousness_node").remote()
control_node = ControlNode.options(name='control_node').remote() control_node = ControlNode.options(name="control_node").remote()
try: try:
WorkerCluster.options( WorkerCluster.options(
name="worker_cluster", name="worker_cluster",
lifetime="detached" # 保证它在后台一直运行 lifetime="detached", # 保证它在后台一直运行
).remote() ).remote()
print("✅ WorkerCluster 已成功启动并注册!") print("✅ WorkerCluster 已成功启动并注册!")
except ValueError: except ValueError:
print("WorkerCluster 已经存在。") print("WorkerCluster 已经存在。")
# 5. 启动工作流运行引擎 # 5. 启动工作流运行引擎
workflow_engine = WorkflowRunningEngine.options(name='workflow_running_engine').remote( workflow_engine = WorkflowRunningEngine.options(
name="workflow_running_engine"
).remote(
consciousness_node=consciousness_node, consciousness_node=consciousness_node,
control_node=control_node, control_node=control_node,
supervisory_node=supervisory_node supervisory_node=supervisory_node,
) )
# 异步拉起 runner 协程群 # 异步拉起 runner 协程群
workflow_engine.run.remote() workflow_engine.run.remote()
@@ -110,5 +112,5 @@ def main():
print("系统已退出。") print("系统已退出。")
if __name__ == '__main__': if __name__ == "__main__":
main() main()
+43 -29
View File
@@ -22,22 +22,28 @@ from pretor.core.global_state_machine.model_provider import Provider
from pretor.utils.agent_model import ResponseModel, DepsModel from pretor.utils.agent_model import ResponseModel, DepsModel
from pretor.utils.error import ModelNotExistError from pretor.utils.error import ModelNotExistError
class AgentFactory: class AgentFactory:
"""AgentFactory 核心组件类。 """AgentFactory 核心组件类。
这是一个领域数据模型或功能封装类,承载了 AgentFactory 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。 """ 这是一个领域数据模型或功能封装类,承载了 AgentFactory 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。"""
def __init__(self):
self._models_mapping = {"openai": (OpenAIChatModel, OpenAIProvider),
"claude": (AnthropicModel, AnthropicProvider),
"deepseek": (OpenAIChatModel, OpenAIProvider),}
def create_agent(self, def __init__(self):
provider: Provider, self._models_mapping = {
model_id: str, "openai": (OpenAIChatModel, OpenAIProvider),
output_type: ResponseModel, "claude": (AnthropicModel, AnthropicProvider),
system_prompt: str, "deepseek": (OpenAIChatModel, OpenAIProvider),
deps_type: DepsModel, }
agent_name: str,
tools: list = None) -> Agent: def create_agent(
self,
provider: Provider,
model_id: str,
output_type: ResponseModel,
system_prompt: str,
deps_type: DepsModel,
agent_name: str,
tools: list = None,
) -> Agent:
""" """
create_agent方法,将输入的provider对象实例化为一个pydantic-ai的agent对象 create_agent方法,将输入的provider对象实例化为一个pydantic-ai的agent对象
@@ -58,22 +64,30 @@ class AgentFactory:
if provider.provider_type not in self._models_mapping: if provider.provider_type not in self._models_mapping:
raise ValueError(f"不支持的协议类型: {provider.provider_type}") raise ValueError(f"不支持的协议类型: {provider.provider_type}")
model_class, provider_class = self._models_mapping[provider.provider_type] model_class, provider_class = self._models_mapping[provider.provider_type]
model = model_class(model_id, provider=provider_class(api_key=provider.provider_apikey, base_url=provider.provider_url)) model = model_class(
model_id,
provider=provider_class(
api_key=provider.provider_apikey, base_url=provider.provider_url
),
)
match provider.provider_type: match provider.provider_type:
case "deepseek": case "deepseek":
agent = DeepSeekReasonerAgent(model=model, agent = DeepSeekReasonerAgent(
name=agent_name, model=model,
output_type=output_type, name=agent_name,
deps_type=deps_type, output_type=output_type,
system_prompt=system_prompt, deps_type=deps_type,
tools=tools, system_prompt=system_prompt,
retries=3, tools=tools,
) retries=3,
)
case _: case _:
agent = Agent(model=model, agent = Agent(
name=agent_name, model=model,
system_prompt=system_prompt, name=agent_name,
output_type=output_type, system_prompt=system_prompt,
deps_type=deps_type, output_type=output_type,
tools=tools) deps_type=deps_type,
return agent tools=tools,
)
return agent
@@ -18,25 +18,29 @@ from typing import Type, TypeVar, Any, Generic
from pydantic import BaseModel, ValidationError from pydantic import BaseModel, ValidationError
from pydantic_ai import Agent from pydantic_ai import Agent
T = TypeVar('T', bound=BaseModel) T = TypeVar("T", bound=BaseModel)
class AgentRunResultProxy: class AgentRunResultProxy:
"""AgentRunResultProxy 核心组件类。 """AgentRunResultProxy 核心组件类。
这是一个领域数据模型或功能封装类,承载了 AgentRunResultProxy 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。 """ 这是一个领域数据模型或功能封装类,承载了 AgentRunResultProxy 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。"""
def __init__(self, original, parsed): def __init__(self, original, parsed):
self._original = original self._original = original
self._parsed = parsed self._parsed = parsed
def __getattr__(self, name): def __getattr__(self, name):
"""检索并获取特定的 getattr 数据集合或实例对象。 """检索并获取特定的 getattr 数据集合或实例对象。
根据提供的查询条件或上下文凭证,从数据库、缓存或第三方服务中读取对应的资源状态。 根据提供的查询条件或上下文凭证,从数据库、缓存或第三方服务中读取对应的资源状态。
Args: name: 赋予该实体的人类可读名称或标题字符串,主要用于前端 UI 展示、日志记录或模糊检索。 Args: name: 赋予该实体的人类可读名称或标题字符串,主要用于前端 UI 展示、日志记录或模糊检索。
Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。"""
if name == 'data': if name == "data":
return self._parsed return self._parsed
if name == 'output': if name == "output":
return self._parsed return self._parsed
return getattr(self._original, name) return getattr(self._original, name)
class DeepSeekReasonerAgent(Generic[T]): class DeepSeekReasonerAgent(Generic[T]):
""" """
专为 DeepSeek-V4/R1 设计的适配器。 专为 DeepSeek-V4/R1 设计的适配器。
@@ -44,15 +48,15 @@ class DeepSeekReasonerAgent(Generic[T]):
""" """
def __init__( def __init__(
self, self,
model, model,
name, name,
output_type: Any = str, output_type: Any = str,
system_prompt: str = "", system_prompt: str = "",
deps_type: Type[Any] = None, deps_type: Type[Any] = None,
tools: list = None, tools: list = None,
retries: int = 3, retries: int = 3,
**kwargs **kwargs,
): ):
self.output_schema = output_type self.output_schema = output_type
self.has_custom_output = output_type is not str and output_type is not None self.has_custom_output = output_type is not str and output_type is not None
@@ -63,6 +67,7 @@ class DeepSeekReasonerAgent(Generic[T]):
if self.has_custom_output: if self.has_custom_output:
try: try:
from pydantic import TypeAdapter from pydantic import TypeAdapter
schema_dict = TypeAdapter(self.output_schema).json_schema() schema_dict = TypeAdapter(self.output_schema).json_schema()
schema_str = json.dumps(schema_dict, ensure_ascii=False) schema_str = json.dumps(schema_dict, ensure_ascii=False)
format_instruction = ( format_instruction = (
@@ -77,14 +82,14 @@ class DeepSeekReasonerAgent(Generic[T]):
if self.tools: if self.tools:
tool_descs = [] tool_descs = []
for t in self.tools: for t in self.tools:
desc = getattr(t, '__name__', str(t)) desc = getattr(t, "__name__", str(t))
if hasattr(t, '__doc__') and t.__doc__: if hasattr(t, "__doc__") and t.__doc__:
desc += f": {t.__doc__.strip()}" desc += f": {t.__doc__.strip()}"
tool_descs.append(f"- {desc}") tool_descs.append(f"- {desc}")
tool_instruction = ( tool_instruction = (
"\n\n系统为您提供了以下工具。由于当前处于结构化降级模式,无法原生调用。" "\n\n系统为您提供了以下工具。由于当前处于结构化降级模式,无法原生调用。"
"但如果您在思考过程中判断必须使用这些工具,请在返回的结构中(或如果是自由文本)注明意图,由外层逻辑进行调度:\n" + "但如果您在思考过程中判断必须使用这些工具,请在返回的结构中(或如果是自由文本)注明意图,由外层逻辑进行调度:\n"
"\n".join(tool_descs) + "\n".join(tool_descs)
) )
self.agent = Agent( self.agent = Agent(
@@ -93,40 +98,41 @@ class DeepSeekReasonerAgent(Generic[T]):
output_type=str, # Force native agent to return str to disable function calling output_type=str, # Force native agent to return str to disable function calling
system_prompt=system_prompt + format_instruction + tool_instruction, system_prompt=system_prompt + format_instruction + tool_instruction,
deps_type=deps_type, deps_type=deps_type,
**kwargs **kwargs,
) )
def _parse_output(self, text: str) -> Any: def _parse_output(self, text: str) -> Any:
"""执行与 parse output 相关的核心业务流转操作。 """执行与 parse output 相关的核心业务流转操作。
该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。 该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。
Args: text (str): 控制逻辑流向的具体字符串参数,指定了期望的 text 内容。 Args: text (str): 控制逻辑流向的具体字符串参数,指定了期望的 text 内容。
Returns: (Any): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ Returns: (Any): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。"""
if not self.has_custom_output: if not self.has_custom_output:
return text return text
match = re.search(r'```json\s*(.*?)\s*```', text, re.DOTALL) match = re.search(r"```json\s*(.*?)\s*```", text, re.DOTALL)
json_str = match.group(1).strip() if match else text json_str = match.group(1).strip() if match else text
if not json_str.startswith('{') and not json_str.startswith('['): if not json_str.startswith("{") and not json_str.startswith("["):
start_obj = json_str.find('{') start_obj = json_str.find("{")
start_arr = json_str.find('[') start_arr = json_str.find("[")
start = -1 start = -1
end = -1 end = -1
if start_obj != -1 and (start_arr == -1 or start_obj < start_arr): if start_obj != -1 and (start_arr == -1 or start_obj < start_arr):
start = start_obj start = start_obj
end = json_str.rfind('}') end = json_str.rfind("}")
elif start_arr != -1: elif start_arr != -1:
start = start_arr start = start_arr
end = json_str.rfind(']') end = json_str.rfind("]")
if start != -1 and end != -1 and end > start: if start != -1 and end != -1 and end > start:
json_str = json_str[start:end+1] json_str = json_str[start : end + 1]
if not json_str: if not json_str:
raise ValueError("未找到有效的 JSON 块。请将结果包装在 ```json 中。") raise ValueError("未找到有效的 JSON 块。请将结果包装在 ```json 中。")
try: try:
from pydantic import TypeAdapter from pydantic import TypeAdapter
adapter = TypeAdapter(self.output_schema) adapter = TypeAdapter(self.output_schema)
return adapter.validate_json(json_str) return adapter.validate_json(json_str)
except ValidationError as e: except ValidationError as e:
@@ -134,33 +140,35 @@ class DeepSeekReasonerAgent(Generic[T]):
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
raise ValueError(f"返回的不是合法的 JSON{e}") raise ValueError(f"返回的不是合法的 JSON{e}")
def __getattr__(self, item): def __getattr__(self, item):
# Delegate any unknown attributes (like .system_prompt, .tool) to the underlying pydantic_ai Agent # Delegate any unknown attributes (like .system_prompt, .tool) to the underlying pydantic_ai Agent
"""检索并获取特定的 getattr 数据集合或实例对象。 """检索并获取特定的 getattr 数据集合或实例对象。
根据提供的查询条件或上下文凭证,从数据库、缓存或第三方服务中读取对应的资源状态。 根据提供的查询条件或上下文凭证,从数据库、缓存或第三方服务中读取对应的资源状态。
Args: item: 参与 getattr 逻辑运算或数据构建的上下文依赖对象。 Args: item: 参与 getattr 逻辑运算或数据构建的上下文依赖对象。
Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。"""
return getattr(self.agent, item) return getattr(self.agent, item)
async def run(self, user_prompt: str, deps: Any = None, message_history: list = None, **kwargs) -> Any: async def run(
self, user_prompt: str, deps: Any = None, message_history: list = None, **kwargs
) -> Any:
# Custom retry loop # Custom retry loop
"""执行与 run 相关的核心业务流转操作。 """执行与 run 相关的核心业务流转操作。
该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。 该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。
Args: user_prompt (str): 控制逻辑流向的具体字符串参数,指定了期望的 user prompt 内容。 deps (Any): 参与 run 逻辑运算或数据构建的上下文依赖对象。 message_history (list): 批量操作所需的列表集合,囊括了需要统一处理的多个 message history 元素。 Args: user_prompt (str): 控制逻辑流向的具体字符串参数,指定了期望的 user prompt 内容。 deps (Any): 参与 run 逻辑运算或数据构建的上下文依赖对象。 message_history (list): 批量操作所需的列表集合,囊括了需要统一处理的多个 message history 元素。
Returns: (Any): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ Returns: (Any): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。"""
current_history = message_history or [] current_history = message_history or []
last_exception = None last_exception = None
for attempt in range(self.retries + 1): for attempt in range(self.retries + 1):
result = await self.agent.run( result = await self.agent.run(
user_prompt, user_prompt, deps=deps, message_history=current_history, **kwargs
deps=deps,
message_history=current_history,
**kwargs
) )
raw_text = result.data if hasattr(result, 'data') else getattr(result, 'output', str(result)) raw_text = (
result.data
if hasattr(result, "data")
else getattr(result, "output", str(result))
)
try: try:
parsed_data = self._parse_output(raw_text) parsed_data = self._parse_output(raw_text)
@@ -171,11 +179,15 @@ class DeepSeekReasonerAgent(Generic[T]):
except ValueError as e: except ValueError as e:
last_exception = e last_exception = e
# Prepare retry prompt # Prepare retry prompt
user_prompt = f"你的上一次输出解析失败,错误原因是: {e}\n请修正格式后重新输出。" user_prompt = (
f"你的上一次输出解析失败,错误原因是: {e}\n请修正格式后重新输出。"
)
# We need to maintain history manually so the model sees what it did wrong # We need to maintain history manually so the model sees what it did wrong
# Actually, pydantic-ai manages history inside the result. Let's use the all_messages from result # Actually, pydantic-ai manages history inside the result. Let's use the all_messages from result
if hasattr(result, 'all_messages'): if hasattr(result, "all_messages"):
current_history = result.all_messages() current_history = result.all_messages()
raise ValueError(f"Exceeded maximum retries ({self.retries}) for output validation. Last error: {last_exception}") raise ValueError(
f"Exceeded maximum retries ({self.retries}) for output validation. Last error: {last_exception}"
)
+29 -12
View File
@@ -28,9 +28,15 @@ from .provider import provider_router
from .resource import resource_router from .resource import resource_router
from .workflow import workflow_router from .workflow import workflow_router
from pretor.utils.error import ( from pretor.utils.error import (
DemandError, ModelNotExistError, UserError, DemandError,
UserNotExistError, UserPasswordError, ProviderError, ModelNotExistError,
ProviderNotExistError, WorkflowError, WorkflowExit UserError,
UserNotExistError,
UserPasswordError,
ProviderError,
ProviderNotExistError,
WorkflowError,
WorkflowExit,
) )
app = FastAPI() app = FastAPI()
@@ -43,6 +49,7 @@ app.include_router(cluster_router) # 集群信息路径
app.include_router(agent_router) # agent路径 app.include_router(agent_router) # agent路径
app.include_router(workflow_router) # workflow路径 app.include_router(workflow_router) # workflow路径
@app.exception_handler(UserNotExistError) @app.exception_handler(UserNotExistError)
async def user_not_exist_handler(request: Request, exc: UserNotExistError): async def user_not_exist_handler(request: Request, exc: UserNotExistError):
return JSONResponse(status_code=404, content={"message": "用户不存在"}) return JSONResponse(status_code=404, content={"message": "用户不存在"})
@@ -87,37 +94,47 @@ async def workflow_exit_handler(request: Request, exc: WorkflowExit):
async def workflow_error_handler(request: Request, exc: WorkflowError): async def workflow_error_handler(request: Request, exc: WorkflowError):
return JSONResponse(status_code=500, content={"message": "工作流执行错误"}) return JSONResponse(status_code=500, content={"message": "工作流执行错误"})
base_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
base_dir = os.path.dirname(
os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
)
frontend_dir = os.path.join(base_dir, "frontend", "dist") frontend_dir = os.path.join(base_dir, "frontend", "dist")
if os.path.exists(frontend_dir): if os.path.exists(frontend_dir):
app.mount("/assets", StaticFiles(directory=os.path.join(frontend_dir, "assets")), name="assets") app.mount(
"/assets",
StaticFiles(directory=os.path.join(frontend_dir, "assets")),
name="assets",
)
@app.get("/favicon.svg", include_in_schema=False) @app.get("/favicon.svg", include_in_schema=False)
async def serve_favicon(): async def serve_favicon():
return FileResponse(os.path.join(frontend_dir, "favicon.svg")) return FileResponse(os.path.join(frontend_dir, "favicon.svg"))
@app.get("/icons.svg", include_in_schema=False) @app.get("/icons.svg", include_in_schema=False)
async def serve_icons(): async def serve_icons():
return FileResponse(os.path.join(frontend_dir, "icons.svg")) return FileResponse(os.path.join(frontend_dir, "icons.svg"))
@app.get("/{full_path:path}", include_in_schema=False) @app.get("/{full_path:path}", include_in_schema=False)
async def serve_frontend(full_path: str): async def serve_frontend(full_path: str):
# 【重要安全修复】避免拦截不存在的 API 路由。如果是调用了不存在的 /api/ 接口,直接返回 404,不返回前端页面 # 【重要安全修复】避免拦截不存在的 API 路由。如果是调用了不存在的 /api/ 接口,直接返回 404,不返回前端页面
if full_path.startswith("api/"): if full_path.startswith("api/"):
return JSONResponse(status_code=404, content={"detail": "API endpoint not found"}) return JSONResponse(
status_code=404, content={"detail": "API endpoint not found"}
)
index_path = os.path.join(frontend_dir, "index.html") index_path = os.path.join(frontend_dir, "index.html")
if os.path.exists(index_path): if os.path.exists(index_path):
return FileResponse(index_path) return FileResponse(index_path)
return JSONResponse(status_code=404, content={"detail": "Frontend build not found"}) return JSONResponse(
status_code=404, content={"detail": "Frontend build not found"}
)
else: else:
import logging import logging
logging.getLogger("pretor").warning(f"Frontend dist folder not found at {frontend_dir}. Skipping frontend mount.") logging.getLogger("pretor").warning(
f"Frontend dist folder not found at {frontend_dir}. Skipping frontend mount."
)
@serve.deployment @serve.deployment
@@ -126,4 +143,4 @@ class PretorGateway:
gateway: Dict[str, WebSocket] gateway: Dict[str, WebSocket]
def __init__(self): def __init__(self):
self.gateway = {} self.gateway = {}
+92 -42
View File
@@ -26,38 +26,48 @@ from pretor.core.database.table.user import UserAuthority
agent_router = APIRouter(prefix="/api/v1/agent", tags=["agent"]) agent_router = APIRouter(prefix="/api/v1/agent", tags=["agent"])
class AgentRegister(BaseModel): class AgentRegister(BaseModel):
"""AgentRegister 核心组件类。 """AgentRegister 核心组件类。
这是一个领域数据模型或功能封装类,承载了 AgentRegister 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。 """ 这是一个领域数据模型或功能封装类,承载了 AgentRegister 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。"""
provider_title: str provider_title: str
model_id: str model_id: str
individual_name: str individual_name: str
tools: Optional[List[str]] = None tools: Optional[List[str]] = None
class AgentLocalRegister(BaseModel): class AgentLocalRegister(BaseModel):
"""AgentLocalRegister 核心组件类。 """AgentLocalRegister 核心组件类。
这是一个领域数据模型或功能封装类,承载了 AgentLocalRegister 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。 """ 这是一个领域数据模型或功能封装类,承载了 AgentLocalRegister 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。"""
path: str path: str
individual_name: str individual_name: str
tools: Optional[List[str]] = None tools: Optional[List[str]] = None
@agent_router.get("") @agent_router.get("")
async def get_system_nodes(_: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.USER))): async def get_system_nodes(
_: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.USER)),
):
"""处理针对 get system nodes 相关的 HTTP API 请求。 """处理针对 get system nodes 相关的 HTTP API 请求。
该接口负责解析前端传入的载荷数据,调用底层核心业务逻辑进行处理,并组装标准化的 JSON 响应。 该接口负责解析前端传入的载荷数据,调用底层核心业务逻辑进行处理,并组装标准化的 JSON 响应。
Args: _ (TokenData): 参与 get system nodes 逻辑运算或数据构建的上下文依赖对象。 Args: _ (TokenData): 参与 get system nodes 逻辑运算或数据构建的上下文依赖对象。
Returns: : 序列化后的标准网络响应模型(如包含业务状态码、成功标志及对应的数据载荷 Data)。 """ Returns: : 序列化后的标准网络响应模型(如包含业务状态码、成功标志及对应的数据载荷 Data)。"""
postgres_database = ray_actor_hook("postgres_database").postgres_database postgres_database = ray_actor_hook("postgres_database").postgres_database
configs = await postgres_database.get_all_system_node_configs.remote() configs = await postgres_database.get_all_system_node_configs.remote()
return {"system_nodes": configs} return {"system_nodes": configs}
@agent_router.post("") @agent_router.post("")
async def load_agent(agent_register: Union[AgentRegister, AgentLocalRegister], async def load_agent(
_: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.USER))): agent_register: Union[AgentRegister, AgentLocalRegister],
_: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.USER)),
):
"""处理针对 load agent 相关的 HTTP API 请求。 """处理针对 load agent 相关的 HTTP API 请求。
该接口负责解析前端传入的载荷数据,调用底层核心业务逻辑进行处理,并组装标准化的 JSON 响应。 该接口负责解析前端传入的载荷数据,调用底层核心业务逻辑进行处理,并组装标准化的 JSON 响应。
Args: agent_register (Union[AgentRegister, AgentLocalRegister]): 参与 load agent 逻辑运算或数据构建的上下文依赖对象。 _ (TokenData): 参与 load agent 逻辑运算或数据构建的上下文依赖对象。 Args: agent_register (Union[AgentRegister, AgentLocalRegister]): 参与 load agent 逻辑运算或数据构建的上下文依赖对象。 _ (TokenData): 参与 load agent 逻辑运算或数据构建的上下文依赖对象。
Returns: : 序列化后的标准网络响应模型(如包含业务状态码、成功标志及对应的数据载荷 Data)。 """ Returns: : 序列化后的标准网络响应模型(如包含业务状态码、成功标志及对应的数据载荷 Data)。"""
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
postgres_database = ray_actor_hook("postgres_database").postgres_database postgres_database = ray_actor_hook("postgres_database").postgres_database
@@ -71,20 +81,35 @@ async def load_agent(agent_register: Union[AgentRegister, AgentLocalRegister],
agent_register.individual_name, agent_register.individual_name,
agent_register.provider_title, agent_register.provider_title,
agent_register.model_id, agent_register.model_id,
agent_register.tools agent_register.tools,
) )
# Load agent into state machine # Load agent into state machine
match agent_register.individual_name: match agent_register.individual_name:
case "supervisory_node": case "supervisory_node":
node = ray_actor_hook("supervisory_node").supervisory_node node = ray_actor_hook("supervisory_node").supervisory_node
await node.create_agent.remote(global_state_machine,agent_register.provider_title,agent_register.model_id, agent_register.tools) await node.create_agent.remote(
global_state_machine,
agent_register.provider_title,
agent_register.model_id,
agent_register.tools,
)
case "consciousness_node": case "consciousness_node":
node = ray_actor_hook("consciousness_node").consciousness_node node = ray_actor_hook("consciousness_node").consciousness_node
await node.create_agent.remote(global_state_machine,agent_register.provider_title,agent_register.model_id, agent_register.tools) await node.create_agent.remote(
global_state_machine,
agent_register.provider_title,
agent_register.model_id,
agent_register.tools,
)
case "control_node": case "control_node":
node = ray_actor_hook("control_node").control_node node = ray_actor_hook("control_node").control_node
await node.create_agent.remote(global_state_machine,agent_register.provider_title,agent_register.model_id, agent_register.tools) await node.create_agent.remote(
global_state_machine,
agent_register.provider_title,
agent_register.model_id,
agent_register.tools,
)
case _: case _:
pass pass
except Exception as e: except Exception as e:
@@ -94,7 +119,8 @@ async def load_agent(agent_register: Union[AgentRegister, AgentLocalRegister],
class WorkerIndividualCreate(BaseModel): class WorkerIndividualCreate(BaseModel):
"""WorkerIndividualCreate 核心组件类。 """WorkerIndividualCreate 核心组件类。
这是一个具体的 Worker 智能体实体类,代表着具备特定人设、领域技能或长文本处理能力的数字员工。它可以被控制器动态拉起,并在安全沙箱内执行复杂的工作流指令与多步骤推理任务。 """ 这是一个具体的 Worker 智能体实体类,代表着具备特定人设、领域技能或长文本处理能力的数字员工。它可以被控制器动态拉起,并在安全沙箱内执行复杂的工作流指令与多步骤推理任务。"""
agent_name: str agent_name: str
agent_type: AgentType agent_type: AgentType
description: str description: str
@@ -109,7 +135,8 @@ class WorkerIndividualCreate(BaseModel):
class WorkerIndividualUpdate(BaseModel): class WorkerIndividualUpdate(BaseModel):
"""WorkerIndividualUpdate 核心组件类。 """WorkerIndividualUpdate 核心组件类。
这是一个具体的 Worker 智能体实体类,代表着具备特定人设、领域技能或长文本处理能力的数字员工。它可以被控制器动态拉起,并在安全沙箱内执行复杂的工作流指令与多步骤推理任务。 """ 这是一个具体的 Worker 智能体实体类,代表着具备特定人设、领域技能或长文本处理能力的数字员工。它可以被控制器动态拉起,并在安全沙箱内执行复杂的工作流指令与多步骤推理任务。"""
agent_name: Optional[str] = None agent_name: Optional[str] = None
agent_type: Optional[AgentType] = None agent_type: Optional[AgentType] = None
description: Optional[str] = None description: Optional[str] = None
@@ -123,63 +150,78 @@ class WorkerIndividualUpdate(BaseModel):
@agent_router.post("/worker") @agent_router.post("/worker")
async def create_worker_individual(worker_data: WorkerIndividualCreate, async def create_worker_individual(
token_data: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.USER))): worker_data: WorkerIndividualCreate,
token_data: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.USER)),
):
"""处理针对 create worker individual 相关的 HTTP API 请求。 """处理针对 create worker individual 相关的 HTTP API 请求。
该接口负责解析前端传入的载荷数据,调用底层核心业务逻辑进行处理,并组装标准化的 JSON 响应。 该接口负责解析前端传入的载荷数据,调用底层核心业务逻辑进行处理,并组装标准化的 JSON 响应。
Args: worker_data (WorkerIndividualCreate): 从客户端传递过来或由上游组件生成的核心业务数据体,通常需要进一步的清洗和结构化解析。 token_data (TokenData): 从客户端传递过来或由上游组件生成的核心业务数据体,通常需要进一步的清洗和结构化解析。 Args: worker_data (WorkerIndividualCreate): 从客户端传递过来或由上游组件生成的核心业务数据体,通常需要进一步的清洗和结构化解析。 token_data (TokenData): 从客户端传递过来或由上游组件生成的核心业务数据体,通常需要进一步的清洗和结构化解析。
Returns: : 序列化后的标准网络响应模型(如包含业务状态码、成功标志及对应的数据载荷 Data)。 """ Returns: : 序列化后的标准网络响应模型(如包含业务状态码、成功标志及对应的数据载荷 Data)。"""
postgres_database = ray_actor_hook("postgres_database").postgres_database postgres_database = ray_actor_hook("postgres_database").postgres_database
data_dict = worker_data.model_dump() data_dict = worker_data.model_dump()
data_dict["owner_id"] = token_data.user_id data_dict["owner_id"] = token_data.user_id
worker = await postgres_database.add_worker_individual.remote( **data_dict) worker = await postgres_database.add_worker_individual.remote(**data_dict)
return {"message": "success", "agent_id": worker.agent_id} return {"message": "success", "agent_id": worker.agent_id}
@agent_router.get("/worker") @agent_router.get("/worker")
async def get_worker_individual_list(token_data: TokenData = Depends(Accessor.get_current_user)): async def get_worker_individual_list(
token_data: TokenData = Depends(Accessor.get_current_user),
):
"""处理针对 get worker individual list 相关的 HTTP API 请求。 """处理针对 get worker individual list 相关的 HTTP API 请求。
该接口负责解析前端传入的载荷数据,调用底层核心业务逻辑进行处理,并组装标准化的 JSON 响应。 该接口负责解析前端传入的载荷数据,调用底层核心业务逻辑进行处理,并组装标准化的 JSON 响应。
Args: token_data (TokenData): 从客户端传递过来或由上游组件生成的核心业务数据体,通常需要进一步的清洗和结构化解析。 Args: token_data (TokenData): 从客户端传递过来或由上游组件生成的核心业务数据体,通常需要进一步的清洗和结构化解析。
Returns: : 序列化后的标准网络响应模型(如包含业务状态码、成功标志及对应的数据载荷 Data)。 """ Returns: : 序列化后的标准网络响应模型(如包含业务状态码、成功标志及对应的数据载荷 Data)。"""
postgres_database = ray_actor_hook("postgres_database").postgres_database postgres_database = ray_actor_hook("postgres_database").postgres_database
workers = await postgres_database.get_worker_individual_list.remote( owner_id=token_data.user_id) workers = await postgres_database.get_worker_individual_list.remote(
owner_id=token_data.user_id
)
return {"workers": workers} return {"workers": workers}
@agent_router.get("/worker/{agent_id}") @agent_router.get("/worker/{agent_id}")
async def get_worker_individual(agent_id: str, async def get_worker_individual(
token_data: TokenData = Depends(Accessor.get_current_user)): agent_id: str, token_data: TokenData = Depends(Accessor.get_current_user)
):
"""处理针对 get worker individual 相关的 HTTP API 请求。 """处理针对 get worker individual 相关的 HTTP API 请求。
该接口负责解析前端传入的载荷数据,调用底层核心业务逻辑进行处理,并组装标准化的 JSON 响应。 该接口负责解析前端传入的载荷数据,调用底层核心业务逻辑进行处理,并组装标准化的 JSON 响应。
Args: agent_id (str): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 agent 实例。 token_data (TokenData): 从客户端传递过来或由上游组件生成的核心业务数据体,通常需要进一步的清洗和结构化解析。 Args: agent_id (str): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 agent 实例。 token_data (TokenData): 从客户端传递过来或由上游组件生成的核心业务数据体,通常需要进一步的清洗和结构化解析。
Returns: : 序列化后的标准网络响应模型(如包含业务状态码、成功标志及对应的数据载荷 Data)。 """ Returns: : 序列化后的标准网络响应模型(如包含业务状态码、成功标志及对应的数据载荷 Data)。"""
postgres_database = ray_actor_hook("postgres_database").postgres_database postgres_database = ray_actor_hook("postgres_database").postgres_database
worker = await postgres_database.get_worker_individual.remote( agent_id=agent_id) worker = await postgres_database.get_worker_individual.remote(agent_id=agent_id)
if not worker: if not worker:
raise HTTPException(status_code=404, detail="Agent not found") raise HTTPException(status_code=404, detail="Agent not found")
if worker.owner_id != token_data.user_id: if worker.owner_id != token_data.user_id:
raise HTTPException(status_code=403, detail="Forbidden: You do not own this agent") raise HTTPException(
status_code=403, detail="Forbidden: You do not own this agent"
)
return worker return worker
@agent_router.put("/worker/{agent_id}") @agent_router.put("/worker/{agent_id}")
async def update_worker_individual(agent_id: str, async def update_worker_individual(
worker_data: WorkerIndividualUpdate, agent_id: str,
token_data: TokenData = Depends(Accessor.get_current_user)): worker_data: WorkerIndividualUpdate,
token_data: TokenData = Depends(Accessor.get_current_user),
):
"""处理针对 update worker individual 相关的 HTTP API 请求。 """处理针对 update worker individual 相关的 HTTP API 请求。
该接口负责解析前端传入的载荷数据,调用底层核心业务逻辑进行处理,并组装标准化的 JSON 响应。 该接口负责解析前端传入的载荷数据,调用底层核心业务逻辑进行处理,并组装标准化的 JSON 响应。
Args: agent_id (str): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 agent 实例。 worker_data (WorkerIndividualUpdate): 从客户端传递过来或由上游组件生成的核心业务数据体,通常需要进一步的清洗和结构化解析。 token_data (TokenData): 从客户端传递过来或由上游组件生成的核心业务数据体,通常需要进一步的清洗和结构化解析。 Args: agent_id (str): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 agent 实例。 worker_data (WorkerIndividualUpdate): 从客户端传递过来或由上游组件生成的核心业务数据体,通常需要进一步的清洗和结构化解析。 token_data (TokenData): 从客户端传递过来或由上游组件生成的核心业务数据体,通常需要进一步的清洗和结构化解析。
Returns: : 序列化后的标准网络响应模型(如包含业务状态码、成功标志及对应的数据载荷 Data)。 """ Returns: : 序列化后的标准网络响应模型(如包含业务状态码、成功标志及对应的数据载荷 Data)。"""
postgres_database = ray_actor_hook("postgres_database").postgres_database postgres_database = ray_actor_hook("postgres_database").postgres_database
worker = await postgres_database.get_worker_individual.remote( agent_id=agent_id) worker = await postgres_database.get_worker_individual.remote(agent_id=agent_id)
if not worker: if not worker:
raise HTTPException(status_code=404, detail="Agent not found") raise HTTPException(status_code=404, detail="Agent not found")
if worker.owner_id != token_data.user_id: if worker.owner_id != token_data.user_id:
raise HTTPException(status_code=403, detail="Forbidden: You do not own this agent") raise HTTPException(
status_code=403, detail="Forbidden: You do not own this agent"
)
update_data = worker_data.model_dump(exclude_unset=True) update_data = worker_data.model_dump(exclude_unset=True)
updated_worker = await postgres_database.update_worker_individual.remote( agent_id=agent_id, **update_data) updated_worker = await postgres_database.update_worker_individual.remote(
agent_id=agent_id, **update_data
)
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
try: try:
@@ -189,18 +231,23 @@ async def update_worker_individual(agent_id: str,
return {"message": "success", "worker": updated_worker} return {"message": "success", "worker": updated_worker}
@agent_router.post("/worker/{agent_id}/reload") @agent_router.post("/worker/{agent_id}/reload")
async def reload_worker_individual(agent_id: str, token_data: TokenData = Depends(Accessor.get_current_user)): async def reload_worker_individual(
agent_id: str, token_data: TokenData = Depends(Accessor.get_current_user)
):
"""处理针对 reload worker individual 相关的 HTTP API 请求。 """处理针对 reload worker individual 相关的 HTTP API 请求。
该接口负责解析前端传入的载荷数据,调用底层核心业务逻辑进行处理,并组装标准化的 JSON 响应。 该接口负责解析前端传入的载荷数据,调用底层核心业务逻辑进行处理,并组装标准化的 JSON 响应。
Args: agent_id (str): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 agent 实例。 token_data (TokenData): 从客户端传递过来或由上游组件生成的核心业务数据体,通常需要进一步的清洗和结构化解析。 Args: agent_id (str): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 agent 实例。 token_data (TokenData): 从客户端传递过来或由上游组件生成的核心业务数据体,通常需要进一步的清洗和结构化解析。
Returns: : 序列化后的标准网络响应模型(如包含业务状态码、成功标志及对应的数据载荷 Data)。 """ Returns: : 序列化后的标准网络响应模型(如包含业务状态码、成功标志及对应的数据载荷 Data)。"""
postgres_database = ray_actor_hook("postgres_database").postgres_database postgres_database = ray_actor_hook("postgres_database").postgres_database
worker = await postgres_database.get_worker_individual.remote(agent_id=agent_id) worker = await postgres_database.get_worker_individual.remote(agent_id=agent_id)
if not worker: if not worker:
raise HTTPException(status_code=404, detail="Agent not found") raise HTTPException(status_code=404, detail="Agent not found")
if worker.owner_id != token_data.user_id: if worker.owner_id != token_data.user_id:
raise HTTPException(status_code=403, detail="Forbidden: You do not own this agent") raise HTTPException(
status_code=403, detail="Forbidden: You do not own this agent"
)
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
await global_state_machine.remove_individual.remote(agent_id) await global_state_machine.remove_individual.remote(agent_id)
@@ -209,17 +256,20 @@ async def reload_worker_individual(agent_id: str, token_data: TokenData = Depend
@agent_router.delete("/worker/{agent_id}") @agent_router.delete("/worker/{agent_id}")
async def delete_worker_individual(agent_id: str, async def delete_worker_individual(
token_data: TokenData = Depends(Accessor.get_current_user)): agent_id: str, token_data: TokenData = Depends(Accessor.get_current_user)
):
"""处理针对 delete worker individual 相关的 HTTP API 请求。 """处理针对 delete worker individual 相关的 HTTP API 请求。
该接口负责解析前端传入的载荷数据,调用底层核心业务逻辑进行处理,并组装标准化的 JSON 响应。 该接口负责解析前端传入的载荷数据,调用底层核心业务逻辑进行处理,并组装标准化的 JSON 响应。
Args: agent_id (str): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 agent 实例。 token_data (TokenData): 从客户端传递过来或由上游组件生成的核心业务数据体,通常需要进一步的清洗和结构化解析。 Args: agent_id (str): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 agent 实例。 token_data (TokenData): 从客户端传递过来或由上游组件生成的核心业务数据体,通常需要进一步的清洗和结构化解析。
Returns: : 序列化后的标准网络响应模型(如包含业务状态码、成功标志及对应的数据载荷 Data)。 """ Returns: : 序列化后的标准网络响应模型(如包含业务状态码、成功标志及对应的数据载荷 Data)。"""
postgres_database = ray_actor_hook("postgres_database").postgres_database postgres_database = ray_actor_hook("postgres_database").postgres_database
worker = await postgres_database.get_worker_individual.remote( agent_id=agent_id) worker = await postgres_database.get_worker_individual.remote(agent_id=agent_id)
if not worker: if not worker:
raise HTTPException(status_code=404, detail="Agent not found") raise HTTPException(status_code=404, detail="Agent not found")
if worker.owner_id != token_data.user_id: if worker.owner_id != token_data.user_id:
raise HTTPException(status_code=403, detail="Forbidden: You do not own this agent") raise HTTPException(
await postgres_database.delete_worker_individual.remote( agent_id=agent_id) status_code=403, detail="Forbidden: You do not own this agent"
return {"message": "success"} )
await postgres_database.delete_worker_individual.remote(agent_id=agent_id)
return {"message": "success"}
+52 -18
View File
@@ -24,79 +24,113 @@ from pretor.utils.error import UserNotExistError
auth_router = APIRouter(prefix="/api/v1/auth", tags=["auth"]) auth_router = APIRouter(prefix="/api/v1/auth", tags=["auth"])
class UserRegister(BaseModel): class UserRegister(BaseModel):
"""UserRegister 核心组件类。 """UserRegister 核心组件类。
这是一个领域数据模型或功能封装类,承载了 UserRegister 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。 """ 这是一个领域数据模型或功能封装类,承载了 UserRegister 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。"""
user_name: str user_name: str
password: str password: str
@auth_router.post("/register") @auth_router.post("/register")
async def create_user(user_register: UserRegister): async def create_user(user_register: UserRegister):
"""处理针对 create user 相关的 HTTP API 请求。 """处理针对 create user 相关的 HTTP API 请求。
该接口负责解析前端传入的载荷数据,调用底层核心业务逻辑进行处理,并组装标准化的 JSON 响应。 该接口负责解析前端传入的载荷数据,调用底层核心业务逻辑进行处理,并组装标准化的 JSON 响应。
Args: user_register (UserRegister): 参与 create user 逻辑运算或数据构建的上下文依赖对象。 Args: user_register (UserRegister): 参与 create user 逻辑运算或数据构建的上下文依赖对象。
Returns: : 序列化后的标准网络响应模型(如包含业务状态码、成功标志及对应的数据载荷 Data)。 """ Returns: : 序列化后的标准网络响应模型(如包含业务状态码、成功标志及对应的数据载荷 Data)。"""
postgres_database = ray_actor_hook("postgres_database").postgres_database postgres_database = ray_actor_hook("postgres_database").postgres_database
hashed_password = await run_in_threadpool(Accessor.hash_password, user_register.password) hashed_password = await run_in_threadpool(
user = await postgres_database.add_user.remote( user_register.user_name, hashed_password) Accessor.hash_password, user_register.password
)
user = await postgres_database.add_user.remote(
user_register.user_name, hashed_password
)
return {"message": "success", "user_id": user.user_id} return {"message": "success", "user_id": user.user_id}
class UserLogin(BaseModel): class UserLogin(BaseModel):
"""UserLogin 核心组件类。 """UserLogin 核心组件类。
这是一个领域数据模型或功能封装类,承载了 UserLogin 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。 """ 这是一个领域数据模型或功能封装类,承载了 UserLogin 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。"""
user_name: str user_name: str
password: str password: str
@auth_router.post("/login") @auth_router.post("/login")
async def login_user(user_login: UserLogin): async def login_user(user_login: UserLogin):
"""处理针对 login user 相关的 HTTP API 请求。 """处理针对 login user 相关的 HTTP API 请求。
该接口负责解析前端传入的载荷数据,调用底层核心业务逻辑进行处理,并组装标准化的 JSON 响应。 该接口负责解析前端传入的载荷数据,调用底层核心业务逻辑进行处理,并组装标准化的 JSON 响应。
Args: user_login (UserLogin): 参与 login user 逻辑运算或数据构建的上下文依赖对象。 Args: user_login (UserLogin): 参与 login user 逻辑运算或数据构建的上下文依赖对象。
Returns: : 序列化后的标准网络响应模型(如包含业务状态码、成功标志及对应的数据载荷 Data)。 """ Returns: : 序列化后的标准网络响应模型(如包含业务状态码、成功标志及对应的数据载荷 Data)。"""
postgres_database = ray_actor_hook("postgres_database").postgres_database postgres_database = ray_actor_hook("postgres_database").postgres_database
user = await postgres_database.login_user.remote( user_login.user_name) user = await postgres_database.login_user.remote(user_login.user_name)
if not user: if not user:
raise UserNotExistError() raise UserNotExistError()
token = await run_in_threadpool(Accessor.login_hashed_password, user, user_login.password) token = await run_in_threadpool(
return {"message":"success", "token":token} Accessor.login_hashed_password, user, user_login.password
)
return {"message": "success", "token": token}
class ChangeAuthorityRequest(BaseModel): class ChangeAuthorityRequest(BaseModel):
"""ChangeAuthorityRequest 核心组件类。 """ChangeAuthorityRequest 核心组件类。
这是一个领域数据模型或功能封装类,承载了 ChangeAuthorityRequest 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。 """ 这是一个领域数据模型或功能封装类,承载了 ChangeAuthorityRequest 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。"""
user_id: str user_id: str
new_authority: UserAuthority new_authority: UserAuthority
@auth_router.put("/authority") @auth_router.put("/authority")
async def change_authority( async def change_authority(
request: ChangeAuthorityRequest, request: ChangeAuthorityRequest,
_: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.SUPER_ADMINISTRATOR)) _: TokenData = Depends(
RoleChecker(allowed_roles=UserAuthority.SUPER_ADMINISTRATOR)
),
): ):
""" """
Update a user's authority level. Only accessible by SUPER_ADMINISTRATOR. Update a user's authority level. Only accessible by SUPER_ADMINISTRATOR.
""" """
postgres_database = ray_actor_hook("postgres_database").postgres_database postgres_database = ray_actor_hook("postgres_database").postgres_database
user = await postgres_database.change_user_authority.remote( user_id=request.user_id, new_authority=request.new_authority) user = await postgres_database.change_user_authority.remote(
return {"message": "success", "user_id": user.user_id, "new_authority": user.user_authority} user_id=request.user_id, new_authority=request.new_authority
)
return {
"message": "success",
"user_id": user.user_id,
"new_authority": user.user_authority,
}
@auth_router.get("/list") @auth_router.get("/list")
async def get_user_list( async def get_user_list(
_: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.SUPER_ADMINISTRATOR)) _: TokenData = Depends(
RoleChecker(allowed_roles=UserAuthority.SUPER_ADMINISTRATOR)
),
): ):
""" """
Get a list of all users. Only accessible by SUPER_ADMINISTRATOR. Get a list of all users. Only accessible by SUPER_ADMINISTRATOR.
""" """
postgres_database = ray_actor_hook("postgres_database").postgres_database postgres_database = ray_actor_hook("postgres_database").postgres_database
users = await postgres_database.get_all_users.remote() users = await postgres_database.get_all_users.remote()
return {"users": [{"user_id": u.user_id, "user_name": u.user_name, "role": u.user_authority} for u in users]} return {
"users": [
{"user_id": u.user_id, "user_name": u.user_name, "role": u.user_authority}
for u in users
]
}
@auth_router.delete("/{user_id}") @auth_router.delete("/{user_id}")
async def delete_user( async def delete_user(
user_id: str, user_id: str,
_: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.SUPER_ADMINISTRATOR)) _: TokenData = Depends(
RoleChecker(allowed_roles=UserAuthority.SUPER_ADMINISTRATOR)
),
): ):
""" """
Delete a user. Only accessible by SUPER_ADMINISTRATOR. Delete a user. Only accessible by SUPER_ADMINISTRATOR.
""" """
postgres_database = ray_actor_hook("postgres_database").postgres_database postgres_database = ray_actor_hook("postgres_database").postgres_database
await postgres_database.delete_user_by_id.remote( user_id=user_id) await postgres_database.delete_user_by_id.remote(user_id=user_id)
return {"message": "success"} return {"message": "success"}
+1 -1
View File
@@ -16,4 +16,4 @@ from fastapi import APIRouter
cluster_router = APIRouter(prefix="/api/v1/cluster", tags=["cluster"]) cluster_router = APIRouter(prefix="/api/v1/cluster", tags=["cluster"])
# Monitor websocket API temporarily removed # Monitor websocket API temporarily removed
+1
View File
@@ -13,4 +13,5 @@
# limitations under the License. # limitations under the License.
from .frontend import client_router from .frontend import client_router
__all__ = ["client_router"] __all__ = ["client_router"]
+24 -10
View File
@@ -19,20 +19,34 @@ from typing import Any, Dict
from pretor.core.workflow.workflow import PretorWorkflow from pretor.core.workflow.workflow import PretorWorkflow
import asyncio import asyncio
class PretorEvent(BaseModel): class PretorEvent(BaseModel):
"""PretorEvent 核心组件类。 """PretorEvent 核心组件类。
这是一个领域数据模型或功能封装类,承载了 PretorEvent 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。 """ 这是一个领域数据模型或功能封装类,承载了 PretorEvent 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。"""
model_config = ConfigDict(arbitrary_types_allowed=True) model_config = ConfigDict(arbitrary_types_allowed=True)
trace_id: str = Field(default_factory=lambda: str(ULID()), description="事件的唯一标识符") trace_id: str = Field(
default_factory=lambda: str(ULID()), description="事件的唯一标识符"
)
platform: str = Field(description="消息来源的平台") platform: str = Field(description="消息来源的平台")
user_id: str = Field(description="用户id") user_id: str = Field(description="用户id")
user_name: str = Field(description="用户名") user_name: str = Field(description="用户名")
create_time: str = Field(default_factory=lambda: str(datetime.datetime.now(datetime.timezone.utc).isoformat()), create_time: str = Field(
description="事件创建时间") default_factory=lambda: str(
datetime.datetime.now(datetime.timezone.utc).isoformat()
),
description="事件创建时间",
)
message: str = Field(description="用户发来的消息") message: str = Field(description="用户发来的消息")
attachment: Dict[str, str] | None = Field(default=None,description="附件") attachment: Dict[str, str] | None = Field(default=None, description="附件")
#-------------------------------------------------------------------------------------------------------------- # --------------------------------------------------------------------------------------------------------------
context: Dict[str, Any] = Field(default_factory=dict, description="事件上下文内容,可包含工作流模板等信息") context: Dict[str, Any] = Field(
workflow: PretorWorkflow | None = Field(default=None,description="工作流") default_factory=dict, description="事件上下文内容,可包含工作流模板等信息"
pending_queue: asyncio.Queue[str] | None= Field(default=None,description="待处理队列") )
receive_queue: asyncio.Queue[str] | None = Field(default=None,description="待接收队列") workflow: PretorWorkflow | None = Field(default=None, description="工作流")
pending_queue: asyncio.Queue[str] | None = Field(
default=None, description="待处理队列"
)
receive_queue: asyncio.Queue[str] | None = Field(
default=None, description="待接收队列"
)
+27 -15
View File
@@ -22,45 +22,54 @@ import anyio
from pretor.utils.logger import get_logger from pretor.utils.logger import get_logger
logger = get_logger('frontend') logger = get_logger("frontend")
client_router = APIRouter(prefix="/api/v1/adapter/client", tags=["client"]) client_router = APIRouter(prefix="/api/v1/adapter/client", tags=["client"])
class Message(BaseModel): class Message(BaseModel):
"""Message 核心组件类。 """Message 核心组件类。
这是一个领域数据模型或功能封装类,承载了 Message 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。 """ 这是一个领域数据模型或功能封装类,承载了 Message 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。"""
message: str message: str
@client_router.post("") @client_router.post("")
async def create_message(message: Message, async def create_message(
token_data: TokenData = Depends(Accessor.get_current_user)): message: Message, token_data: TokenData = Depends(Accessor.get_current_user)
):
"""处理针对 create message 相关的 HTTP API 请求。 """处理针对 create message 相关的 HTTP API 请求。
该接口负责解析前端传入的载荷数据,调用底层核心业务逻辑进行处理,并组装标准化的 JSON 响应。 该接口负责解析前端传入的载荷数据,调用底层核心业务逻辑进行处理,并组装标准化的 JSON 响应。
Args: message (Message): 参与 create message 逻辑运算或数据构建的上下文依赖对象。 token_data (TokenData): 从客户端传递过来或由上游组件生成的核心业务数据体,通常需要进一步的清洗和结构化解析。 Args: message (Message): 参与 create message 逻辑运算或数据构建的上下文依赖对象。 token_data (TokenData): 从客户端传递过来或由上游组件生成的核心业务数据体,通常需要进一步的清洗和结构化解析。
Returns: : 序列化后的标准网络响应模型(如包含业务状态码、成功标志及对应的数据载荷 Data)。 """ Returns: : 序列化后的标准网络响应模型(如包含业务状态码、成功标志及对应的数据载荷 Data)。"""
logger.info("收到消息,来源:客户端") logger.info("收到消息,来源:客户端")
logger.debug(f"消息内容:{message.message}") logger.debug(f"消息内容:{message.message}")
event = PretorEvent(platform="client", event = PretorEvent(
user_id=str(token_data.user_id), platform="client",
user_name=token_data.username, user_id=str(token_data.user_id),
message=message.message) user_name=token_data.username,
message=message.message,
)
supervisory_node = ray_actor_hook("supervisory_node").supervisory_node supervisory_node = ray_actor_hook("supervisory_node").supervisory_node
message = await supervisory_node.working.remote(event) message = await supervisory_node.working.remote(event)
if message.startswith("任务已创建"): if message.startswith("任务已创建"):
return {"message": f"{event.trace_id}\n\n{message}"} return {"message": f"{event.trace_id}\n\n{message}"}
elif message == "未知相应类型": elif message == "未知相应类型":
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="模型回复错误"
detail="模型回复错误") )
else: else:
return {"message": message} return {"message": message}
@client_router.post("/upload") @client_router.post("/upload")
async def upload_file(file: UploadFile = File(...), async def upload_file(
token_data: TokenData = Depends(Accessor.get_current_user)): file: UploadFile = File(...),
token_data: TokenData = Depends(Accessor.get_current_user),
):
"""处理针对 upload file 相关的 HTTP API 请求。 """处理针对 upload file 相关的 HTTP API 请求。
该接口负责解析前端传入的载荷数据,调用底层核心业务逻辑进行处理,并组装标准化的 JSON 响应。 该接口负责解析前端传入的载荷数据,调用底层核心业务逻辑进行处理,并组装标准化的 JSON 响应。
Args: file (UploadFile): 参与 upload file 逻辑运算或数据构建的上下文依赖对象。 token_data (TokenData): 从客户端传递过来或由上游组件生成的核心业务数据体,通常需要进一步的清洗和结构化解析。 Args: file (UploadFile): 参与 upload file 逻辑运算或数据构建的上下文依赖对象。 token_data (TokenData): 从客户端传递过来或由上游组件生成的核心业务数据体,通常需要进一步的清洗和结构化解析。
Returns: : 序列化后的标准网络响应模型(如包含业务状态码、成功标志及对应的数据载荷 Data)。 """ Returns: : 序列化后的标准网络响应模型(如包含业务状态码、成功标志及对应的数据载荷 Data)。"""
try: try:
upload_dir = "uploads" upload_dir = "uploads"
os.makedirs(upload_dir, exist_ok=True) os.makedirs(upload_dir, exist_ok=True)
@@ -69,7 +78,10 @@ async def upload_file(file: UploadFile = File(...),
while chunk := await file.read(64 * 1024): # 64KB chunks while chunk := await file.read(64 * 1024): # 64KB chunks
await buffer.write(chunk) await buffer.write(chunk)
logger.info(f"用户 {token_data.username} 上传了文件: {file.filename}") logger.info(f"用户 {token_data.username} 上传了文件: {file.filename}")
return {"filename": file.filename, "message": f"File {file.filename} uploaded successfully"} return {
"filename": file.filename,
"message": f"File {file.filename} uploaded successfully",
}
except Exception as e: except Exception as e:
logger.error(f"文件上传失败: {e}") logger.error(f"文件上传失败: {e}")
raise HTTPException(status_code=500, detail="文件上传失败") raise HTTPException(status_code=500, detail="文件上传失败")
+32 -15
View File
@@ -24,45 +24,62 @@ from pretor.utils.ray_hook import ray_actor_hook
provider_router = APIRouter(prefix="/api/v1/provider", tags=["provider"]) provider_router = APIRouter(prefix="/api/v1/provider", tags=["provider"])
class ProviderRegister(BaseModel): class ProviderRegister(BaseModel):
"""ProviderRegister 核心组件类。 """ProviderRegister 核心组件类。
这是一个模型/服务提供商适配器类,屏蔽了外部不同供应商(如 OpenAI、Anthropic 等)的底层 API 差异。它负责标准化参数组装、网络请求发送、鉴权处理以及响应结构的反序列化。 """ 这是一个模型/服务提供商适配器类,屏蔽了外部不同供应商(如 OpenAI、Anthropic 等)的底层 API 差异。它负责标准化参数组装、网络请求发送、鉴权处理以及响应结构的反序列化。"""
provider_type: Literal["openai", "claude", "deepseek"] provider_type: Literal["openai", "claude", "deepseek"]
provider_title: str provider_title: str
provider_url: str provider_url: str
provider_apikey: str provider_apikey: str
@provider_router.post("") @provider_router.post("")
async def create_provider(provider_register: ProviderRegister, async def create_provider(
token_data: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.USER))) -> None: provider_register: ProviderRegister,
token_data: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.USER)),
) -> None:
"""处理针对 create provider 相关的 HTTP API 请求。 """处理针对 create provider 相关的 HTTP API 请求。
该接口负责解析前端传入的载荷数据,调用底层核心业务逻辑进行处理,并组装标准化的 JSON 响应。 该接口负责解析前端传入的载荷数据,调用底层核心业务逻辑进行处理,并组装标准化的 JSON 响应。
Args: provider_register (ProviderRegister): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 provider_register 实例。 token_data (TokenData): 从客户端传递过来或由上游组件生成的核心业务数据体,通常需要进一步的清洗和结构化解析。 Args: provider_register (ProviderRegister): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 provider_register 实例。 token_data (TokenData): 从客户端传递过来或由上游组件生成的核心业务数据体,通常需要进一步的清洗和结构化解析。
Returns: (None): 序列化后的标准网络响应模型(如包含业务状态码、成功标志及对应的数据载荷 Data)。 """ Returns: (None): 序列化后的标准网络响应模型(如包含业务状态码、成功标志及对应的数据载荷 Data)。"""
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
await global_state_machine.add_provider_wrap.remote(provider_type=provider_register.provider_type, await global_state_machine.add_provider_wrap.remote(
provider_title=provider_register.provider_title, provider_type=provider_register.provider_type,
provider_url=provider_register.provider_url, provider_title=provider_register.provider_title,
provider_apikey=provider_register.provider_apikey, provider_url=provider_register.provider_url,
provider_owner=token_data.user_id) provider_apikey=provider_register.provider_apikey,
provider_owner=token_data.user_id,
)
@provider_router.get("/list") @provider_router.get("/list")
async def get_provider_list(_: TokenData = Depends(Accessor.get_current_user)) -> Dict[str, Dict[str, Provider]]: async def get_provider_list(
_: TokenData = Depends(Accessor.get_current_user),
) -> Dict[str, Dict[str, Provider]]:
"""处理针对 get provider list 相关的 HTTP API 请求。 """处理针对 get provider list 相关的 HTTP API 请求。
该接口负责解析前端传入的载荷数据,调用底层核心业务逻辑进行处理,并组装标准化的 JSON 响应。 该接口负责解析前端传入的载荷数据,调用底层核心业务逻辑进行处理,并组装标准化的 JSON 响应。
Args: _ (TokenData): 参与 get provider list 逻辑运算或数据构建的上下文依赖对象。 Args: _ (TokenData): 参与 get provider list 逻辑运算或数据构建的上下文依赖对象。
Returns: (Dict[str, Dict[str, Provider]]): 序列化后的标准网络响应模型(如包含业务状态码、成功标志及对应的数据载荷 Data)。 """ Returns: (Dict[str, Dict[str, Provider]]): 序列化后的标准网络响应模型(如包含业务状态码、成功标志及对应的数据载荷 Data)。"""
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
provider_list: Dict[str, Provider] = await global_state_machine.get_provider_list.remote() provider_list: Dict[
str, Provider
] = await global_state_machine.get_provider_list.remote()
return {"provider_list": provider_list} return {"provider_list": provider_list}
@provider_router.delete("/{provider_title}") @provider_router.delete("/{provider_title}")
async def delete_provider(provider_title: str, _: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.SUPER_ADMINISTRATOR))) -> dict: async def delete_provider(
provider_title: str,
_: TokenData = Depends(
RoleChecker(allowed_roles=UserAuthority.SUPER_ADMINISTRATOR)
),
) -> dict:
"""处理针对 delete provider 相关的 HTTP API 请求。 """处理针对 delete provider 相关的 HTTP API 请求。
该接口负责解析前端传入的载荷数据,调用底层核心业务逻辑进行处理,并组装标准化的 JSON 响应。 该接口负责解析前端传入的载荷数据,调用底层核心业务逻辑进行处理,并组装标准化的 JSON 响应。
Args: provider_title (str): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 provider_title 实例。 _ (TokenData): 参与 delete provider 逻辑运算或数据构建的上下文依赖对象。 Args: provider_title (str): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 provider_title 实例。 _ (TokenData): 参与 delete provider 逻辑运算或数据构建的上下文依赖对象。
Returns: (dict): 序列化后的标准网络响应模型(如包含业务状态码、成功标志及对应的数据载荷 Data)。 """ Returns: (dict): 序列化后的标准网络响应模型(如包含业务状态码、成功标志及对应的数据载荷 Data)。"""
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
await global_state_machine.delete_provider.remote(provider_title=provider_title) await global_state_machine.delete_provider.remote(provider_title=provider_title)
return {"message": "success"} return {"message": "success"}
+35 -50
View File
@@ -14,7 +14,6 @@
from pydantic import BaseModel from pydantic import BaseModel
import viceroy import viceroy
from pretor.core.workflow.workflow_template_generator.workflow_template import WorkflowTemplate
from pretor.utils.ray_hook import ray_actor_hook from pretor.utils.ray_hook import ray_actor_hook
from fastapi import APIRouter, Depends from fastapi import APIRouter, Depends
from pretor.utils.access import TokenData from pretor.utils.access import TokenData
@@ -23,97 +22,83 @@ from pretor.core.database.table.user import UserAuthority
resource_router = APIRouter(prefix="/api/v1/resource") resource_router = APIRouter(prefix="/api/v1/resource")
@resource_router.post("/workflow_template")
async def create_workflow_template(workflow_template: WorkflowTemplate,
_: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.USER))):
"""处理针对 create workflow template 相关的 HTTP API 请求。
该接口负责解析前端传入的载荷数据,调用底层核心业务逻辑进行处理,并组装标准化的 JSON 响应。
Args: workflow_template (WorkflowTemplate): 参与 create workflow template 逻辑运算或数据构建的上下文依赖对象。 _ (TokenData): 参与 create workflow template 逻辑运算或数据构建的上下文依赖对象。
Returns: : 序列化后的标准网络响应模型(如包含业务状态码、成功标志及对应的数据载荷 Data)。 """
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
await global_state_machine.add_workflow_template.remote( workflow_template.name, workflow_template)
return {"message": "创建成功"}
@resource_router.get("/workflow_template")
async def get_workflow_templates(_: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.USER))):
"""处理针对 get workflow templates 相关的 HTTP API 请求。
该接口负责解析前端传入的载荷数据,调用底层核心业务逻辑进行处理,并组装标准化的 JSON 响应。
Args: _ (TokenData): 参与 get workflow templates 逻辑运算或数据构建的上下文依赖对象。
Returns: : 序列化后的标准网络响应模型(如包含业务状态码、成功标志及对应的数据载荷 Data)。 """
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
templates = await global_state_machine.get_all_workflow_templates.remote()
return {"templates": templates}
@resource_router.delete("/workflow_template/{template_name}")
async def delete_workflow_template(template_name: str, _: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.SUPER_ADMINISTRATOR))):
"""处理针对 delete workflow template 相关的 HTTP API 请求。
该接口负责解析前端传入的载荷数据,调用底层核心业务逻辑进行处理,并组装标准化的 JSON 响应。
Args: template_name (str): 赋予该实体的人类可读名称或标题字符串,主要用于前端 UI 展示、日志记录或模糊检索。 _ (TokenData): 参与 delete workflow template 逻辑运算或数据构建的上下文依赖对象。
Returns: : 序列化后的标准网络响应模型(如包含业务状态码、成功标志及对应的数据载荷 Data)。 """
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
await global_state_machine.delete_workflow_template.remote( template_name)
return {"message": "success"}
class Skill(BaseModel): class Skill(BaseModel):
"""Skill 核心组件类。 """Skill 核心组件类。
这是一个领域数据模型或功能封装类,承载了 Skill 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。 """ 这是一个领域数据模型或功能封装类,承载了 Skill 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。"""
repo_url: str repo_url: str
path: str | None path: str | None
@resource_router.post("/skill") @resource_router.post("/skill")
async def install_skill(skill: Skill, async def install_skill(
_: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.USER))): skill: Skill, _: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.USER))
):
"""处理针对 install skill 相关的 HTTP API 请求。 """处理针对 install skill 相关的 HTTP API 请求。
该接口负责解析前端传入的载荷数据,调用底层核心业务逻辑进行处理,并组装标准化的 JSON 响应。 该接口负责解析前端传入的载荷数据,调用底层核心业务逻辑进行处理,并组装标准化的 JSON 响应。
Args: skill (Skill): 参与 install skill 逻辑运算或数据构建的上下文依赖对象。 _ (TokenData): 参与 install skill 逻辑运算或数据构建的上下文依赖对象。 Args: skill (Skill): 参与 install skill 逻辑运算或数据构建的上下文依赖对象。 _ (TokenData): 参与 install skill 逻辑运算或数据构建的上下文依赖对象。
Returns: : 序列化后的标准网络响应模型(如包含业务状态码、成功标志及对应的数据载荷 Data)。 """ Returns: : 序列化后的标准网络响应模型(如包含业务状态码、成功标志及对应的数据载荷 Data)。"""
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
# noinspection PyUnresolvedReferences # noinspection PyUnresolvedReferences
import os import os
skill_output_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "plugin", "skill"))
skill_output_dir = os.path.abspath(
os.path.join(os.path.dirname(__file__), "..", "plugin", "skill")
)
os.makedirs(skill_output_dir, exist_ok=True) os.makedirs(skill_output_dir, exist_ok=True)
await viceroy.install_skill_async(url = skill.repo_url, await viceroy.install_skill_async(
path = skill.path, url=skill.repo_url, path=skill.path, output=skill_output_dir
output = skill_output_dir) )
if skill.path: if skill.path:
skill_name = skill.path.split("/")[-1] skill_name = skill.path.split("/")[-1]
else: else:
skill_name = skill.repo_url.split("/")[-1] skill_name = skill.repo_url.split("/")[-1]
await global_state_machine.add_skill.remote( skill_name) await global_state_machine.add_skill.remote(skill_name)
return {"message": "创建成功"} return {"message": "创建成功"}
@resource_router.get("/skill") @resource_router.get("/skill")
async def get_skills(_: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.USER))): async def get_skills(
_: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.USER)),
):
"""处理针对 get skills 相关的 HTTP API 请求。 """处理针对 get skills 相关的 HTTP API 请求。
该接口负责解析前端传入的载荷数据,调用底层核心业务逻辑进行处理,并组装标准化的 JSON 响应。 该接口负责解析前端传入的载荷数据,调用底层核心业务逻辑进行处理,并组装标准化的 JSON 响应。
Args: _ (TokenData): 参与 get skills 逻辑运算或数据构建的上下文依赖对象。 Args: _ (TokenData): 参与 get skills 逻辑运算或数据构建的上下文依赖对象。
Returns: : 序列化后的标准网络响应模型(如包含业务状态码、成功标志及对应的数据载荷 Data)。 """ Returns: : 序列化后的标准网络响应模型(如包含业务状态码、成功标志及对应的数据载荷 Data)。"""
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
skills = await global_state_machine.get_skill_list.remote() skills = await global_state_machine.get_skill_list.remote()
return {"skills": skills} return {"skills": skills}
@resource_router.delete("/skill/{skill_name}") @resource_router.delete("/skill/{skill_name}")
async def delete_skill(skill_name: str, _: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.SUPER_ADMINISTRATOR))): async def delete_skill(
skill_name: str,
_: TokenData = Depends(
RoleChecker(allowed_roles=UserAuthority.SUPER_ADMINISTRATOR)
),
):
"""处理针对 delete skill 相关的 HTTP API 请求。 """处理针对 delete skill 相关的 HTTP API 请求。
该接口负责解析前端传入的载荷数据,调用底层核心业务逻辑进行处理,并组装标准化的 JSON 响应。 该接口负责解析前端传入的载荷数据,调用底层核心业务逻辑进行处理,并组装标准化的 JSON 响应。
Args: skill_name (str): 赋予该实体的人类可读名称或标题字符串,主要用于前端 UI 展示、日志记录或模糊检索。 _ (TokenData): 参与 delete skill 逻辑运算或数据构建的上下文依赖对象。 Args: skill_name (str): 赋予该实体的人类可读名称或标题字符串,主要用于前端 UI 展示、日志记录或模糊检索。 _ (TokenData): 参与 delete skill 逻辑运算或数据构建的上下文依赖对象。
Returns: : 序列化后的标准网络响应模型(如包含业务状态码、成功标志及对应的数据载荷 Data)。 """ Returns: : 序列化后的标准网络响应模型(如包含业务状态码、成功标志及对应的数据载荷 Data)。"""
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
# Note: this only removes it from the state machine manager. # Note: this only removes it from the state machine manager.
await global_state_machine.remove_skill.remote( skill_name) await global_state_machine.remove_skill.remote(skill_name)
return {"message": "success"} return {"message": "success"}
@resource_router.get("/tool") @resource_router.get("/tool")
async def get_tools(_: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.USER))): async def get_tools(
_: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.USER)),
):
"""处理针对 get tools 相关的 HTTP API 请求。 """处理针对 get tools 相关的 HTTP API 请求。
该接口负责解析前端传入的载荷数据,调用底层核心业务逻辑进行处理,并组装标准化的 JSON 响应。 该接口负责解析前端传入的载荷数据,调用底层核心业务逻辑进行处理,并组装标准化的 JSON 响应。
Args: _ (TokenData): 参与 get tools 逻辑运算或数据构建的上下文依赖对象。 Args: _ (TokenData): 参与 get tools 逻辑运算或数据构建的上下文依赖对象。
Returns: : 序列化后的标准网络响应模型(如包含业务状态码、成功标志及对应的数据载荷 Data)。 """ Returns: : 序列化后的标准网络响应模型(如包含业务状态码、成功标志及对应的数据载荷 Data)。"""
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
tool_mapper = await global_state_machine.get_tool_mapper.remote() tool_mapper = await global_state_machine.get_tool_mapper.remote()
all_tool_names = set() all_tool_names = set()
for scope_tools in tool_mapper.values(): for scope_tools in tool_mapper.values():
all_tool_names.update(scope_tools.keys()) all_tool_names.update(scope_tools.keys())
return {"tools": list(all_tool_names)} return {"tools": list(all_tool_names)}
+31 -19
View File
@@ -20,12 +20,15 @@ import asyncio
workflow_router = APIRouter(prefix="/api/v1/workflow", tags=["workflow"]) workflow_router = APIRouter(prefix="/api/v1/workflow", tags=["workflow"])
@workflow_router.get("/list") @workflow_router.get("/list")
async def get_workflow_list(): async def get_workflow_list():
"""处理针对 get workflow list 相关的 HTTP API 请求。 """处理针对 get workflow list 相关的 HTTP API 请求。
该接口负责解析前端传入的载荷数据,调用底层核心业务逻辑进行处理,并组装标准化的 JSON 响应。 该接口负责解析前端传入的载荷数据,调用底层核心业务逻辑进行处理,并组装标准化的 JSON 响应。
Returns: : 序列化后的标准网络响应模型(如包含业务状态码、成功标志及对应的数据载荷 Data)。 """ Returns: : 序列化后的标准网络响应模型(如包含业务状态码、成功标志及对应的数据载荷 Data)。"""
global_workflow_manager = ray_actor_hook("global_workflow_manager").global_workflow_manager global_workflow_manager = ray_actor_hook(
"global_workflow_manager"
).global_workflow_manager
events = await global_workflow_manager.list_events.remote() events = await global_workflow_manager.list_events.remote()
return events return events
@@ -35,8 +38,10 @@ async def get_workflow_detail(trace_id: str):
"""处理针对 get workflow detail 相关的 HTTP API 请求。 """处理针对 get workflow detail 相关的 HTTP API 请求。
该接口负责解析前端传入的载荷数据,调用底层核心业务逻辑进行处理,并组装标准化的 JSON 响应。 该接口负责解析前端传入的载荷数据,调用底层核心业务逻辑进行处理,并组装标准化的 JSON 响应。
Args: trace_id (str): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 trace 实例。 Args: trace_id (str): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 trace 实例。
Returns: : 序列化后的标准网络响应模型(如包含业务状态码、成功标志及对应的数据载荷 Data)。 """ Returns: : 序列化后的标准网络响应模型(如包含业务状态码、成功标志及对应的数据载荷 Data)。"""
global_workflow_manager = ray_actor_hook("global_workflow_manager").global_workflow_manager global_workflow_manager = ray_actor_hook(
"global_workflow_manager"
).global_workflow_manager
event = await global_workflow_manager.get_event.remote(trace_id) event = await global_workflow_manager.get_event.remote(trace_id)
if not event: if not event:
raise HTTPException(status_code=404, detail="Workflow not found") raise HTTPException(status_code=404, detail="Workflow not found")
@@ -55,15 +60,17 @@ async def get_workflow_detail(trace_id: str):
steps = [] steps = []
for step in workflow.work_link: for step in workflow.work_link:
steps.append({ steps.append(
"step": step.step, {
"name": step.name, "step": step.step,
"node": step.node, "name": step.name,
"action": step.action, "node": step.node,
"desc": step.desc, "action": step.action,
"status": step.status, "desc": step.desc,
"agent_id": step.agent_id, "status": step.status,
}) "agent_id": step.agent_id,
}
)
return { return {
"event_id": trace_id, "event_id": trace_id,
"workflow_title": workflow.title, "workflow_title": workflow.title,
@@ -76,17 +83,20 @@ async def get_workflow_detail(trace_id: str):
"steps": steps, "steps": steps,
} }
@workflow_router.get("/sse/{trace_id}") @workflow_router.get("/sse/{trace_id}")
async def get_workflow_sse(trace_id: str, request: Request): async def get_workflow_sse(trace_id: str, request: Request):
"""处理针对 get workflow sse 相关的 HTTP API 请求。 """处理针对 get workflow sse 相关的 HTTP API 请求。
该接口负责解析前端传入的载荷数据,调用底层核心业务逻辑进行处理,并组装标准化的 JSON 响应。 该接口负责解析前端传入的载荷数据,调用底层核心业务逻辑进行处理,并组装标准化的 JSON 响应。
Args: trace_id (str): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 trace 实例。 request (Request): FastAPI 框架注入的原生 HTTP 请求对象,包含了完整的 Header 标头、查询参数和正文流。 Args: trace_id (str): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 trace 实例。 request (Request): FastAPI 框架注入的原生 HTTP 请求对象,包含了完整的 Header 标头、查询参数和正文流。
Returns: : 序列化后的标准网络响应模型(如包含业务状态码、成功标志及对应的数据载荷 Data)。 """ Returns: : 序列化后的标准网络响应模型(如包含业务状态码、成功标志及对应的数据载荷 Data)。"""
global_workflow_manager = ray_actor_hook("global_workflow_manager").global_workflow_manager global_workflow_manager = ray_actor_hook(
"global_workflow_manager"
).global_workflow_manager
async def event_generator(): async def event_generator():
"""执行与 event generator 相关的核心业务流转操作。 """执行与 event generator 相关的核心业务流转操作。
该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。 """ 该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。"""
try: try:
while True: while True:
if await request.is_disconnected(): if await request.is_disconnected():
@@ -102,15 +112,17 @@ async def get_workflow_sse(trace_id: str, request: Request):
return StreamingResponse(event_generator(), media_type="text/event-stream") return StreamingResponse(event_generator(), media_type="text/event-stream")
@workflow_router.post("/reply/{trace_id}") @workflow_router.post("/reply/{trace_id}")
async def post_workflow_reply(trace_id: str, request: Request): async def post_workflow_reply(trace_id: str, request: Request):
"""处理针对 post workflow reply 相关的 HTTP API 请求。 """处理针对 post workflow reply 相关的 HTTP API 请求。
该接口负责解析前端传入的载荷数据,调用底层核心业务逻辑进行处理,并组装标准化的 JSON 响应。 该接口负责解析前端传入的载荷数据,调用底层核心业务逻辑进行处理,并组装标准化的 JSON 响应。
Args: trace_id (str): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 trace 实例。 request (Request): FastAPI 框架注入的原生 HTTP 请求对象,包含了完整的 Header 标头、查询参数和正文流。 Args: trace_id (str): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 trace 实例。 request (Request): FastAPI 框架注入的原生 HTTP 请求对象,包含了完整的 Header 标头、查询参数和正文流。
Returns: : 序列化后的标准网络响应模型(如包含业务状态码、成功标志及对应的数据载荷 Data)。 """ Returns: : 序列化后的标准网络响应模型(如包含业务状态码、成功标志及对应的数据载荷 Data)。"""
data = await request.json() data = await request.json()
reply_msg = data.get("message", "") reply_msg = data.get("message", "")
global_workflow_manager = ray_actor_hook("global_workflow_manager").global_workflow_manager global_workflow_manager = ray_actor_hook(
"global_workflow_manager"
).global_workflow_manager
await global_workflow_manager.put_received.remote(trace_id, reply_msg) await global_workflow_manager.put_received.remote(trace_id, reply_msg)
return {"status": "ok"} return {"status": "ok"}
+9 -4
View File
@@ -17,16 +17,20 @@ from pydantic import ValidationError
from pretor.utils.error import UserNotExistError from pretor.utils.error import UserNotExistError
from pretor.utils.logger import get_logger from pretor.utils.logger import get_logger
logger = get_logger('database_exception')
logger = get_logger("database_exception")
def database_exception(func): def database_exception(func):
"""执行与 database exception 相关的核心业务流转操作。 """执行与 database exception 相关的核心业务流转操作。
该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。 该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。
Args: func: 参与 database exception 逻辑运算或数据构建的上下文依赖对象。 Args: func: 参与 database exception 逻辑运算或数据构建的上下文依赖对象。
Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。"""
async def wrapper(*args, **kwargs): async def wrapper(*args, **kwargs):
"""执行与 wrapper 相关的核心业务流转操作。 """执行与 wrapper 相关的核心业务流转操作。
该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。 该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。
Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。"""
try: try:
return await func(*args, **kwargs) return await func(*args, **kwargs)
except ValidationError as e: except ValidationError as e:
@@ -43,4 +47,5 @@ def database_exception(func):
except Exception as e: except Exception as e:
logger.exception(f"未预期的数据库错误: {e}") logger.exception(f"未预期的数据库错误: {e}")
raise e raise e
return wrapper
return wrapper
+1
View File
@@ -3,6 +3,7 @@ from typing import List, Optional
from pretor.core.database.table.event import EventRecord from pretor.core.database.table.event import EventRecord
from sqlalchemy.ext.asyncio import async_sessionmaker, AsyncSession from sqlalchemy.ext.asyncio import async_sessionmaker, AsyncSession
class EventDatabase: class EventDatabase:
def __init__(self, async_session_maker: async_sessionmaker[AsyncSession]): def __init__(self, async_session_maker: async_sessionmaker[AsyncSession]):
self.async_session_maker = async_session_maker self.async_session_maker = async_session_maker
+25 -13
View File
@@ -19,9 +19,11 @@ from pretor.core.database.database_exception import database_exception
from ulid import ULID from ulid import ULID
class IndividualDatabase: class IndividualDatabase:
"""IndividualDatabase 核心组件类。 """IndividualDatabase 核心组件类。
这是一个数据库操作层 (DAO/Repository) 封装类,专注于处理实体模型与关系型数据库表之间的映射。它将复杂的 SQL 查询、跨表 Join 和事务回滚逻辑进行了高级抽象,向上层服务暴露简洁的数据读写接口。 """ 这是一个数据库操作层 (DAO/Repository) 封装类,专注于处理实体模型与关系型数据库表之间的映射。它将复杂的 SQL 查询、跨表 Join 和事务回滚逻辑进行了高级抽象,向上层服务暴露简洁的数据读写接口。"""
def __init__(self, async_session_maker): def __init__(self, async_session_maker):
self.async_session_maker = async_session_maker self.async_session_maker = async_session_maker
@@ -29,7 +31,7 @@ class IndividualDatabase:
async def add_worker_individual(self, **kwargs) -> WorkerIndividual: async def add_worker_individual(self, **kwargs) -> WorkerIndividual:
"""创建并持久化新的 worker individual 实体。 """创建并持久化新的 worker individual 实体。
接收构建参数,执行必要的数据校验与默认值填充后,将新记录安全地写入底层存储或系统注册表中。 接收构建参数,执行必要的数据校验与默认值填充后,将新记录安全地写入底层存储或系统注册表中。
Returns: (WorkerIndividual): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ Returns: (WorkerIndividual): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。"""
async with self.async_session_maker() as session: async with self.async_session_maker() as session:
agent_id = str(ULID()) agent_id = str(ULID())
individual = WorkerIndividual(agent_id=agent_id, **kwargs) individual = WorkerIndividual(agent_id=agent_id, **kwargs)
@@ -43,9 +45,11 @@ class IndividualDatabase:
"""检索并获取特定的 worker individual 数据集合或实例对象。 """检索并获取特定的 worker individual 数据集合或实例对象。
根据提供的查询条件或上下文凭证,从数据库、缓存或第三方服务中读取对应的资源状态。 根据提供的查询条件或上下文凭证,从数据库、缓存或第三方服务中读取对应的资源状态。
Args: agent_id (str): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 agent 实例。 Args: agent_id (str): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 agent 实例。
Returns: (Optional[WorkerIndividual]): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ Returns: (Optional[WorkerIndividual]): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。"""
async with self.async_session_maker() as session: async with self.async_session_maker() as session:
statement = select(WorkerIndividual).where(WorkerIndividual.agent_id == agent_id) statement = select(WorkerIndividual).where(
WorkerIndividual.agent_id == agent_id
)
results = await session.execute(statement) results = await session.execute(statement)
return results.scalar_one_or_none() return results.scalar_one_or_none()
@@ -54,20 +58,26 @@ class IndividualDatabase:
"""检索并获取特定的 worker individual list 数据集合或实例对象。 """检索并获取特定的 worker individual list 数据集合或实例对象。
根据提供的查询条件或上下文凭证,从数据库、缓存或第三方服务中读取对应的资源状态。 根据提供的查询条件或上下文凭证,从数据库、缓存或第三方服务中读取对应的资源状态。
Args: owner_id (str): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 owner 实例。 Args: owner_id (str): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 owner 实例。
Returns: (List[WorkerIndividual]): 经过筛选、排序或分页处理后的实体对象列表集合。 """ Returns: (List[WorkerIndividual]): 经过筛选、排序或分页处理后的实体对象列表集合。"""
async with self.async_session_maker() as session: async with self.async_session_maker() as session:
statement = select(WorkerIndividual).where(WorkerIndividual.owner_id == owner_id) statement = select(WorkerIndividual).where(
WorkerIndividual.owner_id == owner_id
)
results = await session.execute(statement) results = await session.execute(statement)
return list(results.scalars().all()) return list(results.scalars().all())
@database_exception @database_exception
async def update_worker_individual(self, agent_id: str, **kwargs) -> Optional[WorkerIndividual]: async def update_worker_individual(
self, agent_id: str, **kwargs
) -> Optional[WorkerIndividual]:
"""对现有的 worker individual 进行状态更新或属性覆盖。 """对现有的 worker individual 进行状态更新或属性覆盖。
基于增量变更原则,合并最新的配置或数据,并触发相关依赖组件的缓存刷新或事件通知。 基于增量变更原则,合并最新的配置或数据,并触发相关依赖组件的缓存刷新或事件通知。
Args: agent_id (str): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 agent 实例。 Args: agent_id (str): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 agent 实例。
Returns: (Optional[WorkerIndividual]): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ Returns: (Optional[WorkerIndividual]): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。"""
async with self.async_session_maker() as session: async with self.async_session_maker() as session:
statement = select(WorkerIndividual).where(WorkerIndividual.agent_id == agent_id) statement = select(WorkerIndividual).where(
WorkerIndividual.agent_id == agent_id
)
results = await session.execute(statement) results = await session.execute(statement)
individual = results.scalar_one_or_none() individual = results.scalar_one_or_none()
if not individual: if not individual:
@@ -85,9 +95,11 @@ class IndividualDatabase:
"""安全地移除或注销 worker individual。 """安全地移除或注销 worker individual。
执行物理删除或逻辑删除操作,并妥善清理相关的关联数据及占用资源。 执行物理删除或逻辑删除操作,并妥善清理相关的关联数据及占用资源。
Args: agent_id (str): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 agent 实例。 Args: agent_id (str): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 agent 实例。
Returns: (bool): 一个布尔型结果标志,明确返回 True 表示该操作成功应用或条件达成,False 则表示失败或被拒绝。 """ Returns: (bool): 一个布尔型结果标志,明确返回 True 表示该操作成功应用或条件达成,False 则表示失败或被拒绝。"""
async with self.async_session_maker() as session: async with self.async_session_maker() as session:
statement = select(WorkerIndividual).where(WorkerIndividual.agent_id == agent_id) statement = select(WorkerIndividual).where(
WorkerIndividual.agent_id == agent_id
)
results = await session.execute(statement) results = await session.execute(statement)
individual = results.scalar_one_or_none() individual = results.scalar_one_or_none()
if not individual: if not individual:
@@ -100,8 +112,8 @@ class IndividualDatabase:
async def get_all_worker_individual(self) -> List[WorkerIndividual]: async def get_all_worker_individual(self) -> List[WorkerIndividual]:
"""检索并获取特定的 all worker individual 数据集合或实例对象。 """检索并获取特定的 all worker individual 数据集合或实例对象。
根据提供的查询条件或上下文凭证,从数据库、缓存或第三方服务中读取对应的资源状态。 根据提供的查询条件或上下文凭证,从数据库、缓存或第三方服务中读取对应的资源状态。
Returns: (List[WorkerIndividual]): 经过筛选、排序或分页处理后的实体对象列表集合。 """ Returns: (List[WorkerIndividual]): 经过筛选、排序或分页处理后的实体对象列表集合。"""
async with self.async_session_maker() as session: async with self.async_session_maker() as session:
statement = select(WorkerIndividual) statement = select(WorkerIndividual)
results = await session.execute(statement) results = await session.execute(statement)
return list(results.scalars().all()) return list(results.scalars().all())
+18 -11
View File
@@ -18,9 +18,11 @@ from pretor.core.database.table.provider import Provider
from sqlmodel import select from sqlmodel import select
from pretor.core.database.database_exception import database_exception from pretor.core.database.database_exception import database_exception
class ProviderDatabase: class ProviderDatabase:
"""ProviderDatabase 核心组件类。 """ProviderDatabase 核心组件类。
这是一个模型/服务提供商适配器类,屏蔽了外部不同供应商(如 OpenAI、Anthropic 等)的底层 API 差异。它负责标准化参数组装、网络请求发送、鉴权处理以及响应结构的反序列化。 """ 这是一个模型/服务提供商适配器类,屏蔽了外部不同供应商(如 OpenAI、Anthropic 等)的底层 API 差异。它负责标准化参数组装、网络请求发送、鉴权处理以及响应结构的反序列化。"""
def __init__(self, async_session_maker): def __init__(self, async_session_maker):
self.async_session_maker = async_session_maker self.async_session_maker = async_session_maker
@@ -28,23 +30,28 @@ class ProviderDatabase:
async def get_provider(self) -> List[Provider]: async def get_provider(self) -> List[Provider]:
"""检索并获取特定的 provider 数据集合或实例对象。 """检索并获取特定的 provider 数据集合或实例对象。
根据提供的查询条件或上下文凭证,从数据库、缓存或第三方服务中读取对应的资源状态。 根据提供的查询条件或上下文凭证,从数据库、缓存或第三方服务中读取对应的资源状态。
Returns: (List[Provider]): 经过筛选、排序或分页处理后的实体对象列表集合。 """ Returns: (List[Provider]): 经过筛选、排序或分页处理后的实体对象列表集合。"""
async with self.async_session_maker() as session: async with self.async_session_maker() as session:
statement = select(Provider) statement = select(Provider)
results = await session.execute(statement) results = await session.execute(statement)
results = results.scalars().all() results = results.scalars().all()
providers = [Provider(provider_title=provider.provider_title, providers = [
provider_url=provider.provider_url, Provider(
provider_apikey=provider.provider_apikey, provider_title=provider.provider_title,
provider_models=provider.provider_models, provider_url=provider.provider_url,
provider_type=provider.provider_type) for provider in results] provider_apikey=provider.provider_apikey,
provider_models=provider.provider_models,
provider_type=provider.provider_type,
)
for provider in results
]
return providers return providers
@database_exception @database_exception
async def add_provider(self, **kwargs) -> None: async def add_provider(self, **kwargs) -> None:
"""创建并持久化新的 provider 实体。 """创建并持久化新的 provider 实体。
接收构建参数,执行必要的数据校验与默认值填充后,将新记录安全地写入底层存储或系统注册表中。 接收构建参数,执行必要的数据校验与默认值填充后,将新记录安全地写入底层存储或系统注册表中。
Returns: (None): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ Returns: (None): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。"""
async with self.async_session_maker() as session: async with self.async_session_maker() as session:
provider = Provider(**kwargs) provider = Provider(**kwargs)
session.add(provider) session.add(provider)
@@ -55,7 +62,7 @@ class ProviderDatabase:
"""安全地移除或注销 provider。 """安全地移除或注销 provider。
执行物理删除或逻辑删除操作,并妥善清理相关的关联数据及占用资源。 执行物理删除或逻辑删除操作,并妥善清理相关的关联数据及占用资源。
Args: provider_id (str): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 provider 实例。 Args: provider_id (str): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 provider 实例。
Returns: (None): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ Returns: (None): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。"""
async with self.async_session_maker() as session: async with self.async_session_maker() as session:
provider = await session.get(Provider, provider_id) provider = await session.get(Provider, provider_id)
if provider is not None: if provider is not None:
@@ -67,7 +74,7 @@ class ProviderDatabase:
"""对现有的 provider 进行状态更新或属性覆盖。 """对现有的 provider 进行状态更新或属性覆盖。
基于增量变更原则,合并最新的配置或数据,并触发相关依赖组件的缓存刷新或事件通知。 基于增量变更原则,合并最新的配置或数据,并触发相关依赖组件的缓存刷新或事件通知。
Args: provider_id (str): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 provider 实例。 Args: provider_id (str): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 provider 实例。
Returns: (Provider): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ Returns: (Provider): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。"""
async with self.async_session_maker() as session: async with self.async_session_maker() as session:
provider = await session.get(Provider, provider_id) provider = await session.get(Provider, provider_id)
if provider is not None: if provider is not None:
@@ -77,4 +84,4 @@ class ProviderDatabase:
await session.commit() await session.commit()
await session.refresh(provider) await session.refresh(provider)
return provider return provider
return None return None
+28 -9
View File
@@ -17,20 +17,30 @@ from sqlmodel import select
from typing import List, Optional from typing import List, Optional
from pretor.core.database.database_exception import database_exception from pretor.core.database.database_exception import database_exception
class SystemNodeDatabase: class SystemNodeDatabase:
"""SystemNodeDatabase 核心组件类。 """SystemNodeDatabase 核心组件类。
这是一个系统执行节点类,作为多智能体架构中的独立处理单元。它能够接收工作流上下文,根据内置的大模型策略进行意图理解和自主决策,从而驱动特定阶段的任务闭环。 """ 这是一个系统执行节点类,作为多智能体架构中的独立处理单元。它能够接收工作流上下文,根据内置的大模型策略进行意图理解和自主决策,从而驱动特定阶段的任务闭环。"""
def __init__(self, async_session_maker): def __init__(self, async_session_maker):
self.async_session_maker = async_session_maker self.async_session_maker = async_session_maker
@database_exception @database_exception
async def upsert_system_node_config(self, node_name: str, provider_title: str, model_id: str, tools: Optional[List[str]] = None) -> SystemNodeConfig: async def upsert_system_node_config(
self,
node_name: str,
provider_title: str,
model_id: str,
tools: Optional[List[str]] = None,
) -> SystemNodeConfig:
"""执行与 upsert system node config 相关的核心业务流转操作。 """执行与 upsert system node config 相关的核心业务流转操作。
该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。 该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。
Args: node_name (str): 赋予该实体的人类可读名称或标题字符串,主要用于前端 UI 展示、日志记录或模糊检索。 provider_title (str): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 provider_title 实例。 model_id (str): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 model 实例。 tools (Optional[List[str]]): 控制逻辑流向的具体字符串参数,指定了期望的 tools 内容。 Args: node_name (str): 赋予该实体的人类可读名称或标题字符串,主要用于前端 UI 展示、日志记录或模糊检索。 provider_title (str): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 provider_title 实例。 model_id (str): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 model 实例。 tools (Optional[List[str]]): 控制逻辑流向的具体字符串参数,指定了期望的 tools 内容。
Returns: (SystemNodeConfig): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ Returns: (SystemNodeConfig): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。"""
async with self.async_session_maker() as session: async with self.async_session_maker() as session:
statement = select(SystemNodeConfig).where(SystemNodeConfig.node_name == node_name) statement = select(SystemNodeConfig).where(
SystemNodeConfig.node_name == node_name
)
results = await session.execute(statement) results = await session.execute(statement)
config = results.scalar_one_or_none() config = results.scalar_one_or_none()
if config: if config:
@@ -39,7 +49,12 @@ class SystemNodeDatabase:
if tools is not None: if tools is not None:
config.tools = tools config.tools = tools
else: else:
config = SystemNodeConfig(node_name=node_name, provider_title=provider_title, model_id=model_id, tools=tools) config = SystemNodeConfig(
node_name=node_name,
provider_title=provider_title,
model_id=model_id,
tools=tools,
)
session.add(config) session.add(config)
await session.commit() await session.commit()
await session.refresh(config) await session.refresh(config)
@@ -49,19 +64,23 @@ class SystemNodeDatabase:
async def get_all_system_node_configs(self) -> List[SystemNodeConfig]: async def get_all_system_node_configs(self) -> List[SystemNodeConfig]:
"""检索并获取特定的 all system node configs 数据集合或实例对象。 """检索并获取特定的 all system node configs 数据集合或实例对象。
根据提供的查询条件或上下文凭证,从数据库、缓存或第三方服务中读取对应的资源状态。 根据提供的查询条件或上下文凭证,从数据库、缓存或第三方服务中读取对应的资源状态。
Returns: (List[SystemNodeConfig]): 经过筛选、排序或分页处理后的实体对象列表集合。 """ Returns: (List[SystemNodeConfig]): 经过筛选、排序或分页处理后的实体对象列表集合。"""
async with self.async_session_maker() as session: async with self.async_session_maker() as session:
statement = select(SystemNodeConfig) statement = select(SystemNodeConfig)
results = await session.execute(statement) results = await session.execute(statement)
return list(results.scalars().all()) return list(results.scalars().all())
@database_exception @database_exception
async def get_system_node_config(self, node_name: str) -> Optional[SystemNodeConfig]: async def get_system_node_config(
self, node_name: str
) -> Optional[SystemNodeConfig]:
"""检索并获取特定的 system node config 数据集合或实例对象。 """检索并获取特定的 system node config 数据集合或实例对象。
根据提供的查询条件或上下文凭证,从数据库、缓存或第三方服务中读取对应的资源状态。 根据提供的查询条件或上下文凭证,从数据库、缓存或第三方服务中读取对应的资源状态。
Args: node_name (str): 赋予该实体的人类可读名称或标题字符串,主要用于前端 UI 展示、日志记录或模糊检索。 Args: node_name (str): 赋予该实体的人类可读名称或标题字符串,主要用于前端 UI 展示、日志记录或模糊检索。
Returns: (Optional[SystemNodeConfig]): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ Returns: (Optional[SystemNodeConfig]): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。"""
async with self.async_session_maker() as session: async with self.async_session_maker() as session:
statement = select(SystemNodeConfig).where(SystemNodeConfig.node_name == node_name) statement = select(SystemNodeConfig).where(
SystemNodeConfig.node_name == node_name
)
results = await session.execute(statement) results = await session.execute(statement)
return results.scalar_one_or_none() return results.scalar_one_or_none()
+16 -11
View File
@@ -19,9 +19,11 @@ from pretor.core.database.database_exception import database_exception
from pretor.core.database.table.user import UserAuthority from pretor.core.database.table.user import UserAuthority
from pretor.utils.access import Accessor from pretor.utils.access import Accessor
class AuthDatabase: class AuthDatabase:
"""AuthDatabase 核心组件类。 """AuthDatabase 核心组件类。
这是一个数据库操作层 (DAO/Repository) 封装类,专注于处理实体模型与关系型数据库表之间的映射。它将复杂的 SQL 查询、跨表 Join 和事务回滚逻辑进行了高级抽象,向上层服务暴露简洁的数据读写接口。 """ 这是一个数据库操作层 (DAO/Repository) 封装类,专注于处理实体模型与关系型数据库表之间的映射。它将复杂的 SQL 查询、跨表 Join 和事务回滚逻辑进行了高级抽象,向上层服务暴露简洁的数据读写接口。"""
def __init__(self, async_session_maker): def __init__(self, async_session_maker):
self.async_session_maker = async_session_maker self.async_session_maker = async_session_maker
@@ -30,8 +32,9 @@ class AuthDatabase:
"""创建并持久化新的 user 实体。 """创建并持久化新的 user 实体。
接收构建参数,执行必要的数据校验与默认值填充后,将新记录安全地写入底层存储或系统注册表中。 接收构建参数,执行必要的数据校验与默认值填充后,将新记录安全地写入底层存储或系统注册表中。
Args: user_name (str): 赋予该实体的人类可读名称或标题字符串,主要用于前端 UI 展示、日志记录或模糊检索。 hashed_password (str): 控制逻辑流向的具体字符串参数,指定了期望的 hashed password 内容。 Args: user_name (str): 赋予该实体的人类可读名称或标题字符串,主要用于前端 UI 展示、日志记录或模糊检索。 hashed_password (str): 控制逻辑流向的具体字符串参数,指定了期望的 hashed password 内容。
Returns: (User): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ Returns: (User): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。"""
from ulid import ULID from ulid import ULID
async with self.async_session_maker() as session: async with self.async_session_maker() as session:
# Check if any users exist # Check if any users exist
statement = select(User).limit(1) statement = select(User).limit(1)
@@ -46,7 +49,7 @@ class AuthDatabase:
user_id=str(ULID()), user_id=str(ULID()),
user_name=user_name, user_name=user_name,
hashed_password=hashed_password, hashed_password=hashed_password,
user_authority=authority user_authority=authority,
) )
session.add(user) session.add(user)
await session.commit() await session.commit()
@@ -58,7 +61,7 @@ class AuthDatabase:
"""执行与 change password 相关的核心业务流转操作。 """执行与 change password 相关的核心业务流转操作。
该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。 该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。
Args: user_name: 赋予该实体的人类可读名称或标题字符串,主要用于前端 UI 展示、日志记录或模糊检索。 old_password: 参与 change password 逻辑运算或数据构建的上下文依赖对象。 new_password: 参与 change password 逻辑运算或数据构建的上下文依赖对象。 Args: user_name: 赋予该实体的人类可读名称或标题字符串,主要用于前端 UI 展示、日志记录或模糊检索。 old_password: 参与 change password 逻辑运算或数据构建的上下文依赖对象。 new_password: 参与 change password 逻辑运算或数据构建的上下文依赖对象。
Returns: (User): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ Returns: (User): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。"""
async with self.async_session_maker() as session: async with self.async_session_maker() as session:
statement = select(User).where(User.user_name == user_name) statement = select(User).where(User.user_name == user_name)
results = await session.execute(statement) results = await session.execute(statement)
@@ -78,7 +81,7 @@ class AuthDatabase:
"""安全地移除或注销 user。 """安全地移除或注销 user。
执行物理删除或逻辑删除操作,并妥善清理相关的关联数据及占用资源。 执行物理删除或逻辑删除操作,并妥善清理相关的关联数据及占用资源。
Args: user_name (str): 赋予该实体的人类可读名称或标题字符串,主要用于前端 UI 展示、日志记录或模糊检索。 Args: user_name (str): 赋予该实体的人类可读名称或标题字符串,主要用于前端 UI 展示、日志记录或模糊检索。
Returns: (None): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ Returns: (None): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。"""
async with self.async_session_maker() as session: async with self.async_session_maker() as session:
statement = select(User).where(User.user_name == user_name) statement = select(User).where(User.user_name == user_name)
results = await session.execute(statement) results = await session.execute(statement)
@@ -93,7 +96,7 @@ class AuthDatabase:
"""安全地移除或注销 user by id。 """安全地移除或注销 user by id。
执行物理删除或逻辑删除操作,并妥善清理相关的关联数据及占用资源。 执行物理删除或逻辑删除操作,并妥善清理相关的关联数据及占用资源。
Args: user_id (str): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 user 实例。 Args: user_id (str): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 user 实例。
Returns: (None): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ Returns: (None): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。"""
async with self.async_session_maker() as session: async with self.async_session_maker() as session:
user = await session.get(User, user_id) user = await session.get(User, user_id)
if user is None: if user is None:
@@ -106,7 +109,7 @@ class AuthDatabase:
"""执行与 login user 相关的核心业务流转操作。 """执行与 login user 相关的核心业务流转操作。
该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。 该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。
Args: user_name (str): 赋予该实体的人类可读名称或标题字符串,主要用于前端 UI 展示、日志记录或模糊检索。 Args: user_name (str): 赋予该实体的人类可读名称或标题字符串,主要用于前端 UI 展示、日志记录或模糊检索。
Returns: (str): 处理流程所输出的具体字符串产物,可能是新生成的 ID 序列、格式化好的文本片段或 LLM 推理的回答内容。 """ Returns: (str): 处理流程所输出的具体字符串产物,可能是新生成的 ID 序列、格式化好的文本片段或 LLM 推理的回答内容。"""
async with self.async_session_maker() as session: async with self.async_session_maker() as session:
statement = select(User).where(User.user_name == user_name) statement = select(User).where(User.user_name == user_name)
results = await session.execute(statement) results = await session.execute(statement)
@@ -119,7 +122,7 @@ class AuthDatabase:
async def get_all_users(self) -> list[User]: async def get_all_users(self) -> list[User]:
"""检索并获取特定的 all users 数据集合或实例对象。 """检索并获取特定的 all users 数据集合或实例对象。
根据提供的查询条件或上下文凭证,从数据库、缓存或第三方服务中读取对应的资源状态。 根据提供的查询条件或上下文凭证,从数据库、缓存或第三方服务中读取对应的资源状态。
Returns: (list[User]): 经过筛选、排序或分页处理后的实体对象列表集合。 """ Returns: (list[User]): 经过筛选、排序或分页处理后的实体对象列表集合。"""
async with self.async_session_maker() as session: async with self.async_session_maker() as session:
statement = select(User) statement = select(User)
results = await session.execute(statement) results = await session.execute(statement)
@@ -131,7 +134,7 @@ class AuthDatabase:
"""检索并获取特定的 user authority 数据集合或实例对象。 """检索并获取特定的 user authority 数据集合或实例对象。
根据提供的查询条件或上下文凭证,从数据库、缓存或第三方服务中读取对应的资源状态。 根据提供的查询条件或上下文凭证,从数据库、缓存或第三方服务中读取对应的资源状态。
Args: user_id (str): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 user 实例。 Args: user_id (str): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 user 实例。
Returns: (UserAuthority): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ Returns: (UserAuthority): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。"""
async with self.async_session_maker() as session: async with self.async_session_maker() as session:
user = await session.get(User, user_id) user = await session.get(User, user_id)
if user is None: if user is None:
@@ -139,7 +142,9 @@ class AuthDatabase:
return user.user_authority return user.user_authority
@database_exception @database_exception
async def change_user_authority(self, user_id: str, new_authority: UserAuthority) -> User: async def change_user_authority(
self, user_id: str, new_authority: UserAuthority
) -> User:
""" """
Changes the authority level of a specific user. Changes the authority level of a specific user.
@@ -161,4 +166,4 @@ class AuthDatabase:
session.add(user) session.add(user)
await session.commit() await session.commit()
await session.refresh(user) await session.refresh(user)
return user return user
+1
View File
@@ -15,4 +15,5 @@
from pretor.core.database.table.user import User from pretor.core.database.table.user import User
from pretor.core.database.table.provider import Provider from pretor.core.database.table.provider import Provider
from pretor.core.database.table.individual import WorkerIndividual from pretor.core.database.table.individual import WorkerIndividual
__all__ = ["User", "Provider", "WorkerIndividual"] __all__ = ["User", "Provider", "WorkerIndividual"]
+4 -1
View File
@@ -1,5 +1,8 @@
from sqlmodel import SQLModel, Field from sqlmodel import SQLModel, Field
class EventRecord(SQLModel, table=True): class EventRecord(SQLModel, table=True):
trace_id: str = Field(primary_key=True, description="The unique trace ID of the PretorEvent") trace_id: str = Field(
primary_key=True, description="The unique trace ID of the PretorEvent"
)
event_data_json: str = Field(description="The JSON serialized PretorEvent data") event_data_json: str = Field(description="The JSON serialized PretorEvent data")
+10 -4
View File
@@ -17,16 +17,20 @@ from typing import List, Optional
from sqlalchemy import Column, JSON from sqlalchemy import Column, JSON
from enum import Enum from enum import Enum
class AgentType(str, Enum): class AgentType(str, Enum):
"""AgentType 核心组件类。 """AgentType 核心组件类。
这是一个领域数据模型或功能封装类,承载了 AgentType 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。 """ 这是一个领域数据模型或功能封装类,承载了 AgentType 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。"""
SKILL_INDIVIDUAL = "skill_individual" SKILL_INDIVIDUAL = "skill_individual"
ORDINARY_INDIVIDUAL = "ordinary_individual" ORDINARY_INDIVIDUAL = "ordinary_individual"
SPECIAL_INDIVIDUAL = "special_individual" SPECIAL_INDIVIDUAL = "special_individual"
class WorkerIndividual(SQLModel, table=True): class WorkerIndividual(SQLModel, table=True):
"""WorkerIndividual 核心组件类。 """WorkerIndividual 核心组件类。
这是一个具体的 Worker 智能体实体类,代表着具备特定人设、领域技能或长文本处理能力的数字员工。它可以被控制器动态拉起,并在安全沙箱内执行复杂的工作流指令与多步骤推理任务。 """ 这是一个具体的 Worker 智能体实体类,代表着具备特定人设、领域技能或长文本处理能力的数字员工。它可以被控制器动态拉起,并在安全沙箱内执行复杂的工作流指令与多步骤推理任务。"""
__tablename__ = "worker_individual" __tablename__ = "worker_individual"
agent_id: str = Field(primary_key=True) agent_id: str = Field(primary_key=True)
agent_name: str = Field(index=True) agent_name: str = Field(index=True)
@@ -35,8 +39,10 @@ class WorkerIndividual(SQLModel, table=True):
provider_title: str provider_title: str
model_id: str model_id: str
system_prompt: Optional[str] system_prompt: Optional[str]
output_template: Optional[dict] = Field(sa_column=Column(JSON),description="输出模板标识") output_template: Optional[dict] = Field(
sa_column=Column(JSON), description="输出模板标识"
)
bound_skill: Optional[str] = Field(sa_column=Column(JSON)) bound_skill: Optional[str] = Field(sa_column=Column(JSON))
workspace: Optional[List[str]] = Field(sa_column=Column(JSON)) workspace: Optional[List[str]] = Field(sa_column=Column(JSON))
tools: Optional[List[str]] = Field(sa_column=Column(JSON), default=None) tools: Optional[List[str]] = Field(sa_column=Column(JSON), default=None)
owner_id: str owner_id: str
+4 -2
View File
@@ -17,9 +17,11 @@ from typing import List
from sqlalchemy import Column, JSON from sqlalchemy import Column, JSON
from typing import Optional from typing import Optional
class Provider(SQLModel, table=True): class Provider(SQLModel, table=True):
"""Provider 核心组件类。 """Provider 核心组件类。
这是一个模型/服务提供商适配器类,屏蔽了外部不同供应商(如 OpenAI、Anthropic 等)的底层 API 差异。它负责标准化参数组装、网络请求发送、鉴权处理以及响应结构的反序列化。 """ 这是一个模型/服务提供商适配器类,屏蔽了外部不同供应商(如 OpenAI、Anthropic 等)的底层 API 差异。它负责标准化参数组装、网络请求发送、鉴权处理以及响应结构的反序列化。"""
__tablename__ = "provider" __tablename__ = "provider"
provider_id: str = Field(primary_key=True) provider_id: str = Field(primary_key=True)
provider_title: str = Field(index=True) provider_title: str = Field(index=True)
@@ -31,4 +33,4 @@ class Provider(SQLModel, table=True):
provider_models: List[str] = Field(sa_column=Column(JSON)) provider_models: List[str] = Field(sa_column=Column(JSON))
provider_owner: str provider_owner: str
is_active: bool = Field(default=True, description="该服务商节点是否在线/启用") is_active: bool = Field(default=True, description="该服务商节点是否在线/启用")
+3 -1
View File
@@ -17,9 +17,11 @@ from sqlmodel import SQLModel, Field
from typing import List, Optional from typing import List, Optional
from sqlalchemy import Column, JSON from sqlalchemy import Column, JSON
class SystemNodeConfig(SQLModel, table=True): class SystemNodeConfig(SQLModel, table=True):
"""SystemNodeConfig 核心组件类。 """SystemNodeConfig 核心组件类。
这是一个系统执行节点类,作为多智能体架构中的独立处理单元。它能够接收工作流上下文,根据内置的大模型策略进行意图理解和自主决策,从而驱动特定阶段的任务闭环。 """ 这是一个系统执行节点类,作为多智能体架构中的独立处理单元。它能够接收工作流上下文,根据内置的大模型策略进行意图理解和自主决策,从而驱动特定阶段的任务闭环。"""
__tablename__ = "system_node_config" __tablename__ = "system_node_config"
node_name: str = Field(primary_key=True) node_name: str = Field(primary_key=True)
provider_title: str provider_title: str
+7 -4
View File
@@ -15,21 +15,24 @@
from sqlmodel import SQLModel, Field from sqlmodel import SQLModel, Field
from enum import IntEnum from enum import IntEnum
class UserAuthority(IntEnum): class UserAuthority(IntEnum):
"""UserAuthority 核心组件类。 """UserAuthority 核心组件类。
这是一个领域数据模型或功能封装类,承载了 UserAuthority 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。 """ 这是一个领域数据模型或功能封装类,承载了 UserAuthority 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。"""
SUPER_ADMINISTRATOR = 100 SUPER_ADMINISTRATOR = 100
ADMINISTRATOR = 50 ADMINISTRATOR = 50
USER = 20 USER = 20
UNAUTHORIZED_USER = 10 UNAUTHORIZED_USER = 10
GUEST = 0 GUEST = 0
class User(SQLModel, table=True): class User(SQLModel, table=True):
"""User 核心组件类。 """User 核心组件类。
这是一个领域数据模型或功能封装类,承载了 User 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。 """ 这是一个领域数据模型或功能封装类,承载了 User 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。"""
__tablename__ = 'user'
__tablename__ = "user"
user_id: str = Field(primary_key=True) user_id: str = Field(primary_key=True)
user_name: str = Field(index=True) user_name: str = Field(index=True)
hashed_password: str hashed_password: str
user_authority: UserAuthority = Field(default=UserAuthority.USER) user_authority: UserAuthority = Field(default=UserAuthority.USER)
+2 -13
View File
@@ -1,14 +1,3 @@
# Copyright 2026 zhaoxi826 from pretor.core.global_state_machine.global_state_machine import GlobalStateMachine
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__all__ = ["GlobalStateMachine"]
@@ -15,8 +15,7 @@
import ray import ray
from pretor.core.global_state_machine.provider_manager import ProviderManager from pretor.core.global_state_machine.provider_manager import ProviderManager
from pretor.core.global_state_machine.tool_manager import GlobalToolManager from pretor.core.global_state_machine.tool_manager import GlobalToolManager
from pretor.core.database.postgres import PostgresDatabase from pretor.core.postgres_database import PostgresDatabase
from pretor.core.workflow.workflow_template_manager import WorkflowManager
from pretor.core.global_state_machine.skill_manager import GlobalSkillManager from pretor.core.global_state_machine.skill_manager import GlobalSkillManager
from pretor.core.global_state_machine.individual_manager import GlobalIndividualManager from pretor.core.global_state_machine.individual_manager import GlobalIndividualManager
@@ -24,17 +23,17 @@ from pretor.core.global_state_machine.individual_manager import GlobalIndividual
@ray.remote @ray.remote
class GlobalStateMachine: class GlobalStateMachine:
"""GlobalStateMachine 核心组件类。 """GlobalStateMachine 核心组件类。
这是一个领域数据模型或功能封装类,承载了 GlobalStateMachine 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。 """ 这是一个领域数据模型或功能封装类,承载了 GlobalStateMachine 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。"""
def __init__(self, postgres_database: PostgresDatabase): def __init__(self, postgres_database: PostgresDatabase):
import sys import sys
print("GSM __init__ START", file=sys.stderr, flush=True) print("GSM __init__ START", file=sys.stderr, flush=True)
print(" event_dict done", file=sys.stderr, flush=True) print(" event_dict done", file=sys.stderr, flush=True)
self._global_provider_manager = ProviderManager(postgres_database) self._global_provider_manager = ProviderManager(postgres_database)
print(" provider_manager done", file=sys.stderr, flush=True) print(" provider_manager done", file=sys.stderr, flush=True)
self._global_tool_manager = GlobalToolManager() self._global_tool_manager = GlobalToolManager()
print(" tool_manager done", file=sys.stderr, flush=True) print(" tool_manager done", file=sys.stderr, flush=True)
self._global_workflow_template_manager = WorkflowManager()
print(" workflow_template_manager done", file=sys.stderr, flush=True)
self._global_skill_manager = GlobalSkillManager() self._global_skill_manager = GlobalSkillManager()
print(" skill_manager done", file=sys.stderr, flush=True) print(" skill_manager done", file=sys.stderr, flush=True)
self._global_individual_manager = GlobalIndividualManager() self._global_individual_manager = GlobalIndividualManager()
@@ -44,50 +43,63 @@ class GlobalStateMachine:
async def init_state_machine(self): async def init_state_machine(self):
"""完成 state machine 模块的启动与依赖初始化。 """完成 state machine 模块的启动与依赖初始化。
在系统引导或服务拉起阶段被调用,负责建立网络连接、分配基础内存资源及注册核心服务组件。 """ 在系统引导或服务拉起阶段被调用,负责建立网络连接、分配基础内存资源及注册核心服务组件。"""
await self._global_provider_manager.init_provider_register(self.postgres_database) await self._global_provider_manager.init_provider_register(
await self._global_individual_manager.init_individual_register(self.postgres_database) self.postgres_database
)
await self._global_individual_manager.init_individual_register(
self.postgres_database
)
async def add_provider_wrap(self, provider_type, provider_title, provider_url, provider_apikey, provider_owner): async def add_provider_wrap(
self,
provider_type,
provider_title,
provider_url,
provider_apikey,
provider_owner,
):
"""创建并持久化新的 provider wrap 实体。 """创建并持久化新的 provider wrap 实体。
接收构建参数,执行必要的数据校验与默认值填充后,将新记录安全地写入底层存储或系统注册表中。 接收构建参数,执行必要的数据校验与默认值填充后,将新记录安全地写入底层存储或系统注册表中。
Args: provider_type: 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 provider_type 实例。 provider_title: 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 provider_title 实例。 provider_url: 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 provider_url 实例。 provider_apikey: 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 provider_apikey 实例。 provider_owner: 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 provider_owner 实例。 Args: provider_type: 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 provider_type 实例。 provider_title: 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 provider_title 实例。 provider_url: 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 provider_url 实例。 provider_apikey: 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 provider_apikey 实例。 provider_owner: 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 provider_owner 实例。
Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。"""
return await self._global_provider_manager.add_provider( return await self._global_provider_manager.add_provider(
provider_type=provider_type, provider_type=provider_type,
provider_title=provider_title, provider_title=provider_title,
provider_url=provider_url, provider_url=provider_url,
provider_apikey=provider_apikey, provider_apikey=provider_apikey,
provider_owner=provider_owner, provider_owner=provider_owner,
postgres_database=self.postgres_database postgres_database=self.postgres_database,
) )
# Provider Manager Methods # Provider Manager Methods
def get_provider_list(self): def get_provider_list(self):
"""检索并获取特定的 provider list 数据集合或实例对象。 """检索并获取特定的 provider list 数据集合或实例对象。
根据提供的查询条件或上下文凭证,从数据库、缓存或第三方服务中读取对应的资源状态。 根据提供的查询条件或上下文凭证,从数据库、缓存或第三方服务中读取对应的资源状态。
Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。"""
return self._global_provider_manager.get_provider_list() return self._global_provider_manager.get_provider_list()
def get_provider(self, provider_title): def get_provider(self, provider_title):
"""检索并获取特定的 provider 数据集合或实例对象。 """检索并获取特定的 provider 数据集合或实例对象。
根据提供的查询条件或上下文凭证,从数据库、缓存或第三方服务中读取对应的资源状态。 根据提供的查询条件或上下文凭证,从数据库、缓存或第三方服务中读取对应的资源状态。
Args: provider_title: 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 provider_title 实例。 Args: provider_title: 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 provider_title 实例。
Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。"""
return self._global_provider_manager.get_provider(provider_title) return self._global_provider_manager.get_provider(provider_title)
async def delete_provider(self, provider_title: str): async def delete_provider(self, provider_title: str):
"""安全地移除或注销 provider。 """安全地移除或注销 provider。
执行物理删除或逻辑删除操作,并妥善清理相关的关联数据及占用资源。 执行物理删除或逻辑删除操作,并妥善清理相关的关联数据及占用资源。
Args: provider_title (str): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 provider_title 实例。 Args: provider_title (str): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 provider_title 实例。
Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。"""
return await self._global_provider_manager.delete_provider(provider_title, self.postgres_database) return await self._global_provider_manager.delete_provider(
provider_title, self.postgres_database
)
# Tool Manager Methods # Tool Manager Methods
def get_tool_mapper(self): def get_tool_mapper(self):
"""检索并获取特定的 tool mapper 数据集合或实例对象。 """检索并获取特定的 tool mapper 数据集合或实例对象。
根据提供的查询条件或上下文凭证,从数据库、缓存或第三方服务中读取对应的资源状态。 根据提供的查询条件或上下文凭证,从数据库、缓存或第三方服务中读取对应的资源状态。
Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。"""
return self._global_tool_manager.tool_mapper return self._global_tool_manager.tool_mapper
def get_tool_list(self, agent_name: str): def get_tool_list(self, agent_name: str):
@@ -96,60 +108,32 @@ class GlobalStateMachine:
"""检索并获取特定的 tool list 数据集合或实例对象。 """检索并获取特定的 tool list 数据集合或实例对象。
根据提供的查询条件或上下文凭证,从数据库、缓存或第三方服务中读取对应的资源状态。 根据提供的查询条件或上下文凭证,从数据库、缓存或第三方服务中读取对应的资源状态。
Args: agent_name (str): 赋予该实体的人类可读名称或标题字符串,主要用于前端 UI 展示、日志记录或模糊检索。 Args: agent_name (str): 赋予该实体的人类可读名称或标题字符串,主要用于前端 UI 展示、日志记录或模糊检索。
Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。"""
tools = self._global_tool_manager.tool_mapper.get(agent_name, {}) tools = self._global_tool_manager.tool_mapper.get(agent_name, {})
# also include default tools # also include default tools
default_tools = self._global_tool_manager.tool_mapper.get("default", {}) default_tools = self._global_tool_manager.tool_mapper.get("default", {})
merged_tools = {**default_tools, **tools} merged_tools = {**default_tools, **tools}
return merged_tools return merged_tools
# Workflow Template Manager Methods
def get_all_workflow_templates(self):
"""检索并获取特定的 all workflow templates 数据集合或实例对象。
根据提供的查询条件或上下文凭证,从数据库、缓存或第三方服务中读取对应的资源状态。
Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """
return self._global_workflow_template_manager.get_all_workflow_templates()
def add_workflow_template(self, template_name: str, workflow_template):
"""创建并持久化新的 workflow template 实体。
接收构建参数,执行必要的数据校验与默认值填充后,将新记录安全地写入底层存储或系统注册表中。
Args: template_name (str): 赋予该实体的人类可读名称或标题字符串,主要用于前端 UI 展示、日志记录或模糊检索。 workflow_template: 参与 add workflow template 逻辑运算或数据构建的上下文依赖对象。
Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """
return self._global_workflow_template_manager.add_workflow_template(template_name, workflow_template)
def delete_workflow_template(self, template_name: str):
"""安全地移除或注销 workflow template。
执行物理删除或逻辑删除操作,并妥善清理相关的关联数据及占用资源。
Args: template_name (str): 赋予该实体的人类可读名称或标题字符串,主要用于前端 UI 展示、日志记录或模糊检索。
Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """
return self._global_workflow_template_manager.delete_workflow_template(template_name)
def generate_workflow_template(self, workflow_template):
"""执行与 generate workflow template 相关的核心业务流转操作。
该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。
Args: workflow_template: 参与 generate workflow template 逻辑运算或数据构建的上下文依赖对象。
Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """
return self._global_workflow_template_manager.generate_workflow_template(workflow_template)
# Skill Manager Methods # Skill Manager Methods
def add_skill(self, skill_name: str): def add_skill(self, skill_name: str):
"""创建并持久化新的 skill 实体。 """创建并持久化新的 skill 实体。
接收构建参数,执行必要的数据校验与默认值填充后,将新记录安全地写入底层存储或系统注册表中。 接收构建参数,执行必要的数据校验与默认值填充后,将新记录安全地写入底层存储或系统注册表中。
Args: skill_name (str): 赋予该实体的人类可读名称或标题字符串,主要用于前端 UI 展示、日志记录或模糊检索。 Args: skill_name (str): 赋予该实体的人类可读名称或标题字符串,主要用于前端 UI 展示、日志记录或模糊检索。
Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。"""
return self._global_skill_manager.add_skill(skill_name) return self._global_skill_manager.add_skill(skill_name)
def get_skill_list(self): def get_skill_list(self):
"""检索并获取特定的 skill list 数据集合或实例对象。 """检索并获取特定的 skill list 数据集合或实例对象。
根据提供的查询条件或上下文凭证,从数据库、缓存或第三方服务中读取对应的资源状态。 根据提供的查询条件或上下文凭证,从数据库、缓存或第三方服务中读取对应的资源状态。
Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。"""
return self._global_skill_manager.get_skill_list() return self._global_skill_manager.get_skill_list()
def remove_skill(self, skill_name: str): def remove_skill(self, skill_name: str):
"""安全地移除或注销 skill。 """安全地移除或注销 skill。
执行物理删除或逻辑删除操作,并妥善清理相关的关联数据及占用资源。 执行物理删除或逻辑删除操作,并妥善清理相关的关联数据及占用资源。
Args: skill_name (str): 赋予该实体的人类可读名称或标题字符串,主要用于前端 UI 展示、日志记录或模糊检索。 Args: skill_name (str): 赋予该实体的人类可读名称或标题字符串,主要用于前端 UI 展示、日志记录或模糊检索。
Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。"""
return self._global_skill_manager.remove_skill(skill_name) return self._global_skill_manager.remove_skill(skill_name)
# Individual Manager Methods # Individual Manager Methods
@@ -157,26 +141,25 @@ class GlobalStateMachine:
"""创建并持久化新的 individual 实体。 """创建并持久化新的 individual 实体。
接收构建参数,执行必要的数据校验与默认值填充后,将新记录安全地写入底层存储或系统注册表中。 接收构建参数,执行必要的数据校验与默认值填充后,将新记录安全地写入底层存储或系统注册表中。
Args: agent_id (str): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 agent 实例。 config: 驱动该模块运行的核心配置字典或 Pydantic 数据模型,定义了重试策略、超时时间及模型参数等选项。 Args: agent_id (str): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 agent 实例。 config: 驱动该模块运行的核心配置字典或 Pydantic 数据模型,定义了重试策略、超时时间及模型参数等选项。
Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。"""
return self._global_individual_manager.add_individual(agent_id, config) return self._global_individual_manager.add_individual(agent_id, config)
def get_individual(self, agent_id: str): def get_individual(self, agent_id: str):
"""检索并获取特定的 individual 数据集合或实例对象。 """检索并获取特定的 individual 数据集合或实例对象。
根据提供的查询条件或上下文凭证,从数据库、缓存或第三方服务中读取对应的资源状态。 根据提供的查询条件或上下文凭证,从数据库、缓存或第三方服务中读取对应的资源状态。
Args: agent_id (str): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 agent 实例。 Args: agent_id (str): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 agent 实例。
Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。"""
return self._global_individual_manager.get_individual(agent_id) return self._global_individual_manager.get_individual(agent_id)
def remove_individual(self, agent_id: str): def remove_individual(self, agent_id: str):
"""安全地移除或注销 individual。 """安全地移除或注销 individual。
执行物理删除或逻辑删除操作,并妥善清理相关的关联数据及占用资源。 执行物理删除或逻辑删除操作,并妥善清理相关的关联数据及占用资源。
Args: agent_id (str): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 agent 实例。 Args: agent_id (str): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 agent 实例。
Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。"""
return self._global_individual_manager.remove_individual(agent_id) return self._global_individual_manager.remove_individual(agent_id)
def list_individuals(self): def list_individuals(self):
"""执行与 list individuals 相关的核心业务流转操作。 """执行与 list individuals 相关的核心业务流转操作。
该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。 该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。
Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。"""
return self._global_individual_manager.list_individuals() return self._global_individual_manager.list_individuals()
@@ -14,11 +14,14 @@
from typing import Dict, Any from typing import Dict, Any
from pretor.utils.logger import get_logger from pretor.utils.logger import get_logger
logger = get_logger('individual_manager')
logger = get_logger("individual_manager")
class GlobalIndividualManager: class GlobalIndividualManager:
"""GlobalIndividualManager 核心组件类。 """GlobalIndividualManager 核心组件类。
这是一个管理器类,职责集中在维护整个系统内有关 GlobalIndividual 资源的全局生命周期。它提供了注册机制、状态同步以及跨组件的统一查询入口,确保系统中该类型资源的实例一致性与可控性。 """ 这是一个管理器类,职责集中在维护整个系统内有关 GlobalIndividual 资源的全局生命周期。它提供了注册机制、状态同步以及跨组件的统一查询入口,确保系统中该类型资源的实例一致性与可控性。"""
def __init__(self): def __init__(self):
self._individuals: Dict[str, Dict[str, Any]] = {} self._individuals: Dict[str, Dict[str, Any]] = {}
@@ -26,21 +29,31 @@ class GlobalIndividualManager:
"""完成 individual register 模块的启动与依赖初始化。 """完成 individual register 模块的启动与依赖初始化。
在系统引导或服务拉起阶段被调用,负责建立网络连接、分配基础内存资源及注册核心服务组件。 在系统引导或服务拉起阶段被调用,负责建立网络连接、分配基础内存资源及注册核心服务组件。
Args: postgres: 参与 init individual register 逻辑运算或数据构建的上下文依赖对象。 Args: postgres: 参与 init individual register 逻辑运算或数据构建的上下文依赖对象。
Returns: (None): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ Returns: (None): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。"""
try: try:
try: try:
individuals = await postgres.get_all_worker_individual.remote() individuals = await postgres.get_all_worker_individual.remote()
for ind in individuals: for ind in individuals:
agent_id = getattr(ind, 'agent_id', None) agent_id = getattr(ind, "agent_id", None)
if agent_id: if agent_id:
self._individuals[agent_id] = ind.model_dump() if hasattr(ind, 'model_dump') else dict(ind) self._individuals[agent_id] = (
logger.info(f"成功从数据库拉取了 {len(self._individuals)} 个 Worker Individual 配置。") ind.model_dump()
if hasattr(ind, "model_dump")
else dict(ind)
)
logger.info(
f"成功从数据库拉取了 {len(self._individuals)} 个 Worker Individual 配置。"
)
except AttributeError: except AttributeError:
logger.warning("数据库中 get_all_worker_individual 方法未实现,跳过全量加载。可以在将来完善该接口。") logger.warning(
"数据库中 get_all_worker_individual 方法未实现,跳过全量加载。可以在将来完善该接口。"
)
except Exception as e: except Exception as e:
# 捕获因 Ray 调用目标方法不存在引发的异常 # 捕获因 Ray 调用目标方法不存在引发的异常
if "has no attribute 'get_all_worker_individual'" in str(e): if "has no attribute 'get_all_worker_individual'" in str(e):
logger.warning("数据库 individual_database 中缺少 get_all_worker_individual 方法,无法全量拉取。") logger.warning(
"数据库 individual_database 中缺少 get_all_worker_individual 方法,无法全量拉取。"
)
else: else:
raise e raise e
except Exception as e: except Exception as e:
@@ -64,12 +77,12 @@ class GlobalIndividualManager:
"""安全地移除或注销 individual。 """安全地移除或注销 individual。
执行物理删除或逻辑删除操作,并妥善清理相关的关联数据及占用资源。 执行物理删除或逻辑删除操作,并妥善清理相关的关联数据及占用资源。
Args: agent_id (str): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 agent 实例。 Args: agent_id (str): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 agent 实例。
Returns: (None): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ Returns: (None): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。"""
if agent_id in self._individuals: if agent_id in self._individuals:
del self._individuals[agent_id] del self._individuals[agent_id]
def list_individuals(self) -> Dict[str, Dict[str, Any]]: def list_individuals(self) -> Dict[str, Dict[str, Any]]:
"""执行与 list individuals 相关的核心业务流转操作。 """执行与 list individuals 相关的核心业务流转操作。
该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。 该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。
Returns: (Dict[str, Dict[str, Any]]): 高度聚合的字典结构数据,将多维度的属性特征或统计指标组合后一并返回。 """ Returns: (Dict[str, Dict[str, Any]]): 高度聚合的字典结构数据,将多维度的属性特征或统计指标组合后一并返回。"""
return self._individuals return self._individuals
@@ -12,8 +12,24 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from pretor.core.global_state_machine.model_provider.base_provider import Provider, ProviderArgs from pretor.core.global_state_machine.model_provider.base_provider import (
from pretor.core.global_state_machine.model_provider.openai_provider import OpenAIProvider Provider,
from pretor.core.global_state_machine.model_provider.claude_provider import ClaudeProvider ProviderArgs,
from pretor.core.global_state_machine.model_provider.deepseek_provider import DeepseekProvider )
__all__ = ["Provider", "ProviderArgs", "OpenAIProvider", "ClaudeProvider", "DeepseekProvider"] from pretor.core.global_state_machine.model_provider.openai_provider import (
OpenAIProvider,
)
from pretor.core.global_state_machine.model_provider.claude_provider import (
ClaudeProvider,
)
from pretor.core.global_state_machine.model_provider.deepseek_provider import (
DeepseekProvider,
)
__all__ = [
"Provider",
"ProviderArgs",
"OpenAIProvider",
"ClaudeProvider",
"DeepseekProvider",
]
@@ -17,15 +17,19 @@ from pydantic import BaseModel
from typing import List from typing import List
from enum import Enum from enum import Enum
class ProviderStatus(str, Enum): class ProviderStatus(str, Enum):
"""ProviderStatus 核心组件类。 """ProviderStatus 核心组件类。
这是一个模型/服务提供商适配器类,屏蔽了外部不同供应商(如 OpenAI、Anthropic 等)的底层 API 差异。它负责标准化参数组装、网络请求发送、鉴权处理以及响应结构的反序列化。 """ 这是一个模型/服务提供商适配器类,屏蔽了外部不同供应商(如 OpenAI、Anthropic 等)的底层 API 差异。它负责标准化参数组装、网络请求发送、鉴权处理以及响应结构的反序列化。"""
UP = "up" UP = "up"
DOWN = "down" DOWN = "down"
class Provider(BaseModel): class Provider(BaseModel):
"""Provider 核心组件类。 """Provider 核心组件类。
这是一个模型/服务提供商适配器类,屏蔽了外部不同供应商(如 OpenAI、Anthropic 等)的底层 API 差异。它负责标准化参数组装、网络请求发送、鉴权处理以及响应结构的反序列化。 """ 这是一个模型/服务提供商适配器类,屏蔽了外部不同供应商(如 OpenAI、Anthropic 等)的底层 API 差异。它负责标准化参数组装、网络请求发送、鉴权处理以及响应结构的反序列化。"""
provider_title: str provider_title: str
provider_url: str provider_url: str
provider_apikey: str provider_apikey: str
@@ -34,17 +38,21 @@ class Provider(BaseModel):
provider_owner: str | None = None provider_owner: str | None = None
provider_status: ProviderStatus = ProviderStatus.UP provider_status: ProviderStatus = ProviderStatus.UP
class ProviderArgs(BaseModel): class ProviderArgs(BaseModel):
"""ProviderArgs 核心组件类。 """ProviderArgs 核心组件类。
这是一个模型/服务提供商适配器类,屏蔽了外部不同供应商(如 OpenAI、Anthropic 等)的底层 API 差异。它负责标准化参数组装、网络请求发送、鉴权处理以及响应结构的反序列化。 """ 这是一个模型/服务提供商适配器类,屏蔽了外部不同供应商(如 OpenAI、Anthropic 等)的底层 API 差异。它负责标准化参数组装、网络请求发送、鉴权处理以及响应结构的反序列化。"""
provider_title: str provider_title: str
provider_url: str provider_url: str
provider_apikey: str provider_apikey: str
provider_owner: str provider_owner: str
class BaseProvider(ABC): class BaseProvider(ABC):
"""BaseProvider 核心组件类。 """BaseProvider 核心组件类。
这是一个模型/服务提供商适配器类,屏蔽了外部不同供应商(如 OpenAI、Anthropic 等)的底层 API 差异。它负责标准化参数组装、网络请求发送、鉴权处理以及响应结构的反序列化。 """ 这是一个模型/服务提供商适配器类,屏蔽了外部不同供应商(如 OpenAI、Anthropic 等)的底层 API 差异。它负责标准化参数组装、网络请求发送、鉴权处理以及响应结构的反序列化。"""
@staticmethod @staticmethod
@abstractmethod @abstractmethod
async def create_provider(provider_args: ProviderArgs) -> Provider: async def create_provider(provider_args: ProviderArgs) -> Provider:
@@ -83,7 +91,9 @@ class BaseProvider(ABC):
@staticmethod @staticmethod
@abstractmethod @abstractmethod
def _return_provider(provider_args: ProviderArgs, provider_models: List[str]) -> Provider: def _return_provider(
provider_args: ProviderArgs, provider_models: List[str]
) -> Provider:
""" """
包装Provider对象并返回 包装Provider对象并返回
将provider_args和_load_models获取的方法包装为provider对象 将provider_args和_load_models获取的方法包装为provider对象
@@ -100,5 +110,3 @@ class BaseProvider(ABC):
返回一个Provider对象 返回一个Provider对象
""" """
pass pass
@@ -14,21 +14,29 @@
from pretor.utils.retry import retry_on_retryable_error from pretor.utils.retry import retry_on_retryable_error
from pretor.core.global_state_machine.model_provider.base_provider import BaseProvider, Provider, ProviderArgs from pretor.core.global_state_machine.model_provider.base_provider import (
BaseProvider,
Provider,
ProviderArgs,
)
import httpx import httpx
from typing import List from typing import List
class ClaudeProvider(BaseProvider): class ClaudeProvider(BaseProvider):
"""ClaudeProvider 核心组件类。 """ClaudeProvider 核心组件类。
这是一个模型/服务提供商适配器类,屏蔽了外部不同供应商(如 OpenAI、Anthropic 等)的底层 API 差异。它负责标准化参数组装、网络请求发送、鉴权处理以及响应结构的反序列化。 """ 这是一个模型/服务提供商适配器类,屏蔽了外部不同供应商(如 OpenAI、Anthropic 等)的底层 API 差异。它负责标准化参数组装、网络请求发送、鉴权处理以及响应结构的反序列化。"""
@staticmethod @staticmethod
async def create_provider(provider_args: ProviderArgs) -> Provider: async def create_provider(provider_args: ProviderArgs) -> Provider:
"""创建并持久化新的 provider 实体。 """创建并持久化新的 provider 实体。
接收构建参数,执行必要的数据校验与默认值填充后,将新记录安全地写入底层存储或系统注册表中。 接收构建参数,执行必要的数据校验与默认值填充后,将新记录安全地写入底层存储或系统注册表中。
Args: provider_args (ProviderArgs): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 provider_args 实例。 Args: provider_args (ProviderArgs): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 provider_args 实例。
Returns: (Provider): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ Returns: (Provider): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。"""
provider_models: List[str] = await ClaudeProvider._load_models(provider_args) provider_models: List[str] = await ClaudeProvider._load_models(provider_args)
provider: Provider = ClaudeProvider._return_provider(provider_args, provider_models) provider: Provider = ClaudeProvider._return_provider(
provider_args, provider_models
)
return provider return provider
@staticmethod @staticmethod
@@ -38,11 +46,11 @@ class ClaudeProvider(BaseProvider):
"""执行与 load models 相关的核心业务流转操作。 """执行与 load models 相关的核心业务流转操作。
该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。 该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。
Args: provider_args (ProviderArgs): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 provider_args 实例。 Args: provider_args (ProviderArgs): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 provider_args 实例。
Returns: (List[str]): 经过筛选、排序或分页处理后的实体对象列表集合。 """ Returns: (List[str]): 经过筛选、排序或分页处理后的实体对象列表集合。"""
headers = { headers = {
"x-api-key": provider_args.provider_apikey, "x-api-key": provider_args.provider_apikey,
"anthropic-version": "2023-06-01", "anthropic-version": "2023-06-01",
"Content-Type": "application/json" "Content-Type": "application/json",
} }
# 如果是官方 API,通常使用 /v1/models (如果支持) # 如果是官方 API,通常使用 /v1/models (如果支持)
# 注意:很多时候 Anthropic 并不返回完整列表,如果请求失败,建议返回硬编码的常用模型 # 注意:很多时候 Anthropic 并不返回完整列表,如果请求失败,建议返回硬编码的常用模型
@@ -57,19 +65,27 @@ class ClaudeProvider(BaseProvider):
return sorted(model_ids) return sorted(model_ids)
else: else:
# 如果官方列表接口不可用,fallback 到已知常用模型 # 如果官方列表接口不可用,fallback 到已知常用模型
return ["claude-3-5-sonnet-20240620", "claude-3-opus-20240229", "claude-3-haiku-20240307"] return [
"claude-3-5-sonnet-20240620",
"claude-3-opus-20240229",
"claude-3-haiku-20240307",
]
except Exception as e: except Exception as e:
print(f"[{provider_args.provider_title}] 获取 Claude 模型列表错误: {e}") print(f"[{provider_args.provider_title}] 获取 Claude 模型列表错误: {e}")
return [] return []
@staticmethod @staticmethod
def _return_provider(provider_args: ProviderArgs, provider_models: List[str]) -> Provider: def _return_provider(
provider_args: ProviderArgs, provider_models: List[str]
) -> Provider:
"""执行与 return provider 相关的核心业务流转操作。 """执行与 return provider 相关的核心业务流转操作。
该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。 该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。
Args: provider_args (ProviderArgs): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 provider_args 实例。 provider_models (List[str]): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 provider_models 实例。 Args: provider_args (ProviderArgs): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 provider_args 实例。 provider_models (List[str]): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 provider_models 实例。
Returns: (Provider): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ Returns: (Provider): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。"""
return Provider(provider_title=provider_args.provider_title, return Provider(
provider_apikey=provider_args.provider_apikey, provider_title=provider_args.provider_title,
provider_url=provider_args.provider_url, provider_apikey=provider_args.provider_apikey,
provider_models=provider_models, provider_url=provider_args.provider_url,
provider_type="claude") provider_models=provider_models,
provider_type="claude",
)
@@ -13,21 +13,29 @@
# limitations under the License. # limitations under the License.
from pretor.utils.retry import retry_on_retryable_error from pretor.utils.retry import retry_on_retryable_error
from pretor.core.global_state_machine.model_provider.base_provider import BaseProvider, Provider, ProviderArgs from pretor.core.global_state_machine.model_provider.base_provider import (
BaseProvider,
Provider,
ProviderArgs,
)
import httpx import httpx
from typing import List from typing import List
class DeepseekProvider(BaseProvider): class DeepseekProvider(BaseProvider):
"""DeepseekProvider 核心组件类。 """DeepseekProvider 核心组件类。
这是一个模型/服务提供商适配器类,屏蔽了外部不同供应商(如 OpenAI、Anthropic 等)的底层 API 差异。它负责标准化参数组装、网络请求发送、鉴权处理以及响应结构的反序列化。 """ 这是一个模型/服务提供商适配器类,屏蔽了外部不同供应商(如 OpenAI、Anthropic 等)的底层 API 差异。它负责标准化参数组装、网络请求发送、鉴权处理以及响应结构的反序列化。"""
@staticmethod @staticmethod
async def create_provider(provider_args: ProviderArgs) -> Provider: async def create_provider(provider_args: ProviderArgs) -> Provider:
"""创建并持久化新的 provider 实体。 """创建并持久化新的 provider 实体。
接收构建参数,执行必要的数据校验与默认值填充后,将新记录安全地写入底层存储或系统注册表中。 接收构建参数,执行必要的数据校验与默认值填充后,将新记录安全地写入底层存储或系统注册表中。
Args: provider_args (ProviderArgs): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 provider_args 实例。 Args: provider_args (ProviderArgs): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 provider_args 实例。
Returns: (Provider): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ Returns: (Provider): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。"""
provider_models: List[str] = await DeepseekProvider._load_models(provider_args) provider_models: List[str] = await DeepseekProvider._load_models(provider_args)
provider: Provider = DeepseekProvider._return_provider(provider_args, provider_models) provider: Provider = DeepseekProvider._return_provider(
provider_args, provider_models
)
return provider return provider
@staticmethod @staticmethod
@@ -36,17 +44,23 @@ class DeepseekProvider(BaseProvider):
"""执行与 load models 相关的核心业务流转操作。 """执行与 load models 相关的核心业务流转操作。
该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。 该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。
Args: provider_args (ProviderArgs): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 provider_args 实例。 Args: provider_args (ProviderArgs): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 provider_args 实例。
Returns: (List[str]): 经过筛选、排序或分页处理后的实体对象列表集合。 """ Returns: (List[str]): 经过筛选、排序或分页处理后的实体对象列表集合。"""
headers = { headers = {
"Authorization": f"Bearer {provider_args.provider_apikey}", "Authorization": f"Bearer {provider_args.provider_apikey}",
"Content-Type": "application/json" "Content-Type": "application/json",
} }
url = f"{provider_args.provider_url}/models" if "/v1" in provider_args.provider_url else f"{provider_args.provider_url}/v1/models" url = (
f"{provider_args.provider_url}/models"
if "/v1" in provider_args.provider_url
else f"{provider_args.provider_url}/v1/models"
)
try: try:
async with httpx.AsyncClient(timeout=10.0) as client: async with httpx.AsyncClient(timeout=10.0) as client:
response = await client.get(url, headers=headers) response = await client.get(url, headers=headers)
if response.status_code != 200: if response.status_code != 200:
print(f"[{provider_args.provider_title}] 获取模型失败: {response.status_code}") print(
f"[{provider_args.provider_title}] 获取模型失败: {response.status_code}"
)
return [] return []
data = response.json() data = response.json()
raw_models = data.get("data", []) raw_models = data.get("data", [])
@@ -54,20 +68,27 @@ class DeepseekProvider(BaseProvider):
return sorted(model_ids) return sorted(model_ids)
except httpx.RequestError as e: except httpx.RequestError as e:
from pretor.utils.error import RetryableError from pretor.utils.error import RetryableError
print(f"[{provider_args.provider_title}] 网络请求异常: {e}") print(f"[{provider_args.provider_title}] 网络请求异常: {e}")
raise RetryableError(f"[{provider_args.provider_title}] 网络请求异常: {e}") from e raise RetryableError(
f"[{provider_args.provider_title}] 网络请求异常: {e}"
) from e
except Exception as e: except Exception as e:
print(f"[{provider_args.provider_title}] 解析模型列表时发生错误: {e}") print(f"[{provider_args.provider_title}] 解析模型列表时发生错误: {e}")
return [] return []
@staticmethod @staticmethod
def _return_provider(provider_args: ProviderArgs, provider_models: List[str]) -> Provider: def _return_provider(
provider_args: ProviderArgs, provider_models: List[str]
) -> Provider:
"""执行与 return provider 相关的核心业务流转操作。 """执行与 return provider 相关的核心业务流转操作。
该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。 该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。
Args: provider_args (ProviderArgs): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 provider_args 实例。 provider_models (List[str]): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 provider_models 实例。 Args: provider_args (ProviderArgs): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 provider_args 实例。 provider_models (List[str]): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 provider_models 实例。
Returns: (Provider): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ Returns: (Provider): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。"""
return Provider(provider_title=provider_args.provider_title, return Provider(
provider_apikey=provider_args.provider_apikey, provider_title=provider_args.provider_title,
provider_url=provider_args.provider_url, provider_apikey=provider_args.provider_apikey,
provider_models=provider_models, provider_url=provider_args.provider_url,
provider_type="deepseek") provider_models=provider_models,
provider_type="deepseek",
)
@@ -13,21 +13,29 @@
# limitations under the License. # limitations under the License.
from pretor.utils.retry import retry_on_retryable_error from pretor.utils.retry import retry_on_retryable_error
from pretor.core.global_state_machine.model_provider.base_provider import BaseProvider, Provider, ProviderArgs from pretor.core.global_state_machine.model_provider.base_provider import (
BaseProvider,
Provider,
ProviderArgs,
)
import httpx import httpx
from typing import List from typing import List
class OpenAIProvider(BaseProvider): class OpenAIProvider(BaseProvider):
"""OpenAIProvider 核心组件类。 """OpenAIProvider 核心组件类。
这是一个模型/服务提供商适配器类,屏蔽了外部不同供应商(如 OpenAI、Anthropic 等)的底层 API 差异。它负责标准化参数组装、网络请求发送、鉴权处理以及响应结构的反序列化。 """ 这是一个模型/服务提供商适配器类,屏蔽了外部不同供应商(如 OpenAI、Anthropic 等)的底层 API 差异。它负责标准化参数组装、网络请求发送、鉴权处理以及响应结构的反序列化。"""
@staticmethod @staticmethod
async def create_provider(provider_args: ProviderArgs) -> Provider: async def create_provider(provider_args: ProviderArgs) -> Provider:
"""创建并持久化新的 provider 实体。 """创建并持久化新的 provider 实体。
接收构建参数,执行必要的数据校验与默认值填充后,将新记录安全地写入底层存储或系统注册表中。 接收构建参数,执行必要的数据校验与默认值填充后,将新记录安全地写入底层存储或系统注册表中。
Args: provider_args (ProviderArgs): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 provider_args 实例。 Args: provider_args (ProviderArgs): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 provider_args 实例。
Returns: (Provider): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ Returns: (Provider): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。"""
provider_models: List[str] = await OpenAIProvider._load_models(provider_args) provider_models: List[str] = await OpenAIProvider._load_models(provider_args)
provider: Provider = OpenAIProvider._return_provider(provider_args, provider_models) provider: Provider = OpenAIProvider._return_provider(
provider_args, provider_models
)
return provider return provider
@staticmethod @staticmethod
@@ -36,17 +44,23 @@ class OpenAIProvider(BaseProvider):
"""执行与 load models 相关的核心业务流转操作。 """执行与 load models 相关的核心业务流转操作。
该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。 该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。
Args: provider_args (ProviderArgs): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 provider_args 实例。 Args: provider_args (ProviderArgs): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 provider_args 实例。
Returns: (List[str]): 经过筛选、排序或分页处理后的实体对象列表集合。 """ Returns: (List[str]): 经过筛选、排序或分页处理后的实体对象列表集合。"""
headers = { headers = {
"Authorization": f"Bearer {provider_args.provider_apikey}", "Authorization": f"Bearer {provider_args.provider_apikey}",
"Content-Type": "application/json" "Content-Type": "application/json",
} }
url = f"{provider_args.provider_url}/models" if "/v1" in provider_args.provider_url else f"{provider_args.provider_url}/v1/models" url = (
f"{provider_args.provider_url}/models"
if "/v1" in provider_args.provider_url
else f"{provider_args.provider_url}/v1/models"
)
try: try:
async with httpx.AsyncClient(timeout=10.0) as client: async with httpx.AsyncClient(timeout=10.0) as client:
response = await client.get(url, headers=headers) response = await client.get(url, headers=headers)
if response.status_code != 200: if response.status_code != 200:
print(f"[{provider_args.provider_title}] 获取模型失败: {response.status_code}") print(
f"[{provider_args.provider_title}] 获取模型失败: {response.status_code}"
)
return [] return []
data = response.json() data = response.json()
raw_models = data.get("data", []) raw_models = data.get("data", [])
@@ -54,20 +68,27 @@ class OpenAIProvider(BaseProvider):
return sorted(model_ids) return sorted(model_ids)
except httpx.RequestError as e: except httpx.RequestError as e:
from pretor.utils.error import RetryableError from pretor.utils.error import RetryableError
print(f"[{provider_args.provider_title}] 网络请求异常: {e}") print(f"[{provider_args.provider_title}] 网络请求异常: {e}")
raise RetryableError(f"[{provider_args.provider_title}] 网络请求异常: {e}") from e raise RetryableError(
f"[{provider_args.provider_title}] 网络请求异常: {e}"
) from e
except Exception as e: except Exception as e:
print(f"[{provider_args.provider_title}] 解析模型列表时发生错误: {e}") print(f"[{provider_args.provider_title}] 解析模型列表时发生错误: {e}")
return [] return []
@staticmethod @staticmethod
def _return_provider(provider_args: ProviderArgs, provider_models: List[str]) -> Provider: def _return_provider(
provider_args: ProviderArgs, provider_models: List[str]
) -> Provider:
"""执行与 return provider 相关的核心业务流转操作。 """执行与 return provider 相关的核心业务流转操作。
该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。 该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。
Args: provider_args (ProviderArgs): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 provider_args 实例。 provider_models (List[str]): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 provider_models 实例。 Args: provider_args (ProviderArgs): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 provider_args 实例。 provider_models (List[str]): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 provider_models 实例。
Returns: (Provider): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ Returns: (Provider): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。"""
return Provider(provider_title=provider_args.provider_title, return Provider(
provider_apikey=provider_args.provider_apikey, provider_title=provider_args.provider_title,
provider_url=provider_args.provider_url, provider_apikey=provider_args.provider_apikey,
provider_models=provider_models, provider_url=provider_args.provider_url,
provider_type="openai") provider_models=provider_models,
provider_type="openai",
)
@@ -12,51 +12,73 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from pretor.core.global_state_machine.model_provider import Provider, OpenAIProvider, ClaudeProvider, DeepseekProvider from pretor.core.global_state_machine.model_provider import (
Provider,
OpenAIProvider,
ClaudeProvider,
DeepseekProvider,
)
from typing import Dict, Type from typing import Dict, Type
class ProviderManager: class ProviderManager:
""" """
模型供应商管理器 (ProviderManager)。 模型供应商管理器 (ProviderManager)。
负责维护不同的 LLM 协议适配器,提供从配置注册到 Agent 实例化的全生命周期管理。 负责维护不同的 LLM 协议适配器,提供从配置注册到 Agent 实例化的全生命周期管理。
""" """
# --- 类属性显式标注 (IDE 友好) --- # --- 类属性显式标注 (IDE 友好) ---
provider_mapper: Dict[str, Type[Provider]] provider_mapper: Dict[str, Type[Provider]]
"""协议映射表:键为协议名(如 'openai'),值为对应的 Provider 类。""" """协议映射表:键为协议名(如 'openai'),值为对应的 Provider 类。"""
provider_register: Dict[str, Provider] provider_register: Dict[str, Provider]
"""供应商注册表:键为用户自定义别名,值为已实例化的 Provider 对象。""" """供应商注册表:键为用户自定义别名,值为已实例化的 Provider 对象。"""
def __init__(self, postgres): def __init__(self, postgres):
self.provider_mapper = {"openai": OpenAIProvider, self.provider_mapper = {
"claude": ClaudeProvider, "openai": OpenAIProvider,
"deepseek": DeepseekProvider} "claude": ClaudeProvider,
"deepseek": DeepseekProvider,
}
self.provider_register = {} self.provider_register = {}
async def init_provider_register(self, postgres) -> None: async def init_provider_register(self, postgres) -> None:
"""完成 provider register 模块的启动与依赖初始化。 """完成 provider register 模块的启动与依赖初始化。
在系统引导或服务拉起阶段被调用,负责建立网络连接、分配基础内存资源及注册核心服务组件。 在系统引导或服务拉起阶段被调用,负责建立网络连接、分配基础内存资源及注册核心服务组件。
Args: postgres: 参与 init provider register 逻辑运算或数据构建的上下文依赖对象。 Args: postgres: 参与 init provider register 逻辑运算或数据构建的上下文依赖对象。
Returns: (None): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ Returns: (None): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。"""
providers = await postgres.get_provider.remote() providers = await postgres.get_provider.remote()
for provider in providers: for provider in providers:
self.provider_register[provider.provider_title] = provider self.provider_register[provider.provider_title] = provider
async def add_provider(self, provider_type, provider_title, provider_url, provider_apikey, provider_owner, postgres_database) -> None: async def add_provider(
self,
provider_type,
provider_title,
provider_url,
provider_apikey,
provider_owner,
postgres_database,
) -> None:
"""创建并持久化新的 provider 实体。 """创建并持久化新的 provider 实体。
接收构建参数,执行必要的数据校验与默认值填充后,将新记录安全地写入底层存储或系统注册表中。 接收构建参数,执行必要的数据校验与默认值填充后,将新记录安全地写入底层存储或系统注册表中。
Args: provider_type: 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 provider_type 实例。 provider_title: 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 provider_title 实例。 provider_url: 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 provider_url 实例。 provider_apikey: 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 provider_apikey 实例。 provider_owner: 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 provider_owner 实例。 postgres_database: 从客户端传递过来或由上游组件生成的核心业务数据体,通常需要进一步的清洗和结构化解析。 Args: provider_type: 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 provider_type 实例。 provider_title: 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 provider_title 实例。 provider_url: 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 provider_url 实例。 provider_apikey: 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 provider_apikey 实例。 provider_owner: 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 provider_owner 实例。 postgres_database: 从客户端传递过来或由上游组件生成的核心业务数据体,通常需要进一步的清洗和结构化解析。
Returns: (None): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ Returns: (None): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。"""
from pretor.core.global_state_machine.model_provider import ProviderArgs from pretor.core.global_state_machine.model_provider import ProviderArgs
from pretor.utils.logger import get_logger from pretor.utils.logger import get_logger
logger = get_logger('provider_manager')
logger = get_logger("provider_manager")
import httpx import httpx
provider_args: ProviderArgs = ProviderArgs(provider_title=provider_title, provider_args: ProviderArgs = ProviderArgs(
provider_url=provider_url, provider_title=provider_title,
provider_apikey=provider_apikey, provider_url=provider_url,
provider_owner=provider_owner) provider_apikey=provider_apikey,
provider_owner=provider_owner,
)
try: try:
import ulid import ulid
provider_class = self.provider_mapper.get(provider_type, None) provider_class = self.provider_mapper.get(provider_type, None)
if provider_class is None: if provider_class is None:
logger.warning(f"Provider type {provider_type} is not supported.") logger.warning(f"Provider type {provider_type} is not supported.")
@@ -65,41 +87,49 @@ class ProviderManager:
provider.provider_owner = provider_owner provider.provider_owner = provider_owner
self.provider_register[provider_title] = provider self.provider_register[provider_title] = provider
await postgres_database.add_provider_db.remote( await postgres_database.add_provider_db.remote(
provider_id=str(ulid.ULID()), provider_id=str(ulid.ULID()),
provider_title=provider.provider_title, provider_title=provider.provider_title,
provider_url=provider.provider_url, provider_url=provider.provider_url,
provider_apikey=provider.provider_apikey, provider_apikey=provider.provider_apikey,
provider_models=provider.provider_models, provider_models=provider.provider_models,
provider_type=provider.provider_type, provider_type=provider.provider_type,
provider_owner=provider.provider_owner) provider_owner=provider.provider_owner,
)
logger.info(f"已添加适配器{provider_title}") logger.info(f"已添加适配器{provider_title}")
except httpx.RequestError as e: except httpx.RequestError as e:
from pretor.utils.error import RetryableError from pretor.utils.error import RetryableError
logger.warning(f"[{provider_args.provider_title}] 网络请求异常: {e}") logger.warning(f"[{provider_args.provider_title}] 网络请求异常: {e}")
raise RetryableError(f"[{provider_args.provider_title}] 网络请求异常: {e}") from e raise RetryableError(
f"[{provider_args.provider_title}] 网络请求异常: {e}"
) from e
except Exception as e: except Exception as e:
logger.warning(f"[{provider_args.provider_title}] 解析模型列表时发生错误: {e}") logger.warning(
f"[{provider_args.provider_title}] 解析模型列表时发生错误: {e}"
)
def get_provider_list(self): def get_provider_list(self):
"""检索并获取特定的 provider list 数据集合或实例对象。 """检索并获取特定的 provider list 数据集合或实例对象。
根据提供的查询条件或上下文凭证,从数据库、缓存或第三方服务中读取对应的资源状态。 根据提供的查询条件或上下文凭证,从数据库、缓存或第三方服务中读取对应的资源状态。
Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。"""
return self.provider_register return self.provider_register
def get_provider(self, provider_title): def get_provider(self, provider_title):
"""检索并获取特定的 provider 数据集合或实例对象。 """检索并获取特定的 provider 数据集合或实例对象。
根据提供的查询条件或上下文凭证,从数据库、缓存或第三方服务中读取对应的资源状态。 根据提供的查询条件或上下文凭证,从数据库、缓存或第三方服务中读取对应的资源状态。
Args: provider_title: 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 provider_title 实例。 Args: provider_title: 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 provider_title 实例。
Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。"""
return self.provider_register.get(provider_title) return self.provider_register.get(provider_title)
async def delete_provider(self, provider_title: str, postgres_database) -> None: async def delete_provider(self, provider_title: str, postgres_database) -> None:
"""安全地移除或注销 provider。 """安全地移除或注销 provider。
执行物理删除或逻辑删除操作,并妥善清理相关的关联数据及占用资源。 执行物理删除或逻辑删除操作,并妥善清理相关的关联数据及占用资源。
Args: provider_title (str): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 provider_title 实例。 postgres_database: 从客户端传递过来或由上游组件生成的核心业务数据体,通常需要进一步的清洗和结构化解析。 Args: provider_title (str): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 provider_title 实例。 postgres_database: 从客户端传递过来或由上游组件生成的核心业务数据体,通常需要进一步的清洗和结构化解析。
Returns: (None): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ Returns: (None): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。"""
if provider_title in self.provider_register: if provider_title in self.provider_register:
provider = self.provider_register[provider_title] provider = self.provider_register[provider_title]
await postgres_database.delete_provider_db.remote( provider_id=provider.provider_id) await postgres_database.delete_provider_db.remote(
del self.provider_register[provider_title] provider_id=provider.provider_id
)
del self.provider_register[provider_title]
@@ -17,22 +17,29 @@ from collections import defaultdict
import pathlib import pathlib
import json import json
class GlobalSkillManager: class GlobalSkillManager:
"""GlobalSkillManager 核心组件类。 """GlobalSkillManager 核心组件类。
这是一个管理器类,职责集中在维护整个系统内有关 GlobalSkill 资源的全局生命周期。它提供了注册机制、状态同步以及跨组件的统一查询入口,确保系统中该类型资源的实例一致性与可控性。 """ 这是一个管理器类,职责集中在维护整个系统内有关 GlobalSkill 资源的全局生命周期。它提供了注册机制、状态同步以及跨组件的统一查询入口,确保系统中该类型资源的实例一致性与可控性。"""
skill_mapper = Dict[str,Tuple[str]]
skill_mapper = Dict[str, Tuple[str]]
"""skill的存储表""" """skill的存储表"""
def __init__(self): def __init__(self):
self.skill_mapper = defaultdict(tuple) self.skill_mapper = defaultdict(tuple)
import os import os
skill_plugin_dir = pathlib.Path(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "plugin", "skill")))
skill_plugin_dir = pathlib.Path(
os.path.abspath(
os.path.join(os.path.dirname(__file__), "..", "..", "plugin", "skill")
)
)
if not skill_plugin_dir.exists() or not skill_plugin_dir.is_dir(): if not skill_plugin_dir.exists() or not skill_plugin_dir.is_dir():
return return
for item in skill_plugin_dir.iterdir(): for item in skill_plugin_dir.iterdir():
if item.is_dir() and not item.name.startswith((".", "__")): if item.is_dir() and not item.name.startswith((".", "__")):
json_path = item / "skill.json" # 拼接文件路径 json_path = item / "skill.json" # 拼接文件路径
if json_path.exists(): if json_path.exists():
try: try:
with open(json_path, "r", encoding="utf-8") as f: with open(json_path, "r", encoding="utf-8") as f:
@@ -42,7 +49,7 @@ class GlobalSkillManager:
if name: if name:
self.skill_mapper[name] = ( self.skill_mapper[name] = (
skill.get("description", ""), skill.get("description", ""),
skill.get("instructions", "") skill.get("instructions", ""),
) )
except (json.JSONDecodeError, OSError) as e: except (json.JSONDecodeError, OSError) as e:
print(f"警告: 加载插件 {item.name} 失败: {e}") print(f"警告: 加载插件 {item.name} 失败: {e}")
@@ -50,7 +57,12 @@ class GlobalSkillManager:
def add_skill(self, skill_name: str) -> None: def add_skill(self, skill_name: str) -> None:
"""Add a skill to the manager by reading its skill.json from the path""" """Add a skill to the manager by reading its skill.json from the path"""
import os import os
skill_plugin_dir = pathlib.Path(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "plugin", "skill")))
skill_plugin_dir = pathlib.Path(
os.path.abspath(
os.path.join(os.path.dirname(__file__), "..", "..", "plugin", "skill")
)
)
item = skill_plugin_dir / skill_name item = skill_plugin_dir / skill_name
if item.is_dir() and not item.name.startswith((".", "__")): if item.is_dir() and not item.name.startswith((".", "__")):
json_path = item / "skill.json" json_path = item / "skill.json"
@@ -62,7 +74,7 @@ class GlobalSkillManager:
if name: if name:
self.skill_mapper[name] = ( self.skill_mapper[name] = (
skill.get("description", ""), skill.get("description", ""),
skill.get("instructions", "") skill.get("instructions", ""),
) )
except (json.JSONDecodeError, OSError) as e: except (json.JSONDecodeError, OSError) as e:
print(f"警告: 加载插件 {item.name} 失败: {e}") print(f"警告: 加载插件 {item.name} 失败: {e}")
@@ -19,17 +19,22 @@ from collections import defaultdict
from pretor.plugin.tool_plugin.base_tool import BaseToolData from pretor.plugin.tool_plugin.base_tool import BaseToolData
from typing import Dict, Type from typing import Dict, Type
from pretor.utils.logger import get_logger from pretor.utils.logger import get_logger
logger = get_logger('tool_manager')
logger = get_logger("tool_manager")
class GlobalToolManager: class GlobalToolManager:
"""GlobalToolManager 核心组件类。 """GlobalToolManager 核心组件类。
这是一个管理器类,职责集中在维护整个系统内有关 GlobalTool 资源的全局生命周期。它提供了注册机制、状态同步以及跨组件的统一查询入口,确保系统中该类型资源的实例一致性与可控性。 """ 这是一个管理器类,职责集中在维护整个系统内有关 GlobalTool 资源的全局生命周期。它提供了注册机制、状态同步以及跨组件的统一查询入口,确保系统中该类型资源的实例一致性与可控性。"""
tool_mapper: Dict[str, Dict[str, Type[BaseToolData]]] tool_mapper: Dict[str, Dict[str, Type[BaseToolData]]]
def __init__(self): def __init__(self):
self.tool_mapper = defaultdict(dict) self.tool_mapper = defaultdict(dict)
tool_plugin_dir = pathlib.Path(__file__).parent.parent.parent / "plugin" / "tool_plugin" tool_plugin_dir = (
pathlib.Path(__file__).parent.parent.parent / "plugin" / "tool_plugin"
)
if not tool_plugin_dir.exists() or not tool_plugin_dir.is_dir(): if not tool_plugin_dir.exists() or not tool_plugin_dir.is_dir():
return return
@@ -51,4 +56,4 @@ class GlobalToolManager:
for scope in action_scopes: for scope in action_scopes:
self.tool_mapper[scope][plugin_name] = obj self.tool_mapper[scope][plugin_name] = obj
except Exception as e: except Exception as e:
logger.warning(f"Failed to load tool plugin {plugin_name}: {e}") logger.warning(f"Failed to load tool plugin {plugin_name}: {e}")
@@ -0,0 +1,5 @@
from pretor.core.global_workflow_manager.global_workflow_manager import (
GlobalWorkflowManager,
)
__all__ = ["GlobalWorkflowManager"]
@@ -6,6 +6,7 @@ from pretor.core.workflow.workflow import PretorWorkflow
from pretor.utils.ray_hook import ray_actor_hook from pretor.utils.ray_hook import ray_actor_hook
from pretor.utils.logger import get_logger from pretor.utils.logger import get_logger
@ray.remote @ray.remote
class GlobalWorkflowManager: class GlobalWorkflowManager:
def __init__(self): def __init__(self):
@@ -31,7 +32,9 @@ class GlobalWorkflowManager:
event_copy = event.model_copy() event_copy = event.model_copy()
event_copy.pending_queue = None event_copy.pending_queue = None
event_copy.receive_queue = None event_copy.receive_queue = None
self.event_object_refs[event.trace_id] = ray.put(event_copy.model_dump_json()) self.event_object_refs[event.trace_id] = ray.put(
event_copy.model_dump_json()
)
except Exception as e: except Exception as e:
self.logger.error(f"Failed to load event {record.trace_id}: {e}") self.logger.error(f"Failed to load event {record.trace_id}: {e}")
@@ -40,13 +43,22 @@ class GlobalWorkflowManager:
# Trigger resumption of incomplete workflows # Trigger resumption of incomplete workflows
workflow_running_engine = None workflow_running_engine = None
for trace_id, event in self.event_dict.items(): for trace_id, event in self.event_dict.items():
if event.workflow and event.workflow.status.status in ["waiting_llm_working", "waiting_tool_working", "llm_working", "tool_working"]: if event.workflow and event.workflow.status.status in [
"waiting_llm_working",
"waiting_tool_working",
"llm_working",
"tool_working",
]:
self.logger.info(f"Resuming incomplete workflow {trace_id}") self.logger.info(f"Resuming incomplete workflow {trace_id}")
if not workflow_running_engine: if not workflow_running_engine:
try: try:
workflow_running_engine = ray_actor_hook("workflow_running_engine").workflow_running_engine workflow_running_engine = ray_actor_hook(
"workflow_running_engine"
).workflow_running_engine
except AttributeError: except AttributeError:
self.logger.warning("workflow_running_engine not found, cannot resume workflow") self.logger.warning(
"workflow_running_engine not found, cannot resume workflow"
)
break break
await workflow_running_engine.resume_workflow.remote(event) await workflow_running_engine.resume_workflow.remote(event)
@@ -64,12 +76,11 @@ class GlobalWorkflowManager:
# Update cache # Update cache
self.event_object_refs[event.trace_id] = ray.put(event_json) self.event_object_refs[event.trace_id] = ray.put(event_json)
await self.postgres_database.upsert_event.remote( await self.postgres_database.upsert_event.remote(event.trace_id, event_json)
event.trace_id,
event_json
)
except Exception as e: except Exception as e:
self.logger.error(f"Failed to upsert event {event.trace_id} to database: {e}") self.logger.error(
f"Failed to upsert event {event.trace_id} to database: {e}"
)
async def add_event(self, event: PretorEvent) -> None: async def add_event(self, event: PretorEvent) -> None:
event.pending_queue = asyncio.Queue() event.pending_queue = asyncio.Queue()
@@ -98,7 +109,9 @@ class GlobalWorkflowManager:
event_json = ray.get(self.event_object_refs[trace_id]) event_json = ray.get(self.event_object_refs[trace_id])
return PretorEvent.model_validate_json(event_json) return PretorEvent.model_validate_json(event_json)
except Exception as e: except Exception as e:
self.logger.warning(f"Failed to fetch event from cache for trace {trace_id}: {e}") self.logger.warning(
f"Failed to fetch event from cache for trace {trace_id}: {e}"
)
# Fallback to database # Fallback to database
try: try:
@@ -119,11 +132,15 @@ class GlobalWorkflowManager:
return event return event
except Exception as e: except Exception as e:
self.logger.error(f"Failed to fetch event {trace_id} from database fallback: {e}") self.logger.error(
f"Failed to fetch event {trace_id} from database fallback: {e}"
)
return None return None
async def update_attachment(self, trace_id: str, attachment: Dict[str, str]) -> None: async def update_attachment(
self, trace_id: str, attachment: Dict[str, str]
) -> None:
if trace_id in self.event_dict: if trace_id in self.event_dict:
self.event_dict[trace_id].attachment = attachment self.event_dict[trace_id].attachment = attachment
await self._upsert_event_to_db(self.event_dict[trace_id]) await self._upsert_event_to_db(self.event_dict[trace_id])
@@ -148,17 +165,25 @@ class GlobalWorkflowManager:
try: try:
event = PretorEvent.model_validate_json(record.event_data_json) event = PretorEvent.model_validate_json(record.event_data_json)
workflow_title = event.workflow.title if event.workflow else None workflow_title = event.workflow.title if event.workflow else None
workflow_status = event.workflow.status.status if event.workflow and event.workflow.status else None workflow_status = (
result.append({ event.workflow.status.status
"event_id": event.trace_id, if event.workflow and event.workflow.status
"workflow_title": workflow_title, else None
"status": workflow_status, )
"user_name": event.user_name, result.append(
"message": event.message, {
"create_time": event.create_time, "event_id": event.trace_id,
}) "workflow_title": workflow_title,
"status": workflow_status,
"user_name": event.user_name,
"message": event.message,
"create_time": event.create_time,
}
)
# Best-effort cache population # Best-effort cache population
self.event_object_refs[event.trace_id] = ray.put(record.event_data_json) self.event_object_refs[event.trace_id] = ray.put(
record.event_data_json
)
except Exception: except Exception:
continue continue
except Exception as e: except Exception as e:
@@ -173,7 +198,7 @@ class GlobalWorkflowManager:
async def get_pending(self, trace_id) -> str: async def get_pending(self, trace_id) -> str:
if trace_id in self.event_dict and self.event_dict[trace_id].pending_queue: if trace_id in self.event_dict and self.event_dict[trace_id].pending_queue:
return await self.event_dict[trace_id].pending_queue.get() return await self.event_dict[trace_id].pending_queue.get()
await asyncio.sleep(1) # Prevent CPU spinning if not found await asyncio.sleep(1) # Prevent CPU spinning if not found
return "" return ""
async def put_received(self, trace_id, item) -> None: async def put_received(self, trace_id, item) -> None:
@@ -183,5 +208,5 @@ class GlobalWorkflowManager:
async def get_received(self, trace_id) -> str: async def get_received(self, trace_id) -> str:
if trace_id in self.event_dict and self.event_dict[trace_id].receive_queue: if trace_id in self.event_dict and self.event_dict[trace_id].receive_queue:
return await self.event_dict[trace_id].receive_queue.get() return await self.event_dict[trace_id].receive_queue.get()
await asyncio.sleep(1) # Prevent CPU spinning if not found await asyncio.sleep(1) # Prevent CPU spinning if not found
return "" return ""
@@ -13,4 +13,5 @@
# limitations under the License. # limitations under the License.
from .consciousness_node import ConsciousnessNode from .consciousness_node import ConsciousnessNode
__all__ = ["ConsciousnessNode"] __all__ = ["ConsciousnessNode"]
@@ -15,8 +15,15 @@
import ray import ray
from typing import Union, overload from typing import Union, overload
from pretor.core.individual.consciousness_node.template import (ConsciousnessNodeDeps, ForSupervisoryNode, ForWorkflow,\ from pretor.core.individual.consciousness_node.template import (
ForWorkflowEngine, ForWorkflowInput, ForSupervisoryInput, ForWorkflowEngineInput) ConsciousnessNodeDeps,
ForSupervisoryNode,
ForWorkflow,
ForWorkflowEngine,
ForWorkflowInput,
ForSupervisoryInput,
ForWorkflowEngineInput,
)
from pydantic_ai import Agent, RunContext from pydantic_ai import Agent, RunContext
from pretor.core.global_state_machine.global_state_machine import GlobalStateMachine from pretor.core.global_state_machine.global_state_machine import GlobalStateMachine
from pretor.core.global_state_machine.model_provider.base_provider import Provider from pretor.core.global_state_machine.model_provider.base_provider import Provider
@@ -26,14 +33,21 @@ from pretor.adapter.model_adapter.agent_factory import AgentFactory
@ray.remote @ray.remote
class ConsciousnessNode: class ConsciousnessNode:
"""ConsciousnessNode 核心组件类。 """ConsciousnessNode 核心组件类。
这是一个系统执行节点类,作为多智能体架构中的独立处理单元。它能够接收工作流上下文,根据内置的大模型策略进行意图理解和自主决策,从而驱动特定阶段的任务闭环。 """ 这是一个系统执行节点类,作为多智能体架构中的独立处理单元。它能够接收工作流上下文,根据内置的大模型策略进行意图理解和自主决策,从而驱动特定阶段的任务闭环。"""
def __init__(self) -> None: def __init__(self) -> None:
from pretor.utils.logger import get_logger from pretor.utils.logger import get_logger
self.logger = get_logger('consciousness_node')
self.logger = get_logger("consciousness_node")
self.agent: None | Agent = None self.agent: None | Agent = None
async def create_agent(
async def create_agent(self, global_state_machine: GlobalStateMachine, provider_title: str, model_id: str, tools_list: list[str] = None) -> None: self,
global_state_machine: GlobalStateMachine,
provider_title: str,
model_id: str,
tools_list: list[str] = None,
) -> None:
""" """
create_agent方法,将agent对象装配到ConsciousnessNode的属性内 create_agent方法,将agent对象装配到ConsciousnessNode的属性内
该方法通过provider_title从global_state_machine中获取provider对象,然后从provider对象中取出供应商形象,装配为pydantic_ai的 该方法通过provider_title从global_state_machine中获取provider对象,然后从provider对象中取出供应商形象,装配为pydantic_ai的
@@ -58,32 +72,35 @@ class ConsciousnessNode:
) )
output_type = Union[ForSupervisoryNode, ForWorkflow, ForWorkflowEngine] output_type = Union[ForSupervisoryNode, ForWorkflow, ForWorkflowEngine]
from pretor.utils.get_tool import load_tools_from_list from pretor.utils.get_tool import load_tools_from_list
provider: Provider = await global_state_machine.get_provider.remote( provider_title)
provider: Provider = await global_state_machine.get_provider.remote(
provider_title
)
agent_factory = AgentFactory() agent_factory = AgentFactory()
callables = load_tools_from_list(tools_list) callables = load_tools_from_list(tools_list)
self.agent = agent_factory.create_agent(provider=provider, self.agent = agent_factory.create_agent(
model_id=model_id, provider=provider,
output_type=output_type, model_id=model_id,
system_prompt=system_prompt, output_type=output_type,
deps_type=ConsciousnessNodeDeps, system_prompt=system_prompt,
agent_name="consciousness_node", deps_type=ConsciousnessNodeDeps,
tools=callables) agent_name="consciousness_node",
tools=callables,
)
@self.agent.system_prompt @self.agent.system_prompt
async def dynamic_prompt(ctx: RunContext[ConsciousnessNodeDeps]): async def dynamic_prompt(ctx: RunContext[ConsciousnessNodeDeps]):
"""执行与 dynamic prompt 相关的核心业务流转操作。 """执行与 dynamic prompt 相关的核心业务流转操作。
该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。 该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。
Args: ctx (RunContext[ConsciousnessNodeDeps]): 参与 dynamic prompt 逻辑运算或数据构建的上下文依赖对象。 Args: ctx (RunContext[ConsciousnessNodeDeps]): 参与 dynamic prompt 逻辑运算或数据构建的上下文依赖对象。
Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。"""
prompt = system_prompt + "\n\n" prompt = system_prompt + "\n\n"
prompt += ( prompt += (
f"=== 当前任务上下文 ===\n" f"=== 当前任务上下文 ===\n"
f"- 当前指令 (Command): {ctx.deps.command}\n" f"- 当前指令 (Command): {ctx.deps.command}\n"
f"- 原始用户命令 (Original Command): {ctx.deps.original_command}\n" f"- 原始用户命令 (Original Command): {ctx.deps.original_command}\n"
) )
if ctx.deps.workflow_template:
prompt += f"- 选定工作流模板 (Workflow Template): {ctx.deps.workflow_template}\n"
if ctx.deps.available_skills: if ctx.deps.available_skills:
prompt += "\n=== 当前可用 Skill Individual ===\n" prompt += "\n=== 当前可用 Skill Individual ===\n"
prompt += "你可以直接将以下 Skill Individual 安排进工作流的步骤中(设置 node 为 skill_individual,并将 agent_id 设置为对应 Skill Individual 的真实 agent_id,不要用名称!),作为可调用的工具。\n" prompt += "你可以直接将以下 Skill Individual 安排进工作流的步骤中(设置 node 为 skill_individual,并将 agent_id 设置为对应 Skill Individual 的真实 agent_id,不要用名称!),作为可调用的工具。\n"
@@ -92,30 +109,34 @@ class ConsciousnessNode:
return prompt return prompt
async def working(self, payload: Union[ForWorkflowEngineInput, ForWorkflowInput, ForSupervisoryInput]) -> Union[ForWorkflowEngine, ForWorkflow, ForSupervisoryNode, None]: async def working(
self,
payload: Union[ForWorkflowEngineInput, ForWorkflowInput, ForSupervisoryInput],
) -> Union[ForWorkflowEngine, ForWorkflow, ForSupervisoryNode, None]:
"""执行与 working 相关的核心业务流转操作。 """执行与 working 相关的核心业务流转操作。
该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。 该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。
Args: payload (Union[ForWorkflowEngineInput, ForWorkflowInput, ForSupervisoryInput]): 从客户端传递过来或由上游组件生成的核心业务数据体,通常需要进一步的清洗和结构化解析。 Args: payload (Union[ForWorkflowEngineInput, ForWorkflowInput, ForSupervisoryInput]): 从客户端传递过来或由上游组件生成的核心业务数据体,通常需要进一步的清洗和结构化解析。
Returns: (Union[ForWorkflowEngine, ForWorkflow, ForSupervisoryNode, None]): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ Returns: (Union[ForWorkflowEngine, ForWorkflow, ForSupervisoryNode, None]): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。"""
try: try:
result = await self._run(payload) result = await self._run(payload)
if isinstance(result, (ForWorkflowEngine, ForWorkflow, ForSupervisoryNode)): if isinstance(result, (ForWorkflowEngine, ForWorkflow, ForSupervisoryNode)):
return result return result
else: else:
self.logger.error(f"ConsciousnessNode: 未知或不匹配的返回类型: {type(result)}") self.logger.error(
f"ConsciousnessNode: 未知或不匹配的返回类型: {type(result)}"
)
return None return None
except Exception: except Exception:
self.logger.exception("ConsciousnessNode在执行working时发生严重错误") self.logger.exception("ConsciousnessNode在执行working时发生严重错误")
return None return None
@overload @overload
async def _run(self, payload: ForWorkflowEngineInput) -> ForWorkflowEngine: async def _run(self, payload: ForWorkflowEngineInput) -> ForWorkflowEngine:
""" """
_run方法 _run方法
该分支应当在supervisory_node简单处理用户命令后,工作流创建前调用! 该分支应当在supervisory_node简单处理用户命令后,工作流创建前调用!
Args: Args:
payload: 应当包含workflow_template和event对象 payload: 应当包含原始命令和可用技能等信息
Returns: Returns:
ForWorkflowEngine对象,将被放到全局状态机后丢入WorkflowEngine的异步队列 ForWorkflowEngine对象,将被放到全局状态机后丢入WorkflowEngine的异步队列
@@ -148,45 +169,53 @@ class ConsciousnessNode:
""" """
pass pass
async def _run(self, payload: Union[ForSupervisoryInput, ForWorkflowInput, ForWorkflowEngineInput]) -> Union[ForSupervisoryNode, ForWorkflow, ForWorkflowEngine]: async def _run(
self,
payload: Union[ForSupervisoryInput, ForWorkflowInput, ForWorkflowEngineInput],
) -> Union[ForSupervisoryNode, ForWorkflow, ForWorkflowEngine]:
"""执行与 run 相关的核心业务流转操作。 """执行与 run 相关的核心业务流转操作。
该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。 该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。
Args: payload (Union[ForSupervisoryInput, ForWorkflowInput, ForWorkflowEngineInput]): 从客户端传递过来或由上游组件生成的核心业务数据体,通常需要进一步的清洗和结构化解析。 Args: payload (Union[ForSupervisoryInput, ForWorkflowInput, ForWorkflowEngineInput]): 从客户端传递过来或由上游组件生成的核心业务数据体,通常需要进一步的清洗和结构化解析。
Returns: (Union[ForSupervisoryNode, ForWorkflow, ForWorkflowEngine]): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ Returns: (Union[ForSupervisoryNode, ForWorkflow, ForWorkflowEngine]): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。"""
try: try:
self.agent.retries = 3 self.agent.retries = 3
if isinstance(payload, ForWorkflowEngineInput): if isinstance(payload, ForWorkflowEngineInput):
deps = ConsciousnessNodeDeps( deps = ConsciousnessNodeDeps(
original_command=payload.original_command, original_command=payload.original_command,
workflow_template=payload.workflow_template, command="自主分析并拆解原始命令,生成严密可执行的工作流",
command="拆解原始命令变成一个工作流", available_skills=payload.available_skills,
available_skills=payload.available_skills
) )
self.logger.debug("ConsciousnessNode: 开始生成工作流 (原生重试开启)") self.logger.debug("ConsciousnessNode: 开始生成工作流 (原生重试开启)")
prompt = "根据original_command制定严密的可执行workflow" prompt = "根据original_command制定严密的可执行workflow"
if payload.workflow_template:
prompt += ",可以学习并参考workflow_template的设计理念"
result = await self.agent.run(prompt, deps=deps) result = await self.agent.run(prompt, deps=deps)
return result.output return result.output
elif isinstance(payload, ForWorkflowInput): elif isinstance(payload, ForWorkflowInput):
deps = ConsciousnessNodeDeps( deps = ConsciousnessNodeDeps(
original_command=payload.original_command, original_command=payload.original_command,
command="完成workflow step中分配给意识节点的特定任务或指导" command="完成workflow step中分配给意识节点的特定任务或指导",
)
self.logger.debug(
"ConsciousnessNode: 开始处理工作流节点任务 (原生重试开启)"
)
result = await self.agent.run(
f"处理此工作流步骤信息:\n{payload.workflow_step.model_dump_json()}",
deps=deps,
) )
self.logger.debug("ConsciousnessNode: 开始处理工作流节点任务 (原生重试开启)")
result = await self.agent.run(f"处理此工作流步骤信息:\n{payload.workflow_step.model_dump_json()}",
deps=deps)
return result.output return result.output
elif isinstance(payload, ForSupervisoryInput): elif isinstance(payload, ForSupervisoryInput):
deps = ConsciousnessNodeDeps( deps = ConsciousnessNodeDeps(
original_command=payload.original_command, original_command=payload.original_command,
command="对于工作流整体执行结果进行检查,并且生成一份专业的技术性总结报告" command="对于工作流整体执行结果进行检查,并且生成一份专业的技术性总结报告",
)
self.logger.debug(
"ConsciousnessNode: 开始生成技术总结报告 (原生重试开启)"
)
result = await self.agent.run(
f"基于以下工作流的执行记录,生成技术报告:\n{payload.workflow.model_dump_json()}",
deps=deps,
) )
self.logger.debug("ConsciousnessNode: 开始生成技术总结报告 (原生重试开启)")
result = await self.agent.run(f"基于以下工作流的执行记录,生成技术报告:\n{payload.workflow.model_dump_json()}",
deps=deps)
return result.output return result.output
except Exception as e: except Exception as e:
self.logger.exception(f"ConsciousnessNode 模型生成最终失败: {str(e)}") self.logger.exception(f"ConsciousnessNode 模型生成最终失败: {str(e)}")
@@ -18,60 +18,71 @@ from pretor.utils.agent_model import ResponseModel, DepsModel, InputModel
from pydantic import Field from pydantic import Field
#意识节点回复类 # 意识节点回复类
class ConsciousnessNodeResponse(ResponseModel): class ConsciousnessNodeResponse(ResponseModel):
"""Consciousness response model,是意识节点所有回复类型的父类""" """Consciousness response model,是意识节点所有回复类型的父类"""
pass pass
class ForWorkflowEngine(ConsciousnessNodeResponse): class ForWorkflowEngine(ConsciousnessNodeResponse):
"""生成workflow并放入WorkflowEngine""" """生成workflow并放入WorkflowEngine"""
workflow: PretorWorkflow = Field(..., description="生成好的符合规范的完整工作流对象。")
workflow: PretorWorkflow = Field(
..., description="生成好的符合规范的完整工作流对象。"
)
reasoning: str = Field(..., description="生成此工作流的原因和思路简述。") reasoning: str = Field(..., description="生成此工作流的原因和思路简述。")
class ForWorkflow(ConsciousnessNodeResponse): class ForWorkflow(ConsciousnessNodeResponse):
"""处理workflow中需要ConsciousnessNode的工作""" """处理workflow中需要ConsciousnessNode的工作"""
output: str = Field(..., description="对当前工作流步骤的具体处理结果或指导意见。") output: str = Field(..., description="对当前工作流步骤的具体处理结果或指导意见。")
class ForSupervisoryNode(ConsciousnessNodeResponse): class ForSupervisoryNode(ConsciousnessNodeResponse):
"""工作流完成后进行校验并返回给SupervisoryNode""" """工作流完成后进行校验并返回给SupervisoryNode"""
output: str = Field(..., description="为监控节点提供的全工作流执行情况的技术性总结报告。")
output: str = Field(
..., description="为监控节点提供的全工作流执行情况的技术性总结报告。"
)
class ConsciousnessNodeDeps(DepsModel): class ConsciousnessNodeDeps(DepsModel):
"""ConsciousnessNodeDeps 核心组件类。 """ConsciousnessNodeDeps 核心组件类。
这是一个系统执行节点类,作为多智能体架构中的独立处理单元。它能够接收工作流上下文,根据内置的大模型策略进行意图理解和自主决策,从而驱动特定阶段的任务闭环。 """ 这是一个系统执行节点类,作为多智能体架构中的独立处理单元。它能够接收工作流上下文,根据内置的大模型策略进行意图理解和自主决策,从而驱动特定阶段的任务闭环。"""
original_command: str original_command: str
workflow_template: str | None = None
command: str command: str
available_skills: list[dict] | None = None available_skills: list[dict] | None = None
class ConsciousnessNodeInput(InputModel): class ConsciousnessNodeInput(InputModel):
"""ConsciousnessNodeInput 核心组件类。 """ConsciousnessNodeInput 核心组件类。
这是一个系统执行节点类,作为多智能体架构中的独立处理单元。它能够接收工作流上下文,根据内置的大模型策略进行意图理解和自主决策,从而驱动特定阶段的任务闭环。 """ 这是一个系统执行节点类,作为多智能体架构中的独立处理单元。它能够接收工作流上下文,根据内置的大模型策略进行意图理解和自主决策,从而驱动特定阶段的任务闭环。"""
pass pass
class ForWorkflowEngineInput(ConsciousnessNodeInput): class ForWorkflowEngineInput(ConsciousnessNodeInput):
"""ForWorkflowEngineInput 核心组件类。 """ForWorkflowEngineInput 核心组件类。
这是一个领域数据模型或功能封装类,承载了 ForWorkflowEngineInput 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。 """ 这是一个领域数据模型或功能封装类,承载了 ForWorkflowEngineInput 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。"""
workflow_template: str | None = None
original_command: str original_command: str
available_skills: list[dict] | None = None available_skills: list[dict] | None = None
class ForWorkflowInput(ConsciousnessNodeInput): class ForWorkflowInput(ConsciousnessNodeInput):
"""ForWorkflowInput 核心组件类。 """ForWorkflowInput 核心组件类。
这是一个领域数据模型或功能封装类,承载了 ForWorkflowInput 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。 """ 这是一个领域数据模型或功能封装类,承载了 ForWorkflowInput 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。"""
workflow_step: WorkStep workflow_step: WorkStep
original_command: str original_command: str
class ForSupervisoryInput(ConsciousnessNodeInput): class ForSupervisoryInput(ConsciousnessNodeInput):
"""ForSupervisoryInput 核心组件类。 """ForSupervisoryInput 核心组件类。
这是一个领域数据模型或功能封装类,承载了 ForSupervisoryInput 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。 """ 这是一个领域数据模型或功能封装类,承载了 ForSupervisoryInput 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。"""
workflow: PretorWorkflow workflow: PretorWorkflow
original_command: str original_command: str
@@ -13,4 +13,5 @@
# limitations under the License. # limitations under the License.
from .control_node import ControlNode from .control_node import ControlNode
__all__ = ["ControlNode"] __all__ = ["ControlNode"]
@@ -17,21 +17,31 @@ from pydantic_ai import Agent, RunContext
from pretor.core.global_state_machine.global_state_machine import GlobalStateMachine from pretor.core.global_state_machine.global_state_machine import GlobalStateMachine
from pretor.core.global_state_machine.model_provider.base_provider import Provider from pretor.core.global_state_machine.model_provider.base_provider import Provider
from pretor.adapter.model_adapter.agent_factory import AgentFactory from pretor.adapter.model_adapter.agent_factory import AgentFactory
from pretor.core.individual.control_node.template import ForWorkflow, ForWorkflowInput, ControlNodeDeps from pretor.core.individual.control_node.template import (
ForWorkflow,
ForWorkflowInput,
ControlNodeDeps,
)
@ray.remote @ray.remote
class ControlNode: class ControlNode:
"""ControlNode 核心组件类。 """ControlNode 核心组件类。
这是一个系统执行节点类,作为多智能体架构中的独立处理单元。它能够接收工作流上下文,根据内置的大模型策略进行意图理解和自主决策,从而驱动特定阶段的任务闭环。 """ 这是一个系统执行节点类,作为多智能体架构中的独立处理单元。它能够接收工作流上下文,根据内置的大模型策略进行意图理解和自主决策,从而驱动特定阶段的任务闭环。"""
def __init__(self): def __init__(self):
from pretor.utils.logger import get_logger from pretor.utils.logger import get_logger
self.logger = get_logger('control_node')
self.logger = get_logger("control_node")
self.agent: Agent | None = None self.agent: Agent | None = None
async def create_agent(
async def create_agent(self, global_state_machine: GlobalStateMachine, provider_title: str, model_id: str, tools_list: list[str] = None) -> None: self,
global_state_machine: GlobalStateMachine,
provider_title: str,
model_id: str,
tools_list: list[str] = None,
) -> None:
""" """
create_agent方法,将agent对象装配到Control的属性内 create_agent方法,将agent对象装配到Control的属性内
该方法通过provider_title从global_state_machine中获取provider对象,然后从provider对象中取出供应商形象,装配为pydantic_ai的 该方法通过provider_title从global_state_machine中获取provider对象,然后从provider对象中取出供应商形象,装配为pydantic_ai的
@@ -56,23 +66,29 @@ class ControlNode:
) )
output_type = ForWorkflow output_type = ForWorkflow
from pretor.utils.get_tool import load_tools_from_list from pretor.utils.get_tool import load_tools_from_list
provider: Provider = await global_state_machine.get_provider.remote( provider_title)
provider: Provider = await global_state_machine.get_provider.remote(
provider_title
)
agent_factory = AgentFactory() agent_factory = AgentFactory()
callables = load_tools_from_list(tools_list) callables = load_tools_from_list(tools_list)
self.agent = agent_factory.create_agent(provider=provider, self.agent = agent_factory.create_agent(
model_id=model_id, provider=provider,
output_type=output_type, model_id=model_id,
system_prompt=system_prompt, output_type=output_type,
deps_type=ControlNodeDeps, system_prompt=system_prompt,
agent_name="control_node", deps_type=ControlNodeDeps,
tools=callables) agent_name="control_node",
tools=callables,
)
@self.agent.system_prompt @self.agent.system_prompt
async def dynamic_prompt(ctx: RunContext[ControlNodeDeps]): async def dynamic_prompt(ctx: RunContext[ControlNodeDeps]):
"""执行与 dynamic prompt 相关的核心业务流转操作。 """执行与 dynamic prompt 相关的核心业务流转操作。
该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。 该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。
Args: ctx (RunContext[ControlNodeDeps]): 参与 dynamic prompt 逻辑运算或数据构建的上下文依赖对象。 Args: ctx (RunContext[ControlNodeDeps]): 参与 dynamic prompt 逻辑运算或数据构建的上下文依赖对象。
Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。"""
prompt = system_prompt + "\n\n" prompt = system_prompt + "\n\n"
prompt += ( prompt += (
f"=== 当前任务步骤上下文 ===\n" f"=== 当前任务步骤上下文 ===\n"
@@ -86,7 +102,7 @@ class ControlNode:
"""执行与 working 相关的核心业务流转操作。 """执行与 working 相关的核心业务流转操作。
该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。 该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。
Args: payload (ForWorkflowInput): 从客户端传递过来或由上游组件生成的核心业务数据体,通常需要进一步的清洗和结构化解析。 Args: payload (ForWorkflowInput): 从客户端传递过来或由上游组件生成的核心业务数据体,通常需要进一步的清洗和结构化解析。
Returns: (str): 处理流程所输出的具体字符串产物,可能是新生成的 ID 序列、格式化好的文本片段或 LLM 推理的回答内容。 """ Returns: (str): 处理流程所输出的具体字符串产物,可能是新生成的 ID 序列、格式化好的文本片段或 LLM 推理的回答内容。"""
try: try:
result: ForWorkflow = await self._run(payload) result: ForWorkflow = await self._run(payload)
return result return result
@@ -98,19 +114,21 @@ class ControlNode:
"""执行与 run 相关的核心业务流转操作。 """执行与 run 相关的核心业务流转操作。
该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。 该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。
Args: payload (ForWorkflowInput): 从客户端传递过来或由上游组件生成的核心业务数据体,通常需要进一步的清洗和结构化解析。 Args: payload (ForWorkflowInput): 从客户端传递过来或由上游组件生成的核心业务数据体,通常需要进一步的清洗和结构化解析。
Returns: (ForWorkflow): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ Returns: (ForWorkflow): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。"""
try: try:
self.agent.retries = 3 self.agent.retries = 3
deps = ControlNodeDeps( deps = ControlNodeDeps(workflow_step=payload.workflow_step)
workflow_step=payload.workflow_step self.logger.debug(
f"ControlNode: 开始执行工作流节点 [{payload.workflow_step.name}] (原生重试开启)"
) )
self.logger.debug(f"ControlNode: 开始执行工作流节点 [{payload.workflow_step.name}] (原生重试开启)")
result = await self.agent.run( result = await self.agent.run(
f"请根据提供的 workflow_step 上下文,执行此步骤并输出结果。\n详细指令或附加数据:{payload.workflow_step.model_dump_json()}", f"请根据提供的 workflow_step 上下文,执行此步骤并输出结果。\n详细指令或附加数据:{payload.workflow_step.model_dump_json()}",
deps=deps deps=deps,
) )
return result.output return result.output
except Exception as e: except Exception as e:
self.logger.exception(f"ControlNode 在执行步骤 [{payload.workflow_step.name}] 时最终失败: {str(e)}") self.logger.exception(
f"ControlNode 在执行步骤 [{payload.workflow_step.name}] 时最终失败: {str(e)}"
)
raise RuntimeError(f"ControlNode 执行步骤失败: {str(e)}") from e raise RuntimeError(f"ControlNode 执行步骤失败: {str(e)}") from e
@@ -17,31 +17,39 @@ from pydantic import Field
from pretor.core.workflow.workflow import WorkStep from pretor.core.workflow.workflow import WorkStep
from pretor.utils.agent_model import ResponseModel, InputModel, DepsModel from pretor.utils.agent_model import ResponseModel, InputModel, DepsModel
class ControlNodeResponse(ResponseModel): class ControlNodeResponse(ResponseModel):
"""控制节点回复的基类""" """控制节点回复的基类"""
pass pass
class ControlNodeInput(InputModel): class ControlNodeInput(InputModel):
"""ControlNodeInput 核心组件类。 """ControlNodeInput 核心组件类。
这是一个系统执行节点类,作为多智能体架构中的独立处理单元。它能够接收工作流上下文,根据内置的大模型策略进行意图理解和自主决策,从而驱动特定阶段的任务闭环。 """ 这是一个系统执行节点类,作为多智能体架构中的独立处理单元。它能够接收工作流上下文,根据内置的大模型策略进行意图理解和自主决策,从而驱动特定阶段的任务闭环。"""
pass pass
class ControlNodeDeps(DepsModel): class ControlNodeDeps(DepsModel):
"""ControlNodeDeps 核心组件类。 """ControlNodeDeps 核心组件类。
这是一个系统执行节点类,作为多智能体架构中的独立处理单元。它能够接收工作流上下文,根据内置的大模型策略进行意图理解和自主决策,从而驱动特定阶段的任务闭环。 """ 这是一个系统执行节点类,作为多智能体架构中的独立处理单元。它能够接收工作流上下文,根据内置的大模型策略进行意图理解和自主决策,从而驱动特定阶段的任务闭环。"""
workflow_step: WorkStep workflow_step: WorkStep
# In the future, this can be dynamically populated with tools specific to the current task execution # In the future, this can be dynamically populated with tools specific to the current task execution
class ForWorkflow(ControlNodeResponse): class ForWorkflow(ControlNodeResponse):
"""ForWorkflow 核心组件类。 """ForWorkflow 核心组件类。
这是一个领域数据模型或功能封装类,承载了 ForWorkflow 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。 """ 这是一个领域数据模型或功能封装类,承载了 ForWorkflow 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。"""
output: str = Field(..., description="控制节点执行特定工作流步骤的结果。包含执行细节和输出数据。")
output: str = Field(
..., description="控制节点执行特定工作流步骤的结果。包含执行细节和输出数据。"
)
class ForWorkflowInput(ControlNodeInput): class ForWorkflowInput(ControlNodeInput):
"""ForWorkflowInput 核心组件类。 """ForWorkflowInput 核心组件类。
这是一个领域数据模型或功能封装类,承载了 ForWorkflowInput 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。 """ 这是一个领域数据模型或功能封装类,承载了 ForWorkflowInput 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。"""
workflow_step: WorkStep workflow_step: WorkStep
@@ -13,4 +13,5 @@
# limitations under the License. # limitations under the License.
from .supervisory_node import SupervisoryNode from .supervisory_node import SupervisoryNode
__all__ = ["SupervisoryNode"] __all__ = ["SupervisoryNode"]
@@ -19,7 +19,12 @@ from pretor.api.platform.event import PretorEvent
from pretor.adapter.model_adapter.agent_factory import AgentFactory from pretor.adapter.model_adapter.agent_factory import AgentFactory
from pretor.core.global_state_machine.global_state_machine import GlobalStateMachine from pretor.core.global_state_machine.global_state_machine import GlobalStateMachine
from pretor.core.global_state_machine.model_provider import Provider from pretor.core.global_state_machine.model_provider import Provider
from pretor.core.individual.supervisory_node.template import ForConsciousnessNode, ForUser, SupervisoryNodeDeps, TerminationMessage from pretor.core.individual.supervisory_node.template import (
ForConsciousnessNode,
ForUser,
SupervisoryNodeDeps,
TerminationMessage,
)
from pydantic_ai import RunContext, Agent from pydantic_ai import RunContext, Agent
from pretor.utils.ray_hook import ray_actor_hook from pretor.utils.ray_hook import ray_actor_hook
@@ -27,14 +32,21 @@ from pretor.utils.ray_hook import ray_actor_hook
@ray.remote @ray.remote
class SupervisoryNode: class SupervisoryNode:
"""SupervisoryNode 核心组件类。 """SupervisoryNode 核心组件类。
这是一个系统执行节点类,作为多智能体架构中的独立处理单元。它能够接收工作流上下文,根据内置的大模型策略进行意图理解和自主决策,从而驱动特定阶段的任务闭环。 """ 这是一个系统执行节点类,作为多智能体架构中的独立处理单元。它能够接收工作流上下文,根据内置的大模型策略进行意图理解和自主决策,从而驱动特定阶段的任务闭环。"""
def __init__(self) -> None: def __init__(self) -> None:
from pretor.utils.logger import get_logger from pretor.utils.logger import get_logger
self.logger = get_logger('supervisory_node')
self.logger = get_logger("supervisory_node")
self.agent: None | Agent = None self.agent: None | Agent = None
async def create_agent(
async def create_agent(self, global_state_machine: GlobalStateMachine, provider_title: str, model_id: str, tools_list: list[str] = None) -> None: self,
global_state_machine: GlobalStateMachine,
provider_title: str,
model_id: str,
tools_list: list[str] = None,
) -> None:
""" """
create_agent方法,将agent对象装配到SupervisoryNode的属性内 create_agent方法,将agent对象装配到SupervisoryNode的属性内
该方法通过provider_title从global_state_machine中获取provider对象,然后从provider对象中取出供应商形象,装配为pydantic_ai的Agent实例, 该方法通过provider_title从global_state_machine中获取provider对象,然后从provider对象中取出供应商形象,装配为pydantic_ai的Agent实例,
@@ -53,43 +65,47 @@ class SupervisoryNode:
"你的核心职责是进行【意图识别与路由】。请仔细阅读用户的请求:\n" "你的核心职责是进行【意图识别与路由】。请仔细阅读用户的请求:\n"
"1. 如果用户只是进行简单的问候、闲聊或查询非常基础的信息,请直接生成友好的回复,使用 ForUser 格式。\n" "1. 如果用户只是进行简单的问候、闲聊或查询非常基础的信息,请直接生成友好的回复,使用 ForUser 格式。\n"
"2. 如果用户提出的是复杂任务(如需要编写代码、多步骤规划、数据处理等),请务必将其判定为需要工作流处理的任务," "2. 如果用户提出的是复杂任务(如需要编写代码、多步骤规划、数据处理等),请务必将其判定为需要工作流处理的任务,"
" 并使用 ForConsciousnessNode 格式。若提供的【可用模板列表】中有合适的模板请选用,若都不匹配则 workflow_template 设为 null\n" " 并使用 ForConsciousnessNode 格式将其移交意识节点处理\n"
"3. 如果你收到的是 TerminationMessage(代表工作流已完成并生成了报告),请将报告内容转化为友好的面向用户的回复,使用 ForUser 格式。\n" "3. 如果你收到的是 TerminationMessage(代表工作流已完成并生成了报告),请将报告内容转化为友好的面向用户的回复,使用 ForUser 格式。\n"
"请保持冷静、专业,并严格遵循上述路由规则。" "请保持冷静、专业,并严格遵循上述路由规则。"
) )
output_type = Union[ForConsciousnessNode, ForUser] output_type = Union[ForConsciousnessNode, ForUser]
from pretor.utils.get_tool import load_tools_from_list from pretor.utils.get_tool import load_tools_from_list
provider: Provider = await global_state_machine.get_provider.remote( provider_title)
provider: Provider = await global_state_machine.get_provider.remote(
provider_title
)
agent_factory = AgentFactory() agent_factory = AgentFactory()
callables = load_tools_from_list(tools_list) callables = load_tools_from_list(tools_list)
self.agent = agent_factory.create_agent(provider=provider, self.agent = agent_factory.create_agent(
model_id=model_id, provider=provider,
output_type=output_type, model_id=model_id,
system_prompt=system_prompt, output_type=output_type,
deps_type=SupervisoryNodeDeps, system_prompt=system_prompt,
agent_name="supervisory_node", deps_type=SupervisoryNodeDeps,
tools=callables) agent_name="supervisory_node",
tools=callables,
)
@self.agent.system_prompt @self.agent.system_prompt
async def dynamic_prompt(ctx: RunContext[SupervisoryNodeDeps]): async def dynamic_prompt(ctx: RunContext[SupervisoryNodeDeps]):
"""执行与 dynamic prompt 相关的核心业务流转操作。 """执行与 dynamic prompt 相关的核心业务流转操作。
该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。 该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。
Args: ctx (RunContext[SupervisoryNodeDeps]): 参与 dynamic prompt 逻辑运算或数据构建的上下文依赖对象。 Args: ctx (RunContext[SupervisoryNodeDeps]): 参与 dynamic prompt 逻辑运算或数据构建的上下文依赖对象。
Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。"""
prompt = system_prompt + "\n\n" prompt = system_prompt + "\n\n"
prompt += ( prompt += (
f"=== 当前上下文 ===\n" f"=== 当前上下文 ===\n"
f"- 平台 (Platform): {ctx.deps.platform}\n" f"- 平台 (Platform): {ctx.deps.platform}\n"
f"- 用户名 (User): {ctx.deps.user_name}\n" f"- 用户名 (User): {ctx.deps.user_name}\n"
f"- 当前时间 (Time): {ctx.deps.time}\n" f"- 当前时间 (Time): {ctx.deps.time}\n"
f"- 可用工作流模板 (Available Templates): {ctx.deps.available_templates}\n"
) )
# 修改 system_prompt 变量 # 修改 system_prompt 变量
prompt += ( prompt += (
"\n\n注意:你必须调用且只能调用一个函数(工具)来输出结果。" "\n\n注意:你必须调用且只能调用一个函数(工具)来输出结果。"
"如果你想直接回复用户,请调用 ForUser;" "如果你想直接回复用户,请调用 ForUser;"
"如果你想移交给工作流,请调用 ForConsciousnessNode(若没有合适的模板,workflow_template 填 null" "如果你想移交给工作流,请调用 ForConsciousnessNode。"
"严禁返回纯文本,必须使用工具格式!" "严禁返回纯文本,必须使用工具格式!"
) )
if ctx.deps.error_history: if ctx.deps.error_history:
@@ -113,16 +129,21 @@ class SupervisoryNode:
try: try:
result = await self._run(payload) result = await self._run(payload)
if isinstance(result, ForConsciousnessNode): if isinstance(result, ForConsciousnessNode):
self.logger.info(f"SupervisoryNode: 任务已分配给工作流引擎处理,选用模板 [{result.workflow_template}]") self.logger.info("SupervisoryNode: 任务已分配给工作流引擎处理")
if isinstance(payload, PretorEvent): if isinstance(payload, PretorEvent):
payload.context["workflow_template"] = result.workflow_template
try: try:
global_workflow_manager = ray_actor_hook("global_workflow_manager").global_workflow_manager global_workflow_manager = ray_actor_hook(
"global_workflow_manager"
).global_workflow_manager
await global_workflow_manager.add_event.remote(payload) await global_workflow_manager.add_event.remote(payload)
workflow_running_engine = ray_actor_hook("workflow_running_engine").workflow_running_engine workflow_running_engine = ray_actor_hook(
"workflow_running_engine"
).workflow_running_engine
await workflow_running_engine.put_event.remote(payload) await workflow_running_engine.put_event.remote(payload)
except Exception as e: except Exception as e:
self.logger.error(f"SupervisoryNode: 无法将事件放入 WorkflowRunningEngine: {e}") self.logger.error(
f"SupervisoryNode: 无法将事件放入 WorkflowRunningEngine: {e}"
)
return "抱歉,任务提交失败,系统内部错误。" return "抱歉,任务提交失败,系统内部错误。"
return f"任务已创建,准备创建工作流。原因:{result.reasoning}" return f"任务已创建,准备创建工作流。原因:{result.reasoning}"
elif isinstance(result, ForUser): elif isinstance(result, ForUser):
@@ -144,7 +165,7 @@ class SupervisoryNode:
Returns: Returns:
ForUser对象,监控节点对于用户进行的简单回答 ForUser对象,监控节点对于用户进行的简单回答
ForConsciousnessNode对象,监控节点将用户的请求判断为复杂任务,将PretorEvent传递给意识节点,并且给选择好的工作流模板 ForConsciousnessNode对象,监控节点将用户的请求判断为复杂任务,将PretorEvent传递给意识节点
""" """
... ...
@@ -160,7 +181,9 @@ class SupervisoryNode:
""" """
... ...
async def _run(self, payload: Union[PretorEvent, TerminationMessage]) -> Union[ForConsciousnessNode, ForUser]: async def _run(
self, payload: Union[PretorEvent, TerminationMessage]
) -> Union[ForConsciousnessNode, ForUser]:
""" """
_run方法,将payload转化为对llm发送的消息并发送 _run方法,将payload转化为对llm发送的消息并发送
Args: Args:
@@ -175,23 +198,15 @@ class SupervisoryNode:
message = payload.message message = payload.message
time_str = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") time_str = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
try: try:
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
workflow_template_dict = await global_state_machine.get_all_workflow_templates.remote()
available_templates_str = "\n".join([f"- 名称: {k}, 描述/内容: {v}" for k, v in
workflow_template_dict.items()]) if workflow_template_dict else "暂无注册的工作流模板"
deps = SupervisoryNodeDeps( deps = SupervisoryNodeDeps(
platform=platform, platform=platform, user_name=user_name, time=time_str
user_name=user_name,
time=time_str,
available_templates=available_templates_str
) )
self.logger.debug("SupervisoryNode 开始生成 (启用原生 Pydantic-AI 重试)") self.logger.debug("SupervisoryNode 开始生成 (启用原生 Pydantic-AI 重试)")
prompt_message = message prompt_message = message
if isinstance(payload, TerminationMessage): if isinstance(payload, TerminationMessage):
prompt_message = f"【工作流执行结束报告】\n请将以下技术报告转化为对用户的友好回复:\n{message}" prompt_message = f"【工作流执行结束报告】\n请将以下技术报告转化为对用户的友好回复:\n{message}"
self.agent.retries = 3 self.agent.retries = 3
result = await self.agent.run(prompt_message, result = await self.agent.run(prompt_message, deps=deps)
deps=deps)
return result.output return result.output
except Exception as e: except Exception as e:
self.logger.exception(f"SupervisoryNode 模型生成或解析最终失败: {str(e)}") self.logger.exception(f"SupervisoryNode 模型生成或解析最终失败: {str(e)}")
@@ -16,35 +16,46 @@ from pydantic import Field
from pretor.utils.agent_model import ResponseModel, DepsModel from pretor.utils.agent_model import ResponseModel, DepsModel
from pydantic import BaseModel from pydantic import BaseModel
class SupervisoryNodeResponse(ResponseModel): class SupervisoryNodeResponse(ResponseModel):
"""SupervisoryNodeResponse 核心组件类。 """SupervisoryNodeResponse 核心组件类。
这是一个系统执行节点类,作为多智能体架构中的独立处理单元。它能够接收工作流上下文,根据内置的大模型策略进行意图理解和自主决策,从而驱动特定阶段的任务闭环。 """ 这是一个系统执行节点类,作为多智能体架构中的独立处理单元。它能够接收工作流上下文,根据内置的大模型策略进行意图理解和自主决策,从而驱动特定阶段的任务闭环。"""
pass pass
class ForUser(SupervisoryNodeResponse): class ForUser(SupervisoryNodeResponse):
"""ForUser 核心组件类。 """ForUser 核心组件类。
这是一个领域数据模型或功能封装类,承载了 ForUser 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。 """ 这是一个领域数据模型或功能封装类,承载了 ForUser 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。"""
context: str = Field(..., description="对用户的回复,应当使用和蔼的语气进行回复。用于直接解答简单问题或返回最终报告。")
context: str = Field(
...,
description="对用户的回复,应当使用和蔼的语气进行回复。用于直接解答简单问题或返回最终报告。",
)
class ForConsciousnessNode(SupervisoryNodeResponse): class ForConsciousnessNode(SupervisoryNodeResponse):
"""ForConsciousnessNode 核心组件类。 """ForConsciousnessNode 核心组件类。
这是一个系统执行节点类,作为多智能体架构中的独立处理单元。它能够接收工作流上下文,根据内置的大模型策略进行意图理解和自主决策,从而驱动特定阶段的任务闭环。 """ 这是一个系统执行节点类,作为多智能体架构中的独立处理单元。它能够接收工作流上下文,根据内置的大模型策略进行意图理解和自主决策,从而驱动特定阶段的任务闭环。"""
workflow_template: str | None = Field(default=None, description="选择的工作流模板的名称,用于处理复杂任务。若无需模板则为 None。")
reasoning: str = Field(..., description="选择将任务移交意识节点并选用该模板的简短原因。") reasoning: str = Field(..., description="选择将任务移交意识节点的简短原因。")
class TerminationMessage(BaseModel): class TerminationMessage(BaseModel):
"""TerminationMessage 核心组件类。 """TerminationMessage 核心组件类。
这是一个领域数据模型或功能封装类,承载了 TerminationMessage 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。 """ 这是一个领域数据模型或功能封装类,承载了 TerminationMessage 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。"""
platform: str platform: str
user_name: str user_name: str
message: str message: str
class SupervisoryNodeDeps(DepsModel): class SupervisoryNodeDeps(DepsModel):
"""SupervisoryNodeDeps 核心组件类。 """SupervisoryNodeDeps 核心组件类。
这是一个系统执行节点类,作为多智能体架构中的独立处理单元。它能够接收工作流上下文,根据内置的大模型策略进行意图理解和自主决策,从而驱动特定阶段的任务闭环。 """ 这是一个系统执行节点类,作为多智能体架构中的独立处理单元。它能够接收工作流上下文,根据内置的大模型策略进行意图理解和自主决策,从而驱动特定阶段的任务闭环。"""
platform: str platform: str
user_name: str user_name: str
time: str time: str
retry_count: int = 0 retry_count: int = 0
error_history: str = "" error_history: str = ""
available_templates: str = "默认工作流 (default_workflow)"
@@ -0,0 +1,3 @@
from pretor.core.postgres_database.postgres import PostgresDatabase
__all__ = ["PostgresDatabase"]
@@ -26,19 +26,25 @@ from pretor.core.database.module.user import AuthDatabase
from pretor.core.database.module.provider import ProviderDatabase from pretor.core.database.module.provider import ProviderDatabase
from pretor.core.database.module.system_node import SystemNodeDatabase from pretor.core.database.module.system_node import SystemNodeDatabase
@ray.remote @ray.remote
class PostgresDatabase: class PostgresDatabase:
"""PostgresDatabase 核心组件类。 """PostgresDatabase 核心组件类。
这是一个数据库操作层 (DAO/Repository) 封装类专注于处理实体模型与关系型数据库表之间的映射它将复杂的 SQL 查询跨表 Join 和事务回滚逻辑进行了高级抽象向上层服务暴露简洁的数据读写接口 """ 这是一个数据库操作层 (DAO/Repository) 封装类专注于处理实体模型与关系型数据库表之间的映射它将复杂的 SQL 查询跨表 Join 和事务回滚逻辑进行了高级抽象向上层服务暴露简洁的数据读写接口"""
def __init__(self): def __init__(self):
user = os.environ.get('POSTGRES_USER') user = os.environ.get("POSTGRES_USER")
password = os.environ.get('POSTGRES_PASSWORD') password = os.environ.get("POSTGRES_PASSWORD")
host = os.environ.get('POSTGRES_HOST') host = os.environ.get("POSTGRES_HOST")
port = os.environ.get('POSTGRES_PORT') port = os.environ.get("POSTGRES_PORT")
database = os.environ.get('POSTGRES_DB') database = os.environ.get("POSTGRES_DB")
database_url = f"postgresql+asyncpg://{user}:{password}@{host}:{port}/{database}" database_url = (
f"postgresql+asyncpg://{user}:{password}@{host}:{port}/{database}"
)
self.async_engine = create_async_engine(database_url, echo=True) self.async_engine = create_async_engine(database_url, echo=True)
self.async_session_maker = sessionmaker(self.async_engine, class_=AsyncSession, expire_on_commit=False) self.async_session_maker = sessionmaker(
self.async_engine, class_=AsyncSession, expire_on_commit=False
)
self._auth_database = AuthDatabase(self.async_session_maker) self._auth_database = AuthDatabase(self.async_session_maker)
self._provider_database = ProviderDatabase(self.async_session_maker) self._provider_database = ProviderDatabase(self.async_session_maker)
@@ -51,7 +57,7 @@ class PostgresDatabase:
async def init_db(self) -> None: async def init_db(self) -> None:
"""完成 db 模块的启动与依赖初始化。 """完成 db 模块的启动与依赖初始化。
在系统引导或服务拉起阶段被调用负责建立网络连接分配基础内存资源及注册核心服务组件 在系统引导或服务拉起阶段被调用负责建立网络连接分配基础内存资源及注册核心服务组件
Returns: (None): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象 """ Returns: (None): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象"""
try: try:
async with self.async_engine.begin() as conn: async with self.async_engine.begin() as conn:
await conn.run_sync(SQLModel.metadata.create_all) await conn.run_sync(SQLModel.metadata.create_all)
@@ -67,7 +73,7 @@ class PostgresDatabase:
"""创建并持久化新的 user 实体。 """创建并持久化新的 user 实体。
接收构建参数执行必要的数据校验与默认值填充后将新记录安全地写入底层存储或系统注册表中 接收构建参数执行必要的数据校验与默认值填充后将新记录安全地写入底层存储或系统注册表中
Args: user_name (str): 赋予该实体的人类可读名称或标题字符串主要用于前端 UI 展示日志记录或模糊检索 hashed_password (str): 控制逻辑流向的具体字符串参数指定了期望的 hashed password 内容 Args: user_name (str): 赋予该实体的人类可读名称或标题字符串主要用于前端 UI 展示日志记录或模糊检索 hashed_password (str): 控制逻辑流向的具体字符串参数指定了期望的 hashed password 内容
Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象 """ Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象"""
await self.ready_event.wait() await self.ready_event.wait()
return await self._auth_database.add_user(user_name, hashed_password) return await self._auth_database.add_user(user_name, hashed_password)
@@ -75,15 +81,17 @@ class PostgresDatabase:
"""执行与 change password 相关的核心业务流转操作。 """执行与 change password 相关的核心业务流转操作。
该方法封装了具体的算法策略或状态控制逻辑确保操作能够在事务上下文中被原子且一致地执行 该方法封装了具体的算法策略或状态控制逻辑确保操作能够在事务上下文中被原子且一致地执行
Args: user_name: 赋予该实体的人类可读名称或标题字符串主要用于前端 UI 展示日志记录或模糊检索 old_password: 参与 change password 逻辑运算或数据构建的上下文依赖对象 new_password: 参与 change password 逻辑运算或数据构建的上下文依赖对象 Args: user_name: 赋予该实体的人类可读名称或标题字符串主要用于前端 UI 展示日志记录或模糊检索 old_password: 参与 change password 逻辑运算或数据构建的上下文依赖对象 new_password: 参与 change password 逻辑运算或数据构建的上下文依赖对象
Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象 """ Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象"""
await self.ready_event.wait() await self.ready_event.wait()
return await self._auth_database.change_password(user_name, old_password, new_password) return await self._auth_database.change_password(
user_name, old_password, new_password
)
async def delete_user(self, user_name: str): async def delete_user(self, user_name: str):
"""安全地移除或注销 user。 """安全地移除或注销 user。
执行物理删除或逻辑删除操作并妥善清理相关的关联数据及占用资源 执行物理删除或逻辑删除操作并妥善清理相关的关联数据及占用资源
Args: user_name (str): 赋予该实体的人类可读名称或标题字符串主要用于前端 UI 展示日志记录或模糊检索 Args: user_name (str): 赋予该实体的人类可读名称或标题字符串主要用于前端 UI 展示日志记录或模糊检索
Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象 """ Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象"""
await self.ready_event.wait() await self.ready_event.wait()
return await self._auth_database.delete_user(user_name) return await self._auth_database.delete_user(user_name)
@@ -91,7 +99,7 @@ class PostgresDatabase:
"""安全地移除或注销 user by id。 """安全地移除或注销 user by id。
执行物理删除或逻辑删除操作并妥善清理相关的关联数据及占用资源 执行物理删除或逻辑删除操作并妥善清理相关的关联数据及占用资源
Args: user_id (str): 目标对象的唯一全局标识符 (UUID/ULID)用于在数据库表或缓存结构中精准匹配该 user 实例 Args: user_id (str): 目标对象的唯一全局标识符 (UUID/ULID)用于在数据库表或缓存结构中精准匹配该 user 实例
Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象 """ Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象"""
await self.ready_event.wait() await self.ready_event.wait()
return await self._auth_database.delete_user_by_id(user_id) return await self._auth_database.delete_user_by_id(user_id)
@@ -99,14 +107,14 @@ class PostgresDatabase:
"""执行与 login user 相关的核心业务流转操作。 """执行与 login user 相关的核心业务流转操作。
该方法封装了具体的算法策略或状态控制逻辑确保操作能够在事务上下文中被原子且一致地执行 该方法封装了具体的算法策略或状态控制逻辑确保操作能够在事务上下文中被原子且一致地执行
Args: user_name (str): 赋予该实体的人类可读名称或标题字符串主要用于前端 UI 展示日志记录或模糊检索 Args: user_name (str): 赋予该实体的人类可读名称或标题字符串主要用于前端 UI 展示日志记录或模糊检索
Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象 """ Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象"""
await self.ready_event.wait() await self.ready_event.wait()
return await self._auth_database.login_user(user_name) return await self._auth_database.login_user(user_name)
async def get_all_users(self): async def get_all_users(self):
"""检索并获取特定的 all users 数据集合或实例对象。 """检索并获取特定的 all users 数据集合或实例对象。
根据提供的查询条件或上下文凭证从数据库缓存或第三方服务中读取对应的资源状态 根据提供的查询条件或上下文凭证从数据库缓存或第三方服务中读取对应的资源状态
Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象 """ Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象"""
await self.ready_event.wait() await self.ready_event.wait()
return await self._auth_database.get_all_users() return await self._auth_database.get_all_users()
@@ -114,7 +122,7 @@ class PostgresDatabase:
"""检索并获取特定的 user authority 数据集合或实例对象。 """检索并获取特定的 user authority 数据集合或实例对象。
根据提供的查询条件或上下文凭证从数据库缓存或第三方服务中读取对应的资源状态 根据提供的查询条件或上下文凭证从数据库缓存或第三方服务中读取对应的资源状态
Args: user_id (str): 目标对象的唯一全局标识符 (UUID/ULID)用于在数据库表或缓存结构中精准匹配该 user 实例 Args: user_id (str): 目标对象的唯一全局标识符 (UUID/ULID)用于在数据库表或缓存结构中精准匹配该 user 实例
Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象 """ Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象"""
await self.ready_event.wait() await self.ready_event.wait()
return await self._auth_database.get_user_authority(user_id) return await self._auth_database.get_user_authority(user_id)
@@ -122,7 +130,7 @@ class PostgresDatabase:
"""执行与 change user authority 相关的核心业务流转操作。 """执行与 change user authority 相关的核心业务流转操作。
该方法封装了具体的算法策略或状态控制逻辑确保操作能够在事务上下文中被原子且一致地执行 该方法封装了具体的算法策略或状态控制逻辑确保操作能够在事务上下文中被原子且一致地执行
Args: user_id (str): 目标对象的唯一全局标识符 (UUID/ULID)用于在数据库表或缓存结构中精准匹配该 user 实例 new_authority: 参与 change user authority 逻辑运算或数据构建的上下文依赖对象 Args: user_id (str): 目标对象的唯一全局标识符 (UUID/ULID)用于在数据库表或缓存结构中精准匹配该 user 实例 new_authority: 参与 change user authority 逻辑运算或数据构建的上下文依赖对象
Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象 """ Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象"""
await self.ready_event.wait() await self.ready_event.wait()
return await self._auth_database.change_user_authority(user_id, new_authority) return await self._auth_database.change_user_authority(user_id, new_authority)
@@ -130,14 +138,14 @@ class PostgresDatabase:
async def get_provider(self): async def get_provider(self):
"""检索并获取特定的 provider 数据集合或实例对象。 """检索并获取特定的 provider 数据集合或实例对象。
根据提供的查询条件或上下文凭证从数据库缓存或第三方服务中读取对应的资源状态 根据提供的查询条件或上下文凭证从数据库缓存或第三方服务中读取对应的资源状态
Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象 """ Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象"""
await self.ready_event.wait() await self.ready_event.wait()
return await self._provider_database.get_provider() return await self._provider_database.get_provider()
async def add_provider_db(self, **kwargs): async def add_provider_db(self, **kwargs):
"""创建并持久化新的 provider db 实体。 """创建并持久化新的 provider db 实体。
接收构建参数执行必要的数据校验与默认值填充后将新记录安全地写入底层存储或系统注册表中 接收构建参数执行必要的数据校验与默认值填充后将新记录安全地写入底层存储或系统注册表中
Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象 """ Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象"""
await self.ready_event.wait() await self.ready_event.wait()
return await self._provider_database.add_provider(**kwargs) return await self._provider_database.add_provider(**kwargs)
@@ -145,7 +153,7 @@ class PostgresDatabase:
"""安全地移除或注销 provider db。 """安全地移除或注销 provider db。
执行物理删除或逻辑删除操作并妥善清理相关的关联数据及占用资源 执行物理删除或逻辑删除操作并妥善清理相关的关联数据及占用资源
Args: provider_id (str): 目标对象的唯一全局标识符 (UUID/ULID)用于在数据库表或缓存结构中精准匹配该 provider 实例 Args: provider_id (str): 目标对象的唯一全局标识符 (UUID/ULID)用于在数据库表或缓存结构中精准匹配该 provider 实例
Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象 """ Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象"""
await self.ready_event.wait() await self.ready_event.wait()
return await self._provider_database.delete_provider(provider_id) return await self._provider_database.delete_provider(provider_id)
@@ -153,23 +161,31 @@ class PostgresDatabase:
"""对现有的 provider db 进行状态更新或属性覆盖。 """对现有的 provider db 进行状态更新或属性覆盖。
基于增量变更原则合并最新的配置或数据并触发相关依赖组件的缓存刷新或事件通知 基于增量变更原则合并最新的配置或数据并触发相关依赖组件的缓存刷新或事件通知
Args: provider_id (str): 目标对象的唯一全局标识符 (UUID/ULID)用于在数据库表或缓存结构中精准匹配该 provider 实例 Args: provider_id (str): 目标对象的唯一全局标识符 (UUID/ULID)用于在数据库表或缓存结构中精准匹配该 provider 实例
Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象 """ Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象"""
await self.ready_event.wait() await self.ready_event.wait()
return await self._provider_database.update_provider(provider_id, **kwargs) return await self._provider_database.update_provider(provider_id, **kwargs)
# System Node Database Methods # System Node Database Methods
async def upsert_system_node_config(self, node_name: str, provider_title: str, model_id: str, tools: list[str] = None): async def upsert_system_node_config(
self,
node_name: str,
provider_title: str,
model_id: str,
tools: list[str] = None,
):
"""执行与 upsert system node config 相关的核心业务流转操作。 """执行与 upsert system node config 相关的核心业务流转操作。
该方法封装了具体的算法策略或状态控制逻辑确保操作能够在事务上下文中被原子且一致地执行 该方法封装了具体的算法策略或状态控制逻辑确保操作能够在事务上下文中被原子且一致地执行
Args: node_name (str): 赋予该实体的人类可读名称或标题字符串主要用于前端 UI 展示日志记录或模糊检索 provider_title (str): 目标对象的唯一全局标识符 (UUID/ULID)用于在数据库表或缓存结构中精准匹配该 provider_title 实例 model_id (str): 目标对象的唯一全局标识符 (UUID/ULID)用于在数据库表或缓存结构中精准匹配该 model 实例 tools (list[str]): 控制逻辑流向的具体字符串参数指定了期望的 tools 内容 Args: node_name (str): 赋予该实体的人类可读名称或标题字符串主要用于前端 UI 展示日志记录或模糊检索 provider_title (str): 目标对象的唯一全局标识符 (UUID/ULID)用于在数据库表或缓存结构中精准匹配该 provider_title 实例 model_id (str): 目标对象的唯一全局标识符 (UUID/ULID)用于在数据库表或缓存结构中精准匹配该 model 实例 tools (list[str]): 控制逻辑流向的具体字符串参数指定了期望的 tools 内容
Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象 """ Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象"""
await self.ready_event.wait() await self.ready_event.wait()
return await self._system_node_database.upsert_system_node_config(node_name, provider_title, model_id, tools) return await self._system_node_database.upsert_system_node_config(
node_name, provider_title, model_id, tools
)
async def get_all_system_node_configs(self): async def get_all_system_node_configs(self):
"""检索并获取特定的 all system node configs 数据集合或实例对象。 """检索并获取特定的 all system node configs 数据集合或实例对象。
根据提供的查询条件或上下文凭证从数据库缓存或第三方服务中读取对应的资源状态 根据提供的查询条件或上下文凭证从数据库缓存或第三方服务中读取对应的资源状态
Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象 """ Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象"""
await self.ready_event.wait() await self.ready_event.wait()
return await self._system_node_database.get_all_system_node_configs() return await self._system_node_database.get_all_system_node_configs()
@@ -177,7 +193,7 @@ class PostgresDatabase:
async def add_worker_individual(self, **kwargs): async def add_worker_individual(self, **kwargs):
"""创建并持久化新的 worker individual 实体。 """创建并持久化新的 worker individual 实体。
接收构建参数执行必要的数据校验与默认值填充后将新记录安全地写入底层存储或系统注册表中 接收构建参数执行必要的数据校验与默认值填充后将新记录安全地写入底层存储或系统注册表中
Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象 """ Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象"""
await self.ready_event.wait() await self.ready_event.wait()
return await self._individual_database.add_worker_individual(**kwargs) return await self._individual_database.add_worker_individual(**kwargs)
@@ -185,7 +201,7 @@ class PostgresDatabase:
"""检索并获取特定的 worker individual 数据集合或实例对象。 """检索并获取特定的 worker individual 数据集合或实例对象。
根据提供的查询条件或上下文凭证从数据库缓存或第三方服务中读取对应的资源状态 根据提供的查询条件或上下文凭证从数据库缓存或第三方服务中读取对应的资源状态
Args: agent_id (str): 目标对象的唯一全局标识符 (UUID/ULID)用于在数据库表或缓存结构中精准匹配该 agent 实例 Args: agent_id (str): 目标对象的唯一全局标识符 (UUID/ULID)用于在数据库表或缓存结构中精准匹配该 agent 实例
Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象 """ Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象"""
await self.ready_event.wait() await self.ready_event.wait()
return await self._individual_database.get_worker_individual(agent_id) return await self._individual_database.get_worker_individual(agent_id)
@@ -193,7 +209,7 @@ class PostgresDatabase:
"""检索并获取特定的 worker individual list 数据集合或实例对象。 """检索并获取特定的 worker individual list 数据集合或实例对象。
根据提供的查询条件或上下文凭证从数据库缓存或第三方服务中读取对应的资源状态 根据提供的查询条件或上下文凭证从数据库缓存或第三方服务中读取对应的资源状态
Args: owner_id (str): 目标对象的唯一全局标识符 (UUID/ULID)用于在数据库表或缓存结构中精准匹配该 owner 实例 Args: owner_id (str): 目标对象的唯一全局标识符 (UUID/ULID)用于在数据库表或缓存结构中精准匹配该 owner 实例
Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象 """ Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象"""
await self.ready_event.wait() await self.ready_event.wait()
return await self._individual_database.get_worker_individual_list(owner_id) return await self._individual_database.get_worker_individual_list(owner_id)
@@ -201,24 +217,27 @@ class PostgresDatabase:
"""对现有的 worker individual 进行状态更新或属性覆盖。 """对现有的 worker individual 进行状态更新或属性覆盖。
基于增量变更原则合并最新的配置或数据并触发相关依赖组件的缓存刷新或事件通知 基于增量变更原则合并最新的配置或数据并触发相关依赖组件的缓存刷新或事件通知
Args: agent_id (str): 目标对象的唯一全局标识符 (UUID/ULID)用于在数据库表或缓存结构中精准匹配该 agent 实例 Args: agent_id (str): 目标对象的唯一全局标识符 (UUID/ULID)用于在数据库表或缓存结构中精准匹配该 agent 实例
Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象 """ Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象"""
await self.ready_event.wait() await self.ready_event.wait()
return await self._individual_database.update_worker_individual(agent_id, **kwargs) return await self._individual_database.update_worker_individual(
agent_id, **kwargs
)
async def delete_worker_individual(self, agent_id: str): async def delete_worker_individual(self, agent_id: str):
"""安全地移除或注销 worker individual。 """安全地移除或注销 worker individual。
执行物理删除或逻辑删除操作并妥善清理相关的关联数据及占用资源 执行物理删除或逻辑删除操作并妥善清理相关的关联数据及占用资源
Args: agent_id (str): 目标对象的唯一全局标识符 (UUID/ULID)用于在数据库表或缓存结构中精准匹配该 agent 实例 Args: agent_id (str): 目标对象的唯一全局标识符 (UUID/ULID)用于在数据库表或缓存结构中精准匹配该 agent 实例
Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象 """ Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象"""
await self.ready_event.wait() await self.ready_event.wait()
return await self._individual_database.delete_worker_individual(agent_id) return await self._individual_database.delete_worker_individual(agent_id)
async def get_all_worker_individual(self): async def get_all_worker_individual(self):
"""检索并获取特定的 all worker individual 数据集合或实例对象。 """检索并获取特定的 all worker individual 数据集合或实例对象。
根据提供的查询条件或上下文凭证从数据库缓存或第三方服务中读取对应的资源状态 根据提供的查询条件或上下文凭证从数据库缓存或第三方服务中读取对应的资源状态
Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象 """ Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象"""
await self.ready_event.wait() await self.ready_event.wait()
return await self._individual_database.get_all_worker_individual() return await self._individual_database.get_all_worker_individual()
# Event Database Methods # Event Database Methods
async def upsert_event(self, trace_id: str, event_data_json: str): async def upsert_event(self, trace_id: str, event_data_json: str):
await self.ready_event.wait() await self.ready_event.wait()
+43 -22
View File
@@ -16,69 +16,88 @@ from typing import List, Optional, Union, Literal, Dict, Any
from pydantic import BaseModel, Field, model_validator from pydantic import BaseModel, Field, model_validator
from pretor.utils.logger import get_logger from pretor.utils.logger import get_logger
logger = get_logger('workflow')
logger = get_logger("workflow")
NodeType = Literal[ NodeType = Literal[
"consciousness_node", "control_node", "supervisory_node", "skill_individual" "consciousness_node", "control_node", "supervisory_node", "skill_individual"
] ]
class EventInfo(BaseModel): class EventInfo(BaseModel):
"""EventInfo 核心组件类。 """EventInfo 核心组件类。
这是一个领域数据模型或功能封装类,承载了 EventInfo 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。 """ 这是一个领域数据模型或功能封装类,承载了 EventInfo 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。"""
platform: str platform: str
user_name: str user_name: str
class LogicGate(BaseModel): class LogicGate(BaseModel):
"""LogicGate 核心组件类。 """LogicGate 核心组件类。
这是一个领域数据模型或功能封装类,承载了 LogicGate 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。 """ 这是一个领域数据模型或功能封装类,承载了 LogicGate 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。"""
if_fail: str = Field(..., description="失败跳转目标,如 'jump_to_step_1'") if_fail: str = Field(..., description="失败跳转目标,如 'jump_to_step_1'")
if_pass: Literal["continue", "exit"] = Field(default="continue", description="成功后的动作") if_pass: Literal["continue", "exit"] = Field(
default="continue", description="成功后的动作"
)
class WorkStep(BaseModel): class WorkStep(BaseModel):
"""WorkStep 核心组件类。 """WorkStep 核心组件类。
这是一个领域数据模型或功能封装类,承载了 WorkStep 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。 """ 这是一个领域数据模型或功能封装类,承载了 WorkStep 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。"""
step: int = Field(..., gt=0, description="步骤序号,严格自增") step: int = Field(..., gt=0, description="步骤序号,严格自增")
name: str = Field(..., description="步骤名称") name: str = Field(..., description="步骤名称")
node: NodeType = Field(..., description="负责执行的节点类型") node: NodeType = Field(..., description="负责执行的节点类型")
action: str = Field(..., description="执行的原子动作") action: str = Field(..., description="执行的原子动作")
desc: str = Field(..., description="动作细节的自然语言描述,包含人工规范指导") desc: str = Field(..., description="动作细节的自然语言描述,包含人工规范指导")
inputs: Optional[Union[str, List[str]]] = Field(default=None, description="前置依赖输出") inputs: Optional[Union[str, List[str]]] = Field(
default=None, description="前置依赖输出"
)
outputs: Optional[str] = Field(default=None, description="当前步骤产出物变量名") outputs: Optional[str] = Field(default=None, description="当前步骤产出物变量名")
agent_id: Optional[str] = Field(default=None, description="分配给 skill_individual 的 Skill Individual 真实 agent_id,不可用名称代替") agent_id: Optional[str] = Field(
default=None,
description="分配给 skill_individual 的 Skill Individual 真实 agent_id,不可用名称代替",
)
logic_gate: Optional[LogicGate] = Field(default=None, description="逻辑跳转控制") logic_gate: Optional[LogicGate] = Field(default=None, description="逻辑跳转控制")
status: Literal["waiting", "running", "completed", "failed"] = Field( status: Literal["waiting", "running", "completed", "failed"] = Field(
default="waiting", default="waiting", description="执行状态 (LLM建议保留默认值)"
description="执行状态 (LLM建议保留默认值)"
) )
class WorkflowStatus(BaseModel): class WorkflowStatus(BaseModel):
"""WorkflowStatus 核心组件类。 """WorkflowStatus 核心组件类。
这是一个领域数据模型或功能封装类,承载了 WorkflowStatus 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。 """ 这是一个领域数据模型或功能封装类,承载了 WorkflowStatus 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。"""
step: int = Field(default=1, gt=0, description="当前运行到的工作流步数") step: int = Field(default=1, gt=0, description="当前运行到的工作流步数")
status: Literal["waiting_llm_working", "waiting_tool_working", "llm_working", "tool_working"] = Field( status: Literal[
default="waiting_llm_working", "waiting_llm_working", "waiting_tool_working", "llm_working", "tool_working"
description="当前系统调度状态" ] = Field(default="waiting_llm_working", description="当前系统调度状态")
)
class PretorWorkflow(BaseModel): class PretorWorkflow(BaseModel):
"""PretorWorkflow 核心组件类。 """PretorWorkflow 核心组件类。
这是一个领域数据模型或功能封装类,承载了 PretorWorkflow 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。 """ 这是一个领域数据模型或功能封装类,承载了 PretorWorkflow 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。"""
title: str = Field(..., description="工作流的标题") title: str = Field(..., description="工作流的标题")
work_link: List[WorkStep] = Field(..., description="工作链逻辑定义") work_link: List[WorkStep] = Field(..., description="工作链逻辑定义")
# ---------------- 以下为系统级管控字段,LLM 无需关心 ---------------- # # ---------------- 以下为系统级管控字段,LLM 无需关心 ---------------- #
trace_id: str | None = Field(description="系统自动生成的追溯ID") trace_id: str | None = Field(description="系统自动生成的追溯ID")
version: str = Field(default="v1.0", description="系统协议版本号") version: str = Field(default="v1.0", description="系统协议版本号")
command: Optional[str] = Field(default=None, description="触发此工作流的原始命令") command: Optional[str] = Field(default=None, description="触发此工作流的原始命令")
output: Dict[str, Any] = Field(default_factory=dict, description="工作流最终产出结果") output: Dict[str, Any] = Field(
status: WorkflowStatus = Field(default_factory=WorkflowStatus, description="运行时状态对象") default_factory=dict, description="工作流最终产出结果"
)
status: WorkflowStatus = Field(
default_factory=WorkflowStatus, description="运行时状态对象"
)
event_info: EventInfo | None = Field(default=None) event_info: EventInfo | None = Field(default=None)
context_memory: Dict[str, Any] = Field(default_factory=dict) context_memory: Dict[str, Any] = Field(default_factory=dict)
@model_validator(mode='after') @model_validator(mode="after")
def validate_workflow_integrity(self) -> 'PretorWorkflow': def validate_workflow_integrity(self) -> "PretorWorkflow":
"""执行与 validate workflow integrity 相关的核心业务流转操作。 """执行与 validate workflow integrity 相关的核心业务流转操作。
该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。 该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。
Returns: ('PretorWorkflow'): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ Returns: ('PretorWorkflow'): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。"""
steps = [s.step for s in self.work_link] steps = [s.step for s in self.work_link]
expected = list(range(1, len(steps) + 1)) expected = list(range(1, len(steps) + 1))
if steps != expected: if steps != expected:
@@ -90,9 +109,11 @@ class PretorWorkflow(BaseModel):
try: try:
target = int(s.logic_gate.if_fail.split("_")[-1]) target = int(s.logic_gate.if_fail.split("_")[-1])
if target > max_step or target < 1: if target > max_step or target < 1:
raise ValueError(f"Step {s.step} 的跳转目标 Step {target} 越界了!") raise ValueError(
f"Step {s.step} 的跳转目标 Step {target} 越界了!"
)
except ValueError as e: except ValueError as e:
if "越界" in str(e): if "越界" in str(e):
raise e raise e
raise ValueError(f"LogicGate 格式错误: {s.logic_gate.if_fail}") raise ValueError(f"LogicGate 格式错误: {s.logic_gate.if_fail}")
return self return self
@@ -1,14 +0,0 @@
# Copyright 2026 zhaoxi826
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
@@ -1,46 +0,0 @@
# Copyright 2026 zhaoxi826
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pydantic import BaseModel, model_validator
from typing import Dict,List
class WorkflowTemplateStep(BaseModel):
"""WorkflowTemplateStep 核心组件类。
这是一个领域数据模型或功能封装类,承载了 WorkflowTemplateStep 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。 """
step: int
node: str
action: str
desc: str
input: List[str]
output: List[str]
logic_gate: Dict[str, str]
class WorkflowTemplate(BaseModel):
"""WorkflowTemplate 核心组件类。
这是一个领域数据模型或功能封装类,承载了 WorkflowTemplate 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。 """
name: str
desc: str
work_link: list[WorkflowTemplateStep]
@model_validator(mode='after')
def validate_steps(self) -> 'WorkflowTemplate':
"""执行与 validate steps 相关的核心业务流转操作。
该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。
Returns: ('WorkflowTemplate'): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """
steps = [s.step for s in self.work_link]
if len(steps) != len(set(steps)):
raise ValueError("Step numbers in work_link must be unique")
return self
@@ -1,32 +0,0 @@
# Copyright 2026 zhaoxi826
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pathlib import Path
from pretor.core.workflow.workflow_template_generator.workflow_template import WorkflowTemplate
class WorkflowTemplateGenerator:
"""WorkflowTemplateGenerator 核心组件类。
这是一个领域数据模型或功能封装类,承载了 WorkflowTemplateGenerator 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。 """
@staticmethod
def generate_workflow_template(workflow_template: WorkflowTemplate) -> WorkflowTemplate:
"""执行与 generate workflow template 相关的核心业务流转操作。
该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。
Args: workflow_template (WorkflowTemplate): 参与 generate workflow template 逻辑运算或数据构建的上下文依赖对象。
Returns: (WorkflowTemplate): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """
output_dir = Path("pretor") / "workflow_template"
if not output_dir.exists():
output_dir.mkdir(parents=True)
output_file = output_dir / f"{workflow_template.name}_workflow_template.json"
with output_file.open("w", encoding="utf-8") as f:
f.write(workflow_template.model_dump_json(indent=4))
@@ -1,76 +0,0 @@
# Copyright 2026 zhaoxi826
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from pretor.core.workflow.workflow_template_generator.workflow_template_generator import WorkflowTemplateGenerator
from pathlib import Path
from pretor.core.workflow.workflow_template_generator.workflow_template import WorkflowTemplate
from pretor.utils.logger import get_logger
logger = get_logger('workflow_template_manager')
class WorkflowManager:
"""WorkflowManager 核心组件类。
这是一个管理器类,职责集中在维护整个系统内有关 Workflow 资源的全局生命周期。它提供了注册机制、状态同步以及跨组件的统一查询入口,确保系统中该类型资源的实例一致性与可控性。 """
def __init__(self):
self.workflow_template_generator = WorkflowTemplateGenerator()
self.workflow_templates_registry = {}
self.template_path = Path("pretor/workflow_template")
self._load_workflow_template()
def _load_workflow_template(self) -> None:
"""执行与 load workflow template 相关的核心业务流转操作。
该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。
Returns: (None): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """
for workflow_template_file in self.template_path.glob("*_workflow_template.json"):
with workflow_template_file.open("r",encoding="utf-8") as f:
try:
workflow_template = json.load(f)
self.workflow_templates_registry[workflow_template.get("name")] = workflow_template.get("desc")
except json.decoder.JSONDecodeError:
logger.warning(f"{workflow_template_file}不是json文件或格式错误")
except KeyError:
logger.warning(f"{workflow_template_file}不符合workflow_template格式")
def generate_workflow_template(self, workflow_template: WorkflowTemplate) -> None:
"""执行与 generate workflow template 相关的核心业务流转操作。
该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。
Args: workflow_template (WorkflowTemplate): 参与 generate workflow template 逻辑运算或数据构建的上下文依赖对象。
Returns: (None): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """
try:
workflow_template = self.workflow_template_generator.generate_workflow_template(workflow_template=workflow_template)
self.workflow_templates_registry[workflow_template.name] = workflow_template.desc
except Exception:
logger.exception("Failed to generate workflow template")
def add_workflow_template(self, template_name: str, workflow_template: WorkflowTemplate) -> None:
"""创建并持久化新的 workflow template 实体。
接收构建参数,执行必要的数据校验与默认值填充后,将新记录安全地写入底层存储或系统注册表中。
Args: template_name (str): 赋予该实体的人类可读名称或标题字符串,主要用于前端 UI 展示、日志记录或模糊检索。 workflow_template (WorkflowTemplate): 参与 add workflow template 逻辑运算或数据构建的上下文依赖对象。
Returns: (None): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """
self.generate_workflow_template(workflow_template)
def get_all_workflow_templates(self) -> dict:
"""检索并获取特定的 all workflow templates 数据集合或实例对象。
根据提供的查询条件或上下文凭证,从数据库、缓存或第三方服务中读取对应的资源状态。
Returns: (dict): 高度聚合的字典结构数据,将多维度的属性特征或统计指标组合后一并返回。 """
return self.workflow_templates_registry
def delete_workflow_template(self, template_name: str) -> None:
"""安全地移除或注销 workflow template。
执行物理删除或逻辑删除操作,并妥善清理相关的关联数据及占用资源。
Args: template_name (str): 赋予该实体的人类可读名称或标题字符串,主要用于前端 UI 展示、日志记录或模糊检索。
Returns: (None): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """
if template_name in self.workflow_templates_registry:
del self.workflow_templates_registry[template_name]
@@ -0,0 +1,3 @@
from pretor.core.workflow_running_engine.workflow_runner import WorkflowRunningEngine
__all__ = ["WorkflowRunningEngine"]
@@ -19,44 +19,40 @@ from pretor.core.workflow.workflow import PretorWorkflow, WorkStep, EventInfo
from typing import Optional, Dict, Union, Any, List from typing import Optional, Dict, Union, Any, List
from pretor.utils.error import WorkflowError, WorkflowExit from pretor.utils.error import WorkflowError, WorkflowExit
from pretor.api.platform.event import PretorEvent from pretor.api.platform.event import PretorEvent
from pretor.core.individual.control_node.template import ForWorkflowInput as ControlForWorkflowInput, \ from pretor.core.individual.control_node.template import (
ForWorkflow as ControlForWorkflow ForWorkflowInput as ControlForWorkflowInput,
ForWorkflow as ControlForWorkflow,
)
from pretor.core.individual.consciousness_node.template import ( from pretor.core.individual.consciousness_node.template import (
ForWorkflowInput as ConsciousnessForWorkflowInput, ForWorkflowInput as ConsciousnessForWorkflowInput,
ForSupervisoryInput, ForSupervisoryInput,
ForSupervisoryNode, ForSupervisoryNode,
ForWorkflow as ConsciousnessForWorkflow, ForWorkflow as ConsciousnessForWorkflow,
ForWorkflowEngineInput, ForWorkflowEngineInput,
ForWorkflowEngine ForWorkflowEngine,
) )
from pretor.core.individual.supervisory_node.template import TerminationMessage from pretor.core.individual.supervisory_node.template import TerminationMessage
import pathlib
def get_workflow_template(workflow_name: str) -> str:
"""检索并获取特定的 workflow template 数据集合或实例对象。
根据提供的查询条件或上下文凭证从数据库缓存或第三方服务中读取对应的资源状态
Args: workflow_name (str): 赋予该实体的人类可读名称或标题字符串主要用于前端 UI 展示日志记录或模糊检索
Returns: (str): 处理流程所输出的具体字符串产物可能是新生成的 ID 序列格式化好的文本片段或 LLM 推理的回答内容 """
workflow_template = pathlib.Path(__file__).parent.parent.parent / "workflow_template" / (workflow_name + "_workflow_template.json")
with open(workflow_template, "r", encoding="utf-8") as workflow_template_file:
workflow_template = workflow_template_file.read()
return workflow_template
class WorkflowEngine: class WorkflowEngine:
"""WorkflowEngine 核心组件类。 """WorkflowEngine 核心组件类。
这是一个领域数据模型或功能封装类承载了 WorkflowEngine 相关的内聚属性定义与状态维护它的存在隔离了局部的业务复杂性并对外提供了类型安全的访问接口 """ 这是一个领域数据模型或功能封装类承载了 WorkflowEngine 相关的内聚属性定义与状态维护它的存在隔离了局部的业务复杂性并对外提供了类型安全的访问接口"""
def __init__(self,
workflow: PretorWorkflow, def __init__(
consciousness_node=None, self,
control_node=None, workflow: PretorWorkflow,
supervisory_node=None): consciousness_node=None,
control_node=None,
supervisory_node=None,
):
from pretor.utils.logger import get_logger from pretor.utils.logger import get_logger
self.logger = get_logger('workflow_runner')
self.logger = get_logger("workflow_runner")
self.workflow: PretorWorkflow = workflow self.workflow: PretorWorkflow = workflow
"""工作流:当前WorkflowEngine待执行的workflow""" """工作流:当前WorkflowEngine待执行的workflow"""
self._steps_by_id: Dict[int, WorkStep] = {step.step: step for step in self.workflow.work_link} self._steps_by_id: Dict[int, WorkStep] = {
step.step: step for step in self.workflow.work_link
}
"""步骤表:将当前workflow的步骤序号和步骤内容存放""" """步骤表:将当前workflow的步骤序号和步骤内容存放"""
self.consciousness_node = consciousness_node self.consciousness_node = consciousness_node
"""意识节点""" """意识节点"""
@@ -70,7 +66,7 @@ class WorkflowEngine:
"""执行与 push sse 相关的核心业务流转操作。 """执行与 push sse 相关的核心业务流转操作。
该方法封装了具体的算法策略或状态控制逻辑确保操作能够在事务上下文中被原子且一致地执行 该方法封装了具体的算法策略或状态控制逻辑确保操作能够在事务上下文中被原子且一致地执行
Args: msg (str): 控制逻辑流向的具体字符串参数指定了期望的 msg 内容 Args: msg (str): 控制逻辑流向的具体字符串参数指定了期望的 msg 内容
Returns: (None): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象 """ Returns: (None): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象"""
try: try:
await self._gwm.put_pending.remote(self.workflow.trace_id, msg) await self._gwm.put_pending.remote(self.workflow.trace_id, msg)
except Exception: except Exception:
@@ -95,37 +91,55 @@ class WorkflowEngine:
async def run(self): async def run(self):
""" """
run方法 run方法
处理并执行workflow的方法 处理并执行workflow的方法
""" """
self.logger.info(f"🚀 工作流引擎启动: {self.workflow.title} [Trace ID: {self.workflow.trace_id}]") self.logger.info(
f"🚀 工作流引擎启动: {self.workflow.title} [Trace ID: {self.workflow.trace_id}]"
)
await self._push_sse(f"[工作流启动] {self.workflow.title}") await self._push_sse(f"[工作流启动] {self.workflow.title}")
max_step = len(self.workflow.work_link) max_step = len(self.workflow.work_link)
while 1 <= self.workflow.status.step <= max_step: while 1 <= self.workflow.status.step <= max_step:
current_step_id = self.workflow.status.step current_step_id = self.workflow.status.step
current_step = self._steps_by_id.get(current_step_id) current_step = self._steps_by_id.get(current_step_id)
if not current_step: if not current_step:
self.logger.error(f"严重错误:找不到步骤 {current_step_id},工作流强制终止。") self.logger.error(
f"严重错误:找不到步骤 {current_step_id},工作流强制终止。"
)
self.workflow.status.status = "failed" self.workflow.status.status = "failed"
await self._push_sse(f"[工作流失败] 找不到步骤 {current_step_id}") await self._push_sse(f"[工作流失败] 找不到步骤 {current_step_id}")
break break
self.logger.info(f"▶️ 开始执行 Step {current_step_id}: [{current_step.node}] -> {current_step.action}") self.logger.info(
f"▶️ 开始执行 Step {current_step_id}: [{current_step.node}] -> {current_step.action}"
)
current_step.status = "running" current_step.status = "running"
await self._push_sse(f"[Step {current_step_id}] {current_step.name}: {current_step.desc}") await self._push_sse(
f"[Step {current_step_id}] {current_step.name}: {current_step.desc}"
)
try: try:
step_input_data = self._prepare_inputs(current_step.inputs) step_input_data = self._prepare_inputs(current_step.inputs)
step_result, is_success = await self._dispatch_to_node(current_step, step_input_data) step_result, is_success = await self._dispatch_to_node(
current_step, step_input_data
)
if is_success: if is_success:
if current_step.outputs: if current_step.outputs:
self.workflow.context_memory[current_step.outputs] = step_result self.workflow.context_memory[current_step.outputs] = step_result
self.logger.debug(f"Step {current_step_id} 产出已保存至变量: '{current_step.outputs}'") self.logger.debug(
f"Step {current_step_id} 产出已保存至变量: '{current_step.outputs}'"
)
current_step.status = "completed" current_step.status = "completed"
await self._push_sse(f"[Step {current_step_id} 完成] {current_step.name}") await self._push_sse(
f"[Step {current_step_id} 完成] {current_step.name}"
)
else: else:
self.logger.warning(f"Step {current_step_id} 执行遇到业务失败/驳回。") self.logger.warning(
f"Step {current_step_id} 执行遇到业务失败/驳回。"
)
current_step.status = "failed" current_step.status = "failed"
await self._push_sse(f"[Step {current_step_id} 失败] {current_step.name}") await self._push_sse(
f"[Step {current_step_id} 失败] {current_step.name}"
)
self._handle_logic_gate(current_step, is_success) self._handle_logic_gate(current_step, is_success)
except WorkflowExit: except WorkflowExit:
self.logger.info("命中 if_pass='exit',工作流被主动要求结束。") self.logger.info("命中 if_pass='exit',工作流被主动要求结束。")
@@ -137,7 +151,10 @@ class WorkflowEngine:
await self._push_sse(f"[工作流失败] {e}") await self._push_sse(f"[工作流失败] {e}")
break break
except Exception as e: except Exception as e:
self.logger.error(f"❌ Step {current_step_id} 发生系统级未捕获异常: {e}", exc_info=True) self.logger.error(
f"❌ Step {current_step_id} 发生系统级未捕获异常: {e}",
exc_info=True,
)
current_step.status = "failed" current_step.status = "failed"
self.workflow.status.status = "failed" self.workflow.status.status = "failed"
await self._push_sse(f"[工作流异常] {e}") await self._push_sse(f"[工作流异常] {e}")
@@ -163,9 +180,11 @@ class WorkflowEngine:
if self.consciousness_node: if self.consciousness_node:
supervisory_input = ForSupervisoryInput( supervisory_input = ForSupervisoryInput(
workflow=self.workflow, workflow=self.workflow,
original_command=self.workflow.command or "未知命令" original_command=self.workflow.command or "未知命令",
)
report_obj = await self.consciousness_node.working.remote(
supervisory_input
) )
report_obj = await self.consciousness_node.working.remote(supervisory_input)
if isinstance(report_obj, ForSupervisoryNode): if isinstance(report_obj, ForSupervisoryNode):
report = report_obj.output report = report_obj.output
elif isinstance(report_obj, str): elif isinstance(report_obj, str):
@@ -178,7 +197,7 @@ class WorkflowEngine:
term_msg = TerminationMessage( term_msg = TerminationMessage(
platform=self.workflow.event_info.platform, platform=self.workflow.event_info.platform,
user_name=self.workflow.event_info.user_name, user_name=self.workflow.event_info.user_name,
message=f"工作流执行完毕。系统报告:{report}" message=f"工作流执行完毕。系统报告:{report}",
) )
user_response = await self.supervisory_node.working.remote(term_msg) user_response = await self.supervisory_node.working.remote(term_msg)
self.workflow.context_memory["_final_user_response"] = user_response self.workflow.context_memory["_final_user_response"] = user_response
@@ -188,7 +207,9 @@ class WorkflowEngine:
except Exception: except Exception:
self.logger.exception("生成工作流执行汇报时发生错误") self.logger.exception("生成工作流执行汇报时发生错误")
async def _dispatch_to_node(self, step: WorkStep, input_data: Any) -> tuple[Any, bool]: async def _dispatch_to_node(
self, step: WorkStep, input_data: Any
) -> tuple[Any, bool]:
""" """
分流器 分流器
调用当前step的执行对象 调用当前step的执行对象
@@ -216,8 +237,7 @@ class WorkflowEngine:
raise WorkflowError("未提供 consciousness_node 句柄!") raise WorkflowError("未提供 consciousness_node 句柄!")
original_cmd = self.workflow.command or "" original_cmd = self.workflow.command or ""
payload = ConsciousnessForWorkflowInput( payload = ConsciousnessForWorkflowInput(
workflow_step=step, workflow_step=step, original_command=original_cmd
original_command=original_cmd
) )
result_obj = await self.consciousness_node.working.remote(payload) result_obj = await self.consciousness_node.working.remote(payload)
if isinstance(result_obj, ConsciousnessForWorkflow): if isinstance(result_obj, ConsciousnessForWorkflow):
@@ -225,9 +245,12 @@ class WorkflowEngine:
return result_obj, True return result_obj, True
elif step.node == "skill_individual": elif step.node == "skill_individual":
self.logger.info(f"正在通过 WorkerCluster 调度 skill_individual 执行 {step.action}") self.logger.info(
f"正在通过 WorkerCluster 调度 skill_individual 执行 {step.action}"
)
try: try:
from pretor.utils.ray_hook import ray_actor_hook from pretor.utils.ray_hook import ray_actor_hook
worker_cluster = ray_actor_hook("worker_cluster").worker_cluster worker_cluster = ray_actor_hook("worker_cluster").worker_cluster
task_id = f"{self.workflow.trace_id}_step_{step.step}" task_id = f"{self.workflow.trace_id}_step_{step.step}"
agent_id = step.agent_id or f"default_{step.node}" agent_id = step.agent_id or f"default_{step.node}"
@@ -235,18 +258,24 @@ class WorkflowEngine:
"action": step.action, "action": step.action,
"description": step.desc, "description": step.desc,
"input_data": input_data, "input_data": input_data,
"context_memory": self.workflow.context_memory "context_memory": self.workflow.context_memory,
} }
result_response = await worker_cluster.submit_task.remote(task_id, agent_id, task_event) result_response = await worker_cluster.submit_task.remote(
task_id, agent_id, task_event
)
if result_response.get("success"): if result_response.get("success"):
return result_response.get("data"), True return result_response.get("data"), True
else: else:
self.logger.error(f"WorkerCluster 执行 {step.node} 失败: {result_response.get('error')}") self.logger.error(
f"WorkerCluster 执行 {step.node} 失败: {result_response.get('error')}"
)
return result_response.get("error"), False return result_response.get("error"), False
except Exception as e: except Exception as e:
self.logger.exception(f"调度 WorkerCluster 执行 {step.node} 时发生异常: {e}") self.logger.exception(
f"调度 WorkerCluster 执行 {step.node} 时发生异常: {e}"
)
raise WorkflowError(f"WorkerCluster 调度异常: {e}") raise WorkflowError(f"WorkerCluster 调度异常: {e}")
else: else:
raise WorkflowError(f"未知的节点类型:{step.node}") raise WorkflowError(f"未知的节点类型:{step.node}")
@@ -275,7 +304,9 @@ class WorkflowEngine:
match gate.if_fail.split("_"): match gate.if_fail.split("_"):
case ["jump", "to", "step", target] if target.isdigit(): case ["jump", "to", "step", target] if target.isdigit():
target_step = int(target) target_step = int(target)
self.logger.warning(f"触发逻辑门分支!从 Step {step.step} 跳转至 Step {target_step}") self.logger.warning(
f"触发逻辑门分支!从 Step {step.step} 跳转至 Step {target_step}"
)
self.workflow.status.step = target_step self.workflow.status.step = target_step
case _: case _:
raise WorkflowError(f"未知的 if_fail 格式: {gate.if_fail}") raise WorkflowError(f"未知的 if_fail 格式: {gate.if_fail}")
@@ -284,10 +315,14 @@ class WorkflowEngine:
@ray.remote @ray.remote
class WorkflowRunningEngine: class WorkflowRunningEngine:
"""WorkflowRunningEngine 核心组件类。 """WorkflowRunningEngine 核心组件类。
这是一个领域数据模型或功能封装类承载了 WorkflowRunningEngine 相关的内聚属性定义与状态维护它的存在隔离了局部的业务复杂性并对外提供了类型安全的访问接口 """ 这是一个领域数据模型或功能封装类承载了 WorkflowRunningEngine 相关的内聚属性定义与状态维护它的存在隔离了局部的业务复杂性并对外提供了类型安全的访问接口"""
def __init__(self, consciousness_node=None, control_node=None, supervisory_node=None):
def __init__(
self, consciousness_node=None, control_node=None, supervisory_node=None
):
from pretor.utils.logger import get_logger from pretor.utils.logger import get_logger
self.logger = get_logger('workflow_runner')
self.logger = get_logger("workflow_runner")
self.runner_engine = {} self.runner_engine = {}
self.workflow_queue: asyncio.Queue[PretorEvent] = None self.workflow_queue: asyncio.Queue[PretorEvent] = None
self.consciousness_node = consciousness_node self.consciousness_node = consciousness_node
@@ -298,28 +333,31 @@ class WorkflowRunningEngine:
async def run(self): async def run(self):
# Move actor hook to async start so we don't race during __init__ across cluster # Move actor hook to async start so we don't race during __init__ across cluster
"""执行与 run 相关的核心业务流转操作。 """执行与 run 相关的核心业务流转操作。
该方法封装了具体的算法策略或状态控制逻辑确保操作能够在事务上下文中被原子且一致地执行 """ 该方法封装了具体的算法策略或状态控制逻辑确保操作能够在事务上下文中被原子且一致地执行"""
self.global_state_machine = ray_actor_hook("global_state_machine").global_state_machine self.global_state_machine = ray_actor_hook(
"global_state_machine"
).global_state_machine
self.workflow_queue = asyncio.Queue() self.workflow_queue = asyncio.Queue()
self.runner_engine = { self.runner_engine = {
f"runner_{i}": asyncio.create_task(self.runner(i)) f"runner_{i}": asyncio.create_task(self.runner(i)) for i in range(10)
for i in range(10)
} }
async def put_event(self, event: PretorEvent) -> None: async def put_event(self, event: PretorEvent) -> None:
"""执行与 put event 相关的核心业务流转操作。 """执行与 put event 相关的核心业务流转操作。
该方法封装了具体的算法策略或状态控制逻辑确保操作能够在事务上下文中被原子且一致地执行 该方法封装了具体的算法策略或状态控制逻辑确保操作能够在事务上下文中被原子且一致地执行
Args: event (PretorEvent): 由事件总线或工作流引擎分发过来的事件载荷封装了触发此次调用的上下文快照与任务目标指令 Args: event (PretorEvent): 由事件总线或工作流引擎分发过来的事件载荷封装了触发此次调用的上下文快照与任务目标指令
Returns: (None): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象 """ Returns: (None): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象"""
await self.workflow_queue.put(event) await self.workflow_queue.put(event)
async def resume_workflow(self, event: PretorEvent) -> None: async def resume_workflow(self, event: PretorEvent) -> None:
"""Resume an incomplete workflow that was loaded from the database.""" """Resume an incomplete workflow that was loaded from the database."""
self.logger.info(f"Resuming workflow {event.trace_id}") self.logger.info(f"Resuming workflow {event.trace_id}")
workflow_engine = WorkflowEngine(event.workflow, workflow_engine = WorkflowEngine(
self.consciousness_node, event.workflow,
self.control_node, self.consciousness_node,
self.supervisory_node) self.control_node,
self.supervisory_node,
)
# Assuming you want to schedule it via a task # Assuming you want to schedule it via a task
asyncio.create_task(workflow_engine.run()) asyncio.create_task(workflow_engine.run())
@@ -333,33 +371,37 @@ class WorkflowRunningEngine:
while True: while True:
try: try:
event = await self.workflow_queue.get() event = await self.workflow_queue.get()
self.logger.info(f"WorkflowRunningEngine: runner_{i} 接收到事件 {event.trace_id} 准备生成工作流。") self.logger.info(
f"WorkflowRunningEngine: runner_{i} 接收到事件 {event.trace_id} 准备生成工作流。"
)
if not self.consciousness_node: if not self.consciousness_node:
raise WorkflowError("未配置 consciousness_node,无法生成工作流") raise WorkflowError("未配置 consciousness_node,无法生成工作流")
workflow_template_name = event.context.get("workflow_template", "")
workflow_template = get_workflow_template(workflow_template_name) if workflow_template_name else None
available_skills = None available_skills = None
if self.global_state_machine: if self.global_state_machine:
try: try:
all_individuals = await self.global_state_machine.list_individuals.remote() all_individuals = (
await self.global_state_machine.list_individuals.remote()
)
available_skills = [] available_skills = []
for agent_id, config in all_individuals.items(): for agent_id, config in all_individuals.items():
if config.get("agent_type") == "skill_individual" or config.get("type") == "skill_individual": if (
available_skills.append({ config.get("agent_type") == "skill_individual"
"agent_id": agent_id, or config.get("type") == "skill_individual"
"name": config.get("agent_name", "Unknown"), ):
"description": config.get("description", "") available_skills.append(
}) {
"agent_id": agent_id,
"name": config.get("agent_name", "Unknown"),
"description": config.get("description", ""),
}
)
except Exception as e: except Exception as e:
self.logger.warning(f"获取Skill Individual列表失败: {e}") self.logger.warning(f"获取Skill Individual列表失败: {e}")
payload = ForWorkflowEngineInput( payload = ForWorkflowEngineInput(
original_command=event.message, original_command=event.message, available_skills=available_skills
workflow_template=workflow_template,
available_skills=available_skills
) )
result_obj = await self.consciousness_node.working.remote(payload) result_obj = await self.consciousness_node.working.remote(payload)
@@ -369,25 +411,39 @@ class WorkflowRunningEngine:
workflow.trace_id = event.trace_id workflow.trace_id = event.trace_id
workflow.command = event.message workflow.command = event.message
workflow.event_info = EventInfo(platform=event.platform, workflow.event_info = EventInfo(
user_name=event.user_name,) platform=event.platform,
user_name=event.user_name,
)
self.logger.info( self.logger.info(
f"WorkflowRunningEngine: runner_{i} 成功生成工作流 {workflow.trace_id}:{workflow.title}") f"WorkflowRunningEngine: runner_{i} 成功生成工作流 {workflow.trace_id}:{workflow.title}"
)
global_workflow_manager = ray_actor_hook("global_workflow_manager").global_workflow_manager global_workflow_manager = ray_actor_hook(
await global_workflow_manager.update_workflow.remote(event.trace_id, workflow) "global_workflow_manager"
).global_workflow_manager
await global_workflow_manager.update_workflow.remote(
event.trace_id, workflow
)
workflow_engine = WorkflowEngine(workflow, workflow_engine = WorkflowEngine(
self.consciousness_node, workflow,
self.control_node, self.consciousness_node,
self.supervisory_node) self.control_node,
self.supervisory_node,
)
await workflow_engine.run() await workflow_engine.run()
else: else:
self.logger.error(f"WorkflowRunningEngine: runner_{i} 无法生成工作流,返回类型为 {type(result_obj)}") self.logger.error(
f"WorkflowRunningEngine: runner_{i} 无法生成工作流,返回类型为 {type(result_obj)}"
)
except asyncio.CancelledError: except asyncio.CancelledError:
self.logger.info(f"WorkflowRunningEngine: runner_{i} 被取消。") self.logger.info(f"WorkflowRunningEngine: runner_{i} 被取消。")
raise raise
except Exception as e: except Exception as e:
self.logger.error(f"WorkflowRunningEngine: runner_{i} 遇到未捕获的异常: {e}", exc_info=True) self.logger.error(
f"WorkflowRunningEngine: runner_{i} 遇到未捕获的异常: {e}",
exc_info=True,
)
+1 -1
View File
@@ -10,4 +10,4 @@
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
@@ -13,4 +13,5 @@
# limitations under the License. # limitations under the License.
from .approval import ApprovalToolData, approval from .approval import ApprovalToolData, approval
__all__ = ["ApprovalToolData", "approval"] __all__ = ["ApprovalToolData", "approval"]
+14 -4
View File
@@ -16,12 +16,22 @@ from pretor.plugin.tool_plugin.base_tool import BaseToolData
from pretor.utils.ray_hook import ray_actor_hook from pretor.utils.ray_hook import ray_actor_hook
from typing import List, Literal, Dict from typing import List, Literal, Dict
class ApprovalToolData(BaseToolData): class ApprovalToolData(BaseToolData):
"""ApprovalToolData 核心组件类。 """ApprovalToolData 核心组件类。
这是一个可被智能体动态调用的外部工具组件类。它定义了清晰的输入参数 Schema 与执行契约,赋予智能体与外界真实系统(如文件、网页、API)进行交互的能力。 """ 这是一个可被智能体动态调用的外部工具组件类。它定义了清晰的输入参数 Schema 与执行契约,赋予智能体与外界真实系统(如文件、网页、API)进行交互的能力。"""
is_system: bool = True is_system: bool = True
action_scope: List[Literal["control_node", "consciousness_node", "supervisory_node", "growth_node", "", ""]] = [ action_scope: List[
"control_node", "consciousness_node"] Literal[
"control_node",
"consciousness_node",
"supervisory_node",
"growth_node",
"",
"",
]
] = ["control_node", "consciousness_node"]
config_args: Dict[str, str] = {} config_args: Dict[str, str] = {}
@@ -38,4 +48,4 @@ async def approval(message: str, trace_id: str) -> str:
actor_list = ray_actor_hook("global_state_machine") actor_list = ray_actor_hook("global_state_machine")
await actor_list.global_state_machine.put_pending.remote(trace_id, message) await actor_list.global_state_machine.put_pending.remote(trace_id, message)
reply = await actor_list.global_state_machine.get_received.remote(trace_id) reply = await actor_list.global_state_machine.get_received.remote(trace_id)
return reply return reply
+14 -3
View File
@@ -16,10 +16,21 @@ from pydantic import BaseModel
from typing import List, Literal, Dict from typing import List, Literal, Dict
from pydantic import ConfigDict from pydantic import ConfigDict
class BaseToolData(BaseModel): class BaseToolData(BaseModel):
"""BaseToolData 核心组件类。 """BaseToolData 核心组件类。
这是一个可被智能体动态调用的外部工具组件类。它定义了清晰的输入参数 Schema 与执行契约,赋予智能体与外界真实系统(如文件、网页、API)进行交互的能力。 """ 这是一个可被智能体动态调用的外部工具组件类。它定义了清晰的输入参数 Schema 与执行契约,赋予智能体与外界真实系统(如文件、网页、API)进行交互的能力。"""
model_config = ConfigDict(extra="allow") model_config = ConfigDict(extra="allow")
is_system: bool is_system: bool
action_scope: List[Literal["control_node", "consciousness_node", "supervisory_node", "growth_node", "", ""]] = [] action_scope: List[
config_args: Dict[str, str] = {} Literal[
"control_node",
"consciousness_node",
"supervisory_node",
"growth_node",
"",
"",
]
] = []
config_args: Dict[str, str] = {}
@@ -16,13 +16,16 @@ from pydantic_ai import RunContext
from pretor.plugin.tool_plugin.base_tool import BaseToolData from pretor.plugin.tool_plugin.base_tool import BaseToolData
import os import os
class FileReaderData(BaseToolData): class FileReaderData(BaseToolData):
"""FileReaderData 核心组件类。 """FileReaderData 核心组件类。
这是一个领域数据模型或功能封装类,承载了 FileReaderData 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。 """ 这是一个领域数据模型或功能封装类,承载了 FileReaderData 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。"""
is_system: bool = True is_system: bool = True
name: str = "file_reader" name: str = "file_reader"
description: str = "读取本地文件的内容" description: str = "读取本地文件的内容"
def file_reader(ctx: RunContext, filepath: str) -> str: def file_reader(ctx: RunContext, filepath: str) -> str:
"""读取本地文件内容的工具。 """读取本地文件内容的工具。
@@ -38,7 +41,7 @@ def file_reader(ctx: RunContext, filepath: str) -> str:
return f"Error: {filepath} 不是一个文件。" return f"Error: {filepath} 不是一个文件。"
try: try:
with open(filepath, 'r', encoding='utf-8') as f: with open(filepath, "r", encoding="utf-8") as f:
content = f.read() content = f.read()
return content return content
except Exception as e: except Exception as e:
+17 -19
View File
@@ -24,11 +24,13 @@ from pwdlib import PasswordHash
class TokenData(BaseModel): class TokenData(BaseModel):
"""TokenData 核心组件类。 """TokenData 核心组件类。
这是一个领域数据模型或功能封装类,承载了 TokenData 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。 """ 这是一个领域数据模型或功能封装类,承载了 TokenData 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。"""
user_id: str user_id: str
username: Optional[str] = None username: Optional[str] = None
exp: Optional[int] = None exp: Optional[int] = None
SECRET_KEY = os.getenv("SECRET_KEY") SECRET_KEY = os.getenv("SECRET_KEY")
ALGORITHM = "HS256" ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 60 * 24 ACCESS_TOKEN_EXPIRE_MINUTES = 60 * 24
@@ -41,19 +43,16 @@ password_hasher = PasswordHash.recommended()
class Accessor: class Accessor:
"""Accessor 核心组件类。 """Accessor 核心组件类。
这是一个领域数据模型或功能封装类,承载了 Accessor 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。 """ 这是一个领域数据模型或功能封装类,承载了 Accessor 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。"""
@staticmethod @staticmethod
def _decode_token(token: str) -> TokenData: def _decode_token(token: str) -> TokenData:
"""执行与 decode token 相关的核心业务流转操作。 """执行与 decode token 相关的核心业务流转操作。
该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。 该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。
Args: token (str): 由认证中心颁发的 JWT 或长期访问令牌,用于跨服务调用时的身份自证与权限校验。 Args: token (str): 由认证中心颁发的 JWT 或长期访问令牌,用于跨服务调用时的身份自证与权限校验。
Returns: (TokenData): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ Returns: (TokenData): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。"""
try: try:
payload = jwt.decode( payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
token,
SECRET_KEY,
algorithms=[ALGORITHM]
)
return TokenData(**payload) return TokenData(**payload)
except jwt.ExpiredSignatureError: except jwt.ExpiredSignatureError:
raise HTTPException( raise HTTPException(
@@ -71,9 +70,11 @@ class Accessor:
"""创建并持久化新的 access token 实体。 """创建并持久化新的 access token 实体。
接收构建参数,执行必要的数据校验与默认值填充后,将新记录安全地写入底层存储或系统注册表中。 接收构建参数,执行必要的数据校验与默认值填充后,将新记录安全地写入底层存储或系统注册表中。
Args: data (dict): 从客户端传递过来或由上游组件生成的核心业务数据体,通常需要进一步的清洗和结构化解析。 Args: data (dict): 从客户端传递过来或由上游组件生成的核心业务数据体,通常需要进一步的清洗和结构化解析。
Returns: (str): 处理流程所输出的具体字符串产物,可能是新生成的 ID 序列、格式化好的文本片段或 LLM 推理的回答内容。 """ Returns: (str): 处理流程所输出的具体字符串产物,可能是新生成的 ID 序列、格式化好的文本片段或 LLM 推理的回答内容。"""
to_encode = data.copy() to_encode = data.copy()
expire = datetime.now(timezone.utc) + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) expire = datetime.now(timezone.utc) + timedelta(
minutes=ACCESS_TOKEN_EXPIRE_MINUTES
)
to_encode.update({"exp": int(expire.timestamp())}) to_encode.update({"exp": int(expire.timestamp())})
return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
@@ -82,7 +83,7 @@ class Accessor:
"""执行与 verify password 相关的核心业务流转操作。 """执行与 verify password 相关的核心业务流转操作。
该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。 该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。
Args: plain_password (str): 控制逻辑流向的具体字符串参数,指定了期望的 plain password 内容。 hashed_password (str): 控制逻辑流向的具体字符串参数,指定了期望的 hashed password 内容。 Args: plain_password (str): 控制逻辑流向的具体字符串参数,指定了期望的 plain password 内容。 hashed_password (str): 控制逻辑流向的具体字符串参数,指定了期望的 hashed password 内容。
Returns: (bool): 一个布尔型结果标志,明确返回 True 表示该操作成功应用或条件达成,False 则表示失败或被拒绝。 """ Returns: (bool): 一个布尔型结果标志,明确返回 True 表示该操作成功应用或条件达成,False 则表示失败或被拒绝。"""
return password_hasher.verify(plain_password, hashed_password) return password_hasher.verify(plain_password, hashed_password)
@staticmethod @staticmethod
@@ -90,7 +91,7 @@ class Accessor:
"""检索并获取特定的 current user 数据集合或实例对象。 """检索并获取特定的 current user 数据集合或实例对象。
根据提供的查询条件或上下文凭证,从数据库、缓存或第三方服务中读取对应的资源状态。 根据提供的查询条件或上下文凭证,从数据库、缓存或第三方服务中读取对应的资源状态。
Args: request (Request): FastAPI 框架注入的原生 HTTP 请求对象,包含了完整的 Header 标头、查询参数和正文流。 Args: request (Request): FastAPI 框架注入的原生 HTTP 请求对象,包含了完整的 Header 标头、查询参数和正文流。
Returns: (TokenData): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ Returns: (TokenData): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。"""
auth_header = request.headers.get("Authorization") auth_header = request.headers.get("Authorization")
if not auth_header or not auth_header.startswith("Bearer "): if not auth_header or not auth_header.startswith("Bearer "):
raise HTTPException( raise HTTPException(
@@ -105,7 +106,7 @@ class Accessor:
"""执行与 login hashed password 相关的核心业务流转操作。 """执行与 login hashed password 相关的核心业务流转操作。
该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。 该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。
Args: user (User): 当前已通过鉴权流程的访问者实体对象,内部包含用户角色、权限层级及租户归属等核心元信息。 password (str): 控制逻辑流向的具体字符串参数,指定了期望的 password 内容。 Args: user (User): 当前已通过鉴权流程的访问者实体对象,内部包含用户角色、权限层级及租户归属等核心元信息。 password (str): 控制逻辑流向的具体字符串参数,指定了期望的 password 内容。
Returns: (str): 处理流程所输出的具体字符串产物,可能是新生成的 ID 序列、格式化好的文本片段或 LLM 推理的回答内容。 """ Returns: (str): 处理流程所输出的具体字符串产物,可能是新生成的 ID 序列、格式化好的文本片段或 LLM 推理的回答内容。"""
if not user: if not user:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
@@ -116,10 +117,7 @@ class Accessor:
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
detail="用户名或密码错误", detail="用户名或密码错误",
) )
token_payload = { token_payload = {"user_id": str(user.user_id), "username": user.user_name}
"user_id": str(user.user_id),
"username": user.user_name
}
return Accessor._create_access_token(data=token_payload) return Accessor._create_access_token(data=token_payload)
@staticmethod @staticmethod
@@ -127,9 +125,9 @@ class Accessor:
"""执行与 hash password 相关的核心业务流转操作。 """执行与 hash password 相关的核心业务流转操作。
该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。 该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。
Args: password (str): 控制逻辑流向的具体字符串参数,指定了期望的 password 内容。 Args: password (str): 控制逻辑流向的具体字符串参数,指定了期望的 password 内容。
Returns: (str): 处理流程所输出的具体字符串产物,可能是新生成的 ID 序列、格式化好的文本片段或 LLM 推理的回答内容。 """ Returns: (str): 处理流程所输出的具体字符串产物,可能是新生成的 ID 序列、格式化好的文本片段或 LLM 推理的回答内容。"""
if not password: if not password:
raise ValueError("密码不能为空") raise ValueError("密码不能为空")
if len(password) < 6: if len(password) < 6:
raise ValueError("密码长度不能小于 6 位") raise ValueError("密码长度不能小于 6 位")
return password_hasher.hash(password) return password_hasher.hash(password)
+10 -4
View File
@@ -15,17 +15,23 @@
from pydantic import BaseModel from pydantic import BaseModel
class ResponseModel(BaseModel): class ResponseModel(BaseModel):
"""ResponseModel 核心组件类。 """ResponseModel 核心组件类。
这是一个领域数据模型或功能封装类,承载了 ResponseModel 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。 """ 这是一个领域数据模型或功能封装类,承载了 ResponseModel 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。"""
pass pass
class DepsModel(BaseModel): class DepsModel(BaseModel):
"""DepsModel 核心组件类。 """DepsModel 核心组件类。
这是一个领域数据模型或功能封装类,承载了 DepsModel 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。 """ 这是一个领域数据模型或功能封装类,承载了 DepsModel 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。"""
pass pass
class InputModel(BaseModel): class InputModel(BaseModel):
"""InputModel 核心组件类。 """InputModel 核心组件类。
这是一个领域数据模型或功能封装类,承载了 InputModel 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。 """ 这是一个领域数据模型或功能封装类,承载了 InputModel 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。"""
pass
pass
+4 -2
View File
@@ -15,11 +15,13 @@
from rich.console import Console from rich.console import Console
from rich.text import Text from rich.text import Text
import yaml import yaml
def print_banner() -> None: def print_banner() -> None:
"""执行与 print banner 相关的核心业务流转操作。 """执行与 print banner 相关的核心业务流转操作。
该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。 该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。
Returns: (None): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ Returns: (None): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。"""
with open("config/config.yml","r") as config: with open("config/config.yml", "r") as config:
config = yaml.load(config, Loader=yaml.FullLoader) config = yaml.load(config, Loader=yaml.FullLoader)
version = config.get("version", "unknown") version = config.get("version", "unknown")
pretor_banner = """ pretor_banner = """
+22 -16
View File
@@ -17,47 +17,53 @@ from pretor.utils.access import Accessor, TokenData
from pretor.core.database.table.user import UserAuthority from pretor.core.database.table.user import UserAuthority
from pretor.utils.ray_hook import ray_actor_hook from pretor.utils.ray_hook import ray_actor_hook
async def get_authority(user_id: str) -> UserAuthority: async def get_authority(user_id: str) -> UserAuthority:
"""检索并获取特定的 authority 数据集合或实例对象。 """检索并获取特定的 authority 数据集合或实例对象。
根据提供的查询条件或上下文凭证,从数据库、缓存或第三方服务中读取对应的资源状态。 根据提供的查询条件或上下文凭证,从数据库、缓存或第三方服务中读取对应的资源状态。
Args: user_id (str): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 user 实例。 Args: user_id (str): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 user 实例。
Returns: (UserAuthority): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ Returns: (UserAuthority): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。"""
from pretor.utils.error import UserNotExistError from pretor.utils.error import UserNotExistError
postgres_database = ray_actor_hook("postgres_database").postgres_database postgres_database = ray_actor_hook("postgres_database").postgres_database
try: try:
user_authority = await postgres_database.get_user_authority.remote(user_id=user_id) user_authority = await postgres_database.get_user_authority.remote(
user_id=user_id
)
return user_authority return user_authority
except UserNotExistError: except UserNotExistError:
raise HTTPException( raise HTTPException(status_code=401, detail="用户不存在或已被删除,请重新登录")
status_code=401,
detail="用户不存在或已被删除,请重新登录"
)
except Exception as e: except Exception as e:
# Check if it's a RayTaskError wrapping UserNotExistError # Check if it's a RayTaskError wrapping UserNotExistError
if "UserNotExistError" in str(e): if "UserNotExistError" in str(e):
raise HTTPException( raise HTTPException(
status_code=401, status_code=401, detail="用户不存在或已被删除,请重新登录"
detail="用户不存在或已被删除,请重新登录"
) )
raise raise
class RoleChecker: class RoleChecker:
"""RoleChecker 核心组件类。 """RoleChecker 核心组件类。
这是一个领域数据模型或功能封装类,承载了 RoleChecker 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。 """ 这是一个领域数据模型或功能封装类,承载了 RoleChecker 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。"""
def __init__(self, **kwargs):
self.allowed_roles = kwargs.get("allowed_roles", )
async def __call__(self, def __init__(self, **kwargs):
token_data: Annotated[TokenData, Depends(Accessor.get_current_user)]): self.allowed_roles = kwargs.get(
"allowed_roles",
)
async def __call__(
self, token_data: Annotated[TokenData, Depends(Accessor.get_current_user)]
):
"""执行与 call 相关的核心业务流转操作。 """执行与 call 相关的核心业务流转操作。
该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。 该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。
Args: token_data (Annotated[TokenData, Depends(Accessor.get_current_user)]): 从客户端传递过来或由上游组件生成的核心业务数据体,通常需要进一步的清洗和结构化解析。 Args: token_data (Annotated[TokenData, Depends(Accessor.get_current_user)]): 从客户端传递过来或由上游组件生成的核心业务数据体,通常需要进一步的清洗和结构化解析。
Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。"""
user_authority = await get_authority(token_data.user_id) user_authority = await get_authority(token_data.user_id)
if user_authority < self.allowed_roles: if user_authority < self.allowed_roles:
raise HTTPException( raise HTTPException(
status_code=403, status_code=403,
detail={"message": f"User {token_data.user_id} does not have allowed roles"}, detail={
"message": f"User {token_data.user_id} does not have allowed roles"
},
) )
return token_data return token_data
+31 -11
View File
@@ -12,57 +12,77 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
class RetryableError(Exception): class RetryableError(Exception):
"""基类:所有可重试错误(如网络断开、抖动等临时性故障)""" """基类:所有可重试错误(如网络断开、抖动等临时性故障)"""
pass pass
class NonRetryableError(Exception): class NonRetryableError(Exception):
"""基类:所有不可重试错误(如数据验证失败、类型错误等业务逻辑故障)""" """基类:所有不可重试错误(如数据验证失败、类型错误等业务逻辑故障)"""
pass pass
class DemandError(NonRetryableError): class DemandError(NonRetryableError):
"""DemandError 核心组件类。 """DemandError 核心组件类。
这是一个自定义异常类,专门用于在 Demand 相关业务流程中触发中断。它携带了精确的错误上下文与追溯代码,帮助最外层网关能够统一捕获并返回友好的前端错误提示。 """ 这是一个自定义异常类,专门用于在 Demand 相关业务流程中触发中断。它携带了精确的错误上下文与追溯代码,帮助最外层网关能够统一捕获并返回友好的前端错误提示。"""
pass pass
class ModelNotExistError(Exception): class ModelNotExistError(Exception):
"""ModelNotExistError 核心组件类。 """ModelNotExistError 核心组件类。
这是一个自定义异常类,专门用于在 ModelNotExist 相关业务流程中触发中断。它携带了精确的错误上下文与追溯代码,帮助最外层网关能够统一捕获并返回友好的前端错误提示。 """ 这是一个自定义异常类,专门用于在 ModelNotExist 相关业务流程中触发中断。它携带了精确的错误上下文与追溯代码,帮助最外层网关能够统一捕获并返回友好的前端错误提示。"""
pass pass
class UserError(Exception): class UserError(Exception):
"""UserError 核心组件类。 """UserError 核心组件类。
这是一个自定义异常类,专门用于在 User 相关业务流程中触发中断。它携带了精确的错误上下文与追溯代码,帮助最外层网关能够统一捕获并返回友好的前端错误提示。 """ 这是一个自定义异常类,专门用于在 User 相关业务流程中触发中断。它携带了精确的错误上下文与追溯代码,帮助最外层网关能够统一捕获并返回友好的前端错误提示。"""
pass pass
class UserNotExistError(UserError): class UserNotExistError(UserError):
"""UserNotExistError 核心组件类。 """UserNotExistError 核心组件类。
这是一个自定义异常类,专门用于在 UserNotExist 相关业务流程中触发中断。它携带了精确的错误上下文与追溯代码,帮助最外层网关能够统一捕获并返回友好的前端错误提示。 """ 这是一个自定义异常类,专门用于在 UserNotExist 相关业务流程中触发中断。它携带了精确的错误上下文与追溯代码,帮助最外层网关能够统一捕获并返回友好的前端错误提示。"""
pass pass
class UserPasswordError(UserError): class UserPasswordError(UserError):
"""UserPasswordError 核心组件类。 """UserPasswordError 核心组件类。
这是一个自定义异常类,专门用于在 UserPassword 相关业务流程中触发中断。它携带了精确的错误上下文与追溯代码,帮助最外层网关能够统一捕获并返回友好的前端错误提示。 """ 这是一个自定义异常类,专门用于在 UserPassword 相关业务流程中触发中断。它携带了精确的错误上下文与追溯代码,帮助最外层网关能够统一捕获并返回友好的前端错误提示。"""
pass pass
class ProviderError(Exception): class ProviderError(Exception):
"""ProviderError 核心组件类。 """ProviderError 核心组件类。
这是一个模型/服务提供商适配器类,屏蔽了外部不同供应商(如 OpenAI、Anthropic 等)的底层 API 差异。它负责标准化参数组装、网络请求发送、鉴权处理以及响应结构的反序列化。 """ 这是一个模型/服务提供商适配器类,屏蔽了外部不同供应商(如 OpenAI、Anthropic 等)的底层 API 差异。它负责标准化参数组装、网络请求发送、鉴权处理以及响应结构的反序列化。"""
pass pass
class ProviderNotExistError(ProviderError): class ProviderNotExistError(ProviderError):
"""ProviderNotExistError 核心组件类。 """ProviderNotExistError 核心组件类。
这是一个模型/服务提供商适配器类,屏蔽了外部不同供应商(如 OpenAI、Anthropic 等)的底层 API 差异。它负责标准化参数组装、网络请求发送、鉴权处理以及响应结构的反序列化。 """ 这是一个模型/服务提供商适配器类,屏蔽了外部不同供应商(如 OpenAI、Anthropic 等)的底层 API 差异。它负责标准化参数组装、网络请求发送、鉴权处理以及响应结构的反序列化。"""
pass pass
class WorkflowError(Exception): class WorkflowError(Exception):
"""WorkflowError 核心组件类。 """WorkflowError 核心组件类。
这是一个自定义异常类,专门用于在 Workflow 相关业务流程中触发中断。它携带了精确的错误上下文与追溯代码,帮助最外层网关能够统一捕获并返回友好的前端错误提示。 """ 这是一个自定义异常类,专门用于在 Workflow 相关业务流程中触发中断。它携带了精确的错误上下文与追溯代码,帮助最外层网关能够统一捕获并返回友好的前端错误提示。"""
pass pass
class WorkflowExit(WorkflowError): class WorkflowExit(WorkflowError):
"""WorkflowExit 核心组件类。 """WorkflowExit 核心组件类。
这是一个领域数据模型或功能封装类,承载了 WorkflowExit 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。 """ 这是一个领域数据模型或功能封装类,承载了 WorkflowExit 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。"""
pass pass
+14 -7
View File
@@ -18,7 +18,8 @@ import sys
from typing import Callable, Dict, List from typing import Callable, Dict, List
from pretor.utils.logger import get_logger from pretor.utils.logger import get_logger
logger = get_logger('get_tool')
logger = get_logger("get_tool")
_tool_cache: Dict[str, Callable] = {} _tool_cache: Dict[str, Callable] = {}
@@ -26,13 +27,15 @@ def _get_tool_func(tool_name: str) -> Callable | None:
"""检索并获取特定的 tool func 数据集合或实例对象。 """检索并获取特定的 tool func 数据集合或实例对象。
根据提供的查询条件或上下文凭证,从数据库、缓存或第三方服务中读取对应的资源状态。 根据提供的查询条件或上下文凭证,从数据库、缓存或第三方服务中读取对应的资源状态。
Args: tool_name (str): 赋予该实体的人类可读名称或标题字符串,主要用于前端 UI 展示、日志记录或模糊检索。 Args: tool_name (str): 赋予该实体的人类可读名称或标题字符串,主要用于前端 UI 展示、日志记录或模糊检索。
Returns: (Callable | None): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ Returns: (Callable | None): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。"""
func = _tool_cache.get(tool_name, None) func = _tool_cache.get(tool_name, None)
if func: if func:
return func return func
app_root = "/app" app_root = "/app"
tool_plugin_dir = os.path.join(app_root, "pretor", "plugin", "tool_plugin", tool_name) tool_plugin_dir = os.path.join(
app_root, "pretor", "plugin", "tool_plugin", tool_name
)
if not os.path.exists(tool_plugin_dir) or not os.path.isdir(tool_plugin_dir): if not os.path.exists(tool_plugin_dir) or not os.path.isdir(tool_plugin_dir):
logger.error(f"Tool directory not found: {tool_plugin_dir}") logger.error(f"Tool directory not found: {tool_plugin_dir}")
@@ -57,7 +60,9 @@ def _get_tool_func(tool_name: str) -> Callable | None:
func = getattr(module, tool_name, None) func = getattr(module, tool_name, None)
if not callable(func): if not callable(func):
logger.error(f"Tool function '{tool_name}' not found or not callable in {module_name}") logger.error(
f"Tool function '{tool_name}' not found or not callable in {module_name}"
)
return None return None
_tool_cache[tool_name] = func _tool_cache[tool_name] = func
return func return func
@@ -65,19 +70,21 @@ def _get_tool_func(tool_name: str) -> Callable | None:
logger.error(f"Failed to load module {module_name}: {e}") logger.error(f"Failed to load module {module_name}: {e}")
return None return None
def del_tool_cache(tool_name: str) -> None: def del_tool_cache(tool_name: str) -> None:
"""执行与 del tool cache 相关的核心业务流转操作。 """执行与 del tool cache 相关的核心业务流转操作。
该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。 该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。
Args: tool_name (str): 赋予该实体的人类可读名称或标题字符串,主要用于前端 UI 展示、日志记录或模糊检索。 Args: tool_name (str): 赋予该实体的人类可读名称或标题字符串,主要用于前端 UI 展示、日志记录或模糊检索。
Returns: (None): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ Returns: (None): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。"""
if tool_name in _tool_cache: if tool_name in _tool_cache:
del _tool_cache[tool_name] del _tool_cache[tool_name]
def load_tools_from_list(tool_names: List[str] | None) -> List[Callable]: def load_tools_from_list(tool_names: List[str] | None) -> List[Callable]:
"""执行与 load tools from list 相关的核心业务流转操作。 """执行与 load tools from list 相关的核心业务流转操作。
该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。 该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。
Args: tool_names (List[str] | None): 赋予该实体的人类可读名称或标题字符串,主要用于前端 UI 展示、日志记录或模糊检索。 Args: tool_names (List[str] | None): 赋予该实体的人类可读名称或标题字符串,主要用于前端 UI 展示、日志记录或模糊检索。
Returns: (List[Callable]): 经过筛选、排序或分页处理后的实体对象列表集合。 """ Returns: (List[Callable]): 经过筛选、排序或分页处理后的实体对象列表集合。"""
if not tool_names: if not tool_names:
return [] return []
@@ -87,4 +94,4 @@ def load_tools_from_list(tool_names: List[str] | None) -> List[Callable]:
if tool_func: if tool_func:
tool_list.append(tool_func) tool_list.append(tool_func)
return tool_list return tool_list
+14 -5
View File
@@ -16,10 +16,11 @@ from loguru import logger
from rich.logging import RichHandler from rich.logging import RichHandler
from loguru._logger import Logger from loguru._logger import Logger
def setup_logger() -> Logger: def setup_logger() -> Logger:
"""对现有的 setup logger 进行状态更新或属性覆盖。 """对现有的 setup logger 进行状态更新或属性覆盖。
基于增量变更原则,合并最新的配置或数据,并触发相关依赖组件的缓存刷新或事件通知。 基于增量变更原则,合并最新的配置或数据,并触发相关依赖组件的缓存刷新或事件通知。
Returns: (Logger): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ Returns: (Logger): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。"""
logger.remove() logger.remove()
def format_record(record): def format_record(record):
@@ -27,7 +28,7 @@ def setup_logger() -> Logger:
"""执行与 format record 相关的核心业务流转操作。 """执行与 format record 相关的核心业务流转操作。
该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。 该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。
Args: record: 参与 format record 逻辑运算或数据构建的上下文依赖对象。 Args: record: 参与 format record 逻辑运算或数据构建的上下文依赖对象。
Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。"""
actor = record["extra"].get("actor_name", "System") actor = record["extra"].get("actor_name", "System")
trace_id = record["extra"].get("trace_id", "") trace_id = record["extra"].get("trace_id", "")
@@ -37,19 +38,27 @@ def setup_logger() -> Logger:
logger.configure(extra={"actor_name": "System", "trace_id": ""}) logger.configure(extra={"actor_name": "System", "trace_id": ""})
logger.add( logger.add(
RichHandler(rich_tracebacks=True, markup=True, show_time=False, show_level=False, show_path=False), RichHandler(
rich_tracebacks=True,
markup=True,
show_time=False,
show_level=False,
show_path=False,
),
format=format_record, format=format_record,
level="DEBUG", level="DEBUG",
enqueue=True, # 异步记录 enqueue=True, # 异步记录
) )
return logger return logger
global_logger = setup_logger() global_logger = setup_logger()
def get_logger(actor_name: str, trace_id: str = "") -> Logger: def get_logger(actor_name: str, trace_id: str = "") -> Logger:
"""检索并获取特定的 logger 数据集合或实例对象。 """检索并获取特定的 logger 数据集合或实例对象。
根据提供的查询条件或上下文凭证,从数据库、缓存或第三方服务中读取对应的资源状态。 根据提供的查询条件或上下文凭证,从数据库、缓存或第三方服务中读取对应的资源状态。
Args: actor_name (str): 赋予该实体的人类可读名称或标题字符串,主要用于前端 UI 展示、日志记录或模糊检索。 trace_id (str): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 trace 实例。 Args: actor_name (str): 赋予该实体的人类可读名称或标题字符串,主要用于前端 UI 展示、日志记录或模糊检索。 trace_id (str): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 trace 实例。
Returns: (Logger): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ Returns: (Logger): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。"""
return global_logger.bind(actor_name=actor_name, trace_id=trace_id) return global_logger.bind(actor_name=actor_name, trace_id=trace_id)
+4 -2
View File
@@ -17,6 +17,7 @@ from pydantic import BaseModel
T = TypeVar("T", bound=Type[BaseModel]) T = TypeVar("T", bound=Type[BaseModel])
def pickle(cls: T) -> T: def pickle(cls: T) -> T:
""" """
类装饰器pickle 类装饰器pickle
@@ -27,14 +28,15 @@ def pickle(cls: T) -> T:
Returns: Returns:
返回被重写了__reduce__魔术方法的cls类 返回被重写了__reduce__魔术方法的cls类
""" """
def __reduce__(self): def __reduce__(self):
# 1. 序列化:触发 Pydantic-core (Rust) 的极速序列化 # 1. 序列化:触发 Pydantic-core (Rust) 的极速序列化
"""执行与 reduce 相关的核心业务流转操作。 """执行与 reduce 相关的核心业务流转操作。
该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。 该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。
Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。"""
data = self.model_dump_json() data = self.model_dump_json()
# 2. 反序列化:告诉 Pickle 重建时调用 cls.model_validate_json # 2. 反序列化:告诉 Pickle 重建时调用 cls.model_validate_json
return cls.model_validate_json, (data,) return cls.model_validate_json, (data,)
cls.__reduce__ = __reduce__ cls.__reduce__ = __reduce__
return cls return cls
+11 -7
View File
@@ -14,23 +14,25 @@
import ray import ray
from functools import lru_cache from functools import lru_cache
class ActorList: class ActorList:
"""ActorList 核心组件类。 """ActorList 核心组件类。
这是一个领域数据模型或功能封装类,承载了 ActorList 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。 """ 这是一个领域数据模型或功能封装类,承载了 ActorList 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。"""
def __init__(self): def __init__(self):
super().__setattr__('dict', {}) super().__setattr__("dict", {})
def __setattr__(self, key, value): def __setattr__(self, key, value):
"""对现有的 setattr 进行状态更新或属性覆盖。 """对现有的 setattr 进行状态更新或属性覆盖。
基于增量变更原则,合并最新的配置或数据,并触发相关依赖组件的缓存刷新或事件通知。 基于增量变更原则,合并最新的配置或数据,并触发相关依赖组件的缓存刷新或事件通知。
Args: key: 参与 setattr 逻辑运算或数据构建的上下文依赖对象。 value: 参与 setattr 逻辑运算或数据构建的上下文依赖对象。 """ Args: key: 参与 setattr 逻辑运算或数据构建的上下文依赖对象。 value: 参与 setattr 逻辑运算或数据构建的上下文依赖对象。"""
self.dict[key] = value self.dict[key] = value
def __getattr__(self, key): def __getattr__(self, key):
"""检索并获取特定的 getattr 数据集合或实例对象。 """检索并获取特定的 getattr 数据集合或实例对象。
根据提供的查询条件或上下文凭证,从数据库、缓存或第三方服务中读取对应的资源状态。 根据提供的查询条件或上下文凭证,从数据库、缓存或第三方服务中读取对应的资源状态。
Args: key: 参与 getattr 逻辑运算或数据构建的上下文依赖对象。 Args: key: 参与 getattr 逻辑运算或数据构建的上下文依赖对象。
Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。"""
if key in self.dict: if key in self.dict:
return self.dict[key] return self.dict[key]
raise AttributeError(f"ActorList 对象没有属性 '{key}'") raise AttributeError(f"ActorList 对象没有属性 '{key}'")
@@ -38,28 +40,30 @@ class ActorList:
def __delattr__(self, key): def __delattr__(self, key):
"""执行与 delattr 相关的核心业务流转操作。 """执行与 delattr 相关的核心业务流转操作。
该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。 该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。
Args: key: 参与 delattr 逻辑运算或数据构建的上下文依赖对象。 """ Args: key: 参与 delattr 逻辑运算或数据构建的上下文依赖对象。"""
if key in self.dict: if key in self.dict:
del self.dict[key] del self.dict[key]
else: else:
raise AttributeError(f"ActorList对象没有属性 '{key}'") raise AttributeError(f"ActorList对象没有属性 '{key}'")
@lru_cache(maxsize=128) @lru_cache(maxsize=128)
def _get_cached_actor_handle(actor_name: str): def _get_cached_actor_handle(actor_name: str):
"""缓存接口""" """缓存接口"""
return ray.get_actor(actor_name, namespace="pretor") return ray.get_actor(actor_name, namespace="pretor")
def clear_actor_cache(): def clear_actor_cache():
"""清理接口""" """清理接口"""
_get_cached_actor_handle.cache_clear() _get_cached_actor_handle.cache_clear()
def ray_actor_hook(*actor_names: str): def ray_actor_hook(*actor_names: str):
"""执行与 ray actor hook 相关的核心业务流转操作。 """执行与 ray actor hook 相关的核心业务流转操作。
该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。 该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。
Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。"""
actor_list = ActorList() actor_list = ActorList()
for actor_name in actor_names: for actor_name in actor_names:
handle = _get_cached_actor_handle(actor_name) handle = _get_cached_actor_handle(actor_name)
setattr(actor_list, actor_name, handle) setattr(actor_list, actor_name, handle)
return actor_list return actor_list
+14 -6
View File
@@ -17,43 +17,51 @@ import asyncio
from functools import wraps from functools import wraps
from pretor.utils.error import RetryableError from pretor.utils.error import RetryableError
def retry_on_retryable_error(max_retries=3, base_delay=1): def retry_on_retryable_error(max_retries=3, base_delay=1):
"""执行与 retry on retryable error 相关的核心业务流转操作。 """执行与 retry on retryable error 相关的核心业务流转操作。
该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。 该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。
Args: max_retries: 参与 retry on retryable error 逻辑运算或数据构建的上下文依赖对象。 base_delay: 参与 retry on retryable error 逻辑运算或数据构建的上下文依赖对象。 Args: max_retries: 参与 retry on retryable error 逻辑运算或数据构建的上下文依赖对象。 base_delay: 参与 retry on retryable error 逻辑运算或数据构建的上下文依赖对象。
Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。"""
def decorator(func): def decorator(func):
"""执行与 decorator 相关的核心业务流转操作。 """执行与 decorator 相关的核心业务流转操作。
该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。 该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。
Args: func: 参与 decorator 逻辑运算或数据构建的上下文依赖对象。 Args: func: 参与 decorator 逻辑运算或数据构建的上下文依赖对象。
Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。"""
if asyncio.iscoroutinefunction(func): if asyncio.iscoroutinefunction(func):
@wraps(func) @wraps(func)
async def async_wrapper(*args, **kwargs): async def async_wrapper(*args, **kwargs):
"""执行与 async wrapper 相关的核心业务流转操作。 """执行与 async wrapper 相关的核心业务流转操作。
该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。 该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。
Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。"""
for attempt in range(max_retries): for attempt in range(max_retries):
try: try:
return await func(*args, **kwargs) return await func(*args, **kwargs)
except RetryableError: except RetryableError:
if attempt == max_retries - 1: if attempt == max_retries - 1:
raise raise
await asyncio.sleep(base_delay * (2 ** attempt)) await asyncio.sleep(base_delay * (2**attempt))
return async_wrapper return async_wrapper
else: else:
@wraps(func) @wraps(func)
def sync_wrapper(*args, **kwargs): def sync_wrapper(*args, **kwargs):
"""执行与 sync wrapper 相关的核心业务流转操作。 """执行与 sync wrapper 相关的核心业务流转操作。
该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。 该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。
Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。"""
import time import time
for attempt in range(max_retries): for attempt in range(max_retries):
try: try:
return func(*args, **kwargs) return func(*args, **kwargs)
except RetryableError: except RetryableError:
if attempt == max_retries - 1: if attempt == max_retries - 1:
raise raise
time.sleep(base_delay * (2 ** attempt)) time.sleep(base_delay * (2**attempt))
return sync_wrapper return sync_wrapper
return decorator return decorator
+3
View File
@@ -0,0 +1,3 @@
from pretor.worker_cluster.worker_cluster import WorkerCluster
__all__ = ["WorkerCluster"]
@@ -42,14 +42,16 @@ class WorkerCluster:
self.results_futures = {} self.results_futures = {}
self.runners = [] self.runners = []
self.num_runners = num_runners self.num_runners = num_runners
self.logger = get_logger('worker_cluster') self.logger = get_logger("worker_cluster")
async def start(self): async def start(self):
"""执行与 start 相关的核心业务流转操作。 """执行与 start 相关的核心业务流转操作。
该方法封装了具体的算法策略或状态控制逻辑确保操作能够在事务上下文中被原子且一致地执行 """ 该方法封装了具体的算法策略或状态控制逻辑确保操作能够在事务上下文中被原子且一致地执行"""
if self.task_queue is None: if self.task_queue is None:
self.task_queue = Queue() self.task_queue = Queue()
self.runners = [asyncio.create_task(self._runner(i)) for i in range(self.num_runners)] self.runners = [
asyncio.create_task(self._runner(i)) for i in range(self.num_runners)
]
self.logger.info(f"WorkerCluster 已启动 {self.num_runners} 个 runner 协程。") self.logger.info(f"WorkerCluster 已启动 {self.num_runners} 个 runner 协程。")
async def _recruit_worker(self, agent_id: str) -> BaseIndividual: async def _recruit_worker(self, agent_id: str) -> BaseIndividual:
@@ -58,8 +60,10 @@ class WorkerCluster:
self._active_workers.move_to_end(agent_id) self._active_workers.move_to_end(agent_id)
return self._active_workers[agent_id] return self._active_workers[agent_id]
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine global_state_machine = ray_actor_hook(
agent_config = await global_state_machine.get_individual.remote( agent_id) "global_state_machine"
).global_state_machine
agent_config = await global_state_machine.get_individual.remote(agent_id)
if not agent_config: if not agent_config:
raise ValueError(f"无法唤醒 Agent {agent_id}:数据库中不存在该档案") raise ValueError(f"无法唤醒 Agent {agent_id}:数据库中不存在该档案")
@@ -82,7 +86,7 @@ class WorkerCluster:
async def _runner(self, runner_id: int): async def _runner(self, runner_id: int):
"""执行与 runner 相关的核心业务流转操作。 """执行与 runner 相关的核心业务流转操作。
该方法封装了具体的算法策略或状态控制逻辑确保操作能够在事务上下文中被原子且一致地执行 该方法封装了具体的算法策略或状态控制逻辑确保操作能够在事务上下文中被原子且一致地执行
Args: runner_id (int): 目标对象的唯一全局标识符 (UUID/ULID)用于在数据库表或缓存结构中精准匹配该 runner 实例 """ Args: runner_id (int): 目标对象的唯一全局标识符 (UUID/ULID)用于在数据库表或缓存结构中精准匹配该 runner 实例"""
while True: while True:
try: try:
if self.task_queue is None: if self.task_queue is None:
@@ -93,7 +97,9 @@ class WorkerCluster:
agent_id = task.get("agent_id") agent_id = task.get("agent_id")
task_event = task.get("task_event") task_event = task.get("task_event")
self.logger.debug(f"[WorkerCluster Runner {runner_id}] 开始处理任务 {task_id} 给 Agent {agent_id}") self.logger.debug(
f"[WorkerCluster Runner {runner_id}] 开始处理任务 {task_id} 给 Agent {agent_id}"
)
start_time = time.time() start_time = time.time()
try: try:
@@ -105,40 +111,36 @@ class WorkerCluster:
"success": True, "success": True,
"agent_id": agent_id, "agent_id": agent_id,
"data": result, "data": result,
"metrics": {"cost_time_sec": round(cost_time, 2)} "metrics": {"cost_time_sec": round(cost_time, 2)},
} }
except Exception as e: except Exception as e:
self.logger.exception(f"[WorkerCluster Runner {runner_id}] 执行任务 {task_id} 时发生错误: {e}") self.logger.exception(
response = { f"[WorkerCluster Runner {runner_id}] 执行任务 {task_id} 时发生错误: {e}"
"success": False, )
"agent_id": agent_id, response = {"success": False, "agent_id": agent_id, "error": str(e)}
"error": str(e)
}
if task_id in self.results_futures: if task_id in self.results_futures:
future = self.results_futures[task_id] future = self.results_futures[task_id]
if not future.done(): if not future.done():
future.set_result(response) future.set_result(response)
except Exception as e: except Exception as e:
self.logger.error(f"[WorkerCluster Runner {runner_id}] 循环发生异常: {e}") self.logger.error(
f"[WorkerCluster Runner {runner_id}] 循环发生异常: {e}"
)
await asyncio.sleep(1) await asyncio.sleep(1)
async def submit_task(self, task_id: str, agent_id: str, task_event: dict): async def submit_task(self, task_id: str, agent_id: str, task_event: dict):
"""执行与 submit task 相关的核心业务流转操作。 """执行与 submit task 相关的核心业务流转操作。
该方法封装了具体的算法策略或状态控制逻辑确保操作能够在事务上下文中被原子且一致地执行 该方法封装了具体的算法策略或状态控制逻辑确保操作能够在事务上下文中被原子且一致地执行
Args: task_id (str): 目标对象的唯一全局标识符 (UUID/ULID)用于在数据库表或缓存结构中精准匹配该 task 实例 agent_id (str): 目标对象的唯一全局标识符 (UUID/ULID)用于在数据库表或缓存结构中精准匹配该 agent 实例 task_event (dict): 由事件总线或工作流引擎分发过来的事件载荷封装了触发此次调用的上下文快照与任务目标指令 Args: task_id (str): 目标对象的唯一全局标识符 (UUID/ULID)用于在数据库表或缓存结构中精准匹配该 task 实例 agent_id (str): 目标对象的唯一全局标识符 (UUID/ULID)用于在数据库表或缓存结构中精准匹配该 agent 实例 task_event (dict): 由事件总线或工作流引擎分发过来的事件载荷封装了触发此次调用的上下文快照与任务目标指令
Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象 """ Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象"""
if not self.runners: if not self.runners:
await self.start() await self.start()
future = asyncio.Future() future = asyncio.Future()
self.results_futures[task_id] = future self.results_futures[task_id] = future
task = { task = {"task_id": task_id, "agent_id": agent_id, "task_event": task_event}
"task_id": task_id,
"agent_id": agent_id,
"task_event": task_event
}
await self.task_queue.put_async(task) await self.task_queue.put_async(task)
self.logger.debug(f"[WorkerCluster] 任务 {task_id} 已加入队列。") self.logger.debug(f"[WorkerCluster] 任务 {task_id} 已加入队列。")
@@ -151,10 +153,10 @@ class WorkerCluster:
def get_cluster_metrics(self): def get_cluster_metrics(self):
"""检索并获取特定的 cluster metrics 数据集合或实例对象。 """检索并获取特定的 cluster metrics 数据集合或实例对象。
根据提供的查询条件或上下文凭证从数据库缓存或第三方服务中读取对应的资源状态 根据提供的查询条件或上下文凭证从数据库缓存或第三方服务中读取对应的资源状态
Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象 """ Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象"""
return { return {
"active_worker_count": len(self._active_workers), "active_worker_count": len(self._active_workers),
"max_capacity": self.max_capacity, "max_capacity": self.max_capacity,
"cached_agent_ids": list(self._active_workers.keys()), "cached_agent_ids": list(self._active_workers.keys()),
"queue_size": self.task_queue.size() "queue_size": self.task_queue.size(),
} }
+27 -15
View File
@@ -20,23 +20,31 @@ from pretor.utils.agent_model import ResponseModel, InputModel, DepsModel
from pretor.utils.ray_hook import ray_actor_hook from pretor.utils.ray_hook import ray_actor_hook
from pretor.utils.logger import get_logger from pretor.utils.logger import get_logger
logger = get_logger('worker_individual')
logger = get_logger("worker_individual")
class WorkerIndividualResponse(ResponseModel): class WorkerIndividualResponse(ResponseModel):
"""WorkerIndividualResponse 核心组件类。 """WorkerIndividualResponse 核心组件类。
这是一个具体的 Worker 智能体实体类,代表着具备特定人设、领域技能或长文本处理能力的数字员工。它可以被控制器动态拉起,并在安全沙箱内执行复杂的工作流指令与多步骤推理任务。 """ 这是一个具体的 Worker 智能体实体类,代表着具备特定人设、领域技能或长文本处理能力的数字员工。它可以被控制器动态拉起,并在安全沙箱内执行复杂的工作流指令与多步骤推理任务。"""
output: str = Field(..., description="Worker执行任务的输出结果") output: str = Field(..., description="Worker执行任务的输出结果")
class WorkerIndividualDeps(DepsModel): class WorkerIndividualDeps(DepsModel):
"""WorkerIndividualDeps 核心组件类。 """WorkerIndividualDeps 核心组件类。
这是一个具体的 Worker 智能体实体类,代表着具备特定人设、领域技能或长文本处理能力的数字员工。它可以被控制器动态拉起,并在安全沙箱内执行复杂的工作流指令与多步骤推理任务。 """ 这是一个具体的 Worker 智能体实体类,代表着具备特定人设、领域技能或长文本处理能力的数字员工。它可以被控制器动态拉起,并在安全沙箱内执行复杂的工作流指令与多步骤推理任务。"""
task_event: dict task_event: dict
class WorkerIndividualInput(InputModel): class WorkerIndividualInput(InputModel):
"""WorkerIndividualInput 核心组件类。 """WorkerIndividualInput 核心组件类。
这是一个具体的 Worker 智能体实体类,代表着具备特定人设、领域技能或长文本处理能力的数字员工。它可以被控制器动态拉起,并在安全沙箱内执行复杂的工作流指令与多步骤推理任务。 """ 这是一个具体的 Worker 智能体实体类,代表着具备特定人设、领域技能或长文本处理能力的数字员工。它可以被控制器动态拉起,并在安全沙箱内执行复杂的工作流指令与多步骤推理任务。"""
task_event: dict task_event: dict
class BaseIndividual: class BaseIndividual:
""" """
Worker Individual 的基类 Worker Individual 的基类
@@ -51,14 +59,21 @@ class BaseIndividual:
"""完成 agent 模块的启动与依赖初始化。 """完成 agent 模块的启动与依赖初始化。
在系统引导或服务拉起阶段被调用,负责建立网络连接、分配基础内存资源及注册核心服务组件。 在系统引导或服务拉起阶段被调用,负责建立网络连接、分配基础内存资源及注册核心服务组件。
Args: agent_name (str): 赋予该实体的人类可读名称或标题字符串,主要用于前端 UI 展示、日志记录或模糊检索。 system_prompt (str): 控制逻辑流向的具体字符串参数,指定了期望的 system prompt 内容。 Args: agent_name (str): 赋予该实体的人类可读名称或标题字符串,主要用于前端 UI 展示、日志记录或模糊检索。 system_prompt (str): 控制逻辑流向的具体字符串参数,指定了期望的 system prompt 内容。
Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。"""
from pretor.utils.get_tool import load_tools_from_list from pretor.utils.get_tool import load_tools_from_list
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
provider_title = self.agent_config.get("provider_title", "openai") # default fallback global_state_machine = ray_actor_hook(
"global_state_machine"
).global_state_machine
provider_title = self.agent_config.get(
"provider_title", "openai"
) # default fallback
model_id = self.agent_config.get("model_id", "gpt-4o") # default fallback model_id = self.agent_config.get("model_id", "gpt-4o") # default fallback
tools_list = self.agent_config.get("tools", None) tools_list = self.agent_config.get("tools", None)
provider: Provider = await global_state_machine.get_provider.remote( provider_title) provider: Provider = await global_state_machine.get_provider.remote(
provider_title
)
agent_factory = AgentFactory() agent_factory = AgentFactory()
callables = load_tools_from_list(tools_list) callables = load_tools_from_list(tools_list)
@@ -70,7 +85,7 @@ class BaseIndividual:
system_prompt=system_prompt, system_prompt=system_prompt,
deps_type=WorkerIndividualDeps, deps_type=WorkerIndividualDeps,
agent_name=agent_name, agent_name=agent_name,
tools=callables tools=callables,
) )
@self.agent.system_prompt @self.agent.system_prompt
@@ -78,17 +93,14 @@ class BaseIndividual:
"""执行与 dynamic prompt 相关的核心业务流转操作。 """执行与 dynamic prompt 相关的核心业务流转操作。
该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。 该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。
Args: ctx (RunContext[WorkerIndividualDeps]): 参与 dynamic prompt 逻辑运算或数据构建的上下文依赖对象。 Args: ctx (RunContext[WorkerIndividualDeps]): 参与 dynamic prompt 逻辑运算或数据构建的上下文依赖对象。
Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。"""
prompt = system_prompt + "\n\n" prompt = system_prompt + "\n\n"
prompt += ( prompt += f"=== 当前任务上下文 ===\n{ctx.deps.task_event}\n"
f"=== 当前任务上下文 ===\n"
f"{ctx.deps.task_event}\n"
)
return prompt return prompt
async def run(self, task_event: dict) -> dict: async def run(self, task_event: dict) -> dict:
"""执行与 run 相关的核心业务流转操作。 """执行与 run 相关的核心业务流转操作。
该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。 该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。
Args: task_event (dict): 由事件总线或工作流引擎分发过来的事件载荷,封装了触发此次调用的上下文快照与任务目标指令。 Args: task_event (dict): 由事件总线或工作流引擎分发过来的事件载荷,封装了触发此次调用的上下文快照与任务目标指令。
Returns: (dict): 高度聚合的字典结构数据,将多维度的属性特征或统计指标组合后一并返回。 """ Returns: (dict): 高度聚合的字典结构数据,将多维度的属性特征或统计指标组合后一并返回。"""
raise NotImplementedError("子类必须实现 run 方法") raise NotImplementedError("子类必须实现 run 方法")
@@ -12,10 +12,14 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from pretor.worker_individual.base_individual import BaseIndividual, WorkerIndividualDeps from pretor.worker_individual.base_individual import (
BaseIndividual,
WorkerIndividualDeps,
)
from pretor.utils.logger import get_logger from pretor.utils.logger import get_logger
logger = get_logger('ordinary_individual') logger = get_logger("ordinary_individual")
class OrdinaryIndividual(BaseIndividual): class OrdinaryIndividual(BaseIndividual):
""" """
@@ -29,18 +33,17 @@ class OrdinaryIndividual(BaseIndividual):
"""执行与 run 相关的核心业务流转操作。 """执行与 run 相关的核心业务流转操作。
该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。 该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。
Args: task_event (dict): 由事件总线或工作流引擎分发过来的事件载荷,封装了触发此次调用的上下文快照与任务目标指令。 Args: task_event (dict): 由事件总线或工作流引擎分发过来的事件载荷,封装了触发此次调用的上下文快照与任务目标指令。
Returns: (dict): 高度聚合的字典结构数据,将多维度的属性特征或统计指标组合后一并返回。 """ Returns: (dict): 高度聚合的字典结构数据,将多维度的属性特征或统计指标组合后一并返回。"""
if self.agent is None: if self.agent is None:
system_prompt = self.agent_config.get("prompt", "你是一个普通的AI助手,请尽力完成给定的任务。") system_prompt = self.agent_config.get(
"prompt", "你是一个普通的AI助手,请尽力完成给定的任务。"
)
await self._init_agent("ordinary_individual", system_prompt) await self._init_agent("ordinary_individual", system_prompt)
deps = WorkerIndividualDeps(task_event=task_event) deps = WorkerIndividualDeps(task_event=task_event)
self.agent.retries = 3 self.agent.retries = 3
try: try:
result = await self.agent.run( result = await self.agent.run(f"请执行以下任务:\n{task_event}", deps=deps)
f"请执行以下任务:\n{task_event}",
deps=deps
)
return {"output": result.data.output} return {"output": result.data.output}
except Exception as e: except Exception as e:
logger.exception(f"OrdinaryIndividual {self.agent_id} 执行失败: {e}") logger.exception(f"OrdinaryIndividual {self.agent_id} 执行失败: {e}")
+30 -12
View File
@@ -12,14 +12,18 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from pretor.worker_individual.base_individual import BaseIndividual, WorkerIndividualDeps from pretor.worker_individual.base_individual import (
BaseIndividual,
WorkerIndividualDeps,
)
from pretor.utils.logger import get_logger from pretor.utils.logger import get_logger
import os import os
import json import json
from pydantic_ai import Tool from pydantic_ai import Tool
import importlib.util import importlib.util
logger = get_logger('skill_individual') logger = get_logger("skill_individual")
class SkillIndividual(BaseIndividual): class SkillIndividual(BaseIndividual):
""" """
@@ -43,7 +47,9 @@ class SkillIndividual(BaseIndividual):
elif isinstance(bound_skill, dict): elif isinstance(bound_skill, dict):
skill_mapper = bound_skill skill_mapper = bound_skill
skill_base_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "plugin", "skill")) skill_base_dir = os.path.abspath(
os.path.join(os.path.dirname(__file__), "..", "plugin", "skill")
)
for skill_name, _ in skill_mapper.items(): for skill_name, _ in skill_mapper.items():
skill_path = os.path.join(skill_base_dir, skill_name) skill_path = os.path.join(skill_base_dir, skill_name)
@@ -52,7 +58,7 @@ class SkillIndividual(BaseIndividual):
continue continue
try: try:
with open(metadata_path, 'r', encoding='utf-8') as f: with open(metadata_path, "r", encoding="utf-8") as f:
metadata = json.load(f) metadata = json.load(f)
except Exception as e: except Exception as e:
logger.error(f"Failed to load metadata for skill {skill_name}: {e}") logger.error(f"Failed to load metadata for skill {skill_name}: {e}")
@@ -72,18 +78,28 @@ class SkillIndividual(BaseIndividual):
func_name = func_info.get("name") func_name = func_info.get("name")
try: try:
# Dynamically load the python module # Dynamically load the python module
spec = importlib.util.spec_from_file_location(func_name, script_path) spec = importlib.util.spec_from_file_location(
func_name, script_path
)
module = importlib.util.module_from_spec(spec) module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module) spec.loader.exec_module(module)
func = getattr(module, func_name) func = getattr(module, func_name)
if callable(func): if callable(func):
# Convert to PydanticAI Tool # Convert to PydanticAI Tool
tool = Tool(func, name=func_name, description=func_info.get("docstring", "")) tool = Tool(
func,
name=func_name,
description=func_info.get("docstring", ""),
)
tools.append(tool) tools.append(tool)
logger.info(f"Loaded skill tool: {func_name} from {skill_name}") logger.info(
f"Loaded skill tool: {func_name} from {skill_name}"
)
except Exception as e: except Exception as e:
logger.error(f"Failed to load function {func_name} from {script_path}: {e}") logger.error(
f"Failed to load function {func_name} from {script_path}: {e}"
)
return tools return tools
@@ -91,10 +107,12 @@ class SkillIndividual(BaseIndividual):
"""执行与 run 相关的核心业务流转操作。 """执行与 run 相关的核心业务流转操作。
该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。 该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。
Args: task_event (dict): 由事件总线或工作流引擎分发过来的事件载荷,封装了触发此次调用的上下文快照与任务目标指令。 Args: task_event (dict): 由事件总线或工作流引擎分发过来的事件载荷,封装了触发此次调用的上下文快照与任务目标指令。
Returns: (dict): 高度聚合的字典结构数据,将多维度的属性特征或统计指标组合后一并返回。 """ Returns: (dict): 高度聚合的字典结构数据,将多维度的属性特征或统计指标组合后一并返回。"""
if self.agent is None: if self.agent is None:
system_prompt = self.agent_config.get("prompt", system_prompt = self.agent_config.get(
"你是一个拥有专业技能的专家级AI助手,请利用你的专业知识完成给定的任务。") "prompt",
"你是一个拥有专业技能的专家级AI助手,请利用你的专业知识完成给定的任务。",
)
await self._init_agent("skill_individual", system_prompt) await self._init_agent("skill_individual", system_prompt)
deps = WorkerIndividualDeps(task_event=task_event) deps = WorkerIndividualDeps(task_event=task_event)
@@ -106,7 +124,7 @@ class SkillIndividual(BaseIndividual):
result = await self.agent.run( result = await self.agent.run(
f"请执行以下任务:\n{task_event}", f"请执行以下任务:\n{task_event}",
deps=deps, deps=deps,
tools=tools if tools else None tools=tools if tools else None,
) )
return {"output": result.data.output} return {"output": result.data.output}
except Exception as e: except Exception as e:
+11 -8
View File
@@ -12,10 +12,14 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from pretor.worker_individual.base_individual import BaseIndividual, WorkerIndividualDeps from pretor.worker_individual.base_individual import (
BaseIndividual,
WorkerIndividualDeps,
)
from pretor.utils.logger import get_logger from pretor.utils.logger import get_logger
logger = get_logger('special_individual') logger = get_logger("special_individual")
class SpecialIndividual(BaseIndividual): class SpecialIndividual(BaseIndividual):
""" """
@@ -29,18 +33,17 @@ class SpecialIndividual(BaseIndividual):
"""执行与 run 相关的核心业务流转操作。 """执行与 run 相关的核心业务流转操作。
该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。 该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。
Args: task_event (dict): 由事件总线或工作流引擎分发过来的事件载荷,封装了触发此次调用的上下文快照与任务目标指令。 Args: task_event (dict): 由事件总线或工作流引擎分发过来的事件载荷,封装了触发此次调用的上下文快照与任务目标指令。
Returns: (dict): 高度聚合的字典结构数据,将多维度的属性特征或统计指标组合后一并返回。 """ Returns: (dict): 高度聚合的字典结构数据,将多维度的属性特征或统计指标组合后一并返回。"""
if self.agent is None: if self.agent is None:
system_prompt = self.agent_config.get("prompt", "你是一个特殊的AI助手,负责处理特殊类型的任务。") system_prompt = self.agent_config.get(
"prompt", "你是一个特殊的AI助手,负责处理特殊类型的任务。"
)
await self._init_agent("special_individual", system_prompt) await self._init_agent("special_individual", system_prompt)
deps = WorkerIndividualDeps(task_event=task_event) deps = WorkerIndividualDeps(task_event=task_event)
self.agent.retries = 3 self.agent.retries = 3
try: try:
result = await self.agent.run( result = await self.agent.run(f"请执行以下任务:\n{task_event}", deps=deps)
f"请执行以下任务:\n{task_event}",
deps=deps
)
return {"output": result.data.output} return {"output": result.data.output}
except Exception as e: except Exception as e:
logger.exception(f"SpecialIndividual {self.agent_id} 执行失败: {e}") logger.exception(f"SpecialIndividual {self.agent_id} 执行失败: {e}")
@@ -12,8 +12,12 @@ def test_create_agent_success_real():
mock_provider.provider_url = "url" mock_provider.provider_url = "url"
with patch("pretor.adapter.model_adapter.agent_factory.Agent") as mock_agent_cls: with patch("pretor.adapter.model_adapter.agent_factory.Agent") as mock_agent_cls:
with patch("pretor.adapter.model_adapter.agent_factory.OpenAIChatModel") as mock_model_cls: with patch(
with patch("pretor.adapter.model_adapter.agent_factory.OpenAIProvider") as mock_provider_cls: "pretor.adapter.model_adapter.agent_factory.OpenAIChatModel"
) as mock_model_cls:
with patch(
"pretor.adapter.model_adapter.agent_factory.OpenAIProvider"
) as mock_provider_cls:
factory = AgentFactory() factory = AgentFactory()
agent = factory.create_agent( agent = factory.create_agent(
provider=mock_provider, provider=mock_provider,
@@ -21,17 +25,19 @@ def test_create_agent_success_real():
output_type=str, output_type=str,
system_prompt="You are an AI", system_prompt="You are an AI",
deps_type=dict, deps_type=dict,
agent_name="myagent" agent_name="myagent",
) )
mock_provider_cls.assert_called_once_with(api_key="key", base_url="url") mock_provider_cls.assert_called_once_with(api_key="key", base_url="url")
mock_model_cls.assert_called_once_with("gpt-4", provider=mock_provider_cls.return_value) mock_model_cls.assert_called_once_with(
"gpt-4", provider=mock_provider_cls.return_value
)
mock_agent_cls.assert_called_once_with( mock_agent_cls.assert_called_once_with(
model=mock_model_cls.return_value, model=mock_model_cls.return_value,
name="myagent", name="myagent",
system_prompt="You are an AI", system_prompt="You are an AI",
output_type=str, output_type=str,
deps_type=dict, deps_type=dict,
tools=None tools=None,
) )
assert agent == mock_agent_cls.return_value assert agent == mock_agent_cls.return_value
@@ -5,34 +5,42 @@ from pydantic import ValidationError
from pretor.utils.error import UserNotExistError from pretor.utils.error import UserNotExistError
from pretor.core.database.database_exception import database_exception from pretor.core.database.database_exception import database_exception
@database_exception @database_exception
async def success_func(): async def success_func():
return "success" return "success"
@database_exception @database_exception
async def validation_error_func(): async def validation_error_func():
raise ValidationError.from_exception_data(title="Mock", line_errors=[]) raise ValidationError.from_exception_data(title="Mock", line_errors=[])
@database_exception @database_exception
async def integrity_error_func(): async def integrity_error_func():
raise IntegrityError("mock_statement", "mock_params", "mock_orig") raise IntegrityError("mock_statement", "mock_params", "mock_orig")
@database_exception @database_exception
async def operational_error_func(): async def operational_error_func():
raise OperationalError("mock_statement", "mock_params", "mock_orig") raise OperationalError("mock_statement", "mock_params", "mock_orig")
@database_exception @database_exception
async def user_not_exist_error_func(): async def user_not_exist_error_func():
raise UserNotExistError("mock user") raise UserNotExistError("mock user")
@database_exception @database_exception
async def exception_func(): async def exception_func():
raise Exception("mock generic exception") raise Exception("mock generic exception")
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_success_func(): async def test_success_func():
assert await success_func() == "success" assert await success_func() == "success"
@pytest.mark.asyncio @pytest.mark.asyncio
@patch("pretor.core.database.database_exception.logger") @patch("pretor.core.database.database_exception.logger")
async def test_validation_error(mock_logger): async def test_validation_error(mock_logger):
@@ -41,6 +49,7 @@ async def test_validation_error(mock_logger):
mock_logger.error.assert_called_once() mock_logger.error.assert_called_once()
assert "对象校验失败" in mock_logger.error.call_args[0][0] assert "对象校验失败" in mock_logger.error.call_args[0][0]
@pytest.mark.asyncio @pytest.mark.asyncio
@patch("pretor.core.database.database_exception.logger") @patch("pretor.core.database.database_exception.logger")
async def test_integrity_error(mock_logger): async def test_integrity_error(mock_logger):
@@ -49,6 +58,7 @@ async def test_integrity_error(mock_logger):
mock_logger.error.assert_called_once() mock_logger.error.assert_called_once()
assert "数据库完整性错误" in mock_logger.error.call_args[0][0] assert "数据库完整性错误" in mock_logger.error.call_args[0][0]
@pytest.mark.asyncio @pytest.mark.asyncio
@patch("pretor.core.database.database_exception.logger") @patch("pretor.core.database.database_exception.logger")
async def test_operational_error(mock_logger): async def test_operational_error(mock_logger):
@@ -57,6 +67,7 @@ async def test_operational_error(mock_logger):
mock_logger.error.assert_called_once() mock_logger.error.assert_called_once()
assert "数据库连接异常" in mock_logger.error.call_args[0][0] assert "数据库连接异常" in mock_logger.error.call_args[0][0]
@pytest.mark.asyncio @pytest.mark.asyncio
@patch("pretor.core.database.database_exception.logger") @patch("pretor.core.database.database_exception.logger")
async def test_user_not_exist_error(mock_logger): async def test_user_not_exist_error(mock_logger):
@@ -65,6 +76,7 @@ async def test_user_not_exist_error(mock_logger):
mock_logger.error.assert_called_once() mock_logger.error.assert_called_once()
assert "更改密码失败,用户不存在" in mock_logger.error.call_args[0][0] assert "更改密码失败,用户不存在" in mock_logger.error.call_args[0][0]
@pytest.mark.asyncio @pytest.mark.asyncio
@patch("pretor.core.database.database_exception.logger") @patch("pretor.core.database.database_exception.logger")
async def test_generic_exception(mock_logger): async def test_generic_exception(mock_logger):
+10
View File
@@ -26,6 +26,7 @@ def mock_session_maker():
async def test_add_user(mock_session_maker, mock_dependencies): async def test_add_user(mock_session_maker, mock_dependencies):
mock_user_cls, _ = mock_dependencies mock_user_cls, _ = mock_dependencies
from pretor.core.database.module.user import AuthDatabase from pretor.core.database.module.user import AuthDatabase
maker, session = mock_session_maker maker, session = mock_session_maker
db = AuthDatabase(maker) db = AuthDatabase(maker)
@@ -51,6 +52,7 @@ async def test_add_user(mock_session_maker, mock_dependencies):
async def test_change_password_success(mock_session_maker, mock_dependencies): async def test_change_password_success(mock_session_maker, mock_dependencies):
mock_user_cls, mock_select = mock_dependencies mock_user_cls, mock_select = mock_dependencies
from pretor.core.database.module.user import AuthDatabase from pretor.core.database.module.user import AuthDatabase
maker, session = mock_session_maker maker, session = mock_session_maker
db = AuthDatabase(maker) db = AuthDatabase(maker)
@@ -79,6 +81,7 @@ async def test_change_password_success(mock_session_maker, mock_dependencies):
async def test_change_password_user_not_exist(mock_session_maker, mock_dependencies): async def test_change_password_user_not_exist(mock_session_maker, mock_dependencies):
mock_user_cls, mock_select = mock_dependencies mock_user_cls, mock_select = mock_dependencies
from pretor.core.database.module.user import AuthDatabase from pretor.core.database.module.user import AuthDatabase
maker, session = mock_session_maker maker, session = mock_session_maker
db = AuthDatabase(maker) db = AuthDatabase(maker)
@@ -94,10 +97,12 @@ async def test_change_password_user_not_exist(mock_session_maker, mock_dependenc
async def test_change_password_wrong_password(mock_session_maker, mock_dependencies): async def test_change_password_wrong_password(mock_session_maker, mock_dependencies):
mock_user_cls, mock_select = mock_dependencies mock_user_cls, mock_select = mock_dependencies
from pretor.core.database.module.user import AuthDatabase from pretor.core.database.module.user import AuthDatabase
maker, session = mock_session_maker maker, session = mock_session_maker
db = AuthDatabase(maker) db = AuthDatabase(maker)
from pretor.utils.access import Accessor from pretor.utils.access import Accessor
mock_user = MagicMock() mock_user = MagicMock()
mock_user.hashed_password = Accessor.hash_password("actual_password") mock_user.hashed_password = Accessor.hash_password("actual_password")
mock_exec_result = MagicMock() mock_exec_result = MagicMock()
@@ -105,6 +110,7 @@ async def test_change_password_wrong_password(mock_session_maker, mock_dependenc
session.execute = AsyncMock(return_value=mock_exec_result) session.execute = AsyncMock(return_value=mock_exec_result)
from pretor.utils.error import UserPasswordError from pretor.utils.error import UserPasswordError
with pytest.raises(UserPasswordError): with pytest.raises(UserPasswordError):
await db.change_password("testuser", "old_password", "new_password") await db.change_password("testuser", "old_password", "new_password")
@@ -113,6 +119,7 @@ async def test_change_password_wrong_password(mock_session_maker, mock_dependenc
async def test_delete_user_success(mock_session_maker, mock_dependencies): async def test_delete_user_success(mock_session_maker, mock_dependencies):
mock_user_cls, mock_select = mock_dependencies mock_user_cls, mock_select = mock_dependencies
from pretor.core.database.module.user import AuthDatabase from pretor.core.database.module.user import AuthDatabase
maker, session = mock_session_maker maker, session = mock_session_maker
db = AuthDatabase(maker) db = AuthDatabase(maker)
@@ -134,6 +141,7 @@ async def test_delete_user_success(mock_session_maker, mock_dependencies):
async def test_delete_user_not_exist(mock_session_maker, mock_dependencies): async def test_delete_user_not_exist(mock_session_maker, mock_dependencies):
mock_user_cls, mock_select = mock_dependencies mock_user_cls, mock_select = mock_dependencies
from pretor.core.database.module.user import AuthDatabase from pretor.core.database.module.user import AuthDatabase
maker, session = mock_session_maker maker, session = mock_session_maker
db = AuthDatabase(maker) db = AuthDatabase(maker)
@@ -149,6 +157,7 @@ async def test_delete_user_not_exist(mock_session_maker, mock_dependencies):
async def test_login_user_success(mock_session_maker, mock_dependencies): async def test_login_user_success(mock_session_maker, mock_dependencies):
mock_user_cls, mock_select = mock_dependencies mock_user_cls, mock_select = mock_dependencies
from pretor.core.database.module.user import AuthDatabase from pretor.core.database.module.user import AuthDatabase
maker, session = mock_session_maker maker, session = mock_session_maker
db = AuthDatabase(maker) db = AuthDatabase(maker)
@@ -169,6 +178,7 @@ async def test_login_user_success(mock_session_maker, mock_dependencies):
async def test_login_user_not_exist(mock_session_maker, mock_dependencies): async def test_login_user_not_exist(mock_session_maker, mock_dependencies):
mock_user_cls, mock_select = mock_dependencies mock_user_cls, mock_select = mock_dependencies
from pretor.core.database.module.user import AuthDatabase from pretor.core.database.module.user import AuthDatabase
maker, session = mock_session_maker maker, session = mock_session_maker
db = AuthDatabase(maker) db = AuthDatabase(maker)
@@ -1,5 +1,6 @@
from pretor.core.database.table.provider import Provider from pretor.core.database.table.provider import Provider
def test_provider_table(): def test_provider_table():
# Provide required fields # Provide required fields
provider = Provider( provider = Provider(
@@ -8,7 +9,7 @@ def test_provider_table():
provider_apikey="key", provider_apikey="key",
provider_models=["model_1"], provider_models=["model_1"],
provider_type="type", provider_type="type",
provider_owner=1 provider_owner=1,
) )
assert Provider.__tablename__ == 'provider' assert Provider.__tablename__ == "provider"
assert provider.provider_title == "title" assert provider.provider_title == "title"
+2 -1
View File
@@ -1,6 +1,7 @@
from pretor.core.database.table.user import User from pretor.core.database.table.user import User
def test_user_table(): def test_user_table():
user = User(user_id="id", user_name="name", hashed_password="pw") user = User(user_id="id", user_name="name", hashed_password="pw")
assert User.__tablename__ == 'user' assert User.__tablename__ == "user"
assert user.user_name == "name" assert user.user_name == "name"
@@ -7,14 +7,16 @@ real_import = builtins.__import__
def mock_import(name, globals=None, locals=None, fromlist=(), level=0): def mock_import(name, globals=None, locals=None, fromlist=(), level=0):
if name == 'ray': if name == "ray":
mock_ray = MagicMock() mock_ray = MagicMock()
def mock_remote(*args, **kwargs): def mock_remote(*args, **kwargs):
if len(args) == 1 and callable(args[0]): if len(args) == 1 and callable(args[0]):
return args[0] return args[0]
def decorator(cls): def decorator(cls):
return cls return cls
return decorator return decorator
mock_ray.remote = mock_remote mock_ray.remote = mock_remote
@@ -25,10 +27,10 @@ def mock_import(name, globals=None, locals=None, fromlist=(), level=0):
builtins.__import__ = mock_import builtins.__import__ = mock_import
for mod in list(sys.modules.keys()): for mod in list(sys.modules.keys()):
if 'pretor.core.global_state_machine.global_state_machine' in mod or 'ray' in mod: if "pretor.core.global_state_machine.global_state_machine" in mod or "ray" in mod:
del sys.modules[mod] del sys.modules[mod]
from pretor.core.global_state_machine.global_state_machine import GlobalStateMachine from pretor.core.global_state_machine.global_state_machine import GlobalStateMachine # noqa: E402
builtins.__import__ = real_import builtins.__import__ = real_import
@@ -82,13 +84,17 @@ async def test_add_provider_unsupported(gsm):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_add_provider_request_error(gsm): async def test_add_provider_request_error(gsm):
from httpx import RequestError from httpx import RequestError
mock_provider_class = AsyncMock() mock_provider_class = AsyncMock()
mock_provider_class.create_provider.side_effect = RequestError("Network Error", request=MagicMock()) mock_provider_class.create_provider.side_effect = RequestError(
"Network Error", request=MagicMock()
)
gsm._global_provider_manager.provider_mapper = {"openai": mock_provider_class} gsm._global_provider_manager.provider_mapper = {"openai": mock_provider_class}
with patch("pretor.utils.logger.global_logger.bind") as mock_bind: with patch("pretor.utils.logger.global_logger.bind") as mock_bind:
from pretor.utils.error import RetryableError from pretor.utils.error import RetryableError
import pytest import pytest
mock_logger = MagicMock() mock_logger = MagicMock()
mock_bind.return_value = mock_logger mock_bind.return_value = mock_logger
with pytest.raises(RetryableError): with pytest.raises(RetryableError):
@@ -117,3 +123,4 @@ def test_get_provider_list_and_get_provider(gsm):
assert gsm._global_provider_manager.get_provider_list() == {"p1": mock_provider} assert gsm._global_provider_manager.get_provider_list() == {"p1": mock_provider}
assert gsm._global_provider_manager.get_provider("p1") == mock_provider assert gsm._global_provider_manager.get_provider("p1") == mock_provider
assert gsm._global_provider_manager.get_provider("missing") is None assert gsm._global_provider_manager.get_provider("missing") is None
# noqa: E402
@@ -1,25 +1,32 @@
from pretor.core.global_state_machine.model_provider.base_provider import Provider, ProviderArgs, ProviderStatus from pretor.core.global_state_machine.model_provider.base_provider import (
Provider,
ProviderArgs,
ProviderStatus,
)
def test_provider_status(): def test_provider_status():
assert ProviderStatus.UP == "up" assert ProviderStatus.UP == "up"
assert ProviderStatus.DOWN == "down" assert ProviderStatus.DOWN == "down"
def test_provider_args(): def test_provider_args():
args = ProviderArgs( args = ProviderArgs(
provider_title="title", provider_title="title",
provider_url="url", provider_url="url",
provider_apikey="key", provider_apikey="key",
provider_owner="1" provider_owner="1",
) )
assert args.provider_title == "title" assert args.provider_title == "title"
def test_provider_model(): def test_provider_model():
p = Provider( p = Provider(
provider_title="title", provider_title="title",
provider_url="url", provider_url="url",
provider_apikey="key", provider_apikey="key",
provider_models=["model"], provider_models=["model"],
provider_type="openai" provider_type="openai",
) )
assert p.provider_status == ProviderStatus.UP assert p.provider_status == ProviderStatus.UP
assert p.provider_owner is None assert p.provider_owner is None
@@ -1,6 +1,9 @@
import pytest import pytest
from unittest.mock import patch, MagicMock, AsyncMock from unittest.mock import patch, MagicMock, AsyncMock
from pretor.core.global_state_machine.model_provider.claude_provider import ClaudeProvider, ProviderArgs from pretor.core.global_state_machine.model_provider.claude_provider import (
ClaudeProvider,
ProviderArgs,
)
@pytest.fixture @pytest.fixture
@@ -9,12 +12,14 @@ def provider_args():
provider_title="TestClaude", provider_title="TestClaude",
provider_url="https://api.anthropic.com", provider_url="https://api.anthropic.com",
provider_apikey="testkey", provider_apikey="testkey",
provider_owner="1" provider_owner="1",
) )
@pytest.mark.asyncio @pytest.mark.asyncio
@patch("pretor.core.global_state_machine.model_provider.claude_provider.httpx.AsyncClient") @patch(
"pretor.core.global_state_machine.model_provider.claude_provider.httpx.AsyncClient"
)
async def test_load_models_success(mock_client, provider_args): async def test_load_models_success(mock_client, provider_args):
mock_response = MagicMock() mock_response = MagicMock()
mock_response.status_code = 200 mock_response.status_code = 200
@@ -31,7 +36,9 @@ async def test_load_models_success(mock_client, provider_args):
@pytest.mark.asyncio @pytest.mark.asyncio
@patch("pretor.core.global_state_machine.model_provider.claude_provider.httpx.AsyncClient") @patch(
"pretor.core.global_state_machine.model_provider.claude_provider.httpx.AsyncClient"
)
async def test_load_models_error(mock_client, provider_args): async def test_load_models_error(mock_client, provider_args):
mock_client_instance = AsyncMock() mock_client_instance = AsyncMock()
mock_client_instance.get.side_effect = Exception("network error") mock_client_instance.get.side_effect = Exception("network error")
@@ -42,8 +49,10 @@ async def test_load_models_error(mock_client, provider_args):
@pytest.mark.asyncio @pytest.mark.asyncio
@patch("pretor.core.global_state_machine.model_provider.claude_provider.ClaudeProvider._load_models", @patch(
return_value=["claude-3"]) "pretor.core.global_state_machine.model_provider.claude_provider.ClaudeProvider._load_models",
return_value=["claude-3"],
)
async def test_create_provider(mock_load, provider_args): async def test_create_provider(mock_load, provider_args):
provider = await ClaudeProvider.create_provider(provider_args) provider = await ClaudeProvider.create_provider(provider_args)
assert provider.provider_title == "TestClaude" assert provider.provider_title == "TestClaude"
@@ -1,6 +1,9 @@
import pytest import pytest
from unittest.mock import patch, MagicMock, AsyncMock from unittest.mock import patch, MagicMock, AsyncMock
from pretor.core.global_state_machine.model_provider.openai_provider import OpenAIProvider, ProviderArgs from pretor.core.global_state_machine.model_provider.openai_provider import (
OpenAIProvider,
ProviderArgs,
)
@pytest.fixture @pytest.fixture
@@ -9,7 +12,7 @@ def provider_args():
provider_title="TestOpenAI", provider_title="TestOpenAI",
provider_url="https://api.openai.com/v1", provider_url="https://api.openai.com/v1",
provider_apikey="testkey", provider_apikey="testkey",
provider_owner="1" provider_owner="1",
) )
@@ -19,12 +22,14 @@ def provider_args_no_v1():
provider_title="TestOpenAI", provider_title="TestOpenAI",
provider_url="https://api.openai.com", provider_url="https://api.openai.com",
provider_apikey="testkey", provider_apikey="testkey",
provider_owner="1" provider_owner="1",
) )
@pytest.mark.asyncio @pytest.mark.asyncio
@patch("pretor.core.global_state_machine.model_provider.openai_provider.httpx.AsyncClient") @patch(
"pretor.core.global_state_machine.model_provider.openai_provider.httpx.AsyncClient"
)
async def test_load_models_success(mock_client, provider_args): async def test_load_models_success(mock_client, provider_args):
mock_response = MagicMock() mock_response = MagicMock()
mock_response.status_code = 200 mock_response.status_code = 200
@@ -40,12 +45,14 @@ async def test_load_models_success(mock_client, provider_args):
assert models == ["gpt-3.5-turbo", "gpt-4"] assert models == ["gpt-3.5-turbo", "gpt-4"]
mock_client_instance.get.assert_called_once_with( mock_client_instance.get.assert_called_once_with(
"https://api.openai.com/v1/models", "https://api.openai.com/v1/models",
headers={"Authorization": "Bearer testkey", "Content-Type": "application/json"} headers={"Authorization": "Bearer testkey", "Content-Type": "application/json"},
) )
@pytest.mark.asyncio @pytest.mark.asyncio
@patch("pretor.core.global_state_machine.model_provider.openai_provider.httpx.AsyncClient") @patch(
"pretor.core.global_state_machine.model_provider.openai_provider.httpx.AsyncClient"
)
async def test_load_models_no_v1(mock_client, provider_args_no_v1): async def test_load_models_no_v1(mock_client, provider_args_no_v1):
mock_response = MagicMock() mock_response = MagicMock()
mock_response.status_code = 200 mock_response.status_code = 200
@@ -59,12 +66,14 @@ async def test_load_models_no_v1(mock_client, provider_args_no_v1):
assert models == [] assert models == []
mock_client_instance.get.assert_called_once_with( mock_client_instance.get.assert_called_once_with(
"https://api.openai.com/v1/models", "https://api.openai.com/v1/models",
headers={"Authorization": "Bearer testkey", "Content-Type": "application/json"} headers={"Authorization": "Bearer testkey", "Content-Type": "application/json"},
) )
@pytest.mark.asyncio @pytest.mark.asyncio
@patch("pretor.core.global_state_machine.model_provider.openai_provider.httpx.AsyncClient") @patch(
"pretor.core.global_state_machine.model_provider.openai_provider.httpx.AsyncClient"
)
async def test_load_models_status_error(mock_client, provider_args): async def test_load_models_status_error(mock_client, provider_args):
mock_response = MagicMock() mock_response = MagicMock()
mock_response.status_code = 401 mock_response.status_code = 401
@@ -78,21 +87,29 @@ async def test_load_models_status_error(mock_client, provider_args):
@pytest.mark.asyncio @pytest.mark.asyncio
@patch("pretor.core.global_state_machine.model_provider.openai_provider.httpx.AsyncClient") @patch(
"pretor.core.global_state_machine.model_provider.openai_provider.httpx.AsyncClient"
)
async def test_load_models_request_error(mock_client, provider_args): async def test_load_models_request_error(mock_client, provider_args):
import httpx import httpx
mock_client_instance = AsyncMock() mock_client_instance = AsyncMock()
mock_client_instance.get.side_effect = httpx.RequestError("network error", request=MagicMock()) mock_client_instance.get.side_effect = httpx.RequestError(
"network error", request=MagicMock()
)
mock_client.return_value.__aenter__.return_value = mock_client_instance mock_client.return_value.__aenter__.return_value = mock_client_instance
import pytest import pytest
from pretor.utils.error import RetryableError from pretor.utils.error import RetryableError
with pytest.raises(RetryableError): with pytest.raises(RetryableError):
await OpenAIProvider._load_models(provider_args) await OpenAIProvider._load_models(provider_args)
@pytest.mark.asyncio @pytest.mark.asyncio
@patch("pretor.core.global_state_machine.model_provider.openai_provider.httpx.AsyncClient") @patch(
"pretor.core.global_state_machine.model_provider.openai_provider.httpx.AsyncClient"
)
async def test_load_models_generic_error(mock_client, provider_args): async def test_load_models_generic_error(mock_client, provider_args):
mock_client_instance = AsyncMock() mock_client_instance = AsyncMock()
mock_client_instance.get.side_effect = Exception("generic error") mock_client_instance.get.side_effect = Exception("generic error")
@@ -103,8 +120,10 @@ async def test_load_models_generic_error(mock_client, provider_args):
@pytest.mark.asyncio @pytest.mark.asyncio
@patch("pretor.core.global_state_machine.model_provider.openai_provider.OpenAIProvider._load_models", @patch(
return_value=["gpt-4"]) "pretor.core.global_state_machine.model_provider.openai_provider.OpenAIProvider._load_models",
return_value=["gpt-4"],
)
async def test_create_provider(mock_load, provider_args): async def test_create_provider(mock_load, provider_args):
provider = await OpenAIProvider.create_provider(provider_args) provider = await OpenAIProvider.create_provider(provider_args)
assert provider.provider_title == "TestOpenAI" assert provider.provider_title == "TestOpenAI"
@@ -13,11 +13,15 @@ async def test_provider_manager_init():
mock_provider2.provider_title = "title2" mock_provider2.provider_title = "title2"
mock_postgres.get_provider = MagicMock() mock_postgres.get_provider = MagicMock()
mock_postgres.get_provider.remote = AsyncMock(return_value=[mock_provider1, mock_provider2]) mock_postgres.get_provider.remote = AsyncMock(
return_value=[mock_provider1, mock_provider2]
)
manager = ProviderManager(mock_postgres) manager = ProviderManager(mock_postgres)
mock_postgres.provider_database = MagicMock() mock_postgres.provider_database = MagicMock()
mock_postgres.provider_database.remote = AsyncMock(return_value=[mock_provider1, mock_provider2]) mock_postgres.provider_database.remote = AsyncMock(
return_value=[mock_provider1, mock_provider2]
)
await manager.init_provider_register(mock_postgres) await manager.init_provider_register(mock_postgres)
assert "openai" in manager.provider_mapper assert "openai" in manager.provider_mapper
@@ -1,5 +1,6 @@
from pretor.core.global_state_machine.tool_manager import GlobalToolManager from pretor.core.global_state_machine.tool_manager import GlobalToolManager
def test_global_tool_manager_init(): def test_global_tool_manager_init():
manager = GlobalToolManager() manager = GlobalToolManager()
assert isinstance(manager, GlobalToolManager) assert isinstance(manager, GlobalToolManager)
@@ -7,14 +7,16 @@ real_import = builtins.__import__
def mock_import(name, globals=None, locals=None, fromlist=(), level=0): def mock_import(name, globals=None, locals=None, fromlist=(), level=0):
if name == 'ray': if name == "ray":
mock_ray = MagicMock() mock_ray = MagicMock()
def mock_remote(*args, **kwargs): def mock_remote(*args, **kwargs):
if len(args) == 1 and callable(args[0]): if len(args) == 1 and callable(args[0]):
return args[0] return args[0]
def decorator(cls): def decorator(cls):
return cls return cls
return decorator return decorator
mock_ray.remote = mock_remote mock_ray.remote = mock_remote
@@ -24,28 +26,30 @@ def mock_import(name, globals=None, locals=None, fromlist=(), level=0):
builtins.__import__ = mock_import builtins.__import__ = mock_import
for mod in list(sys.modules.keys()): for mod in list(sys.modules.keys()):
if 'pretor.core.database.postgres' in mod or 'ray' in mod: if "pretor.core.postgres_database.postgres" in mod or "ray" in mod:
del sys.modules[mod] del sys.modules[mod]
from pretor.core.database.postgres import PostgresDatabase from pretor.core.postgres_database.postgres import PostgresDatabase # noqa: E402
builtins.__import__ = real_import builtins.__import__ = real_import
@patch("pretor.core.database.postgres.create_async_engine") @patch("pretor.core.postgres_database.postgres.create_async_engine")
@patch("pretor.core.database.postgres.sessionmaker") @patch("pretor.core.postgres_database.postgres.sessionmaker")
@patch("pretor.core.database.postgres.AuthDatabase") @patch("pretor.core.postgres_database.postgres.AuthDatabase")
@patch("pretor.core.database.postgres.ProviderDatabase") @patch("pretor.core.postgres_database.postgres.ProviderDatabase")
@patch("pretor.core.database.postgres.os.environ.get") @patch("pretor.core.postgres_database.postgres.os.environ.get")
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_postgres_database(mock_env_get, mock_provider_db, mock_auth_db, mock_sessionmaker, mock_create_engine): async def test_postgres_database(
mock_env_get, mock_provider_db, mock_auth_db, mock_sessionmaker, mock_create_engine
):
def env_side_effect(key): def env_side_effect(key):
return { return {
"POSTGRES_USER": "testuser", "POSTGRES_USER": "testuser",
"POSTGRES_PASSWORD": "testpassword", "POSTGRES_PASSWORD": "testpassword",
"POSTGRES_HOST": "localhost", "POSTGRES_HOST": "localhost",
"POSTGRES_PORT": "5432", "POSTGRES_PORT": "5432",
"POSTGRES_DB": "testdb" "POSTGRES_DB": "testdb",
}.get(key) }.get(key)
mock_env_get.side_effect = env_side_effect mock_env_get.side_effect = env_side_effect
@@ -53,6 +57,7 @@ async def test_postgres_database(mock_env_get, mock_provider_db, mock_auth_db, m
mock_engine = MagicMock() mock_engine = MagicMock()
mock_conn = MagicMock() mock_conn = MagicMock()
from unittest.mock import AsyncMock from unittest.mock import AsyncMock
mock_conn.run_sync = AsyncMock() mock_conn.run_sync = AsyncMock()
mock_begin_ctx = MagicMock() mock_begin_ctx = MagicMock()
@@ -64,15 +69,17 @@ async def test_postgres_database(mock_env_get, mock_provider_db, mock_auth_db, m
db = PostgresDatabase() db = PostgresDatabase()
mock_create_engine.assert_called_once_with( mock_create_engine.assert_called_once_with(
"postgresql+asyncpg://testuser:testpassword@localhost:5432/testdb", "postgresql+asyncpg://testuser:testpassword@localhost:5432/testdb", echo=True
echo=True
) )
mock_auth_db.assert_called_once() mock_auth_db.assert_called_once()
mock_provider_db.assert_called_once() mock_provider_db.assert_called_once()
mock_auth_db.return_value.get_user_authority = AsyncMock(return_value="test_auth") mock_auth_db.return_value.get_user_authority = AsyncMock(return_value="test_auth")
with patch("pretor.core.database.postgres.SQLModel.metadata.create_all") as mock_create_all: with patch(
"pretor.core.postgres_database.postgres.SQLModel.metadata.create_all"
) as mock_create_all:
await db.init_db() await db.init_db()
mock_conn.run_sync.assert_called_once_with(mock_create_all) mock_conn.run_sync.assert_called_once_with(mock_create_all)
assert await db.get_user_authority(user_id="123") == "test_auth" assert await db.get_user_authority(user_id="123") == "test_auth"
# noqa: E402
@@ -1,43 +0,0 @@
from unittest.mock import patch, MagicMock
from pretor.core.workflow.workflow_template_generator.workflow_template_generator import WorkflowTemplateGenerator
@patch("pretor.core.workflow.workflow_template_generator.workflow_template_generator.Path")
def test_generate_workflow_template(mock_path):
mock_dir = MagicMock()
mock_dir.exists.return_value = False
mock_file = MagicMock()
mock_dir.__truediv__.return_value = mock_file
mock_open_ctx = MagicMock()
mock_file.open.return_value.__enter__.return_value = mock_open_ctx
mock_path_root = MagicMock()
mock_path_root.__truediv__.return_value = mock_dir
mock_path.return_value = mock_path_root
from pretor.core.workflow.workflow_template_generator.workflow_template import WorkflowTemplate
generator = WorkflowTemplateGenerator()
mock_template = MagicMock(spec=WorkflowTemplate)
mock_template.name = "test_wf"
mock_template.desc = "test_desc"
import json
mock_template.model_dump_json.return_value = json.dumps({
"name": "test_wf",
"desc": "test_desc",
"work_link": [{"step": 1, "node": "n", "action": "a", "desc": "d", "input": [], "output": [], "logic_gate": {}}]
})
generator.generate_workflow_template(
workflow_template=mock_template
)
mock_dir.mkdir.assert_called_once_with(parents=True)
mock_file.open.assert_called_once_with("w", encoding="utf-8")
mock_open_ctx.write.assert_called_once()
write_arg = mock_open_ctx.write.call_args[0][0]
written_data = json.loads(write_arg)
assert written_data["name"] == "test_wf"
assert written_data["desc"] == "test_desc"
assert written_data["work_link"][0]["step"] == 1
@@ -1,36 +0,0 @@
import pytest
from pydantic import ValidationError
from pretor.core.workflow.workflow_template_generator.workflow_template import WorkflowTemplateStep, WorkflowTemplate
def test_workflow_template_step():
step = WorkflowTemplateStep(
step=1,
node="node_type",
action="act",
desc="desc",
input=["in1"],
output=["out1"],
logic_gate={"if_pass": "next"}
)
assert step.step == 1
assert step.node == "node_type"
def test_workflow_template_success():
step1 = WorkflowTemplateStep(
step=1, node="node1", action="a1", desc="d1", input=[], output=[], logic_gate={}
)
step2 = WorkflowTemplateStep(
step=2, node="node2", action="a2", desc="d2", input=[], output=[], logic_gate={}
)
wt = WorkflowTemplate(name="temp", desc="desc", work_link=[step1, step2])
assert wt.name == "temp"
def test_workflow_template_error_duplicate_steps():
step1 = WorkflowTemplateStep(
step=1, node="node1", action="a1", desc="d1", input=[], output=[], logic_gate={}
)
step2 = WorkflowTemplateStep(
step=1, node="node2", action="a2", desc="d2", input=[], output=[], logic_gate={}
)
with pytest.raises(ValidationError, match="Step numbers in work_link must be unique"):
WorkflowTemplate(name="temp", desc="desc", work_link=[step1, step2])
@@ -1,57 +0,0 @@
import json
from unittest.mock import MagicMock, patch, mock_open
from pathlib import Path
from pretor.core.workflow.workflow_template_manager import WorkflowManager
def test_workflow_manager_init_success():
mock_file1 = MagicMock(spec=Path)
mock_file1.open = mock_open(read_data=json.dumps({"name": "test1", "desc": "desc1"}))
mock_file2 = MagicMock(spec=Path)
mock_file2.open = mock_open(read_data=json.dumps({"name": "test2", "desc": "desc2"}))
with patch("pretor.core.workflow.workflow_template_manager.Path.glob", return_value=[mock_file1, mock_file2]):
with patch("pretor.core.workflow.workflow_template_manager.WorkflowTemplateGenerator"):
manager = WorkflowManager()
assert manager.workflow_templates_registry == {"test1": "desc1", "test2": "desc2"}
def test_workflow_manager_init_json_error():
mock_file1 = MagicMock(spec=Path)
mock_file1.open = mock_open(read_data="{invalid_json}")
with patch("pretor.core.workflow.workflow_template_manager.Path.glob", return_value=[mock_file1]):
with patch("pretor.core.workflow.workflow_template_manager.logger") as mock_logger:
with patch("pretor.core.workflow.workflow_template_manager.WorkflowTemplateGenerator"):
manager = WorkflowManager()
assert manager.workflow_templates_registry == {}
mock_logger.warning.assert_called_once()
assert "不是json文件或格式错误" in mock_logger.warning.call_args[0][0]
from pretor.core.workflow.workflow_template_generator.workflow_template import WorkflowTemplate
@patch("pretor.core.workflow.workflow_template_manager.WorkflowTemplateGenerator")
@patch("pretor.core.workflow.workflow_template_manager.Path.glob", return_value=[])
def test_generate_workflow_template_success(mock_glob, mock_generator_cls):
manager = WorkflowManager()
mock_template = MagicMock(spec=WorkflowTemplate)
mock_template.name = "name"
mock_template.desc = "desc"
mock_generator_cls.return_value.generate_workflow_template.return_value = mock_template
manager.generate_workflow_template(mock_template)
mock_generator_cls.return_value.generate_workflow_template.assert_called_once_with(workflow_template=mock_template)
assert manager.workflow_templates_registry["name"] == "desc"
@patch("pretor.core.workflow.workflow_template_manager.WorkflowTemplateGenerator")
@patch("pretor.core.workflow.workflow_template_manager.Path.glob", return_value=[])
@patch("pretor.core.workflow.workflow_template_manager.logger")
def test_generate_workflow_template_exception(mock_logger, mock_glob, mock_generator_cls):
mock_generator_cls.return_value.generate_workflow_template.side_effect = Exception("error")
manager = WorkflowManager()
mock_template = MagicMock(spec=WorkflowTemplate)
manager.generate_workflow_template(mock_template)
mock_logger.exception.assert_called_once_with("Failed to generate workflow template")
+43 -8
View File
@@ -1,5 +1,11 @@
import pytest import pytest
from pretor.core.workflow.workflow import WorkStep, PretorWorkflow, WorkflowStatus, LogicGate from pretor.core.workflow.workflow import (
WorkStep,
PretorWorkflow,
WorkflowStatus,
LogicGate,
)
def test_work_step(): def test_work_step():
ws = WorkStep( ws = WorkStep(
@@ -7,7 +13,7 @@ def test_work_step():
name="step1", name="step1",
node="control_node", node="control_node",
action="coding", action="coding",
desc="Write some code" desc="Write some code",
) )
assert ws.step == 1 assert ws.step == 1
assert ws.name == "step1" assert ws.name == "step1"
@@ -16,30 +22,59 @@ def test_work_step():
assert ws.desc == "Write some code" assert ws.desc == "Write some code"
assert ws.status == "waiting" assert ws.status == "waiting"
def test_pretor_workflow_validation_success(): def test_pretor_workflow_validation_success():
ws1 = WorkStep(step=1, name="s1", node="control_node", action="a1", desc="d1") ws1 = WorkStep(step=1, name="s1", node="control_node", action="a1", desc="d1")
ws2 = WorkStep(step=2, name="s2", node="supervisory_node", action="a2", desc="d2") ws2 = WorkStep(step=2, name="s2", node="supervisory_node", action="a2", desc="d2")
wf = PretorWorkflow(title="wf1", work_link=[ws1, ws2], trace_id="t", event_info={"platform":"a", "user_name":"b"}) wf = PretorWorkflow(
title="wf1",
work_link=[ws1, ws2],
trace_id="t",
event_info={"platform": "a", "user_name": "b"},
)
assert wf.title == "wf1" assert wf.title == "wf1"
def test_pretor_workflow_validation_error_step_discontinuous(): def test_pretor_workflow_validation_error_step_discontinuous():
ws1 = WorkStep(step=1, name="s1", node="control_node", action="a1", desc="d1") ws1 = WorkStep(step=1, name="s1", node="control_node", action="a1", desc="d1")
ws2 = WorkStep(step=3, name="s3", node="supervisory_node", action="a2", desc="d2") ws2 = WorkStep(step=3, name="s3", node="supervisory_node", action="a2", desc="d2")
with pytest.raises(ValueError, match="工作链步数不连续"): with pytest.raises(ValueError, match="工作链步数不连续"):
PretorWorkflow(title="wf1", work_link=[ws1, ws2], trace_id="t", event_info={"platform":"a", "user_name":"b"}) PretorWorkflow(
title="wf1",
work_link=[ws1, ws2],
trace_id="t",
event_info={"platform": "a", "user_name": "b"},
)
def test_pretor_workflow_validation_error_jump_out_of_bounds(): def test_pretor_workflow_validation_error_jump_out_of_bounds():
lg = LogicGate(if_fail="jump_to_step_3", if_pass="continue") lg = LogicGate(if_fail="jump_to_step_3", if_pass="continue")
ws1 = WorkStep(step=1, name="s1", node="control_node", action="a1", desc="d1", logic_gate=lg) ws1 = WorkStep(
step=1, name="s1", node="control_node", action="a1", desc="d1", logic_gate=lg
)
ws2 = WorkStep(step=2, name="s2", node="supervisory_node", action="a2", desc="d2") ws2 = WorkStep(step=2, name="s2", node="supervisory_node", action="a2", desc="d2")
with pytest.raises(ValueError, match="跳转目标 Step 3 越界了"): with pytest.raises(ValueError, match="跳转目标 Step 3 越界了"):
PretorWorkflow(title="wf1", work_link=[ws1, ws2], trace_id="t", event_info={"platform":"a", "user_name":"b"}) PretorWorkflow(
title="wf1",
work_link=[ws1, ws2],
trace_id="t",
event_info={"platform": "a", "user_name": "b"},
)
def test_pretor_workflow_validation_error_jump_format_error(): def test_pretor_workflow_validation_error_jump_format_error():
lg = LogicGate(if_fail="jump_to_step_invalid", if_pass="continue") lg = LogicGate(if_fail="jump_to_step_invalid", if_pass="continue")
ws1 = WorkStep(step=1, name="s1", node="control_node", action="a1", desc="d1", logic_gate=lg) ws1 = WorkStep(
step=1, name="s1", node="control_node", action="a1", desc="d1", logic_gate=lg
)
with pytest.raises(ValueError, match="LogicGate 格式错误"): with pytest.raises(ValueError, match="LogicGate 格式错误"):
PretorWorkflow(title="wf1", work_link=[ws1], trace_id="t", event_info={"platform":"a", "user_name":"b"}) PretorWorkflow(
title="wf1",
work_link=[ws1],
trace_id="t",
event_info={"platform": "a", "user_name": "b"},
)
def test_workflow_status(): def test_workflow_status():
status = WorkflowStatus() status = WorkflowStatus()
@@ -9,14 +9,16 @@ real_import = builtins.__import__
def mock_import(name, globals=None, locals=None, fromlist=(), level=0): def mock_import(name, globals=None, locals=None, fromlist=(), level=0):
if name == 'ray': if name == "ray":
mock_ray = MagicMock() mock_ray = MagicMock()
def mock_remote(*args, **kwargs): def mock_remote(*args, **kwargs):
if len(args) == 1 and callable(args[0]): if len(args) == 1 and callable(args[0]):
return args[0] return args[0]
def decorator(cls): def decorator(cls):
return cls return cls
return decorator return decorator
mock_ray.remote = mock_remote mock_ray.remote = mock_remote
@@ -26,16 +28,19 @@ def mock_import(name, globals=None, locals=None, fromlist=(), level=0):
builtins.__import__ = mock_import builtins.__import__ = mock_import
for mod in list(sys.modules.keys()): for mod in list(sys.modules.keys()):
if 'pretor.core.workflow.workflow_runner' in mod or 'ray' in mod: if "pretor.core.workflow_running_engine.workflow_runner" in mod or "ray" in mod:
del sys.modules[mod] del sys.modules[mod]
from pretor.core.workflow.workflow_runner import WorkflowEngine, WorkflowRunningEngine from pretor.core.workflow_running_engine.workflow_runner import ( # noqa: E402
WorkflowEngine,
WorkflowRunningEngine,
)
builtins.__import__ = real_import builtins.__import__ = real_import
@pytest.fixture @pytest.fixture
def mock_ray(): def mock_ray():
with patch("pretor.core.workflow.workflow_runner.ray") as mock_ray: with patch("pretor.core.workflow_running_engine.workflow_runner.ray") as mock_ray:
mock_ray.get = lambda x: x mock_ray.get = lambda x: x
yield mock_ray yield mock_ray
@@ -91,7 +96,9 @@ async def test_workflow_engine_run():
engine = WorkflowEngine(mock_wf, mock_conscious, mock_control, mock_supervisor) engine = WorkflowEngine(mock_wf, mock_conscious, mock_control, mock_supervisor)
with patch("pretor.core.workflow.workflow_runner.ray") as mock_ray_patch: with patch(
"pretor.core.workflow_running_engine.workflow_runner.ray"
) as mock_ray_patch:
mock_gsm = MagicMock() mock_gsm = MagicMock()
mock_ray_patch.get_actor.return_value = mock_gsm mock_ray_patch.get_actor.return_value = mock_gsm
await engine.run() await engine.run()
@@ -141,22 +148,36 @@ async def test_workflow_running_engine_runner():
user_id="test_user", user_id="test_user",
user_name="test_user", user_name="test_user",
message="test_message", message="test_message",
context={"workflow_template": "test_template"} context={},
) )
await engine.workflow_queue.put(mock_event) await engine.workflow_queue.put(mock_event)
# Mock the global_state_machine get_skill_list.remote method properly # Mock the global_state_machine get_skill_list.remote method properly
mock_gsm = MagicMock() mock_gsm = MagicMock()
mock_gsm.list_individuals.remote = AsyncMock(return_value={"test_skill": {"agent_type": "skill_individual", "agent_name": "TestSkill", "description": "desc"}}) mock_gsm.list_individuals.remote = AsyncMock(
return_value={
"test_skill": {
"agent_type": "skill_individual",
"agent_name": "TestSkill",
"description": "desc",
}
}
)
engine.global_state_machine = mock_gsm engine.global_state_machine = mock_gsm
with patch("pretor.core.workflow.workflow_runner.WorkflowEngine") as mock_wf_engine_cls, patch("builtins.open", new_callable=MagicMock) as mock_open, \ with (
patch("pretor.core.workflow.workflow_runner.ray_actor_hook") as mock_hook: patch(
"pretor.core.workflow_running_engine.workflow_runner.WorkflowEngine"
) as mock_wf_engine_cls,
patch("builtins.open", new_callable=MagicMock) as mock_open,
patch(
"pretor.core.workflow_running_engine.workflow_runner.ray_actor_hook"
) as mock_hook,
):
# Instead of patching hook, we inject it directly # Instead of patching hook, we inject it directly
# engine.global_state_machine = AsyncMock() # engine.global_state_machine = AsyncMock()
mock_open.return_value.__enter__.return_value.read.return_value = '{}' mock_open.return_value.__enter__.return_value.read.return_value = "{}"
mock_gwm = MagicMock() mock_gwm = MagicMock()
mock_gwm.update_workflow.remote = AsyncMock() mock_gwm.update_workflow.remote = AsyncMock()
@@ -170,4 +191,7 @@ async def test_workflow_running_engine_runner():
await asyncio.sleep(0.05) await asyncio.sleep(0.05)
task.cancel() task.cancel()
mock_wf_engine_cls.assert_called_with(mock_wf, mock_consciousness, "control", "supervisor") mock_wf_engine_cls.assert_called_with(
mock_wf, mock_consciousness, "control", "supervisor"
)
# noqa: E402
+7 -4
View File
@@ -28,9 +28,9 @@ sys.modules["passlib"] = MagicMock()
sys.modules["passlib.context"] = MagicMock() sys.modules["passlib.context"] = MagicMock()
sys.modules["pretor.core.database.table.user"] = MagicMock() sys.modules["pretor.core.database.table.user"] = MagicMock()
import pytest import pytest # noqa: E402
import jwt import jwt # noqa: E402
from pretor.utils.access import Accessor from pretor.utils.access import Accessor # noqa: E402
def test_decode_token_success(): def test_decode_token_success():
@@ -55,6 +55,7 @@ def test_decode_token_expired():
token = "expired.token.here" token = "expired.token.here"
from fastapi import HTTPException from fastapi import HTTPException
with patch("jwt.decode", side_effect=jwt.ExpiredSignatureError): with patch("jwt.decode", side_effect=jwt.ExpiredSignatureError):
with patch("pretor.utils.access.HTTPException", HTTPException): with patch("pretor.utils.access.HTTPException", HTTPException):
with pytest.raises(HTTPException) as excinfo: with pytest.raises(HTTPException) as excinfo:
@@ -69,6 +70,7 @@ def test_decode_token_invalid():
token = "invalid.token.here" token = "invalid.token.here"
from fastapi import HTTPException from fastapi import HTTPException
with patch("jwt.decode", side_effect=jwt.InvalidTokenError): with patch("jwt.decode", side_effect=jwt.InvalidTokenError):
with patch("pretor.utils.access.HTTPException", HTTPException): with patch("pretor.utils.access.HTTPException", HTTPException):
with pytest.raises(HTTPException) as excinfo: with pytest.raises(HTTPException) as excinfo:
@@ -93,4 +95,5 @@ def test_decode_token_validation_error():
Accessor._decode_token(token) Accessor._decode_token(token)
assert excinfo.value.status_code == 401 assert excinfo.value.status_code == 401
assert excinfo.value.detail == "无效的认证凭证" assert excinfo.value.detail == "无效的认证凭证"
# noqa: E402