Compare commits

..

No commits in common. "e6bf9e2ce426e652595d2f98b2c9c70e5fda1684" and "bfdb0db933b066e73c0133f7cb1ec23ab2ef01aa" have entirely different histories.

39 changed files with 427 additions and 747 deletions

View File

@ -1,9 +1,9 @@
import axios from 'axios'; import axios from 'axios';
// The base URL should typically come from an environment variable in a real app. // The base URL should typically come from an environment variable in a real app,
// If missing, defaulting to '' means requests will be relative to the current browser origin. // but for development we can default to localhost.
export const apiClient = axios.create({ export const apiClient = axios.create({
baseURL: import.meta.env.VITE_API_BASE_URL || '', baseURL: import.meta.env.VITE_API_BASE_URL || 'http://localhost:8000',
timeout: 10000, timeout: 10000,
headers: { headers: {
'Content-Type': 'application/json', 'Content-Type': 'application/json',

View File

@ -1,331 +1,285 @@
import { useState, useEffect } from 'react'; import { useState, useEffect } from 'react';
import apiClient from '../../api/client'; import apiClient from '../../api/client';
import { Save, Plus, Edit2, Trash2, X } from 'lucide-react'; import { Bot, Save } from 'lucide-react';
import type { Provider } from '../../types'; import type { Provider } from '../../types';
interface WorkerIndividual { function WorkerIndividualForm({ providers }: { providers: Provider[] }) {
agent_id: string; const [formData, setFormData] = useState({
agent_name: string; agent_name: '',
agent_type: string; agent_type: 'OrdinaryIndividual',
description?: string; description: '',
provider_title: string; provider_title: providers.length > 0 ? providers[0].provider_title : '',
model_id: string; model_id: '',
system_prompt?: string; system_prompt: '',
output_template?: string; // Change to string for the form state output_template: '{}',
bound_skill?: string; // Change to string for the form state bound_skill: '{}',
workspace?: string; // Change to string for the form state workspace: '[]'
} });
const [loading, setLoading] = useState(false);
const [message, setMessage] = useState('');
export function WorkerIndividualSettings() { // Update initial provider_title when providers load
const [providers, setProviders] = useState<Provider[]>([]); useEffect(() => {
const [workers, setWorkers] = useState<WorkerIndividual[]>([]); if (providers.length > 0 && !formData.provider_title) {
const [loading, setLoading] = useState(true); setFormData(prev => ({ ...prev, provider_title: providers[0].provider_title }));
const [error, setError] = useState(''); }
}, [providers, formData.provider_title]);
const [isEditing, setIsEditing] = useState(false); const handleChange = (e: React.ChangeEvent<HTMLInputElement | HTMLSelectElement | HTMLTextAreaElement>) => {
const [editData, setEditData] = useState<Partial<WorkerIndividual>>({}); setFormData({ ...formData, [e.target.name]: e.target.value });
const [isNew, setIsNew] = useState(false); };
const [modalMessage, setModalMessage] = useState(''); const handleSubmit = async (e: React.FormEvent) => {
e.preventDefault();
const fetchData = async () => {
setLoading(true); setLoading(true);
setMessage('');
try { try {
const [provRes, workRes] = await Promise.all([ const payload = {
apiClient.get('/api/v1/provider/list'), ...formData,
apiClient.get('/api/v1/agent/worker') output_template: JSON.parse(formData.output_template),
]); bound_skill: JSON.parse(formData.bound_skill),
setProviders(Object.values(provRes.data.provider_list || {})); workspace: JSON.parse(formData.workspace)
setWorkers(workRes.data.workers || []); };
await apiClient.post('/api/v1/agent/worker', payload);
setMessage('Successfully created worker individual');
setFormData({
agent_name: '',
agent_type: 'OrdinaryIndividual',
description: '',
provider_title: '',
model_id: '',
system_prompt: '',
output_template: '{}',
bound_skill: '{}',
workspace: '[]'
});
} catch (err: any) { } catch (err: any) {
console.error(err); console.error(err);
setError('Failed to load data'); setMessage(err.response?.data?.detail || 'Failed to create worker individual. Ensure JSON fields are valid.');
} finally { } finally {
setLoading(false); setLoading(false);
} }
}; };
return (
<form onSubmit={handleSubmit} className="space-y-4">
<div className="grid grid-cols-1 md:grid-cols-2 gap-4">
<div>
<label className="block text-sm font-medium text-slate-700 mb-1">Agent Name</label>
<input required type="text" name="agent_name" value={formData.agent_name} onChange={handleChange} className="w-full px-4 py-2 border border-slate-200 rounded-lg focus:outline-none focus:ring-2 focus:ring-indigo-500" />
</div>
<div>
<label className="block text-sm font-medium text-slate-700 mb-1">Agent Type</label>
<select name="agent_type" value={formData.agent_type} onChange={handleChange} className="w-full px-4 py-2 border border-slate-200 rounded-lg focus:outline-none focus:ring-2 focus:ring-indigo-500">
<option value="OrdinaryIndividual">OrdinaryIndividual</option>
<option value="SkillIndividual">SkillIndividual</option>
<option value="SpecialIndividual">SpecialIndividual</option>
</select>
</div>
<div>
<label className="block text-sm font-medium text-slate-700 mb-1">Description</label>
<input required type="text" name="description" value={formData.description} onChange={handleChange} className="w-full px-4 py-2 border border-slate-200 rounded-lg focus:outline-none focus:ring-2 focus:ring-indigo-500" />
</div>
<div>
<label className="block text-sm font-medium text-slate-700 mb-1">Provider Title</label>
<select
required
name="provider_title"
value={formData.provider_title}
onChange={handleChange}
className="w-full px-4 py-2 border border-slate-200 rounded-lg focus:outline-none focus:ring-2 focus:ring-indigo-500"
>
{providers.length === 0 ? (
<option value="" disabled>No providers available. Create one first.</option>
) : (
providers.map((p) => (
<option key={p.provider_title} value={p.provider_title}>
{p.provider_title}
</option>
))
)}
</select>
</div>
<div>
<label className="block text-sm font-medium text-slate-700 mb-1">Model ID</label>
<input required type="text" name="model_id" value={formData.model_id} onChange={handleChange} className="w-full px-4 py-2 border border-slate-200 rounded-lg focus:outline-none focus:ring-2 focus:ring-indigo-500" />
</div>
<div className="md:col-span-2">
<label className="block text-sm font-medium text-slate-700 mb-1">System Prompt</label>
<textarea required name="system_prompt" value={formData.system_prompt} onChange={handleChange} rows={3} className="w-full px-4 py-2 border border-slate-200 rounded-lg focus:outline-none focus:ring-2 focus:ring-indigo-500" />
</div>
<div>
<label className="block text-sm font-medium text-slate-700 mb-1">Output Template (JSON dict)</label>
<input required type="text" name="output_template" value={formData.output_template} onChange={handleChange} className="w-full px-4 py-2 border border-slate-200 rounded-lg focus:outline-none focus:ring-2 focus:ring-indigo-500 font-mono text-sm" />
</div>
<div>
<label className="block text-sm font-medium text-slate-700 mb-1">Bound Skill (JSON dict)</label>
<input required type="text" name="bound_skill" value={formData.bound_skill} onChange={handleChange} className="w-full px-4 py-2 border border-slate-200 rounded-lg focus:outline-none focus:ring-2 focus:ring-indigo-500 font-mono text-sm" />
</div>
<div>
<label className="block text-sm font-medium text-slate-700 mb-1">Workspace (JSON list)</label>
<input required type="text" name="workspace" value={formData.workspace} onChange={handleChange} className="w-full px-4 py-2 border border-slate-200 rounded-lg focus:outline-none focus:ring-2 focus:ring-indigo-500 font-mono text-sm" />
</div>
</div>
{message && (
<div className={`p-3 rounded-lg text-sm ${message.includes('Success') ? 'bg-green-50 text-green-700' : 'bg-red-50 text-red-700'}`}>
{message}
</div>
)}
<div className="flex justify-end pt-2">
<button type="submit" disabled={loading} className="flex items-center space-x-2 px-4 py-2 bg-indigo-600 text-white rounded-lg hover:bg-indigo-700 transition-colors disabled:opacity-50">
<Save size={16} />
<span>{loading ? 'Creating...' : 'Create Worker'}</span>
</button>
</div>
</form>
);
}
export function WorkerIndividualSettings() {
const [nodeType, setNodeType] = useState('supervisory_node');
const [providerTitle, setProviderTitle] = useState('');
const [modelId, setModelId] = useState('');
const [loading, setLoading] = useState(false);
const [message, setMessage] = useState('');
const [providers, setProviders] = useState<Provider[]>([]);
useEffect(() => { useEffect(() => {
fetchData(); const fetchProviders = async () => {
try {
const response = await apiClient.get('/api/v1/provider/list');
const data = response.data.provider_list || {};
const providerArray: Provider[] = Object.values(data);
setProviders(providerArray);
if (providerArray.length > 0) {
setProviderTitle(providerArray[0].provider_title);
}
} catch (error) {
console.error("Failed to fetch providers", error);
setProviders([]);
}
};
fetchProviders();
}, []); }, []);
const handleEdit = (worker: any) => { // Accept the backend object which might have objects instead of strings const handleCreateNode = async (e: React.FormEvent) => {
setEditData({
...worker,
output_template: typeof worker.output_template === 'string' ? worker.output_template : JSON.stringify(worker.output_template || {}),
bound_skill: typeof worker.bound_skill === 'string' ? worker.bound_skill : JSON.stringify(worker.bound_skill || {}),
workspace: typeof worker.workspace === 'string' ? worker.workspace : JSON.stringify(worker.workspace || [])
});
setIsNew(false);
setIsEditing(true);
setModalMessage('');
};
const handleAddNew = () => {
setEditData({
agent_name: '',
agent_type: 'OrdinaryIndividual',
description: '',
provider_title: providers.length > 0 ? providers[0].provider_title : '',
model_id: '',
system_prompt: '',
output_template: '{}',
bound_skill: '{}',
workspace: '[]'
});
setIsNew(true);
setIsEditing(true);
setModalMessage('');
};
const handleDelete = async (agent_id: string) => {
if (!confirm('Are you sure you want to delete this agent?')) return;
try {
await apiClient.delete(`/api/v1/agent/worker/${agent_id}`);
fetchData();
} catch (err: any) {
console.error(err);
alert('Failed to delete agent');
}
};
const handleModalSave = async (e: React.FormEvent) => {
e.preventDefault(); e.preventDefault();
setModalMessage(''); setLoading(true);
setMessage('');
try { try {
const payload = { await apiClient.post('/api/v1/agent', {
...editData, provider_title: providerTitle,
output_template: JSON.parse(editData.output_template || '{}'), model_id: modelId,
bound_skill: JSON.parse(editData.bound_skill || '{}'), individual_name: nodeType
workspace: JSON.parse(editData.workspace || '[]') });
}; setMessage(`Successfully loaded ${nodeType}`);
setProviderTitle('');
if (isNew) { setModelId('');
await apiClient.post('/api/v1/agent/worker', payload);
} else {
await apiClient.put(`/api/v1/agent/worker/${editData.agent_id}`, payload);
}
setIsEditing(false);
fetchData();
} catch (err: any) { } catch (err: any) {
console.error(err); console.error(err);
setModalMessage(err.response?.data?.detail || err.message || 'Failed to save'); setMessage(err.response?.data?.detail || 'Failed to load agent node');
} finally {
setLoading(false);
} }
}; };
return ( return (
<div className="max-w-5xl space-y-6 relative"> <div className="max-w-4xl space-y-6">
<div className="mb-8 flex justify-between items-end"> <div className="mb-8">
<div> <h1 className="text-2xl font-bold text-slate-800">Worker Individual Settings</h1>
<h1 className="text-2xl font-bold text-slate-800">Worker Individuals</h1> <p className="text-slate-500 mt-1">Configure your system agents and custom workers.</p>
<p className="text-slate-500 mt-1">Manage all system nodes and custom workers.</p>
</div>
<button
onClick={handleAddNew}
className="flex items-center px-4 py-2 bg-indigo-600 text-white rounded-lg hover:bg-indigo-700 transition-colors"
>
<Plus size={16} className="mr-2" />
Add Worker
</button>
</div> </div>
{error && <div className="text-red-600">{error}</div>}
<div className="bg-white rounded-xl shadow-sm border border-slate-200 overflow-hidden"> <div className="bg-white rounded-xl shadow-sm border border-slate-200 overflow-hidden">
<div className="p-0"> <div className="p-6 border-b border-slate-100 flex items-center justify-between">
{loading ? ( <div className="flex items-center space-x-3">
<div className="p-6 text-slate-500">Loading...</div> <div className="w-10 h-10 bg-indigo-50 text-indigo-600 rounded-lg flex items-center justify-center">
) : workers.length === 0 ? ( <Bot size={20} />
<div className="p-6 text-slate-500">No workers found.</div> </div>
) : ( <div>
<table className="w-full text-left border-collapse"> <h2 className="text-lg font-semibold text-slate-800">System Nodes</h2>
<thead> <p className="text-sm text-slate-500">Initialize core system agents</p>
<tr className="bg-slate-50 border-b border-slate-200 text-slate-600 text-sm"> </div>
<th className="p-4 font-semibold">Name</th> </div>
<th className="p-4 font-semibold">Type</th> </div>
<th className="p-4 font-semibold">Provider / Model ID</th> <div className="p-6">
<th className="p-4 font-semibold text-right">Actions</th> <form onSubmit={handleCreateNode} className="space-y-4">
</tr> <div className="grid grid-cols-1 md:grid-cols-3 gap-4">
</thead> <div>
<tbody> <label className="block text-sm font-medium text-slate-700 mb-1">Node Type</label>
{workers.map((w) => ( <select
<tr key={w.agent_id} className="border-b border-slate-100 hover:bg-slate-50 transition-colors"> value={nodeType}
<td className="p-4 font-medium text-slate-800">{w.agent_name}</td> onChange={(e) => setNodeType(e.target.value)}
<td className="p-4 text-slate-600"> className="w-full px-4 py-2 border border-slate-200 rounded-lg focus:outline-none focus:ring-2 focus:ring-indigo-500"
<span className="px-2 py-1 bg-slate-100 rounded text-xs">{w.agent_type}</span> >
</td> <option value="supervisory_node">Supervisory Node</option>
<td className="p-4 text-slate-600 text-sm"> <option value="consciousness_node">Consciousness Node</option>
{w.provider_title} <span className="text-slate-400">/</span> {w.model_id} <option value="control_node">Control Node</option>
</td> </select>
<td className="p-4 text-right space-x-2"> </div>
<button onClick={() => handleEdit(w)} className="p-2 text-indigo-600 hover:bg-indigo-50 rounded-lg transition-colors" title="Edit"> <div>
<Edit2 size={16} /> <label className="block text-sm font-medium text-slate-700 mb-1">Provider Title</label>
</button> <select
<button onClick={() => handleDelete(w.agent_id)} className="p-2 text-red-600 hover:bg-red-50 rounded-lg transition-colors" title="Delete"> value={providerTitle}
<Trash2 size={16} /> onChange={(e) => setProviderTitle(e.target.value)}
</button> required
</td> className="w-full px-4 py-2 border border-slate-200 rounded-lg focus:outline-none focus:ring-2 focus:ring-indigo-500"
</tr> >
))} {providers.length === 0 ? (
</tbody> <option value="" disabled>No providers available. Create one first.</option>
</table> ) : (
)} providers.map((p) => (
<option key={p.provider_title} value={p.provider_title}>
{p.provider_title}
</option>
))
)}
</select>
</div>
<div>
<label className="block text-sm font-medium text-slate-700 mb-1">Model ID</label>
<input
type="text"
value={modelId}
onChange={(e) => setModelId(e.target.value)}
placeholder="e.g. gpt-4"
required
className="w-full px-4 py-2 border border-slate-200 rounded-lg focus:outline-none focus:ring-2 focus:ring-indigo-500"
/>
</div>
</div>
{message && (
<div className={`p-3 rounded-lg text-sm ${message.includes('Success') ? 'bg-green-50 text-green-700' : 'bg-red-50 text-red-700'}`}>
{message}
</div>
)}
<div className="flex justify-end">
<button
type="submit"
disabled={loading}
className="flex items-center space-x-2 px-4 py-2 bg-indigo-600 text-white rounded-lg hover:bg-indigo-700 transition-colors disabled:opacity-50"
>
<Save size={16} />
<span>{loading ? 'Saving...' : 'Load Node'}</span>
</button>
</div>
</form>
</div> </div>
</div> </div>
{/* Edit/Create Modal */} <div className="bg-white rounded-xl shadow-sm border border-slate-200 overflow-hidden">
{isEditing && ( <div className="p-6 border-b border-slate-100">
<div className="fixed inset-0 bg-black/50 z-50 flex items-center justify-center p-4"> <h2 className="text-lg font-semibold text-slate-800">Create Worker Individual</h2>
<div className="bg-white rounded-xl shadow-xl w-full max-w-2xl max-h-[90vh] overflow-y-auto"> <p className="text-sm text-slate-500">Add a new custom worker to the system.</p>
<div className="flex justify-between items-center p-6 border-b border-slate-100 sticky top-0 bg-white z-10">
<h2 className="text-xl font-bold text-slate-800">{isNew ? 'Create Worker' : 'Edit Worker'}</h2>
<button onClick={() => setIsEditing(false)} className="text-slate-400 hover:text-slate-600">
<X size={24} />
</button>
</div>
<form onSubmit={handleModalSave} className="p-6 space-y-4">
<div className="grid grid-cols-1 md:grid-cols-2 gap-4">
<div>
<label className="block text-sm font-medium text-slate-700 mb-1">Agent Name</label>
<input
type="text"
required
value={editData.agent_name || ''}
onChange={(e) => setEditData({...editData, agent_name: e.target.value})}
className="w-full px-4 py-2 border border-slate-200 rounded-lg focus:ring-2 focus:ring-indigo-500"
/>
</div>
<div>
<label className="block text-sm font-medium text-slate-700 mb-1">Agent Type</label>
<select
value={editData.agent_type || 'OrdinaryIndividual'}
onChange={(e) => setEditData({...editData, agent_type: e.target.value})}
className="w-full px-4 py-2 border border-slate-200 rounded-lg focus:ring-2 focus:ring-indigo-500"
>
<option value="supervisory_node">Supervisory Node</option>
<option value="consciousness_node">Consciousness Node</option>
<option value="control_node">Control Node</option>
<option value="OrdinaryIndividual">Ordinary Individual</option>
<option value="SkillIndividual">Skill Individual</option>
<option value="SpecialIndividual">Special Individual</option>
</select>
</div>
</div>
<div className="grid grid-cols-1 md:grid-cols-2 gap-4">
<div>
<label className="block text-sm font-medium text-slate-700 mb-1">Provider Title</label>
<select
value={editData.provider_title || ''}
onChange={(e) => setEditData({...editData, provider_title: e.target.value, model_id: ''})}
required
className="w-full px-4 py-2 border border-slate-200 rounded-lg focus:ring-2 focus:ring-indigo-500"
>
<option value="" disabled>Select Provider</option>
{providers.map((p) => (
<option key={p.provider_title} value={p.provider_title}>{p.provider_title}</option>
))}
</select>
</div>
<div>
<label className="block text-sm font-medium text-slate-700 mb-1">Model ID</label>
{(() => {
const selectedProvider = providers.find(p => p.provider_title === editData.provider_title);
const models = selectedProvider?.provider_models || [];
return (
<select
value={editData.model_id || ''}
onChange={(e) => setEditData({...editData, model_id: e.target.value})}
required
className="w-full px-4 py-2 border border-slate-200 rounded-lg focus:ring-2 focus:ring-indigo-500"
>
<option value="" disabled>Select a model</option>
{models.map(m => <option key={m} value={m}>{m}</option>)}
</select>
);
})()}
</div>
</div>
<div>
<label className="block text-sm font-medium text-slate-700 mb-1">Description</label>
<textarea
value={editData.description || ''}
onChange={(e) => setEditData({...editData, description: e.target.value})}
rows={2}
className="w-full px-4 py-2 border border-slate-200 rounded-lg focus:ring-2 focus:ring-indigo-500"
/>
</div>
<div>
<label className="block text-sm font-medium text-slate-700 mb-1">System Prompt</label>
<textarea
value={editData.system_prompt || ''}
onChange={(e) => setEditData({...editData, system_prompt: e.target.value})}
rows={3}
className="w-full px-4 py-2 border border-slate-200 rounded-lg focus:ring-2 focus:ring-indigo-500 font-mono text-sm"
/>
</div>
<div className="grid grid-cols-1 md:grid-cols-2 gap-4">
<div>
<label className="block text-sm font-medium text-slate-700 mb-1">Output Template (JSON)</label>
<textarea
value={editData.output_template || '{}'}
onChange={(e) => setEditData({...editData, output_template: e.target.value})}
rows={3}
className="w-full px-4 py-2 border border-slate-200 rounded-lg focus:ring-2 focus:ring-indigo-500 font-mono text-sm"
/>
</div>
<div>
<label className="block text-sm font-medium text-slate-700 mb-1">Bound Skill (JSON)</label>
<textarea
value={editData.bound_skill || '{}'}
onChange={(e) => setEditData({...editData, bound_skill: e.target.value})}
rows={3}
className="w-full px-4 py-2 border border-slate-200 rounded-lg focus:ring-2 focus:ring-indigo-500 font-mono text-sm"
/>
</div>
</div>
<div>
<label className="block text-sm font-medium text-slate-700 mb-1">Workspace (JSON Array)</label>
<textarea
value={editData.workspace || '[]'}
onChange={(e) => setEditData({...editData, workspace: e.target.value})}
rows={2}
className="w-full px-4 py-2 border border-slate-200 rounded-lg focus:ring-2 focus:ring-indigo-500 font-mono text-sm"
/>
</div>
{modalMessage && (
<div className="p-3 bg-red-50 text-red-700 text-sm rounded-lg">
{modalMessage}
</div>
)}
<div className="pt-4 flex justify-end space-x-3 border-t border-slate-100">
<button
type="button"
onClick={() => setIsEditing(false)}
className="px-4 py-2 text-slate-600 hover:bg-slate-100 rounded-lg transition-colors"
>
Cancel
</button>
<button
type="submit"
className="flex items-center px-4 py-2 bg-indigo-600 text-white rounded-lg hover:bg-indigo-700 transition-colors"
>
<Save size={16} className="mr-2" />
Save Worker
</button>
</div>
</form>
</div>
</div> </div>
)} <div className="p-6">
<WorkerIndividualForm providers={providers} />
</div>
</div>
</div> </div>
); );
} }

View File

@ -16,57 +16,33 @@ export function RightPanel({ selectedWorkflow }: RightPanelProps) {
return; return;
} }
let ws: WebSocket | null = null; const wsBase = import.meta.env.VITE_API_BASE_URL
let reconnectTimeout: ReturnType<typeof setTimeout>; ? import.meta.env.VITE_API_BASE_URL.replace('http', 'ws')
let retryCount = 0; : `ws://localhost:8000`;
const maxRetryCount = 10;
const baseDelay = 1000;
const connect = () => { // Using the workflow router WS endpoint
const protocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:'; const ws = new WebSocket(`${wsBase}/api/v1/workflow/ws/${selectedWorkflow}`);
const host = window.location.host;
const wsBase = import.meta.env.VITE_API_BASE_URL ws.onopen = () => {
? import.meta.env.VITE_API_BASE_URL.replace(/^http/, 'ws') setIsConnected(true);
: `${protocol}//${host}`;
// Using the workflow router WS endpoint setMessages([]); // clear previous traces
ws = new WebSocket(`${wsBase}/api/v1/workflow/ws/${selectedWorkflow}`);
ws.onopen = () => {
setIsConnected(true);
retryCount = 0; // reset on successful connection
setMessages([]); // clear previous traces
};
ws.onmessage = (event) => {
try {
setMessages(prev => [...prev, event.data]);
} catch (e) {
console.error("Error receiving workflow websocket message", e);
}
};
ws.onclose = () => {
setIsConnected(false);
if (retryCount < maxRetryCount) {
const delay = baseDelay * Math.pow(2, retryCount);
retryCount++;
console.log(`WebSocket closed. Reconnecting in ${delay}ms... (Attempt ${retryCount})`);
reconnectTimeout = setTimeout(connect, delay);
} else {
console.error("Max WebSocket reconnect attempts reached.");
}
};
}; };
connect(); ws.onmessage = (event) => {
try {
setMessages(prev => [...prev, event.data]);
} catch (e) {
console.error("Error receiving workflow websocket message", e);
}
};
ws.onclose = () => {
setIsConnected(false);
};
return () => { return () => {
clearTimeout(reconnectTimeout); ws.close();
if (ws) {
ws.close();
}
}; };
}, [selectedWorkflow]); }, [selectedWorkflow]);

