Skip to content

Commit 1549ac8

Browse files
authored
Merge pull request #788 from transformerlab/fix/diffusion-calls-expid
Fix diffusion calls to also include experiment id
2 parents 3f2947b + 8e843e2 commit 1549ac8

File tree

6 files changed

+67
-14
lines changed

6 files changed

+67
-14
lines changed

src/renderer/components/Experiment/Diffusion/ControlNetModal.tsx

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import {
1414
import { DownloadIcon, TrashIcon, X } from 'lucide-react';
1515
import * as chatAPI from 'renderer/lib/transformerlab-api-sdk';
1616
import { getAPIFullPath } from 'renderer/lib/transformerlab-api-sdk';
17+
import { useExperimentInfo } from 'renderer/lib/ExperimentInfoContext';
1718

1819
export default function ControlNetModal({
1920
open,
@@ -22,11 +23,12 @@ export default function ControlNetModal({
2223
onSelect,
2324
}) {
2425
const [controlNets, setControlNets] = useState<string[]>([]);
26+
const { experimentId } = useExperimentInfo();
2527

2628
const refresh = async () => {
2729
try {
2830
const response = await fetch(
29-
getAPIFullPath('diffusion', ['listControlnets'], {}),
31+
getAPIFullPath('diffusion', ['listControlnets'], { experimentId }),
3032
);
3133
const models = await response.json();
3234
const names = (models.controlnets || []).map(

src/renderer/components/Experiment/Diffusion/Diffusion.tsx

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ type DiffusionProps = {
5656
experimentInfo: any;
5757
};
5858
import { useExperimentInfo } from 'renderer/lib/ExperimentInfoContext';
59+
import { useNotification } from 'renderer/components/Shared/NotificationSystem';
5960

6061
// Helper component for labels with tooltips
6162
const LabelWithTooltip = ({
@@ -111,6 +112,7 @@ const samplePrompts = [
111112

112113
export default function Diffusion() {
113114
const { experimentInfo } = useExperimentInfo();
115+
const { addNotification } = useNotification();
114116
const analytics = useAnalytics();
115117
const { data: diffusionJobs } = useAPI(
116118
'jobs',

src/renderer/components/Experiment/Diffusion/History.tsx

Lines changed: 40 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ import {
3131
ChevronRightIcon,
3232
} from 'lucide-react';
3333
import { getAPIFullPath, useAPI } from 'renderer/lib/transformerlab-api-sdk';
34+
import { useExperimentInfo } from 'renderer/lib/ExperimentInfoContext';
3435
import HistoryCard from './HistoryCard';
3536
import HistoryImageViewModal from './HistoryImageViewModal';
3637
import { HistoryImage } from './types';
@@ -42,12 +43,17 @@ const History: React.FC<HistoryProps> = () => {
4243
const [currentPage, setCurrentPage] = useState(1);
4344
const pageSize = 12; // Number of items per page
4445
const offset = (currentPage - 1) * pageSize;
46+
const { experimentId } = useExperimentInfo();
4547

4648
const {
4749
data: historyData,
4850
isLoading: historyLoading,
4951
mutate: refreshHistory,
50-
} = useAPI('diffusion', ['getHistory'], { limit: pageSize, offset });
52+
} = useAPI('diffusion', ['getHistory'], {
53+
experimentId,
54+
limit: pageSize,
55+
offset,
56+
});
5157

5258
// Calculate pagination info
5359
const totalPages = historyData?.total
@@ -91,7 +97,10 @@ const History: React.FC<HistoryProps> = () => {
9197
const viewImage = async (imageId: string) => {
9298
try {
9399
const response = await fetch(
94-
getAPIFullPath('diffusion', ['getImageInfo'], { imageId }),
100+
getAPIFullPath('diffusion', ['getImageInfo'], {
101+
imageId,
102+
experimentId,
103+
}),
95104
);
96105
const data = await response.json();
97106
setSelectedImage(data);
@@ -104,9 +113,15 @@ const History: React.FC<HistoryProps> = () => {
104113
// Delete single image
105114
const deleteImage = async (imageId: string) => {
106115
try {
107-
await fetch(getAPIFullPath('diffusion', ['deleteImage'], { imageId }), {
108-
method: 'DELETE',
109-
});
116+
await fetch(
117+
getAPIFullPath('diffusion', ['deleteImage'], {
118+
imageId,
119+
experimentId,
120+
}),
121+
{
122+
method: 'DELETE',
123+
},
124+
);
110125
refreshHistory(); // Reload history
111126
setDeleteConfirmOpen(false);
112127
setImageToDelete(null);
@@ -128,9 +143,15 @@ const History: React.FC<HistoryProps> = () => {
128143
// Delete all selected images
129144
await Promise.all(
130145
Array.from(selectedImages).map((imageId) =>
131-
fetch(getAPIFullPath('diffusion', ['deleteImage'], { imageId }), {
132-
method: 'DELETE',
133-
}),
146+
fetch(
147+
getAPIFullPath('diffusion', ['deleteImage'], {
148+
imageId,
149+
experimentId,
150+
}),
151+
{
152+
method: 'DELETE',
153+
},
154+
),
134155
),
135156
);
136157
refreshHistory(); // Reload history
@@ -152,9 +173,14 @@ const History: React.FC<HistoryProps> = () => {
152173
// Clear all history
153174
const clearAllHistory = async () => {
154175
try {
155-
await fetch(getAPIFullPath('diffusion', ['clearHistory'], {}), {
156-
method: 'DELETE',
157-
});
176+
await fetch(
177+
getAPIFullPath('diffusion', ['clearHistory'], {
178+
experimentId,
179+
}),
180+
{
181+
method: 'DELETE',
182+
},
183+
);
158184
refreshHistory(); // Reload history
159185
} catch (e) {
160186
// Error clearing history
@@ -199,7 +225,9 @@ const History: React.FC<HistoryProps> = () => {
199225

200226
try {
201227
const response = await fetch(
202-
getAPIFullPath('diffusion', ['createDataset'], {}),
228+
getAPIFullPath('diffusion', ['createDataset'], {
229+
experimentId,
230+
}),
203231
{
204232
method: 'POST',
205233
headers: { 'Content-Type': 'application/json' },

src/renderer/components/Experiment/Diffusion/HistoryCard.tsx

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import {
1111
} from '@mui/joy';
1212
import { Trash2Icon } from 'lucide-react';
1313
import { getAPIFullPath } from 'renderer/lib/transformerlab-api-sdk';
14+
import { useExperimentInfo } from 'renderer/lib/ExperimentInfoContext';
1415
import { HistoryImage } from './types';
1516

1617
interface HistoryCardProps {
@@ -34,6 +35,7 @@ const HistoryCard: React.FC<HistoryCardProps> = ({
3435
}) => {
3536
const numImages = item.num_images || item.metadata?.num_images || 1;
3637
const hasMultipleImages = numImages > 1;
38+
const { experimentId } = useExperimentInfo();
3739

3840
// Function to render multiple images in a grid
3941
const renderImages = () => {
@@ -58,6 +60,7 @@ const HistoryCard: React.FC<HistoryCardProps> = ({
5860
src={getAPIFullPath('diffusion', ['getImage'], {
5961
imageId: item.id,
6062
index,
63+
experimentId,
6164
})}
6265
alt={`Generated image ${index + 1}`}
6366
style={{
@@ -96,6 +99,7 @@ const HistoryCard: React.FC<HistoryCardProps> = ({
9699
src={getAPIFullPath('diffusion', ['getImage'], {
97100
imageId: item.id,
98101
index: 0,
102+
experimentId,
99103
})}
100104
alt="generated"
101105
style={{

src/renderer/components/Experiment/Diffusion/HistoryImageSelector.tsx

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ import {
1616
CardOverflow,
1717
AspectRatio,
1818
} from '@mui/joy';
19+
import { useExperimentInfo } from 'renderer/lib/ExperimentInfoContext';
1920
import { ChevronLeftIcon, ChevronRightIcon, CheckIcon } from 'lucide-react';
2021
import { getAPIFullPath, useAPI } from 'renderer/lib/transformerlab-api-sdk';
2122
import { HistoryImage } from './types';
@@ -36,11 +37,12 @@ const HistoryImageSelector: React.FC<HistoryImageSelectorProps> = ({
3637
const [selectedImageIndex, setSelectedImageIndex] = useState<number>(0);
3738
const pageSize = 12;
3839
const offset = (currentPage - 1) * pageSize;
40+
const { experimentId } = useExperimentInfo();
3941

4042
const { data: historyData, isLoading: historyLoading } = useAPI(
4143
'diffusion',
4244
['getHistory'],
43-
{ limit: pageSize, offset },
45+
{ experimentId, limit: pageSize, offset },
4446
);
4547

4648
// Reset selection when changing pages
@@ -63,6 +65,7 @@ const HistoryImageSelector: React.FC<HistoryImageSelectorProps> = ({
6365
getAPIFullPath('diffusion', ['getImage'], {
6466
imageId: selectedImageId,
6567
index: selectedImageIndex,
68+
experimentId,
6669
}),
6770
);
6871

@@ -112,6 +115,7 @@ const HistoryImageSelector: React.FC<HistoryImageSelectorProps> = ({
112115
src={getAPIFullPath('diffusion', ['getImage'], {
113116
imageId: item.id,
114117
index: displayIndex,
118+
experimentId,
115119
})}
116120
alt={item.prompt}
117121
style={{ objectFit: 'cover' }}

src/renderer/components/Experiment/Diffusion/HistoryImageViewModal.tsx

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ import {
1919
} from 'lucide-react';
2020
import React, { useState, useEffect } from 'react';
2121
import { getAPIFullPath } from 'renderer/lib/transformerlab-api-sdk';
22+
import { useExperimentInfo } from 'renderer/lib/ExperimentInfoContext';
2223
import { HistoryImage } from './types';
2324

2425
export default function HistoryImageViewModal({
@@ -37,6 +38,7 @@ export default function HistoryImageViewModal({
3738
const [currentImageIndex, setCurrentImageIndex] = useState(0);
3839
const [imageUrls, setImageUrls] = useState<string[]>([]);
3940
const [numImages, setNumImages] = useState(1);
41+
const { experimentId } = useExperimentInfo();
4042
// const [hoveringMainImage, setHoveringMainImage] = useState(false);
4143

4244
// Load all images for the selected item when modal opens
@@ -52,6 +54,7 @@ export default function HistoryImageViewModal({
5254
getAPIFullPath('diffusion', ['getImage'], {
5355
imageId: selectedImage.id,
5456
index,
57+
experimentId,
5558
}),
5659
);
5760

@@ -61,6 +64,7 @@ export default function HistoryImageViewModal({
6164
urls.push(
6265
getAPIFullPath('diffusion', ['getInputImage'], {
6366
imageId: selectedImage.id,
67+
experimentId,
6468
}),
6569
);
6670
}
@@ -70,6 +74,7 @@ export default function HistoryImageViewModal({
7074
getAPIFullPath('diffusion', ['getProcessedImage'], {
7175
imageId: selectedImage.id,
7276
processed: true,
77+
experimentId,
7378
}),
7479
);
7580
}
@@ -95,6 +100,7 @@ export default function HistoryImageViewModal({
95100
const link = document.createElement('a');
96101
link.href = getAPIFullPath('diffusion', ['getAllImages'], {
97102
imageId: selectedImage.id,
103+
experimentId,
98104
});
99105

100106
// Generate filename with timestamp
@@ -208,6 +214,7 @@ export default function HistoryImageViewModal({
208214
getAPIFullPath('diffusion', ['getImage'], {
209215
imageId: selectedImage?.id,
210216
index: currentImageIndex,
217+
experimentId,
211218
})
212219
}
213220
alt="Generated"
@@ -231,6 +238,7 @@ export default function HistoryImageViewModal({
231238
<img
232239
src={getAPIFullPath('diffusion', ['getInputImage'], {
233240
imageId: selectedImage?.id,
241+
experimentId,
234242
})}
235243
alt="Input"
236244
style={{
@@ -341,6 +349,7 @@ export default function HistoryImageViewModal({
341349
['getInputImage'],
342350
{
343351
imageId: selectedImage?.id,
352+
experimentId,
344353
},
345354
)}
346355
alt="Input"
@@ -379,6 +388,7 @@ export default function HistoryImageViewModal({
379388
['getInputImage'],
380389
{
381390
imageId: selectedImage?.id,
391+
experimentId,
382392
},
383393
)}
384394
alt="Input"
@@ -414,6 +424,7 @@ export default function HistoryImageViewModal({
414424
<img
415425
src={getAPIFullPath('diffusion', ['getMaskImage'], {
416426
imageId: selectedImage?.id,
427+
experimentId,
417428
})}
418429
alt="Mask"
419430
style={{
@@ -443,6 +454,7 @@ export default function HistoryImageViewModal({
443454
<img
444455
src={getAPIFullPath('diffusion', ['getInputImage'], {
445456
imageId: selectedImage?.id,
457+
experimentId,
446458
})}
447459
alt="ControlNet Input"
448460
style={{
@@ -485,6 +497,7 @@ export default function HistoryImageViewModal({
485497
{
486498
imageId: selectedImage?.id,
487499
processed: true,
500+
experimentId,
488501
},
489502
)}
490503
alt="Preprocessed"

0 commit comments

Comments
 (0)