import pytest from unittest.mock import patch, MagicMock import sys import builtins real_import = builtins.__import__ def mock_import(name, globals=None, locals=None, fromlist=(), level=0): if name == 'ray': mock_ray = MagicMock() def mock_remote(cls): return cls mock_ray.remote = mock_remote return mock_ray return real_import(name, globals, locals, fromlist, level) builtins.__import__ = mock_import for mod in list(sys.modules.keys()): if 'pretor.core.database.postgres' in mod or 'ray' in mod: del sys.modules[mod] from pretor.core.database.postgres import PostgresDatabase builtins.__import__ = real_import @patch("pretor.core.database.postgres.create_async_engine") @patch("pretor.core.database.postgres.sessionmaker") @patch("pretor.core.database.postgres.AuthDatabase") @patch("pretor.core.database.postgres.ProviderDatabase") @patch("pretor.core.database.postgres.os.environ.get") @pytest.mark.asyncio async def test_postgres_database(mock_env_get, mock_provider_db, mock_auth_db, mock_sessionmaker, mock_create_engine): def env_side_effect(key): return { "POSTGRES_USER": "testuser", "POSTGRES_PASSWORD": "testpassword", "POSTGRES_HOST": "localhost", "POSTGRES_PORT": "5432", "POSTGRES_DB": "testdb" }.get(key) mock_env_get.side_effect = env_side_effect mock_engine = MagicMock() mock_conn = MagicMock() from unittest.mock import AsyncMock mock_conn.run_sync = AsyncMock() mock_begin_ctx = MagicMock() mock_begin_ctx.__aenter__ = AsyncMock(return_value=mock_conn) mock_begin_ctx.__aexit__ = AsyncMock() mock_engine.begin.return_value = mock_begin_ctx mock_create_engine.return_value = mock_engine db = PostgresDatabase() mock_create_engine.assert_called_once_with( "postgresql+asyncpg://testuser:testpassword@localhost:5432/testdb", echo=True ) mock_auth_db.assert_called_once() mock_provider_db.assert_called_once() mock_auth_db.return_value.get_user_authority = AsyncMock(return_value="test_auth") assert await db.auth_database("get_user_authority", user_id="123") == "test_auth" with patch("pretor.core.database.postgres.SQLModel.metadata.create_all") as mock_create_all: await db.init_db() mock_conn.run_sync.assert_called_once_with(mock_create_all)