View File

@ -15,14 +15,7 @@ export function SkillSettings() {
setLoading(true); setLoading(true);
try { try {
const response = await apiClient.get('/api/v1/resource/skill'); const response = await apiClient.get('/api/v1/resource/skill');
const skillsData = response.data.skills || {}; setSkills(response.data.skills || []);
// skillsData might be an object mapping skill names to their details, or it might be an array in some versions.
// We ensure it is an array of strings (skill names)
if (Array.isArray(skillsData)) {
setSkills(skillsData);
} else {
setSkills(Object.keys(skillsData));
}
} catch (err) { } catch (err) {
console.error('Failed to fetch skills:', err); console.error('Failed to fetch skills:', err);
} finally { } finally {

View File

@ -2,27 +2,19 @@ import { useState, useEffect } from 'react';
import apiClient from '../../api/client'; import apiClient from '../../api/client';
import { FileCode, Trash2, Plus, LayoutTemplate } from 'lucide-react'; import { FileCode, Trash2, Plus, LayoutTemplate } from 'lucide-react';
import type { WorkflowTemplate as ParsedWorkflowTemplate } from '../../types'; interface WorkflowTemplate {
name: string;
[key: string]: any;
}
export function WorkflowTemplateSettings() { export function WorkflowTemplateSettings() {
const [templates, setTemplates] = useState<Record<string, ParsedWorkflowTemplate>>({}); const [templates, setTemplates] = useState<Record<string, WorkflowTemplate>>({});
const [loading, setLoading] = useState(true); const [loading, setLoading] = useState(true);
const [templateJson, setTemplateJson] = useState('{\n "name": "my_template",\n "steps": [\n {\n "name": "step1",\n "actor": "actor_name"\n }\n ]\n}'); const [templateJson, setTemplateJson] = useState('{\n "name": "my_template"\n}');
const [creating, setCreating] = useState(false); const [creating, setCreating] = useState(false);
const [message, setMessage] = useState(''); const [message, setMessage] = useState('');
const [error, setError] = useState(''); const [error, setError] = useState('');
const validateTemplate = (data: any): data is ParsedWorkflowTemplate => {
if (!data || typeof data !== 'object') return false;
if (typeof data.name !== 'string') return false;
if (!Array.isArray(data.steps)) return false;
for (const step of data.steps) {
if (typeof step.name !== 'string') return false;
if (typeof step.actor !== 'string') return false;
}
return true;
};
const fetchTemplates = async () => { const fetchTemplates = async () => {
setLoading(true); setLoading(true);
try { try {
@ -47,21 +39,16 @@ export function WorkflowTemplateSettings() {
try { try {
const parsedJson = JSON.parse(templateJson); const parsedJson = JSON.parse(templateJson);
if (!validateTemplate(parsedJson)) {
throw new Error('JSON structure does not match WorkflowTemplate schema (requires name and steps array with name and actor).');
}
await apiClient.post('/api/v1/resource/workflow_template', parsedJson); await apiClient.post('/api/v1/resource/workflow_template', parsedJson);
setMessage('Workflow template created successfully'); setMessage('Workflow template created successfully');
setTemplateJson('{\n "name": "my_template",\n "steps": []\n}'); setTemplateJson('{\n "name": "my_template"\n}');
fetchTemplates(); fetchTemplates();
} catch (err: any) { } catch (err: any) {
console.error(err); console.error(err);
if (err instanceof SyntaxError) { if (err instanceof SyntaxError) {
setError('Invalid JSON format'); setError('Invalid JSON format');
} else { } else {
setError(err.message || err.response?.data?.message || 'Failed to create workflow template'); setError(err.response?.data?.message || 'Failed to create workflow template');
} }
} finally { } finally {
setCreating(false); setCreating(false);

View File

@ -6,59 +6,34 @@ export function useClusterState() {
const [isConnected, setIsConnected] = useState(false); const [isConnected, setIsConnected] = useState(false);
useEffect(() => { useEffect(() => {
let ws: WebSocket | null = null; // Determine WS URL based on API base URL or window location
let reconnectTimeout: ReturnType<typeof setTimeout>; const wsBase = import.meta.env.VITE_API_BASE_URL
let retryCount = 0; ? import.meta.env.VITE_API_BASE_URL.replace('http', 'ws')
const maxRetryCount = 10; : `ws://localhost:8000`;
const baseDelay = 1000;
const connect = () => { const ws = new WebSocket(`${wsBase}/api/v1/cluster/ws/state`);
// Determine WS URL based on API base URL or window location
const protocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:';
const host = window.location.host;
const wsBase = import.meta.env.VITE_API_BASE_URL ws.onopen = () => {
? import.meta.env.VITE_API_BASE_URL.replace(/^http/, 'ws') setIsConnected(true);
: `${protocol}//${host}`;
ws = new WebSocket(`${wsBase}/api/v1/cluster/ws/state`);
ws.onopen = () => {
setIsConnected(true);
retryCount = 0; // Reset retry count on successful connection
};
ws.onmessage = (event) => {
try {
const data = JSON.parse(event.data);
if (Array.isArray(data)) {
setNodes(data);
}
} catch (e) {
console.error("Error parsing cluster state websocket message", e);
}
};
ws.onclose = () => {
setIsConnected(false);
if (retryCount < maxRetryCount) {
const delay = baseDelay * Math.pow(2, retryCount);
retryCount++;
console.log(`WebSocket closed. Reconnecting in ${delay}ms... (Attempt ${retryCount})`);
reconnectTimeout = setTimeout(connect, delay);
} else {
console.error("Max WebSocket reconnect attempts reached.");
}
};
}; };
connect(); ws.onmessage = (event) => {
try {
const data = JSON.parse(event.data);
if (Array.isArray(data)) {
setNodes(data);
}
} catch (e) {
console.error("Error parsing cluster state websocket message", e);
}
};
ws.onclose = () => {
setIsConnected(false);
};
return () => { return () => {
clearTimeout(reconnectTimeout); ws.close();
if (ws) {
ws.close();
}
}; };
}, []); }, []);

View File

@ -18,7 +18,6 @@ export interface Provider {
provider_title: string; provider_title: string;
provider_url?: string; provider_url?: string;
provider_owner?: string; provider_owner?: string;
provider_models?: string[];
// Based on your UI needs we might infer some local status fields // Based on your UI needs we might infer some local status fields
status?: string; status?: string;
model?: string; model?: string;
@ -66,18 +65,3 @@ export interface Workflow {
workflow_title: string; workflow_title: string;
status?: string; status?: string;
} }
// Workflow Template Validation
export interface WorkStep {
name: string;
desc?: string;
actor: string; // the name of the worker individual
inputs?: string[];
outputs?: string[];
}
export interface WorkflowTemplate {
name: string;
desc?: string;
steps: WorkStep[];
}

View File

@ -53,7 +53,7 @@ class AgentFactory:
if provider.provider_type not in self._models_mapping: if provider.provider_type not in self._models_mapping:
raise ValueError(f"不支持的协议类型: {provider.provider_type}") raise ValueError(f"不支持的协议类型: {provider.provider_type}")
model_class, provider_class = self._models_mapping[provider.provider_type] model_class, provider_class = self._models_mapping[provider.provider_type]
model = model_class(model_id, provider=provider_class(api_key=provider.provider_apikey, base_url=provider.provider_url)) model = model_class(model_id, provider_class(api_key=provider.api_key, url=provider.url))
agent = Agent(model=model, agent = Agent(model=model,
name=agent_name, name=agent_name,
system_prompt=system_prompt, system_prompt=system_prompt,

View File

@ -43,21 +43,18 @@ async def load_agent(agent_register: Union[AgentRegister, AgentLocalRegister],
pass pass
elif isinstance(agent_register, AgentRegister): elif isinstance(agent_register, AgentRegister):
try: match agent_register.individual_name:
match agent_register.individual_name: case "supervisory_node":
case "supervisory_node": node = ray_actor_hook("supervisory_node").supervisory_node
node = ray_actor_hook("supervisory_node").supervisory_node node.create_agent.remote(global_state_machine,agent_register.provider_title,agent_register.model_id)
await node.create_agent.remote(global_state_machine,agent_register.provider_title,agent_register.model_id) case "consciousness_node":
case "consciousness_node": node = ray_actor_hook("consciousness_node").consciousness_node
node = ray_actor_hook("consciousness_node").consciousness_node node.create_agent.remote(global_state_machine,agent_register.provider_title,agent_register.model_id)
await node.create_agent.remote(global_state_machine,agent_register.provider_title,agent_register.model_id) case "control_node":
case "control_node": node = ray_actor_hook("control_node").control_node
node = ray_actor_hook("control_node").control_node node.create_agent.remote(global_state_machine,agent_register.provider_title,agent_register.model_id)
await node.create_agent.remote(global_state_machine,agent_register.provider_title,agent_register.model_id) case _:
case _: pass
pass
except Exception as e:
raise HTTPException(status_code=500, detail=f"加载节点失败: {str(e)}")
return {"message": "创建成功"} return {"message": "创建成功"}
@ -88,25 +85,25 @@ class WorkerIndividualUpdate(BaseModel):
@agent_router.post("/worker") @agent_router.post("/worker")
async def create_worker_individual(worker_data: WorkerIndividualCreate, async def create_worker_individual(worker_data: WorkerIndividualCreate,
token_data: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.USER))): token_data: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.USER))):
postgres_database = ray_actor_hook("postgres_database").postgres_database postgres_database = ray_actor_hook("postgres_database")
data_dict = worker_data.model_dump() data_dict = worker_data.model_dump()
data_dict["owner_id"] = token_data.user_id data_dict["owner_id"] = token_data.user_id
worker = await postgres_database.add_worker_individual.remote( **data_dict) worker = await postgres_database.individual_database.remote("add_worker_individual", **data_dict)
return {"message": "success", "agent_id": worker.agent_id} return {"message": "success", "agent_id": worker.agent_id}
@agent_router.get("/worker") @agent_router.get("/worker")
async def get_worker_individual_list(token_data: TokenData = Depends(Accessor.get_current_user)): async def get_worker_individual_list(token_data: TokenData = Depends(Accessor.get_current_user)):
postgres_database = ray_actor_hook("postgres_database").postgres_database postgres_database = ray_actor_hook("postgres_database")
workers = await postgres_database.get_worker_individual_list.remote( owner_id=token_data.user_id) workers = await postgres_database.individual_database.remote("get_worker_individual_list", owner_id=token_data.user_id)
return {"workers": workers} return {"workers": workers}
@agent_router.get("/worker/{agent_id}") @agent_router.get("/worker/{agent_id}")
async def get_worker_individual(agent_id: str, async def get_worker_individual(agent_id: str,
token_data: TokenData = Depends(Accessor.get_current_user)): token_data: TokenData = Depends(Accessor.get_current_user)):
postgres_database = ray_actor_hook("postgres_database").postgres_database postgres_database = ray_actor_hook("postgres_database")
worker = await postgres_database.get_worker_individual.remote( agent_id=agent_id) worker = await postgres_database.individual_database.remote("get_worker_individual", agent_id=agent_id)
if not worker: if not worker:
raise HTTPException(status_code=404, detail="Agent not found") raise HTTPException(status_code=404, detail="Agent not found")
if worker.owner_id != token_data.user_id: if worker.owner_id != token_data.user_id:
@ -118,26 +115,26 @@ async def get_worker_individual(agent_id: str,
async def update_worker_individual(agent_id: str, async def update_worker_individual(agent_id: str,
worker_data: WorkerIndividualUpdate, worker_data: WorkerIndividualUpdate,
token_data: TokenData = Depends(Accessor.get_current_user)): token_data: TokenData = Depends(Accessor.get_current_user)):
postgres_database = ray_actor_hook("postgres_database").postgres_database postgres_database = ray_actor_hook("postgres_database")
worker = await postgres_database.get_worker_individual.remote( agent_id=agent_id) worker = await postgres_database.individual_database.remote("get_worker_individual", agent_id=agent_id)
if not worker: if not worker:
raise HTTPException(status_code=404, detail="Agent not found") raise HTTPException(status_code=404, detail="Agent not found")
if worker.owner_id != token_data.user_id: if worker.owner_id != token_data.user_id:
raise HTTPException(status_code=403, detail="Forbidden: You do not own this agent") raise HTTPException(status_code=403, detail="Forbidden: You do not own this agent")
update_data = worker_data.model_dump(exclude_unset=True) update_data = worker_data.model_dump(exclude_unset=True)
updated_worker = await postgres_database.update_worker_individual.remote( agent_id=agent_id, **update_data) updated_worker = await postgres_database.individual_database.remote("update_worker_individual", agent_id=agent_id, **update_data)
return {"message": "success", "worker": updated_worker} return {"message": "success", "worker": updated_worker}
@agent_router.delete("/worker/{agent_id}") @agent_router.delete("/worker/{agent_id}")
async def delete_worker_individual(agent_id: str, async def delete_worker_individual(agent_id: str,
token_data: TokenData = Depends(Accessor.get_current_user)): token_data: TokenData = Depends(Accessor.get_current_user)):
postgres_database = ray_actor_hook("postgres_database").postgres_database postgres_database = ray_actor_hook("postgres_database")
worker = await postgres_database.get_worker_individual.remote( agent_id=agent_id) worker = await postgres_database.individual_database.remote("get_worker_individual", agent_id=agent_id)
if not worker: if not worker:
raise HTTPException(status_code=404, detail="Agent not found") raise HTTPException(status_code=404, detail="Agent not found")
if worker.owner_id != token_data.user_id: if worker.owner_id != token_data.user_id:
raise HTTPException(status_code=403, detail="Forbidden: You do not own this agent") raise HTTPException(status_code=403, detail="Forbidden: You do not own this agent")
await postgres_database.delete_worker_individual.remote( agent_id=agent_id) await postgres_database.individual_database.remote("delete_worker_individual", agent_id=agent_id)
return {"message": "success"} return {"message": "success"}

View File

@ -32,7 +32,7 @@ class UserRegister(BaseModel):
async def create_user(user_register: UserRegister): async def create_user(user_register: UserRegister):
postgres_database = ray_actor_hook("postgres_database").postgres_database postgres_database = ray_actor_hook("postgres_database").postgres_database
hashed_password = await run_in_threadpool(Accessor.hash_password, user_register.password) hashed_password = await run_in_threadpool(Accessor.hash_password, user_register.password)
user = await postgres_database.add_user.remote( user_register.user_name, hashed_password) user = await postgres_database.auth_database.remote("add_user", user_register.user_name, hashed_password)
return {"message": "success", "user_id": user.user_id} return {"message": "success", "user_id": user.user_id}
class UserLogin(BaseModel): class UserLogin(BaseModel):
@ -42,7 +42,7 @@ class UserLogin(BaseModel):
@auth_router.post("/login") @auth_router.post("/login")
async def login_user(user_login: UserLogin): async def login_user(user_login: UserLogin):
postgres_database = ray_actor_hook("postgres_database").postgres_database postgres_database = ray_actor_hook("postgres_database").postgres_database
user = await postgres_database.login_user.remote( user_login.user_name) user = await postgres_database.auth_database.remote("login_user", user_login.user_name)
if not user: if not user:
raise UserNotExistError() raise UserNotExistError()
token = await run_in_threadpool(Accessor.login_hashed_password, user, user_login.password) token = await run_in_threadpool(Accessor.login_hashed_password, user, user_login.password)
@ -61,7 +61,7 @@ async def change_authority(
Update a user's authority level. Only accessible by SUPER_ADMINISTRATOR. Update a user's authority level. Only accessible by SUPER_ADMINISTRATOR.
""" """
postgres_database = ray_actor_hook("postgres_database").postgres_database postgres_database = ray_actor_hook("postgres_database").postgres_database
user = await postgres_database.change_user_authority.remote( user_id=request.user_id, new_authority=request.new_authority) user = await postgres_database.auth_database.remote("change_user_authority", user_id=request.user_id, new_authority=request.new_authority)
return {"message": "success", "user_id": user.user_id, "new_authority": user.user_authority} return {"message": "success", "user_id": user.user_id, "new_authority": user.user_authority}
@auth_router.get("/list") @auth_router.get("/list")
@ -72,7 +72,7 @@ async def get_user_list(
Get a list of all users. Only accessible by SUPER_ADMINISTRATOR. Get a list of all users. Only accessible by SUPER_ADMINISTRATOR.
""" """
postgres_database = ray_actor_hook("postgres_database").postgres_database postgres_database = ray_actor_hook("postgres_database").postgres_database
users = await postgres_database.get_all_users.remote() users = await postgres_database.auth_database.remote("get_all_users")
return {"users": [{"user_id": u.user_id, "user_name": u.user_name, "role": u.user_authority} for u in users]} return {"users": [{"user_id": u.user_id, "user_name": u.user_name, "role": u.user_authority} for u in users]}
@auth_router.delete("/{user_id}") @auth_router.delete("/{user_id}")
@ -84,5 +84,5 @@ async def delete_user(
Delete a user. Only accessible by SUPER_ADMINISTRATOR. Delete a user. Only accessible by SUPER_ADMINISTRATOR.
""" """
postgres_database = ray_actor_hook("postgres_database").postgres_database postgres_database = ray_actor_hook("postgres_database").postgres_database
await postgres_database.delete_user_by_id.remote( user_id=user_id) await postgres_database.auth_database.remote("delete_user_by_id", user_id=user_id)
return {"message": "success"} return {"message": "success"}

View File

@ -36,10 +36,5 @@ async def update_cluster_state(websocket: WebSocket):
] ]
await websocket.send_json(payload) await websocket.send_json(payload)
await asyncio.sleep(10) await asyncio.sleep(10)
except WebSocketDisconnect: except (WebSocketDisconnect, RuntimeError):
pass
except RuntimeError as e:
if "closed" not in str(e) and "GeneratorExit" not in str(e):
raise
except Exception:
pass pass

View File

@ -34,7 +34,7 @@ async def create_message(message: Message,
user_id=str(token_data.user_id), user_id=str(token_data.user_id),
user_name=token_data.username, user_name=token_data.username,
message=message.message) message=message.message)
supervisory_node = ray_actor_hook("supervisory_node").supervisory_node supervisory_node = ray_actor_hook("supervisor_node")
message = await supervisory_node.working.remote(event) message = await supervisory_node.working.remote(event)
if message == "任务已创建": if message == "任务已创建":
return {"message": event.trace_id} return {"message": event.trace_id}

View File

@ -44,12 +44,13 @@ async def create_provider(provider_register: ProviderRegister,
@provider_router.get("/list") @provider_router.get("/list")
async def get_provider_list(_: TokenData = Depends(Accessor.get_current_user)) -> Dict[str, Dict[str, Provider]]: async def get_provider_list(_: TokenData = Depends(Accessor.get_current_user)) -> Dict[str, Dict[str, Provider]]:
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
provider_list: Dict[str, Provider] = await global_state_machine.get_provider_list.remote() provider_list: Dict[str, Provider] = await global_state_machine.provider_manager.remote("get_provider_list")
providers = list(provider_list.values()) if provider_list else []
return {"provider_list": provider_list} return {"provider_list": provider_list}
@provider_router.delete("/{provider_title}") @provider_router.delete("/{provider_title}")
async def delete_provider(provider_title: str, _: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.SUPER_ADMINISTRATOR))) -> dict: async def delete_provider(provider_title: str, _: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.SUPER_ADMINISTRATOR))) -> dict:
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
postgres_database = ray_actor_hook("postgres_database").postgres_database postgres_database = ray_actor_hook("postgres_database").postgres_database
await global_state_machine.delete_provider.remote( provider_title=provider_title, postgres_database=postgres_database) await global_state_machine.provider_manager.remote("delete_provider", provider_title=provider_title, postgres_database=postgres_database)
return {"message": "success"} return {"message": "success"}

View File

@ -27,19 +27,19 @@ resource_router = APIRouter(prefix="/api/v1/resource")
async def create_workflow_template(workflow_template: WorkflowTemplate, async def create_workflow_template(workflow_template: WorkflowTemplate,
_: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.USER))): _: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.USER))):
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
await global_state_machine.add_workflow_template.remote( workflow_template.name, workflow_template) await global_state_machine.workflow_template_manager.remote("add_workflow_template", workflow_template.name, workflow_template)
return {"message": "创建成功"} return {"message": "创建成功"}
@resource_router.get("/workflow_template") @resource_router.get("/workflow_template")
async def get_workflow_templates(_: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.USER))): async def get_workflow_templates(_: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.USER))):
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
templates = await global_state_machine.get_all_workflow_templates.remote() templates = await global_state_machine.workflow_template_manager.remote("get_all_workflow_templates")
return {"templates": templates} return {"templates": templates}
@resource_router.delete("/workflow_template/{template_name}") @resource_router.delete("/workflow_template/{template_name}")
async def delete_workflow_template(template_name: str, _: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.SUPER_ADMINISTRATOR))): async def delete_workflow_template(template_name: str, _: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.SUPER_ADMINISTRATOR))):
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
await global_state_machine.delete_workflow_template.remote( template_name) await global_state_machine.workflow_template_manager.remote("delete_workflow_template", template_name)
return {"message": "success"} return {"message": "success"}
@ -60,18 +60,18 @@ async def install_skill(skill: Skill,
skill_name = skill.path.split("/")[-1] skill_name = skill.path.split("/")[-1]
else: else:
skill_name = skill.repo_url.split("/")[-1] skill_name = skill.repo_url.split("/")[-1]
await global_state_machine.add_skill.remote( skill_name) await global_state_machine.skill_manager.remote("add_skill", skill_name)
return {"message": "创建成功"} return {"message": "创建成功"}
@resource_router.get("/skill") @resource_router.get("/skill")
async def get_skills(_: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.USER))): async def get_skills(_: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.USER))):
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
skills = await global_state_machine.get_skill_list.remote() skills = await global_state_machine.skill_manager.remote("get_skill_list")
return {"skills": skills} return {"skills": skills}
@resource_router.delete("/skill/{skill_name}") @resource_router.delete("/skill/{skill_name}")
async def delete_skill(skill_name: str, _: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.SUPER_ADMINISTRATOR))): async def delete_skill(skill_name: str, _: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.SUPER_ADMINISTRATOR))):
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
# Note: this only removes it from the state machine manager. # Note: this only removes it from the state machine manager.
await global_state_machine.remove_skill.remote( skill_name) await global_state_machine.skill_manager.remote("remove_skill", skill_name)
return {"message": "success"} return {"message": "success"}

