Shortcut models - метод обучения диффузионных моделей, который позволяет генерировать изображения высокого качества за один или несколько шагов.
В основе shortcut models - идея обучать сеть с учетом не только текущего уровня шума, но и желаемого размера шага. Это позволяет модели "перепрыгивать" через этапы генерации.
Ключевым преимуществом данного подхода является его простота: shortcut models обучаются за один этап, используя одну сеть, в отличие от других методов ускорения выборки, которые полагаются на сложные схемы обучения с несколькими фазами, сетями или точной настройкой шедулера.
В процессе обучения shortcut models используются два типа целей loss function:
Совместная оптимизация этих целей дает возможность модели научиться создавать изображения, сохраняя согласованность при любом размере шага, включая генерацию за один шаг.
Метод применим к flow-matching и transformer-based типам моделей и RNN/LSTM-сетям.
Эксперименты, проведенные с DiT на наборах данных CelebA-HQ и ImageNet-256, подтверждают эффективность метода.
Shortcut models превосходят методы "end-to-end" обучения одношаговых генеративных моделей и конкурируют с двухэтапными методами дистилляции.
Практическая реализация shortcut models написана на JAX. Для локального запуска следует установить зависимости conda из файлов environment.yml и requirements.txt репозитория.
⚠️ Код поддерживает
--model.sharding fsdp
для полностью сегментированного параллелизма данных, если обучение проводится на multi-GPU или TPU.⚠️ Чекпоинты и FID для тестовых датасетов CelebA и Imagenet доступны на Google-диске.
python train.py --model.hidden_size 768 --model.patch_size 2 --model.depth 12 --model.num_heads 12 --model.mlp_ratio 4
--dataset_name celebahq256 --fid_stats data/celeba256_fidstats_ours.npz --model.cfg_scale 0 --model.class_dropout_prob 1 --model.num_classes 1 --batch_size 64 --max_steps 410_000 --model.train_type shortcut
@ai_machinelearning_big_data
#AI #ML #ShortcutModels #Training
Please open Telegram to view this post
VIEW IN TELEGRAM
Please open Telegram to view this post
VIEW IN TELEGRAM
❤17👍13🔥6