203 lines
7.8 KiB
Python
203 lines
7.8 KiB
Python
"""目录扫描 + 装载流水线。
|
||
|
||
公开 ``discover_plugins(dir)`` 和 ``load_plugin(plugin_dir)`` 两个函数:
|
||
- discover:列出所有插件名(manifest 校验通过的)
|
||
- load:读 manifest + agents.json + 解析 entry class,返回可实例化的 ``(class, manifest, agents_dict, plugin_dir)``
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import importlib.util
|
||
import json
|
||
import sys
|
||
from pathlib import Path
|
||
from typing import Any, Dict, List, Tuple, Type
|
||
|
||
from kilostar.plugin_runtime.manifest import OrgManifest
|
||
from kilostar.plugin_runtime.agents_config import AgentsConfig
|
||
from kilostar.utils.logger import get_logger
|
||
|
||
logger = get_logger("plugin_loader")
|
||
|
||
|
||
def discover_plugins(plugin_root: Path) -> List[Path]:
|
||
"""扫描 plugin 根目录,返回所有合法插件目录。
|
||
|
||
合法 = 含 ``manifest.json`` 且能通过 pydantic 校验。
|
||
跳过 ``skill/`` 子目录(那是技能仓库,不是组织)。
|
||
"""
|
||
if not plugin_root.exists() or not plugin_root.is_dir():
|
||
return []
|
||
results: List[Path] = []
|
||
for entry in plugin_root.iterdir():
|
||
if not entry.is_dir() or entry.name.startswith("__"):
|
||
continue
|
||
if entry.name in ("skill",):
|
||
continue
|
||
manifest_path = entry / "manifest.json"
|
||
if not manifest_path.exists():
|
||
continue
|
||
try:
|
||
with open(manifest_path, "r", encoding="utf-8") as f:
|
||
data = json.load(f)
|
||
OrgManifest.model_validate(data)
|
||
except Exception as e:
|
||
logger.warning(f"skip plugin {entry.name}: invalid manifest ({e})")
|
||
continue
|
||
results.append(entry)
|
||
return results
|
||
|
||
|
||
def load_plugin(
|
||
plugin_dir: Path,
|
||
) -> Tuple[Type[Any], Dict[str, Any], Dict[str, Any], str]:
|
||
"""加载单个插件,返回 (Class, manifest_dict, agents_dict, plugin_dir_str)。
|
||
|
||
- 解析 manifest.json + agents.json
|
||
- 如果 manifest.entry 为空,使用 ``BaseOrganization`` 默认实现
|
||
- 否则按 ``"core.organization:DataCleaningOrg"`` 形式动态 import 子类
|
||
"""
|
||
with open(plugin_dir / "manifest.json", "r", encoding="utf-8") as f:
|
||
manifest_dict = json.load(f)
|
||
manifest = OrgManifest.model_validate(manifest_dict)
|
||
|
||
agents_path = plugin_dir / "agents.json"
|
||
if not agents_path.exists():
|
||
raise FileNotFoundError(f"plugin {manifest.name} missing agents.json")
|
||
with open(agents_path, "r", encoding="utf-8") as f:
|
||
agents_dict = json.load(f)
|
||
AgentsConfig.model_validate(agents_dict)
|
||
|
||
if manifest.entry:
|
||
cls = _import_entry_class(plugin_dir, manifest.entry, manifest.name)
|
||
else:
|
||
from kilostar.plugin_runtime.base_organization import BaseOrganization
|
||
|
||
cls = BaseOrganization
|
||
|
||
return cls, manifest_dict, agents_dict, str(plugin_dir)
|
||
|
||
|
||
def _import_entry_class(plugin_dir: Path, entry: str, plugin_name: str) -> Type[Any]:
|
||
"""形如 ``core.organization:DataCleaningOrg`` 的入口字符串解析。
|
||
|
||
``:`` 左边是相对插件根的模块路径(用 / 或 . 分隔均可),右边是类名。
|
||
会预先把插件根 + 入口模块所在子目录注册成虚拟 package,让相对导入
|
||
(``from .db import ...``)能正常工作。
|
||
"""
|
||
if ":" not in entry:
|
||
raise ValueError(f"invalid entry {entry!r}: missing ':<ClassName>'")
|
||
mod_path, class_name = entry.split(":", 1)
|
||
rel = mod_path.replace(".", "/").lstrip("/")
|
||
file_path = plugin_dir / f"{rel}.py"
|
||
if not file_path.exists():
|
||
raise FileNotFoundError(f"plugin {plugin_name} entry file not found: {file_path}")
|
||
|
||
# 注册虚拟 root package(如 ``_kilostar_plugin_data_analytics``)+ 入口所在子包
|
||
# (如 ``_kilostar_plugin_data_analytics.core``),这样 ``from .db import Base``
|
||
# 才能在 spec_from_file_location 加载的模块里正常解析。
|
||
import types as _types
|
||
|
||
root_pkg = f"_kilostar_plugin_{plugin_name}"
|
||
if root_pkg not in sys.modules:
|
||
root_mod = _types.ModuleType(root_pkg)
|
||
root_mod.__path__ = [str(plugin_dir)]
|
||
sys.modules[root_pkg] = root_mod
|
||
|
||
parts = mod_path.replace("/", ".").split(".")
|
||
cur_pkg = root_pkg
|
||
cur_dir = plugin_dir
|
||
for p in parts[:-1]:
|
||
cur_pkg = f"{cur_pkg}.{p}"
|
||
cur_dir = cur_dir / p
|
||
if cur_pkg not in sys.modules:
|
||
sub_mod = _types.ModuleType(cur_pkg)
|
||
sub_mod.__path__ = [str(cur_dir)]
|
||
sys.modules[cur_pkg] = sub_mod
|
||
|
||
module_name = f"{root_pkg}.{mod_path.replace('/', '.')}"
|
||
spec = importlib.util.spec_from_file_location(module_name, str(file_path))
|
||
if spec is None or spec.loader is None:
|
||
raise RuntimeError(f"cannot load module {module_name}")
|
||
mod = importlib.util.module_from_spec(spec)
|
||
mod.__package__ = ".".join(module_name.split(".")[:-1])
|
||
sys.modules[module_name] = mod
|
||
spec.loader.exec_module(mod)
|
||
|
||
cls = getattr(mod, class_name, None)
|
||
if cls is None:
|
||
raise AttributeError(f"plugin {plugin_name}: {class_name} not found in {file_path}")
|
||
return cls
|
||
|
||
|
||
async def install_dependencies(deps_python: List[str]) -> None:
|
||
"""用 uv 安装组织声明的 python 依赖。
|
||
|
||
第一版直接装到主 venv,简单粗暴;viceroy 接管后这步会被替换。
|
||
"""
|
||
if not deps_python:
|
||
return
|
||
import asyncio as _asyncio
|
||
|
||
cmd = ["uv", "pip", "install", *deps_python]
|
||
proc = await _asyncio.create_subprocess_exec(
|
||
*cmd,
|
||
stdout=_asyncio.subprocess.PIPE,
|
||
stderr=_asyncio.subprocess.PIPE,
|
||
)
|
||
stdout, stderr = await proc.communicate()
|
||
if proc.returncode != 0:
|
||
raise RuntimeError(
|
||
f"uv pip install failed (rc={proc.returncode}): {stderr.decode()}"
|
||
)
|
||
logger.info(f"installed deps: {deps_python}")
|
||
|
||
|
||
def discover_plugin_api(plugin_dir: Path, plugin_name: str) -> Any:
|
||
"""加载 ``<plugin_dir>/api.py``,返回模块的 ``router`` 属性(或 None)。
|
||
|
||
约定:插件如需暴露 HTTP 路由,在自己根目录写一个 ``api.py``,里面实例化
|
||
``router = APIRouter(...)`` 并按业务挂端点。主程序启动期统一以
|
||
``manifest.api_prefix`` 把它 include 到 FastAPI app。
|
||
"""
|
||
api_path = plugin_dir / "api.py"
|
||
if not api_path.exists():
|
||
return None
|
||
module_name = f"data.plugin.{plugin_name}.api"
|
||
spec = importlib.util.spec_from_file_location(module_name, str(api_path))
|
||
if spec is None or spec.loader is None:
|
||
logger.warning(f"plugin {plugin_name}: cannot load api.py at {api_path}")
|
||
return None
|
||
mod = importlib.util.module_from_spec(spec)
|
||
sys.modules[module_name] = mod
|
||
try:
|
||
spec.loader.exec_module(mod)
|
||
except Exception as e:
|
||
logger.warning(f"plugin {plugin_name}: api.py import failed: {e}")
|
||
return None
|
||
return getattr(mod, "router", None)
|
||
|
||
|
||
def collect_plugin_routers(plugin_root: Path) -> List[tuple]:
|
||
"""扫描所有插件,返回 ``[(api_prefix, router)]`` 列表。
|
||
|
||
用于 FastAPI 启动期统一挂载。纯文件扫描,不依赖任何 actor,避免启动顺序耦合。
|
||
无 ``api.py`` / 加载失败 / 缺 ``api_prefix`` 的插件被静默跳过。
|
||
"""
|
||
out: List[tuple] = []
|
||
for plugin_dir in discover_plugins(plugin_root):
|
||
try:
|
||
with open(plugin_dir / "manifest.json", "r", encoding="utf-8") as f:
|
||
manifest = OrgManifest.model_validate(json.load(f))
|
||
except Exception as e:
|
||
logger.warning(f"skip plugin {plugin_dir.name} (manifest invalid): {e}")
|
||
continue
|
||
if not manifest.api_prefix:
|
||
continue
|
||
router = discover_plugin_api(plugin_dir, manifest.name)
|
||
if router is None:
|
||
continue
|
||
out.append((manifest.api_prefix, router))
|
||
logger.info(f"discovered plugin router: {manifest.name} @ {manifest.api_prefix}")
|
||
return out
|