View File

@ -29,11 +29,6 @@ async def get_workflow(websocket: WebSocket, event_id: str):
await websocket.send_text(await global_state_machine.get_pending.remote(event_id)) await websocket.send_text(await global_state_machine.get_pending.remote(event_id))
response = await websocket.receive_text() response = await websocket.receive_text()
await global_state_machine.put_received(event_id, response) await global_state_machine.put_received(event_id, response)
except WebSocketDisconnect: except (WebSocketDisconnect, RuntimeError):
pass
except RuntimeError as e:
if "closed" not in str(e) and "GeneratorExit" not in str(e):
raise
except Exception:
pass pass

View File

@ -53,77 +53,17 @@ class PostgresDatabase:
finally: finally:
self.ready_event.set() self.ready_event.set()
# Auth Database Methods async def auth_database(self, method_name: str, *args, **kwargs):
async def add_user(self, user_name: str, hashed_password: str):
await self.ready_event.wait() await self.ready_event.wait()
return await self._auth_database.add_user(user_name, hashed_password) method = getattr(self._auth_database, method_name)
return await method(*args, **kwargs)
async def change_password(self, user_name, old_password, new_password): async def provider_database(self, method_name: str, *args, **kwargs):
await self.ready_event.wait() await self.ready_event.wait()
return await self._auth_database.change_password(user_name, old_password, new_password) method = getattr(self._provider_database, method_name)
return await method(*args, **kwargs)
async def delete_user(self, user_name: str): async def individual_database(self, method_name: str, *args, **kwargs):
await self.ready_event.wait() await self.ready_event.wait()
return await self._auth_database.delete_user(user_name) method = getattr(self._individual_database, method_name)
return await method(*args, **kwargs)
async def delete_user_by_id(self, user_id: str):
await self.ready_event.wait()
return await self._auth_database.delete_user_by_id(user_id)
async def login_user(self, user_name: str):
await self.ready_event.wait()
return await self._auth_database.login_user(user_name)
async def get_all_users(self):
await self.ready_event.wait()
return await self._auth_database.get_all_users()
async def get_user_authority(self, user_id: str):
await self.ready_event.wait()
return await self._auth_database.get_user_authority(user_id)
async def change_user_authority(self, user_id: str, new_authority):
await self.ready_event.wait()
return await self._auth_database.change_user_authority(user_id, new_authority)
# Provider Database Methods
async def get_provider(self):
await self.ready_event.wait()
return await self._provider_database.get_provider()
async def add_provider_db(self, **kwargs):
await self.ready_event.wait()
return await self._provider_database.add_provider(**kwargs)
async def delete_provider_db(self, provider_id: str):
await self.ready_event.wait()
return await self._provider_database.delete_provider(provider_id)
async def update_provider_db(self, provider_id: str, **kwargs):
await self.ready_event.wait()
return await self._provider_database.update_provider(provider_id, **kwargs)
# Individual Database Methods
async def add_worker_individual(self, **kwargs):
await self.ready_event.wait()
return await self._individual_database.add_worker_individual(**kwargs)
async def get_worker_individual(self, agent_id: str):
await self.ready_event.wait()
return await self._individual_database.get_worker_individual(agent_id)
async def get_worker_individual_list(self, owner_id: str):
await self.ready_event.wait()
return await self._individual_database.get_worker_individual_list(owner_id)
async def update_worker_individual(self, agent_id: str, **kwargs):
await self.ready_event.wait()
return await self._individual_database.update_worker_individual(agent_id, **kwargs)
async def delete_worker_individual(self, agent_id: str):
await self.ready_event.wait()
return await self._individual_database.delete_worker_individual(agent_id)
async def get_all_worker_individual(self):
await self.ready_event.wait()
return await self._individual_database.get_all_worker_individual()

