Skip to content

Commit 19d882d

Browse files
hipsterusernamepsychedelicious
authored andcommitted
Address comments
1 parent ee4bc49 commit 19d882d

File tree

6 files changed

+234
-231
lines changed

6 files changed

+234
-231
lines changed
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
import { flatMap, negate, uniqWith } from 'lodash-es';
2+
import { useCallback, useMemo } from 'react';
3+
import { useTranslation } from 'react-i18next';
4+
import { useInstallModelMutation } from 'services/api/endpoints/models';
5+
import { toast } from 'features/toast/toast';
6+
import { flattenStarterModel, useBuildModelInstallArg } from './useBuildModelsToInstall';
7+
import type { StarterModel } from 'services/api/types';
8+
9+
export const useStarterBundleInstall = () => {
10+
const [installModel] = useInstallModelMutation();
11+
const { getIsInstalled, buildModelInstallArg } = useBuildModelInstallArg();
12+
const { t } = useTranslation();
13+
14+
const getModelsToInstall = useCallback(
15+
(bundle: StarterModel[]) => {
16+
// Flatten the models and remove duplicates, which is expected as models can have the same dependencies
17+
const flattenedModels = flatMap(bundle, flattenStarterModel);
18+
const uniqueModels = uniqWith(
19+
flattenedModels,
20+
(m1, m2) => m1.source === m2.source || (m1.name === m2.name && m1.base === m2.base && m1.type === m2.type)
21+
);
22+
// We want to install models that are not installed and skip models that are already installed
23+
const install = uniqueModels.filter(negate(getIsInstalled)).map(buildModelInstallArg);
24+
const skip = uniqueModels.filter(getIsInstalled).map(buildModelInstallArg);
25+
26+
return { install, skip };
27+
},
28+
[getIsInstalled, buildModelInstallArg]
29+
);
30+
31+
const installBundle = useCallback(
32+
(bundle: StarterModel[], bundleName?: string) => {
33+
const modelsToInstall = getModelsToInstall(bundle);
34+
35+
if (modelsToInstall.install.length === 0) {
36+
if (bundleName) {
37+
toast({
38+
status: 'info',
39+
title: t('modelManager.bundleAlreadyInstalled', { bundleName }),
40+
description: t('modelManager.allModelsAlreadyInstalled'),
41+
});
42+
}
43+
return;
44+
}
45+
46+
// Install all models in the bundle
47+
modelsToInstall.install.forEach(installModel);
48+
49+
let description = t('modelManager.installingXModels', { count: modelsToInstall.install.length });
50+
if (modelsToInstall.skip.length > 1) {
51+
description += t('modelManager.skippingXDuplicates', { count: modelsToInstall.skip.length - 1 });
52+
}
53+
54+
toast({
55+
status: 'info',
56+
title: t('modelManager.installingBundle'),
57+
description,
58+
});
59+
},
60+
[getModelsToInstall, installModel, t]
61+
);
62+
63+
return { installBundle, getModelsToInstall };
64+
};

invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/InstallModelForm.tsx

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -55,23 +55,21 @@ export const InstallModelForm = memo(() => {
5555
return (
5656
<form onSubmit={handleSubmit(onSubmit)}>
5757
<Flex flexDir="column" gap={4}>
58-
<Flex gap={2} alignItems="flex-end" justifyContent="space-between">
59-
<FormControl orientation="vertical">
60-
<FormLabel>{t('modelManager.urlOrLocalPath')}</FormLabel>
61-
<Flex alignItems="center" gap={3} w="full">
62-
<Input placeholder={t('modelManager.simpleModelPlaceholder')} {...register('location')} />
63-
<Button
64-
onClick={handleSubmit(onSubmit)}
65-
isDisabled={!formState.dirtyFields.location}
66-
isLoading={isLoading}
67-
size="sm"
68-
>
69-
{t('modelManager.install')}
70-
</Button>
71-
</Flex>
72-
<FormHelperText>{t('modelManager.urlOrLocalPathHelper')}</FormHelperText>
73-
</FormControl>
74-
</Flex>
58+
<FormControl orientation="vertical">
59+
<FormLabel>{t('modelManager.urlOrLocalPath')}</FormLabel>
60+
<Flex alignItems="center" gap={3} w="full">
61+
<Input placeholder={t('modelManager.simpleModelPlaceholder')} {...register('location')} />
62+
<Button
63+
onClick={handleSubmit(onSubmit)}
64+
isDisabled={!formState.dirtyFields.location}
65+
isLoading={isLoading}
66+
size="sm"
67+
>
68+
{t('modelManager.install')}
69+
</Button>
70+
</Flex>
71+
<FormHelperText>{t('modelManager.urlOrLocalPathHelper')}</FormHelperText>
72+
</FormControl>
7573

7674
<FormControl>
7775
<Flex flexDir="column" gap={2}>

0 commit comments

Comments
 (0)