Pretor/archonbot/protocol__plugin/model_protocol/gemini.py

76 lines
2.9 KiB
Python

import httpx
import json
from typing import List, Dict, Any, AsyncGenerator
from archonbot.protocol__plugin.model_protocol.modelbase import ModelBase
class GeminiAdapter(ModelBase):
def __init__(self, base_url: str, adapter_title: str, api_key: str):
self.adapter_title: str = adapter_title
self.base_url = base_url.rstrip('/')
if not self.base_url.endswith('/v1'):
self.base_url += '/v1'
self.api_key = api_key
self.model_list = []
async def get_model(self) -> List[str]:
url = f"{self.base_url}/models"
headers = {"Authorization": f"Bearer {self.api_key}"}
async with httpx.AsyncClient(timeout=10.0) as client:
response = await client.get(url, headers=headers)
response.raise_for_status()
data = response.json()
self.model_list = [m.get("id", "") for m in data.get("data", [])]
return self.model_list
async def post_message(
self,
model: str,
messages: List[Dict[str, str]],
stream: bool = False,
temperature: float = 0.7,
max_tokens: int = 4096,
**kwargs
) -> Any:
url = f"{self.base_url}/chat/completions"
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
payload = {
"model": model,
"messages": messages,
"stream": stream,
"temperature": temperature,
"max_tokens": max_tokens,
**kwargs
}
# 144GB 显存或云端长文本建议设置较长超时
timeout = httpx.Timeout(120.0, connect=10.0)
if not stream:
async with httpx.AsyncClient(timeout=timeout) as client:
response = await client.post(url, headers=headers, json=payload)
response.raise_for_status()
return response.json()
else:
return self._handle_stream(url, headers, payload)
@staticmethod
async def _handle_stream(self, url: str, headers: dict, payload: dict) -> AsyncGenerator[str, None]:
async with httpx.AsyncClient(timeout=None) as client:
async with client.stream("POST", url, headers=headers, json=payload) as response:
response.raise_for_status()
async for line in response.aiter_lines():
if not line.strip() or line == "data: [DONE]":
continue
if line.startswith("data: "):
try:
chunk = json.loads(line[6:])
delta = chunk["choices"][0]["delta"].get("content", "")
if delta:
yield delta
except (json.JSONDecodeError, KeyError):
continue