chore: initial commit for Pretor v0.1.0-alpha
正式发布 Pretor 平台的首个 alpha 版本。本项目旨在构建一个基于分布式架构的多智能体协同工作流水线。 核心功能实现: 1. 建立基于 BaseIndividual 的动态插件加载机制。 2. 实现三类核心 worker_individual 子个体。 3. 集成 Ray 框架支持分布式集群调度。 4. 基于 PostgreSQL 的全量持久化存储方案。 5. 提供完整的 FastAPI 后端与 React 前端交互界面。
This commit is contained in:
@@ -0,0 +1,14 @@
|
||||
# Copyright 2026 zhaoxi826
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
@@ -0,0 +1,166 @@
|
||||
# Copyright 2026 zhaoxi826
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import ray
|
||||
from pretor.core.global_state_machine.provider_manager import ProviderManager
|
||||
from pretor.core.global_state_machine.tool_manager import GlobalToolManager
|
||||
from typing import Dict
|
||||
from pretor.core.database.postgres import PostgresDatabase
|
||||
from pretor.api.platform.event import PretorEvent
|
||||
import asyncio
|
||||
from pretor.core.workflow.workflow import PretorWorkflow
|
||||
from pretor.core.workflow.workflow_template_manager import WorkflowManager
|
||||
from pretor.core.global_state_machine.skill_manager import GlobalSkillManager
|
||||
from pretor.core.global_state_machine.individual_manager import GlobalIndividualManager
|
||||
|
||||
|
||||
@ray.remote
|
||||
class GlobalStateMachine:
|
||||
def __init__(self, postgres_database: PostgresDatabase):
|
||||
import sys
|
||||
print("GSM __init__ START", file=sys.stderr, flush=True)
|
||||
self.event_dict: Dict[str, PretorEvent] = {}
|
||||
print(" event_dict done", file=sys.stderr, flush=True)
|
||||
self._global_provider_manager = ProviderManager(postgres_database)
|
||||
print(" provider_manager done", file=sys.stderr, flush=True)
|
||||
self._global_tool_manager = GlobalToolManager()
|
||||
print(" tool_manager done", file=sys.stderr, flush=True)
|
||||
self._global_workflow_template_manager = WorkflowManager()
|
||||
print(" workflow_template_manager done", file=sys.stderr, flush=True)
|
||||
self._global_skill_manager = GlobalSkillManager()
|
||||
print(" skill_manager done", file=sys.stderr, flush=True)
|
||||
self._global_individual_manager = GlobalIndividualManager()
|
||||
print(" individual_manager done", file=sys.stderr, flush=True)
|
||||
self.postgres_database = postgres_database
|
||||
print("GSM __init__ DONE", file=sys.stderr, flush=True)
|
||||
|
||||
async def init_state_machine(self):
|
||||
await self._global_provider_manager.init_provider_register(self.postgres_database)
|
||||
await self._global_individual_manager.init_individual_register(self.postgres_database)
|
||||
|
||||
async def add_provider_wrap(self, provider_type, provider_title, provider_url, provider_apikey, provider_owner):
|
||||
return await self._global_provider_manager.add_provider(
|
||||
provider_type=provider_type,
|
||||
provider_title=provider_title,
|
||||
provider_url=provider_url,
|
||||
provider_apikey=provider_apikey,
|
||||
provider_owner=provider_owner,
|
||||
postgres_database=self.postgres_database
|
||||
)
|
||||
|
||||
# Provider Manager Methods
|
||||
def get_provider_list(self):
|
||||
return self._global_provider_manager.get_provider_list()
|
||||
|
||||
def get_provider(self, provider_title):
|
||||
return self._global_provider_manager.get_provider(provider_title)
|
||||
|
||||
async def delete_provider(self, provider_title: str):
|
||||
return await self._global_provider_manager.delete_provider(provider_title, self.postgres_database)
|
||||
|
||||
# Tool Manager Methods
|
||||
def get_tool_mapper(self):
|
||||
return self._global_tool_manager.tool_mapper
|
||||
|
||||
def get_tool_list(self, agent_name: str):
|
||||
# get_tool_list didn't actually exist on tool_manager, let's implement it to return the tools
|
||||
# for a specific agent name (or scope)
|
||||
tools = self._global_tool_manager.tool_mapper.get(agent_name, {})
|
||||
# also include default tools
|
||||
default_tools = self._global_tool_manager.tool_mapper.get("default", {})
|
||||
merged_tools = {**default_tools, **tools}
|
||||
return merged_tools
|
||||
|
||||
# Workflow Template Manager Methods
|
||||
def get_all_workflow_templates(self):
|
||||
return self._global_workflow_template_manager.get_all_workflow_templates()
|
||||
|
||||
def add_workflow_template(self, template_name: str, workflow_template):
|
||||
return self._global_workflow_template_manager.add_workflow_template(template_name, workflow_template)
|
||||
|
||||
def delete_workflow_template(self, template_name: str):
|
||||
return self._global_workflow_template_manager.delete_workflow_template(template_name)
|
||||
|
||||
def generate_workflow_template(self, workflow_template):
|
||||
return self._global_workflow_template_manager.generate_workflow_template(workflow_template)
|
||||
|
||||
# Skill Manager Methods
|
||||
def add_skill(self, skill_name: str):
|
||||
return self._global_skill_manager.add_skill(skill_name)
|
||||
|
||||
def get_skill_list(self):
|
||||
return self._global_skill_manager.get_skill_list()
|
||||
|
||||
def remove_skill(self, skill_name: str):
|
||||
return self._global_skill_manager.remove_skill(skill_name)
|
||||
|
||||
# Individual Manager Methods
|
||||
def add_individual(self, agent_id: str, config):
|
||||
return self._global_individual_manager.add_individual(agent_id, config)
|
||||
|
||||
def get_individual(self, agent_id: str):
|
||||
return self._global_individual_manager.get_individual(agent_id)
|
||||
|
||||
def remove_individual(self, agent_id: str):
|
||||
return self._global_individual_manager.remove_individual(agent_id)
|
||||
|
||||
def list_individuals(self):
|
||||
return self._global_individual_manager.list_individuals()
|
||||
|
||||
###以下方法为event_dict方法
|
||||
def add_event(self, event: PretorEvent) -> None:
|
||||
event.pending_queue = asyncio.Queue()
|
||||
event.receive_queue = asyncio.Queue()
|
||||
self.event_dict[event.trace_id] = event
|
||||
|
||||
def delete_event(self, trace_id: str) -> None:
|
||||
del self.event_dict[trace_id]
|
||||
|
||||
def get_event(self, trace_id: str) -> PretorEvent:
|
||||
return self.event_dict.get(trace_id, None)
|
||||
|
||||
def update_attachment(self, trace_id: str, attachment: Dict[str, str]) -> None:
|
||||
self.event_dict[trace_id].attachment = attachment
|
||||
|
||||
def update_workflow(self, trace_id: str, workflow: PretorWorkflow) -> None:
|
||||
self.event_dict[trace_id].workflow = workflow
|
||||
|
||||
def get_workflow(self, trace_id: str) -> PretorWorkflow:
|
||||
return self.event_dict[trace_id].workflow
|
||||
|
||||
def list_events(self) -> list[dict]:
|
||||
result = []
|
||||
for trace_id, event in self.event_dict.items():
|
||||
workflow_title = event.workflow.title if event.workflow else None
|
||||
workflow_status = event.workflow.status.status if event.workflow and event.workflow.status else None
|
||||
result.append({
|
||||
"event_id": trace_id,
|
||||
"workflow_title": workflow_title,
|
||||
"status": workflow_status,
|
||||
"user_name": event.user_name,
|
||||
"message": event.message,
|
||||
})
|
||||
return result
|
||||
|
||||
async def put_pending(self, trace_id, item) -> None:
|
||||
await self.event_dict[trace_id].pending_queue.put(item)
|
||||
|
||||
async def get_pending(self, trace_id) -> str:
|
||||
return await self.event_dict[trace_id].pending_queue.get()
|
||||
|
||||
async def put_received(self, trace_id, item) -> None:
|
||||
await self.event_dict[trace_id].receive_queue.put(item)
|
||||
|
||||
async def get_received(self, trace_id) -> str:
|
||||
return await self.event_dict[trace_id].receive_queue.get()
|
||||
@@ -0,0 +1,62 @@
|
||||
# Copyright 2026 zhaoxi826
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Dict, Any
|
||||
from pretor.utils.logger import get_logger
|
||||
logger = get_logger('individual_manager')
|
||||
|
||||
class GlobalIndividualManager:
|
||||
def __init__(self):
|
||||
self._individuals: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
async def init_individual_register(self, postgres) -> None:
|
||||
try:
|
||||
try:
|
||||
individuals = await postgres.get_all_worker_individual.remote()
|
||||
for ind in individuals:
|
||||
agent_id = getattr(ind, 'agent_id', None)
|
||||
if agent_id:
|
||||
self._individuals[agent_id] = ind.model_dump() if hasattr(ind, 'model_dump') else dict(ind)
|
||||
logger.info(f"成功从数据库拉取了 {len(self._individuals)} 个 Worker Individual 配置。")
|
||||
except AttributeError:
|
||||
logger.warning("数据库中 get_all_worker_individual 方法未实现,跳过全量加载。可以在将来完善该接口。")
|
||||
except Exception as e:
|
||||
# 捕获因 Ray 调用目标方法不存在引发的异常
|
||||
if "has no attribute 'get_all_worker_individual'" in str(e):
|
||||
logger.warning("数据库 individual_database 中缺少 get_all_worker_individual 方法,无法全量拉取。")
|
||||
else:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logger.error(f"从数据库拉取 Worker Individual 配置失败: {e}")
|
||||
|
||||
def add_individual(self, agent_id: str, config: Dict[str, Any]) -> None:
|
||||
"""
|
||||
注册一个 worker individual
|
||||
config 可以包含 type, prompt, provider_title, model_id 等
|
||||
"""
|
||||
config["agent_id"] = agent_id
|
||||
self._individuals[agent_id] = config
|
||||
|
||||
def get_individual(self, agent_id: str) -> Dict[str, Any]:
|
||||
"""
|
||||
获取一个 worker individual 的配置
|
||||
"""
|
||||
return self._individuals.get(agent_id, None)
|
||||
|
||||
def remove_individual(self, agent_id: str) -> None:
|
||||
if agent_id in self._individuals:
|
||||
del self._individuals[agent_id]
|
||||
|
||||
def list_individuals(self) -> Dict[str, Dict[str, Any]]:
|
||||
return self._individuals
|
||||
@@ -0,0 +1,19 @@
|
||||
# Copyright 2026 zhaoxi826
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from pretor.core.global_state_machine.model_provider.base_provider import Provider, ProviderArgs
|
||||
from pretor.core.global_state_machine.model_provider.openai_provider import OpenAIProvider
|
||||
from pretor.core.global_state_machine.model_provider.claude_provider import ClaudeProvider
|
||||
from pretor.core.global_state_machine.model_provider.deepseek_provider import DeepseekProvider
|
||||
__all__ = ["Provider", "ProviderArgs", "OpenAIProvider", "ClaudeProvider", "DeepseekProvider"]
|
||||
@@ -0,0 +1,96 @@
|
||||
# Copyright 2026 zhaoxi826
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from pydantic import BaseModel
|
||||
from typing import List
|
||||
from enum import Enum
|
||||
|
||||
class ProviderStatus(str, Enum):
|
||||
UP = "up"
|
||||
DOWN = "down"
|
||||
|
||||
class Provider(BaseModel):
|
||||
provider_title: str
|
||||
provider_url: str
|
||||
provider_apikey: str
|
||||
provider_models: List[str]
|
||||
provider_type: str
|
||||
provider_owner: str | None = None
|
||||
provider_status: ProviderStatus = ProviderStatus.UP
|
||||
|
||||
class ProviderArgs(BaseModel):
|
||||
provider_title: str
|
||||
provider_url: str
|
||||
provider_apikey: str
|
||||
provider_owner: str
|
||||
|
||||
class BaseProvider(ABC):
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
async def create_provider(provider_args: ProviderArgs) -> Provider:
|
||||
"""
|
||||
创建一个供应商,传入provider_args参数,打包为一个Provider对象
|
||||
|
||||
Args:
|
||||
provider_args: 参数包,包含以下几个参数
|
||||
provider_title: 供应商的别名
|
||||
provider_url: 供应商的url
|
||||
provider_apikey:供应商的apikey
|
||||
|
||||
Returns:
|
||||
返回一个Provider对象,由provider_manager管理
|
||||
"""
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
async def _load_models(provider_args: ProviderArgs) -> List[str]:
|
||||
"""
|
||||
加载模型列表
|
||||
base_provider的字类应当按照供应商的api标准,向供应商的接口发送http请求从而或者供应商所提供的模型列表
|
||||
|
||||
Args:
|
||||
provider_args: 参数包,包含以下几个参数
|
||||
provider_title: 供应商的别名
|
||||
provider_url: 供应商的url
|
||||
provider_apikey:供应商的apikey
|
||||
|
||||
Returns:
|
||||
返回一个列表,为http请求获取的模型列表
|
||||
|
||||
"""
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def _return_provider(provider_args: ProviderArgs, provider_models: List[str]) -> Provider:
|
||||
"""
|
||||
包装Provider对象并返回
|
||||
将provider_args和_load_models获取的方法包装为provider对象
|
||||
|
||||
Args:
|
||||
provider_args: 参数包,包含以下几个参数
|
||||
provider_title: 供应商的别名
|
||||
provider_url: 供应商的url
|
||||
provider_apikey:供应商的apikey
|
||||
|
||||
provider_models: 模型列表,为该供应商包含的模型列表
|
||||
|
||||
Returns:
|
||||
返回一个Provider对象
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@@ -0,0 +1,60 @@
|
||||
from pretor.utils.retry import retry_on_retryable_error
|
||||
# Copyright 2026 zhaoxi826
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from pretor.core.global_state_machine.model_provider.base_provider import BaseProvider, Provider, ProviderArgs
|
||||
import httpx
|
||||
from typing import List
|
||||
|
||||
class ClaudeProvider(BaseProvider):
|
||||
@staticmethod
|
||||
async def create_provider(provider_args: ProviderArgs) -> Provider:
|
||||
provider_models: List[str] = await ClaudeProvider._load_models(provider_args)
|
||||
provider: Provider = ClaudeProvider._return_provider(provider_args, provider_models)
|
||||
return provider
|
||||
|
||||
@staticmethod
|
||||
@retry_on_retryable_error()
|
||||
async def _load_models(provider_args: ProviderArgs) -> List[str]:
|
||||
# Anthropic 官方需要 version 头
|
||||
headers = {
|
||||
"x-api-key": provider_args.provider_apikey,
|
||||
"anthropic-version": "2023-06-01",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
# 如果是官方 API,通常使用 /v1/models (如果支持)
|
||||
# 注意:很多时候 Anthropic 并不返回完整列表,如果请求失败,建议返回硬编码的常用模型
|
||||
url = f"{provider_args.provider_url.rstrip('/')}/v1/models"
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
response = await client.get(url, headers=headers)
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
model_ids = [m["id"] for m in data.get("data", [])]
|
||||
return sorted(model_ids)
|
||||
else:
|
||||
# 如果官方列表接口不可用,fallback 到已知常用模型
|
||||
return ["claude-3-5-sonnet-20240620", "claude-3-opus-20240229", "claude-3-haiku-20240307"]
|
||||
except Exception as e:
|
||||
print(f"[{provider_args.provider_title}] 获取 Claude 模型列表错误: {e}")
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def _return_provider(provider_args: ProviderArgs, provider_models: List[str]) -> Provider:
|
||||
return Provider(provider_title=provider_args.provider_title,
|
||||
provider_apikey=provider_args.provider_apikey,
|
||||
provider_url=provider_args.provider_url,
|
||||
provider_models=provider_models,
|
||||
provider_type="claude")
|
||||
@@ -0,0 +1,59 @@
|
||||
# Copyright 2026 zhaoxi826
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from pretor.utils.retry import retry_on_retryable_error
|
||||
from pretor.core.global_state_machine.model_provider.base_provider import BaseProvider, Provider, ProviderArgs
|
||||
import httpx
|
||||
from typing import List
|
||||
|
||||
class DeepseekProvider(BaseProvider):
|
||||
@staticmethod
|
||||
async def create_provider(provider_args: ProviderArgs) -> Provider:
|
||||
provider_models: List[str] = await DeepseekProvider._load_models(provider_args)
|
||||
provider: Provider = DeepseekProvider._return_provider(provider_args, provider_models)
|
||||
return provider
|
||||
|
||||
@staticmethod
|
||||
@retry_on_retryable_error()
|
||||
async def _load_models(provider_args: ProviderArgs) -> List[str]:
|
||||
headers = {
|
||||
"Authorization": f"Bearer {provider_args.provider_apikey}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
url = f"{provider_args.provider_url}/models" if "/v1" in provider_args.provider_url else f"{provider_args.provider_url}/v1/models"
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
response = await client.get(url, headers=headers)
|
||||
if response.status_code != 200:
|
||||
print(f"[{provider_args.provider_title}] 获取模型失败: {response.status_code}")
|
||||
return []
|
||||
data = response.json()
|
||||
raw_models = data.get("data", [])
|
||||
model_ids = [m["id"] for m in raw_models]
|
||||
return sorted(model_ids)
|
||||
except httpx.RequestError as e:
|
||||
from pretor.utils.error import RetryableError
|
||||
print(f"[{provider_args.provider_title}] 网络请求异常: {e}")
|
||||
raise RetryableError(f"[{provider_args.provider_title}] 网络请求异常: {e}") from e
|
||||
except Exception as e:
|
||||
print(f"[{provider_args.provider_title}] 解析模型列表时发生错误: {e}")
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def _return_provider(provider_args: ProviderArgs, provider_models: List[str]) -> Provider:
|
||||
return Provider(provider_title=provider_args.provider_title,
|
||||
provider_apikey=provider_args.provider_apikey,
|
||||
provider_url=provider_args.provider_url,
|
||||
provider_models=provider_models,
|
||||
provider_type="deepseek")
|
||||
@@ -0,0 +1,59 @@
|
||||
# Copyright 2026 zhaoxi826
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from pretor.utils.retry import retry_on_retryable_error
|
||||
from pretor.core.global_state_machine.model_provider.base_provider import BaseProvider, Provider, ProviderArgs
|
||||
import httpx
|
||||
from typing import List
|
||||
|
||||
class OpenAIProvider(BaseProvider):
|
||||
@staticmethod
|
||||
async def create_provider(provider_args: ProviderArgs) -> Provider:
|
||||
provider_models: List[str] = await OpenAIProvider._load_models(provider_args)
|
||||
provider: Provider = OpenAIProvider._return_provider(provider_args, provider_models)
|
||||
return provider
|
||||
|
||||
@staticmethod
|
||||
@retry_on_retryable_error()
|
||||
async def _load_models(provider_args: ProviderArgs) -> List[str]:
|
||||
headers = {
|
||||
"Authorization": f"Bearer {provider_args.provider_apikey}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
url = f"{provider_args.provider_url}/models" if "/v1" in provider_args.provider_url else f"{provider_args.provider_url}/v1/models"
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
response = await client.get(url, headers=headers)
|
||||
if response.status_code != 200:
|
||||
print(f"[{provider_args.provider_title}] 获取模型失败: {response.status_code}")
|
||||
return []
|
||||
data = response.json()
|
||||
raw_models = data.get("data", [])
|
||||
model_ids = [m["id"] for m in raw_models]
|
||||
return sorted(model_ids)
|
||||
except httpx.RequestError as e:
|
||||
from pretor.utils.error import RetryableError
|
||||
print(f"[{provider_args.provider_title}] 网络请求异常: {e}")
|
||||
raise RetryableError(f"[{provider_args.provider_title}] 网络请求异常: {e}") from e
|
||||
except Exception as e:
|
||||
print(f"[{provider_args.provider_title}] 解析模型列表时发生错误: {e}")
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def _return_provider(provider_args: ProviderArgs, provider_models: List[str]) -> Provider:
|
||||
return Provider(provider_title=provider_args.provider_title,
|
||||
provider_apikey=provider_args.provider_apikey,
|
||||
provider_url=provider_args.provider_url,
|
||||
provider_models=provider_models,
|
||||
provider_type="openai")
|
||||
@@ -0,0 +1,86 @@
|
||||
# Copyright 2026 zhaoxi826
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from pretor.core.global_state_machine.model_provider import Provider, OpenAIProvider, ClaudeProvider, DeepseekProvider
|
||||
from typing import Dict, Type
|
||||
|
||||
class ProviderManager:
|
||||
"""
|
||||
模型供应商管理器 (ProviderManager)。
|
||||
负责维护不同的 LLM 协议适配器,提供从配置注册到 Agent 实例化的全生命周期管理。
|
||||
"""
|
||||
# --- 类属性显式标注 (IDE 友好) ---
|
||||
provider_mapper: Dict[str, Type[Provider]]
|
||||
"""协议映射表:键为协议名(如 'openai'),值为对应的 Provider 类。"""
|
||||
|
||||
provider_register: Dict[str, Provider]
|
||||
"""供应商注册表:键为用户自定义别名,值为已实例化的 Provider 对象。"""
|
||||
def __init__(self, postgres):
|
||||
self.provider_mapper = {"openai": OpenAIProvider,
|
||||
"claude": ClaudeProvider,
|
||||
"deepseek": DeepseekProvider}
|
||||
self.provider_register = {}
|
||||
|
||||
async def init_provider_register(self, postgres) -> None:
|
||||
providers = await postgres.get_provider.remote()
|
||||
for provider in providers:
|
||||
self.provider_register[provider.provider_title] = provider
|
||||
|
||||
async def add_provider(self, provider_type, provider_title, provider_url, provider_apikey, provider_owner, postgres_database) -> None:
|
||||
from pretor.core.global_state_machine.model_provider import ProviderArgs
|
||||
from pretor.utils.logger import get_logger
|
||||
logger = get_logger('provider_manager')
|
||||
import httpx
|
||||
|
||||
provider_args: ProviderArgs = ProviderArgs(provider_title=provider_title,
|
||||
provider_url=provider_url,
|
||||
provider_apikey=provider_apikey,
|
||||
provider_owner=provider_owner)
|
||||
try:
|
||||
import ulid
|
||||
provider_class = self.provider_mapper.get(provider_type, None)
|
||||
if provider_class is None:
|
||||
logger.warning(f"Provider type {provider_type} is not supported.")
|
||||
return None
|
||||
provider: Provider = await provider_class.create_provider(provider_args)
|
||||
provider.provider_owner = provider_owner
|
||||
self.provider_register[provider_title] = provider
|
||||
await postgres_database.add_provider_db.remote(
|
||||
provider_id=str(ulid.ULID()),
|
||||
provider_title=provider.provider_title,
|
||||
provider_url=provider.provider_url,
|
||||
provider_apikey=provider.provider_apikey,
|
||||
provider_models=provider.provider_models,
|
||||
provider_type=provider.provider_type,
|
||||
provider_owner=provider.provider_owner)
|
||||
|
||||
logger.info(f"已添加适配器{provider_title}")
|
||||
except httpx.RequestError as e:
|
||||
from pretor.utils.error import RetryableError
|
||||
logger.warning(f"[{provider_args.provider_title}] 网络请求异常: {e}")
|
||||
raise RetryableError(f"[{provider_args.provider_title}] 网络请求异常: {e}") from e
|
||||
except Exception as e:
|
||||
logger.warning(f"[{provider_args.provider_title}] 解析模型列表时发生错误: {e}")
|
||||
|
||||
def get_provider_list(self):
|
||||
return self.provider_register
|
||||
|
||||
def get_provider(self, provider_title):
|
||||
return self.provider_register.get(provider_title)
|
||||
|
||||
async def delete_provider(self, provider_title: str, postgres_database) -> None:
|
||||
if provider_title in self.provider_register:
|
||||
provider = self.provider_register[provider_title]
|
||||
await postgres_database.delete_provider_db.remote( provider_id=provider.provider_id)
|
||||
del self.provider_register[provider_title]
|
||||
@@ -0,0 +1,75 @@
|
||||
# Copyright 2026 zhaoxi826
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Tuple, Dict
|
||||
from collections import defaultdict
|
||||
import pathlib
|
||||
import json
|
||||
|
||||
class GlobalSkillManager:
|
||||
skill_mapper = Dict[str,Tuple[str]]
|
||||
"""skill的存储表"""
|
||||
|
||||
def __init__(self):
|
||||
self.skill_mapper = defaultdict(tuple)
|
||||
|
||||
import os
|
||||
skill_plugin_dir = pathlib.Path(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "plugin", "skill")))
|
||||
if not skill_plugin_dir.exists() or not skill_plugin_dir.is_dir():
|
||||
return
|
||||
for item in skill_plugin_dir.iterdir():
|
||||
if item.is_dir() and not item.name.startswith((".", "__")):
|
||||
json_path = item / "skill.json" # 拼接文件路径
|
||||
if json_path.exists():
|
||||
try:
|
||||
with open(json_path, "r", encoding="utf-8") as f:
|
||||
skill = json.load(f)
|
||||
# 提取并映射
|
||||
name = skill.get("name")
|
||||
if name:
|
||||
self.skill_mapper[name] = (
|
||||
skill.get("description", ""),
|
||||
skill.get("instructions", "")
|
||||
)
|
||||
except (json.JSONDecodeError, OSError) as e:
|
||||
print(f"警告: 加载插件 {item.name} 失败: {e}")
|
||||
|
||||
def add_skill(self, skill_name: str) -> None:
|
||||
"""Add a skill to the manager by reading its skill.json from the path"""
|
||||
import os
|
||||
skill_plugin_dir = pathlib.Path(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "plugin", "skill")))
|
||||
item = skill_plugin_dir / skill_name
|
||||
if item.is_dir() and not item.name.startswith((".", "__")):
|
||||
json_path = item / "skill.json"
|
||||
if json_path.exists():
|
||||
try:
|
||||
with open(json_path, "r", encoding="utf-8") as f:
|
||||
skill = json.load(f)
|
||||
name = skill.get("name")
|
||||
if name:
|
||||
self.skill_mapper[name] = (
|
||||
skill.get("description", ""),
|
||||
skill.get("instructions", "")
|
||||
)
|
||||
except (json.JSONDecodeError, OSError) as e:
|
||||
print(f"警告: 加载插件 {item.name} 失败: {e}")
|
||||
|
||||
def get_skill_list(self) -> dict:
|
||||
"""Return all skills currently loaded."""
|
||||
return self.skill_mapper
|
||||
|
||||
def remove_skill(self, skill_name: str) -> None:
|
||||
"""Remove a skill from the manager mapping."""
|
||||
if skill_name in self.skill_mapper:
|
||||
del self.skill_mapper[skill_name]
|
||||
@@ -0,0 +1,52 @@
|
||||
# Copyright 2026 zhaoxi826
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import pathlib
|
||||
import importlib
|
||||
import inspect
|
||||
from collections import defaultdict
|
||||
from pretor.plugin.tool_plugin.base_tool import BaseToolData
|
||||
from typing import Dict, Type
|
||||
from pretor.utils.logger import get_logger
|
||||
logger = get_logger('tool_manager')
|
||||
|
||||
class GlobalToolManager:
|
||||
tool_mapper: Dict[str, Dict[str, Type[BaseToolData]]]
|
||||
|
||||
def __init__(self):
|
||||
self.tool_mapper = defaultdict(dict)
|
||||
|
||||
tool_plugin_dir = pathlib.Path(__file__).parent.parent.parent / "plugin" / "tool_plugin"
|
||||
if not tool_plugin_dir.exists() or not tool_plugin_dir.is_dir():
|
||||
return
|
||||
|
||||
for item in tool_plugin_dir.iterdir():
|
||||
if item.is_dir() and not item.name.startswith("__"):
|
||||
plugin_name = item.name
|
||||
module_name = f"pretor.plugin.tool_plugin.{plugin_name}"
|
||||
|
||||
try:
|
||||
module = importlib.import_module(module_name)
|
||||
for name, obj in inspect.getmembers(module, inspect.isclass):
|
||||
if issubclass(obj, BaseToolData) and obj is not BaseToolData:
|
||||
# It's a valid tool class
|
||||
action_scopes = obj.model_fields.get("action_scope").default
|
||||
|
||||
if not action_scopes:
|
||||
self.tool_mapper["default"][plugin_name] = obj
|
||||
else:
|
||||
for scope in action_scopes:
|
||||
self.tool_mapper[scope][plugin_name] = obj
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load tool plugin {plugin_name}: {e}")
|
||||
Reference in New Issue
Block a user