From 28b5dcaac02f0c47a7e5397704ce6f12d13e26e8 Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Wed, 2 Jul 2025 18:35:56 +0000 Subject: [PATCH 1/2] Reorder embedding options in PromptTriggerSelect component Co-authored-by: kent --- .../features/prompt/PromptTriggerSelect.tsx | 26 +++++++++---------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/invokeai/frontend/web/src/features/prompt/PromptTriggerSelect.tsx b/invokeai/frontend/web/src/features/prompt/PromptTriggerSelect.tsx index effe5769e91..da12a1930fb 100644 --- a/invokeai/frontend/web/src/features/prompt/PromptTriggerSelect.tsx +++ b/invokeai/frontend/web/src/features/prompt/PromptTriggerSelect.tsx @@ -42,19 +42,6 @@ export const PromptTriggerSelect = memo(({ onSelect, onClose }: PromptTriggerSel const options = useMemo(() => { const _options: GroupBase[] = []; - if (tiModels) { - const embeddingOptions = tiModels - .filter((ti) => ti.base === mainModelConfig?.base) - .map((model) => ({ label: model.name, value: `<${model.name}>` })); - - if (embeddingOptions.length > 0) { - _options.push({ - label: t('prompt.compatibleEmbeddings'), - options: embeddingOptions, - }); - } - } - if (loraModels) { const triggerPhraseOptions = loraModels .filter((lora) => map(addedLoRAs, (l) => l.model.key).includes(lora.key)) @@ -74,6 +61,19 @@ export const PromptTriggerSelect = memo(({ onSelect, onClose }: PromptTriggerSel } } + if (tiModels) { + const embeddingOptions = tiModels + .filter((ti) => ti.base === mainModelConfig?.base) + .map((model) => ({ label: model.name, value: `<${model.name}>` })); + + if (embeddingOptions.length > 0) { + _options.push({ + label: t('prompt.compatibleEmbeddings'), + options: embeddingOptions, + }); + } + } + if (mainModelConfig && isNonRefinerMainModelConfig(mainModelConfig) && mainModelConfig.trigger_phrases?.length) { _options.push({ label: t('modelManager.mainModelTriggerPhrases'), From 085db1bee1ff38ec607f1d465f2ebf52d28608ea Mon Sep 17 00:00:00 2001 From: Kent Keirsey <31807370+hipsterusername@users.noreply.github.com> Date: Wed, 2 Jul 2025 14:58:20 -0400 Subject: [PATCH 2/2] Added related model support --- .../features/prompt/PromptTriggerSelect.tsx | 63 +++++++++++++++++-- 1 file changed, 59 insertions(+), 4 deletions(-) diff --git a/invokeai/frontend/web/src/features/prompt/PromptTriggerSelect.tsx b/invokeai/frontend/web/src/features/prompt/PromptTriggerSelect.tsx index da12a1930fb..0633a4beb30 100644 --- a/invokeai/frontend/web/src/features/prompt/PromptTriggerSelect.tsx +++ b/invokeai/frontend/web/src/features/prompt/PromptTriggerSelect.tsx @@ -1,5 +1,5 @@ import type { ChakraProps, ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library'; -import { Combobox, FormControl } from '@invoke-ai/ui-library'; +import { Combobox, FormControl, Icon } from '@invoke-ai/ui-library'; import { skipToken } from '@reduxjs/toolkit/query'; import { useAppSelector } from 'app/store/storeHooks'; import type { GroupBase } from 'chakra-react-select'; @@ -10,12 +10,16 @@ import type { PromptTriggerSelectProps } from 'features/prompt/types'; import { t } from 'i18next'; import { memo, useCallback, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; +import { PiLinkSimple } from 'react-icons/pi'; +import { useGetRelatedModelIdsBatchQuery } from 'services/api/endpoints/modelRelationships'; import { useGetModelConfigQuery } from 'services/api/endpoints/models'; import { useEmbeddingModels, useLoRAModels } from 'services/api/hooks/modelsByType'; import { isNonRefinerMainModelConfig } from 'services/api/types'; const noOptionsMessage = () => t('prompt.noMatchingTriggers'); +type RelatedEmbedding = ComboboxOption & { starred?: boolean }; + export const PromptTriggerSelect = memo(({ onSelect, onClose }: PromptTriggerSelectProps) => { const { t } = useTranslation(); @@ -27,6 +31,27 @@ export const PromptTriggerSelect = memo(({ onSelect, onClose }: PromptTriggerSel const [loraModels, { isLoading: isLoadingLoRAs }] = useLoRAModels(); const [tiModels, { isLoading: isLoadingTIs }] = useEmbeddingModels(); + // Get related model keys for current selected models + const selectedModelKeys = useMemo(() => { + const keys: string[] = []; + if (mainModel) { + keys.push(mainModel.key); + } + for (const { model } of addedLoRAs) { + keys.push(model.key); + } + return keys; + }, [mainModel, addedLoRAs]); + + const { relatedModelKeys } = useGetRelatedModelIdsBatchQuery(selectedModelKeys, { + selectFromResult: ({ data }) => { + if (!data) { + return { relatedModelKeys: [] }; + } + return { relatedModelKeys: data }; + }, + }); + const _onChange = useCallback( (v) => { if (!v) { @@ -62,9 +87,25 @@ export const PromptTriggerSelect = memo(({ onSelect, onClose }: PromptTriggerSel } if (tiModels) { - const embeddingOptions = tiModels + // Create embedding options with starred property for related models + const embeddingOptions: RelatedEmbedding[] = tiModels .filter((ti) => ti.base === mainModelConfig?.base) - .map((model) => ({ label: model.name, value: `<${model.name}>` })); + .map((model) => ({ + label: model.name, + value: `<${model.name}>`, + starred: relatedModelKeys.includes(model.key), + })); + + // Sort so related embeddings come first + embeddingOptions.sort((a, b) => { + if (a.starred && !b.starred) { + return -1; + } + if (!a.starred && b.starred) { + return 1; + } + return 0; + }); if (embeddingOptions.length > 0) { _options.push({ @@ -85,7 +126,20 @@ export const PromptTriggerSelect = memo(({ onSelect, onClose }: PromptTriggerSel } return _options; - }, [tiModels, loraModels, mainModelConfig, t, addedLoRAs]); + }, [tiModels, loraModels, mainModelConfig, t, addedLoRAs, relatedModelKeys]); + + const formatOptionLabel = useCallback((option: ComboboxOption) => { + const embeddingOption = option as RelatedEmbedding; + if (embeddingOption.starred) { + return ( +
+ + {option.label} +
+ ); + } + return option.label; + }, []); return ( @@ -104,6 +158,7 @@ export const PromptTriggerSelect = memo(({ onSelect, onClose }: PromptTriggerSel onMenuClose={onClose} data-testid="add-prompt-trigger" sx={selectStyles} + formatOptionLabel={formatOptionLabel} /> );