Skip to content

Commit 6e979a3

Browse files
authored
Merge pull request #792 from transformerlab/add/audio-batched-generation
Add batched generation button and modal for audio
2 parents 1549ac8 + d3e30a9 commit 6e979a3

File tree

4 files changed

+237
-6
lines changed

4 files changed

+237
-6
lines changed

src/renderer/components/Experiment/Audio/Audio.tsx

Lines changed: 58 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,9 @@ import {
1010
Box,
1111
Select,
1212
Option,
13-
Textarea,
1413
Stack,
1514
Slider,
1615
FormLabel,
17-
Switch,
18-
Input,
1916
Modal,
2017
ModalDialog,
2118
ModalClose,
@@ -24,6 +21,7 @@ import {
2421
Card,
2522
Alert,
2623
} from '@mui/joy';
24+
import BatchedAudioModal from './BatchedAudioModal';
2725
import { useAPI } from '../../../lib/transformerlab-api-sdk';
2826
import AudioHistory from './AudioHistory';
2927

@@ -175,6 +173,7 @@ export default function Audio() {
175173
const [topP, setTopP] = React.useState(1.0);
176174
const [selectedLanguage, setSelectedLanguage] = React.useState('');
177175
const [selectedVoice, setSelectedVoice] = React.useState('');
176+
const [showBatchModal, setShowBatchModal] = React.useState(false);
178177

179178
const [showSettingsModal, setShowSettingsModal] = React.useState(false);
180179

@@ -194,7 +193,7 @@ export default function Audio() {
194193
setAudioUrl(null);
195194
setErrorMessage(null);
196195

197-
const result = await sendAndReceiveAudioPath(
196+
const result: any = await sendAndReceiveAudioPath(
198197
experimentInfo?.id,
199198
currentModel,
200199
adaptor,
@@ -283,7 +282,14 @@ export default function Audio() {
283282
}}
284283
>
285284
<Typography level="h2">Text to Speech</Typography>
286-
<Box sx={{ textAlign: 'right' }}>
285+
<Box
286+
sx={{
287+
textAlign: 'right',
288+
display: 'flex',
289+
alignItems: 'center',
290+
gap: 1,
291+
}}
292+
>
287293
<Typography level="body-sm">{currentModel}</Typography>
288294
{adaptor && (
289295
<Typography level="body-xs" color="neutral">
@@ -483,12 +489,13 @@ export default function Audio() {
483489
width: '100%',
484490
}}
485491
>
492+
{/* Large text input area at the top */}
486493
{/* Large text input area at the top */}
487494
<FormControl sx={{ mt: 1 }}>
488495
<textarea
489496
value={text}
490497
onChange={(e) => setText(e.target.value)}
491-
placeholder="Enter your text here for speech generation..."
498+
placeholder={'Enter your text here for speech generation...'}
492499
style={{
493500
minHeight: '100px',
494501
padding: '16px',
@@ -519,6 +526,12 @@ export default function Audio() {
519526
>
520527
Generate Speech
521528
</Button>
529+
<Button
530+
variant="outlined"
531+
onClick={() => setShowBatchModal(true)}
532+
>
533+
Create Prompt Batch
534+
</Button>
522535
</Stack>
523536

524537
{errorMessage && (
@@ -547,6 +560,45 @@ export default function Audio() {
547560
</ModalDialog>
548561
</Modal>
549562

563+
<BatchedAudioModal
564+
open={showBatchModal}
565+
onClose={() => setShowBatchModal(false)}
566+
isLoading={isLoading}
567+
onSubmit={async (lines: string[]) => {
568+
setIsLoading(true);
569+
setErrorMessage(null);
570+
try {
571+
const result = await chatAPI.sendBatchedAudio(
572+
experimentInfo?.id,
573+
currentModel,
574+
adaptor,
575+
lines,
576+
filePrefix,
577+
sampleRate,
578+
temperature,
579+
speed,
580+
topP,
581+
selectedVoice || undefined,
582+
uploadedAudioPath || undefined,
583+
);
584+
const anyOk = Array.isArray(result)
585+
? result.some((r) => r && r.message)
586+
: false;
587+
if (!anyOk) {
588+
setErrorMessage('Batched generation failed.');
589+
}
590+
setShowBatchModal(false);
591+
handleClearUpload();
592+
mutateHistory();
593+
if (audioHistoryRef.current) {
594+
audioHistoryRef.current.scrollTop = 0;
595+
}
596+
} finally {
597+
setIsLoading(false);
598+
}
599+
}}
600+
/>
601+
550602
{/* No Model Running Modal */}
551603
<Sheet
552604
sx={{
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
import * as React from 'react';
2+
import {
3+
Modal,
4+
ModalDialog,
5+
ModalClose,
6+
DialogTitle,
7+
Divider,
8+
Stack,
9+
Button,
10+
Typography,
11+
FormControl,
12+
} from '@mui/joy';
13+
14+
type BatchedAudioModalProps = {
15+
open: boolean;
16+
onClose: () => void;
17+
isLoading?: boolean;
18+
onSubmit: (prompts: string[]) => Promise<void> | void;
19+
};
20+
21+
export default function BatchedAudioModal({
22+
open,
23+
onClose,
24+
isLoading = false,
25+
onSubmit,
26+
}: BatchedAudioModalProps) {
27+
const [prompts, setPrompts] = React.useState<string[]>(['']);
28+
29+
function updatePrompt(index: number, value: string) {
30+
const next = [...prompts];
31+
next[index] = value;
32+
setPrompts(next);
33+
}
34+
35+
function addPrompt() {
36+
setPrompts((prev) => [...prev, '']);
37+
}
38+
39+
function removePrompt(index: number) {
40+
setPrompts((prev) => prev.filter((_, i) => i !== index));
41+
}
42+
43+
function resetPrompts() {
44+
setPrompts(['']);
45+
}
46+
47+
async function handleSubmit() {
48+
const cleaned = prompts.map((p) => p.trim()).filter((p) => p.length > 0);
49+
if (cleaned.length === 0) return;
50+
await onSubmit(cleaned);
51+
}
52+
53+
return (
54+
<Modal open={open} onClose={onClose}>
55+
<ModalDialog variant="outlined" sx={{ minWidth: 600, maxWidth: 900 }}>
56+
<ModalClose />
57+
<DialogTitle>Send Batched Prompts</DialogTitle>
58+
<Divider />
59+
<Stack spacing={2} sx={{ mt: 1 }}>
60+
<Typography level="body-sm" color="neutral">
61+
Add one or more prompts. Each prompt can be multi-line. A separate
62+
audio file will be generated for each prompt.
63+
</Typography>
64+
65+
<Stack spacing={1} sx={{ maxHeight: 360, overflowY: 'auto' }}>
66+
{prompts.map((value, idx) => (
67+
<FormControl key={idx} sx={{ gap: 0.5 }}>
68+
<textarea
69+
value={value}
70+
onChange={(e) => updatePrompt(idx, e.target.value)}
71+
placeholder={`Prompt ${idx + 1}`}
72+
style={{
73+
minHeight: '100px',
74+
padding: '12px',
75+
borderRadius: '8px',
76+
fontSize: '14px',
77+
lineHeight: '1.5',
78+
overflowY: 'auto',
79+
width: '100%',
80+
}}
81+
/>
82+
<Stack
83+
direction="row"
84+
spacing={1}
85+
sx={{ alignSelf: 'flex-end' }}
86+
>
87+
{prompts.length > 1 && (
88+
<Button
89+
size="sm"
90+
variant="plain"
91+
color="danger"
92+
onClick={() => removePrompt(idx)}
93+
>
94+
Remove
95+
</Button>
96+
)}
97+
</Stack>
98+
</FormControl>
99+
))}
100+
</Stack>
101+
102+
<Stack direction="row" spacing={1}>
103+
<Button variant="outlined" onClick={addPrompt}>
104+
+ Add Prompt
105+
</Button>
106+
</Stack>
107+
108+
<Stack direction="row" spacing={1} justifyContent="flex-end">
109+
<Button variant="plain" onClick={onClose}>
110+
Cancel
111+
</Button>
112+
<Button variant="outlined" onClick={resetPrompts}>
113+
Reset
114+
</Button>
115+
<Button
116+
variant="solid"
117+
disabled={isLoading}
118+
loading={isLoading}
119+
onClick={handleSubmit}
120+
>
121+
Send Batch
122+
</Button>
123+
</Stack>
124+
</Stack>
125+
</ModalDialog>
126+
</Modal>
127+
);
128+
}

src/renderer/lib/api-client/chat.ts

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -764,6 +764,56 @@ export async function sendBatchedChat(
764764
return results;
765765
}
766766

767+
// Batched Text-to-Speech: send multiple texts to generate multiple audios
768+
export async function sendBatchedAudio(
769+
experimentId: number,
770+
currentModel: string,
771+
adaptor: string,
772+
texts: string[],
773+
filePrefix: string,
774+
sampleRate: number,
775+
temperature: number,
776+
speed: number,
777+
topP: number,
778+
voice?: string,
779+
audioPath?: string,
780+
batchSize: number = 64,
781+
): Promise<any[] | null> {
782+
const data: any = {
783+
experiment_id: experimentId,
784+
model: currentModel,
785+
adaptor: adaptor,
786+
texts: texts,
787+
file_prefix: filePrefix,
788+
sample_rate: sampleRate,
789+
temperature: temperature,
790+
speed: speed,
791+
top_p: topP,
792+
batch_size: batchSize,
793+
inference_url: `${INFERENCE_SERVER_URL()}v1/audio/speech`,
794+
};
795+
796+
if (voice) data.voice = voice;
797+
if (audioPath) data.audio_path = audioPath;
798+
799+
try {
800+
const response = await fetch(`${API_URL()}batch/audio/speech`, {
801+
method: 'POST',
802+
headers: {
803+
'Content-Type': 'application/json',
804+
accept: 'application/json',
805+
},
806+
body: JSON.stringify(data),
807+
});
808+
if (!response.ok) return null;
809+
const results = await response.json();
810+
return results;
811+
} catch (err) {
812+
console.log('Error in sendBatchedAudio:', err);
813+
return null;
814+
}
815+
}
816+
767817
export async function callTool(
768818
function_name: String,
769819
function_args: Object = {},

src/renderer/lib/transformerlab-api-sdk.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ export {
1313
sendCompletionReactWay,
1414
sendBatchedCompletion,
1515
sendBatchedChat,
16+
sendBatchedAudio,
1617
callTool,
1718
getToolsForCompletions,
1819
getEmbeddings,

0 commit comments

Comments
 (0)