Skip to content

Commit c8552cf

Browse files
committed
Merge branch 'add/resume-training-checkpoint' of https://github.com/transformerlab/transformerlab-app into add/resume-training-checkpoint
2 parents 319bb06 + 77e4fc1 commit c8552cf

File tree

10 files changed

+905
-729
lines changed

10 files changed

+905
-729
lines changed

src/renderer/components/Experiment/Tasks/DirectoryUpload.tsx

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,10 @@ import { UploadIcon, FileIcon, XIcon } from 'lucide-react';
1313
import * as chatAPI from 'renderer/lib/transformerlab-api-sdk';
1414

1515
interface DirectoryUploadProps {
16-
onUploadComplete?: (uploadedDirPath: string) => void;
16+
onUploadComplete?: (
17+
uploadedDirPath: string,
18+
localStoragePath?: string,
19+
) => void;
1720
onUploadError?: (error: string) => void;
1821
disabled?: boolean;
1922
}
@@ -79,6 +82,7 @@ export default function DirectoryUpload({
7982
if (result.status === 'success') {
8083
const uploadedDirPathResult =
8184
result.data.uploaded_files.dir_files.uploaded_dir;
85+
const localStoragePath = result.local_storage_path;
8286
setUploadedDirPath(uploadedDirPathResult);
8387

8488
// Update uploaded files list
@@ -89,7 +93,7 @@ export default function DirectoryUpload({
8993
}));
9094
setUploadedFiles(filesList);
9195

92-
onUploadComplete(uploadedDirPathResult);
96+
onUploadComplete(uploadedDirPathResult, localStoragePath);
9397
} else {
9498
const errorMessage = result.message || 'Upload failed';
9599
setUploadError(errorMessage);

src/renderer/components/Experiment/Tasks/NewTaskModal.tsx

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ type NewTaskModalProps = {
3939
num_nodes?: number;
4040
setup?: string;
4141
uploaded_dir_path?: string;
42+
local_upload_copy?: string;
4243
}) => void;
4344
isSubmitting?: boolean;
4445
};
@@ -61,6 +62,7 @@ export default function NewTaskModal({
6162
const [numNodes, setNumNodes] = React.useState('');
6263
const [setup, setSetup] = React.useState('');
6364
const [uploadedDirPath, setUploadedDirPath] = React.useState('');
65+
const [localUploadCopy, setLocalUploadCopy] = React.useState('');
6466
// keep separate refs for the two Monaco editors
6567
const setupEditorRef = useRef<any>(null);
6668
const commandEditorRef = useRef<any>(null);
@@ -89,6 +91,7 @@ export default function NewTaskModal({
8991
num_nodes: numNodes ? parseInt(numNodes, 10) : undefined,
9092
setup: setupValue,
9193
uploaded_dir_path: uploadedDirPath || undefined,
94+
local_upload_copy: localUploadCopy || undefined,
9295
});
9396
// Reset all form fields
9497
setTitle('');
@@ -270,7 +273,12 @@ export default function NewTaskModal({
270273
</FormControl>
271274

272275
<DirectoryUpload
273-
onUploadComplete={(path) => setUploadedDirPath(path)}
276+
onUploadComplete={(path, localPath) => {
277+
setUploadedDirPath(path);
278+
if (localPath) {
279+
setLocalUploadCopy(localPath);
280+
}
281+
}}
274282
onUploadError={(error) => console.error('Upload error:', error)}
275283
disabled={isSubmitting}
276284
/>

src/renderer/components/Experiment/Tasks/Tasks.tsx

Lines changed: 73 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,7 @@ export default function Tasks() {
203203
num_nodes: data.num_nodes || undefined,
204204
setup: data.setup || undefined,
205205
uploaded_dir_path: data.uploaded_dir_path || undefined,
206+
local_upload_copy: data.local_upload_copy || undefined,
206207
},
207208
plugin: 'remote_orchestrator',
208209
outputs: {},
@@ -246,45 +247,93 @@ export default function Tasks() {
246247
}
247248
};
248249

250+
// Helper function to build FormData for remote job operations
251+
const buildRemoteJobFormData = (task: any, cfg: any, jobId?: string) => {
252+
const formData = new FormData();
253+
formData.append('experimentId', experimentInfo.id);
254+
255+
if (jobId) {
256+
formData.append('job_id', jobId);
257+
}
258+
259+
if (cfg.cluster_name) formData.append('cluster_name', cfg.cluster_name);
260+
if (cfg.command) formData.append('command', cfg.command);
261+
if (task.name) formData.append('task_name', task.name);
262+
if (cfg.cpus) formData.append('cpus', String(cfg.cpus));
263+
if (cfg.memory) formData.append('memory', String(cfg.memory));
264+
if (cfg.disk_space) formData.append('disk_space', String(cfg.disk_space));
265+
if (cfg.accelerators)
266+
formData.append('accelerators', String(cfg.accelerators));
267+
if (cfg.num_nodes) formData.append('num_nodes', String(cfg.num_nodes));
268+
if (cfg.setup) formData.append('setup', String(cfg.setup));
269+
if (cfg.uploaded_dir_path)
270+
formData.append('uploaded_dir_path', String(cfg.uploaded_dir_path));
271+
272+
return formData;
273+
};
274+
249275
const handleQueue = async (task: any) => {
250276
if (!experimentInfo?.id) return;
251277

278+
addNotification({
279+
type: 'success',
280+
message: 'Creating job...',
281+
});
282+
252283
try {
253284
const cfg =
254285
typeof task.config === 'string'
255286
? JSON.parse(task.config)
256287
: task.config || {};
257-
const formData = new FormData();
258-
formData.append('experimentId', experimentInfo.id);
259-
if (cfg.cluster_name) formData.append('cluster_name', cfg.cluster_name);
260-
if (cfg.command) formData.append('command', cfg.command);
261-
// Prefer the task name as job/task name
262-
if (task.name) formData.append('task_name', task.name);
263-
if (cfg.cpus) formData.append('cpus', String(cfg.cpus));
264-
if (cfg.memory) formData.append('memory', String(cfg.memory));
265-
if (cfg.disk_space) formData.append('disk_space', String(cfg.disk_space));
266-
if (cfg.accelerators)
267-
formData.append('accelerators', String(cfg.accelerators));
268-
if (cfg.num_nodes) formData.append('num_nodes', String(cfg.num_nodes));
269-
if (cfg.setup) formData.append('setup', String(cfg.setup));
270-
if (cfg.uploaded_dir_path)
271-
formData.append('uploaded_dir_path', String(cfg.uploaded_dir_path));
272-
273-
const resp = await chatAPI.authenticatedFetch(
274-
chatAPI.Endpoints.Jobs.LaunchRemote(experimentInfo.id),
275-
{ method: 'POST', body: formData },
288+
289+
// Create the actual remote job
290+
const createJobFormData = buildRemoteJobFormData(task, cfg);
291+
292+
const createJobResp = await chatAPI.authenticatedFetch(
293+
chatAPI.Endpoints.Jobs.CreateRemoteJob(experimentInfo.id),
294+
{ method: 'POST', body: createJobFormData },
276295
);
277-
const result = await resp.json();
278-
if (result.status === 'success') {
296+
const createJobResult = await createJobResp.json();
297+
298+
if (createJobResult.status === 'success') {
299+
// Keep placeholder visible and refresh jobs list
300+
// The placeholder will be replaced when the real job appears
301+
await jobsMutate();
302+
279303
addNotification({
280304
type: 'success',
281-
message: 'Task queued for remote launch.',
305+
message: 'Job created. Launching remotely...',
282306
});
283-
await Promise.all([jobsMutate(), tasksMutate()]);
307+
308+
// Then launch the remote job
309+
const launchFormData = buildRemoteJobFormData(
310+
task,
311+
cfg,
312+
createJobResult.job_id,
313+
);
314+
315+
const launchResp = await chatAPI.authenticatedFetch(
316+
chatAPI.Endpoints.Jobs.LaunchRemote(experimentInfo.id),
317+
{ method: 'POST', body: launchFormData },
318+
);
319+
const launchResult = await launchResp.json();
320+
321+
if (launchResult.status === 'success') {
322+
addNotification({
323+
type: 'success',
324+
message: 'Task launched remotely.',
325+
});
326+
await Promise.all([jobsMutate(), tasksMutate()]);
327+
} else {
328+
addNotification({
329+
type: 'danger',
330+
message: `Remote launch failed: ${launchResult.message}`,
331+
});
332+
}
284333
} else {
285334
addNotification({
286335
type: 'danger',
287-
message: `Remote launch failed: ${result.message}`,
336+
message: `Failed to create job: ${createJobResult.message}`,
288337
});
289338
}
290339
} catch (e) {

src/renderer/components/MainAppPanel.tsx

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ import Welcome from './Welcome/Welcome';
2121
import ModelZoo from './ModelZoo/ModelZoo';
2222
import Plugins from './Plugins/Plugins';
2323
import PluginDetails from './Plugins/PluginDetails';
24-
import TaskLibrary from './TaskLibrary/TaskLibrary';
24+
import TasksGallery from './TaskLibrary/TasksGallery';
2525

2626
import Computer from './Computer';
2727
import Eval from './Experiment/Eval/Eval';
@@ -45,7 +45,6 @@ import SelectEmbeddingModel from './Experiment/Foundation/SelectEmbeddingModel';
4545
import { useAnalytics } from './Shared/analytics/AnalyticsContext';
4646
import SafeJSONParse from './Shared/SafeJSONParse';
4747
import Tasks from './Experiment/Tasks/Tasks';
48-
import TaskLibrary from './TaskLibrary/TaskLibrary';
4948

5049
// // Define the app version
5150
// const APP_VERSION = '1.0.0';
@@ -399,7 +398,7 @@ export default function MainAppPanel({
399398
element={<Plugins setLogsDrawerOpen={setLogsDrawerOpen} />}
400399
/>
401400
<Route path="/plugins/:pluginName" element={<PluginDetails />} />
402-
<Route path="/task_library" element={<TaskLibrary />} />
401+
<Route path="/task_library" element={<TasksGallery />} />
403402
<Route path="/api" element={<Api />} />
404403
<Route path="/experiment/settings" element={<Settings />} />
405404
<Route
@@ -451,7 +450,7 @@ export default function MainAppPanel({
451450
path="/data"
452451
element={<Data gpuOrchestrationServer={gpuOrchestrationServer} />}
453452
/>
454-
<Route path="/task_library" element={<TaskLibrary />} />
453+
<Route path="/task_library" element={<TasksGallery />} />
455454
<Route path="/computer" element={<Computer />} />
456455
<Route path="/settings" element={<TransformerLabSettings />} />
457456
<Route path="/logs" element={<Logs />} />
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
import React from 'react';
2+
import {
3+
Sheet,
4+
Stack,
5+
Typography,
6+
Button,
7+
Card,
8+
CardContent,
9+
LinearProgress,
10+
Chip,
11+
Box,
12+
} from '@mui/joy';
13+
import { DownloadIcon, ExternalLinkIcon } from 'lucide-react';
14+
15+
interface GalleryTask {
16+
id: string;
17+
name: string;
18+
description: string;
19+
tag: string;
20+
source: string;
21+
}
22+
23+
interface GalleryTasksListProps {
24+
tasks: GalleryTask[];
25+
isLoading: boolean;
26+
onInstall: (id: string) => void;
27+
installingTasks: Set<string>;
28+
localTasks: Array<{ task_dir: string; name: string }>;
29+
}
30+
31+
export default function GalleryTasksList({
32+
tasks,
33+
isLoading,
34+
onInstall,
35+
installingTasks,
36+
localTasks,
37+
}: GalleryTasksListProps) {
38+
if (isLoading) {
39+
return <LinearProgress />;
40+
}
41+
42+
if (tasks.length === 0) {
43+
return (
44+
<Sheet
45+
variant="soft"
46+
sx={{
47+
p: 4,
48+
textAlign: 'center',
49+
borderRadius: 'md',
50+
}}
51+
>
52+
<Typography level="body-md" color="neutral">
53+
No gallery tasks available. Check your internet connection and try
54+
again.
55+
</Typography>
56+
</Sheet>
57+
);
58+
}
59+
60+
return (
61+
<Stack spacing={2}>
62+
{tasks.map((task) => (
63+
<Card key={task.subdir} variant="outlined">
64+
<CardContent>
65+
<Stack
66+
direction="row"
67+
justifyContent="space-between"
68+
alignItems="flex-start"
69+
>
70+
<Box sx={{ flex: 1 }}>
71+
<Stack
72+
direction="row"
73+
alignItems="center"
74+
spacing={1}
75+
sx={{ mb: 1 }}
76+
>
77+
<ExternalLinkIcon size={16} />
78+
<Typography level="title-sm">{task.name}</Typography>
79+
<Chip size="sm" variant="soft" color="primary">
80+
{task.tag}
81+
</Chip>
82+
</Stack>
83+
<Typography level="body-sm" color="neutral" sx={{ mb: 2 }}>
84+
{task.description || 'No description available'}
85+
</Typography>
86+
<Stack direction="row" spacing={1}>
87+
<Chip size="sm" variant="outlined" color="success">
88+
Gallery
89+
</Chip>
90+
</Stack>
91+
</Box>
92+
{(() => {
93+
const isInstalling = installingTasks.has(task.id);
94+
const isInstalled = localTasks.some(
95+
(localTask) => localTask.task_dir === task.id,
96+
);
97+
98+
if (isInstalling) {
99+
return (
100+
<Button size="sm" loading disabled>
101+
Installing...
102+
</Button>
103+
);
104+
}
105+
106+
if (isInstalled) {
107+
return (
108+
<Button
109+
size="sm"
110+
variant="outlined"
111+
color="success"
112+
disabled
113+
>
114+
Installed
115+
</Button>
116+
);
117+
}
118+
119+
return (
120+
<Button
121+
size="sm"
122+
startDecorator={<DownloadIcon size={16} />}
123+
onClick={() => onInstall(task.id)}
124+
>
125+
Install
126+
</Button>
127+
);
128+
})()}
129+
</Stack>
130+
</CardContent>
131+
</Card>
132+
))}
133+
</Stack>
134+
);
135+
}

0 commit comments

Comments
 (0)