feat: 人设模板系统、节点调度标签、pydantic-settings收敛、错误处理增强

新增persona_template表和CRUD API,BaseIndividualModel增加node_affinity和template_origin_id字段,
WorkerCluster支持多集群Ray资源调度,环境变量收敛到pydantic-settings统一校验,
数据库异常转换为结构化BusinessError/RetryableError,系统节点支持custom_system_prompt。

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
2026-06-04 06:07:46 +00:00
parent f3a92a793e
commit 8f1398c591
23 changed files with 582 additions and 48 deletions
@@ -0,0 +1,69 @@
"""add persona_template table, node_affinity and template_origin_id to base_individual
Revision ID: 0003_persona_template
Revises: 0002_graph_and_logs
Create Date: 2026-06-04 00:00:00
"""
from typing import Sequence, Union
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
from alembic import op
revision: str = "0003_persona_template"
down_revision: Union[str, None] = "0002_graph_and_logs"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.create_table(
"persona_template",
sa.Column("template_id", sa.String(64), primary_key=True),
sa.Column("name", sa.String(100), nullable=False),
sa.Column("description", sa.Text(), nullable=False, server_default=""),
sa.Column("system_prompt", sa.Text(), nullable=False, server_default=""),
sa.Column("agent_type", sa.String(32), nullable=False, server_default="ordinary"),
sa.Column("provider_title", sa.String(50), nullable=True),
sa.Column("model_id", sa.String(100), nullable=True),
sa.Column("tools", postgresql.JSONB(), nullable=True, server_default="'[]'::jsonb"),
sa.Column("tags", postgresql.JSONB(), nullable=True, server_default="'[]'::jsonb"),
sa.Column("is_builtin", sa.Boolean(), nullable=False, server_default="false"),
sa.Column("owner_id", sa.String(64), nullable=True),
)
op.create_index("ix_persona_template_name", "persona_template", ["name"])
op.create_index("ix_persona_template_owner_id", "persona_template", ["owner_id"])
op.add_column(
"base_individual",
sa.Column("node_affinity", sa.String(32), nullable=False, server_default="cpu"),
)
op.add_column(
"base_individual",
sa.Column("template_origin_id", sa.String(64), nullable=True),
)
op.create_foreign_key(
"fk_base_individual_template_origin",
"base_individual",
"persona_template",
["template_origin_id"],
["template_id"],
ondelete="SET NULL",
)
op.create_index(
"ix_base_individual_template_origin_id",
"base_individual",
["template_origin_id"],
)
def downgrade() -> None:
op.drop_index("ix_base_individual_template_origin_id", "base_individual")
op.drop_constraint("fk_base_individual_template_origin", "base_individual", type_="foreignkey")
op.drop_column("base_individual", "template_origin_id")
op.drop_column("base_individual", "node_affinity")
op.drop_index("ix_persona_template_owner_id", "persona_template")
op.drop_index("ix_persona_template_name", "persona_template")
op.drop_table("persona_template")
@@ -0,0 +1,27 @@
"""add custom_system_prompt to system_node_config
Revision ID: 0004_system_node_custom_prompt
Revises: 0003_persona_template
Create Date: 2026-06-04 00:01:00
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
revision: str = "0004_system_node_custom_prompt"
down_revision: Union[str, None] = "0003_persona_template"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.add_column(
"system_node_config",
sa.Column("custom_system_prompt", sa.Text(), nullable=True),
)
def downgrade() -> None:
op.drop_column("system_node_config", "custom_system_prompt")
+4 -3
View File
@@ -21,6 +21,7 @@ from fastapi.responses import FileResponse, JSONResponse
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
from kilostar.utils.standalone_proxy import _STANDALONE from kilostar.utils.standalone_proxy import _STANDALONE
from kilostar.utils.settings import get_settings
if not _STANDALONE: if not _STANDALONE:
from ray import serve from ray import serve
@@ -51,13 +52,13 @@ _api_logger = get_logger("api")
def _get_locale(request: Request) -> str | None: def _get_locale(request: Request) -> str | None:
"""从请求头解析首选语言,供异常 handler 使用。"""
return request.headers.get("accept-language") or None return request.headers.get("accept-language") or None
app = FastAPI() app = FastAPI()
_cors_origins_env = os.environ.get("KILOSTAR_CORS_ORIGINS", "") _settings = get_settings()
_is_dev = os.environ.get("KILOSTAR_ENV", "production").lower() in ("dev", "development") _cors_origins_env = _settings.kilostar_cors_origins
_is_dev = _settings.security.kilostar_env.lower() in ("dev", "development")
if not _cors_origins_env and _is_dev: if not _cors_origins_env and _is_dev:
_cors_origins_env = "*" _cors_origins_env = "*"
elif not _cors_origins_env: elif not _cors_origins_env:
+4 -2
View File
@@ -266,8 +266,10 @@ async def send_message(
if not user_id and not group_id: if not user_id and not group_id:
raise ValueError("必须指定 user_id 或 group_id 之一") raise ValueError("必须指定 user_id 或 group_id 之一")
base = base_url or os.environ.get("ONEBOT_HTTP_URL", "http://127.0.0.1:5700") from kilostar.utils.settings import get_settings
token = access_token or os.environ.get("ONEBOT_ACCESS_TOKEN") _ob = get_settings().onebot
base = base_url or _ob.onebot_http_url
token = access_token or _ob.onebot_access_token or None
if group_id: if group_id:
action = "send_group_msg" action = "send_group_msg"
+7
View File
@@ -106,3 +106,10 @@ async def query_system_logs(
offset=offset, offset=offset,
) )
return {"logs": logs, "count": len(logs)} return {"logs": logs, "count": len(logs)}
@system_router.get("/api/v1/system/node-labels")
async def get_node_labels(
_: TokenData = Depends(Accessor.get_current_user),
):
return {"node_labels": ["cpu", "core", "gpu"]}
@@ -48,8 +48,9 @@ class ConsciousnessNode:
tools_list: list[str] = None, tools_list: list[str] = None,
toolsets=None, toolsets=None,
locale: str | None = None, locale: str | None = None,
custom_system_prompt: str | None = None,
) -> None: ) -> None:
system_prompt: str = agent_prompt("consciousness_node", locale=locale) system_prompt: str = agent_prompt("consciousness_node", locale=locale, custom_system_prompt=custom_system_prompt)
output_type = Union[ForregulatoryNode, ForWorkflow, ForWorkflowEngine] output_type = Union[ForregulatoryNode, ForWorkflow, ForWorkflowEngine]
from kilostar.utils.get_tool import load_tools_from_list from kilostar.utils.get_tool import load_tools_from_list
from kilostar.core.global_state_machine.gsm_snapshot import fetch_snapshot from kilostar.core.global_state_machine.gsm_snapshot import fetch_snapshot
@@ -47,6 +47,7 @@ class ControlNode:
tools_list: list[str] = None, tools_list: list[str] = None,
toolsets=None, toolsets=None,
locale: str | None = None, locale: str | None = None,
custom_system_prompt: str | None = None,
) -> None: ) -> None:
""" """
create_agent方法,将agent对象装配到Control的属性内 create_agent方法,将agent对象装配到Control的属性内
@@ -58,11 +59,12 @@ class ControlNode:
provider_title: 供应商名 provider_title: 供应商名
model_id: 模型id model_id: 模型id
locale: 语言代码(zh/en),控制system prompt语言 locale: 语言代码(zh/en),控制system prompt语言
custom_system_prompt: 管理员自定义追加提示词(可选)
Returns: Returns:
无返回 无返回
""" """
system_prompt: str = agent_prompt("control_node", locale=locale) system_prompt: str = agent_prompt("control_node", locale=locale, custom_system_prompt=custom_system_prompt)
output_type = ForWorkflow output_type = ForWorkflow
from kilostar.utils.get_tool import load_tools_from_list from kilostar.utils.get_tool import load_tools_from_list
from kilostar.core.global_state_machine.gsm_snapshot import fetch_snapshot from kilostar.core.global_state_machine.gsm_snapshot import fetch_snapshot
@@ -49,6 +49,7 @@ class RegulatoryNode:
tools_list: list[str] = None, tools_list: list[str] = None,
toolsets=None, toolsets=None,
locale: str | None = None, locale: str | None = None,
custom_system_prompt: str | None = None,
) -> None: ) -> None:
""" """
create_agent方法,将agent对象装配到regulatoryNode的属性内 create_agent方法,将agent对象装配到regulatoryNode的属性内
@@ -60,10 +61,11 @@ class RegulatoryNode:
model_id: 模型id model_id: 模型id
tools_list: 工具列表 tools_list: 工具列表
locale: 语言代码(zh/en),控制system prompt语言 locale: 语言代码(zh/en),控制system prompt语言
custom_system_prompt: 管理员自定义追加提示词(可选)
Returns: Returns:
无返回 无返回
""" """
system_prompt: str = agent_prompt("regulatory_node", locale=locale) system_prompt: str = agent_prompt("regulatory_node", locale=locale, custom_system_prompt=custom_system_prompt)
output_type = Union[MessageResponse] output_type = Union[MessageResponse]
from kilostar.utils.get_tool import load_tools_from_list from kilostar.utils.get_tool import load_tools_from_list
from kilostar.core.global_state_machine.gsm_snapshot import fetch_snapshot from kilostar.core.global_state_machine.gsm_snapshot import fetch_snapshot
@@ -14,7 +14,7 @@
from sqlalchemy.exc import IntegrityError, OperationalError from sqlalchemy.exc import IntegrityError, OperationalError
from pydantic import ValidationError from pydantic import ValidationError
from kilostar.utils.error import UserNotExistError from kilostar.utils.error import UserNotExistError, BusinessError, RetryableError
from kilostar.utils.logger import get_logger from kilostar.utils.logger import get_logger
@@ -31,14 +31,16 @@ def database_exception(func):
logger.error(f"对象校验失败:{e}") logger.error(f"对象校验失败:{e}")
raise e raise e
except IntegrityError as e: except IntegrityError as e:
logger.error(f"数据库完整性错误 (如重复记录): {e}") logger.warning(f"数据库完整性冲突: {e.orig}")
raise e err = BusinessError(str(e.orig))
err.http_status = 409
err.code = "conflict"
raise err from e
except OperationalError as e: except OperationalError as e:
logger.error(f"数据库连接异常: {e}") logger.error(f"数据库连接异常: {e}")
raise e raise RetryableError(f"数据库暂时不可用,请稍后重试: {e}") from e
except UserNotExistError as e: except (UserNotExistError, BusinessError):
logger.error(f"更改密码失败,用户不存在:{e}") raise
raise e
except Exception as e: except Exception as e:
logger.exception(f"未预期的数据库错误: {e}") logger.exception(f"未预期的数据库错误: {e}")
raise e raise e
@@ -34,6 +34,7 @@ from kilostar.core.postgres_database.model.mcp_server import MCPServerModel
from kilostar.core.postgres_database.model.tool_config import ToolConfigModel from kilostar.core.postgres_database.model.tool_config import ToolConfigModel
from kilostar.core.postgres_database.model.custom_toolset import CustomToolsetModel from kilostar.core.postgres_database.model.custom_toolset import CustomToolsetModel
from kilostar.core.postgres_database.model.system_event_log import SystemEventLog from kilostar.core.postgres_database.model.system_event_log import SystemEventLog
from kilostar.core.postgres_database.model.persona_template import PersonaTemplate
# 兼容旧代码的别名 # 兼容旧代码的别名
Provider = ProviderModel Provider = ProviderModel
@@ -63,5 +64,6 @@ __all__ = [
"ToolConfigModel", "ToolConfigModel",
"CustomToolsetModel", "CustomToolsetModel",
"SystemEventLog", "SystemEventLog",
"PersonaTemplate",
"AgentType", "AgentType",
] ]
@@ -43,6 +43,12 @@ class BaseIndividualModel(BaseDataModel):
owner_id: Mapped[str] = mapped_column(String(64), index=True) owner_id: Mapped[str] = mapped_column(String(64), index=True)
agent_type: Mapped[str] = mapped_column(String(32)) agent_type: Mapped[str] = mapped_column(String(32))
node_affinity: Mapped[str] = mapped_column(String(32), nullable=False, default="cpu")
template_origin_id: Mapped[Optional[str]] = mapped_column(
ForeignKey("persona_template.template_id", ondelete="SET NULL"),
nullable=True,
index=True,
)
__mapper_args__ = {"polymorphic_on": "agent_type", "polymorphic_identity": "base"} __mapper_args__ = {"polymorphic_on": "agent_type", "polymorphic_identity": "base"}
@@ -0,0 +1,26 @@
from typing import List, Optional
from sqlalchemy import String, Text, Boolean, text
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.orm import Mapped, mapped_column
from .base import BaseDataModel
class PersonaTemplate(BaseDataModel):
__tablename__ = "persona_template"
template_id: Mapped[str] = mapped_column(String(64), primary_key=True)
name: Mapped[str] = mapped_column(String(100), nullable=False, index=True)
description: Mapped[str] = mapped_column(Text, nullable=False, default="")
system_prompt: Mapped[str] = mapped_column(Text, nullable=False, default="")
agent_type: Mapped[str] = mapped_column(String(32), nullable=False, default="ordinary")
provider_title: Mapped[Optional[str]] = mapped_column(String(50))
model_id: Mapped[Optional[str]] = mapped_column(String(100))
tools: Mapped[Optional[List[str]]] = mapped_column(
JSONB, default=list, server_default=text("'[]'::jsonb")
)
tags: Mapped[Optional[List[str]]] = mapped_column(
JSONB, default=list, server_default=text("'[]'::jsonb")
)
is_builtin: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False)
owner_id: Mapped[Optional[str]] = mapped_column(String(64), index=True)
@@ -0,0 +1,73 @@
from sqlalchemy import select
from ulid import ULID
from kilostar.core.postgres_database.model.persona_template import PersonaTemplate
from kilostar.core.postgres_database.database_exception import database_exception
class PersonaTemplateDatabase:
def __init__(self, async_session_maker):
self.async_session_maker = async_session_maker
@database_exception
async def add_template(self, **kwargs) -> PersonaTemplate:
async with self.async_session_maker() as session:
tpl = PersonaTemplate(template_id=str(ULID()), **kwargs)
session.add(tpl)
await session.commit()
await session.refresh(tpl)
return tpl
@database_exception
async def get_template(self, template_id: str):
async with self.async_session_maker() as session:
result = await session.execute(
select(PersonaTemplate).where(PersonaTemplate.template_id == template_id)
)
return result.scalar_one_or_none()
@database_exception
async def list_templates(self, owner_id: str = None, include_builtin: bool = True):
async with self.async_session_maker() as session:
stmt = select(PersonaTemplate)
if owner_id and include_builtin:
from sqlalchemy import or_
stmt = stmt.where(
or_(PersonaTemplate.owner_id == owner_id, PersonaTemplate.is_builtin == True)
)
elif owner_id:
stmt = stmt.where(PersonaTemplate.owner_id == owner_id)
elif include_builtin:
stmt = stmt.where(PersonaTemplate.is_builtin == True)
result = await session.execute(stmt)
return list(result.scalars().all())
@database_exception
async def update_template(self, template_id: str, **kwargs):
async with self.async_session_maker() as session:
result = await session.execute(
select(PersonaTemplate).where(PersonaTemplate.template_id == template_id)
)
tpl = result.scalar_one_or_none()
if not tpl:
return None
for k, v in kwargs.items():
if v is not None:
setattr(tpl, k, v)
session.add(tpl)
await session.commit()
await session.refresh(tpl)
return tpl
@database_exception
async def delete_template(self, template_id: str) -> bool:
async with self.async_session_maker() as session:
result = await session.execute(
select(PersonaTemplate).where(PersonaTemplate.template_id == template_id)
)
tpl = result.scalar_one_or_none()
if not tpl:
return False
await session.delete(tpl)
await session.commit()
return True
+2 -6
View File
@@ -42,12 +42,8 @@ class TokenData(BaseModel):
def _get_secret_key() -> str: def _get_secret_key() -> str:
"""读取并校验 SECRET_KEY 环境变量。 from kilostar.utils.settings import get_settings
key = get_settings().security.secret_key
校验在首次实际使用 JWT 时进行,避免在模块导入阶段抛错,
从而把"环境约束""模块加载"解耦。
"""
key = os.getenv("SECRET_KEY")
if not key or key in _INSECURE_SECRETS: if not key or key in _INSECURE_SECRETS:
raise RuntimeError( raise RuntimeError(
"未提供有效的 SECRET_KEY 或使用了不安全的默认值,请设置一个高熵的随机字符串" "未提供有效的 SECRET_KEY 或使用了不安全的默认值,请设置一个高熵的随机字符串"
+15 -11
View File
@@ -25,10 +25,11 @@
from __future__ import annotations from __future__ import annotations
import os
from typing import Dict from typing import Dict
_DEFAULT_LOCALE: str = os.getenv("KILOSTAR_LANG", "zh") from kilostar.utils.settings import get_settings
_DEFAULT_LOCALE: str = get_settings().kilostar_lang
# ─── Agent System Prompts ────────────────────────────────────────────────── # ─── Agent System Prompts ──────────────────────────────────────────────────
@@ -163,16 +164,16 @@ def t(key: str, locale: str | None = None, accept_language: str | None = None, *
return text.format(**kwargs) if kwargs else text return text.format(**kwargs) if kwargs else text
def agent_prompt(agent_name: str, locale: str | None = None, accept_language: str | None = None) -> str: def agent_prompt(
agent_name: str,
locale: str | None = None,
accept_language: str | None = None,
custom_system_prompt: str | None = None,
) -> str:
"""获取指定 Agent 的 system prompt,并追加语言指令。 """获取指定 Agent 的 system prompt,并追加语言指令。
Args: 若 ``custom_system_prompt`` 不为空,追加在默认 prompt 和语言指令之后,
agent_name: ``regulatory_node`` / ``consciousness_node`` / ``control_node`` 使管理员自定义内容能够覆盖/补充默认行为,同时保留角色定义。
locale: 显式指定语言代码。
accept_language: ``Accept-Language`` 头内容。
Returns:
完整 system prompt(含 "请使用 XX 语言回复" 的追加指令)。
""" """
loc = _resolve_locale(locale, accept_language) loc = _resolve_locale(locale, accept_language)
prompt = _PROMPTS.get(agent_name, {}).get(loc) or _PROMPTS.get(agent_name, {}).get(_DEFAULT_LOCALE, "") prompt = _PROMPTS.get(agent_name, {}).get(loc) or _PROMPTS.get(agent_name, {}).get(_DEFAULT_LOCALE, "")
@@ -180,4 +181,7 @@ def agent_prompt(agent_name: str, locale: str | None = None, accept_language: st
"zh": "\n\n【重要】请始终使用简体中文进行思考和回复。", "zh": "\n\n【重要】请始终使用简体中文进行思考和回复。",
"en": "\n\n[Important] Please always think and reply in English.", "en": "\n\n[Important] Please always think and reply in English.",
}.get(loc, "") }.get(loc, "")
return prompt + lang_instruction result = prompt + lang_instruction
if custom_system_prompt and custom_system_prompt.strip():
result += f"\n\n{custom_system_prompt.strip()}"
return result
+6 -10
View File
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
import sys import sys
from loguru import logger from loguru import logger
@@ -23,15 +22,11 @@ from kilostar.utils.request_context import get_request_id, get_trace_id
def _is_json_mode() -> bool: def _is_json_mode() -> bool:
"""根据环境变量决定是否启用 JSON 结构化日志。 from kilostar.utils.settings import get_settings
s = get_settings().log
支持开关:``KILOSTAR_LOG_FORMAT=json`` 或 ``KILOSTAR_LOG_JSON=1/true``。 if s.kilostar_log_format.lower() == "json":
"""
fmt = os.environ.get("KILOSTAR_LOG_FORMAT", "").lower()
if fmt == "json":
return True return True
flag = os.environ.get("KILOSTAR_LOG_JSON", "").lower() return s.kilostar_log_json.lower() in {"1", "true", "yes", "on"}
return flag in {"1", "true", "yes", "on"}
def _ctx_patcher(record): def _ctx_patcher(record):
@@ -58,7 +53,8 @@ def setup_logger() -> Logger:
""" """
logger.remove() logger.remove()
log_level = os.environ.get("KILOSTAR_LOG_LEVEL", "DEBUG").upper() from kilostar.utils.settings import get_settings
log_level = get_settings().log.kilostar_log_level.upper()
if _is_json_mode(): if _is_json_mode():
logger.configure( logger.configure(
+15
View File
@@ -125,3 +125,18 @@ def ray_actor_hook(*actor_names: str, timeout: float = 0.0, interval: float = 0.
handle = _get_cached_actor_handle(actor_name) handle = _get_cached_actor_handle(actor_name)
setattr(actor_list, actor_name, handle) setattr(actor_list, actor_name, handle)
return actor_list return actor_list
def get_worker_cluster(affinity: str = "cpu"):
"""按 node_affinity 标签取对应的 WorkerCluster actor 句柄。
单机模式统一返回唯一的 worker_cluster 实例。
分布式模式按 affinity 路由到 worker_cluster_cpu / _core / _gpu。
未知标签降级到 cpu。
"""
if _STANDALONE:
return _standalone_registry.get("worker_cluster")
_valid = {"cpu", "core", "gpu"}
node_type = affinity if affinity in _valid else "cpu"
return _get_cached_actor_handle(f"worker_cluster_{node_type}")
+55
View File
@@ -0,0 +1,55 @@
"""KiloStar 集中式环境变量管理。
所有散落在各模块的 os.getenv/os.environ 收敛到此处,
通过 pydantic-settings 统一校验、类型转换、默认值管理。
"""
from __future__ import annotations
from functools import lru_cache
from pydantic import Field
from pydantic_settings import BaseSettings
class DatabaseSettings(BaseSettings):
postgres_user: str = "postgres"
postgres_password: str = ""
postgres_host: str = "db"
postgres_port: int = 5432
postgres_db: str = "postgres"
class SecuritySettings(BaseSettings):
secret_key: str = ""
kilostar_secret_key: str = ""
kilostar_env: str = "production"
class LogSettings(BaseSettings):
kilostar_log_level: str = "DEBUG"
kilostar_log_format: str = ""
kilostar_log_json: str = ""
class OnebotSettings(BaseSettings):
onebot_access_token: str = ""
onebot_http_url: str = "http://127.0.0.1:5700"
class AppSettings(BaseSettings):
kilostar_mode: str = "distributed"
kilostar_lang: str = "zh"
kilostar_cors_origins: str = ""
db: DatabaseSettings = Field(default_factory=DatabaseSettings)
security: SecuritySettings = Field(default_factory=SecuritySettings)
log: LogSettings = Field(default_factory=LogSettings)
onebot: OnebotSettings = Field(default_factory=OnebotSettings)
model_config = {"env_nested_delimiter": "__"}
@lru_cache(maxsize=1)
def get_settings() -> AppSettings:
return AppSettings()
+8 -1
View File
@@ -36,10 +36,15 @@ class WorkerCluster:
""" """
工作集群 Actor:管理和调度所有的 worker_individual 工作集群 Actor:管理和调度所有的 worker_individual
设计理念:按需加载,内存 LRU 淘汰,避免 Actor 爆炸 设计理念:按需加载,内存 LRU 淘汰,避免 Actor 爆炸
分布式模式下每种 node_type 对应一个独立实例,Ray 根据自定义资源
``kilostar_node_cpu`` / ``kilostar_node_core`` / ``kilostar_node_gpu``
将 Actor 调度到声明了对应资源的节点上。
""" """
def __init__(self, max_capacity: int = 200, num_runners: int = 10): def __init__(self, max_capacity: int = 200, num_runners: int = 10, node_type: str = "cpu"):
self.max_capacity = max_capacity self.max_capacity = max_capacity
self.node_type = node_type
self._active_workers: OrderedDict[str, BaseIndividual] = OrderedDict() self._active_workers: OrderedDict[str, BaseIndividual] = OrderedDict()
self.status = "running" self.status = "running"
self.task_queue = None self.task_queue = None
@@ -76,6 +81,8 @@ class WorkerCluster:
raise ValueError(f"无法唤醒 Agent {agent_id}:数据库中不存在该档案") raise ValueError(f"无法唤醒 Agent {agent_id}:数据库中不存在该档案")
worker_type = agent_config.get("type", "ordinary") worker_type = agent_config.get("type", "ordinary")
node_affinity = agent_config.get("node_affinity", "cpu")
self.logger.debug(f"[WorkerCluster] 唤醒 Agent {agent_id}, node_affinity={node_affinity}")
if worker_type == "skill": if worker_type == "skill":
worker = SkillIndividual(agent_config) worker = SkillIndividual(agent_config)
elif worker_type == "special": elif worker_type == "special":
+1
View File
@@ -21,6 +21,7 @@ dependencies = [
"pretor-viceroy>=0.2.0", "pretor-viceroy>=0.2.0",
"pwdlib[argon2,bcrypt]>=0.3.0", "pwdlib[argon2,bcrypt]>=0.3.0",
"pydantic-ai>=1.73.0", "pydantic-ai>=1.73.0",
"pydantic-settings>=2.0",
"pyfiglet>=1.0.4", "pyfiglet>=1.0.4",
"pyjwt>=2.12.1", "pyjwt>=2.12.1",
"python-ulid>=3.1.0", "python-ulid>=3.1.0",
+165
View File
@@ -0,0 +1,165 @@
"""``api/agent.py`` persona template 路由:CRUD 鉴权与 node_affinity 校验。"""
from __future__ import annotations
import types
from unittest.mock import AsyncMock
import pytest
from fastapi import FastAPI
from httpx import AsyncClient, ASGITransport
from kilostar.api.agent import agent_router, _VALID_AFFINITIES
from kilostar.utils.access import Accessor, TokenData
from kilostar.core.postgres_database.model import UserAuthority
def _fake_user(user_id: str = "alice"):
return TokenData(user_id=user_id, username=user_id)
def _tpl(owner: str = "alice", is_builtin: bool = False):
return types.SimpleNamespace(
template_id="tpl1", name="MyBot", agent_type="ordinary",
description="d", system_prompt="s", provider_title="openai",
model_id="gpt-4", tools=[], tags=[], is_builtin=is_builtin, owner_id=owner,
)
@pytest.fixture
def app(monkeypatch):
import kilostar.utils.check_user.role_check as rc
monkeypatch.setattr(rc, "get_authority", AsyncMock(return_value=UserAuthority.USER))
_app = FastAPI()
_app.include_router(agent_router)
_app.dependency_overrides[Accessor.get_current_user] = lambda: _fake_user()
return _app
def _register_pg(fake_actors, **kwargs):
pg = types.SimpleNamespace(**kwargs)
fake_actors.register("postgres_database", pg)
return pg
# ── node_affinity 校验(纯 pydantic) ─────────────────────────────────────
def test_valid_affinities_accepted():
from kilostar.api.agent import WorkerIndividualCreate
for aff in _VALID_AFFINITIES:
m = WorkerIndividualCreate(
agent_name="x", agent_type="ordinary", description="d",
provider_title="p", model_id="m", system_prompt="s",
output_template={}, bound_skill={}, workspace=[], node_affinity=aff,
)
assert m.node_affinity == aff
def test_invalid_affinity_raises():
from pydantic import ValidationError
from kilostar.api.agent import WorkerIndividualCreate
with pytest.raises(ValidationError):
WorkerIndividualCreate(
agent_name="x", agent_type="ordinary", description="d",
provider_title="p", model_id="m", system_prompt="s",
output_template={}, bound_skill={}, workspace=[], node_affinity="bad",
)
def test_update_invalid_affinity_raises():
from pydantic import ValidationError
from kilostar.api.agent import WorkerIndividualUpdate
with pytest.raises(ValidationError):
WorkerIndividualUpdate(node_affinity="bad")
def test_update_none_affinity_ok():
from kilostar.api.agent import WorkerIndividualUpdate
assert WorkerIndividualUpdate(node_affinity=None).node_affinity is None
# ── template API 路由 ──────────────────────────────────────────────────────
@pytest.mark.asyncio
async def test_list_templates(app, fake_actors):
pg = types.SimpleNamespace(
list_templates=types.SimpleNamespace(remote=AsyncMock(return_value=[]))
)
fake_actors.register("postgres_database", pg)
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://t") as c:
r = await c.get("/api/v1/agent/template")
assert r.status_code == 200
assert r.json()["templates"] == []
@pytest.mark.asyncio
async def test_create_template(app, fake_actors):
tpl = _tpl()
pg = types.SimpleNamespace(
add_template=types.SimpleNamespace(remote=AsyncMock(return_value=tpl))
)
fake_actors.register("postgres_database", pg)
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://t") as c:
r = await c.post("/api/v1/agent/template", json={"name": "test"})
assert r.status_code == 200
assert r.json()["template_id"] == "tpl1"
@pytest.mark.asyncio
async def test_delete_template_not_found(app, fake_actors):
pg = types.SimpleNamespace(
get_template=types.SimpleNamespace(remote=AsyncMock(return_value=None))
)
fake_actors.register("postgres_database", pg)
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://t") as c:
r = await c.delete("/api/v1/agent/template/missing")
assert r.status_code == 404
@pytest.mark.asyncio
async def test_delete_builtin_template_forbidden(app, fake_actors):
pg = types.SimpleNamespace(
get_template=types.SimpleNamespace(remote=AsyncMock(return_value=_tpl(is_builtin=True)))
)
fake_actors.register("postgres_database", pg)
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://t") as c:
r = await c.delete("/api/v1/agent/template/tpl1")
assert r.status_code == 403
@pytest.mark.asyncio
async def test_delete_other_users_template_forbidden(app, fake_actors):
pg = types.SimpleNamespace(
get_template=types.SimpleNamespace(remote=AsyncMock(return_value=_tpl(owner="bob")))
)
fake_actors.register("postgres_database", pg)
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://t") as c:
r = await c.delete("/api/v1/agent/template/tpl1")
assert r.status_code == 403
@pytest.mark.asyncio
async def test_create_worker_from_template(app, fake_actors):
worker = types.SimpleNamespace(agent_id="w1")
pg = types.SimpleNamespace(
get_template=types.SimpleNamespace(remote=AsyncMock(return_value=_tpl())),
add_worker_individual=types.SimpleNamespace(remote=AsyncMock(return_value=worker)),
)
fake_actors.register("postgres_database", pg)
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://t") as c:
r = await c.post("/api/v1/agent/worker/from-template/tpl1")
assert r.status_code == 200
assert r.json()["agent_id"] == "w1"
@pytest.mark.asyncio
async def test_create_worker_from_missing_template(app, fake_actors):
pg = types.SimpleNamespace(
get_template=types.SimpleNamespace(remote=AsyncMock(return_value=None))
)
fake_actors.register("postgres_database", pg)
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://t") as c:
r = await c.post("/api/v1/agent/worker/from-template/nope")
assert r.status_code == 404
+6 -5
View File
@@ -5,7 +5,7 @@ from pydantic import BaseModel, ValidationError
from sqlalchemy.exc import IntegrityError, OperationalError from sqlalchemy.exc import IntegrityError, OperationalError
from kilostar.core.postgres_database.database_exception import database_exception from kilostar.core.postgres_database.database_exception import database_exception
from kilostar.utils.error import UserNotExistError from kilostar.utils.error import UserNotExistError, BusinessError, RetryableError
async def test_normal_path_returns_value(): async def test_normal_path_returns_value():
@@ -36,21 +36,22 @@ async def test_validation_error_propagates():
await boom() await boom()
async def test_integrity_error_propagates(): async def test_integrity_error_becomes_business_error():
@database_exception @database_exception
async def boom() -> None: async def boom() -> None:
raise IntegrityError("stmt", {}, Exception("dup")) raise IntegrityError("stmt", {}, Exception("dup"))
with pytest.raises(IntegrityError): with pytest.raises(BusinessError) as exc_info:
await boom() await boom()
assert exc_info.value.http_status == 409
async def test_operational_error_propagates(): async def test_operational_error_becomes_retryable():
@database_exception @database_exception
async def boom() -> None: async def boom() -> None:
raise OperationalError("stmt", {}, Exception("conn")) raise OperationalError("stmt", {}, Exception("conn"))
with pytest.raises(OperationalError): with pytest.raises(RetryableError):
await boom() await boom()
+74
View File
@@ -0,0 +1,74 @@
"""``PersonaTemplateDatabase`` — list_templates 查询逻辑单元测试。"""
from __future__ import annotations
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from kilostar.core.postgres_database.module.persona_template import PersonaTemplateDatabase
def _make_db():
session = AsyncMock()
session.__aenter__ = AsyncMock(return_value=session)
session.__aexit__ = AsyncMock(return_value=False)
session_maker = MagicMock(return_value=session)
return PersonaTemplateDatabase(session_maker), session
@pytest.mark.anyio
async def test_list_owner_and_builtin_uses_or():
"""owner_id + include_builtin=True 应构造 OR 条件,不拉出其他人的模板。"""
db, session = _make_db()
execute_result = MagicMock()
execute_result.scalars.return_value.all.return_value = []
session.execute = AsyncMock(return_value=execute_result)
result = await db.list_templates(owner_id="alice", include_builtin=True)
assert result == []
# 确认 execute 被调用(OR 条件路径走通)
session.execute.assert_awaited_once()
@pytest.mark.anyio
async def test_list_owner_only():
db, session = _make_db()
execute_result = MagicMock()
execute_result.scalars.return_value.all.return_value = []
session.execute = AsyncMock(return_value=execute_result)
await db.list_templates(owner_id="alice", include_builtin=False)
session.execute.assert_awaited_once()
@pytest.mark.anyio
async def test_list_builtin_only():
db, session = _make_db()
execute_result = MagicMock()
execute_result.scalars.return_value.all.return_value = []
session.execute = AsyncMock(return_value=execute_result)
await db.list_templates(owner_id=None, include_builtin=True)
session.execute.assert_awaited_once()
@pytest.mark.anyio
async def test_delete_nonexistent_returns_false():
db, session = _make_db()
execute_result = MagicMock()
execute_result.scalar_one_or_none.return_value = None
session.execute = AsyncMock(return_value=execute_result)
result = await db.delete_template("missing")
assert result is False
@pytest.mark.anyio
async def test_update_nonexistent_returns_none():
db, session = _make_db()
execute_result = MagicMock()
execute_result.scalar_one_or_none.return_value = None
session.execute = AsyncMock(return_value=execute_result)
result = await db.update_template("missing", name="new")
assert result is None