chore: initial commit for Pretor v0.1.0-alpha

正式发布 Pretor 平台的首个 alpha 版本。本项目旨在构建一个基于分布式架构的多智能体协同工作流水线。

核心功能实现:
1. 建立基于 BaseIndividual 的动态插件加载机制。
2. 实现三类核心 worker_individual 子个体。
3. 集成 Ray 框架支持分布式集群调度。
4. 基于 PostgreSQL 的全量持久化存储方案。
5. 提供完整的 FastAPI 后端与 React 前端交互界面。
This commit is contained in:
2026-04-29 10:09:07 +08:00
commit d84212f780
163 changed files with 19251 additions and 0 deletions
@@ -0,0 +1,14 @@
# Copyright 2026 zhaoxi826
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
@@ -0,0 +1,166 @@
# Copyright 2026 zhaoxi826
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import ray
from pretor.core.global_state_machine.provider_manager import ProviderManager
from pretor.core.global_state_machine.tool_manager import GlobalToolManager
from typing import Dict
from pretor.core.database.postgres import PostgresDatabase
from pretor.api.platform.event import PretorEvent
import asyncio
from pretor.core.workflow.workflow import PretorWorkflow
from pretor.core.workflow.workflow_template_manager import WorkflowManager
from pretor.core.global_state_machine.skill_manager import GlobalSkillManager
from pretor.core.global_state_machine.individual_manager import GlobalIndividualManager
@ray.remote
class GlobalStateMachine:
def __init__(self, postgres_database: PostgresDatabase):
import sys
print("GSM __init__ START", file=sys.stderr, flush=True)
self.event_dict: Dict[str, PretorEvent] = {}
print(" event_dict done", file=sys.stderr, flush=True)
self._global_provider_manager = ProviderManager(postgres_database)
print(" provider_manager done", file=sys.stderr, flush=True)
self._global_tool_manager = GlobalToolManager()
print(" tool_manager done", file=sys.stderr, flush=True)
self._global_workflow_template_manager = WorkflowManager()
print(" workflow_template_manager done", file=sys.stderr, flush=True)
self._global_skill_manager = GlobalSkillManager()
print(" skill_manager done", file=sys.stderr, flush=True)
self._global_individual_manager = GlobalIndividualManager()
print(" individual_manager done", file=sys.stderr, flush=True)
self.postgres_database = postgres_database
print("GSM __init__ DONE", file=sys.stderr, flush=True)
async def init_state_machine(self):
await self._global_provider_manager.init_provider_register(self.postgres_database)
await self._global_individual_manager.init_individual_register(self.postgres_database)
async def add_provider_wrap(self, provider_type, provider_title, provider_url, provider_apikey, provider_owner):
return await self._global_provider_manager.add_provider(
provider_type=provider_type,
provider_title=provider_title,
provider_url=provider_url,
provider_apikey=provider_apikey,
provider_owner=provider_owner,
postgres_database=self.postgres_database
)
# Provider Manager Methods
def get_provider_list(self):
return self._global_provider_manager.get_provider_list()
def get_provider(self, provider_title):
return self._global_provider_manager.get_provider(provider_title)
async def delete_provider(self, provider_title: str):
return await self._global_provider_manager.delete_provider(provider_title, self.postgres_database)
# Tool Manager Methods
def get_tool_mapper(self):
return self._global_tool_manager.tool_mapper
def get_tool_list(self, agent_name: str):
# get_tool_list didn't actually exist on tool_manager, let's implement it to return the tools
# for a specific agent name (or scope)
tools = self._global_tool_manager.tool_mapper.get(agent_name, {})
# also include default tools
default_tools = self._global_tool_manager.tool_mapper.get("default", {})
merged_tools = {**default_tools, **tools}
return merged_tools
# Workflow Template Manager Methods
def get_all_workflow_templates(self):
return self._global_workflow_template_manager.get_all_workflow_templates()
def add_workflow_template(self, template_name: str, workflow_template):
return self._global_workflow_template_manager.add_workflow_template(template_name, workflow_template)
def delete_workflow_template(self, template_name: str):
return self._global_workflow_template_manager.delete_workflow_template(template_name)
def generate_workflow_template(self, workflow_template):
return self._global_workflow_template_manager.generate_workflow_template(workflow_template)
# Skill Manager Methods
def add_skill(self, skill_name: str):
return self._global_skill_manager.add_skill(skill_name)
def get_skill_list(self):
return self._global_skill_manager.get_skill_list()
def remove_skill(self, skill_name: str):
return self._global_skill_manager.remove_skill(skill_name)
# Individual Manager Methods
def add_individual(self, agent_id: str, config):
return self._global_individual_manager.add_individual(agent_id, config)
def get_individual(self, agent_id: str):
return self._global_individual_manager.get_individual(agent_id)
def remove_individual(self, agent_id: str):
return self._global_individual_manager.remove_individual(agent_id)
def list_individuals(self):
return self._global_individual_manager.list_individuals()
###以下方法为event_dict方法
def add_event(self, event: PretorEvent) -> None:
event.pending_queue = asyncio.Queue()
event.receive_queue = asyncio.Queue()
self.event_dict[event.trace_id] = event
def delete_event(self, trace_id: str) -> None:
del self.event_dict[trace_id]
def get_event(self, trace_id: str) -> PretorEvent:
return self.event_dict.get(trace_id, None)
def update_attachment(self, trace_id: str, attachment: Dict[str, str]) -> None:
self.event_dict[trace_id].attachment = attachment
def update_workflow(self, trace_id: str, workflow: PretorWorkflow) -> None:
self.event_dict[trace_id].workflow = workflow
def get_workflow(self, trace_id: str) -> PretorWorkflow:
return self.event_dict[trace_id].workflow
def list_events(self) -> list[dict]:
result = []
for trace_id, event in self.event_dict.items():
workflow_title = event.workflow.title if event.workflow else None
workflow_status = event.workflow.status.status if event.workflow and event.workflow.status else None
result.append({
"event_id": trace_id,
"workflow_title": workflow_title,
"status": workflow_status,
"user_name": event.user_name,
"message": event.message,
})
return result
async def put_pending(self, trace_id, item) -> None:
await self.event_dict[trace_id].pending_queue.put(item)
async def get_pending(self, trace_id) -> str:
return await self.event_dict[trace_id].pending_queue.get()
async def put_received(self, trace_id, item) -> None:
await self.event_dict[trace_id].receive_queue.put(item)
async def get_received(self, trace_id) -> str:
return await self.event_dict[trace_id].receive_queue.get()
@@ -0,0 +1,62 @@
# Copyright 2026 zhaoxi826
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, Any
from pretor.utils.logger import get_logger
logger = get_logger('individual_manager')
class GlobalIndividualManager:
def __init__(self):
self._individuals: Dict[str, Dict[str, Any]] = {}
async def init_individual_register(self, postgres) -> None:
try:
try:
individuals = await postgres.get_all_worker_individual.remote()
for ind in individuals:
agent_id = getattr(ind, 'agent_id', None)
if agent_id:
self._individuals[agent_id] = ind.model_dump() if hasattr(ind, 'model_dump') else dict(ind)
logger.info(f"成功从数据库拉取了 {len(self._individuals)} 个 Worker Individual 配置。")
except AttributeError:
logger.warning("数据库中 get_all_worker_individual 方法未实现,跳过全量加载。可以在将来完善该接口。")
except Exception as e:
# 捕获因 Ray 调用目标方法不存在引发的异常
if "has no attribute 'get_all_worker_individual'" in str(e):
logger.warning("数据库 individual_database 中缺少 get_all_worker_individual 方法,无法全量拉取。")
else:
raise e
except Exception as e:
logger.error(f"从数据库拉取 Worker Individual 配置失败: {e}")
def add_individual(self, agent_id: str, config: Dict[str, Any]) -> None:
"""
注册一个 worker individual
config 可以包含 type, prompt, provider_title, model_id 等
"""
config["agent_id"] = agent_id
self._individuals[agent_id] = config
def get_individual(self, agent_id: str) -> Dict[str, Any]:
"""
获取一个 worker individual 的配置
"""
return self._individuals.get(agent_id, None)
def remove_individual(self, agent_id: str) -> None:
if agent_id in self._individuals:
del self._individuals[agent_id]
def list_individuals(self) -> Dict[str, Dict[str, Any]]:
return self._individuals
@@ -0,0 +1,19 @@
# Copyright 2026 zhaoxi826
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pretor.core.global_state_machine.model_provider.base_provider import Provider, ProviderArgs
from pretor.core.global_state_machine.model_provider.openai_provider import OpenAIProvider
from pretor.core.global_state_machine.model_provider.claude_provider import ClaudeProvider
from pretor.core.global_state_machine.model_provider.deepseek_provider import DeepseekProvider
__all__ = ["Provider", "ProviderArgs", "OpenAIProvider", "ClaudeProvider", "DeepseekProvider"]
@@ -0,0 +1,96 @@
# Copyright 2026 zhaoxi826
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC, abstractmethod
from pydantic import BaseModel
from typing import List
from enum import Enum
class ProviderStatus(str, Enum):
UP = "up"
DOWN = "down"
class Provider(BaseModel):
provider_title: str
provider_url: str
provider_apikey: str
provider_models: List[str]
provider_type: str
provider_owner: str | None = None
provider_status: ProviderStatus = ProviderStatus.UP
class ProviderArgs(BaseModel):
provider_title: str
provider_url: str
provider_apikey: str
provider_owner: str
class BaseProvider(ABC):
@staticmethod
@abstractmethod
async def create_provider(provider_args: ProviderArgs) -> Provider:
"""
创建一个供应商,传入provider_args参数,打包为一个Provider对象
Args:
provider_args: 参数包,包含以下几个参数
provider_title: 供应商的别名
provider_url: 供应商的url
provider_apikey:供应商的apikey
Returns:
返回一个Provider对象,由provider_manager管理
"""
pass
@staticmethod
@abstractmethod
async def _load_models(provider_args: ProviderArgs) -> List[str]:
"""
加载模型列表
base_provider的字类应当按照供应商的api标准,向供应商的接口发送http请求从而或者供应商所提供的模型列表
Args:
provider_args: 参数包,包含以下几个参数
provider_title: 供应商的别名
provider_url: 供应商的url
provider_apikey:供应商的apikey
Returns:
返回一个列表,为http请求获取的模型列表
"""
pass
@staticmethod
@abstractmethod
def _return_provider(provider_args: ProviderArgs, provider_models: List[str]) -> Provider:
"""
包装Provider对象并返回
将provider_args和_load_models获取的方法包装为provider对象
Args:
provider_args: 参数包,包含以下几个参数
provider_title: 供应商的别名
provider_url: 供应商的url
provider_apikey:供应商的apikey
provider_models: 模型列表,为该供应商包含的模型列表
Returns:
返回一个Provider对象
"""
pass
@@ -0,0 +1,60 @@
from pretor.utils.retry import retry_on_retryable_error
# Copyright 2026 zhaoxi826
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pretor.core.global_state_machine.model_provider.base_provider import BaseProvider, Provider, ProviderArgs
import httpx
from typing import List
class ClaudeProvider(BaseProvider):
@staticmethod
async def create_provider(provider_args: ProviderArgs) -> Provider:
provider_models: List[str] = await ClaudeProvider._load_models(provider_args)
provider: Provider = ClaudeProvider._return_provider(provider_args, provider_models)
return provider
@staticmethod
@retry_on_retryable_error()
async def _load_models(provider_args: ProviderArgs) -> List[str]:
# Anthropic 官方需要 version 头
headers = {
"x-api-key": provider_args.provider_apikey,
"anthropic-version": "2023-06-01",
"Content-Type": "application/json"
}
# 如果是官方 API,通常使用 /v1/models (如果支持)
# 注意:很多时候 Anthropic 并不返回完整列表,如果请求失败,建议返回硬编码的常用模型
url = f"{provider_args.provider_url.rstrip('/')}/v1/models"
try:
async with httpx.AsyncClient(timeout=10.0) as client:
response = await client.get(url, headers=headers)
if response.status_code == 200:
data = response.json()
model_ids = [m["id"] for m in data.get("data", [])]
return sorted(model_ids)
else:
# 如果官方列表接口不可用,fallback 到已知常用模型
return ["claude-3-5-sonnet-20240620", "claude-3-opus-20240229", "claude-3-haiku-20240307"]
except Exception as e:
print(f"[{provider_args.provider_title}] 获取 Claude 模型列表错误: {e}")
return []
@staticmethod
def _return_provider(provider_args: ProviderArgs, provider_models: List[str]) -> Provider:
return Provider(provider_title=provider_args.provider_title,
provider_apikey=provider_args.provider_apikey,
provider_url=provider_args.provider_url,
provider_models=provider_models,
provider_type="claude")
@@ -0,0 +1,59 @@
# Copyright 2026 zhaoxi826
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pretor.utils.retry import retry_on_retryable_error
from pretor.core.global_state_machine.model_provider.base_provider import BaseProvider, Provider, ProviderArgs
import httpx
from typing import List
class DeepseekProvider(BaseProvider):
@staticmethod
async def create_provider(provider_args: ProviderArgs) -> Provider:
provider_models: List[str] = await DeepseekProvider._load_models(provider_args)
provider: Provider = DeepseekProvider._return_provider(provider_args, provider_models)
return provider
@staticmethod
@retry_on_retryable_error()
async def _load_models(provider_args: ProviderArgs) -> List[str]:
headers = {
"Authorization": f"Bearer {provider_args.provider_apikey}",
"Content-Type": "application/json"
}
url = f"{provider_args.provider_url}/models" if "/v1" in provider_args.provider_url else f"{provider_args.provider_url}/v1/models"
try:
async with httpx.AsyncClient(timeout=10.0) as client:
response = await client.get(url, headers=headers)
if response.status_code != 200:
print(f"[{provider_args.provider_title}] 获取模型失败: {response.status_code}")
return []
data = response.json()
raw_models = data.get("data", [])
model_ids = [m["id"] for m in raw_models]
return sorted(model_ids)
except httpx.RequestError as e:
from pretor.utils.error import RetryableError
print(f"[{provider_args.provider_title}] 网络请求异常: {e}")
raise RetryableError(f"[{provider_args.provider_title}] 网络请求异常: {e}") from e
except Exception as e:
print(f"[{provider_args.provider_title}] 解析模型列表时发生错误: {e}")
return []
@staticmethod
def _return_provider(provider_args: ProviderArgs, provider_models: List[str]) -> Provider:
return Provider(provider_title=provider_args.provider_title,
provider_apikey=provider_args.provider_apikey,
provider_url=provider_args.provider_url,
provider_models=provider_models,
provider_type="deepseek")
@@ -0,0 +1,59 @@
# Copyright 2026 zhaoxi826
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pretor.utils.retry import retry_on_retryable_error
from pretor.core.global_state_machine.model_provider.base_provider import BaseProvider, Provider, ProviderArgs
import httpx
from typing import List
class OpenAIProvider(BaseProvider):
@staticmethod
async def create_provider(provider_args: ProviderArgs) -> Provider:
provider_models: List[str] = await OpenAIProvider._load_models(provider_args)
provider: Provider = OpenAIProvider._return_provider(provider_args, provider_models)
return provider
@staticmethod
@retry_on_retryable_error()
async def _load_models(provider_args: ProviderArgs) -> List[str]:
headers = {
"Authorization": f"Bearer {provider_args.provider_apikey}",
"Content-Type": "application/json"
}
url = f"{provider_args.provider_url}/models" if "/v1" in provider_args.provider_url else f"{provider_args.provider_url}/v1/models"
try:
async with httpx.AsyncClient(timeout=10.0) as client:
response = await client.get(url, headers=headers)
if response.status_code != 200:
print(f"[{provider_args.provider_title}] 获取模型失败: {response.status_code}")
return []
data = response.json()
raw_models = data.get("data", [])
model_ids = [m["id"] for m in raw_models]
return sorted(model_ids)
except httpx.RequestError as e:
from pretor.utils.error import RetryableError
print(f"[{provider_args.provider_title}] 网络请求异常: {e}")
raise RetryableError(f"[{provider_args.provider_title}] 网络请求异常: {e}") from e
except Exception as e:
print(f"[{provider_args.provider_title}] 解析模型列表时发生错误: {e}")
return []
@staticmethod
def _return_provider(provider_args: ProviderArgs, provider_models: List[str]) -> Provider:
return Provider(provider_title=provider_args.provider_title,
provider_apikey=provider_args.provider_apikey,
provider_url=provider_args.provider_url,
provider_models=provider_models,
provider_type="openai")
@@ -0,0 +1,86 @@
# Copyright 2026 zhaoxi826
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pretor.core.global_state_machine.model_provider import Provider, OpenAIProvider, ClaudeProvider, DeepseekProvider
from typing import Dict, Type
class ProviderManager:
"""
模型供应商管理器 (ProviderManager)。
负责维护不同的 LLM 协议适配器,提供从配置注册到 Agent 实例化的全生命周期管理。
"""
# --- 类属性显式标注 (IDE 友好) ---
provider_mapper: Dict[str, Type[Provider]]
"""协议映射表:键为协议名(如 'openai'),值为对应的 Provider 类。"""
provider_register: Dict[str, Provider]
"""供应商注册表:键为用户自定义别名,值为已实例化的 Provider 对象。"""
def __init__(self, postgres):
self.provider_mapper = {"openai": OpenAIProvider,
"claude": ClaudeProvider,
"deepseek": DeepseekProvider}
self.provider_register = {}
async def init_provider_register(self, postgres) -> None:
providers = await postgres.get_provider.remote()
for provider in providers:
self.provider_register[provider.provider_title] = provider
async def add_provider(self, provider_type, provider_title, provider_url, provider_apikey, provider_owner, postgres_database) -> None:
from pretor.core.global_state_machine.model_provider import ProviderArgs
from pretor.utils.logger import get_logger
logger = get_logger('provider_manager')
import httpx
provider_args: ProviderArgs = ProviderArgs(provider_title=provider_title,
provider_url=provider_url,
provider_apikey=provider_apikey,
provider_owner=provider_owner)
try:
import ulid
provider_class = self.provider_mapper.get(provider_type, None)
if provider_class is None:
logger.warning(f"Provider type {provider_type} is not supported.")
return None
provider: Provider = await provider_class.create_provider(provider_args)
provider.provider_owner = provider_owner
self.provider_register[provider_title] = provider
await postgres_database.add_provider_db.remote(
provider_id=str(ulid.ULID()),
provider_title=provider.provider_title,
provider_url=provider.provider_url,
provider_apikey=provider.provider_apikey,
provider_models=provider.provider_models,
provider_type=provider.provider_type,
provider_owner=provider.provider_owner)
logger.info(f"已添加适配器{provider_title}")
except httpx.RequestError as e:
from pretor.utils.error import RetryableError
logger.warning(f"[{provider_args.provider_title}] 网络请求异常: {e}")
raise RetryableError(f"[{provider_args.provider_title}] 网络请求异常: {e}") from e
except Exception as e:
logger.warning(f"[{provider_args.provider_title}] 解析模型列表时发生错误: {e}")
def get_provider_list(self):
return self.provider_register
def get_provider(self, provider_title):
return self.provider_register.get(provider_title)
async def delete_provider(self, provider_title: str, postgres_database) -> None:
if provider_title in self.provider_register:
provider = self.provider_register[provider_title]
await postgres_database.delete_provider_db.remote( provider_id=provider.provider_id)
del self.provider_register[provider_title]
@@ -0,0 +1,75 @@
# Copyright 2026 zhaoxi826
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Tuple, Dict
from collections import defaultdict
import pathlib
import json
class GlobalSkillManager:
skill_mapper = Dict[str,Tuple[str]]
"""skill的存储表"""
def __init__(self):
self.skill_mapper = defaultdict(tuple)
import os
skill_plugin_dir = pathlib.Path(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "plugin", "skill")))
if not skill_plugin_dir.exists() or not skill_plugin_dir.is_dir():
return
for item in skill_plugin_dir.iterdir():
if item.is_dir() and not item.name.startswith((".", "__")):
json_path = item / "skill.json" # 拼接文件路径
if json_path.exists():
try:
with open(json_path, "r", encoding="utf-8") as f:
skill = json.load(f)
# 提取并映射
name = skill.get("name")
if name:
self.skill_mapper[name] = (
skill.get("description", ""),
skill.get("instructions", "")
)
except (json.JSONDecodeError, OSError) as e:
print(f"警告: 加载插件 {item.name} 失败: {e}")
def add_skill(self, skill_name: str) -> None:
"""Add a skill to the manager by reading its skill.json from the path"""
import os
skill_plugin_dir = pathlib.Path(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "plugin", "skill")))
item = skill_plugin_dir / skill_name
if item.is_dir() and not item.name.startswith((".", "__")):
json_path = item / "skill.json"
if json_path.exists():
try:
with open(json_path, "r", encoding="utf-8") as f:
skill = json.load(f)
name = skill.get("name")
if name:
self.skill_mapper[name] = (
skill.get("description", ""),
skill.get("instructions", "")
)
except (json.JSONDecodeError, OSError) as e:
print(f"警告: 加载插件 {item.name} 失败: {e}")
def get_skill_list(self) -> dict:
"""Return all skills currently loaded."""
return self.skill_mapper
def remove_skill(self, skill_name: str) -> None:
"""Remove a skill from the manager mapping."""
if skill_name in self.skill_mapper:
del self.skill_mapper[skill_name]
@@ -0,0 +1,52 @@
# Copyright 2026 zhaoxi826
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pathlib
import importlib
import inspect
from collections import defaultdict
from pretor.plugin.tool_plugin.base_tool import BaseToolData
from typing import Dict, Type
from pretor.utils.logger import get_logger
logger = get_logger('tool_manager')
class GlobalToolManager:
tool_mapper: Dict[str, Dict[str, Type[BaseToolData]]]
def __init__(self):
self.tool_mapper = defaultdict(dict)
tool_plugin_dir = pathlib.Path(__file__).parent.parent.parent / "plugin" / "tool_plugin"
if not tool_plugin_dir.exists() or not tool_plugin_dir.is_dir():
return
for item in tool_plugin_dir.iterdir():
if item.is_dir() and not item.name.startswith("__"):
plugin_name = item.name
module_name = f"pretor.plugin.tool_plugin.{plugin_name}"
try:
module = importlib.import_module(module_name)
for name, obj in inspect.getmembers(module, inspect.isclass):
if issubclass(obj, BaseToolData) and obj is not BaseToolData:
# It's a valid tool class
action_scopes = obj.model_fields.get("action_scope").default
if not action_scopes:
self.tool_mapper["default"][plugin_name] = obj
else:
for scope in action_scopes:
self.tool_mapper[scope][plugin_name] = obj
except Exception as e:
logger.warning(f"Failed to load tool plugin {plugin_name}: {e}")