Files
2026-07-01 09:22:26 +00:00

553 lines
22 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""BaseOrganization:重型插件基类。
设计要点:
- 单机模式 = 普通 Python 对象,分布式 = ray actor``@actor_class`` 装饰子类)
- 内置 ``asyncio.Queue`` 输入队列 + 任务表
- 对外两条通道:``dispatch`` (阻塞) / ``submit`` (射后不管),底层都汇集到 ``_run_task``
- 子类只需覆写 ``setup`` / ``react`` 两个钩子;零代码插件由 ``agents.json`` 声明驱动
"""
from __future__ import annotations
import asyncio
import json
import time
from typing import Any, AsyncGenerator, Callable, Dict, List, Optional, Type
from ulid import ULID
from kilostar.plugin_runtime.event import OrgEvent, TaskState
from kilostar.plugin_runtime.manifest import OrgManifest
from kilostar.plugin_runtime.agents_config import AgentsConfig, AgentDef
from kilostar.utils.logger import get_logger
from kilostar.utils.settings import get_artifact_dir, get_plugin_data_dir
class BaseOrganization:
"""重型插件基类。
生命周期:
``__init__(manifest, agents_config, plugin_dir)`` → ``setup()`` → 持续运行 →
``shutdown()``
setup 期间会:加载本组织 toolset/、构造 agent 实例(带 consult 工具)、
起后台 worker 协程消费输入队列。
"""
def __init__(
self,
manifest_dict: Dict[str, Any],
agents_dict: Dict[str, Any],
plugin_dir: str,
) -> None:
self.manifest = OrgManifest.model_validate(manifest_dict)
self.agents_config = AgentsConfig.model_validate(agents_dict)
self.plugin_dir = plugin_dir
self.name = self.manifest.name
self.logger = get_logger(f"org.{self.name}")
# 任务队列与状态表
self._queue: asyncio.Queue = asyncio.Queue()
self._tasks: Dict[str, TaskState] = {}
self._futures: Dict[str, asyncio.Future] = {}
self._streams: Dict[str, asyncio.Queue] = {}
# 后台消费协程
self._worker_task: Optional[asyncio.Task] = None
self._stopped = False
# 由 setup 填充
self._tools_by_name: Dict[str, Callable] = {}
self._agents: Dict[str, Any] = {} # name -> pydantic-ai Agent
# 插件本地 SQLite 引擎(按需启用,调 init_local_db
self._engine: Any = None
self._session_maker: Any = None
# ─── 生命周期 ──────────────────────────────────────────────
async def setup(self) -> None:
"""加载本组织资源,实例化 agents,启动队列消费协程。
子类可以 override 来扩展(连数据库、起子进程等),但应该 ``await super().setup()``。
"""
await self._load_local_tools()
await self._build_agents()
self._worker_task = asyncio.create_task(self._consume_queue())
async def shutdown(self) -> None:
self._stopped = True
if self._worker_task is not None:
self._worker_task.cancel()
if self._engine is not None:
try:
await self._engine.dispose()
except Exception:
self.logger.debug("engine dispose failed; ignored")
async def on_first_install(self) -> None:
"""安装期一次性钩子:插件首次落地时被调用一次。
典型用途:建数据表、写默认配置、提示用户去前端做后续配置。失败会抛错并让
plugin_manager 回滚(不写 marker,下次启动会重试)。子类按需覆盖;默认空实现。
"""
return None
async def init_local_db(self, base_classes: List[Type[Any]]) -> None:
"""建立插件私有 SQLite 引擎并按 ``base_classes`` 的元数据建表。
``base_classes`` 是插件自己定义的 ``DeclarativeBase`` 子类(每个插件用独立的 Base,
避免跟核心 PG 模型的元数据空间串场)。每次 setup 调用都安全:
``create_all`` 是幂等的,已存在的表不会被改动。
建立后 ``self._session_maker`` 可用于工具/API 内部按需 ``async with sm() as s``。
"""
from sqlalchemy.ext.asyncio import (
AsyncSession,
async_sessionmaker,
create_async_engine,
)
db_path = get_plugin_data_dir(self.name) / f"{self.name}.db"
url = f"sqlite+aiosqlite:///{db_path}"
self._engine = create_async_engine(url, future=True)
self._session_maker = async_sessionmaker(
self._engine, class_=AsyncSession, expire_on_commit=False
)
async with self._engine.begin() as conn:
for base in base_classes:
metadata = getattr(base, "metadata", None)
if metadata is None:
continue
await conn.run_sync(metadata.create_all)
self.logger.info(f"local sqlite ready: {db_path}")
# ─── 对外通道 ──────────────────────────────────────────────
async def dispatch(
self, task_description: str, ctx: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
"""cabinet 同步入口:阻塞等到任务完成才返回。
Returns:
``{"task_id": ..., "status": ..., "result": ..., "error": ...}``
"""
task_id = await self._enqueue(task_description, ctx or {}, source="cabinet")
future = self._futures[task_id]
try:
return await future
finally:
self._futures.pop(task_id, None)
async def submit(
self, task_description: str, ctx: Optional[Dict[str, Any]] = None
) -> str:
"""用户 API 入口:投入队列就返回,状态走 ``status`` / ``stream``。"""
return await self._enqueue(task_description, ctx or {}, source="user")
async def status(self, task_id: str) -> Optional[Dict[str, Any]]:
ts = self._tasks.get(task_id)
if ts is None:
return None
return {
"task_id": ts.task_id,
"status": ts.status,
"description": ts.description,
"source": ts.source,
"result": ts.result,
"error": ts.error,
"events": [e.to_dict() for e in ts.events],
}
async def stream(self, task_id: str) -> AsyncGenerator[Dict[str, Any], None]:
"""SSE 端点用:异步生成器,每 yield 一个事件 dict。
如果 task 已经完成,把历史事件回放完毕后即结束;否则持续推送实时事件。
"""
ts = self._tasks.get(task_id)
if ts is None:
return
# 历史回放
for ev in list(ts.events):
yield ev.to_dict()
if ts.status in ("completed", "failed"):
return
# 实时订阅:用一个 per-stream queue
sub_queue: asyncio.Queue = asyncio.Queue()
self._streams.setdefault(task_id, sub_queue)
try:
while True:
ev = await sub_queue.get()
if ev is None:
break
yield ev.to_dict()
finally:
self._streams.pop(task_id, None)
async def list_tasks(self) -> List[Dict[str, Any]]:
return [
{
"task_id": ts.task_id,
"status": ts.status,
"source": ts.source,
"description": ts.description,
}
for ts in self._tasks.values()
]
# ─── 子类钩子 ──────────────────────────────────────────────
async def react(
self,
task_description: str,
ctx: Dict[str, Any],
emit: Callable[[OrgEvent], Any],
) -> Any:
"""默认 ReAct 实现:把任务交给 entry agent 跑一轮。
子类可覆盖以实现自定义编排(DAG/pipeline)。
"""
entry_name = self.agents_config.orchestration.entry
entry_agent = self._agents.get(entry_name)
if entry_agent is None:
raise RuntimeError(f"entry agent {entry_name!r} not found in {self.name}")
await emit(
OrgEvent(
task_id=ctx["task_id"],
type="step",
payload={"agent": entry_name, "phase": "start"},
)
)
try:
result = await entry_agent.run(user_prompt=task_description)
output = getattr(result, "output", None) or str(result)
except Exception as e:
self.logger.exception(f"entry agent {entry_name} run failed: {e}")
raise
await emit(
OrgEvent(
task_id=ctx["task_id"],
type="step",
payload={"agent": entry_name, "phase": "end"},
)
)
return output
# ─── 内部实现 ──────────────────────────────────────────────
async def _enqueue(
self,
task_description: str,
ctx: Dict[str, Any],
source: str,
) -> str:
task_id = str(ULID())
trace_id = ctx.get("trace_id") or task_id
user_id = ctx.get("user_id", "")
# 沙箱目录:data/artifact/<trace>/<org>/
artifact_dir = str(get_artifact_dir() / trace_id / self.name)
ts = TaskState(
task_id=task_id,
org_name=self.name,
trace_id=trace_id,
user_id=user_id,
description=task_description,
source=source, # type: ignore[arg-type]
)
self._tasks[task_id] = ts
self._futures[task_id] = asyncio.get_event_loop().create_future()
full_ctx = {
**ctx,
"trace_id": trace_id,
"user_id": user_id,
"task_id": task_id,
"source": source,
"artifact_dir": artifact_dir,
}
await self._queue.put((task_id, task_description, full_ctx))
# 持久化(best-effortPG 不可用时静默)
await self._persist_task(ts)
return task_id
async def _consume_queue(self) -> None:
while not self._stopped:
try:
task_id, desc, ctx = await self._queue.get()
except asyncio.CancelledError:
break
try:
await self._run_task(task_id, desc, ctx)
except Exception as e:
self.logger.exception(f"task {task_id} crashed: {e}")
async def _run_task(self, task_id: str, desc: str, ctx: Dict[str, Any]) -> None:
ts = self._tasks[task_id]
ts.status = "running"
await self._persist_task(ts)
async def _emit(ev: OrgEvent) -> None:
ts.events.append(ev)
sub = self._streams.get(task_id)
if sub is not None:
await sub.put(ev)
await self._persist_event(ts, ev)
try:
result = await self.react(desc, ctx, _emit)
ts.status = "completed"
ts.result = result
await _emit(
OrgEvent(task_id=task_id, type="done", payload={"result": result})
)
except Exception as e:
ts.status = "failed"
ts.error = str(e)
await _emit(
OrgEvent(task_id=task_id, type="error", payload={"error": str(e)})
)
finally:
await self._persist_task(ts)
# 通知 stream 关闭
sub = self._streams.get(task_id)
if sub is not None:
await sub.put(None)
# 唤醒 dispatch 端
fut = self._futures.get(task_id)
if fut is not None and not fut.done():
fut.set_result(
{
"task_id": task_id,
"status": ts.status,
"result": ts.result,
"error": ts.error,
}
)
# ─── PG 持久化 ─────────────────────────────────────────────
async def _persist_task(self, ts: TaskState) -> None:
"""把任务状态写到 PG。失败不阻塞执行。"""
try:
from kilostar.utils.ray_hook import ray_actor_hook
pg = ray_actor_hook("postgres_database").postgres_database
await pg.upsert_org_task.remote(
task_id=ts.task_id,
org_name=ts.org_name,
trace_id=ts.trace_id,
user_id=ts.user_id,
status=ts.status,
description=ts.description,
source=ts.source,
result=ts.result if isinstance(ts.result, (str, dict, list, type(None))) else str(ts.result),
error=ts.error,
)
except Exception:
self.logger.debug("persist_task skipped (no DB / not ready)")
async def _persist_event(self, ts: TaskState, ev: OrgEvent) -> None:
try:
from kilostar.utils.ray_hook import ray_actor_hook
pg = ray_actor_hook("postgres_database").postgres_database
await pg.append_org_task_event.remote(
task_id=ts.task_id, event=ev.to_dict()
)
except Exception:
self.logger.debug("persist_event skipped")
# ─── 资源加载 ──────────────────────────────────────────────
async def _load_local_tools(self) -> None:
"""加载本组织 toolset/ 目录下的工具。
复用 ``GlobalToolManager`` 的逻辑:扫描 manifest.json,按 name 注入函数表。
全局工具白名单(``python_executor`` 等)也合并进来,给 agent 兜底。
"""
from pathlib import Path
import importlib
import importlib.util
import sys
import types
toolset_dir = Path(self.plugin_dir) / "toolset"
if toolset_dir.exists() and (toolset_dir / "manifest.json").exists():
with open(toolset_dir / "manifest.json", "r", encoding="utf-8") as f:
manifest = json.load(f)
# 跟 loader._import_entry_class 共用一条虚拟 package 链:
# ``_kilostar_plugin_<name>`` → ``.toolset``,让 ``from ._s3_common import ...``
# 这种相对导入能正常解析。
root_pkg = f"_kilostar_plugin_{self.name}"
tool_pkg = f"{root_pkg}.toolset"
if root_pkg not in sys.modules:
root_mod = types.ModuleType(root_pkg)
root_mod.__path__ = [str(Path(self.plugin_dir))]
sys.modules[root_pkg] = root_mod
if tool_pkg not in sys.modules:
pkg = types.ModuleType(tool_pkg)
pkg.__path__ = [str(toolset_dir)]
sys.modules[tool_pkg] = pkg
# 第一遍:把 toolset 目录下所有 .py 都按文件名注册成子模块,
# 让共享辅助模块(如 ``_s3_common``)先就位。
for py_path in sorted(toolset_dir.glob("*.py")):
if py_path.name == "__init__.py":
continue
sub_name = f"{tool_pkg}.{py_path.stem}"
if sub_name in sys.modules:
continue
spec = importlib.util.spec_from_file_location(sub_name, str(py_path))
if spec is None or spec.loader is None:
continue
mod = importlib.util.module_from_spec(spec)
mod.__package__ = tool_pkg
sys.modules[sub_name] = mod
try:
spec.loader.exec_module(mod)
except Exception as e:
self.logger.warning(f"failed to load tool module {py_path.name}: {e}")
sys.modules.pop(sub_name, None)
# 第二遍:按 manifest 列表挑出工具函数
for tool_def in manifest.get("tools", []):
tname = tool_def.get("name")
tfile = tool_def.get("file", f"{tname}.py")
if not tname:
continue
stem = Path(tfile).stem
sub_name = f"{tool_pkg}.{stem}"
mod = sys.modules.get(sub_name)
if mod is None:
self.logger.warning(f"tool module not loaded: {tfile}")
continue
func = getattr(mod, tname, None)
if callable(func):
self._tools_by_name[tname] = func
# 从全局 tool manager 借通用工具
await self._merge_global_tools()
async def _merge_global_tools(self) -> None:
"""合并 cabinet 全局工具白名单(python_executor 等基础工具)。"""
try:
from kilostar.core.global_state_machine.gsm_snapshot import fetch_snapshot
snapshot = await fetch_snapshot()
for name, func in snapshot.all_funcs.items():
self._tools_by_name.setdefault(name, func)
except Exception:
self.logger.debug("global tools not available; org runs with local only")
async def _build_agents(self) -> None:
"""按 agents.json 实例化 pydantic-ai Agent。
每个 agent 注入:
- 自己声明的 tools(从 ``_tools_by_name`` 取)
- 一个特殊 ``consult`` 工具(如果 peers 非空),用于跨 agent 协作
provider+model 的来源:
1. agents.json 里若已写死 ``model`` → 直接用(兼容老插件)
2. 否则按 ``(plugin_name, slot_name)`` 查 DB,拿用户在 Agent 设置页配置的
provider+model;查不到则跳过该 slot(日志 warning,让用户先去配置)
"""
from kilostar.adapter.model_adapter.agent_factory import AgentFactory
from kilostar.core.global_state_machine.gsm_snapshot import fetch_snapshot
snapshot = await fetch_snapshot()
factory = AgentFactory()
for adef in self.agents_config.agents:
provider_title, model_id = await self._resolve_slot_model(adef)
if not provider_title or not model_id:
self.logger.warning(
f"agent slot {adef.name!r}: provider/model 未配置(请在 Agent 设置页装配)"
)
continue
provider = snapshot.providers.get(provider_title)
if provider is None:
self.logger.warning(
f"provider {provider_title!r} not found; agent {adef.name} skipped"
)
continue
tools = [
self._tools_by_name[t]
for t in adef.tools
if t in self._tools_by_name
]
consult_tool = self._make_consult_tool(adef)
if consult_tool is not None:
tools.append(consult_tool)
try:
agent = factory.create_agent(
provider=provider,
model_id=model_id,
output_type=str,
system_prompt=adef.system_prompt or f"You are {adef.role}.",
deps_type=type(None),
agent_name=f"{self.name}.{adef.name}",
tools=tools,
toolsets=None,
)
self._agents[adef.name] = agent
except Exception as e:
self.logger.warning(f"build agent {adef.name} failed: {e}")
async def _resolve_slot_model(self, adef: AgentDef) -> tuple[str, str]:
"""决定 slot 用哪个 provider+model。
优先静态绑定(向后兼容老插件),否则查 DB 中用户为该 slot 配置的值。
DB 不可用时返回空——构建侧据此跳过该 slot。
"""
if adef.model and adef.model.provider_title and adef.model.model_id:
return adef.model.provider_title, adef.model.model_id
try:
from kilostar.utils.ray_hook import ray_actor_hook
pg = ray_actor_hook("postgres_database").postgres_database
row = await pg.find_plugin_slot.remote(self.name, adef.name)
if row is None:
return "", ""
return getattr(row, "provider_title", "") or "", getattr(row, "model_id", "") or ""
except Exception as e:
self.logger.debug(f"slot model lookup failed (DB?): {e}")
return "", ""
def _make_consult_tool(self, adef: AgentDef):
"""为 agent 生成一个 ``consult(peer, question)`` 工具。
peers 为空则不生成;调用时直接 await 同事 agent.run。
"""
if not adef.peers:
return None
peers = list(adef.peers)
org = self
async def consult(peer: str, question: str) -> str:
"""向同事 agent 提问以获取专业意见。
Args:
peer: 同事 agent 名字
question: 要问的问题
"""
if peer not in peers:
return f"[error] {peer} 不在你的协作列表中: {peers}"
target = org._agents.get(peer)
if target is None:
return f"[error] 同事 agent {peer} 未启动"
try:
resp = await target.run(user_prompt=question)
return getattr(resp, "output", None) or str(resp)
except Exception as e:
return f"[error] {peer} 失败: {e}"
return consult