feat: 人设模板系统、节点调度标签、pydantic-settings收敛、错误处理增强
新增persona_template表和CRUD API,BaseIndividualModel增加node_affinity和template_origin_id字段, WorkerCluster支持多集群Ray资源调度,环境变量收敛到pydantic-settings统一校验, 数据库异常转换为结构化BusinessError/RetryableError,系统节点支持custom_system_prompt。 Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,69 @@
|
|||||||
|
"""add persona_template table, node_affinity and template_origin_id to base_individual
|
||||||
|
|
||||||
|
Revision ID: 0003_persona_template
|
||||||
|
Revises: 0002_graph_and_logs
|
||||||
|
Create Date: 2026-06-04 00:00:00
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from sqlalchemy.dialects import postgresql
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
revision: str = "0003_persona_template"
|
||||||
|
down_revision: Union[str, None] = "0002_graph_and_logs"
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.create_table(
|
||||||
|
"persona_template",
|
||||||
|
sa.Column("template_id", sa.String(64), primary_key=True),
|
||||||
|
sa.Column("name", sa.String(100), nullable=False),
|
||||||
|
sa.Column("description", sa.Text(), nullable=False, server_default=""),
|
||||||
|
sa.Column("system_prompt", sa.Text(), nullable=False, server_default=""),
|
||||||
|
sa.Column("agent_type", sa.String(32), nullable=False, server_default="ordinary"),
|
||||||
|
sa.Column("provider_title", sa.String(50), nullable=True),
|
||||||
|
sa.Column("model_id", sa.String(100), nullable=True),
|
||||||
|
sa.Column("tools", postgresql.JSONB(), nullable=True, server_default="'[]'::jsonb"),
|
||||||
|
sa.Column("tags", postgresql.JSONB(), nullable=True, server_default="'[]'::jsonb"),
|
||||||
|
sa.Column("is_builtin", sa.Boolean(), nullable=False, server_default="false"),
|
||||||
|
sa.Column("owner_id", sa.String(64), nullable=True),
|
||||||
|
)
|
||||||
|
op.create_index("ix_persona_template_name", "persona_template", ["name"])
|
||||||
|
op.create_index("ix_persona_template_owner_id", "persona_template", ["owner_id"])
|
||||||
|
|
||||||
|
op.add_column(
|
||||||
|
"base_individual",
|
||||||
|
sa.Column("node_affinity", sa.String(32), nullable=False, server_default="cpu"),
|
||||||
|
)
|
||||||
|
op.add_column(
|
||||||
|
"base_individual",
|
||||||
|
sa.Column("template_origin_id", sa.String(64), nullable=True),
|
||||||
|
)
|
||||||
|
op.create_foreign_key(
|
||||||
|
"fk_base_individual_template_origin",
|
||||||
|
"base_individual",
|
||||||
|
"persona_template",
|
||||||
|
["template_origin_id"],
|
||||||
|
["template_id"],
|
||||||
|
ondelete="SET NULL",
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"ix_base_individual_template_origin_id",
|
||||||
|
"base_individual",
|
||||||
|
["template_origin_id"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_index("ix_base_individual_template_origin_id", "base_individual")
|
||||||
|
op.drop_constraint("fk_base_individual_template_origin", "base_individual", type_="foreignkey")
|
||||||
|
op.drop_column("base_individual", "template_origin_id")
|
||||||
|
op.drop_column("base_individual", "node_affinity")
|
||||||
|
|
||||||
|
op.drop_index("ix_persona_template_owner_id", "persona_template")
|
||||||
|
op.drop_index("ix_persona_template_name", "persona_template")
|
||||||
|
op.drop_table("persona_template")
|
||||||
@@ -0,0 +1,27 @@
|
|||||||
|
"""add custom_system_prompt to system_node_config
|
||||||
|
|
||||||
|
Revision ID: 0004_system_node_custom_prompt
|
||||||
|
Revises: 0003_persona_template
|
||||||
|
Create Date: 2026-06-04 00:01:00
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
revision: str = "0004_system_node_custom_prompt"
|
||||||
|
down_revision: Union[str, None] = "0003_persona_template"
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.add_column(
|
||||||
|
"system_node_config",
|
||||||
|
sa.Column("custom_system_prompt", sa.Text(), nullable=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_column("system_node_config", "custom_system_prompt")
|
||||||
@@ -21,6 +21,7 @@ from fastapi.responses import FileResponse, JSONResponse
|
|||||||
from fastapi.staticfiles import StaticFiles
|
from fastapi.staticfiles import StaticFiles
|
||||||
|
|
||||||
from kilostar.utils.standalone_proxy import _STANDALONE
|
from kilostar.utils.standalone_proxy import _STANDALONE
|
||||||
|
from kilostar.utils.settings import get_settings
|
||||||
|
|
||||||
if not _STANDALONE:
|
if not _STANDALONE:
|
||||||
from ray import serve
|
from ray import serve
|
||||||
@@ -51,13 +52,13 @@ _api_logger = get_logger("api")
|
|||||||
|
|
||||||
|
|
||||||
def _get_locale(request: Request) -> str | None:
|
def _get_locale(request: Request) -> str | None:
|
||||||
"""从请求头解析首选语言,供异常 handler 使用。"""
|
|
||||||
return request.headers.get("accept-language") or None
|
return request.headers.get("accept-language") or None
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
|
||||||
_cors_origins_env = os.environ.get("KILOSTAR_CORS_ORIGINS", "")
|
_settings = get_settings()
|
||||||
_is_dev = os.environ.get("KILOSTAR_ENV", "production").lower() in ("dev", "development")
|
_cors_origins_env = _settings.kilostar_cors_origins
|
||||||
|
_is_dev = _settings.security.kilostar_env.lower() in ("dev", "development")
|
||||||
if not _cors_origins_env and _is_dev:
|
if not _cors_origins_env and _is_dev:
|
||||||
_cors_origins_env = "*"
|
_cors_origins_env = "*"
|
||||||
elif not _cors_origins_env:
|
elif not _cors_origins_env:
|
||||||
|
|||||||
@@ -266,8 +266,10 @@ async def send_message(
|
|||||||
if not user_id and not group_id:
|
if not user_id and not group_id:
|
||||||
raise ValueError("必须指定 user_id 或 group_id 之一")
|
raise ValueError("必须指定 user_id 或 group_id 之一")
|
||||||
|
|
||||||
base = base_url or os.environ.get("ONEBOT_HTTP_URL", "http://127.0.0.1:5700")
|
from kilostar.utils.settings import get_settings
|
||||||
token = access_token or os.environ.get("ONEBOT_ACCESS_TOKEN")
|
_ob = get_settings().onebot
|
||||||
|
base = base_url or _ob.onebot_http_url
|
||||||
|
token = access_token or _ob.onebot_access_token or None
|
||||||
|
|
||||||
if group_id:
|
if group_id:
|
||||||
action = "send_group_msg"
|
action = "send_group_msg"
|
||||||
|
|||||||
@@ -106,3 +106,10 @@ async def query_system_logs(
|
|||||||
offset=offset,
|
offset=offset,
|
||||||
)
|
)
|
||||||
return {"logs": logs, "count": len(logs)}
|
return {"logs": logs, "count": len(logs)}
|
||||||
|
|
||||||
|
|
||||||
|
@system_router.get("/api/v1/system/node-labels")
|
||||||
|
async def get_node_labels(
|
||||||
|
_: TokenData = Depends(Accessor.get_current_user),
|
||||||
|
):
|
||||||
|
return {"node_labels": ["cpu", "core", "gpu"]}
|
||||||
@@ -48,8 +48,9 @@ class ConsciousnessNode:
|
|||||||
tools_list: list[str] = None,
|
tools_list: list[str] = None,
|
||||||
toolsets=None,
|
toolsets=None,
|
||||||
locale: str | None = None,
|
locale: str | None = None,
|
||||||
|
custom_system_prompt: str | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
system_prompt: str = agent_prompt("consciousness_node", locale=locale)
|
system_prompt: str = agent_prompt("consciousness_node", locale=locale, custom_system_prompt=custom_system_prompt)
|
||||||
output_type = Union[ForregulatoryNode, ForWorkflow, ForWorkflowEngine]
|
output_type = Union[ForregulatoryNode, ForWorkflow, ForWorkflowEngine]
|
||||||
from kilostar.utils.get_tool import load_tools_from_list
|
from kilostar.utils.get_tool import load_tools_from_list
|
||||||
from kilostar.core.global_state_machine.gsm_snapshot import fetch_snapshot
|
from kilostar.core.global_state_machine.gsm_snapshot import fetch_snapshot
|
||||||
|
|||||||
@@ -47,6 +47,7 @@ class ControlNode:
|
|||||||
tools_list: list[str] = None,
|
tools_list: list[str] = None,
|
||||||
toolsets=None,
|
toolsets=None,
|
||||||
locale: str | None = None,
|
locale: str | None = None,
|
||||||
|
custom_system_prompt: str | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
create_agent方法,将agent对象装配到Control的属性内
|
create_agent方法,将agent对象装配到Control的属性内
|
||||||
@@ -58,11 +59,12 @@ class ControlNode:
|
|||||||
provider_title: 供应商名
|
provider_title: 供应商名
|
||||||
model_id: 模型id
|
model_id: 模型id
|
||||||
locale: 语言代码(zh/en),控制system prompt语言
|
locale: 语言代码(zh/en),控制system prompt语言
|
||||||
|
custom_system_prompt: 管理员自定义追加提示词(可选)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
无返回
|
无返回
|
||||||
"""
|
"""
|
||||||
system_prompt: str = agent_prompt("control_node", locale=locale)
|
system_prompt: str = agent_prompt("control_node", locale=locale, custom_system_prompt=custom_system_prompt)
|
||||||
output_type = ForWorkflow
|
output_type = ForWorkflow
|
||||||
from kilostar.utils.get_tool import load_tools_from_list
|
from kilostar.utils.get_tool import load_tools_from_list
|
||||||
from kilostar.core.global_state_machine.gsm_snapshot import fetch_snapshot
|
from kilostar.core.global_state_machine.gsm_snapshot import fetch_snapshot
|
||||||
|
|||||||
@@ -49,6 +49,7 @@ class RegulatoryNode:
|
|||||||
tools_list: list[str] = None,
|
tools_list: list[str] = None,
|
||||||
toolsets=None,
|
toolsets=None,
|
||||||
locale: str | None = None,
|
locale: str | None = None,
|
||||||
|
custom_system_prompt: str | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
create_agent方法,将agent对象装配到regulatoryNode的属性内
|
create_agent方法,将agent对象装配到regulatoryNode的属性内
|
||||||
@@ -60,10 +61,11 @@ class RegulatoryNode:
|
|||||||
model_id: 模型id
|
model_id: 模型id
|
||||||
tools_list: 工具列表
|
tools_list: 工具列表
|
||||||
locale: 语言代码(zh/en),控制system prompt语言
|
locale: 语言代码(zh/en),控制system prompt语言
|
||||||
|
custom_system_prompt: 管理员自定义追加提示词(可选)
|
||||||
Returns:
|
Returns:
|
||||||
无返回
|
无返回
|
||||||
"""
|
"""
|
||||||
system_prompt: str = agent_prompt("regulatory_node", locale=locale)
|
system_prompt: str = agent_prompt("regulatory_node", locale=locale, custom_system_prompt=custom_system_prompt)
|
||||||
output_type = Union[MessageResponse]
|
output_type = Union[MessageResponse]
|
||||||
from kilostar.utils.get_tool import load_tools_from_list
|
from kilostar.utils.get_tool import load_tools_from_list
|
||||||
from kilostar.core.global_state_machine.gsm_snapshot import fetch_snapshot
|
from kilostar.core.global_state_machine.gsm_snapshot import fetch_snapshot
|
||||||
|
|||||||
@@ -14,7 +14,7 @@
|
|||||||
|
|
||||||
from sqlalchemy.exc import IntegrityError, OperationalError
|
from sqlalchemy.exc import IntegrityError, OperationalError
|
||||||
from pydantic import ValidationError
|
from pydantic import ValidationError
|
||||||
from kilostar.utils.error import UserNotExistError
|
from kilostar.utils.error import UserNotExistError, BusinessError, RetryableError
|
||||||
|
|
||||||
from kilostar.utils.logger import get_logger
|
from kilostar.utils.logger import get_logger
|
||||||
|
|
||||||
@@ -31,14 +31,16 @@ def database_exception(func):
|
|||||||
logger.error(f"对象校验失败:{e}")
|
logger.error(f"对象校验失败:{e}")
|
||||||
raise e
|
raise e
|
||||||
except IntegrityError as e:
|
except IntegrityError as e:
|
||||||
logger.error(f"数据库完整性错误 (如重复记录): {e}")
|
logger.warning(f"数据库完整性冲突: {e.orig}")
|
||||||
raise e
|
err = BusinessError(str(e.orig))
|
||||||
|
err.http_status = 409
|
||||||
|
err.code = "conflict"
|
||||||
|
raise err from e
|
||||||
except OperationalError as e:
|
except OperationalError as e:
|
||||||
logger.error(f"数据库连接异常: {e}")
|
logger.error(f"数据库连接异常: {e}")
|
||||||
raise e
|
raise RetryableError(f"数据库暂时不可用,请稍后重试: {e}") from e
|
||||||
except UserNotExistError as e:
|
except (UserNotExistError, BusinessError):
|
||||||
logger.error(f"更改密码失败,用户不存在:{e}")
|
raise
|
||||||
raise e
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception(f"未预期的数据库错误: {e}")
|
logger.exception(f"未预期的数据库错误: {e}")
|
||||||
raise e
|
raise e
|
||||||
|
|||||||
@@ -34,6 +34,7 @@ from kilostar.core.postgres_database.model.mcp_server import MCPServerModel
|
|||||||
from kilostar.core.postgres_database.model.tool_config import ToolConfigModel
|
from kilostar.core.postgres_database.model.tool_config import ToolConfigModel
|
||||||
from kilostar.core.postgres_database.model.custom_toolset import CustomToolsetModel
|
from kilostar.core.postgres_database.model.custom_toolset import CustomToolsetModel
|
||||||
from kilostar.core.postgres_database.model.system_event_log import SystemEventLog
|
from kilostar.core.postgres_database.model.system_event_log import SystemEventLog
|
||||||
|
from kilostar.core.postgres_database.model.persona_template import PersonaTemplate
|
||||||
|
|
||||||
# 兼容旧代码的别名
|
# 兼容旧代码的别名
|
||||||
Provider = ProviderModel
|
Provider = ProviderModel
|
||||||
@@ -63,5 +64,6 @@ __all__ = [
|
|||||||
"ToolConfigModel",
|
"ToolConfigModel",
|
||||||
"CustomToolsetModel",
|
"CustomToolsetModel",
|
||||||
"SystemEventLog",
|
"SystemEventLog",
|
||||||
|
"PersonaTemplate",
|
||||||
"AgentType",
|
"AgentType",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -43,6 +43,12 @@ class BaseIndividualModel(BaseDataModel):
|
|||||||
owner_id: Mapped[str] = mapped_column(String(64), index=True)
|
owner_id: Mapped[str] = mapped_column(String(64), index=True)
|
||||||
|
|
||||||
agent_type: Mapped[str] = mapped_column(String(32))
|
agent_type: Mapped[str] = mapped_column(String(32))
|
||||||
|
node_affinity: Mapped[str] = mapped_column(String(32), nullable=False, default="cpu")
|
||||||
|
template_origin_id: Mapped[Optional[str]] = mapped_column(
|
||||||
|
ForeignKey("persona_template.template_id", ondelete="SET NULL"),
|
||||||
|
nullable=True,
|
||||||
|
index=True,
|
||||||
|
)
|
||||||
|
|
||||||
__mapper_args__ = {"polymorphic_on": "agent_type", "polymorphic_identity": "base"}
|
__mapper_args__ = {"polymorphic_on": "agent_type", "polymorphic_identity": "base"}
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,26 @@
|
|||||||
|
from typing import List, Optional
|
||||||
|
from sqlalchemy import String, Text, Boolean, text
|
||||||
|
from sqlalchemy.dialects.postgresql import JSONB
|
||||||
|
from sqlalchemy.orm import Mapped, mapped_column
|
||||||
|
|
||||||
|
from .base import BaseDataModel
|
||||||
|
|
||||||
|
|
||||||
|
class PersonaTemplate(BaseDataModel):
|
||||||
|
__tablename__ = "persona_template"
|
||||||
|
|
||||||
|
template_id: Mapped[str] = mapped_column(String(64), primary_key=True)
|
||||||
|
name: Mapped[str] = mapped_column(String(100), nullable=False, index=True)
|
||||||
|
description: Mapped[str] = mapped_column(Text, nullable=False, default="")
|
||||||
|
system_prompt: Mapped[str] = mapped_column(Text, nullable=False, default="")
|
||||||
|
agent_type: Mapped[str] = mapped_column(String(32), nullable=False, default="ordinary")
|
||||||
|
provider_title: Mapped[Optional[str]] = mapped_column(String(50))
|
||||||
|
model_id: Mapped[Optional[str]] = mapped_column(String(100))
|
||||||
|
tools: Mapped[Optional[List[str]]] = mapped_column(
|
||||||
|
JSONB, default=list, server_default=text("'[]'::jsonb")
|
||||||
|
)
|
||||||
|
tags: Mapped[Optional[List[str]]] = mapped_column(
|
||||||
|
JSONB, default=list, server_default=text("'[]'::jsonb")
|
||||||
|
)
|
||||||
|
is_builtin: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False)
|
||||||
|
owner_id: Mapped[Optional[str]] = mapped_column(String(64), index=True)
|
||||||
@@ -0,0 +1,73 @@
|
|||||||
|
from sqlalchemy import select
|
||||||
|
from ulid import ULID
|
||||||
|
|
||||||
|
from kilostar.core.postgres_database.model.persona_template import PersonaTemplate
|
||||||
|
from kilostar.core.postgres_database.database_exception import database_exception
|
||||||
|
|
||||||
|
|
||||||
|
class PersonaTemplateDatabase:
|
||||||
|
def __init__(self, async_session_maker):
|
||||||
|
self.async_session_maker = async_session_maker
|
||||||
|
|
||||||
|
@database_exception
|
||||||
|
async def add_template(self, **kwargs) -> PersonaTemplate:
|
||||||
|
async with self.async_session_maker() as session:
|
||||||
|
tpl = PersonaTemplate(template_id=str(ULID()), **kwargs)
|
||||||
|
session.add(tpl)
|
||||||
|
await session.commit()
|
||||||
|
await session.refresh(tpl)
|
||||||
|
return tpl
|
||||||
|
|
||||||
|
@database_exception
|
||||||
|
async def get_template(self, template_id: str):
|
||||||
|
async with self.async_session_maker() as session:
|
||||||
|
result = await session.execute(
|
||||||
|
select(PersonaTemplate).where(PersonaTemplate.template_id == template_id)
|
||||||
|
)
|
||||||
|
return result.scalar_one_or_none()
|
||||||
|
|
||||||
|
@database_exception
|
||||||
|
async def list_templates(self, owner_id: str = None, include_builtin: bool = True):
|
||||||
|
async with self.async_session_maker() as session:
|
||||||
|
stmt = select(PersonaTemplate)
|
||||||
|
if owner_id and include_builtin:
|
||||||
|
from sqlalchemy import or_
|
||||||
|
stmt = stmt.where(
|
||||||
|
or_(PersonaTemplate.owner_id == owner_id, PersonaTemplate.is_builtin == True)
|
||||||
|
)
|
||||||
|
elif owner_id:
|
||||||
|
stmt = stmt.where(PersonaTemplate.owner_id == owner_id)
|
||||||
|
elif include_builtin:
|
||||||
|
stmt = stmt.where(PersonaTemplate.is_builtin == True)
|
||||||
|
result = await session.execute(stmt)
|
||||||
|
return list(result.scalars().all())
|
||||||
|
|
||||||
|
@database_exception
|
||||||
|
async def update_template(self, template_id: str, **kwargs):
|
||||||
|
async with self.async_session_maker() as session:
|
||||||
|
result = await session.execute(
|
||||||
|
select(PersonaTemplate).where(PersonaTemplate.template_id == template_id)
|
||||||
|
)
|
||||||
|
tpl = result.scalar_one_or_none()
|
||||||
|
if not tpl:
|
||||||
|
return None
|
||||||
|
for k, v in kwargs.items():
|
||||||
|
if v is not None:
|
||||||
|
setattr(tpl, k, v)
|
||||||
|
session.add(tpl)
|
||||||
|
await session.commit()
|
||||||
|
await session.refresh(tpl)
|
||||||
|
return tpl
|
||||||
|
|
||||||
|
@database_exception
|
||||||
|
async def delete_template(self, template_id: str) -> bool:
|
||||||
|
async with self.async_session_maker() as session:
|
||||||
|
result = await session.execute(
|
||||||
|
select(PersonaTemplate).where(PersonaTemplate.template_id == template_id)
|
||||||
|
)
|
||||||
|
tpl = result.scalar_one_or_none()
|
||||||
|
if not tpl:
|
||||||
|
return False
|
||||||
|
await session.delete(tpl)
|
||||||
|
await session.commit()
|
||||||
|
return True
|
||||||
@@ -42,12 +42,8 @@ class TokenData(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
def _get_secret_key() -> str:
|
def _get_secret_key() -> str:
|
||||||
"""读取并校验 SECRET_KEY 环境变量。
|
from kilostar.utils.settings import get_settings
|
||||||
|
key = get_settings().security.secret_key
|
||||||
校验在首次实际使用 JWT 时进行,避免在模块导入阶段抛错,
|
|
||||||
从而把"环境约束"和"模块加载"解耦。
|
|
||||||
"""
|
|
||||||
key = os.getenv("SECRET_KEY")
|
|
||||||
if not key or key in _INSECURE_SECRETS:
|
if not key or key in _INSECURE_SECRETS:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"未提供有效的 SECRET_KEY 或使用了不安全的默认值,请设置一个高熵的随机字符串"
|
"未提供有效的 SECRET_KEY 或使用了不安全的默认值,请设置一个高熵的随机字符串"
|
||||||
|
|||||||
+15
-11
@@ -25,10 +25,11 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import os
|
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
_DEFAULT_LOCALE: str = os.getenv("KILOSTAR_LANG", "zh")
|
from kilostar.utils.settings import get_settings
|
||||||
|
|
||||||
|
_DEFAULT_LOCALE: str = get_settings().kilostar_lang
|
||||||
|
|
||||||
# ─── Agent System Prompts ──────────────────────────────────────────────────
|
# ─── Agent System Prompts ──────────────────────────────────────────────────
|
||||||
|
|
||||||
@@ -163,16 +164,16 @@ def t(key: str, locale: str | None = None, accept_language: str | None = None, *
|
|||||||
return text.format(**kwargs) if kwargs else text
|
return text.format(**kwargs) if kwargs else text
|
||||||
|
|
||||||
|
|
||||||
def agent_prompt(agent_name: str, locale: str | None = None, accept_language: str | None = None) -> str:
|
def agent_prompt(
|
||||||
|
agent_name: str,
|
||||||
|
locale: str | None = None,
|
||||||
|
accept_language: str | None = None,
|
||||||
|
custom_system_prompt: str | None = None,
|
||||||
|
) -> str:
|
||||||
"""获取指定 Agent 的 system prompt,并追加语言指令。
|
"""获取指定 Agent 的 system prompt,并追加语言指令。
|
||||||
|
|
||||||
Args:
|
若 ``custom_system_prompt`` 不为空,追加在默认 prompt 和语言指令之后,
|
||||||
agent_name: ``regulatory_node`` / ``consciousness_node`` / ``control_node``
|
使管理员自定义内容能够覆盖/补充默认行为,同时保留角色定义。
|
||||||
locale: 显式指定语言代码。
|
|
||||||
accept_language: ``Accept-Language`` 头内容。
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
完整 system prompt(含 "请使用 XX 语言回复" 的追加指令)。
|
|
||||||
"""
|
"""
|
||||||
loc = _resolve_locale(locale, accept_language)
|
loc = _resolve_locale(locale, accept_language)
|
||||||
prompt = _PROMPTS.get(agent_name, {}).get(loc) or _PROMPTS.get(agent_name, {}).get(_DEFAULT_LOCALE, "")
|
prompt = _PROMPTS.get(agent_name, {}).get(loc) or _PROMPTS.get(agent_name, {}).get(_DEFAULT_LOCALE, "")
|
||||||
@@ -180,4 +181,7 @@ def agent_prompt(agent_name: str, locale: str | None = None, accept_language: st
|
|||||||
"zh": "\n\n【重要】请始终使用简体中文进行思考和回复。",
|
"zh": "\n\n【重要】请始终使用简体中文进行思考和回复。",
|
||||||
"en": "\n\n[Important] Please always think and reply in English.",
|
"en": "\n\n[Important] Please always think and reply in English.",
|
||||||
}.get(loc, "")
|
}.get(loc, "")
|
||||||
return prompt + lang_instruction
|
result = prompt + lang_instruction
|
||||||
|
if custom_system_prompt and custom_system_prompt.strip():
|
||||||
|
result += f"\n\n{custom_system_prompt.strip()}"
|
||||||
|
return result
|
||||||
|
|||||||
@@ -12,7 +12,6 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import os
|
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
@@ -23,15 +22,11 @@ from kilostar.utils.request_context import get_request_id, get_trace_id
|
|||||||
|
|
||||||
|
|
||||||
def _is_json_mode() -> bool:
|
def _is_json_mode() -> bool:
|
||||||
"""根据环境变量决定是否启用 JSON 结构化日志。
|
from kilostar.utils.settings import get_settings
|
||||||
|
s = get_settings().log
|
||||||
支持开关:``KILOSTAR_LOG_FORMAT=json`` 或 ``KILOSTAR_LOG_JSON=1/true``。
|
if s.kilostar_log_format.lower() == "json":
|
||||||
"""
|
|
||||||
fmt = os.environ.get("KILOSTAR_LOG_FORMAT", "").lower()
|
|
||||||
if fmt == "json":
|
|
||||||
return True
|
return True
|
||||||
flag = os.environ.get("KILOSTAR_LOG_JSON", "").lower()
|
return s.kilostar_log_json.lower() in {"1", "true", "yes", "on"}
|
||||||
return flag in {"1", "true", "yes", "on"}
|
|
||||||
|
|
||||||
|
|
||||||
def _ctx_patcher(record):
|
def _ctx_patcher(record):
|
||||||
@@ -58,7 +53,8 @@ def setup_logger() -> Logger:
|
|||||||
"""
|
"""
|
||||||
logger.remove()
|
logger.remove()
|
||||||
|
|
||||||
log_level = os.environ.get("KILOSTAR_LOG_LEVEL", "DEBUG").upper()
|
from kilostar.utils.settings import get_settings
|
||||||
|
log_level = get_settings().log.kilostar_log_level.upper()
|
||||||
|
|
||||||
if _is_json_mode():
|
if _is_json_mode():
|
||||||
logger.configure(
|
logger.configure(
|
||||||
|
|||||||
@@ -125,3 +125,18 @@ def ray_actor_hook(*actor_names: str, timeout: float = 0.0, interval: float = 0.
|
|||||||
handle = _get_cached_actor_handle(actor_name)
|
handle = _get_cached_actor_handle(actor_name)
|
||||||
setattr(actor_list, actor_name, handle)
|
setattr(actor_list, actor_name, handle)
|
||||||
return actor_list
|
return actor_list
|
||||||
|
|
||||||
|
|
||||||
|
def get_worker_cluster(affinity: str = "cpu"):
|
||||||
|
"""按 node_affinity 标签取对应的 WorkerCluster actor 句柄。
|
||||||
|
|
||||||
|
单机模式统一返回唯一的 worker_cluster 实例。
|
||||||
|
分布式模式按 affinity 路由到 worker_cluster_cpu / _core / _gpu。
|
||||||
|
未知标签降级到 cpu。
|
||||||
|
"""
|
||||||
|
if _STANDALONE:
|
||||||
|
return _standalone_registry.get("worker_cluster")
|
||||||
|
|
||||||
|
_valid = {"cpu", "core", "gpu"}
|
||||||
|
node_type = affinity if affinity in _valid else "cpu"
|
||||||
|
return _get_cached_actor_handle(f"worker_cluster_{node_type}")
|
||||||
|
|||||||
@@ -0,0 +1,55 @@
|
|||||||
|
"""KiloStar 集中式环境变量管理。
|
||||||
|
|
||||||
|
所有散落在各模块的 os.getenv/os.environ 收敛到此处,
|
||||||
|
通过 pydantic-settings 统一校验、类型转换、默认值管理。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from functools import lru_cache
|
||||||
|
|
||||||
|
from pydantic import Field
|
||||||
|
from pydantic_settings import BaseSettings
|
||||||
|
|
||||||
|
|
||||||
|
class DatabaseSettings(BaseSettings):
|
||||||
|
postgres_user: str = "postgres"
|
||||||
|
postgres_password: str = ""
|
||||||
|
postgres_host: str = "db"
|
||||||
|
postgres_port: int = 5432
|
||||||
|
postgres_db: str = "postgres"
|
||||||
|
|
||||||
|
|
||||||
|
class SecuritySettings(BaseSettings):
|
||||||
|
secret_key: str = ""
|
||||||
|
kilostar_secret_key: str = ""
|
||||||
|
kilostar_env: str = "production"
|
||||||
|
|
||||||
|
|
||||||
|
class LogSettings(BaseSettings):
|
||||||
|
kilostar_log_level: str = "DEBUG"
|
||||||
|
kilostar_log_format: str = ""
|
||||||
|
kilostar_log_json: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
class OnebotSettings(BaseSettings):
|
||||||
|
onebot_access_token: str = ""
|
||||||
|
onebot_http_url: str = "http://127.0.0.1:5700"
|
||||||
|
|
||||||
|
|
||||||
|
class AppSettings(BaseSettings):
|
||||||
|
kilostar_mode: str = "distributed"
|
||||||
|
kilostar_lang: str = "zh"
|
||||||
|
kilostar_cors_origins: str = ""
|
||||||
|
|
||||||
|
db: DatabaseSettings = Field(default_factory=DatabaseSettings)
|
||||||
|
security: SecuritySettings = Field(default_factory=SecuritySettings)
|
||||||
|
log: LogSettings = Field(default_factory=LogSettings)
|
||||||
|
onebot: OnebotSettings = Field(default_factory=OnebotSettings)
|
||||||
|
|
||||||
|
model_config = {"env_nested_delimiter": "__"}
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize=1)
|
||||||
|
def get_settings() -> AppSettings:
|
||||||
|
return AppSettings()
|
||||||
@@ -36,10 +36,15 @@ class WorkerCluster:
|
|||||||
"""
|
"""
|
||||||
工作集群 Actor:管理和调度所有的 worker_individual
|
工作集群 Actor:管理和调度所有的 worker_individual
|
||||||
设计理念:按需加载,内存 LRU 淘汰,避免 Actor 爆炸
|
设计理念:按需加载,内存 LRU 淘汰,避免 Actor 爆炸
|
||||||
|
|
||||||
|
分布式模式下每种 node_type 对应一个独立实例,Ray 根据自定义资源
|
||||||
|
``kilostar_node_cpu`` / ``kilostar_node_core`` / ``kilostar_node_gpu``
|
||||||
|
将 Actor 调度到声明了对应资源的节点上。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, max_capacity: int = 200, num_runners: int = 10):
|
def __init__(self, max_capacity: int = 200, num_runners: int = 10, node_type: str = "cpu"):
|
||||||
self.max_capacity = max_capacity
|
self.max_capacity = max_capacity
|
||||||
|
self.node_type = node_type
|
||||||
self._active_workers: OrderedDict[str, BaseIndividual] = OrderedDict()
|
self._active_workers: OrderedDict[str, BaseIndividual] = OrderedDict()
|
||||||
self.status = "running"
|
self.status = "running"
|
||||||
self.task_queue = None
|
self.task_queue = None
|
||||||
@@ -76,6 +81,8 @@ class WorkerCluster:
|
|||||||
raise ValueError(f"无法唤醒 Agent {agent_id}:数据库中不存在该档案")
|
raise ValueError(f"无法唤醒 Agent {agent_id}:数据库中不存在该档案")
|
||||||
|
|
||||||
worker_type = agent_config.get("type", "ordinary")
|
worker_type = agent_config.get("type", "ordinary")
|
||||||
|
node_affinity = agent_config.get("node_affinity", "cpu")
|
||||||
|
self.logger.debug(f"[WorkerCluster] 唤醒 Agent {agent_id}, node_affinity={node_affinity}")
|
||||||
if worker_type == "skill":
|
if worker_type == "skill":
|
||||||
worker = SkillIndividual(agent_config)
|
worker = SkillIndividual(agent_config)
|
||||||
elif worker_type == "special":
|
elif worker_type == "special":
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ dependencies = [
|
|||||||
"pretor-viceroy>=0.2.0",
|
"pretor-viceroy>=0.2.0",
|
||||||
"pwdlib[argon2,bcrypt]>=0.3.0",
|
"pwdlib[argon2,bcrypt]>=0.3.0",
|
||||||
"pydantic-ai>=1.73.0",
|
"pydantic-ai>=1.73.0",
|
||||||
|
"pydantic-settings>=2.0",
|
||||||
"pyfiglet>=1.0.4",
|
"pyfiglet>=1.0.4",
|
||||||
"pyjwt>=2.12.1",
|
"pyjwt>=2.12.1",
|
||||||
"python-ulid>=3.1.0",
|
"python-ulid>=3.1.0",
|
||||||
|
|||||||
@@ -0,0 +1,165 @@
|
|||||||
|
"""``api/agent.py`` persona template 路由:CRUD 鉴权与 node_affinity 校验。"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import types
|
||||||
|
from unittest.mock import AsyncMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from httpx import AsyncClient, ASGITransport
|
||||||
|
|
||||||
|
from kilostar.api.agent import agent_router, _VALID_AFFINITIES
|
||||||
|
from kilostar.utils.access import Accessor, TokenData
|
||||||
|
from kilostar.core.postgres_database.model import UserAuthority
|
||||||
|
|
||||||
|
|
||||||
|
def _fake_user(user_id: str = "alice"):
|
||||||
|
return TokenData(user_id=user_id, username=user_id)
|
||||||
|
|
||||||
|
|
||||||
|
def _tpl(owner: str = "alice", is_builtin: bool = False):
|
||||||
|
return types.SimpleNamespace(
|
||||||
|
template_id="tpl1", name="MyBot", agent_type="ordinary",
|
||||||
|
description="d", system_prompt="s", provider_title="openai",
|
||||||
|
model_id="gpt-4", tools=[], tags=[], is_builtin=is_builtin, owner_id=owner,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def app(monkeypatch):
|
||||||
|
import kilostar.utils.check_user.role_check as rc
|
||||||
|
monkeypatch.setattr(rc, "get_authority", AsyncMock(return_value=UserAuthority.USER))
|
||||||
|
_app = FastAPI()
|
||||||
|
_app.include_router(agent_router)
|
||||||
|
_app.dependency_overrides[Accessor.get_current_user] = lambda: _fake_user()
|
||||||
|
return _app
|
||||||
|
|
||||||
|
|
||||||
|
def _register_pg(fake_actors, **kwargs):
|
||||||
|
pg = types.SimpleNamespace(**kwargs)
|
||||||
|
fake_actors.register("postgres_database", pg)
|
||||||
|
return pg
|
||||||
|
|
||||||
|
|
||||||
|
# ── node_affinity 校验(纯 pydantic) ─────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_valid_affinities_accepted():
|
||||||
|
from kilostar.api.agent import WorkerIndividualCreate
|
||||||
|
for aff in _VALID_AFFINITIES:
|
||||||
|
m = WorkerIndividualCreate(
|
||||||
|
agent_name="x", agent_type="ordinary", description="d",
|
||||||
|
provider_title="p", model_id="m", system_prompt="s",
|
||||||
|
output_template={}, bound_skill={}, workspace=[], node_affinity=aff,
|
||||||
|
)
|
||||||
|
assert m.node_affinity == aff
|
||||||
|
|
||||||
|
|
||||||
|
def test_invalid_affinity_raises():
|
||||||
|
from pydantic import ValidationError
|
||||||
|
from kilostar.api.agent import WorkerIndividualCreate
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
WorkerIndividualCreate(
|
||||||
|
agent_name="x", agent_type="ordinary", description="d",
|
||||||
|
provider_title="p", model_id="m", system_prompt="s",
|
||||||
|
output_template={}, bound_skill={}, workspace=[], node_affinity="bad",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_invalid_affinity_raises():
|
||||||
|
from pydantic import ValidationError
|
||||||
|
from kilostar.api.agent import WorkerIndividualUpdate
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
WorkerIndividualUpdate(node_affinity="bad")
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_none_affinity_ok():
|
||||||
|
from kilostar.api.agent import WorkerIndividualUpdate
|
||||||
|
assert WorkerIndividualUpdate(node_affinity=None).node_affinity is None
|
||||||
|
|
||||||
|
|
||||||
|
# ── template API 路由 ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_templates(app, fake_actors):
|
||||||
|
pg = types.SimpleNamespace(
|
||||||
|
list_templates=types.SimpleNamespace(remote=AsyncMock(return_value=[]))
|
||||||
|
)
|
||||||
|
fake_actors.register("postgres_database", pg)
|
||||||
|
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://t") as c:
|
||||||
|
r = await c.get("/api/v1/agent/template")
|
||||||
|
assert r.status_code == 200
|
||||||
|
assert r.json()["templates"] == []
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_template(app, fake_actors):
|
||||||
|
tpl = _tpl()
|
||||||
|
pg = types.SimpleNamespace(
|
||||||
|
add_template=types.SimpleNamespace(remote=AsyncMock(return_value=tpl))
|
||||||
|
)
|
||||||
|
fake_actors.register("postgres_database", pg)
|
||||||
|
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://t") as c:
|
||||||
|
r = await c.post("/api/v1/agent/template", json={"name": "test"})
|
||||||
|
assert r.status_code == 200
|
||||||
|
assert r.json()["template_id"] == "tpl1"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_delete_template_not_found(app, fake_actors):
|
||||||
|
pg = types.SimpleNamespace(
|
||||||
|
get_template=types.SimpleNamespace(remote=AsyncMock(return_value=None))
|
||||||
|
)
|
||||||
|
fake_actors.register("postgres_database", pg)
|
||||||
|
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://t") as c:
|
||||||
|
r = await c.delete("/api/v1/agent/template/missing")
|
||||||
|
assert r.status_code == 404
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_delete_builtin_template_forbidden(app, fake_actors):
|
||||||
|
pg = types.SimpleNamespace(
|
||||||
|
get_template=types.SimpleNamespace(remote=AsyncMock(return_value=_tpl(is_builtin=True)))
|
||||||
|
)
|
||||||
|
fake_actors.register("postgres_database", pg)
|
||||||
|
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://t") as c:
|
||||||
|
r = await c.delete("/api/v1/agent/template/tpl1")
|
||||||
|
assert r.status_code == 403
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_delete_other_users_template_forbidden(app, fake_actors):
|
||||||
|
pg = types.SimpleNamespace(
|
||||||
|
get_template=types.SimpleNamespace(remote=AsyncMock(return_value=_tpl(owner="bob")))
|
||||||
|
)
|
||||||
|
fake_actors.register("postgres_database", pg)
|
||||||
|
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://t") as c:
|
||||||
|
r = await c.delete("/api/v1/agent/template/tpl1")
|
||||||
|
assert r.status_code == 403
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_worker_from_template(app, fake_actors):
|
||||||
|
worker = types.SimpleNamespace(agent_id="w1")
|
||||||
|
pg = types.SimpleNamespace(
|
||||||
|
get_template=types.SimpleNamespace(remote=AsyncMock(return_value=_tpl())),
|
||||||
|
add_worker_individual=types.SimpleNamespace(remote=AsyncMock(return_value=worker)),
|
||||||
|
)
|
||||||
|
fake_actors.register("postgres_database", pg)
|
||||||
|
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://t") as c:
|
||||||
|
r = await c.post("/api/v1/agent/worker/from-template/tpl1")
|
||||||
|
assert r.status_code == 200
|
||||||
|
assert r.json()["agent_id"] == "w1"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_worker_from_missing_template(app, fake_actors):
|
||||||
|
pg = types.SimpleNamespace(
|
||||||
|
get_template=types.SimpleNamespace(remote=AsyncMock(return_value=None))
|
||||||
|
)
|
||||||
|
fake_actors.register("postgres_database", pg)
|
||||||
|
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://t") as c:
|
||||||
|
r = await c.post("/api/v1/agent/worker/from-template/nope")
|
||||||
|
assert r.status_code == 404
|
||||||
@@ -5,7 +5,7 @@ from pydantic import BaseModel, ValidationError
|
|||||||
from sqlalchemy.exc import IntegrityError, OperationalError
|
from sqlalchemy.exc import IntegrityError, OperationalError
|
||||||
|
|
||||||
from kilostar.core.postgres_database.database_exception import database_exception
|
from kilostar.core.postgres_database.database_exception import database_exception
|
||||||
from kilostar.utils.error import UserNotExistError
|
from kilostar.utils.error import UserNotExistError, BusinessError, RetryableError
|
||||||
|
|
||||||
|
|
||||||
async def test_normal_path_returns_value():
|
async def test_normal_path_returns_value():
|
||||||
@@ -36,21 +36,22 @@ async def test_validation_error_propagates():
|
|||||||
await boom()
|
await boom()
|
||||||
|
|
||||||
|
|
||||||
async def test_integrity_error_propagates():
|
async def test_integrity_error_becomes_business_error():
|
||||||
@database_exception
|
@database_exception
|
||||||
async def boom() -> None:
|
async def boom() -> None:
|
||||||
raise IntegrityError("stmt", {}, Exception("dup"))
|
raise IntegrityError("stmt", {}, Exception("dup"))
|
||||||
|
|
||||||
with pytest.raises(IntegrityError):
|
with pytest.raises(BusinessError) as exc_info:
|
||||||
await boom()
|
await boom()
|
||||||
|
assert exc_info.value.http_status == 409
|
||||||
|
|
||||||
|
|
||||||
async def test_operational_error_propagates():
|
async def test_operational_error_becomes_retryable():
|
||||||
@database_exception
|
@database_exception
|
||||||
async def boom() -> None:
|
async def boom() -> None:
|
||||||
raise OperationalError("stmt", {}, Exception("conn"))
|
raise OperationalError("stmt", {}, Exception("conn"))
|
||||||
|
|
||||||
with pytest.raises(OperationalError):
|
with pytest.raises(RetryableError):
|
||||||
await boom()
|
await boom()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,74 @@
|
|||||||
|
"""``PersonaTemplateDatabase`` — list_templates 查询逻辑单元测试。"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
from kilostar.core.postgres_database.module.persona_template import PersonaTemplateDatabase
|
||||||
|
|
||||||
|
|
||||||
|
def _make_db():
|
||||||
|
session = AsyncMock()
|
||||||
|
session.__aenter__ = AsyncMock(return_value=session)
|
||||||
|
session.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
session_maker = MagicMock(return_value=session)
|
||||||
|
return PersonaTemplateDatabase(session_maker), session
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_list_owner_and_builtin_uses_or():
|
||||||
|
"""owner_id + include_builtin=True 应构造 OR 条件,不拉出其他人的模板。"""
|
||||||
|
db, session = _make_db()
|
||||||
|
execute_result = MagicMock()
|
||||||
|
execute_result.scalars.return_value.all.return_value = []
|
||||||
|
session.execute = AsyncMock(return_value=execute_result)
|
||||||
|
|
||||||
|
result = await db.list_templates(owner_id="alice", include_builtin=True)
|
||||||
|
assert result == []
|
||||||
|
# 确认 execute 被调用(OR 条件路径走通)
|
||||||
|
session.execute.assert_awaited_once()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_list_owner_only():
|
||||||
|
db, session = _make_db()
|
||||||
|
execute_result = MagicMock()
|
||||||
|
execute_result.scalars.return_value.all.return_value = []
|
||||||
|
session.execute = AsyncMock(return_value=execute_result)
|
||||||
|
|
||||||
|
await db.list_templates(owner_id="alice", include_builtin=False)
|
||||||
|
session.execute.assert_awaited_once()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_list_builtin_only():
|
||||||
|
db, session = _make_db()
|
||||||
|
execute_result = MagicMock()
|
||||||
|
execute_result.scalars.return_value.all.return_value = []
|
||||||
|
session.execute = AsyncMock(return_value=execute_result)
|
||||||
|
|
||||||
|
await db.list_templates(owner_id=None, include_builtin=True)
|
||||||
|
session.execute.assert_awaited_once()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_delete_nonexistent_returns_false():
|
||||||
|
db, session = _make_db()
|
||||||
|
execute_result = MagicMock()
|
||||||
|
execute_result.scalar_one_or_none.return_value = None
|
||||||
|
session.execute = AsyncMock(return_value=execute_result)
|
||||||
|
|
||||||
|
result = await db.delete_template("missing")
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_update_nonexistent_returns_none():
|
||||||
|
db, session = _make_db()
|
||||||
|
execute_result = MagicMock()
|
||||||
|
execute_result.scalar_one_or_none.return_value = None
|
||||||
|
session.execute = AsyncMock(return_value=execute_result)
|
||||||
|
|
||||||
|
result = await db.update_template("missing", name="new")
|
||||||
|
assert result is None
|
||||||
Reference in New Issue
Block a user