It’s like a photo booth, but once the subject is captured, it can be synthesized wherever your dreams take you…
Оригинальная статья: DreamBooth: Fine Tuning Text-to-Image Diffusion Models for Subject-Driven Generation
Современные большие text-to-image модели могут достаточно точно и разнообразно генерировать изображения по текстовому запросу. Однако, возникает проблема, если мы хотим получить изображение с конкретным объектом, примеры которого у нас есть (3-5 изображений), но в измененном контексте, задаваемом промтом. Например, сгенерировать фотографию с собой/своим животным/любимой вещью в известном туристическом месте. Данную задачу позволяет решить подход DreamBooth.
Тривиальный способ научить диффузионную модель генерировать заданный объект — это присвоить ему уникальный идентификатор (предлагается брать наиболее редко используемые токены, ведь от их оригинальной смысловой нагрузки будет проще всего избавиться, в промте будем обозначать как 'a [V] [class noun]'), чтобы модель понимала, что в промте речь идет про конкретный объект, а затем ее зафайнтьюнить как обычную диффузию. Функция потерь запишется как
где
- Language drift. Изначально появился в языковых моделях, предобученных на больших корпусах текста. После файнтьюна под узкую задачу такие модели переставали понимать синтаксис и семантику языка. При переносе на диффузионные модели, получаем, что нейросеть забывает, как должны строиться изображения, располагаться объекты по отношению друг к другу. Более того, весь класс, которому принадлежит наш объект, может у модели теперь ассоциироваться с конкретным его экземпляром.
- Reduced output diversity. Text-to-image диффузионные модели выдают широкий спектр разнообразных изображений. Но после дообучения на небольшом наборе данных это разнообразие теряется, и мы больше не в состоянии получить фотографию желаемого объекта в другой позе или с другого ракурса.
Побороть обе проблемы можно, если дообучать не только на небольшом датасете с интересуемым объектом, но и на изображениях, генерируемых моделью до файнтьюна по промту, содержащему лишь класс объекта. Таким образом, мы и сохраняем разнообразие и не переобучаемся на маленькой выборке.
Лосс-функция теперь перепишется как
здесь все обозначения совпадают с предыдущей формулой, за исключением:
-
$x_{pr} = \hat{x}(z_{t_1}, c_{pr})$ — данные, сгенерированные предобученной диффузионной моделью с замороженными весами из шума$z_{t_1} \sim N(0, 1)$ -
$c_{pr} := \Gamma(f('\text{a [class noun]}'))$ — вектор на выходе энкодера текста -
$\lambda$ — параметр, контролирующий отношение слагаемых.
Первое слагаемое назовем Reconstruction Loss, второе — Class-Specific Prior Preservation Loss. Измененный процесс файнтьюна можно изобразить в виде схемы:
Как можно заметить, на изображениях входного датасета собака лежит на мягких поверхностях, и на сгенерированных изображениях без prior-preservation loss'а она тоже лежит на похожих поверхностях. А на изображениях, полученных с prior-preservation loss'ом, собака стоит и сидит на отличающихся поверхностях.
Также авторам удалось достичь успеха в следующих задачах синтеза изображений:
- Recontextualization — изменение окружения объекта
- Novel View Synthesis — синтез новых ракурсов
- Art Renditions — генерация в стиле картин великих художников
- Property Modification — изменение качеств объекта
В качестве домена я выбрал фотографии кота своего друга. Вот несколько из них:
Параметры дообучения:
- Количество шагов - 500
- Ранг LoRA - 16 (больше смысла не имеет, меньше может быть недостаточно, чтобы сдвинуть домен)
По итогу получилась модель, способная генерировать наш объект в различных контекстах.
Теперь посмотрим, как меняется результат на инференсе в зависимости от гиперпараметров. Их можно выделить 3: LORA_SCALE_UNET, LORA_SCALE_TEXT_ENCODER, GUIDANCE. Первые два отвечают за то, на какой коэффициент будет умножаться прибавка от LoRA, то есть параметру 0 соответствует предобученная модель, а параметру 1 — модель, полученная после файнтьюна. GUIDANCE — насколько близко к промту будет сгенерированное изображение. Чем больше GUIDANCE, тем более уникальные и разнообразные изображения будут получаться, однако, качество самих изображений будет ухудшаться.
Посмотрим, как влияют параметры LORA_SCALE*, значение GUIDANCE будет зафиксировано примерно посередине разумного диапазона числом 7.5.
Здесь сверху вниз растет LORA_SCALE_UNET, слева направо LORA_SCALE_TEXT_ENCODER. Как можно заметить, при генерации изображения по стандартному промту "A photo of a [V] cat", значения коэффициента энкодера не влияют. Но выберем оптимальные по качеству значения коэффициентов (0.8, 0.8). Теперь посмотрим на влияние GUIDANCE при уже заданных LORA_SCALE:
При значениях больше 9 становится хуже качество изображений, появляются артефакты. А при значениях меньше 6 изображения достаточно похожи на таковые из референса. Выберем оптимальным значение GUIDANCE равное 7. Теперь зная хорошие значения на инференсе, попробуем погенерировать изображения с разными промтами.
Чтобы убедиться, что наша модель не забыла, как выглядят коты в общем случае, запустим генерацию изображений несколько раз на промте "A photo of a cat". Вот полученные картинки:
Полученные коты отличаются от нашего заданного цветом шерсти и рисунком на ней.
Теперь посмотрим, как модель генерирует нашего кота в различных контекстах.
Попробуем генерацию изображений по промтам вида "A photo of a [V] cat in a [place]".
-
"A photo of a [V] cat in a bath"
-
"A photo of a [V] cat driving a car"
-
"Pic-A photo of a [V] cat on a moon surface.jpg"
-
"A photo of a [V] cat in a snow."
Попробуем повторить результаты статьи, где они скрещивали таргетный объект с другими животными. Вот несколько полученных изображений.
-
"A photo of a [V] cat crossed with a hippo"
-
"A photo of a [V] cat crossed with a panda"
-
"A photo of a [V] cat crossed with a koala"
-
"A photo of a [V] cat crossed with a lion"
-
"A photo of a depressed [V] cat"
-
"A photo of a sad [V] cat"
-
"A photo of a happy [V] cat"
-
"A photo of a screaming [V] cat"
Попробуем получить другие ракурсы нашего кота.
-
"A photo of a [V] cat seen from the back"
-
"A photo of a [V] cat seen from the bottom"
-
"A photo of a [V] cat seen from the side"
-
"A photo of a [V] cat seen from the top"
Полностью повторить результаты оригинальной статьи не получилось. Объяснить это можно тем, что мы обучали LoRA, а не дообучали всю модель. К тому же можно было генерировать больше классовых изображений и возможно от этого получить большее разнообразие. Однако, удалось связать "[V]" в промте с конкретными параметрами объекта и получить изображения, сильно отличающиеся от данных в reference set. При этом не был смещен весь домен класса. В поисках лучшего результата можно поэкспериментировать с количеством шагов на обучении и на инференсе. Код можно найти в ноутбуке. Также можно поварьировать коэффициент, меняющий соотношение между Reconstruction Loss и Class-Specific Prior Preservation Loss.