73 lines
3.0 KiB
Python
73 lines
3.0 KiB
Python
import ray
|
|
from pretor.core.database.table import User
|
|
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
|
|
from sqlalchemy.orm import sessionmaker
|
|
from sqlmodel import SQLModel, select
|
|
from pretor.utils.error import UserNotExistError, UserPasswordError
|
|
import os
|
|
from pretor.core.database.database_exception import database_exception
|
|
from pretor.core.database.memory import MemoryRAG
|
|
|
|
@ray.remote
|
|
class PostgresDatabase:
|
|
def __init__(self):
|
|
user = os.environ.get('POSTGRES_USER')
|
|
password = os.environ.get('POSTGRES_PASSWORD')
|
|
host = os.environ.get('POSTGRES_HOST')
|
|
port = os.environ.get('POSTGRES_PORT')
|
|
database = os.environ.get('POSTGRES_DB')
|
|
database_url = f"postgresql+asyncpg://{user}:{password}@{host}:{port}/{database}"
|
|
self.async_engine = create_async_engine(database_url, echo=True)
|
|
self.async_session_maker = sessionmaker(self.async_engine, class_=AsyncSession, expire_on_commit=False)
|
|
self.memory = MemoryRAG(self.async_session_maker)
|
|
|
|
async def init_db(self) -> None:
|
|
async with self.async_engine.begin() as conn:
|
|
await conn.run_sync(SQLModel.metadata.create_all)
|
|
|
|
@database_exception
|
|
async def add_user(self, user_name: str, hashed_password: str) -> User:
|
|
user = User(user_name=user_name, hashed_password=hashed_password)
|
|
async with self.async_session_maker as session:
|
|
session.add(user)
|
|
await session.commit()
|
|
await session.refresh(user)
|
|
return user
|
|
|
|
@database_exception
|
|
async def change_password(self, user_name, old_password, new_password) -> User:
|
|
async with self.async_session_maker() as session:
|
|
statement = select(User).where(User.user_name == user_name)
|
|
results = await session.exec(statement)
|
|
user = results.scalar_one_or_none()
|
|
if user is None:
|
|
raise UserNotExistError()
|
|
if old_password != user.hashed_password:
|
|
raise UserPasswordError()
|
|
user.hashed_password = new_password
|
|
session.add(user)
|
|
await session.commit()
|
|
await session.refresh(user)
|
|
return user
|
|
|
|
@database_exception
|
|
async def delete_user(self, user_name: str) -> None:
|
|
async with self.async_session_maker() as session:
|
|
statement = select(User).where(User.user_name == user_name)
|
|
results = await session.exec(statement)
|
|
user = results.scalar_one_or_none()
|
|
if user is None:
|
|
raise UserNotExistError()
|
|
session.delete(user)
|
|
await session.commit()
|
|
|
|
@database_exception
|
|
async def get_user_password(self, user_name: str) -> str:
|
|
async with self.async_session_maker() as session:
|
|
statement = select(User).where(User.user_name == user_name)
|
|
results = await session.exec(statement)
|
|
user = results.scalar_one_or_none()
|
|
if user is None:
|
|
raise UserNotExistError()
|
|
return user.hashed_password
|
|
|