74 lines
2.4 KiB
Python
74 lines
2.4 KiB
Python
import pytest
|
|
from unittest.mock import patch, MagicMock
|
|
import sys
|
|
import builtins
|
|
|
|
real_import = builtins.__import__
|
|
|
|
|
|
def mock_import(name, globals=None, locals=None, fromlist=(), level=0):
|
|
if name == 'ray':
|
|
mock_ray = MagicMock()
|
|
|
|
def mock_remote(cls):
|
|
return cls
|
|
|
|
mock_ray.remote = mock_remote
|
|
return mock_ray
|
|
return real_import(name, globals, locals, fromlist, level)
|
|
|
|
|
|
builtins.__import__ = mock_import
|
|
for mod in list(sys.modules.keys()):
|
|
if 'pretor.core.database.postgres' in mod or 'ray' in mod:
|
|
del sys.modules[mod]
|
|
|
|
from pretor.core.database.postgres import PostgresDatabase
|
|
|
|
builtins.__import__ = real_import
|
|
|
|
|
|
@patch("pretor.core.database.postgres.create_async_engine")
|
|
@patch("pretor.core.database.postgres.sessionmaker")
|
|
@patch("pretor.core.database.postgres.AuthDatabase")
|
|
@patch("pretor.core.database.postgres.ProviderDatabase")
|
|
@patch("pretor.core.database.postgres.os.environ.get")
|
|
@pytest.mark.asyncio
|
|
async def test_postgres_database(mock_env_get, mock_provider_db, mock_auth_db, mock_sessionmaker, mock_create_engine):
|
|
def env_side_effect(key):
|
|
return {
|
|
"POSTGRES_USER": "testuser",
|
|
"POSTGRES_PASSWORD": "testpassword",
|
|
"POSTGRES_HOST": "localhost",
|
|
"POSTGRES_PORT": "5432",
|
|
"POSTGRES_DB": "testdb"
|
|
}.get(key)
|
|
|
|
mock_env_get.side_effect = env_side_effect
|
|
|
|
mock_engine = MagicMock()
|
|
mock_conn = MagicMock()
|
|
from unittest.mock import AsyncMock
|
|
mock_conn.run_sync = AsyncMock()
|
|
|
|
mock_begin_ctx = MagicMock()
|
|
mock_begin_ctx.__aenter__ = AsyncMock(return_value=mock_conn)
|
|
mock_begin_ctx.__aexit__ = AsyncMock()
|
|
mock_engine.begin.return_value = mock_begin_ctx
|
|
mock_create_engine.return_value = mock_engine
|
|
|
|
db = PostgresDatabase()
|
|
|
|
mock_create_engine.assert_called_once_with(
|
|
"postgresql+asyncpg://testuser:testpassword@localhost:5432/testdb",
|
|
echo=True
|
|
)
|
|
mock_auth_db.assert_called_once()
|
|
mock_provider_db.assert_called_once()
|
|
mock_auth_db.return_value.get_user_authority = AsyncMock(return_value="test_auth")
|
|
assert await db.auth_database("get_user_authority", user_id="123") == "test_auth"
|
|
|
|
with patch("pretor.core.database.postgres.SQLModel.metadata.create_all") as mock_create_all:
|
|
await db.init_db()
|
|
mock_conn.run_sync.assert_called_once_with(mock_create_all)
|