Это учебный проект, демонстрирующий создание и обучение нейросетевой модели для преобразования изображений математических формул в соответствующую им разметку LaTeX. Проект прошел через несколько итераций, включая классическую архитектуру Encoder-Decoder с Attention и современную архитектуру на основе Трансформера.
Основная цель проекта — не только получить работающую модель, но и изучить на практике ключевые концепции, проблемы и техники, возникающие при решении подобных задач Computer Vision и NLP.
- Архитектура модели
- Данные
- Ключевые этапы и изученные концепции
- Как запустить
- Результаты и выводы
- Возможные пути улучшения
Финальная версия модели построена на архитектуре Encoder-Decoder с использованием Трансформера.
- Задача: Извлечь из входного изображения набор визуальных признаков.
- Реализация:
- Используется предобученная сверточная нейросеть ResNet34 в качестве "хребта" для извлечения признаков (трансферное обучение).
- Верхние слои ResNet, отвечающие за классификацию, удалены.
- Добавлен
Conv2d 1x1
слой для приведения размерности признаков кd_model
, совместимой с Трансформером.
- Задача: Сгенерировать последовательность LaTeX-токенов на основе признаков от энкодера.
- Реализация:
- Используется декодер Трансформера, состоящий из нескольких слоев
nn.TransformerDecoderLayer
. - Positional Encoding добавляется к эмбеддингам токенов, чтобы модель знала о порядке в последовательности.
- Masked Self-Attention позволяет модели понимать контекст уже сгенерированной части формулы.
- Cross-Attention связывает текстовый контекст с визуальными признаками от энкодера.
- Используется декодер Трансформера, состоящий из нескольких слоев
- Источник: im2latex-100k — публичный датасет, содержащий ~100,000 пар "изображение-формула", сгенерированных из реальных научных статей.
- Предварительная обработка:
- Токенизация: LaTeX-строки разбиваются на отдельные токены (команды, символы).
- Построение словаря: Создается словарь всех уникальных токенов.
- Материализация данных: Для ускорения обучения все изображения и формулы заранее обрабатываются и кэшируются в виде тензоров с помощью
torch.save()
. Это позволяет избежать "бутылочного горлышка" при загрузке данных на лету.
В ходе проекта были изучены и решены следующие практические задачи:
-
Отладка производительности:
- Диагностика "узких мест" (bottlenecks) с помощью ручного профилирования и
torch.profiler
. - Решение проблем с медленной загрузкой данных: использование
num_workers
, RAM-диска и, наконец, полной предварительной обработки данных.
- Диагностика "узких мест" (bottlenecks) с помощью ручного профилирования и
-
Оптимизация обучения:
- Применение смешанной точности (AMP) с
torch.amp.autocast
иGradScaler
для ускорения вычислений на GPU. - Использование
torch.compile()
для JIT-компиляции модели и получения дополнительного прироста скорости.
- Применение смешанной точности (AMP) с
-
Улучшение процесса обучения:
- Реализация Ранней Остановки (Early Stopping) для предотвращения переобучения и экономии времени.
- Применение Файн-тюнинга (Fine-tuning) с уменьшенным
learning_rate
для "дожатия" качества модели после выхода на плато.
-
Оценка модели:
- Использование нескольких метрик:
Valid Loss
(для принятия решений),BLEU
(для оценки схожести) иExact Match
(для оценки абсолютной точности). - Реализация "редкой" оценки дорогих метрик для ускорения цикла валидации.
- Использование нескольких метрик:
-
Улучшение инференса:
- Сравнение "жадного" поиска и Beam Search. Анализ проблем (зацикливание, "галлюцинации") при "жадной" генерации.
- Реализация пакетной генерации (
predict_batch
) для ускорения валидации.
Проект реализован в виде Jupyter Notebook, предназначенного для среды Google Colab.
- Среда: Убедитесь, что выбран аппаратный ускоритель GPU (желательно A100 для высокой производительности).
- Подготовка (Ячейки 1-4):
- Запустите ячейки для установки зависимостей, определения конфигурации и скачивания данных.
- При первом запуске будет выполнен долгий процесс предварительной обработки и кэширования данных на ваш Google Drive. Это может занять до 40 минут.
- При всех последующих запусках данные будут быстро загружаться из кэша.
- Обучение (Ячейка 7):
- Запустите ячейку основного цикла обучения. Прогресс будет сохраняться в папку
im2latex_checkpoints
на вашем Google Drive.
- Запустите ячейку основного цикла обучения. Прогресс будет сохраняться в папку
- Тестирование (Ячейки 8 и 9):
- После завершения обучения запустите ячейки для тестирования лучшей сохраненной модели на примерах из тестового набора или на ваших собственных изображениях.
- Финальная модель на основе Трансформера демонстрирует стабильное обучение, а
Valid Loss
неуклонно снижается на протяжении ~30 эпох. - Модель отлично изучила синтаксис LaTeX, генерируя правдоподобные формулы.
- Основная сложность заключается в установлении точной связи между изображением и текстом, что требует длительного обучения.
- Оптимизация производительности, особенно предварительная обработка данных и
torch.compile
, критически важна для эффективного использования мощных GPU.
- Реализовать Beam Search для декодера Трансформера, чтобы значительно улучшить качество генерируемых формул.
- Применить аугментацию данных для изображений, чтобы повысить устойчивость модели.
- Использовать планировщик скорости обучения (
Learning Rate Scheduler
) для автоматизации процесса файн-тюнинга. - Исследовать более продвинутые архитектуры энкодера (например, Swin Transformer) или реализации внимания (например, Flash Attention).