# Copyright 2026 zhaoxi826 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import importlib from typing import Callable, Dict, List import pathlib from pretor.utils.ray_hook import ray_actor_hook from pretor.utils.logger import get_logger logger = get_logger('get_tool') _tool_cache: Dict[str, Callable] = {} _agent_tool_result_cache: Dict[str, List[Callable]] = {} def _get_tool_func(tool_name: str) -> Callable | None: func = _tool_cache.get(tool_name, None) if func: return func tool_plugin_dir = pathlib.Path(__file__).parent.parent.parent / "plugin" / "tool_plugin" / tool_name if not tool_plugin_dir.exists() or not tool_plugin_dir.is_dir(): return None module_name = f"pretor.plugin.tool_plugin.{tool_name}" try: module = importlib.import_module(module_name) func = getattr(module, tool_name) if not callable(func): return None _tool_cache[tool_name] = func return func except ModuleNotFoundError: logger.error(f"Module {module_name} not found") return None def del_tool_cache(tool_name: str) -> None: if tool_name in _tool_cache: del _tool_cache[tool_name] refresh_agent_tools() async def get_tool(agent_name: str) -> List[Callable]: cached = _agent_tool_result_cache.get(agent_name) if cached is not None: return cached global_state_machine = ray_actor_hook("global_state_machine") _tool_list = await global_state_machine.get_tool_list.remote(agent_name) tool_list = [] for tool_name in _tool_list.keys(): tool_func = _get_tool_func(tool_name) if tool_func: tool_list.append(tool_func) else: continue _agent_tool_result_cache[agent_name] = tool_list return tool_list def refresh_agent_tools() -> None: _agent_tool_result_cache.clear()