# Copyright 2026 zhaoxi826 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import List from pretor.core.database.table.provider import Provider from sqlmodel import select from pretor.core.database.database_exception import database_exception class ProviderDatabase: def __init__(self, async_session_maker): self.async_session_maker = async_session_maker @database_exception async def get_provider(self) -> List[Provider]: async with self.async_session_maker() as session: statement = select(Provider) results = await session.execute(statement) results = results.scalars().all() providers = [Provider(provider_title=provider.provider_title, provider_url=provider.provider_url, provider_apikey=provider.provider_apikey, provider_models=provider.provider_models, provider_type=provider.provider_type) for provider in results] return providers @database_exception async def add_provider(self, **kwargs) -> None: async with self.async_session_maker() as session: provider = Provider(**kwargs) session.add(provider) await session.commit() @database_exception async def delete_provider(self, provider_id: str) -> None: async with self.async_session_maker() as session: provider = await session.get(Provider, provider_id) if provider is not None: session.delete(provider) await session.commit() @database_exception async def update_provider(self, provider_id: str, **kwargs) -> Provider: async with self.async_session_maker() as session: provider = await session.get(Provider, provider_id) if provider is not None: for key, value in kwargs.items(): setattr(provider, key, value) session.add(provider) await session.commit() await session.refresh(provider) return provider return None