Pretor/pretor/core/database/postgres.py

73 lines
3.0 KiB
Python

import ray
from pretor.core.database.table import User
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
from sqlalchemy.orm import sessionmaker
from sqlmodel import SQLModel, select
from pretor.utils.error import UserNotExistError, UserPasswordError
import os
from pretor.core.database.database_exception import database_exception
from pretor.core.database.memory import MemoryRAG
@ray.remote
class PostgresDatabase:
def __init__(self):
user = os.environ.get('POSTGRES_USER')
password = os.environ.get('POSTGRES_PASSWORD')
host = os.environ.get('POSTGRES_HOST')
port = os.environ.get('POSTGRES_PORT')
database = os.environ.get('POSTGRES_DB')
database_url = f"postgresql+asyncpg://{user}:{password}@{host}:{port}/{database}"
self.async_engine = create_async_engine(database_url, echo=True)
self.async_session_maker = sessionmaker(self.async_engine, class_=AsyncSession, expire_on_commit=False)
self.memory = MemoryRAG(self.async_session_maker)
async def init_db(self) -> None:
async with self.async_engine.begin() as conn:
await conn.run_sync(SQLModel.metadata.create_all)
@database_exception
async def add_user(self, user_name: str, hashed_password: str) -> User:
user = User(user_name=user_name, hashed_password=hashed_password)
async with self.async_session_maker as session:
session.add(user)
await session.commit()
await session.refresh(user)
return user
@database_exception
async def change_password(self, user_name, old_password, new_password) -> User:
async with self.async_session_maker() as session:
statement = select(User).where(User.user_name == user_name)
results = await session.exec(statement)
user = results.scalar_one_or_none()
if user is None:
raise UserNotExistError()
if old_password != user.hashed_password:
raise UserPasswordError()
user.hashed_password = new_password
session.add(user)
await session.commit()
await session.refresh(user)
return user
@database_exception
async def delete_user(self, user_name: str) -> None:
async with self.async_session_maker() as session:
statement = select(User).where(User.user_name == user_name)
results = await session.exec(statement)
user = results.scalar_one_or_none()
if user is None:
raise UserNotExistError()
session.delete(user)
await session.commit()
@database_exception
async def get_user_password(self, user_name: str) -> str:
async with self.async_session_maker() as session:
statement = select(User).where(User.user_name == user_name)
results = await session.exec(statement)
user = results.scalar_one_or_none()
if user is None:
raise UserNotExistError()
return user.hashed_password