wip: 增加了skill_manager
This commit is contained in:
parent
95ec019b5a
commit
cf0117ae2f
|
|
@ -1,4 +1,4 @@
|
|||
<div align="center">
|
||||
<div align="center">
|
||||
|
||||
# Pretor (执政官)
|
||||
|
||||
|
|
@ -21,8 +21,6 @@
|
|||
- (暂未实现)本项目适配多种消息平台,实现在外可通过多种方式给 **Pretor** 下达指令完成工作。
|
||||
- (暂未实现)本项目内置 **growth_node(生长节点)** ,实现傻瓜式微调模型操作,让你的 **Pretor** 自己学会一些独特的技能。
|
||||
|
||||
那么如何拥有属于自己的**执政官**呢?
|
||||
|
||||
---
|
||||
## 快速开始
|
||||
本项目正在开发中...
|
||||
|
|
|
|||
|
|
@ -12,6 +12,8 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from pydantic import BaseModel
|
||||
import viceroy
|
||||
from pretor.core.workflow.workflow_template_generator.workflow_template import WorkflowTemplate
|
||||
from pretor.utils.ray_hook import ray_actor_hook
|
||||
from fastapi import APIRouter, Depends
|
||||
|
|
@ -24,4 +26,23 @@ async def create_workflow_template(workflow_template: WorkflowTemplate,
|
|||
_: TokenData = Depends(Accessor.get_current_user)):
|
||||
global_state_machine = ray_actor_hook("global_state_machine")
|
||||
await global_state_machine.workflow_template_generate.remote(workflow_template)
|
||||
return {"message": "创建成功"}
|
||||
return {"message": "创建成功"}
|
||||
|
||||
class Skill(BaseModel):
|
||||
repo_url: str
|
||||
path: str | None
|
||||
|
||||
@resource_router.post("/skill")
|
||||
async def install_skill(skill: Skill,
|
||||
_: TokenData = Depends(Accessor.get_current_user)):
|
||||
global_state_machine = ray_actor_hook("global_state_machine")
|
||||
await viceroy.install_skill_async(url = skill.repo_url,
|
||||
path = skill.path,
|
||||
output = "./pretor/plugin/tool_plugin")
|
||||
if skill.path:
|
||||
skill_name = skill.path.split("/")[-1]
|
||||
else:
|
||||
skill_name = skill.repo_url.split("/")[-1]
|
||||
await global_state_machine.add_skill.remote(skill_name)
|
||||
return {"message": "创建成功"}
|
||||
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@ from pretor.core.individual.control_node.control_node import ControlNode
|
|||
from pretor.api.platform.frontend import client_router
|
||||
from pretor.api.auth import auth_router
|
||||
from pretor.api.provider import provider_router
|
||||
from pretor.api.resource import resource_router
|
||||
|
||||
@ray.remote
|
||||
class PretorGateway:
|
||||
|
|
@ -46,6 +47,7 @@ class PretorGateway:
|
|||
self.app.include_router(client_router)
|
||||
self.app.include_router(auth_router)
|
||||
self.app.include_router(provider_router)
|
||||
self.app.include_router(resource_router)
|
||||
|
||||
async def server_run(self, host="0.0.0.0", port=8000):
|
||||
config = uvicorn.Config(app=self.app, host=host, port=port, loop="asyncio")
|
||||
|
|
|
|||
|
|
@ -13,10 +13,12 @@
|
|||
# limitations under the License.
|
||||
|
||||
import ray
|
||||
import pathlib
|
||||
from pretor.core.global_state_machine.provider_manager import ProviderManager
|
||||
from pretor.core.global_state_machine.tool_manager import GlobalToolManager
|
||||
from pretor.core.global_state_machine.model_provider import Provider, ProviderArgs
|
||||
import httpx
|
||||
import json
|
||||
from loguru import logger
|
||||
from typing import Dict, Literal
|
||||
from pretor.core.database.postgres import PostgresDatabase
|
||||
|
|
@ -25,7 +27,8 @@ import asyncio
|
|||
from pretor.core.workflow.workflow import PretorWorkflow
|
||||
from pretor.core.workflow.workflow_template_generator.workflow_template import WorkflowTemplate
|
||||
from pretor.core.workflow.workflow_template_manager import WorkflowManager
|
||||
from pretor.tool_plugin.base_tool import BaseToolData
|
||||
from pretor.core.global_state_machine.skill_manager import GlobalSkillManager
|
||||
from pretor.plugin.tool_plugin.base_tool import BaseToolData
|
||||
|
||||
@ray.remote
|
||||
class GlobalStateMachine:
|
||||
|
|
@ -35,6 +38,8 @@ class GlobalStateMachine:
|
|||
self.global_provider_manager = ProviderManager(postgres_database)
|
||||
self.global_tool_manager = GlobalToolManager()
|
||||
self.global_workflow_template_manager = WorkflowManager()
|
||||
self.global_skill_manager = GlobalSkillManager()
|
||||
|
||||
|
||||
self.postgres_database = postgres_database
|
||||
|
||||
|
|
@ -152,4 +157,20 @@ class GlobalStateMachine:
|
|||
|
||||
###以下为workflow_template_manager方法
|
||||
def workflow_template_generator(self, workflow_template: WorkflowTemplate) -> None:
|
||||
self.global_workflow_template_manager.generate_workflow_template(workflow_template)
|
||||
self.global_workflow_template_manager.generate_workflow_template(workflow_template)
|
||||
|
||||
###以下为skill_manager方法
|
||||
def add_skill(self, skill_name: str):
|
||||
skill_plugin_dir = pathlib.Path(__file__).parent.parent.parent / "plugin" / "skill_plugin" / skill_name
|
||||
json_path = skill_plugin_dir / "skill.json"
|
||||
try:
|
||||
with open(json_path, "r", encoding="utf-8") as f:
|
||||
skill = json.load(f)
|
||||
name = skill.get("name")
|
||||
if name:
|
||||
self.global_skill_manager.skill_mapper[name] = (
|
||||
skill.get("description", ""),
|
||||
skill.get("instructions", "")
|
||||
)
|
||||
except (json.JSONDecodeError, OSError) as e:
|
||||
print(f"警告: 加载插件 {skill_name} 失败: {e}")
|
||||
|
|
|
|||
|
|
@ -0,0 +1,46 @@
|
|||
# Copyright 2026 zhaoxi826
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Tuple, Dict
|
||||
from collections import defaultdict
|
||||
import pathlib
|
||||
import json
|
||||
|
||||
class GlobalSkillManager:
|
||||
skill_mapper = Dict[str,Tuple[str]]
|
||||
"""skill的存储表"""
|
||||
|
||||
def __init__(self):
|
||||
self.skill_mapper = defaultdict(tuple)
|
||||
|
||||
skill_plugin_dir = pathlib.Path(__file__).parent.parent.parent / "plugin" / "skill_plugin"
|
||||
if not skill_plugin_dir.exists() or not skill_plugin_dir.is_dir():
|
||||
return
|
||||
for item in skill_plugin_dir.iterdir():
|
||||
if item.is_dir() and not item.name.startswith((".", "__")):
|
||||
json_path = item / "skill.json" # 拼接文件路径
|
||||
if json_path.exists():
|
||||
try:
|
||||
with open(json_path, "r", encoding="utf-8") as f:
|
||||
skill = json.load(f)
|
||||
# 提取并映射
|
||||
name = skill.get("name")
|
||||
if name:
|
||||
self.skill_mapper[name] = (
|
||||
skill.get("description", ""),
|
||||
skill.get("instructions", "")
|
||||
)
|
||||
except (json.JSONDecodeError, OSError) as e:
|
||||
print(f"警告: 加载插件 {item.name} 失败: {e}")
|
||||
|
||||
|
|
@ -16,9 +16,9 @@ import pathlib
|
|||
import importlib
|
||||
import inspect
|
||||
from collections import defaultdict
|
||||
from pretor.tool_plugin.base_tool import BaseToolData
|
||||
from pretor.plugin.tool_plugin.base_tool import BaseToolData
|
||||
from typing import Dict, Type
|
||||
|
||||
from loguru import logger
|
||||
|
||||
class GlobalToolManager:
|
||||
tool_mapper: Dict[str, Dict[str, Type[BaseToolData]]]
|
||||
|
|
@ -26,14 +26,14 @@ class GlobalToolManager:
|
|||
def __init__(self):
|
||||
self.tool_mapper = defaultdict(dict)
|
||||
|
||||
tool_plugin_dir = pathlib.Path(__file__).parent.parent.parent / "tool_plugin"
|
||||
tool_plugin_dir = pathlib.Path(__file__).parent.parent.parent / "plugin.tool_plugin"
|
||||
if not tool_plugin_dir.exists() or not tool_plugin_dir.is_dir():
|
||||
return
|
||||
|
||||
for item in tool_plugin_dir.iterdir():
|
||||
if item.is_dir() and not item.name.startswith("__"):
|
||||
plugin_name = item.name
|
||||
module_name = f"pretor.tool_plugin.{plugin_name}"
|
||||
module_name = f"pretor.plugin.tool_plugin.{plugin_name}"
|
||||
|
||||
try:
|
||||
module = importlib.import_module(module_name)
|
||||
|
|
@ -48,5 +48,4 @@ class GlobalToolManager:
|
|||
for scope in action_scopes:
|
||||
self.tool_mapper[scope][plugin_name] = obj
|
||||
except Exception as e:
|
||||
import logging
|
||||
logging.getLogger(__name__).warning(f"Failed to load tool plugin {plugin_name}: {e}")
|
||||
logger.warning(f"Failed to load tool plugin {plugin_name}: {e}")
|
||||
|
|
@ -18,7 +18,7 @@ from pretor.core.workflow.workflow_template_generator.workflow_template import W
|
|||
class WorkflowTemplateGenerator:
|
||||
@staticmethod
|
||||
def generate_workflow_template(workflow_template: WorkflowTemplate) -> WorkflowTemplate:
|
||||
output_dir = Path("pretor.workflow_template")
|
||||
output_dir = Path("pretor") / "workflow_template"
|
||||
if not output_dir.exists():
|
||||
output_dir.mkdir(parents=True)
|
||||
output_file = output_dir / f"{workflow_template.name}_workflow_template.json"
|
||||
|
|
|
|||
|
|
@ -11,9 +11,8 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from pydantic import Field
|
||||
|
||||
from pretor.tool_plugin.base_tool import BaseToolData
|
||||
from pretor.plugin.tool_plugin.base_tool import BaseToolData
|
||||
from pretor.utils.ray_hook import ray_actor_hook
|
||||
from typing import List, Literal, Dict
|
||||
|
||||
|
|
@ -12,7 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import ray
|
||||
from typing import List
|
||||
from functools import lru_cache
|
||||
|
||||
class ActorList:
|
||||
def __init__(self):
|
||||
|
|
@ -32,10 +32,19 @@ class ActorList:
|
|||
else:
|
||||
raise AttributeError(f"ActorList对象没有属性 '{key}'")
|
||||
|
||||
@lru_cache(maxsize=128)
|
||||
def _get_cached_actor_handle(actor_name: str):
|
||||
"""缓存接口"""
|
||||
return ray.get_actor(actor_name)
|
||||
|
||||
def clear_actor_cache():
|
||||
"""清理接口"""
|
||||
_get_cached_actor_handle.cache_clear()
|
||||
|
||||
def ray_actor_hook(*actor_names: str):
|
||||
actor_list = ActorList()
|
||||
for actor_name in actor_names:
|
||||
handle = ray.get_actor(actor_name)
|
||||
handle = _get_cached_actor_handle(actor_name)
|
||||
setattr(actor_list, actor_name, handle)
|
||||
return actor_list
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
[project]
|
||||
name = "archonbot"
|
||||
name = "pretor"
|
||||
version = "0.1.0"
|
||||
description = "Add your description here"
|
||||
readme = "README.md"
|
||||
|
|
@ -12,13 +12,22 @@ dependencies = [
|
|||
"jinja2>=3.1.6",
|
||||
"loguru>=0.7.3",
|
||||
"passlib[bcrypt]>=1.7.4",
|
||||
"pretor-viceroy>=0.2.0",
|
||||
"pydantic-ai>=1.73.0",
|
||||
"pyfiglet>=1.0.4",
|
||||
"pytest>=9.0.3",
|
||||
"python-ulid>=3.1.0",
|
||||
"ray[default,serve]>=2.54.0",
|
||||
"rich>=14.3.3",
|
||||
"sqlmodel>=0.0.37",
|
||||
"types-docutils==0.22.3.20260408",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
gpu = [
|
||||
"vllm>=0.11.0",
|
||||
]
|
||||
|
||||
[dependency-groups]
|
||||
dev = [
|
||||
"pytest>=9.0.3",
|
||||
]
|
||||
|
|
|
|||
Loading…
Reference in New Issue