feat(system):优化后端

1.新增后端测试
2.增加了后端的加密
3.增加了i18n(国际化)
This commit is contained in:
2026-05-31 15:39:34 +00:00
parent affe460180
commit 99520c69d7
118 changed files with 8174 additions and 1491 deletions
@@ -13,47 +13,129 @@
# limitations under the License.
import ray
from kilostar.core.global_state_machine.provider_manager import ProviderManager
from kilostar.core.global_state_machine.tool_manager import GlobalToolManager
from kilostar.core.postgres_database import PostgresDatabase
from kilostar.core.global_state_machine.skill_manager import GlobalSkillManager
from typing import Any, Dict, List, Optional, Tuple
from kilostar.core.global_state_machine.individual_manager import (
GlobalIndividualManager,
)
from kilostar.core.global_state_machine.provider_manager import ProviderManager
from kilostar.core.global_state_machine.skill_manager import GlobalSkillManager
from kilostar.core.global_state_machine.tool_manager import GlobalToolManager
from kilostar.core.global_state_machine.gsm_snapshot import GSMSnapshot
from kilostar.core.postgres_database import PostgresDatabase
@ray.remote
class GlobalStateMachine:
"""全局状态机 Actor,统一持有 Provider/Tool/Skill/Individual 四个注册表。
"""全局状态机 Actor,统一持有 Provider/Tool/Skill/Individual/MCP/CustomToolset 注册表。
其它 Actor 通过 ``ray.get_actor("global_state_machine")`` 拿到本实例,
再调用本类暴露的方法来读写各注册表,避免每个 Actor 各自维护一份状态
所有持久化都走 PostgresDatabase;启动时由 ``init_state_machine`` 一次性把
Provider / MCP / Tool config / Custom toolset 拉到内存,保证后续读操作零等待
"""
def __init__(self, postgres_database: PostgresDatabase):
import sys
print("GSM __init__ START", file=sys.stderr, flush=True)
print(" event_dict done", file=sys.stderr, flush=True)
self._global_provider_manager = ProviderManager(postgres_database)
print(" provider_manager done", file=sys.stderr, flush=True)
self._global_tool_manager = GlobalToolManager()
print(" tool_manager done", file=sys.stderr, flush=True)
self._global_skill_manager = GlobalSkillManager()
print(" skill_manager done", file=sys.stderr, flush=True)
self._global_individual_manager = GlobalIndividualManager()
print(" individual_manager done", file=sys.stderr, flush=True)
# 内存注册表(启动时由 init_state_machine 从 DB 加载)
self._mcp_servers: Dict[str, Dict[str, Any]] = {}
self._tool_configs: Dict[str, Dict[str, Any]] = {}
self._custom_toolsets: Dict[str, Dict[str, Any]] = {}
# 配置快照与版本号:每次写入 → version+=1 → ray.put 新 snapshot
# 读端通过 current_config_ref 拿 ref 后用 ray.get 直读,绕开 actor 单线程瓶颈
self._config_version: int = 0
self._current_ref: Optional[ray.ObjectRef] = None
self.postgres_database = postgres_database
print("GSM __init__ DONE", file=sys.stderr, flush=True)
async def init_state_machine(self):
"""从数据库加载 Provider/Individual 注册表到内存。"""
"""启动期一次性把 Provider/Individual/MCP/ToolConfig/CustomToolset 拉到内存。"""
await self._global_provider_manager.init_provider_register(
self.postgres_database
)
await self._global_individual_manager.init_individual_register(
self.postgres_database
)
# MCP servers
rows = await self.postgres_database.list_mcp_servers_db.remote()
self._mcp_servers = {row["server_id"]: row for row in rows}
# Tool configs
cfg_rows = await self.postgres_database.list_tool_configs_db.remote()
self._tool_configs = {row["tool_name"]: row["config"] for row in cfg_rows}
# Custom toolsets
ts_rows = await self.postgres_database.list_custom_toolsets.remote()
self._custom_toolsets = {row["toolset_id"]: row for row in ts_rows}
# 让 tool_manager 立刻把 custom toolset 装配成 FunctionToolset
self._global_tool_manager.rebuild_custom_toolsets(self._custom_toolsets)
# 启动期一次性发布 v1 快照,让等待中的读端立刻可用
self._publish_snapshot()
# ─── Snapshot 发布(Object Store 读路径) ────────────────────
def _build_snapshot(self) -> GSMSnapshot:
"""把当前内存状态打包成 GSMSnapshot,调用方应已确保数据是最新的。
注意 ``tool_funcs`` 这里被拍平成 ``{tool_name: callable}`` —— 内部按
scope 分桶的细节在 snapshot 层不暴露,task 端只关心"我能调哪些函数"
如果同名工具在多个 scope 出现,后写入的覆盖(系统工具与第三方互斥,
实际不会冲突)。
``system_tools_by_scope`` 单独保留按 scope 分桶的工具名列表,让客户端
在自己进程里复刻 ``get_toolsets_for_scope`` 的合并语义(fetch_snapshot
用户调 ``build_toolsets_for_scope`` 即可重建 FunctionToolset 列表)。
"""
tm = self._global_tool_manager
flat_funcs: Dict[str, Any] = {}
system_tools_by_scope: Dict[str, List[str]] = {}
for _scope, name_to_func in tm._tool_funcs.items():
system_tools_by_scope[_scope] = list(name_to_func.keys())
for name, func in name_to_func.items():
flat_funcs[name] = func
return GSMSnapshot(
version=self._config_version,
providers=dict(self._global_provider_manager.provider_register),
individuals=dict(self._global_individual_manager._individuals),
mcp_servers=dict(self._mcp_servers),
tool_configs=dict(self._tool_configs),
custom_toolsets=dict(self._custom_toolsets),
skills=dict(self._global_skill_manager.skill_mapper),
tool_metadata=dict(tm.tool_metadata),
tool_funcs=flat_funcs,
third_party_funcs=dict(tm._third_party_funcs),
tool_mapper={
scope: dict(name_to_cls)
for scope, name_to_cls in tm.tool_mapper.items()
},
system_tools_by_scope=system_tools_by_scope,
)
def _publish_snapshot(self) -> None:
"""版本号 +1 并把当前状态 put 到 Ray Object Store。
旧 ref 会因为引用计数归零而进入回收队列;正在执行的 task 已经把 ref
拷贝到了自己的进程,dec 不会影响它们的读取。
"""
self._config_version += 1
self._current_ref = ray.put(self._build_snapshot())
async def current_config_ref(self) -> Tuple[int, ray.ObjectRef]:
"""返回 ``(version, ObjectRef)``,调用方拿了 ref 后用 ``ray.get`` 自取。
**不要**直接返回 snapshot 对象 —— 那样会走 actor RPC 反序列化,丧失
object store 的共享内存优势。返回 ref 才能让调用方在自己进程里 ray.get。
"""
if self._current_ref is None:
self._publish_snapshot()
return self._config_version, self._current_ref
async def current_version(self) -> int:
"""轻量版:只返回当前版本号,用于读端判断本地缓存是否还新。"""
return self._config_version
# ─── Provider ──────────────────────────────────────────────
async def add_provider_wrap(
self,
@@ -64,7 +146,7 @@ class GlobalStateMachine:
provider_owner,
):
"""新增一个模型 Provider:内存注册 + 数据库持久化一并完成。"""
return await self._global_provider_manager.add_provider(
result = await self._global_provider_manager.add_provider(
provider_type=provider_type,
provider_title=provider_title,
provider_url=provider_url,
@@ -72,8 +154,9 @@ class GlobalStateMachine:
provider_owner=provider_owner,
postgres_database=self.postgres_database,
)
self._publish_snapshot()
return result
# Provider Manager Methods
def get_provider_list(self):
"""返回内存中已登记的全部 Provider。"""
return self._global_provider_manager.get_provider_list()
@@ -84,11 +167,14 @@ class GlobalStateMachine:
async def delete_provider(self, provider_title: str):
"""删除一个 Provider:内存注册 + 数据库持久化一并完成。"""
return await self._global_provider_manager.delete_provider(
result = await self._global_provider_manager.delete_provider(
provider_title, self.postgres_database
)
self._publish_snapshot()
return result
# ─── Tool / Toolset ────────────────────────────────────────
# Tool Manager Methods
def get_tool_mapper(self):
"""返回 agent_name -> {tool_name: callable} 的全量映射。"""
return self._global_tool_manager.tool_mapper
@@ -96,37 +182,152 @@ class GlobalStateMachine:
def get_tool_list(self, agent_name: str):
"""返回某个 agent 可用的工具集(其专属工具与 default 工具的并集)。"""
tools = self._global_tool_manager.tool_mapper.get(agent_name, {})
# also include default tools
default_tools = self._global_tool_manager.tool_mapper.get("default", {})
merged_tools = {**default_tools, **tools}
return merged_tools
return {**default_tools, **tools}
def get_tool_categories(self):
"""返回工具按分类聚合的完整信息。"""
return {
"system": self._global_tool_manager.get_system_tools(),
"third_party": self._global_tool_manager.get_third_party_tools(),
"by_category": {
"system": self._global_tool_manager.get_tools_by_category("system"),
"search": self._global_tool_manager.get_tools_by_category("search"),
"mcp": self._global_tool_manager.get_tools_by_category("mcp"),
"other": self._global_tool_manager.get_tools_by_category("other"),
},
"all": self._global_tool_manager.get_all_tools(),
}
def get_toolsets_for_scope(self, scope: str) -> List[Any]:
"""返回某个 scope 下的"系统 + 自定义工具组"toolset 列表(不含 MCP)。"""
return self._global_tool_manager.get_toolsets_for_scope(scope)
# ─── MCP Server Registry ───────────────────────────────────
async def add_mcp_server(self, server_id: str, config: Dict[str, Any]) -> bool:
"""注册一个 MCP 服务器配置(写库 → 写内存)。"""
saved = await self.postgres_database.upsert_mcp_server.remote(server_id, config)
self._mcp_servers[server_id] = saved
self._publish_snapshot()
return True
def get_mcp_server(self, server_id: str) -> Optional[Dict[str, Any]]:
return self._mcp_servers.get(server_id)
def list_mcp_servers(self) -> List[Dict[str, Any]]:
return [
{"server_id": sid, **cfg} for sid, cfg in self._mcp_servers.items()
]
async def delete_mcp_server(self, server_id: str) -> bool:
ok = await self.postgres_database.delete_mcp_server_db.remote(server_id)
self._mcp_servers.pop(server_id, None)
self._publish_snapshot()
return bool(ok)
def get_mcp_server_configs(self) -> Dict[str, Dict[str, Any]]:
"""返回原始 MCP 服务器配置字典(供节点创建 toolsets 时使用)。"""
return dict(self._mcp_servers)
# ─── Tool ConfigTavily API key 等)─────────────────────
async def set_tool_config(self, tool_name: str, config: Dict[str, Any]) -> bool:
"""整体覆盖某工具的运行期配置(敏感字段在 DAO 内自动加密)。"""
saved = await self.postgres_database.upsert_tool_config.remote(
tool_name, config
)
self._tool_configs[tool_name] = saved["config"]
self._publish_snapshot()
return True
def get_tool_config(self, tool_name: str) -> Dict[str, Any]:
"""按工具名取出配置;不存在则返回空字典。"""
return dict(self._tool_configs.get(tool_name, {}))
def list_tool_configs(self) -> Dict[str, Dict[str, Any]]:
"""返回全部已配置工具的配置(包含敏感字段,调用方需自行脱敏)。"""
return dict(self._tool_configs)
async def delete_tool_config(self, tool_name: str) -> bool:
ok = await self.postgres_database.delete_tool_config_db.remote(tool_name)
self._tool_configs.pop(tool_name, None)
self._publish_snapshot()
return bool(ok)
# ─── Custom Toolset(用户自定义工具组)──────────────────
async def add_custom_toolset(
self,
toolset_id: str,
name: str,
tools: List[str],
description: Optional[str] = None,
owner_id: Optional[str] = None,
) -> Dict[str, Any]:
"""新增/更新一个自定义工具组:仅允许引用非 system/非 mcp 的工具。"""
# 校验:只能放第三方(非 system / 非 mcp)工具
invalid = [
t for t in tools if not self._global_tool_manager.is_third_party_tool(t)
]
if invalid:
raise ValueError(
f"自定义工具组只允许包含第三方工具,以下不合法:{invalid}"
)
saved = await self.postgres_database.upsert_custom_toolset.remote(
toolset_id=toolset_id,
name=name,
tools=list(tools),
description=description,
owner_id=owner_id,
)
self._custom_toolsets[toolset_id] = saved
self._global_tool_manager.rebuild_custom_toolsets(self._custom_toolsets)
self._publish_snapshot()
return saved
def list_custom_toolsets(self) -> List[Dict[str, Any]]:
return list(self._custom_toolsets.values())
def get_custom_toolset(self, toolset_id: str) -> Optional[Dict[str, Any]]:
return self._custom_toolsets.get(toolset_id)
async def delete_custom_toolset(self, toolset_id: str) -> bool:
ok = await self.postgres_database.delete_custom_toolset.remote(toolset_id)
self._custom_toolsets.pop(toolset_id, None)
self._global_tool_manager.rebuild_custom_toolsets(self._custom_toolsets)
self._publish_snapshot()
return bool(ok)
# ─── Skill ────────────────────────────────────────────────
# Skill Manager Methods
def add_skill(self, skill_name: str):
"""注册一个新的 Skill 名称到 Skill 注册表。"""
return self._global_skill_manager.add_skill(skill_name)
result = self._global_skill_manager.add_skill(skill_name)
self._publish_snapshot()
return result
def get_skill_list(self):
"""返回全部已注册的 Skill 名称。"""
return self._global_skill_manager.get_skill_list()
def remove_skill(self, skill_name: str):
"""从注册表中移除一个 Skill。"""
return self._global_skill_manager.remove_skill(skill_name)
result = self._global_skill_manager.remove_skill(skill_name)
self._publish_snapshot()
return result
# ─── Individual ───────────────────────────────────────────
# Individual Manager Methods
def add_individual(self, agent_id: str, config):
"""把一个 Worker Individual 的运行期配置加入注册表。"""
return self._global_individual_manager.add_individual(agent_id, config)
result = self._global_individual_manager.add_individual(agent_id, config)
self._publish_snapshot()
return result
def get_individual(self, agent_id: str):
"""按 agent_id 取出某个 Worker Individual 的配置。"""
return self._global_individual_manager.get_individual(agent_id)
def remove_individual(self, agent_id: str):
"""从注册表中移除一个 Worker Individual。"""
return self._global_individual_manager.remove_individual(agent_id)
result = self._global_individual_manager.remove_individual(agent_id)
self._publish_snapshot()
return result
def list_individuals(self):
"""返回当前注册的全部 Worker Individual 列表。"""
return self._global_individual_manager.list_individuals()
@@ -0,0 +1,184 @@
# Copyright 2026 zhaoxi826
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""GSM 快照对象与客户端拉取工具。
设计动机:把 GSM actor 内存里的"读路径配置"打包成可放进 Ray Object Store
的不可变快照。读端不再走 actor RPC,而是 ``fetch_snapshot()`` 一次拿到全量
当前配置,亚毫秒级共享内存读,绕开单 actor 的吞吐瓶颈。
GSM 仍然是 source of truth + 写入串行化器,但读路径解耦:
- 写入 → GSM 内存更新 → ``ray.put(snapshot)`` 拿到新 ObjectRef → 版本号 +1
- 读取 → ``current_config_ref()`` 拿 (version, ref) → ``ray.get(ref)`` 直读
旧的 ``get_provider / get_individual / ...`` 接口保留不动,是低频路径的兜底;
新代码(特别是 skill task 这种高并发热路径)应优先走 ``fetch_snapshot``。
"""
from __future__ import annotations
import asyncio
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, List, Optional, Tuple
import ray
from kilostar.core.global_state_machine.model_provider.base_provider import Provider
from kilostar.utils.logger import get_logger
_logger = get_logger("gsm_snapshot")
@dataclass
class GSMSnapshot:
"""GSM 配置的不可变快照,所有字段都必须 cloudpickle 友好。
本类故意不放 ``FunctionToolset`` 实例 —— pydantic-ai toolset 的可序列化性
随版本可能变动,让 task 端按 ``tool_funcs`` + ``tool_mapper`` 自己装配
既隔离了 pydantic-ai 的实现风险,又让 snapshot 体积更小。
"""
version: int = 0
providers: Dict[str, Provider] = field(default_factory=dict)
individuals: Dict[str, Dict[str, Any]] = field(default_factory=dict)
mcp_servers: Dict[str, Dict[str, Any]] = field(default_factory=dict)
tool_configs: Dict[str, Dict[str, Any]] = field(default_factory=dict)
custom_toolsets: Dict[str, Dict[str, Any]] = field(default_factory=dict)
skills: Dict[str, Tuple[str, str]] = field(default_factory=dict)
tool_metadata: Dict[str, Dict[str, Any]] = field(default_factory=dict)
tool_funcs: Dict[str, Callable[..., Any]] = field(default_factory=dict)
third_party_funcs: Dict[str, Callable[..., Any]] = field(default_factory=dict)
tool_mapper: Dict[str, Dict[str, type]] = field(default_factory=dict)
# ``{scope: [tool_name, ...]}``:系统工具按 scope 维护的工具名清单。
# 客户端按名字 + ``tool_funcs`` 在自己进程里重建 FunctionToolset
# 避开把不可序列化/版本耦合的 toolset 实例塞进快照的坑。
system_tools_by_scope: Dict[str, List[str]] = field(default_factory=dict)
_local_cache: Dict[str, Any] = {"version": -1, "snapshot": None}
_cache_lock = asyncio.Lock()
async def fetch_snapshot(
*,
use_cache: bool = True,
gsm_actor: Optional[Any] = None,
) -> GSMSnapshot:
"""拉取当前 GSM 快照。
优先走"版本号检查 + ObjectRef 直读"路径:
1. 调 ``gsm.current_version.remote()`` 看本地缓存是否还新(一次轻量 RPC)
2. 若本地缓存版本号一致,直接返回缓存(亚毫秒,零网络)
3. 否则调 ``gsm.current_config_ref.remote()`` 拿 ref``ray.get`` 解出
4. 更新本地缓存
Args:
use_cache: 默认开启进程内 LRU 缓存(实际是单槽位,持有当前版本);
测试或诊断场景可关掉强制重拉。
gsm_actor: 可选传入 GSM actor handle;省略时通过 ``ray_actor_hook`` 获取。
Note:
本函数在 task / actor 进程内多次调用是廉价的;建议每次需要 config 时
现取,不要把 snapshot 长期持有跨任务边界(避免 ObjectRef 阻碍回收)。
"""
if gsm_actor is None:
from kilostar.utils.ray_hook import ray_actor_hook
gsm_actor = ray_actor_hook("global_state_machine").global_state_machine
if use_cache:
async with _cache_lock:
try:
latest_version = await gsm_actor.current_version.remote()
except Exception:
latest_version = None
if (
latest_version is not None
and _local_cache.get("version") == latest_version
and _local_cache.get("snapshot") is not None
):
return _local_cache["snapshot"]
version, ref = await gsm_actor.current_config_ref.remote()
snapshot = ray.get(ref)
_local_cache["version"] = version
_local_cache["snapshot"] = snapshot
return snapshot
version, ref = await gsm_actor.current_config_ref.remote()
return ray.get(ref)
def reset_local_cache() -> None:
"""清空本进程内的快照缓存(测试用)。"""
_local_cache["version"] = -1
_local_cache["snapshot"] = None
# ─── 客户端 helper:从快照重建本地视图 ─────────────────────────────
def build_toolsets_for_scope(
snapshot: GSMSnapshot, scope: str
) -> List[Any]:
"""在调用方进程里按 ``snapshot`` 现场组装 FunctionToolset 列表。
复刻 ``GlobalToolManager.get_toolsets_for_scope`` 的合并逻辑:
- 系统 toolset:按 ``default`` + ``scope`` 两个 bucket 拼装
- 自定义 toolset``custom_toolsets`` 里所有有效项
返回的 toolset 是 *进程局部* 的——pydantic-ai FunctionToolset 实例不能跨进程
共享,但函数对象本身已经躺在 snapshot 里被 cloudpickle 还原过,
重新 ``FunctionToolset(tools=[...])`` 几乎零代价。
"""
try:
from pydantic_ai.toolsets import FunctionToolset
except ImportError:
_logger.warning("pydantic_ai.toolsets unavailable; cannot build toolsets")
return []
result: List[Any] = []
for bucket in ("default", scope):
names = snapshot.system_tools_by_scope.get(bucket) or []
funcs = [snapshot.tool_funcs[n] for n in names if n in snapshot.tool_funcs]
if not funcs:
continue
try:
result.append(
FunctionToolset(tools=funcs, id=f"system::{bucket}")
)
except Exception as e: # pragma: no cover - 防御
_logger.error(f"build system toolset {bucket} failed: {e}")
for toolset_id, defn in snapshot.custom_toolsets.items():
names = defn.get("tools") or []
funcs = [
snapshot.third_party_funcs[n]
for n in names
if n in snapshot.third_party_funcs
]
if not funcs:
continue
try:
result.append(
FunctionToolset(tools=funcs, id=f"custom::{toolset_id}")
)
except Exception as e: # pragma: no cover - 防御
_logger.error(f"build custom toolset {toolset_id} failed: {e}")
return result
@@ -1,35 +1,42 @@
# Copyright 2026 zhaoxi826
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pathlib
import importlib
import inspect
from collections import defaultdict
from typing import Any, Callable, Dict, List, Type
from kilostar.plugin.tool_plugin.base_tool import BaseToolData
from typing import Dict, Type
from kilostar.utils.logger import get_logger
logger = get_logger("tool_manager")
_SYSTEM_BUCKET = "system"
class GlobalToolManager:
"""工具注册表:扫描 ``kilostar/plugin/tool_plugin/`` 下所有 BaseToolData 子类,
按 ``action_scope`` 分桶到 ``tool_mapper[scope][plugin_name]``;无 scope 的归入 ``default``。"""
按 ``action_scope`` 打包成 ``FunctionToolset``。
三类 toolset
- **system**``is_system=True`` 的工具,按 scope 分组
- **custom**:用户自定义工具组(由 ``rebuild_custom_toolsets`` 动态构建)
- **mcp**:由 ``mcp_helper`` 独立管理,不经过本类
``category="mcp"`` 的工具不会被本类管理。
"""
tool_metadata: Dict[str, Dict[str, Any]]
_tool_funcs: Dict[str, Dict[str, Callable]]
_system_toolsets: Dict[str, Any]
_custom_toolsets: Dict[str, Any]
_third_party_funcs: Dict[str, Callable]
tool_mapper: Dict[str, Dict[str, Type[BaseToolData]]]
def __init__(self):
def __init__(self) -> None:
self.tool_metadata = {}
self._tool_funcs = defaultdict(dict)
self._system_toolsets = {}
self._custom_toolsets = {}
self._third_party_funcs = {}
self.tool_mapper = defaultdict(dict)
tool_plugin_dir = (
@@ -39,21 +46,154 @@ class GlobalToolManager:
return
for item in tool_plugin_dir.iterdir():
if item.is_dir() and not item.name.startswith("__"):
plugin_name = item.name
module_name = f"kilostar.plugin.tool_plugin.{plugin_name}"
if not (item.is_dir() and not item.name.startswith("__")):
continue
plugin_name = item.name
module_name = f"kilostar.plugin.tool_plugin.{plugin_name}"
try:
module = importlib.import_module(module_name)
for name, obj in inspect.getmembers(module, inspect.isclass):
if issubclass(obj, BaseToolData) and obj is not BaseToolData:
# It's a valid tool class
action_scopes = obj.model_fields.get("action_scope").default
try:
module = importlib.import_module(module_name)
except Exception as e:
logger.warning(f"Failed to import tool plugin {plugin_name}: {e}")
continue
if not action_scopes:
self.tool_mapper["default"][plugin_name] = obj
else:
for scope in action_scopes:
self.tool_mapper[scope][plugin_name] = obj
except Exception as e:
logger.warning(f"Failed to load tool plugin {plugin_name}: {e}")
tool_data_cls = self._find_tool_data_class(module)
if tool_data_cls is None:
continue
tool_func = getattr(module, plugin_name, None)
if not callable(tool_func):
logger.warning(
f"Tool plugin '{plugin_name}' has no callable named "
f"'{plugin_name}' in its module; skipped."
)
continue
action_scopes = (
tool_data_cls.model_fields.get("action_scope").default or []
)
is_system = bool(tool_data_cls.model_fields.get("is_system").default)
category_field = tool_data_cls.model_fields.get("category")
category = (category_field.default if category_field else "other") or "other"
self.tool_metadata[plugin_name] = {
"name": plugin_name,
"is_system": is_system,
"category": category,
"action_scope": list(action_scopes),
}
if category == "mcp":
continue
scopes = [s for s in action_scopes if s] or ["default"]
if is_system:
for scope in scopes:
self._tool_funcs[scope][plugin_name] = tool_func
self.tool_mapper[scope][plugin_name] = tool_data_cls
else:
self._third_party_funcs[plugin_name] = tool_func
for scope in scopes:
self.tool_mapper[scope][plugin_name] = tool_data_cls
self._build_system_toolsets()
def _build_system_toolsets(self) -> None:
FunctionToolset = self._import_function_toolset()
if FunctionToolset is None:
return
for scope, name_to_func in self._tool_funcs.items():
if not name_to_func:
continue
try:
self._system_toolsets[scope] = FunctionToolset(
tools=list(name_to_func.values()),
id=f"system::{scope}",
)
except Exception as e:
logger.error(f"Failed to build system toolset {scope}: {e}")
def rebuild_custom_toolsets(self, custom_defs: Dict[str, Dict[str, Any]]) -> None:
"""根据 DB 中的自定义工具组定义重建 custom FunctionToolset。"""
FunctionToolset = self._import_function_toolset()
if FunctionToolset is None:
self._custom_toolsets = {}
return
new_map: Dict[str, Any] = {}
for toolset_id, defn in custom_defs.items():
tools_names = defn.get("tools") or []
funcs = [
self._third_party_funcs[n]
for n in tools_names
if n in self._third_party_funcs
]
if not funcs:
continue
try:
new_map[toolset_id] = FunctionToolset(
tools=funcs,
id=f"custom::{toolset_id}",
)
except Exception as e:
logger.error(f"Failed to build custom toolset {toolset_id}: {e}")
self._custom_toolsets = new_map
@staticmethod
def _import_function_toolset():
try:
from pydantic_ai.toolsets import FunctionToolset
return FunctionToolset
except ImportError:
logger.warning("pydantic_ai.toolsets unavailable")
return None
@staticmethod
def _find_tool_data_class(module) -> Type[BaseToolData] | None:
for _, obj in inspect.getmembers(module, inspect.isclass):
if issubclass(obj, BaseToolData) and obj is not BaseToolData:
return obj
return None
# ─── Toolset accessors ───
def get_system_toolset(self, scope: str) -> Any | None:
return self._system_toolsets.get(scope)
def get_toolsets_for_scope(self, scope: str) -> List[Any]:
"""合并 systemdefault + scope+ 全部 custom toolset。"""
result: List[Any] = []
for s in ("default", scope):
ts = self._system_toolsets.get(s)
if ts is not None:
result.append(ts)
result.extend(self._custom_toolsets.values())
return result
# ─── Metadata accessors ───
def is_third_party_tool(self, tool_name: str) -> bool:
return tool_name in self._third_party_funcs
def get_tools_by_category(self, category: str) -> List[Dict[str, Any]]:
return [m for m in self.tool_metadata.values() if m.get("category") == category]
def get_system_tools(self) -> List[Dict[str, Any]]:
return [m for m in self.tool_metadata.values() if m.get("is_system") is True]
def get_third_party_tools(self) -> List[Dict[str, Any]]:
return [
m
for m in self.tool_metadata.values()
if m.get("is_system") is not True and m.get("category") != "mcp"
]
def get_all_tools(self) -> List[Dict[str, Any]]:
return list(self.tool_metadata.values())
# 兼容旧接口
def get_non_system_tools(self) -> List[Dict[str, Any]]:
return self.get_third_party_tools()
def get_personal_tools(self) -> List[Dict[str, Any]]:
return self.get_third_party_tools()