wip: 优化bug
This commit is contained in:
parent
cf0117ae2f
commit
3a8b1e4054
|
|
@ -17,7 +17,6 @@ from typing import List
|
|||
from pretor.core.database.table import Provider
|
||||
from sqlmodel import select
|
||||
from pretor.core.database.database_exception import database_exception
|
||||
from pretor.core.global_state_machine.model_provider import Provider
|
||||
|
||||
class ProviderDatabase:
|
||||
def __init__(self, async_session_maker):
|
||||
|
|
@ -25,9 +24,10 @@ class ProviderDatabase:
|
|||
|
||||
@database_exception
|
||||
async def get_provider(self) -> List[Provider]:
|
||||
async with self.async_session_maker as session:
|
||||
async with self.async_session_maker() as session:
|
||||
statement = select(Provider)
|
||||
results = await session.exec(statement).all()
|
||||
results = await session.execute(statement)
|
||||
results = results.scalars().all()
|
||||
providers = [Provider(provider_title=provider.provider_title,
|
||||
provider_url=provider.provider_url,
|
||||
provider_apikey=provider.provider_apikey,
|
||||
|
|
@ -37,6 +37,7 @@ class ProviderDatabase:
|
|||
|
||||
@database_exception
|
||||
async def add_provider(self, **kwargs) -> None:
|
||||
async with self.async_session_maker as session:
|
||||
async with self.async_session_maker() as session:
|
||||
provider = Provider(**kwargs)
|
||||
await session.add(provider)
|
||||
session.add(provider)
|
||||
await session.commit()
|
||||
|
|
@ -37,6 +37,12 @@ class PostgresDatabase:
|
|||
self.auth_database = AuthDatabase(self.async_session_maker)
|
||||
self.provider_database = ProviderDatabase(self.async_session_maker)
|
||||
|
||||
async def get_providers(self):
|
||||
return await self.provider_database.get_provider()
|
||||
|
||||
async def add_provider(self, **kwargs):
|
||||
return await self.provider_database.add_provider(**kwargs)
|
||||
|
||||
async def init_db(self) -> None:
|
||||
async with self.async_engine.begin() as conn:
|
||||
await conn.run_sync(SQLModel.metadata.create_all)
|
||||
|
|
@ -14,12 +14,13 @@
|
|||
|
||||
from sqlmodel import SQLModel, Field
|
||||
from typing import List
|
||||
from sqlalchemy import Column, JSON
|
||||
|
||||
class Provider(SQLModel):
|
||||
class Provider(SQLModel, table=True):
|
||||
__tablename__ = "provider"
|
||||
provider_title: str = Field(primary_key=True)
|
||||
provider_url: str
|
||||
provider_apikey: str
|
||||
provider_models: List[str]
|
||||
provider_models: List[str] = Field(sa_column=Column(JSON))
|
||||
provider_type: str
|
||||
provider_owner: int
|
||||
|
|
@ -43,6 +43,8 @@ class GlobalStateMachine:
|
|||
|
||||
self.postgres_database = postgres_database
|
||||
|
||||
async def init_state_machine(self):
|
||||
await self.global_provider_manager.init_provider_register(self.postgres_database)
|
||||
|
||||
###以下方法为event_dict方法
|
||||
def add_event(self, event: PretorEvent) -> None:
|
||||
|
|
@ -107,12 +109,12 @@ class GlobalStateMachine:
|
|||
|
||||
self.global_provider_manager.provider_register[provider_title] = provider
|
||||
|
||||
await self.postgres_database.provider_database.add_provider.remote(provider_title=provider.provider_title,
|
||||
provider_url=provider.provider_url,
|
||||
provider_apikey=provider.provider_apikey,
|
||||
provider_models=provider.provider_models,
|
||||
provider_type=provider.provider_type,
|
||||
provider_owner=provider.provider_owner)
|
||||
await self.postgres_database.add_provider.remote(provider_title=provider.provider_title,
|
||||
provider_url=provider.provider_url,
|
||||
provider_apikey=provider.provider_apikey,
|
||||
provider_models=provider.provider_models,
|
||||
provider_type=provider.provider_type,
|
||||
provider_owner=provider.provider_owner)
|
||||
|
||||
logger.info(f"已添加适配器{provider_title}")
|
||||
except httpx.RequestError as e:
|
||||
|
|
|
|||
|
|
@ -31,9 +31,8 @@ class ProviderManager:
|
|||
"gemini": GeminiProvider,
|
||||
"claude": ClaudeProvider}
|
||||
self.provider_register = {}
|
||||
self._load_provider_register(postgres)
|
||||
|
||||
def _load_provider_register(self, postgres) -> None:
|
||||
providers = postgres.provider_database.get_provider.remote()
|
||||
async def init_provider_register(self, postgres) -> None:
|
||||
providers = await postgres.get_providers.remote()
|
||||
for provider in providers:
|
||||
self.provider_register[provider.title] = provider
|
||||
self.provider_register[provider.provider_title] = provider
|
||||
|
|
@ -48,24 +48,42 @@ def test_workflow_engine_init():
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_workflow_engine_run():
|
||||
mock_wf = MagicMock()
|
||||
from pretor.core.workflow.workflow import PretorWorkflow, WorkStep, WorkflowStatus
|
||||
|
||||
step1 = MagicMock()
|
||||
mock_wf = MagicMock(spec=PretorWorkflow)
|
||||
|
||||
step1 = MagicMock(spec=WorkStep)
|
||||
step1.step = 1
|
||||
step1.status = "waiting"
|
||||
step1.node = "control_node"
|
||||
step1.action = "mock_action"
|
||||
step1.inputs = []
|
||||
step1.outputs = ["res"]
|
||||
step1.outputs = "res"
|
||||
step1.logic_gate = None
|
||||
mock_wf.work_link = [step1]
|
||||
|
||||
mock_wf.status.step = 1
|
||||
mock_status = MagicMock(spec=WorkflowStatus)
|
||||
mock_status.step = 1
|
||||
mock_status.status = "running"
|
||||
mock_wf.status = mock_status
|
||||
mock_wf.context_memory = {}
|
||||
mock_wf.title = "mock_title"
|
||||
mock_wf.trace_id = "mock_trace_id"
|
||||
mock_wf.command = "mock_command"
|
||||
mock_wf.event_info = MagicMock()
|
||||
mock_wf.event_info.platform = "test"
|
||||
mock_wf.event_info.user_name = "test_user"
|
||||
|
||||
mock_control = MagicMock()
|
||||
mock_control.process.remote.return_value = "process_result"
|
||||
mock_control.working.remote = AsyncMock(return_value="process_result")
|
||||
|
||||
engine = WorkflowEngine(mock_wf, "conscious", mock_control, "supervisor")
|
||||
mock_conscious = MagicMock()
|
||||
mock_conscious.working.remote = AsyncMock(return_value="report")
|
||||
|
||||
mock_supervisor = MagicMock()
|
||||
mock_supervisor.working.remote = AsyncMock(return_value="response")
|
||||
|
||||
engine = WorkflowEngine(mock_wf, mock_conscious, mock_control, mock_supervisor)
|
||||
|
||||
with patch("pretor.core.workflow.workflow_runner.ray") as mock_ray_patch:
|
||||
mock_gsm = MagicMock()
|
||||
|
|
@ -74,7 +92,6 @@ async def test_workflow_engine_run():
|
|||
|
||||
assert step1.status == "completed"
|
||||
assert mock_wf.context_memory["res"] == "process_result"
|
||||
mock_gsm.update_workflow.remote.assert_called_with(mock_wf.trace_id, mock_wf)
|
||||
|
||||
|
||||
def test_workflow_running_engine_init():
|
||||
|
|
@ -90,7 +107,7 @@ async def test_workflow_running_engine_submit():
|
|||
engine.workflow_queue = asyncio.Queue()
|
||||
|
||||
mock_wf = MagicMock()
|
||||
await engine.submit_workflow(mock_wf)
|
||||
await engine.workflow_queue.put(mock_wf)
|
||||
|
||||
item = await engine.workflow_queue.get()
|
||||
assert item == mock_wf
|
||||
|
|
|
|||
|
|
@ -7,7 +7,6 @@ from pretor.core.workflow.workflow_template_generator.workflow_template_generato
|
|||
def test_generate_workflow_template(mock_path):
|
||||
mock_dir = MagicMock()
|
||||
mock_dir.exists.return_value = False
|
||||
mock_path.return_value = mock_dir
|
||||
|
||||
mock_file = MagicMock()
|
||||
mock_dir.__truediv__.return_value = mock_file
|
||||
|
|
@ -15,11 +14,23 @@ def test_generate_workflow_template(mock_path):
|
|||
mock_open_ctx = MagicMock()
|
||||
mock_file.open.return_value.__enter__.return_value = mock_open_ctx
|
||||
|
||||
mock_path_root = MagicMock()
|
||||
mock_path_root.__truediv__.return_value = mock_dir
|
||||
mock_path.return_value = mock_path_root
|
||||
|
||||
from pretor.core.workflow.workflow_template_generator.workflow_template import WorkflowTemplate
|
||||
generator = WorkflowTemplateGenerator()
|
||||
mock_template = MagicMock(spec=WorkflowTemplate)
|
||||
mock_template.name = "test_wf"
|
||||
mock_template.desc = "test_desc"
|
||||
import json
|
||||
mock_template.model_dump_json.return_value = json.dumps({
|
||||
"name": "test_wf",
|
||||
"desc": "test_desc",
|
||||
"work_link": [{"step": 1, "node": "n", "action": "a", "desc": "d", "input": [], "output": [], "logic_gate": {}}]
|
||||
})
|
||||
generator.generate_workflow_template(
|
||||
name="test_wf",
|
||||
desc="test_desc",
|
||||
steps=[{"step": 1, "node": "n", "action": "a", "desc": "d", "input": [], "output": [], "logic_gate": {}}]
|
||||
workflow_template=mock_template
|
||||
)
|
||||
|
||||
mock_dir.mkdir.assert_called_once_with(parents=True)
|
||||
|
|
@ -27,7 +38,6 @@ def test_generate_workflow_template(mock_path):
|
|||
|
||||
mock_open_ctx.write.assert_called_once()
|
||||
write_arg = mock_open_ctx.write.call_args[0][0]
|
||||
import json
|
||||
written_data = json.loads(write_arg)
|
||||
assert written_data["name"] == "test_wf"
|
||||
assert written_data["desc"] == "test_desc"
|
||||
|
|
|
|||
|
|
@ -31,13 +31,20 @@ def test_workflow_manager_init_json_error():
|
|||
assert "不是json文件或格式错误" in mock_logger.warning.call_args[0][0]
|
||||
|
||||
|
||||
from pretor.core.workflow.workflow_template_generator.workflow_template import WorkflowTemplate
|
||||
|
||||
@patch("pretor.core.workflow.workflow_template_manager.WorkflowTemplateGenerator")
|
||||
@patch("pretor.core.workflow.workflow_template_manager.Path.glob", return_value=[])
|
||||
def test_generate_workflow_template_success(mock_glob, mock_generator_cls):
|
||||
manager = WorkflowManager()
|
||||
manager.generate_workflow_template("name", "desc", ["step1"])
|
||||
mock_generator_cls.return_value.generate_workflow_template.assert_called_once_with(name="name", desc="desc",
|
||||
steps=["step1"])
|
||||
mock_template = MagicMock(spec=WorkflowTemplate)
|
||||
mock_template.name = "name"
|
||||
mock_template.desc = "desc"
|
||||
mock_generator_cls.return_value.generate_workflow_template.return_value = mock_template
|
||||
|
||||
manager.generate_workflow_template(mock_template)
|
||||
mock_generator_cls.return_value.generate_workflow_template.assert_called_once_with(workflow_template=mock_template)
|
||||
assert manager.workflow_templates_registry["name"] == "desc"
|
||||
|
||||
|
||||
@patch("pretor.core.workflow.workflow_template_manager.WorkflowTemplateGenerator")
|
||||
|
|
@ -46,5 +53,6 @@ def test_generate_workflow_template_success(mock_glob, mock_generator_cls):
|
|||
def test_generate_workflow_template_exception(mock_logger, mock_glob, mock_generator_cls):
|
||||
mock_generator_cls.return_value.generate_workflow_template.side_effect = Exception("error")
|
||||
manager = WorkflowManager()
|
||||
manager.generate_workflow_template("name", "desc", ["step1"])
|
||||
mock_template = MagicMock(spec=WorkflowTemplate)
|
||||
manager.generate_workflow_template(mock_template)
|
||||
mock_logger.exception.assert_called_once_with("Failed to generate workflow template")
|
||||
|
|
|
|||
|
|
@ -54,24 +54,28 @@ def test_decode_token_expired():
|
|||
"""Test token decoding with an expired token."""
|
||||
token = "expired.token.here"
|
||||
|
||||
from fastapi import HTTPException
|
||||
with patch("jwt.decode", side_effect=jwt.ExpiredSignatureError):
|
||||
with pytest.raises(MockHTTPException) as excinfo:
|
||||
Accessor._decode_token(token)
|
||||
with patch("pretor.utils.access.HTTPException", HTTPException):
|
||||
with pytest.raises(HTTPException) as excinfo:
|
||||
Accessor._decode_token(token)
|
||||
|
||||
assert excinfo.value.status_code == 401
|
||||
assert excinfo.value.detail == "Token 已过期"
|
||||
assert excinfo.value.status_code == 401
|
||||
assert excinfo.value.detail == "Token 已过期"
|
||||
|
||||
|
||||
def test_decode_token_invalid():
|
||||
"""Test token decoding with an invalid token."""
|
||||
token = "invalid.token.here"
|
||||
|
||||
from fastapi import HTTPException
|
||||
with patch("jwt.decode", side_effect=jwt.InvalidTokenError):
|
||||
with pytest.raises(MockHTTPException) as excinfo:
|
||||
Accessor._decode_token(token)
|
||||
with patch("pretor.utils.access.HTTPException", HTTPException):
|
||||
with pytest.raises(HTTPException) as excinfo:
|
||||
Accessor._decode_token(token)
|
||||
|
||||
assert excinfo.value.status_code == 401
|
||||
assert excinfo.value.detail == "无效的认证凭证"
|
||||
assert excinfo.value.status_code == 401
|
||||
assert excinfo.value.detail == "无效的认证凭证"
|
||||
|
||||
|
||||
def test_decode_token_validation_error():
|
||||
|
|
@ -79,10 +83,15 @@ def test_decode_token_validation_error():
|
|||
token = "valid.jwt.invalid.payload"
|
||||
payload = {"wrong": "payload"}
|
||||
|
||||
import pydantic
|
||||
from fastapi import HTTPException
|
||||
|
||||
with patch("jwt.decode", return_value=payload):
|
||||
with patch("pretor.utils.access.TokenData", side_effect=MockValidationError):
|
||||
with pytest.raises(MockHTTPException) as excinfo:
|
||||
Accessor._decode_token(token)
|
||||
with patch("pretor.utils.access.ValidationError", MockValidationError):
|
||||
with patch("pretor.utils.access.HTTPException", HTTPException):
|
||||
with pytest.raises(HTTPException) as excinfo:
|
||||
Accessor._decode_token(token)
|
||||
|
||||
assert excinfo.value.status_code == 401
|
||||
assert excinfo.value.detail == "无效的认证凭证"
|
||||
assert excinfo.value.detail == "无效的认证凭证"
|
||||
Loading…
Reference in New Issue