Compare commits
3 Commits
bfdb0db933
...
e6bf9e2ce4
| Author | SHA1 | Date |
|---|---|---|
|
|
e6bf9e2ce4 | |
|
|
0b08288656 | |
|
|
5f4754718f |
|
|
@ -1,9 +1,9 @@
|
||||||
import axios from 'axios';
|
import axios from 'axios';
|
||||||
|
|
||||||
// The base URL should typically come from an environment variable in a real app,
|
// The base URL should typically come from an environment variable in a real app.
|
||||||
// but for development we can default to localhost.
|
// If missing, defaulting to '' means requests will be relative to the current browser origin.
|
||||||
export const apiClient = axios.create({
|
export const apiClient = axios.create({
|
||||||
baseURL: import.meta.env.VITE_API_BASE_URL || 'http://localhost:8000',
|
baseURL: import.meta.env.VITE_API_BASE_URL || '',
|
||||||
timeout: 10000,
|
timeout: 10000,
|
||||||
headers: {
|
headers: {
|
||||||
'Content-Type': 'application/json',
|
'Content-Type': 'application/json',
|
||||||
|
|
|
||||||
|
|
@ -1,285 +1,331 @@
|
||||||
import { useState, useEffect } from 'react';
|
import { useState, useEffect } from 'react';
|
||||||
import apiClient from '../../api/client';
|
import apiClient from '../../api/client';
|
||||||
import { Bot, Save } from 'lucide-react';
|
import { Save, Plus, Edit2, Trash2, X } from 'lucide-react';
|
||||||
import type { Provider } from '../../types';
|
import type { Provider } from '../../types';
|
||||||
|
|
||||||
function WorkerIndividualForm({ providers }: { providers: Provider[] }) {
|
interface WorkerIndividual {
|
||||||
const [formData, setFormData] = useState({
|
agent_id: string;
|
||||||
agent_name: '',
|
agent_name: string;
|
||||||
agent_type: 'OrdinaryIndividual',
|
agent_type: string;
|
||||||
description: '',
|
description?: string;
|
||||||
provider_title: providers.length > 0 ? providers[0].provider_title : '',
|
provider_title: string;
|
||||||
model_id: '',
|
model_id: string;
|
||||||
system_prompt: '',
|
system_prompt?: string;
|
||||||
output_template: '{}',
|
output_template?: string; // Change to string for the form state
|
||||||
bound_skill: '{}',
|
bound_skill?: string; // Change to string for the form state
|
||||||
workspace: '[]'
|
workspace?: string; // Change to string for the form state
|
||||||
});
|
|
||||||
const [loading, setLoading] = useState(false);
|
|
||||||
const [message, setMessage] = useState('');
|
|
||||||
|
|
||||||
// Update initial provider_title when providers load
|
|
||||||
useEffect(() => {
|
|
||||||
if (providers.length > 0 && !formData.provider_title) {
|
|
||||||
setFormData(prev => ({ ...prev, provider_title: providers[0].provider_title }));
|
|
||||||
}
|
|
||||||
}, [providers, formData.provider_title]);
|
|
||||||
|
|
||||||
const handleChange = (e: React.ChangeEvent<HTMLInputElement | HTMLSelectElement | HTMLTextAreaElement>) => {
|
|
||||||
setFormData({ ...formData, [e.target.name]: e.target.value });
|
|
||||||
};
|
|
||||||
|
|
||||||
const handleSubmit = async (e: React.FormEvent) => {
|
|
||||||
e.preventDefault();
|
|
||||||
setLoading(true);
|
|
||||||
setMessage('');
|
|
||||||
try {
|
|
||||||
const payload = {
|
|
||||||
...formData,
|
|
||||||
output_template: JSON.parse(formData.output_template),
|
|
||||||
bound_skill: JSON.parse(formData.bound_skill),
|
|
||||||
workspace: JSON.parse(formData.workspace)
|
|
||||||
};
|
|
||||||
await apiClient.post('/api/v1/agent/worker', payload);
|
|
||||||
setMessage('Successfully created worker individual');
|
|
||||||
setFormData({
|
|
||||||
agent_name: '',
|
|
||||||
agent_type: 'OrdinaryIndividual',
|
|
||||||
description: '',
|
|
||||||
provider_title: '',
|
|
||||||
model_id: '',
|
|
||||||
system_prompt: '',
|
|
||||||
output_template: '{}',
|
|
||||||
bound_skill: '{}',
|
|
||||||
workspace: '[]'
|
|
||||||
});
|
|
||||||
} catch (err: any) {
|
|
||||||
console.error(err);
|
|
||||||
setMessage(err.response?.data?.detail || 'Failed to create worker individual. Ensure JSON fields are valid.');
|
|
||||||
} finally {
|
|
||||||
setLoading(false);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
return (
|
|
||||||
<form onSubmit={handleSubmit} className="space-y-4">
|
|
||||||
<div className="grid grid-cols-1 md:grid-cols-2 gap-4">
|
|
||||||
<div>
|
|
||||||
<label className="block text-sm font-medium text-slate-700 mb-1">Agent Name</label>
|
|
||||||
<input required type="text" name="agent_name" value={formData.agent_name} onChange={handleChange} className="w-full px-4 py-2 border border-slate-200 rounded-lg focus:outline-none focus:ring-2 focus:ring-indigo-500" />
|
|
||||||
</div>
|
|
||||||
<div>
|
|
||||||
<label className="block text-sm font-medium text-slate-700 mb-1">Agent Type</label>
|
|
||||||
<select name="agent_type" value={formData.agent_type} onChange={handleChange} className="w-full px-4 py-2 border border-slate-200 rounded-lg focus:outline-none focus:ring-2 focus:ring-indigo-500">
|
|
||||||
<option value="OrdinaryIndividual">OrdinaryIndividual</option>
|
|
||||||
<option value="SkillIndividual">SkillIndividual</option>
|
|
||||||
<option value="SpecialIndividual">SpecialIndividual</option>
|
|
||||||
</select>
|
|
||||||
</div>
|
|
||||||
<div>
|
|
||||||
<label className="block text-sm font-medium text-slate-700 mb-1">Description</label>
|
|
||||||
<input required type="text" name="description" value={formData.description} onChange={handleChange} className="w-full px-4 py-2 border border-slate-200 rounded-lg focus:outline-none focus:ring-2 focus:ring-indigo-500" />
|
|
||||||
</div>
|
|
||||||
<div>
|
|
||||||
<label className="block text-sm font-medium text-slate-700 mb-1">Provider Title</label>
|
|
||||||
<select
|
|
||||||
required
|
|
||||||
name="provider_title"
|
|
||||||
value={formData.provider_title}
|
|
||||||
onChange={handleChange}
|
|
||||||
className="w-full px-4 py-2 border border-slate-200 rounded-lg focus:outline-none focus:ring-2 focus:ring-indigo-500"
|
|
||||||
>
|
|
||||||
{providers.length === 0 ? (
|
|
||||||
<option value="" disabled>No providers available. Create one first.</option>
|
|
||||||
) : (
|
|
||||||
providers.map((p) => (
|
|
||||||
<option key={p.provider_title} value={p.provider_title}>
|
|
||||||
{p.provider_title}
|
|
||||||
</option>
|
|
||||||
))
|
|
||||||
)}
|
|
||||||
</select>
|
|
||||||
</div>
|
|
||||||
<div>
|
|
||||||
<label className="block text-sm font-medium text-slate-700 mb-1">Model ID</label>
|
|
||||||
<input required type="text" name="model_id" value={formData.model_id} onChange={handleChange} className="w-full px-4 py-2 border border-slate-200 rounded-lg focus:outline-none focus:ring-2 focus:ring-indigo-500" />
|
|
||||||
</div>
|
|
||||||
<div className="md:col-span-2">
|
|
||||||
<label className="block text-sm font-medium text-slate-700 mb-1">System Prompt</label>
|
|
||||||
<textarea required name="system_prompt" value={formData.system_prompt} onChange={handleChange} rows={3} className="w-full px-4 py-2 border border-slate-200 rounded-lg focus:outline-none focus:ring-2 focus:ring-indigo-500" />
|
|
||||||
</div>
|
|
||||||
<div>
|
|
||||||
<label className="block text-sm font-medium text-slate-700 mb-1">Output Template (JSON dict)</label>
|
|
||||||
<input required type="text" name="output_template" value={formData.output_template} onChange={handleChange} className="w-full px-4 py-2 border border-slate-200 rounded-lg focus:outline-none focus:ring-2 focus:ring-indigo-500 font-mono text-sm" />
|
|
||||||
</div>
|
|
||||||
<div>
|
|
||||||
<label className="block text-sm font-medium text-slate-700 mb-1">Bound Skill (JSON dict)</label>
|
|
||||||
<input required type="text" name="bound_skill" value={formData.bound_skill} onChange={handleChange} className="w-full px-4 py-2 border border-slate-200 rounded-lg focus:outline-none focus:ring-2 focus:ring-indigo-500 font-mono text-sm" />
|
|
||||||
</div>
|
|
||||||
<div>
|
|
||||||
<label className="block text-sm font-medium text-slate-700 mb-1">Workspace (JSON list)</label>
|
|
||||||
<input required type="text" name="workspace" value={formData.workspace} onChange={handleChange} className="w-full px-4 py-2 border border-slate-200 rounded-lg focus:outline-none focus:ring-2 focus:ring-indigo-500 font-mono text-sm" />
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
{message && (
|
|
||||||
<div className={`p-3 rounded-lg text-sm ${message.includes('Success') ? 'bg-green-50 text-green-700' : 'bg-red-50 text-red-700'}`}>
|
|
||||||
{message}
|
|
||||||
</div>
|
|
||||||
)}
|
|
||||||
|
|
||||||
<div className="flex justify-end pt-2">
|
|
||||||
<button type="submit" disabled={loading} className="flex items-center space-x-2 px-4 py-2 bg-indigo-600 text-white rounded-lg hover:bg-indigo-700 transition-colors disabled:opacity-50">
|
|
||||||
<Save size={16} />
|
|
||||||
<span>{loading ? 'Creating...' : 'Create Worker'}</span>
|
|
||||||
</button>
|
|
||||||
</div>
|
|
||||||
</form>
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
export function WorkerIndividualSettings() {
|
export function WorkerIndividualSettings() {
|
||||||
const [nodeType, setNodeType] = useState('supervisory_node');
|
|
||||||
const [providerTitle, setProviderTitle] = useState('');
|
|
||||||
const [modelId, setModelId] = useState('');
|
|
||||||
const [loading, setLoading] = useState(false);
|
|
||||||
const [message, setMessage] = useState('');
|
|
||||||
const [providers, setProviders] = useState<Provider[]>([]);
|
const [providers, setProviders] = useState<Provider[]>([]);
|
||||||
|
const [workers, setWorkers] = useState<WorkerIndividual[]>([]);
|
||||||
|
const [loading, setLoading] = useState(true);
|
||||||
|
const [error, setError] = useState('');
|
||||||
|
|
||||||
useEffect(() => {
|
const [isEditing, setIsEditing] = useState(false);
|
||||||
const fetchProviders = async () => {
|
const [editData, setEditData] = useState<Partial<WorkerIndividual>>({});
|
||||||
try {
|
const [isNew, setIsNew] = useState(false);
|
||||||
const response = await apiClient.get('/api/v1/provider/list');
|
|
||||||
const data = response.data.provider_list || {};
|
|
||||||
const providerArray: Provider[] = Object.values(data);
|
|
||||||
setProviders(providerArray);
|
|
||||||
if (providerArray.length > 0) {
|
|
||||||
setProviderTitle(providerArray[0].provider_title);
|
|
||||||
}
|
|
||||||
} catch (error) {
|
|
||||||
console.error("Failed to fetch providers", error);
|
|
||||||
setProviders([]);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
fetchProviders();
|
|
||||||
}, []);
|
|
||||||
|
|
||||||
const handleCreateNode = async (e: React.FormEvent) => {
|
const [modalMessage, setModalMessage] = useState('');
|
||||||
e.preventDefault();
|
|
||||||
|
const fetchData = async () => {
|
||||||
setLoading(true);
|
setLoading(true);
|
||||||
setMessage('');
|
|
||||||
try {
|
try {
|
||||||
await apiClient.post('/api/v1/agent', {
|
const [provRes, workRes] = await Promise.all([
|
||||||
provider_title: providerTitle,
|
apiClient.get('/api/v1/provider/list'),
|
||||||
model_id: modelId,
|
apiClient.get('/api/v1/agent/worker')
|
||||||
individual_name: nodeType
|
]);
|
||||||
});
|
setProviders(Object.values(provRes.data.provider_list || {}));
|
||||||
setMessage(`Successfully loaded ${nodeType}`);
|
setWorkers(workRes.data.workers || []);
|
||||||
setProviderTitle('');
|
|
||||||
setModelId('');
|
|
||||||
} catch (err: any) {
|
} catch (err: any) {
|
||||||
console.error(err);
|
console.error(err);
|
||||||
setMessage(err.response?.data?.detail || 'Failed to load agent node');
|
setError('Failed to load data');
|
||||||
} finally {
|
} finally {
|
||||||
setLoading(false);
|
setLoading(false);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
fetchData();
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
const handleEdit = (worker: any) => { // Accept the backend object which might have objects instead of strings
|
||||||
|
setEditData({
|
||||||
|
...worker,
|
||||||
|
output_template: typeof worker.output_template === 'string' ? worker.output_template : JSON.stringify(worker.output_template || {}),
|
||||||
|
bound_skill: typeof worker.bound_skill === 'string' ? worker.bound_skill : JSON.stringify(worker.bound_skill || {}),
|
||||||
|
workspace: typeof worker.workspace === 'string' ? worker.workspace : JSON.stringify(worker.workspace || [])
|
||||||
|
});
|
||||||
|
setIsNew(false);
|
||||||
|
setIsEditing(true);
|
||||||
|
setModalMessage('');
|
||||||
|
};
|
||||||
|
|
||||||
|
const handleAddNew = () => {
|
||||||
|
setEditData({
|
||||||
|
agent_name: '',
|
||||||
|
agent_type: 'OrdinaryIndividual',
|
||||||
|
description: '',
|
||||||
|
provider_title: providers.length > 0 ? providers[0].provider_title : '',
|
||||||
|
model_id: '',
|
||||||
|
system_prompt: '',
|
||||||
|
output_template: '{}',
|
||||||
|
bound_skill: '{}',
|
||||||
|
workspace: '[]'
|
||||||
|
});
|
||||||
|
setIsNew(true);
|
||||||
|
setIsEditing(true);
|
||||||
|
setModalMessage('');
|
||||||
|
};
|
||||||
|
|
||||||
|
const handleDelete = async (agent_id: string) => {
|
||||||
|
if (!confirm('Are you sure you want to delete this agent?')) return;
|
||||||
|
try {
|
||||||
|
await apiClient.delete(`/api/v1/agent/worker/${agent_id}`);
|
||||||
|
fetchData();
|
||||||
|
} catch (err: any) {
|
||||||
|
console.error(err);
|
||||||
|
alert('Failed to delete agent');
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const handleModalSave = async (e: React.FormEvent) => {
|
||||||
|
e.preventDefault();
|
||||||
|
setModalMessage('');
|
||||||
|
try {
|
||||||
|
const payload = {
|
||||||
|
...editData,
|
||||||
|
output_template: JSON.parse(editData.output_template || '{}'),
|
||||||
|
bound_skill: JSON.parse(editData.bound_skill || '{}'),
|
||||||
|
workspace: JSON.parse(editData.workspace || '[]')
|
||||||
|
};
|
||||||
|
|
||||||
|
if (isNew) {
|
||||||
|
await apiClient.post('/api/v1/agent/worker', payload);
|
||||||
|
} else {
|
||||||
|
await apiClient.put(`/api/v1/agent/worker/${editData.agent_id}`, payload);
|
||||||
|
}
|
||||||
|
|
||||||
|
setIsEditing(false);
|
||||||
|
fetchData();
|
||||||
|
} catch (err: any) {
|
||||||
|
console.error(err);
|
||||||
|
setModalMessage(err.response?.data?.detail || err.message || 'Failed to save');
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div className="max-w-4xl space-y-6">
|
<div className="max-w-5xl space-y-6 relative">
|
||||||
<div className="mb-8">
|
<div className="mb-8 flex justify-between items-end">
|
||||||
<h1 className="text-2xl font-bold text-slate-800">Worker Individual Settings</h1>
|
<div>
|
||||||
<p className="text-slate-500 mt-1">Configure your system agents and custom workers.</p>
|
<h1 className="text-2xl font-bold text-slate-800">Worker Individuals</h1>
|
||||||
|
<p className="text-slate-500 mt-1">Manage all system nodes and custom workers.</p>
|
||||||
|
</div>
|
||||||
|
<button
|
||||||
|
onClick={handleAddNew}
|
||||||
|
className="flex items-center px-4 py-2 bg-indigo-600 text-white rounded-lg hover:bg-indigo-700 transition-colors"
|
||||||
|
>
|
||||||
|
<Plus size={16} className="mr-2" />
|
||||||
|
Add Worker
|
||||||
|
</button>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
{error && <div className="text-red-600">{error}</div>}
|
||||||
|
|
||||||
<div className="bg-white rounded-xl shadow-sm border border-slate-200 overflow-hidden">
|
<div className="bg-white rounded-xl shadow-sm border border-slate-200 overflow-hidden">
|
||||||
<div className="p-6 border-b border-slate-100 flex items-center justify-between">
|
<div className="p-0">
|
||||||
<div className="flex items-center space-x-3">
|
{loading ? (
|
||||||
<div className="w-10 h-10 bg-indigo-50 text-indigo-600 rounded-lg flex items-center justify-center">
|
<div className="p-6 text-slate-500">Loading...</div>
|
||||||
<Bot size={20} />
|
) : workers.length === 0 ? (
|
||||||
</div>
|
<div className="p-6 text-slate-500">No workers found.</div>
|
||||||
<div>
|
) : (
|
||||||
<h2 className="text-lg font-semibold text-slate-800">System Nodes</h2>
|
<table className="w-full text-left border-collapse">
|
||||||
<p className="text-sm text-slate-500">Initialize core system agents</p>
|
<thead>
|
||||||
</div>
|
<tr className="bg-slate-50 border-b border-slate-200 text-slate-600 text-sm">
|
||||||
</div>
|
<th className="p-4 font-semibold">Name</th>
|
||||||
|
<th className="p-4 font-semibold">Type</th>
|
||||||
|
<th className="p-4 font-semibold">Provider / Model ID</th>
|
||||||
|
<th className="p-4 font-semibold text-right">Actions</th>
|
||||||
|
</tr>
|
||||||
|
</thead>
|
||||||
|
<tbody>
|
||||||
|
{workers.map((w) => (
|
||||||
|
<tr key={w.agent_id} className="border-b border-slate-100 hover:bg-slate-50 transition-colors">
|
||||||
|
<td className="p-4 font-medium text-slate-800">{w.agent_name}</td>
|
||||||
|
<td className="p-4 text-slate-600">
|
||||||
|
<span className="px-2 py-1 bg-slate-100 rounded text-xs">{w.agent_type}</span>
|
||||||
|
</td>
|
||||||
|
<td className="p-4 text-slate-600 text-sm">
|
||||||
|
{w.provider_title} <span className="text-slate-400">/</span> {w.model_id}
|
||||||
|
</td>
|
||||||
|
<td className="p-4 text-right space-x-2">
|
||||||
|
<button onClick={() => handleEdit(w)} className="p-2 text-indigo-600 hover:bg-indigo-50 rounded-lg transition-colors" title="Edit">
|
||||||
|
<Edit2 size={16} />
|
||||||
|
</button>
|
||||||
|
<button onClick={() => handleDelete(w.agent_id)} className="p-2 text-red-600 hover:bg-red-50 rounded-lg transition-colors" title="Delete">
|
||||||
|
<Trash2 size={16} />
|
||||||
|
</button>
|
||||||
|
</td>
|
||||||
|
</tr>
|
||||||
|
))}
|
||||||
|
</tbody>
|
||||||
|
</table>
|
||||||
|
)}
|
||||||
</div>
|
</div>
|
||||||
<div className="p-6">
|
</div>
|
||||||
<form onSubmit={handleCreateNode} className="space-y-4">
|
|
||||||
<div className="grid grid-cols-1 md:grid-cols-3 gap-4">
|
|
||||||
<div>
|
|
||||||
<label className="block text-sm font-medium text-slate-700 mb-1">Node Type</label>
|
|
||||||
<select
|
|
||||||
value={nodeType}
|
|
||||||
onChange={(e) => setNodeType(e.target.value)}
|
|
||||||
className="w-full px-4 py-2 border border-slate-200 rounded-lg focus:outline-none focus:ring-2 focus:ring-indigo-500"
|
|
||||||
>
|
|
||||||
<option value="supervisory_node">Supervisory Node</option>
|
|
||||||
<option value="consciousness_node">Consciousness Node</option>
|
|
||||||
<option value="control_node">Control Node</option>
|
|
||||||
</select>
|
|
||||||
</div>
|
|
||||||
<div>
|
|
||||||
<label className="block text-sm font-medium text-slate-700 mb-1">Provider Title</label>
|
|
||||||
<select
|
|
||||||
value={providerTitle}
|
|
||||||
onChange={(e) => setProviderTitle(e.target.value)}
|
|
||||||
required
|
|
||||||
className="w-full px-4 py-2 border border-slate-200 rounded-lg focus:outline-none focus:ring-2 focus:ring-indigo-500"
|
|
||||||
>
|
|
||||||
{providers.length === 0 ? (
|
|
||||||
<option value="" disabled>No providers available. Create one first.</option>
|
|
||||||
) : (
|
|
||||||
providers.map((p) => (
|
|
||||||
<option key={p.provider_title} value={p.provider_title}>
|
|
||||||
{p.provider_title}
|
|
||||||
</option>
|
|
||||||
))
|
|
||||||
)}
|
|
||||||
</select>
|
|
||||||
</div>
|
|
||||||
<div>
|
|
||||||
<label className="block text-sm font-medium text-slate-700 mb-1">Model ID</label>
|
|
||||||
<input
|
|
||||||
type="text"
|
|
||||||
value={modelId}
|
|
||||||
onChange={(e) => setModelId(e.target.value)}
|
|
||||||
placeholder="e.g. gpt-4"
|
|
||||||
required
|
|
||||||
className="w-full px-4 py-2 border border-slate-200 rounded-lg focus:outline-none focus:ring-2 focus:ring-indigo-500"
|
|
||||||
/>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
{message && (
|
{/* Edit/Create Modal */}
|
||||||
<div className={`p-3 rounded-lg text-sm ${message.includes('Success') ? 'bg-green-50 text-green-700' : 'bg-red-50 text-red-700'}`}>
|
{isEditing && (
|
||||||
{message}
|
<div className="fixed inset-0 bg-black/50 z-50 flex items-center justify-center p-4">
|
||||||
</div>
|
<div className="bg-white rounded-xl shadow-xl w-full max-w-2xl max-h-[90vh] overflow-y-auto">
|
||||||
)}
|
<div className="flex justify-between items-center p-6 border-b border-slate-100 sticky top-0 bg-white z-10">
|
||||||
|
<h2 className="text-xl font-bold text-slate-800">{isNew ? 'Create Worker' : 'Edit Worker'}</h2>
|
||||||
<div className="flex justify-end">
|
<button onClick={() => setIsEditing(false)} className="text-slate-400 hover:text-slate-600">
|
||||||
<button
|
<X size={24} />
|
||||||
type="submit"
|
|
||||||
disabled={loading}
|
|
||||||
className="flex items-center space-x-2 px-4 py-2 bg-indigo-600 text-white rounded-lg hover:bg-indigo-700 transition-colors disabled:opacity-50"
|
|
||||||
>
|
|
||||||
<Save size={16} />
|
|
||||||
<span>{loading ? 'Saving...' : 'Load Node'}</span>
|
|
||||||
</button>
|
</button>
|
||||||
</div>
|
</div>
|
||||||
</form>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<div className="bg-white rounded-xl shadow-sm border border-slate-200 overflow-hidden">
|
<form onSubmit={handleModalSave} className="p-6 space-y-4">
|
||||||
<div className="p-6 border-b border-slate-100">
|
<div className="grid grid-cols-1 md:grid-cols-2 gap-4">
|
||||||
<h2 className="text-lg font-semibold text-slate-800">Create Worker Individual</h2>
|
<div>
|
||||||
<p className="text-sm text-slate-500">Add a new custom worker to the system.</p>
|
<label className="block text-sm font-medium text-slate-700 mb-1">Agent Name</label>
|
||||||
|
<input
|
||||||
|
type="text"
|
||||||
|
required
|
||||||
|
value={editData.agent_name || ''}
|
||||||
|
onChange={(e) => setEditData({...editData, agent_name: e.target.value})}
|
||||||
|
className="w-full px-4 py-2 border border-slate-200 rounded-lg focus:ring-2 focus:ring-indigo-500"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
<div>
|
||||||
|
<label className="block text-sm font-medium text-slate-700 mb-1">Agent Type</label>
|
||||||
|
<select
|
||||||
|
value={editData.agent_type || 'OrdinaryIndividual'}
|
||||||
|
onChange={(e) => setEditData({...editData, agent_type: e.target.value})}
|
||||||
|
className="w-full px-4 py-2 border border-slate-200 rounded-lg focus:ring-2 focus:ring-indigo-500"
|
||||||
|
>
|
||||||
|
<option value="supervisory_node">Supervisory Node</option>
|
||||||
|
<option value="consciousness_node">Consciousness Node</option>
|
||||||
|
<option value="control_node">Control Node</option>
|
||||||
|
<option value="OrdinaryIndividual">Ordinary Individual</option>
|
||||||
|
<option value="SkillIndividual">Skill Individual</option>
|
||||||
|
<option value="SpecialIndividual">Special Individual</option>
|
||||||
|
</select>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div className="grid grid-cols-1 md:grid-cols-2 gap-4">
|
||||||
|
<div>
|
||||||
|
<label className="block text-sm font-medium text-slate-700 mb-1">Provider Title</label>
|
||||||
|
<select
|
||||||
|
value={editData.provider_title || ''}
|
||||||
|
onChange={(e) => setEditData({...editData, provider_title: e.target.value, model_id: ''})}
|
||||||
|
required
|
||||||
|
className="w-full px-4 py-2 border border-slate-200 rounded-lg focus:ring-2 focus:ring-indigo-500"
|
||||||
|
>
|
||||||
|
<option value="" disabled>Select Provider</option>
|
||||||
|
{providers.map((p) => (
|
||||||
|
<option key={p.provider_title} value={p.provider_title}>{p.provider_title}</option>
|
||||||
|
))}
|
||||||
|
</select>
|
||||||
|
</div>
|
||||||
|
<div>
|
||||||
|
<label className="block text-sm font-medium text-slate-700 mb-1">Model ID</label>
|
||||||
|
{(() => {
|
||||||
|
const selectedProvider = providers.find(p => p.provider_title === editData.provider_title);
|
||||||
|
const models = selectedProvider?.provider_models || [];
|
||||||
|
return (
|
||||||
|
<select
|
||||||
|
value={editData.model_id || ''}
|
||||||
|
onChange={(e) => setEditData({...editData, model_id: e.target.value})}
|
||||||
|
required
|
||||||
|
className="w-full px-4 py-2 border border-slate-200 rounded-lg focus:ring-2 focus:ring-indigo-500"
|
||||||
|
>
|
||||||
|
<option value="" disabled>Select a model</option>
|
||||||
|
{models.map(m => <option key={m} value={m}>{m}</option>)}
|
||||||
|
</select>
|
||||||
|
);
|
||||||
|
})()}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div>
|
||||||
|
<label className="block text-sm font-medium text-slate-700 mb-1">Description</label>
|
||||||
|
<textarea
|
||||||
|
value={editData.description || ''}
|
||||||
|
onChange={(e) => setEditData({...editData, description: e.target.value})}
|
||||||
|
rows={2}
|
||||||
|
className="w-full px-4 py-2 border border-slate-200 rounded-lg focus:ring-2 focus:ring-indigo-500"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div>
|
||||||
|
<label className="block text-sm font-medium text-slate-700 mb-1">System Prompt</label>
|
||||||
|
<textarea
|
||||||
|
value={editData.system_prompt || ''}
|
||||||
|
onChange={(e) => setEditData({...editData, system_prompt: e.target.value})}
|
||||||
|
rows={3}
|
||||||
|
className="w-full px-4 py-2 border border-slate-200 rounded-lg focus:ring-2 focus:ring-indigo-500 font-mono text-sm"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div className="grid grid-cols-1 md:grid-cols-2 gap-4">
|
||||||
|
<div>
|
||||||
|
<label className="block text-sm font-medium text-slate-700 mb-1">Output Template (JSON)</label>
|
||||||
|
<textarea
|
||||||
|
value={editData.output_template || '{}'}
|
||||||
|
onChange={(e) => setEditData({...editData, output_template: e.target.value})}
|
||||||
|
rows={3}
|
||||||
|
className="w-full px-4 py-2 border border-slate-200 rounded-lg focus:ring-2 focus:ring-indigo-500 font-mono text-sm"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
<div>
|
||||||
|
<label className="block text-sm font-medium text-slate-700 mb-1">Bound Skill (JSON)</label>
|
||||||
|
<textarea
|
||||||
|
value={editData.bound_skill || '{}'}
|
||||||
|
onChange={(e) => setEditData({...editData, bound_skill: e.target.value})}
|
||||||
|
rows={3}
|
||||||
|
className="w-full px-4 py-2 border border-slate-200 rounded-lg focus:ring-2 focus:ring-indigo-500 font-mono text-sm"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div>
|
||||||
|
<label className="block text-sm font-medium text-slate-700 mb-1">Workspace (JSON Array)</label>
|
||||||
|
<textarea
|
||||||
|
value={editData.workspace || '[]'}
|
||||||
|
onChange={(e) => setEditData({...editData, workspace: e.target.value})}
|
||||||
|
rows={2}
|
||||||
|
className="w-full px-4 py-2 border border-slate-200 rounded-lg focus:ring-2 focus:ring-indigo-500 font-mono text-sm"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{modalMessage && (
|
||||||
|
<div className="p-3 bg-red-50 text-red-700 text-sm rounded-lg">
|
||||||
|
{modalMessage}
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
|
||||||
|
<div className="pt-4 flex justify-end space-x-3 border-t border-slate-100">
|
||||||
|
<button
|
||||||
|
type="button"
|
||||||
|
onClick={() => setIsEditing(false)}
|
||||||
|
className="px-4 py-2 text-slate-600 hover:bg-slate-100 rounded-lg transition-colors"
|
||||||
|
>
|
||||||
|
Cancel
|
||||||
|
</button>
|
||||||
|
<button
|
||||||
|
type="submit"
|
||||||
|
className="flex items-center px-4 py-2 bg-indigo-600 text-white rounded-lg hover:bg-indigo-700 transition-colors"
|
||||||
|
>
|
||||||
|
<Save size={16} className="mr-2" />
|
||||||
|
Save Worker
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
</form>
|
||||||
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div className="p-6">
|
)}
|
||||||
<WorkerIndividualForm providers={providers} />
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
</div>
|
</div>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -16,33 +16,57 @@ export function RightPanel({ selectedWorkflow }: RightPanelProps) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const wsBase = import.meta.env.VITE_API_BASE_URL
|
let ws: WebSocket | null = null;
|
||||||
? import.meta.env.VITE_API_BASE_URL.replace('http', 'ws')
|
let reconnectTimeout: ReturnType<typeof setTimeout>;
|
||||||
: `ws://localhost:8000`;
|
let retryCount = 0;
|
||||||
|
const maxRetryCount = 10;
|
||||||
|
const baseDelay = 1000;
|
||||||
|
|
||||||
// Using the workflow router WS endpoint
|
const connect = () => {
|
||||||
const ws = new WebSocket(`${wsBase}/api/v1/workflow/ws/${selectedWorkflow}`);
|
const protocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:';
|
||||||
|
const host = window.location.host;
|
||||||
|
|
||||||
ws.onopen = () => {
|
const wsBase = import.meta.env.VITE_API_BASE_URL
|
||||||
setIsConnected(true);
|
? import.meta.env.VITE_API_BASE_URL.replace(/^http/, 'ws')
|
||||||
|
: `${protocol}//${host}`;
|
||||||
|
|
||||||
setMessages([]); // clear previous traces
|
// Using the workflow router WS endpoint
|
||||||
|
ws = new WebSocket(`${wsBase}/api/v1/workflow/ws/${selectedWorkflow}`);
|
||||||
|
|
||||||
|
ws.onopen = () => {
|
||||||
|
setIsConnected(true);
|
||||||
|
retryCount = 0; // reset on successful connection
|
||||||
|
setMessages([]); // clear previous traces
|
||||||
|
};
|
||||||
|
|
||||||
|
ws.onmessage = (event) => {
|
||||||
|
try {
|
||||||
|
setMessages(prev => [...prev, event.data]);
|
||||||
|
} catch (e) {
|
||||||
|
console.error("Error receiving workflow websocket message", e);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
ws.onclose = () => {
|
||||||
|
setIsConnected(false);
|
||||||
|
if (retryCount < maxRetryCount) {
|
||||||
|
const delay = baseDelay * Math.pow(2, retryCount);
|
||||||
|
retryCount++;
|
||||||
|
console.log(`WebSocket closed. Reconnecting in ${delay}ms... (Attempt ${retryCount})`);
|
||||||
|
reconnectTimeout = setTimeout(connect, delay);
|
||||||
|
} else {
|
||||||
|
console.error("Max WebSocket reconnect attempts reached.");
|
||||||
|
}
|
||||||
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
ws.onmessage = (event) => {
|
connect();
|
||||||
try {
|
|
||||||
setMessages(prev => [...prev, event.data]);
|
|
||||||
} catch (e) {
|
|
||||||
console.error("Error receiving workflow websocket message", e);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
ws.onclose = () => {
|
|
||||||
setIsConnected(false);
|
|
||||||
};
|
|
||||||
|
|
||||||
return () => {
|
return () => {
|
||||||
ws.close();
|
clearTimeout(reconnectTimeout);
|
||||||
|
if (ws) {
|
||||||
|
ws.close();
|
||||||
|
}
|
||||||
};
|
};
|
||||||
}, [selectedWorkflow]);
|
}, [selectedWorkflow]);
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,14 @@ export function SkillSettings() {
|
||||||
setLoading(true);
|
setLoading(true);
|
||||||
try {
|
try {
|
||||||
const response = await apiClient.get('/api/v1/resource/skill');
|
const response = await apiClient.get('/api/v1/resource/skill');
|
||||||
setSkills(response.data.skills || []);
|
const skillsData = response.data.skills || {};
|
||||||
|
// skillsData might be an object mapping skill names to their details, or it might be an array in some versions.
|
||||||
|
// We ensure it is an array of strings (skill names)
|
||||||
|
if (Array.isArray(skillsData)) {
|
||||||
|
setSkills(skillsData);
|
||||||
|
} else {
|
||||||
|
setSkills(Object.keys(skillsData));
|
||||||
|
}
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
console.error('Failed to fetch skills:', err);
|
console.error('Failed to fetch skills:', err);
|
||||||
} finally {
|
} finally {
|
||||||
|
|
|
||||||
|
|
@ -2,19 +2,27 @@ import { useState, useEffect } from 'react';
|
||||||
import apiClient from '../../api/client';
|
import apiClient from '../../api/client';
|
||||||
import { FileCode, Trash2, Plus, LayoutTemplate } from 'lucide-react';
|
import { FileCode, Trash2, Plus, LayoutTemplate } from 'lucide-react';
|
||||||
|
|
||||||
interface WorkflowTemplate {
|
import type { WorkflowTemplate as ParsedWorkflowTemplate } from '../../types';
|
||||||
name: string;
|
|
||||||
[key: string]: any;
|
|
||||||
}
|
|
||||||
|
|
||||||
export function WorkflowTemplateSettings() {
|
export function WorkflowTemplateSettings() {
|
||||||
const [templates, setTemplates] = useState<Record<string, WorkflowTemplate>>({});
|
const [templates, setTemplates] = useState<Record<string, ParsedWorkflowTemplate>>({});
|
||||||
const [loading, setLoading] = useState(true);
|
const [loading, setLoading] = useState(true);
|
||||||
const [templateJson, setTemplateJson] = useState('{\n "name": "my_template"\n}');
|
const [templateJson, setTemplateJson] = useState('{\n "name": "my_template",\n "steps": [\n {\n "name": "step1",\n "actor": "actor_name"\n }\n ]\n}');
|
||||||
const [creating, setCreating] = useState(false);
|
const [creating, setCreating] = useState(false);
|
||||||
const [message, setMessage] = useState('');
|
const [message, setMessage] = useState('');
|
||||||
const [error, setError] = useState('');
|
const [error, setError] = useState('');
|
||||||
|
|
||||||
|
const validateTemplate = (data: any): data is ParsedWorkflowTemplate => {
|
||||||
|
if (!data || typeof data !== 'object') return false;
|
||||||
|
if (typeof data.name !== 'string') return false;
|
||||||
|
if (!Array.isArray(data.steps)) return false;
|
||||||
|
for (const step of data.steps) {
|
||||||
|
if (typeof step.name !== 'string') return false;
|
||||||
|
if (typeof step.actor !== 'string') return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
};
|
||||||
|
|
||||||
const fetchTemplates = async () => {
|
const fetchTemplates = async () => {
|
||||||
setLoading(true);
|
setLoading(true);
|
||||||
try {
|
try {
|
||||||
|
|
@ -39,16 +47,21 @@ export function WorkflowTemplateSettings() {
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const parsedJson = JSON.parse(templateJson);
|
const parsedJson = JSON.parse(templateJson);
|
||||||
|
|
||||||
|
if (!validateTemplate(parsedJson)) {
|
||||||
|
throw new Error('JSON structure does not match WorkflowTemplate schema (requires name and steps array with name and actor).');
|
||||||
|
}
|
||||||
|
|
||||||
await apiClient.post('/api/v1/resource/workflow_template', parsedJson);
|
await apiClient.post('/api/v1/resource/workflow_template', parsedJson);
|
||||||
setMessage('Workflow template created successfully');
|
setMessage('Workflow template created successfully');
|
||||||
setTemplateJson('{\n "name": "my_template"\n}');
|
setTemplateJson('{\n "name": "my_template",\n "steps": []\n}');
|
||||||
fetchTemplates();
|
fetchTemplates();
|
||||||
} catch (err: any) {
|
} catch (err: any) {
|
||||||
console.error(err);
|
console.error(err);
|
||||||
if (err instanceof SyntaxError) {
|
if (err instanceof SyntaxError) {
|
||||||
setError('Invalid JSON format');
|
setError('Invalid JSON format');
|
||||||
} else {
|
} else {
|
||||||
setError(err.response?.data?.message || 'Failed to create workflow template');
|
setError(err.message || err.response?.data?.message || 'Failed to create workflow template');
|
||||||
}
|
}
|
||||||
} finally {
|
} finally {
|
||||||
setCreating(false);
|
setCreating(false);
|
||||||
|
|
|
||||||
|
|
@ -6,34 +6,59 @@ export function useClusterState() {
|
||||||
const [isConnected, setIsConnected] = useState(false);
|
const [isConnected, setIsConnected] = useState(false);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
// Determine WS URL based on API base URL or window location
|
let ws: WebSocket | null = null;
|
||||||
const wsBase = import.meta.env.VITE_API_BASE_URL
|
let reconnectTimeout: ReturnType<typeof setTimeout>;
|
||||||
? import.meta.env.VITE_API_BASE_URL.replace('http', 'ws')
|
let retryCount = 0;
|
||||||
: `ws://localhost:8000`;
|
const maxRetryCount = 10;
|
||||||
|
const baseDelay = 1000;
|
||||||
|
|
||||||
const ws = new WebSocket(`${wsBase}/api/v1/cluster/ws/state`);
|
const connect = () => {
|
||||||
|
// Determine WS URL based on API base URL or window location
|
||||||
|
const protocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:';
|
||||||
|
const host = window.location.host;
|
||||||
|
|
||||||
ws.onopen = () => {
|
const wsBase = import.meta.env.VITE_API_BASE_URL
|
||||||
setIsConnected(true);
|
? import.meta.env.VITE_API_BASE_URL.replace(/^http/, 'ws')
|
||||||
};
|
: `${protocol}//${host}`;
|
||||||
|
|
||||||
ws.onmessage = (event) => {
|
ws = new WebSocket(`${wsBase}/api/v1/cluster/ws/state`);
|
||||||
try {
|
|
||||||
const data = JSON.parse(event.data);
|
ws.onopen = () => {
|
||||||
if (Array.isArray(data)) {
|
setIsConnected(true);
|
||||||
setNodes(data);
|
retryCount = 0; // Reset retry count on successful connection
|
||||||
|
};
|
||||||
|
|
||||||
|
ws.onmessage = (event) => {
|
||||||
|
try {
|
||||||
|
const data = JSON.parse(event.data);
|
||||||
|
if (Array.isArray(data)) {
|
||||||
|
setNodes(data);
|
||||||
|
}
|
||||||
|
} catch (e) {
|
||||||
|
console.error("Error parsing cluster state websocket message", e);
|
||||||
}
|
}
|
||||||
} catch (e) {
|
};
|
||||||
console.error("Error parsing cluster state websocket message", e);
|
|
||||||
}
|
ws.onclose = () => {
|
||||||
|
setIsConnected(false);
|
||||||
|
if (retryCount < maxRetryCount) {
|
||||||
|
const delay = baseDelay * Math.pow(2, retryCount);
|
||||||
|
retryCount++;
|
||||||
|
console.log(`WebSocket closed. Reconnecting in ${delay}ms... (Attempt ${retryCount})`);
|
||||||
|
reconnectTimeout = setTimeout(connect, delay);
|
||||||
|
} else {
|
||||||
|
console.error("Max WebSocket reconnect attempts reached.");
|
||||||
|
}
|
||||||
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
ws.onclose = () => {
|
connect();
|
||||||
setIsConnected(false);
|
|
||||||
};
|
|
||||||
|
|
||||||
return () => {
|
return () => {
|
||||||
ws.close();
|
clearTimeout(reconnectTimeout);
|
||||||
|
if (ws) {
|
||||||
|
ws.close();
|
||||||
|
}
|
||||||
};
|
};
|
||||||
}, []);
|
}, []);
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -18,6 +18,7 @@ export interface Provider {
|
||||||
provider_title: string;
|
provider_title: string;
|
||||||
provider_url?: string;
|
provider_url?: string;
|
||||||
provider_owner?: string;
|
provider_owner?: string;
|
||||||
|
provider_models?: string[];
|
||||||
// Based on your UI needs we might infer some local status fields
|
// Based on your UI needs we might infer some local status fields
|
||||||
status?: string;
|
status?: string;
|
||||||
model?: string;
|
model?: string;
|
||||||
|
|
@ -65,3 +66,18 @@ export interface Workflow {
|
||||||
workflow_title: string;
|
workflow_title: string;
|
||||||
status?: string;
|
status?: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Workflow Template Validation
|
||||||
|
export interface WorkStep {
|
||||||
|
name: string;
|
||||||
|
desc?: string;
|
||||||
|
actor: string; // the name of the worker individual
|
||||||
|
inputs?: string[];
|
||||||
|
outputs?: string[];
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface WorkflowTemplate {
|
||||||
|
name: string;
|
||||||
|
desc?: string;
|
||||||
|
steps: WorkStep[];
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -53,7 +53,7 @@ class AgentFactory:
|
||||||
if provider.provider_type not in self._models_mapping:
|
if provider.provider_type not in self._models_mapping:
|
||||||
raise ValueError(f"不支持的协议类型: {provider.provider_type}")
|
raise ValueError(f"不支持的协议类型: {provider.provider_type}")
|
||||||
model_class, provider_class = self._models_mapping[provider.provider_type]
|
model_class, provider_class = self._models_mapping[provider.provider_type]
|
||||||
model = model_class(model_id, provider_class(api_key=provider.api_key, url=provider.url))
|
model = model_class(model_id, provider=provider_class(api_key=provider.provider_apikey, base_url=provider.provider_url))
|
||||||
agent = Agent(model=model,
|
agent = Agent(model=model,
|
||||||
name=agent_name,
|
name=agent_name,
|
||||||
system_prompt=system_prompt,
|
system_prompt=system_prompt,
|
||||||
|
|
|
||||||
|
|
@ -43,18 +43,21 @@ async def load_agent(agent_register: Union[AgentRegister, AgentLocalRegister],
|
||||||
pass
|
pass
|
||||||
|
|
||||||
elif isinstance(agent_register, AgentRegister):
|
elif isinstance(agent_register, AgentRegister):
|
||||||
match agent_register.individual_name:
|
try:
|
||||||
case "supervisory_node":
|
match agent_register.individual_name:
|
||||||
node = ray_actor_hook("supervisory_node").supervisory_node
|
case "supervisory_node":
|
||||||
node.create_agent.remote(global_state_machine,agent_register.provider_title,agent_register.model_id)
|
node = ray_actor_hook("supervisory_node").supervisory_node
|
||||||
case "consciousness_node":
|
await node.create_agent.remote(global_state_machine,agent_register.provider_title,agent_register.model_id)
|
||||||
node = ray_actor_hook("consciousness_node").consciousness_node
|
case "consciousness_node":
|
||||||
node.create_agent.remote(global_state_machine,agent_register.provider_title,agent_register.model_id)
|
node = ray_actor_hook("consciousness_node").consciousness_node
|
||||||
case "control_node":
|
await node.create_agent.remote(global_state_machine,agent_register.provider_title,agent_register.model_id)
|
||||||
node = ray_actor_hook("control_node").control_node
|
case "control_node":
|
||||||
node.create_agent.remote(global_state_machine,agent_register.provider_title,agent_register.model_id)
|
node = ray_actor_hook("control_node").control_node
|
||||||
case _:
|
await node.create_agent.remote(global_state_machine,agent_register.provider_title,agent_register.model_id)
|
||||||
pass
|
case _:
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail=f"加载节点失败: {str(e)}")
|
||||||
return {"message": "创建成功"}
|
return {"message": "创建成功"}
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -85,25 +88,25 @@ class WorkerIndividualUpdate(BaseModel):
|
||||||
@agent_router.post("/worker")
|
@agent_router.post("/worker")
|
||||||
async def create_worker_individual(worker_data: WorkerIndividualCreate,
|
async def create_worker_individual(worker_data: WorkerIndividualCreate,
|
||||||
token_data: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.USER))):
|
token_data: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.USER))):
|
||||||
postgres_database = ray_actor_hook("postgres_database")
|
postgres_database = ray_actor_hook("postgres_database").postgres_database
|
||||||
data_dict = worker_data.model_dump()
|
data_dict = worker_data.model_dump()
|
||||||
data_dict["owner_id"] = token_data.user_id
|
data_dict["owner_id"] = token_data.user_id
|
||||||
worker = await postgres_database.individual_database.remote("add_worker_individual", **data_dict)
|
worker = await postgres_database.add_worker_individual.remote( **data_dict)
|
||||||
return {"message": "success", "agent_id": worker.agent_id}
|
return {"message": "success", "agent_id": worker.agent_id}
|
||||||
|
|
||||||
|
|
||||||
@agent_router.get("/worker")
|
@agent_router.get("/worker")
|
||||||
async def get_worker_individual_list(token_data: TokenData = Depends(Accessor.get_current_user)):
|
async def get_worker_individual_list(token_data: TokenData = Depends(Accessor.get_current_user)):
|
||||||
postgres_database = ray_actor_hook("postgres_database")
|
postgres_database = ray_actor_hook("postgres_database").postgres_database
|
||||||
workers = await postgres_database.individual_database.remote("get_worker_individual_list", owner_id=token_data.user_id)
|
workers = await postgres_database.get_worker_individual_list.remote( owner_id=token_data.user_id)
|
||||||
return {"workers": workers}
|
return {"workers": workers}
|
||||||
|
|
||||||
|
|
||||||
@agent_router.get("/worker/{agent_id}")
|
@agent_router.get("/worker/{agent_id}")
|
||||||
async def get_worker_individual(agent_id: str,
|
async def get_worker_individual(agent_id: str,
|
||||||
token_data: TokenData = Depends(Accessor.get_current_user)):
|
token_data: TokenData = Depends(Accessor.get_current_user)):
|
||||||
postgres_database = ray_actor_hook("postgres_database")
|
postgres_database = ray_actor_hook("postgres_database").postgres_database
|
||||||
worker = await postgres_database.individual_database.remote("get_worker_individual", agent_id=agent_id)
|
worker = await postgres_database.get_worker_individual.remote( agent_id=agent_id)
|
||||||
if not worker:
|
if not worker:
|
||||||
raise HTTPException(status_code=404, detail="Agent not found")
|
raise HTTPException(status_code=404, detail="Agent not found")
|
||||||
if worker.owner_id != token_data.user_id:
|
if worker.owner_id != token_data.user_id:
|
||||||
|
|
@ -115,26 +118,26 @@ async def get_worker_individual(agent_id: str,
|
||||||
async def update_worker_individual(agent_id: str,
|
async def update_worker_individual(agent_id: str,
|
||||||
worker_data: WorkerIndividualUpdate,
|
worker_data: WorkerIndividualUpdate,
|
||||||
token_data: TokenData = Depends(Accessor.get_current_user)):
|
token_data: TokenData = Depends(Accessor.get_current_user)):
|
||||||
postgres_database = ray_actor_hook("postgres_database")
|
postgres_database = ray_actor_hook("postgres_database").postgres_database
|
||||||
worker = await postgres_database.individual_database.remote("get_worker_individual", agent_id=agent_id)
|
worker = await postgres_database.get_worker_individual.remote( agent_id=agent_id)
|
||||||
if not worker:
|
if not worker:
|
||||||
raise HTTPException(status_code=404, detail="Agent not found")
|
raise HTTPException(status_code=404, detail="Agent not found")
|
||||||
if worker.owner_id != token_data.user_id:
|
if worker.owner_id != token_data.user_id:
|
||||||
raise HTTPException(status_code=403, detail="Forbidden: You do not own this agent")
|
raise HTTPException(status_code=403, detail="Forbidden: You do not own this agent")
|
||||||
|
|
||||||
update_data = worker_data.model_dump(exclude_unset=True)
|
update_data = worker_data.model_dump(exclude_unset=True)
|
||||||
updated_worker = await postgres_database.individual_database.remote("update_worker_individual", agent_id=agent_id, **update_data)
|
updated_worker = await postgres_database.update_worker_individual.remote( agent_id=agent_id, **update_data)
|
||||||
return {"message": "success", "worker": updated_worker}
|
return {"message": "success", "worker": updated_worker}
|
||||||
|
|
||||||
|
|
||||||
@agent_router.delete("/worker/{agent_id}")
|
@agent_router.delete("/worker/{agent_id}")
|
||||||
async def delete_worker_individual(agent_id: str,
|
async def delete_worker_individual(agent_id: str,
|
||||||
token_data: TokenData = Depends(Accessor.get_current_user)):
|
token_data: TokenData = Depends(Accessor.get_current_user)):
|
||||||
postgres_database = ray_actor_hook("postgres_database")
|
postgres_database = ray_actor_hook("postgres_database").postgres_database
|
||||||
worker = await postgres_database.individual_database.remote("get_worker_individual", agent_id=agent_id)
|
worker = await postgres_database.get_worker_individual.remote( agent_id=agent_id)
|
||||||
if not worker:
|
if not worker:
|
||||||
raise HTTPException(status_code=404, detail="Agent not found")
|
raise HTTPException(status_code=404, detail="Agent not found")
|
||||||
if worker.owner_id != token_data.user_id:
|
if worker.owner_id != token_data.user_id:
|
||||||
raise HTTPException(status_code=403, detail="Forbidden: You do not own this agent")
|
raise HTTPException(status_code=403, detail="Forbidden: You do not own this agent")
|
||||||
await postgres_database.individual_database.remote("delete_worker_individual", agent_id=agent_id)
|
await postgres_database.delete_worker_individual.remote( agent_id=agent_id)
|
||||||
return {"message": "success"}
|
return {"message": "success"}
|
||||||
|
|
@ -32,7 +32,7 @@ class UserRegister(BaseModel):
|
||||||
async def create_user(user_register: UserRegister):
|
async def create_user(user_register: UserRegister):
|
||||||
postgres_database = ray_actor_hook("postgres_database").postgres_database
|
postgres_database = ray_actor_hook("postgres_database").postgres_database
|
||||||
hashed_password = await run_in_threadpool(Accessor.hash_password, user_register.password)
|
hashed_password = await run_in_threadpool(Accessor.hash_password, user_register.password)
|
||||||
user = await postgres_database.auth_database.remote("add_user", user_register.user_name, hashed_password)
|
user = await postgres_database.add_user.remote( user_register.user_name, hashed_password)
|
||||||
return {"message": "success", "user_id": user.user_id}
|
return {"message": "success", "user_id": user.user_id}
|
||||||
|
|
||||||
class UserLogin(BaseModel):
|
class UserLogin(BaseModel):
|
||||||
|
|
@ -42,7 +42,7 @@ class UserLogin(BaseModel):
|
||||||
@auth_router.post("/login")
|
@auth_router.post("/login")
|
||||||
async def login_user(user_login: UserLogin):
|
async def login_user(user_login: UserLogin):
|
||||||
postgres_database = ray_actor_hook("postgres_database").postgres_database
|
postgres_database = ray_actor_hook("postgres_database").postgres_database
|
||||||
user = await postgres_database.auth_database.remote("login_user", user_login.user_name)
|
user = await postgres_database.login_user.remote( user_login.user_name)
|
||||||
if not user:
|
if not user:
|
||||||
raise UserNotExistError()
|
raise UserNotExistError()
|
||||||
token = await run_in_threadpool(Accessor.login_hashed_password, user, user_login.password)
|
token = await run_in_threadpool(Accessor.login_hashed_password, user, user_login.password)
|
||||||
|
|
@ -61,7 +61,7 @@ async def change_authority(
|
||||||
Update a user's authority level. Only accessible by SUPER_ADMINISTRATOR.
|
Update a user's authority level. Only accessible by SUPER_ADMINISTRATOR.
|
||||||
"""
|
"""
|
||||||
postgres_database = ray_actor_hook("postgres_database").postgres_database
|
postgres_database = ray_actor_hook("postgres_database").postgres_database
|
||||||
user = await postgres_database.auth_database.remote("change_user_authority", user_id=request.user_id, new_authority=request.new_authority)
|
user = await postgres_database.change_user_authority.remote( user_id=request.user_id, new_authority=request.new_authority)
|
||||||
return {"message": "success", "user_id": user.user_id, "new_authority": user.user_authority}
|
return {"message": "success", "user_id": user.user_id, "new_authority": user.user_authority}
|
||||||
|
|
||||||
@auth_router.get("/list")
|
@auth_router.get("/list")
|
||||||
|
|
@ -72,7 +72,7 @@ async def get_user_list(
|
||||||
Get a list of all users. Only accessible by SUPER_ADMINISTRATOR.
|
Get a list of all users. Only accessible by SUPER_ADMINISTRATOR.
|
||||||
"""
|
"""
|
||||||
postgres_database = ray_actor_hook("postgres_database").postgres_database
|
postgres_database = ray_actor_hook("postgres_database").postgres_database
|
||||||
users = await postgres_database.auth_database.remote("get_all_users")
|
users = await postgres_database.get_all_users.remote()
|
||||||
return {"users": [{"user_id": u.user_id, "user_name": u.user_name, "role": u.user_authority} for u in users]}
|
return {"users": [{"user_id": u.user_id, "user_name": u.user_name, "role": u.user_authority} for u in users]}
|
||||||
|
|
||||||
@auth_router.delete("/{user_id}")
|
@auth_router.delete("/{user_id}")
|
||||||
|
|
@ -84,5 +84,5 @@ async def delete_user(
|
||||||
Delete a user. Only accessible by SUPER_ADMINISTRATOR.
|
Delete a user. Only accessible by SUPER_ADMINISTRATOR.
|
||||||
"""
|
"""
|
||||||
postgres_database = ray_actor_hook("postgres_database").postgres_database
|
postgres_database = ray_actor_hook("postgres_database").postgres_database
|
||||||
await postgres_database.auth_database.remote("delete_user_by_id", user_id=user_id)
|
await postgres_database.delete_user_by_id.remote( user_id=user_id)
|
||||||
return {"message": "success"}
|
return {"message": "success"}
|
||||||
|
|
@ -36,5 +36,10 @@ async def update_cluster_state(websocket: WebSocket):
|
||||||
]
|
]
|
||||||
await websocket.send_json(payload)
|
await websocket.send_json(payload)
|
||||||
await asyncio.sleep(10)
|
await asyncio.sleep(10)
|
||||||
except (WebSocketDisconnect, RuntimeError):
|
except WebSocketDisconnect:
|
||||||
|
pass
|
||||||
|
except RuntimeError as e:
|
||||||
|
if "closed" not in str(e) and "GeneratorExit" not in str(e):
|
||||||
|
raise
|
||||||
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
@ -34,7 +34,7 @@ async def create_message(message: Message,
|
||||||
user_id=str(token_data.user_id),
|
user_id=str(token_data.user_id),
|
||||||
user_name=token_data.username,
|
user_name=token_data.username,
|
||||||
message=message.message)
|
message=message.message)
|
||||||
supervisory_node = ray_actor_hook("supervisor_node")
|
supervisory_node = ray_actor_hook("supervisory_node").supervisory_node
|
||||||
message = await supervisory_node.working.remote(event)
|
message = await supervisory_node.working.remote(event)
|
||||||
if message == "任务已创建":
|
if message == "任务已创建":
|
||||||
return {"message": event.trace_id}
|
return {"message": event.trace_id}
|
||||||
|
|
|
||||||
|
|
@ -44,13 +44,12 @@ async def create_provider(provider_register: ProviderRegister,
|
||||||
@provider_router.get("/list")
|
@provider_router.get("/list")
|
||||||
async def get_provider_list(_: TokenData = Depends(Accessor.get_current_user)) -> Dict[str, Dict[str, Provider]]:
|
async def get_provider_list(_: TokenData = Depends(Accessor.get_current_user)) -> Dict[str, Dict[str, Provider]]:
|
||||||
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
|
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
|
||||||
provider_list: Dict[str, Provider] = await global_state_machine.provider_manager.remote("get_provider_list")
|
provider_list: Dict[str, Provider] = await global_state_machine.get_provider_list.remote()
|
||||||
providers = list(provider_list.values()) if provider_list else []
|
|
||||||
return {"provider_list": provider_list}
|
return {"provider_list": provider_list}
|
||||||
|
|
||||||
@provider_router.delete("/{provider_title}")
|
@provider_router.delete("/{provider_title}")
|
||||||
async def delete_provider(provider_title: str, _: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.SUPER_ADMINISTRATOR))) -> dict:
|
async def delete_provider(provider_title: str, _: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.SUPER_ADMINISTRATOR))) -> dict:
|
||||||
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
|
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
|
||||||
postgres_database = ray_actor_hook("postgres_database").postgres_database
|
postgres_database = ray_actor_hook("postgres_database").postgres_database
|
||||||
await global_state_machine.provider_manager.remote("delete_provider", provider_title=provider_title, postgres_database=postgres_database)
|
await global_state_machine.delete_provider.remote( provider_title=provider_title, postgres_database=postgres_database)
|
||||||
return {"message": "success"}
|
return {"message": "success"}
|
||||||
|
|
@ -27,19 +27,19 @@ resource_router = APIRouter(prefix="/api/v1/resource")
|
||||||
async def create_workflow_template(workflow_template: WorkflowTemplate,
|
async def create_workflow_template(workflow_template: WorkflowTemplate,
|
||||||
_: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.USER))):
|
_: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.USER))):
|
||||||
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
|
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
|
||||||
await global_state_machine.workflow_template_manager.remote("add_workflow_template", workflow_template.name, workflow_template)
|
await global_state_machine.add_workflow_template.remote( workflow_template.name, workflow_template)
|
||||||
return {"message": "创建成功"}
|
return {"message": "创建成功"}
|
||||||
|
|
||||||
@resource_router.get("/workflow_template")
|
@resource_router.get("/workflow_template")
|
||||||
async def get_workflow_templates(_: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.USER))):
|
async def get_workflow_templates(_: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.USER))):
|
||||||
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
|
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
|
||||||
templates = await global_state_machine.workflow_template_manager.remote("get_all_workflow_templates")
|
templates = await global_state_machine.get_all_workflow_templates.remote()
|
||||||
return {"templates": templates}
|
return {"templates": templates}
|
||||||
|
|
||||||
@resource_router.delete("/workflow_template/{template_name}")
|
@resource_router.delete("/workflow_template/{template_name}")
|
||||||
async def delete_workflow_template(template_name: str, _: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.SUPER_ADMINISTRATOR))):
|
async def delete_workflow_template(template_name: str, _: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.SUPER_ADMINISTRATOR))):
|
||||||
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
|
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
|
||||||
await global_state_machine.workflow_template_manager.remote("delete_workflow_template", template_name)
|
await global_state_machine.delete_workflow_template.remote( template_name)
|
||||||
return {"message": "success"}
|
return {"message": "success"}
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -60,18 +60,18 @@ async def install_skill(skill: Skill,
|
||||||
skill_name = skill.path.split("/")[-1]
|
skill_name = skill.path.split("/")[-1]
|
||||||
else:
|
else:
|
||||||
skill_name = skill.repo_url.split("/")[-1]
|
skill_name = skill.repo_url.split("/")[-1]
|
||||||
await global_state_machine.skill_manager.remote("add_skill", skill_name)
|
await global_state_machine.add_skill.remote( skill_name)
|
||||||
return {"message": "创建成功"}
|
return {"message": "创建成功"}
|
||||||
|
|
||||||
@resource_router.get("/skill")
|
@resource_router.get("/skill")
|
||||||
async def get_skills(_: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.USER))):
|
async def get_skills(_: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.USER))):
|
||||||
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
|
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
|
||||||
skills = await global_state_machine.skill_manager.remote("get_skill_list")
|
skills = await global_state_machine.get_skill_list.remote()
|
||||||
return {"skills": skills}
|
return {"skills": skills}
|
||||||
|
|
||||||
@resource_router.delete("/skill/{skill_name}")
|
@resource_router.delete("/skill/{skill_name}")
|
||||||
async def delete_skill(skill_name: str, _: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.SUPER_ADMINISTRATOR))):
|
async def delete_skill(skill_name: str, _: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.SUPER_ADMINISTRATOR))):
|
||||||
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
|
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
|
||||||
# Note: this only removes it from the state machine manager.
|
# Note: this only removes it from the state machine manager.
|
||||||
await global_state_machine.skill_manager.remote("remove_skill", skill_name)
|
await global_state_machine.remove_skill.remote( skill_name)
|
||||||
return {"message": "success"}
|
return {"message": "success"}
|
||||||
|
|
|
||||||
|
|
@ -29,6 +29,11 @@ async def get_workflow(websocket: WebSocket, event_id: str):
|
||||||
await websocket.send_text(await global_state_machine.get_pending.remote(event_id))
|
await websocket.send_text(await global_state_machine.get_pending.remote(event_id))
|
||||||
response = await websocket.receive_text()
|
response = await websocket.receive_text()
|
||||||
await global_state_machine.put_received(event_id, response)
|
await global_state_machine.put_received(event_id, response)
|
||||||
except (WebSocketDisconnect, RuntimeError):
|
except WebSocketDisconnect:
|
||||||
|
pass
|
||||||
|
except RuntimeError as e:
|
||||||
|
if "closed" not in str(e) and "GeneratorExit" not in str(e):
|
||||||
|
raise
|
||||||
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -53,17 +53,77 @@ class PostgresDatabase:
|
||||||
finally:
|
finally:
|
||||||
self.ready_event.set()
|
self.ready_event.set()
|
||||||
|
|
||||||
async def auth_database(self, method_name: str, *args, **kwargs):
|
# Auth Database Methods
|
||||||
|
async def add_user(self, user_name: str, hashed_password: str):
|
||||||
await self.ready_event.wait()
|
await self.ready_event.wait()
|
||||||
method = getattr(self._auth_database, method_name)
|
return await self._auth_database.add_user(user_name, hashed_password)
|
||||||
return await method(*args, **kwargs)
|
|
||||||
|
|
||||||
async def provider_database(self, method_name: str, *args, **kwargs):
|
async def change_password(self, user_name, old_password, new_password):
|
||||||
await self.ready_event.wait()
|
await self.ready_event.wait()
|
||||||
method = getattr(self._provider_database, method_name)
|
return await self._auth_database.change_password(user_name, old_password, new_password)
|
||||||
return await method(*args, **kwargs)
|
|
||||||
|
|
||||||
async def individual_database(self, method_name: str, *args, **kwargs):
|
async def delete_user(self, user_name: str):
|
||||||
await self.ready_event.wait()
|
await self.ready_event.wait()
|
||||||
method = getattr(self._individual_database, method_name)
|
return await self._auth_database.delete_user(user_name)
|
||||||
return await method(*args, **kwargs)
|
|
||||||
|
async def delete_user_by_id(self, user_id: str):
|
||||||
|
await self.ready_event.wait()
|
||||||
|
return await self._auth_database.delete_user_by_id(user_id)
|
||||||
|
|
||||||
|
async def login_user(self, user_name: str):
|
||||||
|
await self.ready_event.wait()
|
||||||
|
return await self._auth_database.login_user(user_name)
|
||||||
|
|
||||||
|
async def get_all_users(self):
|
||||||
|
await self.ready_event.wait()
|
||||||
|
return await self._auth_database.get_all_users()
|
||||||
|
|
||||||
|
async def get_user_authority(self, user_id: str):
|
||||||
|
await self.ready_event.wait()
|
||||||
|
return await self._auth_database.get_user_authority(user_id)
|
||||||
|
|
||||||
|
async def change_user_authority(self, user_id: str, new_authority):
|
||||||
|
await self.ready_event.wait()
|
||||||
|
return await self._auth_database.change_user_authority(user_id, new_authority)
|
||||||
|
|
||||||
|
# Provider Database Methods
|
||||||
|
async def get_provider(self):
|
||||||
|
await self.ready_event.wait()
|
||||||
|
return await self._provider_database.get_provider()
|
||||||
|
|
||||||
|
async def add_provider_db(self, **kwargs):
|
||||||
|
await self.ready_event.wait()
|
||||||
|
return await self._provider_database.add_provider(**kwargs)
|
||||||
|
|
||||||
|
async def delete_provider_db(self, provider_id: str):
|
||||||
|
await self.ready_event.wait()
|
||||||
|
return await self._provider_database.delete_provider(provider_id)
|
||||||
|
|
||||||
|
async def update_provider_db(self, provider_id: str, **kwargs):
|
||||||
|
await self.ready_event.wait()
|
||||||
|
return await self._provider_database.update_provider(provider_id, **kwargs)
|
||||||
|
|
||||||
|
# Individual Database Methods
|
||||||
|
async def add_worker_individual(self, **kwargs):
|
||||||
|
await self.ready_event.wait()
|
||||||
|
return await self._individual_database.add_worker_individual(**kwargs)
|
||||||
|
|
||||||
|
async def get_worker_individual(self, agent_id: str):
|
||||||
|
await self.ready_event.wait()
|
||||||
|
return await self._individual_database.get_worker_individual(agent_id)
|
||||||
|
|
||||||
|
async def get_worker_individual_list(self, owner_id: str):
|
||||||
|
await self.ready_event.wait()
|
||||||
|
return await self._individual_database.get_worker_individual_list(owner_id)
|
||||||
|
|
||||||
|
async def update_worker_individual(self, agent_id: str, **kwargs):
|
||||||
|
await self.ready_event.wait()
|
||||||
|
return await self._individual_database.update_worker_individual(agent_id, **kwargs)
|
||||||
|
|
||||||
|
async def delete_worker_individual(self, agent_id: str):
|
||||||
|
await self.ready_event.wait()
|
||||||
|
return await self._individual_database.delete_worker_individual(agent_id)
|
||||||
|
|
||||||
|
async def get_all_worker_individual(self):
|
||||||
|
await self.ready_event.wait()
|
||||||
|
return await self._individual_database.get_all_worker_individual()
|
||||||
|
|
@ -53,35 +53,64 @@ class GlobalStateMachine:
|
||||||
postgres_database=self.postgres_database
|
postgres_database=self.postgres_database
|
||||||
)
|
)
|
||||||
|
|
||||||
async def provider_manager(self, method_name: str, *args, **kwargs):
|
# Provider Manager Methods
|
||||||
method = getattr(self._global_provider_manager, method_name)
|
def get_provider_list(self):
|
||||||
if asyncio.iscoroutinefunction(method):
|
return self._global_provider_manager.get_provider_list()
|
||||||
return await method(*args, **kwargs)
|
|
||||||
return method(*args, **kwargs)
|
|
||||||
|
|
||||||
async def tool_manager(self, method_name: str, *args, **kwargs):
|
def get_provider(self, provider_title):
|
||||||
method = getattr(self._global_tool_manager, method_name)
|
return self._global_provider_manager.get_provider(provider_title)
|
||||||
if asyncio.iscoroutinefunction(method):
|
|
||||||
return await method(*args, **kwargs)
|
|
||||||
return method(*args, **kwargs)
|
|
||||||
|
|
||||||
async def workflow_template_manager(self, method_name: str, *args, **kwargs):
|
async def delete_provider(self, provider_title: str):
|
||||||
method = getattr(self._global_workflow_template_manager, method_name)
|
return await self._global_provider_manager.delete_provider(provider_title, self.postgres_database)
|
||||||
if asyncio.iscoroutinefunction(method):
|
|
||||||
return await method(*args, **kwargs)
|
|
||||||
return method(*args, **kwargs)
|
|
||||||
|
|
||||||
async def skill_manager(self, method_name: str, *args, **kwargs):
|
# Tool Manager Methods
|
||||||
method = getattr(self._global_skill_manager, method_name)
|
def get_tool_mapper(self):
|
||||||
if asyncio.iscoroutinefunction(method):
|
return self._global_tool_manager.tool_mapper
|
||||||
return await method(*args, **kwargs)
|
|
||||||
return method(*args, **kwargs)
|
|
||||||
|
|
||||||
async def individual_manager(self, method_name: str, *args, **kwargs):
|
def get_tool_list(self, agent_name: str):
|
||||||
method = getattr(self._global_individual_manager, method_name)
|
# get_tool_list didn't actually exist on tool_manager, let's implement it to return the tools
|
||||||
if asyncio.iscoroutinefunction(method):
|
# for a specific agent name (or scope)
|
||||||
return await method(*args, **kwargs)
|
tools = self._global_tool_manager.tool_mapper.get(agent_name, {})
|
||||||
return method(*args, **kwargs)
|
# also include default tools
|
||||||
|
default_tools = self._global_tool_manager.tool_mapper.get("default", {})
|
||||||
|
merged_tools = {**default_tools, **tools}
|
||||||
|
return merged_tools
|
||||||
|
|
||||||
|
# Workflow Template Manager Methods
|
||||||
|
def get_all_workflow_templates(self):
|
||||||
|
return self._global_workflow_template_manager.get_all_workflow_templates()
|
||||||
|
|
||||||
|
def add_workflow_template(self, template_name: str, workflow_template):
|
||||||
|
return self._global_workflow_template_manager.add_workflow_template(template_name, workflow_template)
|
||||||
|
|
||||||
|
def delete_workflow_template(self, template_name: str):
|
||||||
|
return self._global_workflow_template_manager.delete_workflow_template(template_name)
|
||||||
|
|
||||||
|
def generate_workflow_template(self, workflow_template):
|
||||||
|
return self._global_workflow_template_manager.generate_workflow_template(workflow_template)
|
||||||
|
|
||||||
|
# Skill Manager Methods
|
||||||
|
def add_skill(self, skill_name: str):
|
||||||
|
return self._global_skill_manager.add_skill(skill_name)
|
||||||
|
|
||||||
|
def get_skill_list(self):
|
||||||
|
return self._global_skill_manager.get_skill_list()
|
||||||
|
|
||||||
|
def remove_skill(self, skill_name: str):
|
||||||
|
return self._global_skill_manager.remove_skill(skill_name)
|
||||||
|
|
||||||
|
# Individual Manager Methods
|
||||||
|
def add_individual(self, agent_id: str, config):
|
||||||
|
return self._global_individual_manager.add_individual(agent_id, config)
|
||||||
|
|
||||||
|
def get_individual(self, agent_id: str):
|
||||||
|
return self._global_individual_manager.get_individual(agent_id)
|
||||||
|
|
||||||
|
def remove_individual(self, agent_id: str):
|
||||||
|
return self._global_individual_manager.remove_individual(agent_id)
|
||||||
|
|
||||||
|
def list_individuals(self):
|
||||||
|
return self._global_individual_manager.list_individuals()
|
||||||
|
|
||||||
###以下方法为event_dict方法
|
###以下方法为event_dict方法
|
||||||
def add_event(self, event: PretorEvent) -> None:
|
def add_event(self, event: PretorEvent) -> None:
|
||||||
|
|
|
||||||
|
|
@ -23,7 +23,7 @@ class GlobalIndividualManager:
|
||||||
async def init_individual_register(self, postgres) -> None:
|
async def init_individual_register(self, postgres) -> None:
|
||||||
try:
|
try:
|
||||||
try:
|
try:
|
||||||
individuals = await postgres.individual_database.remote("get_all_worker_individual")
|
individuals = await postgres.get_all_worker_individual.remote()
|
||||||
for ind in individuals:
|
for ind in individuals:
|
||||||
agent_id = getattr(ind, 'agent_id', None)
|
agent_id = getattr(ind, 'agent_id', None)
|
||||||
if agent_id:
|
if agent_id:
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,4 @@
|
||||||
|
from pretor.utils.retry import retry_on_retryable_error
|
||||||
# Copyright 2026 zhaoxi826
|
# Copyright 2026 zhaoxi826
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
|
@ -24,6 +25,7 @@ class ClaudeProvider(BaseProvider):
|
||||||
return provider
|
return provider
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@retry_on_retryable_error()
|
||||||
async def _load_models(provider_args: ProviderArgs) -> List[str]:
|
async def _load_models(provider_args: ProviderArgs) -> List[str]:
|
||||||
# Anthropic 官方需要 version 头
|
# Anthropic 官方需要 version 头
|
||||||
headers = {
|
headers = {
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,4 @@
|
||||||
|
from pretor.utils.retry import retry_on_retryable_error
|
||||||
# Copyright 2026 zhaoxi826
|
# Copyright 2026 zhaoxi826
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
|
@ -24,6 +25,7 @@ class GeminiProvider(BaseProvider):
|
||||||
return provider
|
return provider
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@retry_on_retryable_error()
|
||||||
async def _load_models(provider_args: ProviderArgs) -> List[str]:
|
async def _load_models(provider_args: ProviderArgs) -> List[str]:
|
||||||
# Google Gemini 原生鉴权通常使用 x-goog-api-key 或 query parameter
|
# Google Gemini 原生鉴权通常使用 x-goog-api-key 或 query parameter
|
||||||
headers = {
|
headers = {
|
||||||
|
|
@ -46,6 +48,10 @@ class GeminiProvider(BaseProvider):
|
||||||
model_ids = [m["name"].split("/")[-1] for m in raw_models if
|
model_ids = [m["name"].split("/")[-1] for m in raw_models if
|
||||||
"generateContent" in m.get("supportedGenerationMethods", [])]
|
"generateContent" in m.get("supportedGenerationMethods", [])]
|
||||||
return sorted(list(set(model_ids)))
|
return sorted(list(set(model_ids)))
|
||||||
|
except httpx.RequestError as e:
|
||||||
|
from pretor.utils.error import RetryableError
|
||||||
|
print(f"[{provider_args.provider_title}] 网络请求异常: {e}")
|
||||||
|
raise RetryableError(f"[{provider_args.provider_title}] 网络请求异常: {e}") from e
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"[{provider_args.provider_title}] 获取 Gemini 模型列表错误: {e}")
|
print(f"[{provider_args.provider_title}] 获取 Gemini 模型列表错误: {e}")
|
||||||
return []
|
return []
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,4 @@
|
||||||
|
from pretor.utils.retry import retry_on_retryable_error
|
||||||
# Copyright 2026 zhaoxi826
|
# Copyright 2026 zhaoxi826
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
|
@ -24,6 +25,7 @@ class OpenAIProvider(BaseProvider):
|
||||||
return provider
|
return provider
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@retry_on_retryable_error()
|
||||||
async def _load_models(provider_args: ProviderArgs) -> List[str]:
|
async def _load_models(provider_args: ProviderArgs) -> List[str]:
|
||||||
headers = {
|
headers = {
|
||||||
"Authorization": f"Bearer {provider_args.provider_apikey}",
|
"Authorization": f"Bearer {provider_args.provider_apikey}",
|
||||||
|
|
@ -41,8 +43,9 @@ class OpenAIProvider(BaseProvider):
|
||||||
model_ids = [m["id"] for m in raw_models]
|
model_ids = [m["id"] for m in raw_models]
|
||||||
return sorted(model_ids)
|
return sorted(model_ids)
|
||||||
except httpx.RequestError as e:
|
except httpx.RequestError as e:
|
||||||
|
from pretor.utils.error import RetryableError
|
||||||
print(f"[{provider_args.provider_title}] 网络请求异常: {e}")
|
print(f"[{provider_args.provider_title}] 网络请求异常: {e}")
|
||||||
return []
|
raise RetryableError(f"[{provider_args.provider_title}] 网络请求异常: {e}") from e
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"[{provider_args.provider_title}] 解析模型列表时发生错误: {e}")
|
print(f"[{provider_args.provider_title}] 解析模型列表时发生错误: {e}")
|
||||||
return []
|
return []
|
||||||
|
|
|
||||||
|
|
@ -33,7 +33,7 @@ class ProviderManager:
|
||||||
self.provider_register = {}
|
self.provider_register = {}
|
||||||
|
|
||||||
async def init_provider_register(self, postgres) -> None:
|
async def init_provider_register(self, postgres) -> None:
|
||||||
providers = await postgres.provider_database.remote("get_provider")
|
providers = await postgres.get_provider.remote()
|
||||||
for provider in providers:
|
for provider in providers:
|
||||||
self.provider_register[provider.provider_title] = provider
|
self.provider_register[provider.provider_title] = provider
|
||||||
|
|
||||||
|
|
@ -48,14 +48,17 @@ class ProviderManager:
|
||||||
provider_apikey=provider_apikey,
|
provider_apikey=provider_apikey,
|
||||||
provider_owner=provider_owner)
|
provider_owner=provider_owner)
|
||||||
try:
|
try:
|
||||||
|
import ulid
|
||||||
provider_class = self.provider_mapper.get(provider_type, None)
|
provider_class = self.provider_mapper.get(provider_type, None)
|
||||||
if provider_class is None:
|
if provider_class is None:
|
||||||
logger.warning(f"Provider type {provider_type} is not supported.")
|
logger.warning(f"Provider type {provider_type} is not supported.")
|
||||||
return None
|
return None
|
||||||
provider: Provider = await provider_class.create_model(provider_args)
|
provider: Provider = await provider_class.create_provider(provider_args)
|
||||||
provider.provider_owner = provider_owner
|
provider.provider_owner = provider_owner
|
||||||
self.provider_register[provider_title] = provider
|
self.provider_register[provider_title] = provider
|
||||||
await postgres_database.provider_database.remote("add_provider", provider_title=provider.provider_title,
|
await postgres_database.add_provider_db.remote(
|
||||||
|
provider_id=str(ulid.ULID()),
|
||||||
|
provider_title=provider.provider_title,
|
||||||
provider_url=provider.provider_url,
|
provider_url=provider.provider_url,
|
||||||
provider_apikey=provider.provider_apikey,
|
provider_apikey=provider.provider_apikey,
|
||||||
provider_models=provider.provider_models,
|
provider_models=provider.provider_models,
|
||||||
|
|
@ -64,7 +67,9 @@ class ProviderManager:
|
||||||
|
|
||||||
logger.info(f"已添加适配器{provider_title}")
|
logger.info(f"已添加适配器{provider_title}")
|
||||||
except httpx.RequestError as e:
|
except httpx.RequestError as e:
|
||||||
|
from pretor.utils.error import RetryableError
|
||||||
logger.warning(f"[{provider_args.provider_title}] 网络请求异常: {e}")
|
logger.warning(f"[{provider_args.provider_title}] 网络请求异常: {e}")
|
||||||
|
raise RetryableError(f"[{provider_args.provider_title}] 网络请求异常: {e}") from e
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"[{provider_args.provider_title}] 解析模型列表时发生错误: {e}")
|
logger.warning(f"[{provider_args.provider_title}] 解析模型列表时发生错误: {e}")
|
||||||
|
|
||||||
|
|
@ -77,5 +82,5 @@ class ProviderManager:
|
||||||
async def delete_provider(self, provider_title: str, postgres_database) -> None:
|
async def delete_provider(self, provider_title: str, postgres_database) -> None:
|
||||||
if provider_title in self.provider_register:
|
if provider_title in self.provider_register:
|
||||||
provider = self.provider_register[provider_title]
|
provider = self.provider_register[provider_title]
|
||||||
await postgres_database.provider_database.remote("delete_provider", provider_id=provider.provider_id)
|
await postgres_database.delete_provider_db.remote( provider_id=provider.provider_id)
|
||||||
del self.provider_register[provider_title]
|
del self.provider_register[provider_title]
|
||||||
|
|
@ -33,7 +33,7 @@ class ConsciousnessNode:
|
||||||
self.agent: None | Agent = None
|
self.agent: None | Agent = None
|
||||||
|
|
||||||
|
|
||||||
def create_agent(self, global_state_machine: GlobalStateMachine, provider_title: str, model_id: str) -> None:
|
async def create_agent(self, global_state_machine: GlobalStateMachine, provider_title: str, model_id: str) -> None:
|
||||||
"""
|
"""
|
||||||
create_agent方法,将agent对象装配到ConsciousnessNode的属性内
|
create_agent方法,将agent对象装配到ConsciousnessNode的属性内
|
||||||
该方法通过provider_title从global_state_machine中获取provider对象,然后从provider对象中取出供应商形象,装配为pydantic_ai的
|
该方法通过provider_title从global_state_machine中获取provider对象,然后从provider对象中取出供应商形象,装配为pydantic_ai的
|
||||||
|
|
@ -57,7 +57,7 @@ class ConsciousnessNode:
|
||||||
"请确保所有的思考和生成过程符合逻辑,严密且高质量。"
|
"请确保所有的思考和生成过程符合逻辑,严密且高质量。"
|
||||||
)
|
)
|
||||||
output_type = Union[ForSupervisoryNode, ForWorkflow, ForWorkflowEngine]
|
output_type = Union[ForSupervisoryNode, ForWorkflow, ForWorkflowEngine]
|
||||||
provider: Provider = global_state_machine.provider_manager.remote("get_provider", provider_title)
|
provider: Provider = await global_state_machine.get_provider.remote( provider_title)
|
||||||
agent_factory = AgentFactory()
|
agent_factory = AgentFactory()
|
||||||
self.agent = agent_factory.create_agent(provider=provider,
|
self.agent = agent_factory.create_agent(provider=provider,
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
|
|
|
||||||
|
|
@ -30,7 +30,7 @@ class ControlNode:
|
||||||
self.agent: Agent | None = None
|
self.agent: Agent | None = None
|
||||||
|
|
||||||
|
|
||||||
def create_agent(self, global_state_machine: GlobalStateMachine, provider_title: str, model_id: str) -> None:
|
async def create_agent(self, global_state_machine: GlobalStateMachine, provider_title: str, model_id: str) -> None:
|
||||||
"""
|
"""
|
||||||
create_agent方法,将agent对象装配到Control的属性内
|
create_agent方法,将agent对象装配到Control的属性内
|
||||||
该方法通过provider_title从global_state_machine中获取provider对象,然后从provider对象中取出供应商形象,装配为pydantic_ai的
|
该方法通过provider_title从global_state_machine中获取provider对象,然后从provider对象中取出供应商形象,装配为pydantic_ai的
|
||||||
|
|
@ -54,7 +54,7 @@ class ControlNode:
|
||||||
"请注意:你的输出应当具体、实用,直接提供任务所要求的结果,不要做过多无关的寒暄。"
|
"请注意:你的输出应当具体、实用,直接提供任务所要求的结果,不要做过多无关的寒暄。"
|
||||||
)
|
)
|
||||||
output_type = ForWorkflow
|
output_type = ForWorkflow
|
||||||
provider: Provider = global_state_machine.provider_manager.remote("get_provider", provider_title)
|
provider: Provider = await global_state_machine.get_provider.remote( provider_title)
|
||||||
agent_factory = AgentFactory()
|
agent_factory = AgentFactory()
|
||||||
self.agent = agent_factory.create_agent(provider=provider,
|
self.agent = agent_factory.create_agent(provider=provider,
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
|
|
|
||||||
|
|
@ -57,7 +57,7 @@ class SupervisoryNode:
|
||||||
"请保持冷静、专业,并严格遵循上述路由规则。"
|
"请保持冷静、专业,并严格遵循上述路由规则。"
|
||||||
)
|
)
|
||||||
output_type = Union[ForConsciousnessNode, ForUser]
|
output_type = Union[ForConsciousnessNode, ForUser]
|
||||||
provider: Provider = await global_state_machine.provider_manager.remote("get_provider", provider_title)
|
provider: Provider = await global_state_machine.get_provider.remote( provider_title)
|
||||||
agent_factory = AgentFactory()
|
agent_factory = AgentFactory()
|
||||||
self.agent = agent_factory.create_agent(provider=provider,
|
self.agent = agent_factory.create_agent(provider=provider,
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
|
|
@ -157,8 +157,8 @@ class SupervisoryNode:
|
||||||
message = payload.message
|
message = payload.message
|
||||||
time_str = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
time_str = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||||
try:
|
try:
|
||||||
global_state_machine = ray_actor_hook("global_state_machine")
|
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
|
||||||
workflow_template_dict = await global_state_machine.workflow_template_manager.remote("get_workflow_template_list")
|
workflow_template_dict = await global_state_machine.get_all_workflow_templates.remote()
|
||||||
available_templates_str = "\n".join([f"- 名称: {k}, 描述/内容: {v}" for k, v in
|
available_templates_str = "\n".join([f"- 名称: {k}, 描述/内容: {v}" for k, v in
|
||||||
workflow_template_dict.items()]) if workflow_template_dict else "暂无注册的工作流模板"
|
workflow_template_dict.items()]) if workflow_template_dict else "暂无注册的工作流模板"
|
||||||
deps = SupervisoryNodeDeps(
|
deps = SupervisoryNodeDeps(
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,6 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from functools import lru_cache
|
|
||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
from fastapi import Depends, HTTPException
|
from fastapi import Depends, HTTPException
|
||||||
from pretor.utils.access import Accessor, TokenData
|
from pretor.utils.access import Accessor, TokenData
|
||||||
|
|
@ -19,9 +18,24 @@ from pretor.core.database.table.user import UserAuthority
|
||||||
from pretor.utils.ray_hook import ray_actor_hook
|
from pretor.utils.ray_hook import ray_actor_hook
|
||||||
|
|
||||||
async def get_authority(user_id: str) -> UserAuthority:
|
async def get_authority(user_id: str) -> UserAuthority:
|
||||||
|
from pretor.utils.error import UserNotExistError
|
||||||
postgres_database = ray_actor_hook("postgres_database").postgres_database
|
postgres_database = ray_actor_hook("postgres_database").postgres_database
|
||||||
user_authority = await postgres_database.auth_database.remote("get_user_authority", user_id=user_id)
|
try:
|
||||||
return user_authority
|
user_authority = await postgres_database.get_user_authority.remote(user_id=user_id)
|
||||||
|
return user_authority
|
||||||
|
except UserNotExistError:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=401,
|
||||||
|
detail="用户不存在或已被删除,请重新登录"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
# Check if it's a RayTaskError wrapping UserNotExistError
|
||||||
|
if "UserNotExistError" in str(e):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=401,
|
||||||
|
detail="用户不存在或已被删除,请重新登录"
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
class RoleChecker:
|
class RoleChecker:
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
|
|
|
||||||
|
|
@ -12,7 +12,15 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
class DemandError(Exception):
|
class RetryableError(Exception):
|
||||||
|
"""基类:所有可重试错误(如网络断开、抖动等临时性故障)"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
class NonRetryableError(Exception):
|
||||||
|
"""基类:所有不可重试错误(如数据验证失败、类型错误等业务逻辑故障)"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
class DemandError(NonRetryableError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
class ModelNotExistError(Exception):
|
class ModelNotExistError(Exception):
|
||||||
|
|
|
||||||
|
|
@ -51,7 +51,7 @@ def del_tool_cache(tool_name: str) -> None:
|
||||||
@lru_cache(maxsize=1)
|
@lru_cache(maxsize=1)
|
||||||
async def get_tool(agent_name: str) -> List[Callable]:
|
async def get_tool(agent_name: str) -> List[Callable]:
|
||||||
global_state_machine = ray_actor_hook("global_state_machine")
|
global_state_machine = ray_actor_hook("global_state_machine")
|
||||||
_tool_list = await global_state_machine.tool_manager.remote("get_tool_list", agent_name)
|
_tool_list = await global_state_machine.get_tool_list.remote( agent_name)
|
||||||
tool_list = []
|
tool_list = []
|
||||||
for tool_name in _tool_list.keys():
|
for tool_name in _tool_list.keys():
|
||||||
tool_func = _get_tool_func(tool_name)
|
tool_func = _get_tool_func(tool_name)
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,31 @@
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from functools import wraps
|
||||||
|
from pretor.utils.error import RetryableError
|
||||||
|
|
||||||
|
def retry_on_retryable_error(max_retries=3, base_delay=1):
|
||||||
|
def decorator(func):
|
||||||
|
if asyncio.iscoroutinefunction(func):
|
||||||
|
@wraps(func)
|
||||||
|
async def async_wrapper(*args, **kwargs):
|
||||||
|
for attempt in range(max_retries):
|
||||||
|
try:
|
||||||
|
return await func(*args, **kwargs)
|
||||||
|
except RetryableError:
|
||||||
|
if attempt == max_retries - 1:
|
||||||
|
raise
|
||||||
|
await asyncio.sleep(base_delay * (2 ** attempt))
|
||||||
|
return async_wrapper
|
||||||
|
else:
|
||||||
|
@wraps(func)
|
||||||
|
def sync_wrapper(*args, **kwargs):
|
||||||
|
import time
|
||||||
|
for attempt in range(max_retries):
|
||||||
|
try:
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
except RetryableError:
|
||||||
|
if attempt == max_retries - 1:
|
||||||
|
raise
|
||||||
|
time.sleep(base_delay * (2 ** attempt))
|
||||||
|
return sync_wrapper
|
||||||
|
return decorator
|
||||||
|
|
@ -52,7 +52,7 @@ class WorkerCluster:
|
||||||
return self._active_workers[agent_id]
|
return self._active_workers[agent_id]
|
||||||
|
|
||||||
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
|
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
|
||||||
agent_config = await global_state_machine.individual_manager.remote("get_individual", agent_id)
|
agent_config = await global_state_machine.get_individual.remote( agent_id)
|
||||||
|
|
||||||
if not agent_config:
|
if not agent_config:
|
||||||
raise ValueError(f"无法唤醒 Agent {agent_id}:数据库中不存在该档案")
|
raise ValueError(f"无法唤醒 Agent {agent_id}:数据库中不存在该档案")
|
||||||
|
|
|
||||||
|
|
@ -51,7 +51,7 @@ class BaseIndividual:
|
||||||
provider_title = self.agent_config.get("provider_title", "openai") # default fallback
|
provider_title = self.agent_config.get("provider_title", "openai") # default fallback
|
||||||
model_id = self.agent_config.get("model_id", "gpt-4o") # default fallback
|
model_id = self.agent_config.get("model_id", "gpt-4o") # default fallback
|
||||||
|
|
||||||
provider: Provider = await global_state_machine.provider_manager.remote("get_provider", provider_title)
|
provider: Provider = await global_state_machine.get_provider.remote( provider_title)
|
||||||
agent_factory = AgentFactory()
|
agent_factory = AgentFactory()
|
||||||
self.agent = agent_factory.create_agent(
|
self.agent = agent_factory.create_agent(
|
||||||
provider=provider,
|
provider=provider,
|
||||||
|
|
|
||||||
|
|
@ -8,8 +8,8 @@ def test_create_agent_success_real():
|
||||||
mock_provider = MagicMock()
|
mock_provider = MagicMock()
|
||||||
mock_provider.provider_type = "openai"
|
mock_provider.provider_type = "openai"
|
||||||
mock_provider.provider_models = ["gpt-4"]
|
mock_provider.provider_models = ["gpt-4"]
|
||||||
mock_provider.api_key = "key"
|
mock_provider.provider_apikey = "key"
|
||||||
mock_provider.url = "url"
|
mock_provider.provider_url = "url"
|
||||||
|
|
||||||
with patch("pretor.adapter.model_adapter.agent_factory.Agent") as mock_agent_cls:
|
with patch("pretor.adapter.model_adapter.agent_factory.Agent") as mock_agent_cls:
|
||||||
with patch("pretor.adapter.model_adapter.agent_factory.OpenAIChatModel") as mock_model_cls:
|
with patch("pretor.adapter.model_adapter.agent_factory.OpenAIChatModel") as mock_model_cls:
|
||||||
|
|
@ -23,8 +23,8 @@ def test_create_agent_success_real():
|
||||||
deps_type=dict,
|
deps_type=dict,
|
||||||
agent_name="myagent"
|
agent_name="myagent"
|
||||||
)
|
)
|
||||||
mock_provider_cls.assert_called_once_with(api_key="key", url="url")
|
mock_provider_cls.assert_called_once_with(api_key="key", base_url="url")
|
||||||
mock_model_cls.assert_called_once_with("gpt-4", mock_provider_cls.return_value)
|
mock_model_cls.assert_called_once_with("gpt-4", provider=mock_provider_cls.return_value)
|
||||||
mock_agent_cls.assert_called_once_with(
|
mock_agent_cls.assert_called_once_with(
|
||||||
model=mock_model_cls.return_value,
|
model=mock_model_cls.return_value,
|
||||||
name="myagent",
|
name="myagent",
|
||||||
|
|
|
||||||
|
|
@ -10,8 +10,12 @@ def mock_import(name, globals=None, locals=None, fromlist=(), level=0):
|
||||||
if name == 'ray':
|
if name == 'ray':
|
||||||
mock_ray = MagicMock()
|
mock_ray = MagicMock()
|
||||||
|
|
||||||
def mock_remote(cls):
|
def mock_remote(*args, **kwargs):
|
||||||
return cls
|
if len(args) == 1 and callable(args[0]):
|
||||||
|
return args[0]
|
||||||
|
def decorator(cls):
|
||||||
|
return cls
|
||||||
|
return decorator
|
||||||
|
|
||||||
mock_ray.remote = mock_remote
|
mock_ray.remote = mock_remote
|
||||||
return mock_ray
|
return mock_ray
|
||||||
|
|
@ -66,8 +70,9 @@ async def test_postgres_database(mock_env_get, mock_provider_db, mock_auth_db, m
|
||||||
mock_auth_db.assert_called_once()
|
mock_auth_db.assert_called_once()
|
||||||
mock_provider_db.assert_called_once()
|
mock_provider_db.assert_called_once()
|
||||||
mock_auth_db.return_value.get_user_authority = AsyncMock(return_value="test_auth")
|
mock_auth_db.return_value.get_user_authority = AsyncMock(return_value="test_auth")
|
||||||
assert await db.auth_database("get_user_authority", user_id="123") == "test_auth"
|
|
||||||
|
|
||||||
with patch("pretor.core.database.postgres.SQLModel.metadata.create_all") as mock_create_all:
|
with patch("pretor.core.database.postgres.SQLModel.metadata.create_all") as mock_create_all:
|
||||||
await db.init_db()
|
await db.init_db()
|
||||||
mock_conn.run_sync.assert_called_once_with(mock_create_all)
|
mock_conn.run_sync.assert_called_once_with(mock_create_all)
|
||||||
|
|
||||||
|
assert await db.get_user_authority(user_id="123") == "test_auth"
|
||||||
|
|
|
||||||
|
|
@ -12,8 +12,12 @@ def mock_import(name, globals=None, locals=None, fromlist=(), level=0):
|
||||||
if name == 'ray':
|
if name == 'ray':
|
||||||
mock_ray = MagicMock()
|
mock_ray = MagicMock()
|
||||||
|
|
||||||
def mock_remote(cls):
|
def mock_remote(*args, **kwargs):
|
||||||
return cls
|
if len(args) == 1 and callable(args[0]):
|
||||||
|
return args[0]
|
||||||
|
def decorator(cls):
|
||||||
|
return cls
|
||||||
|
return decorator
|
||||||
|
|
||||||
mock_ray.remote = mock_remote
|
mock_ray.remote = mock_remote
|
||||||
return mock_ray
|
return mock_ray
|
||||||
|
|
@ -100,19 +104,20 @@ async def test_add_provider_success(gsm, mock_postgres):
|
||||||
mock_provider.provider_apikey = "key"
|
mock_provider.provider_apikey = "key"
|
||||||
mock_provider.provider_models = ["model"]
|
mock_provider.provider_models = ["model"]
|
||||||
mock_provider.provider_type = "openai"
|
mock_provider.provider_type = "openai"
|
||||||
mock_provider_class.create_model.return_value = mock_provider
|
mock_provider_class.create_provider.return_value = mock_provider
|
||||||
|
|
||||||
gsm._global_provider_manager.provider_mapper = {"openai": mock_provider_class}
|
gsm._global_provider_manager.provider_mapper = {"openai": mock_provider_class}
|
||||||
gsm._global_provider_manager.provider_register = {}
|
gsm._global_provider_manager.provider_register = {}
|
||||||
|
|
||||||
mock_add_provider = AsyncMock()
|
mock_add_provider = AsyncMock()
|
||||||
mock_postgres.provider_database.remote = mock_add_provider
|
mock_postgres.add_provider_db = MagicMock()
|
||||||
|
mock_postgres.add_provider_db.remote = mock_add_provider
|
||||||
|
|
||||||
await gsm.add_provider_wrap("openai", "title", "url", "key", 1)
|
await gsm.add_provider_wrap("openai", "title", "url", "key", "1")
|
||||||
|
|
||||||
assert gsm._global_provider_manager.provider_register["title"] == mock_provider
|
assert gsm._global_provider_manager.provider_register["title"] == mock_provider
|
||||||
mock_add_provider.assert_called_once()
|
mock_add_provider.assert_called_once()
|
||||||
assert mock_provider.provider_owner == 1
|
assert mock_provider.provider_owner == "1"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|
@ -121,7 +126,7 @@ async def test_add_provider_unsupported(gsm):
|
||||||
with patch("pretor.utils.logger.global_logger.bind") as mock_bind:
|
with patch("pretor.utils.logger.global_logger.bind") as mock_bind:
|
||||||
mock_logger = MagicMock()
|
mock_logger = MagicMock()
|
||||||
mock_bind.return_value = mock_logger
|
mock_bind.return_value = mock_logger
|
||||||
await gsm.add_provider_wrap("magic", "title", "url", "key", 1)
|
await gsm.add_provider_wrap("magic", "title", "url", "key", "1")
|
||||||
mock_logger.warning.assert_called_with("Provider type magic is not supported.")
|
mock_logger.warning.assert_called_with("Provider type magic is not supported.")
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -129,27 +134,29 @@ async def test_add_provider_unsupported(gsm):
|
||||||
async def test_add_provider_request_error(gsm):
|
async def test_add_provider_request_error(gsm):
|
||||||
from httpx import RequestError
|
from httpx import RequestError
|
||||||
mock_provider_class = AsyncMock()
|
mock_provider_class = AsyncMock()
|
||||||
mock_provider_class.create_model.side_effect = RequestError("Network Error", request=MagicMock())
|
mock_provider_class.create_provider.side_effect = RequestError("Network Error", request=MagicMock())
|
||||||
gsm._global_provider_manager.provider_mapper = {"openai": mock_provider_class}
|
gsm._global_provider_manager.provider_mapper = {"openai": mock_provider_class}
|
||||||
|
|
||||||
with patch("pretor.utils.logger.global_logger.bind") as mock_bind:
|
with patch("pretor.utils.logger.global_logger.bind") as mock_bind:
|
||||||
|
from pretor.utils.error import RetryableError
|
||||||
|
import pytest
|
||||||
mock_logger = MagicMock()
|
mock_logger = MagicMock()
|
||||||
mock_bind.return_value = mock_logger
|
mock_bind.return_value = mock_logger
|
||||||
await gsm.add_provider_wrap("openai", "title", "url", "key", 1)
|
with pytest.raises(RetryableError):
|
||||||
mock_logger.warning.assert_called_once()
|
await gsm.add_provider_wrap("openai", "title", "url", "key", "1")
|
||||||
assert "网络请求异常" in mock_logger.warning.call_args[0][0]
|
mock_logger.warning.assert_called()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_add_provider_generic_error(gsm):
|
async def test_add_provider_generic_error(gsm):
|
||||||
mock_provider_class = AsyncMock()
|
mock_provider_class = AsyncMock()
|
||||||
mock_provider_class.create_model.side_effect = ValueError("Some Error")
|
mock_provider_class.create_provider.side_effect = ValueError("Some Error")
|
||||||
gsm._global_provider_manager.provider_mapper = {"openai": mock_provider_class}
|
gsm._global_provider_manager.provider_mapper = {"openai": mock_provider_class}
|
||||||
|
|
||||||
with patch("pretor.utils.logger.global_logger.bind") as mock_bind:
|
with patch("pretor.utils.logger.global_logger.bind") as mock_bind:
|
||||||
mock_logger = MagicMock()
|
mock_logger = MagicMock()
|
||||||
mock_bind.return_value = mock_logger
|
mock_bind.return_value = mock_logger
|
||||||
await gsm.add_provider_wrap("openai", "title", "url", "key", 1)
|
await gsm.add_provider_wrap("openai", "title", "url", "key", "1")
|
||||||
mock_logger.warning.assert_called_once()
|
mock_logger.warning.assert_called_once()
|
||||||
assert "解析模型列表时发生错误" in mock_logger.warning.call_args[0][0]
|
assert "解析模型列表时发生错误" in mock_logger.warning.call_args[0][0]
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,7 @@ def provider_args():
|
||||||
provider_title="TestClaude",
|
provider_title="TestClaude",
|
||||||
provider_url="https://api.anthropic.com",
|
provider_url="https://api.anthropic.com",
|
||||||
provider_apikey="testkey",
|
provider_apikey="testkey",
|
||||||
provider_owner=1
|
provider_owner="1"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,7 @@ def provider_args():
|
||||||
provider_title="TestGemini",
|
provider_title="TestGemini",
|
||||||
provider_url="https://generativelanguage.googleapis.com",
|
provider_url="https://generativelanguage.googleapis.com",
|
||||||
provider_apikey="testkey",
|
provider_apikey="testkey",
|
||||||
provider_owner=1
|
provider_owner="1"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,7 @@ def provider_args():
|
||||||
provider_title="TestOpenAI",
|
provider_title="TestOpenAI",
|
||||||
provider_url="https://api.openai.com/v1",
|
provider_url="https://api.openai.com/v1",
|
||||||
provider_apikey="testkey",
|
provider_apikey="testkey",
|
||||||
provider_owner=1
|
provider_owner="1"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -19,7 +19,7 @@ def provider_args_no_v1():
|
||||||
provider_title="TestOpenAI",
|
provider_title="TestOpenAI",
|
||||||
provider_url="https://api.openai.com",
|
provider_url="https://api.openai.com",
|
||||||
provider_apikey="testkey",
|
provider_apikey="testkey",
|
||||||
provider_owner=1
|
provider_owner="1"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -85,8 +85,10 @@ async def test_load_models_request_error(mock_client, provider_args):
|
||||||
mock_client_instance.get.side_effect = httpx.RequestError("network error", request=MagicMock())
|
mock_client_instance.get.side_effect = httpx.RequestError("network error", request=MagicMock())
|
||||||
mock_client.return_value.__aenter__.return_value = mock_client_instance
|
mock_client.return_value.__aenter__.return_value = mock_client_instance
|
||||||
|
|
||||||
models = await OpenAIProvider._load_models(provider_args)
|
import pytest
|
||||||
assert models == []
|
from pretor.utils.error import RetryableError
|
||||||
|
with pytest.raises(RetryableError):
|
||||||
|
await OpenAIProvider._load_models(provider_args)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|
|
||||||
|
|
@ -13,7 +13,8 @@ async def test_provider_manager_init():
|
||||||
mock_provider2 = MagicMock()
|
mock_provider2 = MagicMock()
|
||||||
mock_provider2.provider_title = "title2"
|
mock_provider2.provider_title = "title2"
|
||||||
|
|
||||||
mock_postgres.get_providers.remote = AsyncMock(return_value=[mock_provider1, mock_provider2])
|
mock_postgres.get_provider = MagicMock()
|
||||||
|
mock_postgres.get_provider.remote = AsyncMock(return_value=[mock_provider1, mock_provider2])
|
||||||
|
|
||||||
manager = ProviderManager(mock_postgres)
|
manager = ProviderManager(mock_postgres)
|
||||||
mock_postgres.provider_database = MagicMock()
|
mock_postgres.provider_database = MagicMock()
|
||||||
|
|
|
||||||
|
|
@ -12,8 +12,12 @@ def mock_import(name, globals=None, locals=None, fromlist=(), level=0):
|
||||||
if name == 'ray':
|
if name == 'ray':
|
||||||
mock_ray = MagicMock()
|
mock_ray = MagicMock()
|
||||||
|
|
||||||
def mock_remote(cls):
|
def mock_remote(*args, **kwargs):
|
||||||
return cls
|
if len(args) == 1 and callable(args[0]):
|
||||||
|
return args[0]
|
||||||
|
def decorator(cls):
|
||||||
|
return cls
|
||||||
|
return decorator
|
||||||
|
|
||||||
mock_ray.remote = mock_remote
|
mock_ray.remote = mock_remote
|
||||||
return mock_ray
|
return mock_ray
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue