76 lines
2.9 KiB
Python
76 lines
2.9 KiB
Python
import httpx
|
|
import json
|
|
from typing import List, Dict, Any, AsyncGenerator
|
|
from archonbot.protocol__plugin.model_protocol.modelbase import ModelBase
|
|
|
|
class GeminiAdapter(ModelBase):
|
|
def __init__(self, base_url: str, adapter_title: str, api_key: str):
|
|
self.adapter_title: str = adapter_title
|
|
self.base_url = base_url.rstrip('/')
|
|
if not self.base_url.endswith('/v1'):
|
|
self.base_url += '/v1'
|
|
self.api_key = api_key
|
|
self.model_list = []
|
|
|
|
async def get_model(self) -> List[str]:
|
|
url = f"{self.base_url}/models"
|
|
headers = {"Authorization": f"Bearer {self.api_key}"}
|
|
|
|
async with httpx.AsyncClient(timeout=10.0) as client:
|
|
response = await client.get(url, headers=headers)
|
|
response.raise_for_status()
|
|
data = response.json()
|
|
self.model_list = [m.get("id", "") for m in data.get("data", [])]
|
|
return self.model_list
|
|
|
|
async def post_message(
|
|
self,
|
|
model: str,
|
|
messages: List[Dict[str, str]],
|
|
stream: bool = False,
|
|
temperature: float = 0.7,
|
|
max_tokens: int = 4096,
|
|
**kwargs
|
|
) -> Any:
|
|
url = f"{self.base_url}/chat/completions"
|
|
headers = {
|
|
"Authorization": f"Bearer {self.api_key}",
|
|
"Content-Type": "application/json"
|
|
}
|
|
|
|
payload = {
|
|
"model": model,
|
|
"messages": messages,
|
|
"stream": stream,
|
|
"temperature": temperature,
|
|
"max_tokens": max_tokens,
|
|
**kwargs
|
|
}
|
|
|
|
# 144GB 显存或云端长文本建议设置较长超时
|
|
timeout = httpx.Timeout(120.0, connect=10.0)
|
|
|
|
if not stream:
|
|
async with httpx.AsyncClient(timeout=timeout) as client:
|
|
response = await client.post(url, headers=headers, json=payload)
|
|
response.raise_for_status()
|
|
return response.json()
|
|
else:
|
|
return self._handle_stream(url, headers, payload)
|
|
|
|
@staticmethod
|
|
async def _handle_stream(self, url: str, headers: dict, payload: dict) -> AsyncGenerator[str, None]:
|
|
async with httpx.AsyncClient(timeout=None) as client:
|
|
async with client.stream("POST", url, headers=headers, json=payload) as response:
|
|
response.raise_for_status()
|
|
async for line in response.aiter_lines():
|
|
if not line.strip() or line == "data: [DONE]":
|
|
continue
|
|
if line.startswith("data: "):
|
|
try:
|
|
chunk = json.loads(line[6:])
|
|
delta = chunk["choices"][0]["delta"].get("content", "")
|
|
if delta:
|
|
yield delta
|
|
except (json.JSONDecodeError, KeyError):
|
|
continue |