Использование предикаторной функции для построения ансамбля нейросетей
Тип научной работы: M1P
Автор: Уденеев Александр Владимирович
Научный руководитель: Бахтеев Олег
Научный консультант (при наличии): Бабкин Петр
The automated search for optimal neural network architectures (NAS) is a challenging computational problem, and Neural Ensemble Search (NES) is even more complex.
In this work, we propose a surrogate-based approach for ensemble creation. Neural architectures are represented as graphs, and their predictions on a dataset serve as training data for the surrogate function.
Using this function, we develop an efficient NES framework that enables the selection of diverse and high-performing architectures. The resulting ensemble achieves superior predictive accuracy on CIFAR-10 compared to other one-shot NES methods, demonstrating the effectiveness of our approach.
Keywords: NES, GCN, triplet loss, surrogate function
- To be added
- To be added
Этот проект реализует нейроэволюционный поиск архитектур (NAS) с акцентом на разнообразие моделей в ансамбле. Система состоит из 3 основных компонентов:
- Обучение моделей (
train_models.py
) - обучение ансамбля моделей и оценка их качества - Обучение суррогатных моделей (
train_surrogate.py
) - создание суррогатных моделей для предсказания качества и разнообразия архитектур - Поиск архитектур (
inference_surrogate.py
) - поиск оптимальных архитектур с помощью суррогатных моделей
{
"seed":42, // Сид для воспроизводимости
"num_workers": 4, // Количество процессов для загрузки данных
"device": "cpu", // Устройство для вычислений
"developer_mode": true, // Режим разработчика (в нем модели обучаются лишь на одном батче)
}
Скрипт для обучения ансамбля моделей и оценки их качества.
- Обучение множества моделей с разными архитектурами
- Оценка индивидуальной точности каждой модели
- Оценка точности и калибровки ансамбля моделей
- Сохранение результатов обучения
{
"n_models_to_evaluate": 100, // Количество моделей для генерации, если подготавливаем датасет, иначе ни на что не влияет
"evaluate_ensemble_flag": false, // Флаг оценки ансамбля (true/false). Если true, то оцениваем ансамбль, если false, то подготавливаем датасет
"prepared_dataset_path": "datasets/evaluated_dataset/", // По этому пути лежит подготовленный датасет
"best_models_save_path": "best_models/", // Путь к архитектурам моделей, из которых состоит ансамбля
"dataset_name": "CIFAR10", // Используемый датасет (CIFAR10/CIFAR100/FashionMNIST)
"final_dataset_path": "datasets/final_dataset/", // Путь к папке, куда будем скачивать датасеты для обучения моделей
"n_epochs_final": 1, // Количество эпох обучения
"lr_final": 0.025, // Скорость обучения
"batch_size_final": 96, // Размер батча
"width": 4, // Ширина слоев в DARTS
"num_cells": 3, // Количество ячеек в DARTS
"n_ece_bins": 15 // Количество бинов для расчета ECE
}
- Загрузка или генерация архитектур моделей
- Создание DataLoader'ов для выбранного датасета
- Обучение каждой модели:
- Инициализация архитектуры DARTS
- Обучение
- Сохранение результатов
- При флаге
evaluate_ensemble_flag
:- Оценка ансамбля на тестовых данных
- Расчет точности и ECE
- Сохранение результатов оценки
Скрипт для обучения суррогатных моделей, предсказывающих качество и разнообразие архитектур.
- Загрузка датасета с архитектурами и результатами
- Расчет матрицы разнообразия
- Преобразование архитектур в графы
- Обучение GAT-моделей для предсказания:
- Точности архитектуры
- Эмбеддингов разнообразия
{
"dataset_path": "datasets/third_dataset/", // Путь к датасету архитектур
"n_models": 300, // Количество используемых моделей
"diversity_matrix_metric": "overlap", // Метрика разнообразия (overlap/js)
"upper_margin": 0.75, // Верхний квантиль для дискретизации матрицы похожести
"lower_margin": 0.25, // Нижний квантиль для дискретизации матрицы похожести
"input_dim": 8, // Размерность признаков
"acc_num_epochs": 10, // Количество эпох обучения модели точности
"acc_lr": 1e-2, // LR для модели точности
"acc_dropout": 0.2, // Dropout для модели точности
"acc_n_heads": 16, // Количество голов в модели точности
"acc_final_lr": 1e-5, // eta_min для Cosine scheduler
"div_num_epochs": 5, // Количество эпох обучения модели разнообразия
"div_lr": 1e-3, // LR для модели разнообразия
"div_dropout": 0.1, // Dropout для модели разнообразия
"div_n_heads": 4, // Количество голов в модели разнообразия
"div_final_lr": 1e-6, // eta_min для Cosine scheduler
"margin": 1, // Отступ для triplet loss
"div_output_dim": 128, // Размерность эмбеддинга разнообразия
"surrogate_inference_path": "surrogate_models/", // Путь для сохранения моделей
"tr
8000
ain_size": 0.8, // Размер тренировочной выборки
"batch_size": 8, // Размер батча
"draw_fig_acc": false, // Отрисовывать ли график зависимости точности от эпохи (true/false)
"draw_fig_div": false, // Отрисовывать ли график зависимости triplet loss от эпохи (true/false)
}
- Загрузка датасета с архитектурами и результатами
- Расчет матрицы разнообразия между моделями
- Преобразование матрицы в дискретный вид
- Преобразование архитектур в графовые представления
- Создание датасетов для обучения:
- Для предсказания точности
- Для обучения эмбеддингов разнообразия (триплеты)
- Обучение двух GAT-моделей:
- Модель точности (регрессия)
- Модель разнообразия (эмбеддинги)
- Сохранение обученных моделей
Скрипт для поиска оптимальных архитектур с помощью обученных суррогатных моделей.
- Инициализация обученных суррогатных моделей
- Генерация новых архитектур
- Предсказание точности и эмбеддингов
- Отбор архитектур по точности и разнообразию
- Кластеризация архитектур и выбор представителей
- Визуализация результатов
- Сохранение лучших архитектур
{
"n_ensemble_models": 2, // Количество моделей в ансамбле
"n_models_in_pool": 128, // Размер пула кандидатов
"n_models_to_generate": 4096, // Количество генерируемых архитектур
"min_accuracy_for_pool": 0.01, // Минимальная точность для попадания в пул
"plot_tsne": false, // Флаг визуализации t-SNE
"best_models_save_path": "best_models/" // Путь для сохранения лучших архитектур
}
- Загрузка обученных суррогатных моделей
- Генерация архитектур:
- Генерация большого количества случайных архитектур
- Предсказание их точности и эмбеддингов разнообразия
- Фильтрация по минимальной точности
- Формирование пула кандидатов:
- Постепенное заполнение пула лучшими архитектурами
- Отбор по максимальному расстоянию в пространстве эмбеддингов
- Кластеризация:
- Кластеризация архитектур в пуле
- Выбор наиболее репрезентативных моделей из каждого кластера
- Визуализация (при включенном флаге):
- PCA + t-SNE для визуализации пространства эмбеддингов
- Отображение кластеров и выбранных моделей
- Сохранение лучших архитектур
# Перед запуском необходимо скачать выставить флаг "evaluate_ensemble_flag": true
./start_all.sh
- Подготовка датасета:
# Перед запуском необходимо выставить флаг "evaluate_ensemble_flag": false и указать количество моделей
# для оценки
python train_models.py --hyperparameters_json surrogate_hp.json
- Обучение суррогатных моделей:
python train_surrogate.py --hyperparameters_json surrogate_hp.json
- Поиск архитектур:
python inference_surrogate.py --hyperparameters_json surrogate_hp.json
- Оценка ансамбля:
# В файле конфигурации установить "evaluate_ensemble_flag": true
python train_models.py --hyperparameters_json surrogate_hp.json
- Python 3.8+
- PyTorch 1.10+
- torchvision
- scikit-learn
- tqdm
- NNI (Neural Network Intelligence)