View File

@ -53,64 +53,35 @@ class GlobalStateMachine:
postgres_database=self.postgres_database postgres_database=self.postgres_database
) )
# Provider Manager Methods async def provider_manager(self, method_name: str, *args, **kwargs):
def get_provider_list(self): method = getattr(self._global_provider_manager, method_name)
return self._global_provider_manager.get_provider_list() if asyncio.iscoroutinefunction(method):
return await method(*args, **kwargs)
return method(*args, **kwargs)
def get_provider(self, provider_title): async def tool_manager(self, method_name: str, *args, **kwargs):
return self._global_provider_manager.get_provider(provider_title) method = getattr(self._global_tool_manager, method_name)
if asyncio.iscoroutinefunction(method):
return await method(*args, **kwargs)
return method(*args, **kwargs)
async def delete_provider(self, provider_title: str): async def workflow_template_manager(self, method_name: str, *args, **kwargs):
return await self._global_provider_manager.delete_provider(provider_title, self.postgres_database) method = getattr(self._global_workflow_template_manager, method_name)
if asyncio.iscoroutinefunction(method):
return await method(*args, **kwargs)
return method(*args, **kwargs)
# Tool Manager Methods async def skill_manager(self, method_name: str, *args, **kwargs):
def get_tool_mapper(self): method = getattr(self._global_skill_manager, method_name)
return self._global_tool_manager.tool_mapper if asyncio.iscoroutinefunction(method):
return await method(*args, **kwargs)
return method(*args, **kwargs)
def get_tool_list(self, agent_name: str): async def individual_manager(self, method_name: str, *args, **kwargs):
# get_tool_list didn't actually exist on tool_manager, let's implement it to return the tools method = getattr(self._global_individual_manager, method_name)
# for a specific agent name (or scope) if asyncio.iscoroutinefunction(method):
tools = self._global_tool_manager.tool_mapper.get(agent_name, {}) return await method(*args, **kwargs)
# also include default tools return method(*args, **kwargs)
default_tools = self._global_tool_manager.tool_mapper.get("default", {})
merged_tools = {**default_tools, **tools}
return merged_tools
# Workflow Template Manager Methods
def get_all_workflow_templates(self):
return self._global_workflow_template_manager.get_all_workflow_templates()
def add_workflow_template(self, template_name: str, workflow_template):
return self._global_workflow_template_manager.add_workflow_template(template_name, workflow_template)
def delete_workflow_template(self, template_name: str):
return self._global_workflow_template_manager.delete_workflow_template(template_name)
def generate_workflow_template(self, workflow_template):
return self._global_workflow_template_manager.generate_workflow_template(workflow_template)
# Skill Manager Methods
def add_skill(self, skill_name: str):
return self._global_skill_manager.add_skill(skill_name)
def get_skill_list(self):
return self._global_skill_manager.get_skill_list()
def remove_skill(self, skill_name: str):
return self._global_skill_manager.remove_skill(skill_name)
# Individual Manager Methods
def add_individual(self, agent_id: str, config):
return self._global_individual_manager.add_individual(agent_id, config)
def get_individual(self, agent_id: str):
return self._global_individual_manager.get_individual(agent_id)
def remove_individual(self, agent_id: str):
return self._global_individual_manager.remove_individual(agent_id)
def list_individuals(self):
return self._global_individual_manager.list_individuals()
###以下方法为event_dict方法 ###以下方法为event_dict方法
def add_event(self, event: PretorEvent) -> None: def add_event(self, event: PretorEvent) -> None:

