import ray from pretor.core.database.table import User from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession from sqlalchemy.orm import sessionmaker from sqlmodel import SQLModel, select from pretor.utils.error import UserNotExistError, UserPasswordError import os from pretor.core.database.database_exception import database_exception from pretor.core.database.memory import MemoryRAG @ray.remote class PostgresDatabase: def __init__(self): user = os.environ.get('POSTGRES_USER') password = os.environ.get('POSTGRES_PASSWORD') host = os.environ.get('POSTGRES_HOST') port = os.environ.get('POSTGRES_PORT') database = os.environ.get('POSTGRES_DB') database_url = f"postgresql+asyncpg://{user}:{password}@{host}:{port}/{database}" self.async_engine = create_async_engine(database_url, echo=True) self.async_session_maker = sessionmaker(self.async_engine, class_=AsyncSession, expire_on_commit=False) self.memory = MemoryRAG(self.async_session_maker) async def init_db(self) -> None: async with self.async_engine.begin() as conn: await conn.run_sync(SQLModel.metadata.create_all) @database_exception async def add_user(self, user_name: str, hashed_password: str) -> User: user = User(user_name=user_name, hashed_password=hashed_password) async with self.async_session_maker as session: session.add(user) await session.commit() await session.refresh(user) return user @database_exception async def change_password(self, user_name, old_password, new_password) -> User: async with self.async_session_maker() as session: statement = select(User).where(User.user_name == user_name) results = await session.exec(statement) user = results.scalar_one_or_none() if user is None: raise UserNotExistError() if old_password != user.hashed_password: raise UserPasswordError() user.hashed_password = new_password session.add(user) await session.commit() await session.refresh(user) return user @database_exception async def delete_user(self, user_name: str) -> None: async with self.async_session_maker() as session: statement = select(User).where(User.user_name == user_name) results = await session.exec(statement) user = results.scalar_one_or_none() if user is None: raise UserNotExistError() session.delete(user) await session.commit() @database_exception async def get_user_password(self, user_name: str) -> str: async with self.async_session_maker() as session: statement = select(User).where(User.user_name == user_name) results = await session.exec(statement) user = results.scalar_one_or_none() if user is None: raise UserNotExistError() return user.hashed_password