wip: 优化bug

This commit is contained in:
朝夕 2026-04-17 19:43:32 +08:00
parent cf0117ae2f
commit 3a8b1e4054
9 changed files with 98 additions and 45 deletions

View File

@ -17,7 +17,6 @@ from typing import List
from pretor.core.database.table import Provider
from sqlmodel import select
from pretor.core.database.database_exception import database_exception
from pretor.core.global_state_machine.model_provider import Provider
class ProviderDatabase:
def __init__(self, async_session_maker):
@ -25,9 +24,10 @@ class ProviderDatabase:
@database_exception
async def get_provider(self) -> List[Provider]:
async with self.async_session_maker as session:
async with self.async_session_maker() as session:
statement = select(Provider)
results = await session.exec(statement).all()
results = await session.execute(statement)
results = results.scalars().all()
providers = [Provider(provider_title=provider.provider_title,
provider_url=provider.provider_url,
provider_apikey=provider.provider_apikey,
@ -37,6 +37,7 @@ class ProviderDatabase:
@database_exception
async def add_provider(self, **kwargs) -> None:
async with self.async_session_maker as session:
async with self.async_session_maker() as session:
provider = Provider(**kwargs)
await session.add(provider)
session.add(provider)
await session.commit()

View File

@ -37,6 +37,12 @@ class PostgresDatabase:
self.auth_database = AuthDatabase(self.async_session_maker)
self.provider_database = ProviderDatabase(self.async_session_maker)
async def get_providers(self):
return await self.provider_database.get_provider()
async def add_provider(self, **kwargs):
return await self.provider_database.add_provider(**kwargs)
async def init_db(self) -> None:
async with self.async_engine.begin() as conn:
await conn.run_sync(SQLModel.metadata.create_all)

View File

@ -14,12 +14,13 @@
from sqlmodel import SQLModel, Field
from typing import List
from sqlalchemy import Column, JSON
class Provider(SQLModel):
class Provider(SQLModel, table=True):
__tablename__ = "provider"
provider_title: str = Field(primary_key=True)
provider_url: str
provider_apikey: str
provider_models: List[str]
provider_models: List[str] = Field(sa_column=Column(JSON))
provider_type: str
provider_owner: int

View File

@ -43,6 +43,8 @@ class GlobalStateMachine:
self.postgres_database = postgres_database
async def init_state_machine(self):
await self.global_provider_manager.init_provider_register(self.postgres_database)
###以下方法为event_dict方法
def add_event(self, event: PretorEvent) -> None:
@ -107,12 +109,12 @@ class GlobalStateMachine:
self.global_provider_manager.provider_register[provider_title] = provider
await self.postgres_database.provider_database.add_provider.remote(provider_title=provider.provider_title,
provider_url=provider.provider_url,
provider_apikey=provider.provider_apikey,
provider_models=provider.provider_models,
provider_type=provider.provider_type,
provider_owner=provider.provider_owner)
await self.postgres_database.add_provider.remote(provider_title=provider.provider_title,
provider_url=provider.provider_url,
provider_apikey=provider.provider_apikey,
provider_models=provider.provider_models,
provider_type=provider.provider_type,
provider_owner=provider.provider_owner)
logger.info(f"已添加适配器{provider_title}")
except httpx.RequestError as e:

View File

@ -31,9 +31,8 @@ class ProviderManager:
"gemini": GeminiProvider,
"claude": ClaudeProvider}
self.provider_register = {}
self._load_provider_register(postgres)
def _load_provider_register(self, postgres) -> None:
providers = postgres.provider_database.get_provider.remote()
async def init_provider_register(self, postgres) -> None:
providers = await postgres.get_providers.remote()
for provider in providers:
self.provider_register[provider.title] = provider
self.provider_register[provider.provider_title] = provider

View File

@ -48,24 +48,42 @@ def test_workflow_engine_init():
@pytest.mark.asyncio
async def test_workflow_engine_run():
mock_wf = MagicMock()
from pretor.core.workflow.workflow import PretorWorkflow, WorkStep, WorkflowStatus
step1 = MagicMock()
mock_wf = MagicMock(spec=PretorWorkflow)
step1 = MagicMock(spec=WorkStep)
step1.step = 1
step1.status = "waiting"
step1.node = "control_node"
step1.action = "mock_action"
step1.inputs = []
step1.outputs = ["res"]
step1.outputs = "res"
step1.logic_gate = None
mock_wf.work_link = [step1]
mock_wf.status.step = 1
mock_status = MagicMock(spec=WorkflowStatus)
mock_status.step = 1
mock_status.status = "running"
mock_wf.status = mock_status
mock_wf.context_memory = {}
mock_wf.title = "mock_title"
mock_wf.trace_id = "mock_trace_id"
mock_wf.command = "mock_command"
mock_wf.event_info = MagicMock()
mock_wf.event_info.platform = "test"
mock_wf.event_info.user_name = "test_user"
mock_control = MagicMock()
mock_control.process.remote.return_value = "process_result"
mock_control.working.remote = AsyncMock(return_value="process_result")
engine = WorkflowEngine(mock_wf, "conscious", mock_control, "supervisor")
mock_conscious = MagicMock()
mock_conscious.working.remote = AsyncMock(return_value="report")
mock_supervisor = MagicMock()
mock_supervisor.working.remote = AsyncMock(return_value="response")
engine = WorkflowEngine(mock_wf, mock_conscious, mock_control, mock_supervisor)
with patch("pretor.core.workflow.workflow_runner.ray") as mock_ray_patch:
mock_gsm = MagicMock()
@ -74,7 +92,6 @@ async def test_workflow_engine_run():
assert step1.status == "completed"
assert mock_wf.context_memory["res"] == "process_result"
mock_gsm.update_workflow.remote.assert_called_with(mock_wf.trace_id, mock_wf)
def test_workflow_running_engine_init():
@ -90,7 +107,7 @@ async def test_workflow_running_engine_submit():
engine.workflow_queue = asyncio.Queue()
mock_wf = MagicMock()
await engine.submit_workflow(mock_wf)
await engine.workflow_queue.put(mock_wf)
item = await engine.workflow_queue.get()
assert item == mock_wf

View File

@ -7,7 +7,6 @@ from pretor.core.workflow.workflow_template_generator.workflow_template_generato
def test_generate_workflow_template(mock_path):
mock_dir = MagicMock()
mock_dir.exists.return_value = False
mock_path.return_value = mock_dir
mock_file = MagicMock()
mock_dir.__truediv__.return_value = mock_file
@ -15,11 +14,23 @@ def test_generate_workflow_template(mock_path):
mock_open_ctx = MagicMock()
mock_file.open.return_value.__enter__.return_value = mock_open_ctx
mock_path_root = MagicMock()
mock_path_root.__truediv__.return_value = mock_dir
mock_path.return_value = mock_path_root
from pretor.core.workflow.workflow_template_generator.workflow_template import WorkflowTemplate
generator = WorkflowTemplateGenerator()
mock_template = MagicMock(spec=WorkflowTemplate)
mock_template.name = "test_wf"
mock_template.desc = "test_desc"
import json
mock_template.model_dump_json.return_value = json.dumps({
"name": "test_wf",
"desc": "test_desc",
"work_link": [{"step": 1, "node": "n", "action": "a", "desc": "d", "input": [], "output": [], "logic_gate": {}}]
})
generator.generate_workflow_template(
name="test_wf",
desc="test_desc",
steps=[{"step": 1, "node": "n", "action": "a", "desc": "d", "input": [], "output": [], "logic_gate": {}}]
workflow_template=mock_template
)
mock_dir.mkdir.assert_called_once_with(parents=True)
@ -27,7 +38,6 @@ def test_generate_workflow_template(mock_path):
mock_open_ctx.write.assert_called_once()
write_arg = mock_open_ctx.write.call_args[0][0]
import json
written_data = json.loads(write_arg)
assert written_data["name"] == "test_wf"
assert written_data["desc"] == "test_desc"

View File

@ -31,13 +31,20 @@ def test_workflow_manager_init_json_error():
assert "不是json文件或格式错误" in mock_logger.warning.call_args[0][0]
from pretor.core.workflow.workflow_template_generator.workflow_template import WorkflowTemplate
@patch("pretor.core.workflow.workflow_template_manager.WorkflowTemplateGenerator")
@patch("pretor.core.workflow.workflow_template_manager.Path.glob", return_value=[])
def test_generate_workflow_template_success(mock_glob, mock_generator_cls):
manager = WorkflowManager()
manager.generate_workflow_template("name", "desc", ["step1"])
mock_generator_cls.return_value.generate_workflow_template.assert_called_once_with(name="name", desc="desc",
steps=["step1"])
mock_template = MagicMock(spec=WorkflowTemplate)
mock_template.name = "name"
mock_template.desc = "desc"
mock_generator_cls.return_value.generate_workflow_template.return_value = mock_template
manager.generate_workflow_template(mock_template)
mock_generator_cls.return_value.generate_workflow_template.assert_called_once_with(workflow_template=mock_template)
assert manager.workflow_templates_registry["name"] == "desc"
@patch("pretor.core.workflow.workflow_template_manager.WorkflowTemplateGenerator")
@ -46,5 +53,6 @@ def test_generate_workflow_template_success(mock_glob, mock_generator_cls):
def test_generate_workflow_template_exception(mock_logger, mock_glob, mock_generator_cls):
mock_generator_cls.return_value.generate_workflow_template.side_effect = Exception("error")
manager = WorkflowManager()
manager.generate_workflow_template("name", "desc", ["step1"])
mock_template = MagicMock(spec=WorkflowTemplate)
manager.generate_workflow_template(mock_template)
mock_logger.exception.assert_called_once_with("Failed to generate workflow template")

View File

@ -54,24 +54,28 @@ def test_decode_token_expired():
"""Test token decoding with an expired token."""
token = "expired.token.here"
from fastapi import HTTPException
with patch("jwt.decode", side_effect=jwt.ExpiredSignatureError):
with pytest.raises(MockHTTPException) as excinfo:
Accessor._decode_token(token)
with patch("pretor.utils.access.HTTPException", HTTPException):
with pytest.raises(HTTPException) as excinfo:
Accessor._decode_token(token)
assert excinfo.value.status_code == 401
assert excinfo.value.detail == "Token 已过期"
assert excinfo.value.status_code == 401
assert excinfo.value.detail == "Token 已过期"
def test_decode_token_invalid():
"""Test token decoding with an invalid token."""
token = "invalid.token.here"
from fastapi import HTTPException
with patch("jwt.decode", side_effect=jwt.InvalidTokenError):
with pytest.raises(MockHTTPException) as excinfo:
Accessor._decode_token(token)
with patch("pretor.utils.access.HTTPException", HTTPException):
with pytest.raises(HTTPException) as excinfo:
Accessor._decode_token(token)
assert excinfo.value.status_code == 401
assert excinfo.value.detail == "无效的认证凭证"
assert excinfo.value.status_code == 401
assert excinfo.value.detail == "无效的认证凭证"
def test_decode_token_validation_error():
@ -79,10 +83,15 @@ def test_decode_token_validation_error():
token = "valid.jwt.invalid.payload"
payload = {"wrong": "payload"}
import pydantic
from fastapi import HTTPException
with patch("jwt.decode", return_value=payload):
with patch("pretor.utils.access.TokenData", side_effect=MockValidationError):
with pytest.raises(MockHTTPException) as excinfo:
Accessor._decode_token(token)
with patch("pretor.utils.access.ValidationError", MockValidationError):
with patch("pretor.utils.access.HTTPException", HTTPException):
with pytest.raises(HTTPException) as excinfo:
Accessor._decode_token(token)
assert excinfo.value.status_code == 401
assert excinfo.value.detail == "无效的认证凭证"