View File

@ -23,7 +23,7 @@ class GlobalIndividualManager:
async def init_individual_register(self, postgres) -> None: async def init_individual_register(self, postgres) -> None:
try: try:
try: try:
individuals = await postgres.get_all_worker_individual.remote() individuals = await postgres.individual_database.remote("get_all_worker_individual")
for ind in individuals: for ind in individuals:
agent_id = getattr(ind, 'agent_id', None) agent_id = getattr(ind, 'agent_id', None)
if agent_id: if agent_id:

View File

@ -1,4 +1,3 @@
from pretor.utils.retry import retry_on_retryable_error
# Copyright 2026 zhaoxi826 # Copyright 2026 zhaoxi826
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
@ -25,7 +24,6 @@ class ClaudeProvider(BaseProvider):
return provider return provider
@staticmethod @staticmethod
@retry_on_retryable_error()
async def _load_models(provider_args: ProviderArgs) -> List[str]: async def _load_models(provider_args: ProviderArgs) -> List[str]:
# Anthropic 官方需要 version 头 # Anthropic 官方需要 version 头
headers = { headers = {

View File

@ -1,4 +1,3 @@
from pretor.utils.retry import retry_on_retryable_error
# Copyright 2026 zhaoxi826 # Copyright 2026 zhaoxi826
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
@ -25,7 +24,6 @@ class GeminiProvider(BaseProvider):
return provider return provider
@staticmethod @staticmethod
@retry_on_retryable_error()
async def _load_models(provider_args: ProviderArgs) -> List[str]: async def _load_models(provider_args: ProviderArgs) -> List[str]:
# Google Gemini 原生鉴权通常使用 x-goog-api-key 或 query parameter # Google Gemini 原生鉴权通常使用 x-goog-api-key 或 query parameter
headers = { headers = {
@ -48,10 +46,6 @@ class GeminiProvider(BaseProvider):
model_ids = [m["name"].split("/")[-1] for m in raw_models if model_ids = [m["name"].split("/")[-1] for m in raw_models if
"generateContent" in m.get("supportedGenerationMethods", [])] "generateContent" in m.get("supportedGenerationMethods", [])]
return sorted(list(set(model_ids))) return sorted(list(set(model_ids)))
except httpx.RequestError as e:
from pretor.utils.error import RetryableError
print(f"[{provider_args.provider_title}] 网络请求异常: {e}")
raise RetryableError(f"[{provider_args.provider_title}] 网络请求异常: {e}") from e
except Exception as e: except Exception as e:
print(f"[{provider_args.provider_title}] 获取 Gemini 模型列表错误: {e}") print(f"[{provider_args.provider_title}] 获取 Gemini 模型列表错误: {e}")
return [] return []

View File

@ -1,4 +1,3 @@
from pretor.utils.retry import retry_on_retryable_error
# Copyright 2026 zhaoxi826 # Copyright 2026 zhaoxi826
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
@ -25,7 +24,6 @@ class OpenAIProvider(BaseProvider):
return provider return provider
@staticmethod @staticmethod
@retry_on_retryable_error()
async def _load_models(provider_args: ProviderArgs) -> List[str]: async def _load_models(provider_args: ProviderArgs) -> List[str]:
headers = { headers = {
"Authorization": f"Bearer {provider_args.provider_apikey}", "Authorization": f"Bearer {provider_args.provider_apikey}",
@ -43,9 +41,8 @@ class OpenAIProvider(BaseProvider):
model_ids = [m["id"] for m in raw_models] model_ids = [m["id"] for m in raw_models]
return sorted(model_ids) return sorted(model_ids)
except httpx.RequestError as e: except httpx.RequestError as e:
from pretor.utils.error import RetryableError
print(f"[{provider_args.provider_title}] 网络请求异常: {e}") print(f"[{provider_args.provider_title}] 网络请求异常: {e}")
raise RetryableError(f"[{provider_args.provider_title}] 网络请求异常: {e}") from e return []
except Exception as e: except Exception as e:
print(f"[{provider_args.provider_title}] 解析模型列表时发生错误: {e}") print(f"[{provider_args.provider_title}] 解析模型列表时发生错误: {e}")
return [] return []

View File

@ -33,7 +33,7 @@ class ProviderManager:
self.provider_register = {} self.provider_register = {}
async def init_provider_register(self, postgres) -> None: async def init_provider_register(self, postgres) -> None:
providers = await postgres.get_provider.remote() providers = await postgres.provider_database.remote("get_provider")
for provider in providers: for provider in providers:
self.provider_register[provider.provider_title] = provider self.provider_register[provider.provider_title] = provider
@ -48,17 +48,14 @@ class ProviderManager:
provider_apikey=provider_apikey, provider_apikey=provider_apikey,
provider_owner=provider_owner) provider_owner=provider_owner)
try: try:
import ulid
provider_class = self.provider_mapper.get(provider_type, None) provider_class = self.provider_mapper.get(provider_type, None)
if provider_class is None: if provider_class is None:
logger.warning(f"Provider type {provider_type} is not supported.") logger.warning(f"Provider type {provider_type} is not supported.")
return None return None
provider: Provider = await provider_class.create_provider(provider_args) provider: Provider = await provider_class.create_model(provider_args)
provider.provider_owner = provider_owner provider.provider_owner = provider_owner
self.provider_register[provider_title] = provider self.provider_register[provider_title] = provider
await postgres_database.add_provider_db.remote( await postgres_database.provider_database.remote("add_provider", provider_title=provider.provider_title,
provider_id=str(ulid.ULID()),
provider_title=provider.provider_title,
provider_url=provider.provider_url, provider_url=provider.provider_url,
provider_apikey=provider.provider_apikey, provider_apikey=provider.provider_apikey,
provider_models=provider.provider_models, provider_models=provider.provider_models,
@ -67,9 +64,7 @@ class ProviderManager:
logger.info(f"已添加适配器{provider_title}") logger.info(f"已添加适配器{provider_title}")
except httpx.RequestError as e: except httpx.RequestError as e:
from pretor.utils.error import RetryableError
logger.warning(f"[{provider_args.provider_title}] 网络请求异常: {e}") logger.warning(f"[{provider_args.provider_title}] 网络请求异常: {e}")
raise RetryableError(f"[{provider_args.provider_title}] 网络请求异常: {e}") from e
except Exception as e: except Exception as e:
logger.warning(f"[{provider_args.provider_title}] 解析模型列表时发生错误: {e}") logger.warning(f"[{provider_args.provider_title}] 解析模型列表时发生错误: {e}")
@ -82,5 +77,5 @@ class ProviderManager:
async def delete_provider(self, provider_title: str, postgres_database) -> None: async def delete_provider(self, provider_title: str, postgres_database) -> None:
if provider_title in self.provider_register: if provider_title in self.provider_register:
provider = self.provider_register[provider_title] provider = self.provider_register[provider_title]
await postgres_database.delete_provider_db.remote( provider_id=provider.provider_id) await postgres_database.provider_database.remote("delete_provider", provider_id=provider.provider_id)
del self.provider_register[provider_title] del self.provider_register[provider_title]

View File

@ -33,7 +33,7 @@ class ConsciousnessNode:
self.agent: None | Agent = None self.agent: None | Agent = None
async def create_agent(self, global_state_machine: GlobalStateMachine, provider_title: str, model_id: str) -> None: def create_agent(self, global_state_machine: GlobalStateMachine, provider_title: str, model_id: str) -> None:
""" """
create_agent方法将agent对象装配到ConsciousnessNode的属性内 create_agent方法将agent对象装配到ConsciousnessNode的属性内
该方法通过provider_title从global_state_machine中获取provider对象然后从provider对象中取出供应商形象装配为pydantic_ai的 该方法通过provider_title从global_state_machine中获取provider对象然后从provider对象中取出供应商形象装配为pydantic_ai的
@ -57,7 +57,7 @@ class ConsciousnessNode:
"请确保所有的思考和生成过程符合逻辑,严密且高质量。" "请确保所有的思考和生成过程符合逻辑,严密且高质量。"
) )
output_type = Union[ForSupervisoryNode, ForWorkflow, ForWorkflowEngine] output_type = Union[ForSupervisoryNode, ForWorkflow, ForWorkflowEngine]
provider: Provider = await global_state_machine.get_provider.remote( provider_title) provider: Provider = global_state_machine.provider_manager.remote("get_provider", provider_title)
agent_factory = AgentFactory() agent_factory = AgentFactory()
self.agent = agent_factory.create_agent(provider=provider, self.agent = agent_factory.create_agent(provider=provider,
model_id=model_id, model_id=model_id,

View File

@ -30,7 +30,7 @@ class ControlNode:
self.agent: Agent | None = None self.agent: Agent | None = None
async def create_agent(self, global_state_machine: GlobalStateMachine, provider_title: str, model_id: str) -> None: def create_agent(self, global_state_machine: GlobalStateMachine, provider_title: str, model_id: str) -> None:
""" """
create_agent方法将agent对象装配到Control的属性内 create_agent方法将agent对象装配到Control的属性内
该方法通过provider_title从global_state_machine中获取provider对象然后从provider对象中取出供应商形象装配为pydantic_ai的 该方法通过provider_title从global_state_machine中获取provider对象然后从provider对象中取出供应商形象装配为pydantic_ai的
@ -54,7 +54,7 @@ class ControlNode:
"请注意:你的输出应当具体、实用,直接提供任务所要求的结果,不要做过多无关的寒暄。" "请注意:你的输出应当具体、实用,直接提供任务所要求的结果,不要做过多无关的寒暄。"
) )
output_type = ForWorkflow output_type = ForWorkflow
provider: Provider = await global_state_machine.get_provider.remote( provider_title) provider: Provider = global_state_machine.provider_manager.remote("get_provider", provider_title)
agent_factory = AgentFactory() agent_factory = AgentFactory()
self.agent = agent_factory.create_agent(provider=provider, self.agent = agent_factory.create_agent(provider=provider,
model_id=model_id, model_id=model_id,

View File

@ -57,7 +57,7 @@ class SupervisoryNode:
"请保持冷静、专业,并严格遵循上述路由规则。" "请保持冷静、专业,并严格遵循上述路由规则。"
) )
output_type = Union[ForConsciousnessNode, ForUser] output_type = Union[ForConsciousnessNode, ForUser]
provider: Provider = await global_state_machine.get_provider.remote( provider_title) provider: Provider = await global_state_machine.provider_manager.remote("get_provider", provider_title)
agent_factory = AgentFactory() agent_factory = AgentFactory()
self.agent = agent_factory.create_agent(provider=provider, self.agent = agent_factory.create_agent(provider=provider,
model_id=model_id, model_id=model_id,
@ -157,8 +157,8 @@ class SupervisoryNode:
message = payload.message message = payload.message
time_str = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") time_str = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
try: try:
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine global_state_machine = ray_actor_hook("global_state_machine")
workflow_template_dict = await global_state_machine.get_all_workflow_templates.remote() workflow_template_dict = await global_state_machine.workflow_template_manager.remote("get_workflow_template_list")
available_templates_str = "\n".join([f"- 名称: {k}, 描述/内容: {v}" for k, v in available_templates_str = "\n".join([f"- 名称: {k}, 描述/内容: {v}" for k, v in
workflow_template_dict.items()]) if workflow_template_dict else "暂无注册的工作流模板" workflow_template_dict.items()]) if workflow_template_dict else "暂无注册的工作流模板"
deps = SupervisoryNodeDeps( deps = SupervisoryNodeDeps(

View File

@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from functools import lru_cache
from typing import Annotated from typing import Annotated
from fastapi import Depends, HTTPException from fastapi import Depends, HTTPException
from pretor.utils.access import Accessor, TokenData from pretor.utils.access import Accessor, TokenData
@ -18,24 +19,9 @@ from pretor.core.database.table.user import UserAuthority
from pretor.utils.ray_hook import ray_actor_hook from pretor.utils.ray_hook import ray_actor_hook
async def get_authority(user_id: str) -> UserAuthority: async def get_authority(user_id: str) -> UserAuthority:
from pretor.utils.error import UserNotExistError
postgres_database = ray_actor_hook("postgres_database").postgres_database postgres_database = ray_actor_hook("postgres_database").postgres_database
try: user_authority = await postgres_database.auth_database.remote("get_user_authority", user_id=user_id)
user_authority = await postgres_database.get_user_authority.remote(user_id=user_id) return user_authority
return user_authority
except UserNotExistError:
raise HTTPException(
status_code=401,
detail="用户不存在或已被删除,请重新登录"
)
except Exception as e:
# Check if it's a RayTaskError wrapping UserNotExistError
if "UserNotExistError" in str(e):
raise HTTPException(
status_code=401,
detail="用户不存在或已被删除,请重新登录"
)
raise
class RoleChecker: class RoleChecker:
def __init__(self, **kwargs): def __init__(self, **kwargs):

View File

@ -12,15 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
class RetryableError(Exception): class DemandError(Exception):
"""基类:所有可重试错误(如网络断开、抖动等临时性故障)"""
pass
class NonRetryableError(Exception):
"""基类:所有不可重试错误(如数据验证失败、类型错误等业务逻辑故障)"""
pass
class DemandError(NonRetryableError):
pass pass
class ModelNotExistError(Exception): class ModelNotExistError(Exception):

View File

@ -51,7 +51,7 @@ def del_tool_cache(tool_name: str) -> None:
@lru_cache(maxsize=1) @lru_cache(maxsize=1)
async def get_tool(agent_name: str) -> List[Callable]: async def get_tool(agent_name: str) -> List[Callable]:
global_state_machine = ray_actor_hook("global_state_machine") global_state_machine = ray_actor_hook("global_state_machine")
_tool_list = await global_state_machine.get_tool_list.remote( agent_name) _tool_list = await global_state_machine.tool_manager.remote("get_tool_list", agent_name)
tool_list = [] tool_list = []
for tool_name in _tool_list.keys(): for tool_name in _tool_list.keys():
tool_func = _get_tool_func(tool_name) tool_func = _get_tool_func(tool_name)

View File

@ -1,31 +0,0 @@
import asyncio
from functools import wraps
from pretor.utils.error import RetryableError
def retry_on_retryable_error(max_retries=3, base_delay=1):
def decorator(func):
if asyncio.iscoroutinefunction(func):
@wraps(func)
async def async_wrapper(*args, **kwargs):
for attempt in range(max_retries):
try:
return await func(*args, **kwargs)
except RetryableError:
if attempt == max_retries - 1:
raise
await asyncio.sleep(base_delay * (2 ** attempt))
return async_wrapper
else:
@wraps(func)
def sync_wrapper(*args, **kwargs):
import time
for attempt in range(max_retries):
try:
return func(*args, **kwargs)
except RetryableError:
if attempt == max_retries - 1:
raise
time.sleep(base_delay * (2 ** attempt))
return sync_wrapper
return decorator

View File

@ -52,7 +52,7 @@ class WorkerCluster:
return self._active_workers[agent_id] return self._active_workers[agent_id]
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
agent_config = await global_state_machine.get_individual.remote( agent_id) agent_config = await global_state_machine.individual_manager.remote("get_individual", agent_id)
if not agent_config: if not agent_config:
raise ValueError(f"无法唤醒 Agent {agent_id}:数据库中不存在该档案") raise ValueError(f"无法唤醒 Agent {agent_id}:数据库中不存在该档案")

View File

@ -51,7 +51,7 @@ class BaseIndividual:
provider_title = self.agent_config.get("provider_title", "openai") # default fallback provider_title = self.agent_config.get("provider_title", "openai") # default fallback
model_id = self.agent_config.get("model_id", "gpt-4o") # default fallback model_id = self.agent_config.get("model_id", "gpt-4o") # default fallback
provider: Provider = await global_state_machine.get_provider.remote( provider_title) provider: Provider = await global_state_machine.provider_manager.remote("get_provider", provider_title)
agent_factory = AgentFactory() agent_factory = AgentFactory()
self.agent = agent_factory.create_agent( self.agent = agent_factory.create_agent(
provider=provider, provider=provider,

View File

@ -8,8 +8,8 @@ def test_create_agent_success_real():
mock_provider = MagicMock() mock_provider = MagicMock()
mock_provider.provider_type = "openai" mock_provider.provider_type = "openai"
mock_provider.provider_models = ["gpt-4"] mock_provider.provider_models = ["gpt-4"]
mock_provider.provider_apikey = "key" mock_provider.api_key = "key"
mock_provider.provider_url = "url" mock_provider.url = "url"
with patch("pretor.adapter.model_adapter.agent_factory.Agent") as mock_agent_cls: with patch("pretor.adapter.model_adapter.agent_factory.Agent") as mock_agent_cls:
with patch("pretor.adapter.model_adapter.agent_factory.OpenAIChatModel") as mock_model_cls: with patch("pretor.adapter.model_adapter.agent_factory.OpenAIChatModel") as mock_model_cls:
@ -23,8 +23,8 @@ def test_create_agent_success_real():
deps_type=dict, deps_type=dict,
agent_name="myagent" agent_name="myagent"
) )
mock_provider_cls.assert_called_once_with(api_key="key", base_url="url") mock_provider_cls.assert_called_once_with(api_key="key", url="url")
mock_model_cls.assert_called_once_with("gpt-4", provider=mock_provider_cls.return_value) mock_model_cls.assert_called_once_with("gpt-4", mock_provider_cls.return_value)
mock_agent_cls.assert_called_once_with( mock_agent_cls.assert_called_once_with(
model=mock_model_cls.return_value, model=mock_model_cls.return_value,
name="myagent", name="myagent",

View File

@ -10,12 +10,8 @@ def mock_import(name, globals=None, locals=None, fromlist=(), level=0):
if name == 'ray': if name == 'ray':
mock_ray = MagicMock() mock_ray = MagicMock()
def mock_remote(*args, **kwargs): def mock_remote(cls):
if len(args) == 1 and callable(args[0]): return cls
return args[0]
def decorator(cls):
return cls
return decorator
mock_ray.remote = mock_remote mock_ray.remote = mock_remote
return mock_ray return mock_ray
@ -70,9 +66,8 @@ async def test_postgres_database(mock_env_get, mock_provider_db, mock_auth_db, m
mock_auth_db.assert_called_once() mock_auth_db.assert_called_once()
mock_provider_db.assert_called_once() mock_provider_db.assert_called_once()
mock_auth_db.return_value.get_user_authority = AsyncMock(return_value="test_auth") mock_auth_db.return_value.get_user_authority = AsyncMock(return_value="test_auth")
assert await db.auth_database("get_user_authority", user_id="123") == "test_auth"
with patch("pretor.core.database.postgres.SQLModel.metadata.create_all") as mock_create_all: with patch("pretor.core.database.postgres.SQLModel.metadata.create_all") as mock_create_all:
await db.init_db() await db.init_db()
mock_conn.run_sync.assert_called_once_with(mock_create_all) mock_conn.run_sync.assert_called_once_with(mock_create_all)
assert await db.get_user_authority(user_id="123") == "test_auth"

View File

@ -12,12 +12,8 @@ def mock_import(name, globals=None, locals=None, fromlist=(), level=0):
if name == 'ray': if name == 'ray':
mock_ray = MagicMock() mock_ray = MagicMock()
def mock_remote(*args, **kwargs): def mock_remote(cls):
if len(args) == 1 and callable(args[0]): return cls
return args[0]
def decorator(cls):
return cls
return decorator
mock_ray.remote = mock_remote mock_ray.remote = mock_remote
return mock_ray return mock_ray
@ -104,20 +100,19 @@ async def test_add_provider_success(gsm, mock_postgres):
mock_provider.provider_apikey = "key" mock_provider.provider_apikey = "key"
mock_provider.provider_models = ["model"] mock_provider.provider_models = ["model"]
mock_provider.provider_type = "openai" mock_provider.provider_type = "openai"
mock_provider_class.create_provider.return_value = mock_provider mock_provider_class.create_model.return_value = mock_provider
gsm._global_provider_manager.provider_mapper = {"openai": mock_provider_class} gsm._global_provider_manager.provider_mapper = {"openai": mock_provider_class}
gsm._global_provider_manager.provider_register = {} gsm._global_provider_manager.provider_register = {}
mock_add_provider = AsyncMock() mock_add_provider = AsyncMock()
mock_postgres.add_provider_db = MagicMock() mock_postgres.provider_database.remote = mock_add_provider
mock_postgres.add_provider_db.remote = mock_add_provider
await gsm.add_provider_wrap("openai", "title", "url", "key", "1") await gsm.add_provider_wrap("openai", "title", "url", "key", 1)
assert gsm._global_provider_manager.provider_register["title"] == mock_provider assert gsm._global_provider_manager.provider_register["title"] == mock_provider
mock_add_provider.assert_called_once() mock_add_provider.assert_called_once()
assert mock_provider.provider_owner == "1" assert mock_provider.provider_owner == 1
@pytest.mark.asyncio @pytest.mark.asyncio
@ -126,7 +121,7 @@ async def test_add_provider_unsupported(gsm):
with patch("pretor.utils.logger.global_logger.bind") as mock_bind: with patch("pretor.utils.logger.global_logger.bind") as mock_bind:
mock_logger = MagicMock() mock_logger = MagicMock()
mock_bind.return_value = mock_logger mock_bind.return_value = mock_logger
await gsm.add_provider_wrap("magic", "title", "url", "key", "1") await gsm.add_provider_wrap("magic", "title", "url", "key", 1)
mock_logger.warning.assert_called_with("Provider type magic is not supported.") mock_logger.warning.assert_called_with("Provider type magic is not supported.")
@ -134,29 +129,27 @@ async def test_add_provider_unsupported(gsm):
async def test_add_provider_request_error(gsm): async def test_add_provider_request_error(gsm):
from httpx import RequestError from httpx import RequestError
mock_provider_class = AsyncMock() mock_provider_class = AsyncMock()
mock_provider_class.create_provider.side_effect = RequestError("Network Error", request=MagicMock()) mock_provider_class.create_model.side_effect = RequestError("Network Error", request=MagicMock())
gsm._global_provider_manager.provider_mapper = {"openai": mock_provider_class} gsm._global_provider_manager.provider_mapper = {"openai": mock_provider_class}
with patch("pretor.utils.logger.global_logger.bind") as mock_bind: with patch("pretor.utils.logger.global_logger.bind") as mock_bind:
from pretor.utils.error import RetryableError
import pytest
mock_logger = MagicMock() mock_logger = MagicMock()
mock_bind.return_value = mock_logger mock_bind.return_value = mock_logger
with pytest.raises(RetryableError): await gsm.add_provider_wrap("openai", "title", "url", "key", 1)
await gsm.add_provider_wrap("openai", "title", "url", "key", "1") mock_logger.warning.assert_called_once()
mock_logger.warning.assert_called() assert "网络请求异常" in mock_logger.warning.call_args[0][0]
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_add_provider_generic_error(gsm): async def test_add_provider_generic_error(gsm):
mock_provider_class = AsyncMock() mock_provider_class = AsyncMock()
mock_provider_class.create_provider.side_effect = ValueError("Some Error") mock_provider_class.create_model.side_effect = ValueError("Some Error")
gsm._global_provider_manager.provider_mapper = {"openai": mock_provider_class} gsm._global_provider_manager.provider_mapper = {"openai": mock_provider_class}
with patch("pretor.utils.logger.global_logger.bind") as mock_bind: with patch("pretor.utils.logger.global_logger.bind") as mock_bind:
mock_logger = MagicMock() mock_logger = MagicMock()
mock_bind.return_value = mock_logger mock_bind.return_value = mock_logger
await gsm.add_provider_wrap("openai", "title", "url", "key", "1") await gsm.add_provider_wrap("openai", "title", "url", "key", 1)
mock_logger.warning.assert_called_once() mock_logger.warning.assert_called_once()
assert "解析模型列表时发生错误" in mock_logger.warning.call_args[0][0] assert "解析模型列表时发生错误" in mock_logger.warning.call_args[0][0]

View File

@ -9,7 +9,7 @@ def provider_args():
provider_title="TestClaude", provider_title="TestClaude",
provider_url="https://api.anthropic.com", provider_url="https://api.anthropic.com",
provider_apikey="testkey", provider_apikey="testkey",
provider_owner="1" provider_owner=1
) )

View File

@ -9,7 +9,7 @@ def provider_args():
provider_title="TestGemini", provider_title="TestGemini",
provider_url="https://generativelanguage.googleapis.com", provider_url="https://generativelanguage.googleapis.com",
provider_apikey="testkey", provider_apikey="testkey",
provider_owner="1" provider_owner=1
) )

View File

@ -9,7 +9,7 @@ def provider_args():
provider_title="TestOpenAI", provider_title="TestOpenAI",
provider_url="https://api.openai.com/v1", provider_url="https://api.openai.com/v1",
provider_apikey="testkey", provider_apikey="testkey",
provider_owner="1" provider_owner=1
) )
@ -19,7 +19,7 @@ def provider_args_no_v1():
provider_title="TestOpenAI", provider_title="TestOpenAI",
provider_url="https://api.openai.com", provider_url="https://api.openai.com",
provider_apikey="testkey", provider_apikey="testkey",
provider_owner="1" provider_owner=1
) )
@ -85,10 +85,8 @@ async def test_load_models_request_error(mock_client, provider_args):
mock_client_instance.get.side_effect = httpx.RequestError("network error", request=MagicMock()) mock_client_instance.get.side_effect = httpx.RequestError("network error", request=MagicMock())
mock_client.return_value.__aenter__.return_value = mock_client_instance mock_client.return_value.__aenter__.return_value = mock_client_instance
import pytest models = await OpenAIProvider._load_models(provider_args)
from pretor.utils.error import RetryableError assert models == []
with pytest.raises(RetryableError):
await OpenAIProvider._load_models(provider_args)
@pytest.mark.asyncio @pytest.mark.asyncio

View File

@ -13,8 +13,7 @@ async def test_provider_manager_init():
mock_provider2 = MagicMock() mock_provider2 = MagicMock()
mock_provider2.provider_title = "title2" mock_provider2.provider_title = "title2"
mock_postgres.get_provider = MagicMock() mock_postgres.get_providers.remote = AsyncMock(return_value=[mock_provider1, mock_provider2])
mock_postgres.get_provider.remote = AsyncMock(return_value=[mock_provider1, mock_provider2])
manager = ProviderManager(mock_postgres) manager = ProviderManager(mock_postgres)
mock_postgres.provider_database = MagicMock() mock_postgres.provider_database = MagicMock()

View File

@ -12,12 +12,8 @@ def mock_import(name, globals=None, locals=None, fromlist=(), level=0):
if name == 'ray': if name == 'ray':
mock_ray = MagicMock() mock_ray = MagicMock()
def mock_remote(*args, **kwargs): def mock_remote(cls):
if len(args) == 1 and callable(args[0]): return cls
return args[0]
def decorator(cls):
return cls
return decorator
mock_ray.remote = mock_remote mock_ray.remote = mock_remote
return mock_ray return mock_ray