fix: 修复了get_tool调用缓存协程对象的bug
This commit is contained in:
parent
b934ee2e32
commit
d322826c87
|
|
@ -82,5 +82,8 @@ async def delete_skill(skill_name: str, _: TokenData = Depends(RoleChecker(allow
|
||||||
@resource_router.get("/tool")
|
@resource_router.get("/tool")
|
||||||
async def get_tools(_: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.USER))):
|
async def get_tools(_: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.USER))):
|
||||||
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
|
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
|
||||||
tools = await global_state_machine.get_tool_list.remote("default")
|
tool_mapper = await global_state_machine.get_tool_mapper.remote()
|
||||||
return {"tools": list(tools.keys())}
|
all_tool_names = set()
|
||||||
|
for scope_tools in tool_mapper.values():
|
||||||
|
all_tool_names.update(scope_tools.keys())
|
||||||
|
return {"tools": list(all_tool_names)}
|
||||||
|
|
@ -134,7 +134,7 @@ class ConsciousnessNode:
|
||||||
async def _run(self, payload: Union[ForSupervisoryInput, ForWorkflowInput, ForWorkflowEngineInput]) -> Union[ForSupervisoryNode, ForWorkflow, ForWorkflowEngine]:
|
async def _run(self, payload: Union[ForSupervisoryInput, ForWorkflowInput, ForWorkflowEngineInput]) -> Union[ForSupervisoryNode, ForWorkflow, ForWorkflowEngine]:
|
||||||
try:
|
try:
|
||||||
self.agent.retries = 3
|
self.agent.retries = 3
|
||||||
tool = get_tool("control_node")
|
tool = await get_tool("control_node")
|
||||||
if isinstance(payload, ForWorkflowEngineInput):
|
if isinstance(payload, ForWorkflowEngineInput):
|
||||||
deps = ConsciousnessNodeDeps(
|
deps = ConsciousnessNodeDeps(
|
||||||
original_command=payload.original_command,
|
original_command=payload.original_command,
|
||||||
|
|
|
||||||
|
|
@ -89,7 +89,7 @@ class ControlNode:
|
||||||
)
|
)
|
||||||
self.logger.debug(f"ControlNode: 开始执行工作流节点 [{payload.workflow_step.name}] (原生重试开启)")
|
self.logger.debug(f"ControlNode: 开始执行工作流节点 [{payload.workflow_step.name}] (原生重试开启)")
|
||||||
|
|
||||||
tool = get_tool("control_node")
|
tool = await get_tool("control_node")
|
||||||
|
|
||||||
result = await self.agent.run(
|
result = await self.agent.run(
|
||||||
f"请根据提供的 workflow_step 上下文,执行此步骤并输出结果。\n详细指令或附加数据:{payload.workflow_step.model_dump_json()}",
|
f"请根据提供的 workflow_step 上下文,执行此步骤并输出结果。\n详细指令或附加数据:{payload.workflow_step.model_dump_json()}",
|
||||||
|
|
|
||||||
|
|
@ -172,7 +172,7 @@ class SupervisoryNode:
|
||||||
if isinstance(payload, TerminationMessage):
|
if isinstance(payload, TerminationMessage):
|
||||||
prompt_message = f"【工作流执行结束报告】\n请将以下技术报告转化为对用户的友好回复:\n{message}"
|
prompt_message = f"【工作流执行结束报告】\n请将以下技术报告转化为对用户的友好回复:\n{message}"
|
||||||
self.agent.retries = 3
|
self.agent.retries = 3
|
||||||
tool = get_tool("supervisory_node")
|
tool = await get_tool("supervisory_node")
|
||||||
result = await self.agent.run(prompt_message,
|
result = await self.agent.run(prompt_message,
|
||||||
deps=deps,
|
deps=deps,
|
||||||
tools=tool)
|
tools=tool)
|
||||||
|
|
|
||||||
|
|
@ -17,6 +17,7 @@ from pretor.plugin.tool_plugin.base_tool import BaseToolData
|
||||||
import os
|
import os
|
||||||
|
|
||||||
class FileReaderData(BaseToolData):
|
class FileReaderData(BaseToolData):
|
||||||
|
is_system: bool = True
|
||||||
name: str = "file_reader"
|
name: str = "file_reader"
|
||||||
description: str = "读取本地文件的内容"
|
description: str = "读取本地文件的内容"
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -15,12 +15,12 @@
|
||||||
import importlib
|
import importlib
|
||||||
from typing import Callable, Dict, List
|
from typing import Callable, Dict, List
|
||||||
import pathlib
|
import pathlib
|
||||||
from functools import lru_cache
|
|
||||||
from pretor.utils.ray_hook import ray_actor_hook
|
from pretor.utils.ray_hook import ray_actor_hook
|
||||||
|
|
||||||
from pretor.utils.logger import get_logger
|
from pretor.utils.logger import get_logger
|
||||||
logger = get_logger('get_tool')
|
logger = get_logger('get_tool')
|
||||||
_tool_cache: Dict[str, Callable] = {}
|
_tool_cache: Dict[str, Callable] = {}
|
||||||
|
_agent_tool_result_cache: Dict[str, List[Callable]] = {}
|
||||||
|
|
||||||
|
|
||||||
def _get_tool_func(tool_name: str) -> Callable | None:
|
def _get_tool_func(tool_name: str) -> Callable | None:
|
||||||
|
|
@ -48,10 +48,13 @@ def del_tool_cache(tool_name: str) -> None:
|
||||||
del _tool_cache[tool_name]
|
del _tool_cache[tool_name]
|
||||||
refresh_agent_tools()
|
refresh_agent_tools()
|
||||||
|
|
||||||
@lru_cache(maxsize=1)
|
|
||||||
async def get_tool(agent_name: str) -> List[Callable]:
|
async def get_tool(agent_name: str) -> List[Callable]:
|
||||||
|
cached = _agent_tool_result_cache.get(agent_name)
|
||||||
|
if cached is not None:
|
||||||
|
return cached
|
||||||
global_state_machine = ray_actor_hook("global_state_machine")
|
global_state_machine = ray_actor_hook("global_state_machine")
|
||||||
_tool_list = await global_state_machine.get_tool_list.remote( agent_name)
|
_tool_list = await global_state_machine.get_tool_list.remote(agent_name)
|
||||||
tool_list = []
|
tool_list = []
|
||||||
for tool_name in _tool_list.keys():
|
for tool_name in _tool_list.keys():
|
||||||
tool_func = _get_tool_func(tool_name)
|
tool_func = _get_tool_func(tool_name)
|
||||||
|
|
@ -59,9 +62,8 @@ async def get_tool(agent_name: str) -> List[Callable]:
|
||||||
tool_list.append(tool_func)
|
tool_list.append(tool_func)
|
||||||
else:
|
else:
|
||||||
continue
|
continue
|
||||||
|
_agent_tool_result_cache[agent_name] = tool_list
|
||||||
return tool_list
|
return tool_list
|
||||||
|
|
||||||
def refresh_agent_tools() -> None:
|
def refresh_agent_tools() -> None:
|
||||||
get_tool.cache_clear()
|
_agent_tool_result_cache.clear()
|
||||||
|
|
||||||
|
|
||||||
Loading…
Reference in New Issue