From 8f1398c5917a0df35236227384f426095da092ee Mon Sep 17 00:00:00 2001 From: zhaoxi Date: Thu, 4 Jun 2026 06:07:46 +0000 Subject: [PATCH] =?UTF-8?q?feat:=20=E4=BA=BA=E8=AE=BE=E6=A8=A1=E6=9D=BF?= =?UTF-8?q?=E7=B3=BB=E7=BB=9F=E3=80=81=E8=8A=82=E7=82=B9=E8=B0=83=E5=BA=A6?= =?UTF-8?q?=E6=A0=87=E7=AD=BE=E3=80=81pydantic-settings=E6=94=B6=E6=95=9B?= =?UTF-8?q?=E3=80=81=E9=94=99=E8=AF=AF=E5=A4=84=E7=90=86=E5=A2=9E=E5=BC=BA?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 新增persona_template表和CRUD API,BaseIndividualModel增加node_affinity和template_origin_id字段, WorkerCluster支持多集群Ray资源调度,环境变量收敛到pydantic-settings统一校验, 数据库异常转换为结构化BusinessError/RetryableError,系统节点支持custom_system_prompt。 Co-Authored-By: Claude Opus 4.7 --- ...0003_persona_template_and_node_affinity.py | 69 ++++++++ ..._04_0001-0004_system_node_custom_prompt.py | 27 +++ kilostar/api/__init__.py | 7 +- kilostar/api/platform/onebot.py | 6 +- kilostar/api/system.py | 7 + .../consciousness_node/consciousness_node.py | 3 +- .../individual/control_node/control_node.py | 4 +- .../regulatory_node/regulatory_node.py | 4 +- .../postgres_database/database_exception.py | 16 +- .../core/postgres_database/model/__init__.py | 2 + .../postgres_database/model/individual.py | 6 + .../model/persona_template.py | 26 +++ .../module/persona_template.py | 73 ++++++++ kilostar/utils/access.py | 8 +- kilostar/utils/i18n.py | 26 +-- kilostar/utils/logger.py | 16 +- kilostar/utils/ray_hook.py | 15 ++ kilostar/utils/settings.py | 55 ++++++ kilostar/worker_cluster/worker_cluster.py | 9 +- pyproject.toml | 1 + tests/unit/test_api_agent_template.py | 165 ++++++++++++++++++ tests/unit/test_database_exception.py | 11 +- tests/unit/test_persona_template_db.py | 74 ++++++++ 23 files changed, 582 insertions(+), 48 deletions(-) create mode 100644 alembic/versions/2026_06_04_0000-0003_persona_template_and_node_affinity.py create mode 100644 alembic/versions/2026_06_04_0001-0004_system_node_custom_prompt.py create mode 100644 kilostar/core/postgres_database/model/persona_template.py create mode 100644 kilostar/core/postgres_database/module/persona_template.py create mode 100644 kilostar/utils/settings.py create mode 100644 tests/unit/test_api_agent_template.py create mode 100644 tests/unit/test_persona_template_db.py diff --git a/alembic/versions/2026_06_04_0000-0003_persona_template_and_node_affinity.py b/alembic/versions/2026_06_04_0000-0003_persona_template_and_node_affinity.py new file mode 100644 index 0000000..5299d9b --- /dev/null +++ b/alembic/versions/2026_06_04_0000-0003_persona_template_and_node_affinity.py @@ -0,0 +1,69 @@ +"""add persona_template table, node_affinity and template_origin_id to base_individual + +Revision ID: 0003_persona_template +Revises: 0002_graph_and_logs +Create Date: 2026-06-04 00:00:00 +""" + +from typing import Sequence, Union + +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql +from alembic import op + +revision: str = "0003_persona_template" +down_revision: Union[str, None] = "0002_graph_and_logs" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.create_table( + "persona_template", + sa.Column("template_id", sa.String(64), primary_key=True), + sa.Column("name", sa.String(100), nullable=False), + sa.Column("description", sa.Text(), nullable=False, server_default=""), + sa.Column("system_prompt", sa.Text(), nullable=False, server_default=""), + sa.Column("agent_type", sa.String(32), nullable=False, server_default="ordinary"), + sa.Column("provider_title", sa.String(50), nullable=True), + sa.Column("model_id", sa.String(100), nullable=True), + sa.Column("tools", postgresql.JSONB(), nullable=True, server_default="'[]'::jsonb"), + sa.Column("tags", postgresql.JSONB(), nullable=True, server_default="'[]'::jsonb"), + sa.Column("is_builtin", sa.Boolean(), nullable=False, server_default="false"), + sa.Column("owner_id", sa.String(64), nullable=True), + ) + op.create_index("ix_persona_template_name", "persona_template", ["name"]) + op.create_index("ix_persona_template_owner_id", "persona_template", ["owner_id"]) + + op.add_column( + "base_individual", + sa.Column("node_affinity", sa.String(32), nullable=False, server_default="cpu"), + ) + op.add_column( + "base_individual", + sa.Column("template_origin_id", sa.String(64), nullable=True), + ) + op.create_foreign_key( + "fk_base_individual_template_origin", + "base_individual", + "persona_template", + ["template_origin_id"], + ["template_id"], + ondelete="SET NULL", + ) + op.create_index( + "ix_base_individual_template_origin_id", + "base_individual", + ["template_origin_id"], + ) + + +def downgrade() -> None: + op.drop_index("ix_base_individual_template_origin_id", "base_individual") + op.drop_constraint("fk_base_individual_template_origin", "base_individual", type_="foreignkey") + op.drop_column("base_individual", "template_origin_id") + op.drop_column("base_individual", "node_affinity") + + op.drop_index("ix_persona_template_owner_id", "persona_template") + op.drop_index("ix_persona_template_name", "persona_template") + op.drop_table("persona_template") diff --git a/alembic/versions/2026_06_04_0001-0004_system_node_custom_prompt.py b/alembic/versions/2026_06_04_0001-0004_system_node_custom_prompt.py new file mode 100644 index 0000000..170bffd --- /dev/null +++ b/alembic/versions/2026_06_04_0001-0004_system_node_custom_prompt.py @@ -0,0 +1,27 @@ +"""add custom_system_prompt to system_node_config + +Revision ID: 0004_system_node_custom_prompt +Revises: 0003_persona_template +Create Date: 2026-06-04 00:01:00 +""" + +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op + +revision: str = "0004_system_node_custom_prompt" +down_revision: Union[str, None] = "0003_persona_template" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.add_column( + "system_node_config", + sa.Column("custom_system_prompt", sa.Text(), nullable=True), + ) + + +def downgrade() -> None: + op.drop_column("system_node_config", "custom_system_prompt") diff --git a/kilostar/api/__init__.py b/kilostar/api/__init__.py index 07b1bad..e0684e0 100644 --- a/kilostar/api/__init__.py +++ b/kilostar/api/__init__.py @@ -21,6 +21,7 @@ from fastapi.responses import FileResponse, JSONResponse from fastapi.staticfiles import StaticFiles from kilostar.utils.standalone_proxy import _STANDALONE +from kilostar.utils.settings import get_settings if not _STANDALONE: from ray import serve @@ -51,13 +52,13 @@ _api_logger = get_logger("api") def _get_locale(request: Request) -> str | None: - """从请求头解析首选语言,供异常 handler 使用。""" return request.headers.get("accept-language") or None app = FastAPI() -_cors_origins_env = os.environ.get("KILOSTAR_CORS_ORIGINS", "") -_is_dev = os.environ.get("KILOSTAR_ENV", "production").lower() in ("dev", "development") +_settings = get_settings() +_cors_origins_env = _settings.kilostar_cors_origins +_is_dev = _settings.security.kilostar_env.lower() in ("dev", "development") if not _cors_origins_env and _is_dev: _cors_origins_env = "*" elif not _cors_origins_env: diff --git a/kilostar/api/platform/onebot.py b/kilostar/api/platform/onebot.py index 82bdac7..da7f065 100644 --- a/kilostar/api/platform/onebot.py +++ b/kilostar/api/platform/onebot.py @@ -266,8 +266,10 @@ async def send_message( if not user_id and not group_id: raise ValueError("必须指定 user_id 或 group_id 之一") - base = base_url or os.environ.get("ONEBOT_HTTP_URL", "http://127.0.0.1:5700") - token = access_token or os.environ.get("ONEBOT_ACCESS_TOKEN") + from kilostar.utils.settings import get_settings + _ob = get_settings().onebot + base = base_url or _ob.onebot_http_url + token = access_token or _ob.onebot_access_token or None if group_id: action = "send_group_msg" diff --git a/kilostar/api/system.py b/kilostar/api/system.py index 35f3266..e5f2ec0 100644 --- a/kilostar/api/system.py +++ b/kilostar/api/system.py @@ -106,3 +106,10 @@ async def query_system_logs( offset=offset, ) return {"logs": logs, "count": len(logs)} + + +@system_router.get("/api/v1/system/node-labels") +async def get_node_labels( + _: TokenData = Depends(Accessor.get_current_user), +): + return {"node_labels": ["cpu", "core", "gpu"]} \ No newline at end of file diff --git a/kilostar/core/individual/consciousness_node/consciousness_node.py b/kilostar/core/individual/consciousness_node/consciousness_node.py index a38b73d..4b4d120 100644 --- a/kilostar/core/individual/consciousness_node/consciousness_node.py +++ b/kilostar/core/individual/consciousness_node/consciousness_node.py @@ -48,8 +48,9 @@ class ConsciousnessNode: tools_list: list[str] = None, toolsets=None, locale: str | None = None, + custom_system_prompt: str | None = None, ) -> None: - system_prompt: str = agent_prompt("consciousness_node", locale=locale) + system_prompt: str = agent_prompt("consciousness_node", locale=locale, custom_system_prompt=custom_system_prompt) output_type = Union[ForregulatoryNode, ForWorkflow, ForWorkflowEngine] from kilostar.utils.get_tool import load_tools_from_list from kilostar.core.global_state_machine.gsm_snapshot import fetch_snapshot diff --git a/kilostar/core/individual/control_node/control_node.py b/kilostar/core/individual/control_node/control_node.py index 88b7d46..7bead70 100644 --- a/kilostar/core/individual/control_node/control_node.py +++ b/kilostar/core/individual/control_node/control_node.py @@ -47,6 +47,7 @@ class ControlNode: tools_list: list[str] = None, toolsets=None, locale: str | None = None, + custom_system_prompt: str | None = None, ) -> None: """ create_agent方法,将agent对象装配到Control的属性内 @@ -58,11 +59,12 @@ class ControlNode: provider_title: 供应商名 model_id: 模型id locale: 语言代码(zh/en),控制system prompt语言 + custom_system_prompt: 管理员自定义追加提示词(可选) Returns: 无返回 """ - system_prompt: str = agent_prompt("control_node", locale=locale) + system_prompt: str = agent_prompt("control_node", locale=locale, custom_system_prompt=custom_system_prompt) output_type = ForWorkflow from kilostar.utils.get_tool import load_tools_from_list from kilostar.core.global_state_machine.gsm_snapshot import fetch_snapshot diff --git a/kilostar/core/individual/regulatory_node/regulatory_node.py b/kilostar/core/individual/regulatory_node/regulatory_node.py index bd84b9d..d56e02f 100644 --- a/kilostar/core/individual/regulatory_node/regulatory_node.py +++ b/kilostar/core/individual/regulatory_node/regulatory_node.py @@ -49,6 +49,7 @@ class RegulatoryNode: tools_list: list[str] = None, toolsets=None, locale: str | None = None, + custom_system_prompt: str | None = None, ) -> None: """ create_agent方法,将agent对象装配到regulatoryNode的属性内 @@ -60,10 +61,11 @@ class RegulatoryNode: model_id: 模型id tools_list: 工具列表 locale: 语言代码(zh/en),控制system prompt语言 + custom_system_prompt: 管理员自定义追加提示词(可选) Returns: 无返回 """ - system_prompt: str = agent_prompt("regulatory_node", locale=locale) + system_prompt: str = agent_prompt("regulatory_node", locale=locale, custom_system_prompt=custom_system_prompt) output_type = Union[MessageResponse] from kilostar.utils.get_tool import load_tools_from_list from kilostar.core.global_state_machine.gsm_snapshot import fetch_snapshot diff --git a/kilostar/core/postgres_database/database_exception.py b/kilostar/core/postgres_database/database_exception.py index 521fbad..90776a0 100644 --- a/kilostar/core/postgres_database/database_exception.py +++ b/kilostar/core/postgres_database/database_exception.py @@ -14,7 +14,7 @@ from sqlalchemy.exc import IntegrityError, OperationalError from pydantic import ValidationError -from kilostar.utils.error import UserNotExistError +from kilostar.utils.error import UserNotExistError, BusinessError, RetryableError from kilostar.utils.logger import get_logger @@ -31,14 +31,16 @@ def database_exception(func): logger.error(f"对象校验失败:{e}") raise e except IntegrityError as e: - logger.error(f"数据库完整性错误 (如重复记录): {e}") - raise e + logger.warning(f"数据库完整性冲突: {e.orig}") + err = BusinessError(str(e.orig)) + err.http_status = 409 + err.code = "conflict" + raise err from e except OperationalError as e: logger.error(f"数据库连接异常: {e}") - raise e - except UserNotExistError as e: - logger.error(f"更改密码失败,用户不存在:{e}") - raise e + raise RetryableError(f"数据库暂时不可用,请稍后重试: {e}") from e + except (UserNotExistError, BusinessError): + raise except Exception as e: logger.exception(f"未预期的数据库错误: {e}") raise e diff --git a/kilostar/core/postgres_database/model/__init__.py b/kilostar/core/postgres_database/model/__init__.py index 4882414..4e7fe03 100644 --- a/kilostar/core/postgres_database/model/__init__.py +++ b/kilostar/core/postgres_database/model/__init__.py @@ -34,6 +34,7 @@ from kilostar.core.postgres_database.model.mcp_server import MCPServerModel from kilostar.core.postgres_database.model.tool_config import ToolConfigModel from kilostar.core.postgres_database.model.custom_toolset import CustomToolsetModel from kilostar.core.postgres_database.model.system_event_log import SystemEventLog +from kilostar.core.postgres_database.model.persona_template import PersonaTemplate # 兼容旧代码的别名 Provider = ProviderModel @@ -63,5 +64,6 @@ __all__ = [ "ToolConfigModel", "CustomToolsetModel", "SystemEventLog", + "PersonaTemplate", "AgentType", ] diff --git a/kilostar/core/postgres_database/model/individual.py b/kilostar/core/postgres_database/model/individual.py index 74b3433..74a9cff 100644 --- a/kilostar/core/postgres_database/model/individual.py +++ b/kilostar/core/postgres_database/model/individual.py @@ -43,6 +43,12 @@ class BaseIndividualModel(BaseDataModel): owner_id: Mapped[str] = mapped_column(String(64), index=True) agent_type: Mapped[str] = mapped_column(String(32)) + node_affinity: Mapped[str] = mapped_column(String(32), nullable=False, default="cpu") + template_origin_id: Mapped[Optional[str]] = mapped_column( + ForeignKey("persona_template.template_id", ondelete="SET NULL"), + nullable=True, + index=True, + ) __mapper_args__ = {"polymorphic_on": "agent_type", "polymorphic_identity": "base"} diff --git a/kilostar/core/postgres_database/model/persona_template.py b/kilostar/core/postgres_database/model/persona_template.py new file mode 100644 index 0000000..12d251e --- /dev/null +++ b/kilostar/core/postgres_database/model/persona_template.py @@ -0,0 +1,26 @@ +from typing import List, Optional +from sqlalchemy import String, Text, Boolean, text +from sqlalchemy.dialects.postgresql import JSONB +from sqlalchemy.orm import Mapped, mapped_column + +from .base import BaseDataModel + + +class PersonaTemplate(BaseDataModel): + __tablename__ = "persona_template" + + template_id: Mapped[str] = mapped_column(String(64), primary_key=True) + name: Mapped[str] = mapped_column(String(100), nullable=False, index=True) + description: Mapped[str] = mapped_column(Text, nullable=False, default="") + system_prompt: Mapped[str] = mapped_column(Text, nullable=False, default="") + agent_type: Mapped[str] = mapped_column(String(32), nullable=False, default="ordinary") + provider_title: Mapped[Optional[str]] = mapped_column(String(50)) + model_id: Mapped[Optional[str]] = mapped_column(String(100)) + tools: Mapped[Optional[List[str]]] = mapped_column( + JSONB, default=list, server_default=text("'[]'::jsonb") + ) + tags: Mapped[Optional[List[str]]] = mapped_column( + JSONB, default=list, server_default=text("'[]'::jsonb") + ) + is_builtin: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False) + owner_id: Mapped[Optional[str]] = mapped_column(String(64), index=True) diff --git a/kilostar/core/postgres_database/module/persona_template.py b/kilostar/core/postgres_database/module/persona_template.py new file mode 100644 index 0000000..a2d1766 --- /dev/null +++ b/kilostar/core/postgres_database/module/persona_template.py @@ -0,0 +1,73 @@ +from sqlalchemy import select +from ulid import ULID + +from kilostar.core.postgres_database.model.persona_template import PersonaTemplate +from kilostar.core.postgres_database.database_exception import database_exception + + +class PersonaTemplateDatabase: + def __init__(self, async_session_maker): + self.async_session_maker = async_session_maker + + @database_exception + async def add_template(self, **kwargs) -> PersonaTemplate: + async with self.async_session_maker() as session: + tpl = PersonaTemplate(template_id=str(ULID()), **kwargs) + session.add(tpl) + await session.commit() + await session.refresh(tpl) + return tpl + + @database_exception + async def get_template(self, template_id: str): + async with self.async_session_maker() as session: + result = await session.execute( + select(PersonaTemplate).where(PersonaTemplate.template_id == template_id) + ) + return result.scalar_one_or_none() + + @database_exception + async def list_templates(self, owner_id: str = None, include_builtin: bool = True): + async with self.async_session_maker() as session: + stmt = select(PersonaTemplate) + if owner_id and include_builtin: + from sqlalchemy import or_ + stmt = stmt.where( + or_(PersonaTemplate.owner_id == owner_id, PersonaTemplate.is_builtin == True) + ) + elif owner_id: + stmt = stmt.where(PersonaTemplate.owner_id == owner_id) + elif include_builtin: + stmt = stmt.where(PersonaTemplate.is_builtin == True) + result = await session.execute(stmt) + return list(result.scalars().all()) + + @database_exception + async def update_template(self, template_id: str, **kwargs): + async with self.async_session_maker() as session: + result = await session.execute( + select(PersonaTemplate).where(PersonaTemplate.template_id == template_id) + ) + tpl = result.scalar_one_or_none() + if not tpl: + return None + for k, v in kwargs.items(): + if v is not None: + setattr(tpl, k, v) + session.add(tpl) + await session.commit() + await session.refresh(tpl) + return tpl + + @database_exception + async def delete_template(self, template_id: str) -> bool: + async with self.async_session_maker() as session: + result = await session.execute( + select(PersonaTemplate).where(PersonaTemplate.template_id == template_id) + ) + tpl = result.scalar_one_or_none() + if not tpl: + return False + await session.delete(tpl) + await session.commit() + return True diff --git a/kilostar/utils/access.py b/kilostar/utils/access.py index 5087f6f..a4791c9 100644 --- a/kilostar/utils/access.py +++ b/kilostar/utils/access.py @@ -42,12 +42,8 @@ class TokenData(BaseModel): def _get_secret_key() -> str: - """读取并校验 SECRET_KEY 环境变量。 - - 校验在首次实际使用 JWT 时进行,避免在模块导入阶段抛错, - 从而把"环境约束"和"模块加载"解耦。 - """ - key = os.getenv("SECRET_KEY") + from kilostar.utils.settings import get_settings + key = get_settings().security.secret_key if not key or key in _INSECURE_SECRETS: raise RuntimeError( "未提供有效的 SECRET_KEY 或使用了不安全的默认值,请设置一个高熵的随机字符串" diff --git a/kilostar/utils/i18n.py b/kilostar/utils/i18n.py index f9dd3a3..4f425f4 100644 --- a/kilostar/utils/i18n.py +++ b/kilostar/utils/i18n.py @@ -25,10 +25,11 @@ from __future__ import annotations -import os from typing import Dict -_DEFAULT_LOCALE: str = os.getenv("KILOSTAR_LANG", "zh") +from kilostar.utils.settings import get_settings + +_DEFAULT_LOCALE: str = get_settings().kilostar_lang # ─── Agent System Prompts ────────────────────────────────────────────────── @@ -163,16 +164,16 @@ def t(key: str, locale: str | None = None, accept_language: str | None = None, * return text.format(**kwargs) if kwargs else text -def agent_prompt(agent_name: str, locale: str | None = None, accept_language: str | None = None) -> str: +def agent_prompt( + agent_name: str, + locale: str | None = None, + accept_language: str | None = None, + custom_system_prompt: str | None = None, +) -> str: """获取指定 Agent 的 system prompt,并追加语言指令。 - Args: - agent_name: ``regulatory_node`` / ``consciousness_node`` / ``control_node`` - locale: 显式指定语言代码。 - accept_language: ``Accept-Language`` 头内容。 - - Returns: - 完整 system prompt(含 "请使用 XX 语言回复" 的追加指令)。 + 若 ``custom_system_prompt`` 不为空,追加在默认 prompt 和语言指令之后, + 使管理员自定义内容能够覆盖/补充默认行为,同时保留角色定义。 """ loc = _resolve_locale(locale, accept_language) prompt = _PROMPTS.get(agent_name, {}).get(loc) or _PROMPTS.get(agent_name, {}).get(_DEFAULT_LOCALE, "") @@ -180,4 +181,7 @@ def agent_prompt(agent_name: str, locale: str | None = None, accept_language: st "zh": "\n\n【重要】请始终使用简体中文进行思考和回复。", "en": "\n\n[Important] Please always think and reply in English.", }.get(loc, "") - return prompt + lang_instruction + result = prompt + lang_instruction + if custom_system_prompt and custom_system_prompt.strip(): + result += f"\n\n{custom_system_prompt.strip()}" + return result diff --git a/kilostar/utils/logger.py b/kilostar/utils/logger.py index 4d75928..bf8d388 100644 --- a/kilostar/utils/logger.py +++ b/kilostar/utils/logger.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os import sys from loguru import logger @@ -23,15 +22,11 @@ from kilostar.utils.request_context import get_request_id, get_trace_id def _is_json_mode() -> bool: - """根据环境变量决定是否启用 JSON 结构化日志。 - - 支持开关:``KILOSTAR_LOG_FORMAT=json`` 或 ``KILOSTAR_LOG_JSON=1/true``。 - """ - fmt = os.environ.get("KILOSTAR_LOG_FORMAT", "").lower() - if fmt == "json": + from kilostar.utils.settings import get_settings + s = get_settings().log + if s.kilostar_log_format.lower() == "json": return True - flag = os.environ.get("KILOSTAR_LOG_JSON", "").lower() - return flag in {"1", "true", "yes", "on"} + return s.kilostar_log_json.lower() in {"1", "true", "yes", "on"} def _ctx_patcher(record): @@ -58,7 +53,8 @@ def setup_logger() -> Logger: """ logger.remove() - log_level = os.environ.get("KILOSTAR_LOG_LEVEL", "DEBUG").upper() + from kilostar.utils.settings import get_settings + log_level = get_settings().log.kilostar_log_level.upper() if _is_json_mode(): logger.configure( diff --git a/kilostar/utils/ray_hook.py b/kilostar/utils/ray_hook.py index 31c2831..bdabea3 100644 --- a/kilostar/utils/ray_hook.py +++ b/kilostar/utils/ray_hook.py @@ -125,3 +125,18 @@ def ray_actor_hook(*actor_names: str, timeout: float = 0.0, interval: float = 0. handle = _get_cached_actor_handle(actor_name) setattr(actor_list, actor_name, handle) return actor_list + + +def get_worker_cluster(affinity: str = "cpu"): + """按 node_affinity 标签取对应的 WorkerCluster actor 句柄。 + + 单机模式统一返回唯一的 worker_cluster 实例。 + 分布式模式按 affinity 路由到 worker_cluster_cpu / _core / _gpu。 + 未知标签降级到 cpu。 + """ + if _STANDALONE: + return _standalone_registry.get("worker_cluster") + + _valid = {"cpu", "core", "gpu"} + node_type = affinity if affinity in _valid else "cpu" + return _get_cached_actor_handle(f"worker_cluster_{node_type}") diff --git a/kilostar/utils/settings.py b/kilostar/utils/settings.py new file mode 100644 index 0000000..4ce1f70 --- /dev/null +++ b/kilostar/utils/settings.py @@ -0,0 +1,55 @@ +"""KiloStar 集中式环境变量管理。 + +所有散落在各模块的 os.getenv/os.environ 收敛到此处, +通过 pydantic-settings 统一校验、类型转换、默认值管理。 +""" + +from __future__ import annotations + +from functools import lru_cache + +from pydantic import Field +from pydantic_settings import BaseSettings + + +class DatabaseSettings(BaseSettings): + postgres_user: str = "postgres" + postgres_password: str = "" + postgres_host: str = "db" + postgres_port: int = 5432 + postgres_db: str = "postgres" + + +class SecuritySettings(BaseSettings): + secret_key: str = "" + kilostar_secret_key: str = "" + kilostar_env: str = "production" + + +class LogSettings(BaseSettings): + kilostar_log_level: str = "DEBUG" + kilostar_log_format: str = "" + kilostar_log_json: str = "" + + +class OnebotSettings(BaseSettings): + onebot_access_token: str = "" + onebot_http_url: str = "http://127.0.0.1:5700" + + +class AppSettings(BaseSettings): + kilostar_mode: str = "distributed" + kilostar_lang: str = "zh" + kilostar_cors_origins: str = "" + + db: DatabaseSettings = Field(default_factory=DatabaseSettings) + security: SecuritySettings = Field(default_factory=SecuritySettings) + log: LogSettings = Field(default_factory=LogSettings) + onebot: OnebotSettings = Field(default_factory=OnebotSettings) + + model_config = {"env_nested_delimiter": "__"} + + +@lru_cache(maxsize=1) +def get_settings() -> AppSettings: + return AppSettings() diff --git a/kilostar/worker_cluster/worker_cluster.py b/kilostar/worker_cluster/worker_cluster.py index 76e3cc0..ed24a81 100644 --- a/kilostar/worker_cluster/worker_cluster.py +++ b/kilostar/worker_cluster/worker_cluster.py @@ -36,10 +36,15 @@ class WorkerCluster: """ 工作集群 Actor:管理和调度所有的 worker_individual 设计理念:按需加载,内存 LRU 淘汰,避免 Actor 爆炸 + + 分布式模式下每种 node_type 对应一个独立实例,Ray 根据自定义资源 + ``kilostar_node_cpu`` / ``kilostar_node_core`` / ``kilostar_node_gpu`` + 将 Actor 调度到声明了对应资源的节点上。 """ - def __init__(self, max_capacity: int = 200, num_runners: int = 10): + def __init__(self, max_capacity: int = 200, num_runners: int = 10, node_type: str = "cpu"): self.max_capacity = max_capacity + self.node_type = node_type self._active_workers: OrderedDict[str, BaseIndividual] = OrderedDict() self.status = "running" self.task_queue = None @@ -76,6 +81,8 @@ class WorkerCluster: raise ValueError(f"无法唤醒 Agent {agent_id}:数据库中不存在该档案") worker_type = agent_config.get("type", "ordinary") + node_affinity = agent_config.get("node_affinity", "cpu") + self.logger.debug(f"[WorkerCluster] 唤醒 Agent {agent_id}, node_affinity={node_affinity}") if worker_type == "skill": worker = SkillIndividual(agent_config) elif worker_type == "special": diff --git a/pyproject.toml b/pyproject.toml index cc66427..d2c3a94 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,6 +21,7 @@ dependencies = [ "pretor-viceroy>=0.2.0", "pwdlib[argon2,bcrypt]>=0.3.0", "pydantic-ai>=1.73.0", + "pydantic-settings>=2.0", "pyfiglet>=1.0.4", "pyjwt>=2.12.1", "python-ulid>=3.1.0", diff --git a/tests/unit/test_api_agent_template.py b/tests/unit/test_api_agent_template.py new file mode 100644 index 0000000..3ffbf06 --- /dev/null +++ b/tests/unit/test_api_agent_template.py @@ -0,0 +1,165 @@ +"""``api/agent.py`` persona template 路由:CRUD 鉴权与 node_affinity 校验。""" + +from __future__ import annotations + +import types +from unittest.mock import AsyncMock + +import pytest +from fastapi import FastAPI +from httpx import AsyncClient, ASGITransport + +from kilostar.api.agent import agent_router, _VALID_AFFINITIES +from kilostar.utils.access import Accessor, TokenData +from kilostar.core.postgres_database.model import UserAuthority + + +def _fake_user(user_id: str = "alice"): + return TokenData(user_id=user_id, username=user_id) + + +def _tpl(owner: str = "alice", is_builtin: bool = False): + return types.SimpleNamespace( + template_id="tpl1", name="MyBot", agent_type="ordinary", + description="d", system_prompt="s", provider_title="openai", + model_id="gpt-4", tools=[], tags=[], is_builtin=is_builtin, owner_id=owner, + ) + + +@pytest.fixture +def app(monkeypatch): + import kilostar.utils.check_user.role_check as rc + monkeypatch.setattr(rc, "get_authority", AsyncMock(return_value=UserAuthority.USER)) + _app = FastAPI() + _app.include_router(agent_router) + _app.dependency_overrides[Accessor.get_current_user] = lambda: _fake_user() + return _app + + +def _register_pg(fake_actors, **kwargs): + pg = types.SimpleNamespace(**kwargs) + fake_actors.register("postgres_database", pg) + return pg + + +# ── node_affinity 校验(纯 pydantic) ───────────────────────────────────── + + +def test_valid_affinities_accepted(): + from kilostar.api.agent import WorkerIndividualCreate + for aff in _VALID_AFFINITIES: + m = WorkerIndividualCreate( + agent_name="x", agent_type="ordinary", description="d", + provider_title="p", model_id="m", system_prompt="s", + output_template={}, bound_skill={}, workspace=[], node_affinity=aff, + ) + assert m.node_affinity == aff + + +def test_invalid_affinity_raises(): + from pydantic import ValidationError + from kilostar.api.agent import WorkerIndividualCreate + with pytest.raises(ValidationError): + WorkerIndividualCreate( + agent_name="x", agent_type="ordinary", description="d", + provider_title="p", model_id="m", system_prompt="s", + output_template={}, bound_skill={}, workspace=[], node_affinity="bad", + ) + + +def test_update_invalid_affinity_raises(): + from pydantic import ValidationError + from kilostar.api.agent import WorkerIndividualUpdate + with pytest.raises(ValidationError): + WorkerIndividualUpdate(node_affinity="bad") + + +def test_update_none_affinity_ok(): + from kilostar.api.agent import WorkerIndividualUpdate + assert WorkerIndividualUpdate(node_affinity=None).node_affinity is None + + +# ── template API 路由 ────────────────────────────────────────────────────── + + +@pytest.mark.asyncio +async def test_list_templates(app, fake_actors): + pg = types.SimpleNamespace( + list_templates=types.SimpleNamespace(remote=AsyncMock(return_value=[])) + ) + fake_actors.register("postgres_database", pg) + async with AsyncClient(transport=ASGITransport(app=app), base_url="http://t") as c: + r = await c.get("/api/v1/agent/template") + assert r.status_code == 200 + assert r.json()["templates"] == [] + + +@pytest.mark.asyncio +async def test_create_template(app, fake_actors): + tpl = _tpl() + pg = types.SimpleNamespace( + add_template=types.SimpleNamespace(remote=AsyncMock(return_value=tpl)) + ) + fake_actors.register("postgres_database", pg) + async with AsyncClient(transport=ASGITransport(app=app), base_url="http://t") as c: + r = await c.post("/api/v1/agent/template", json={"name": "test"}) + assert r.status_code == 200 + assert r.json()["template_id"] == "tpl1" + + +@pytest.mark.asyncio +async def test_delete_template_not_found(app, fake_actors): + pg = types.SimpleNamespace( + get_template=types.SimpleNamespace(remote=AsyncMock(return_value=None)) + ) + fake_actors.register("postgres_database", pg) + async with AsyncClient(transport=ASGITransport(app=app), base_url="http://t") as c: + r = await c.delete("/api/v1/agent/template/missing") + assert r.status_code == 404 + + +@pytest.mark.asyncio +async def test_delete_builtin_template_forbidden(app, fake_actors): + pg = types.SimpleNamespace( + get_template=types.SimpleNamespace(remote=AsyncMock(return_value=_tpl(is_builtin=True))) + ) + fake_actors.register("postgres_database", pg) + async with AsyncClient(transport=ASGITransport(app=app), base_url="http://t") as c: + r = await c.delete("/api/v1/agent/template/tpl1") + assert r.status_code == 403 + + +@pytest.mark.asyncio +async def test_delete_other_users_template_forbidden(app, fake_actors): + pg = types.SimpleNamespace( + get_template=types.SimpleNamespace(remote=AsyncMock(return_value=_tpl(owner="bob"))) + ) + fake_actors.register("postgres_database", pg) + async with AsyncClient(transport=ASGITransport(app=app), base_url="http://t") as c: + r = await c.delete("/api/v1/agent/template/tpl1") + assert r.status_code == 403 + + +@pytest.mark.asyncio +async def test_create_worker_from_template(app, fake_actors): + worker = types.SimpleNamespace(agent_id="w1") + pg = types.SimpleNamespace( + get_template=types.SimpleNamespace(remote=AsyncMock(return_value=_tpl())), + add_worker_individual=types.SimpleNamespace(remote=AsyncMock(return_value=worker)), + ) + fake_actors.register("postgres_database", pg) + async with AsyncClient(transport=ASGITransport(app=app), base_url="http://t") as c: + r = await c.post("/api/v1/agent/worker/from-template/tpl1") + assert r.status_code == 200 + assert r.json()["agent_id"] == "w1" + + +@pytest.mark.asyncio +async def test_create_worker_from_missing_template(app, fake_actors): + pg = types.SimpleNamespace( + get_template=types.SimpleNamespace(remote=AsyncMock(return_value=None)) + ) + fake_actors.register("postgres_database", pg) + async with AsyncClient(transport=ASGITransport(app=app), base_url="http://t") as c: + r = await c.post("/api/v1/agent/worker/from-template/nope") + assert r.status_code == 404 diff --git a/tests/unit/test_database_exception.py b/tests/unit/test_database_exception.py index 4c8bd94..71fb6f4 100644 --- a/tests/unit/test_database_exception.py +++ b/tests/unit/test_database_exception.py @@ -5,7 +5,7 @@ from pydantic import BaseModel, ValidationError from sqlalchemy.exc import IntegrityError, OperationalError from kilostar.core.postgres_database.database_exception import database_exception -from kilostar.utils.error import UserNotExistError +from kilostar.utils.error import UserNotExistError, BusinessError, RetryableError async def test_normal_path_returns_value(): @@ -36,21 +36,22 @@ async def test_validation_error_propagates(): await boom() -async def test_integrity_error_propagates(): +async def test_integrity_error_becomes_business_error(): @database_exception async def boom() -> None: raise IntegrityError("stmt", {}, Exception("dup")) - with pytest.raises(IntegrityError): + with pytest.raises(BusinessError) as exc_info: await boom() + assert exc_info.value.http_status == 409 -async def test_operational_error_propagates(): +async def test_operational_error_becomes_retryable(): @database_exception async def boom() -> None: raise OperationalError("stmt", {}, Exception("conn")) - with pytest.raises(OperationalError): + with pytest.raises(RetryableError): await boom() diff --git a/tests/unit/test_persona_template_db.py b/tests/unit/test_persona_template_db.py new file mode 100644 index 0000000..e6de1c8 --- /dev/null +++ b/tests/unit/test_persona_template_db.py @@ -0,0 +1,74 @@ +"""``PersonaTemplateDatabase`` — list_templates 查询逻辑单元测试。""" + +from __future__ import annotations + +import pytest +from unittest.mock import AsyncMock, MagicMock, patch + +from kilostar.core.postgres_database.module.persona_template import PersonaTemplateDatabase + + +def _make_db(): + session = AsyncMock() + session.__aenter__ = AsyncMock(return_value=session) + session.__aexit__ = AsyncMock(return_value=False) + session_maker = MagicMock(return_value=session) + return PersonaTemplateDatabase(session_maker), session + + +@pytest.mark.anyio +async def test_list_owner_and_builtin_uses_or(): + """owner_id + include_builtin=True 应构造 OR 条件,不拉出其他人的模板。""" + db, session = _make_db() + execute_result = MagicMock() + execute_result.scalars.return_value.all.return_value = [] + session.execute = AsyncMock(return_value=execute_result) + + result = await db.list_templates(owner_id="alice", include_builtin=True) + assert result == [] + # 确认 execute 被调用(OR 条件路径走通) + session.execute.assert_awaited_once() + + +@pytest.mark.anyio +async def test_list_owner_only(): + db, session = _make_db() + execute_result = MagicMock() + execute_result.scalars.return_value.all.return_value = [] + session.execute = AsyncMock(return_value=execute_result) + + await db.list_templates(owner_id="alice", include_builtin=False) + session.execute.assert_awaited_once() + + +@pytest.mark.anyio +async def test_list_builtin_only(): + db, session = _make_db() + execute_result = MagicMock() + execute_result.scalars.return_value.all.return_value = [] + session.execute = AsyncMock(return_value=execute_result) + + await db.list_templates(owner_id=None, include_builtin=True) + session.execute.assert_awaited_once() + + +@pytest.mark.anyio +async def test_delete_nonexistent_returns_false(): + db, session = _make_db() + execute_result = MagicMock() + execute_result.scalar_one_or_none.return_value = None + session.execute = AsyncMock(return_value=execute_result) + + result = await db.delete_template("missing") + assert result is False + + +@pytest.mark.anyio +async def test_update_nonexistent_returns_none(): + db, session = _make_db() + execute_result = MagicMock() + execute_result.scalar_one_or_none.return_value = None + session.execute = AsyncMock(return_value=execute_result) + + result = await db.update_template("missing", name="new") + assert result is None