Передача нейронного стиля с помощью TensorFlow
Перенос нейронного стиля - это метод оптимизации, используемый для получения двух изображений, изображения содержимого и эталонного изображения стиля (например, работы известного художника), и смешивания их вместе, чтобы выходное изображение выглядело как изображение содержимого, но было «нарисовано» в стиле эталонного изображения стиля. Этот метод используется многими популярными приложениями для Android iOS, такими как Prisma , DreamScope , PicsArt .
Архитектура :
Бумага для переноса в нейронном стиле использует карты характеристик, сгенерированные промежуточными слоями сети VGG-19, для генерации выходного изображения. Эта архитектура принимает изображения стиля и содержимого в качестве входных данных и сохраняет функции, извлеченные сверточными слоями сети VGG.
Потеря контента:
Для расчета стоимости контента мы применяем среднеквадратичную разницу между матрицами, сгенерированными слоем контента, когда мы передаем сгенерированное изображение и исходное изображение. Пусть p и x - исходное изображение и изображение, которое сгенерировано, а P и F - их соответствующие представления признаков в слое l . Затем мы определяем потерю квадратичной ошибки между двумя представлениями функций.
Потеря стиля :
Для расчета стоимости стиля сначала рассчитаем матрицу граммов. Вычисление матриц грамма включает вычисление внутреннего продукта между векторизованными картами признаков определенного слоя . Здесь G ij (l) представляет собой внутренний продукт между векторизованными объектами i, j слоя l.
Теперь, чтобы вычислить потери от конкретного, мы найдем среднеквадратическую разницу матриц грамма, вычисленную из векторов признаков изображения стиля и сгенерированного изображения. Затем это взвешено с коэффициентом взвешивания слоя.
Пусть a и x - исходное изображение и сгенерированное изображение, а Al и Gl - их соответствующее представление стиля (матрицы граммов) в слое l. Вклад слоя l в общие потери тогда равен:
Таким образом, общая потеря стиля будет:
Общая потеря
Полная потеря - это линейная комбинация стиля и потери контента, которую мы определили выше:
Где α и β - весовые коэффициенты для реконструкции контента и стиля соответственно.
Реализация в Tensorflow:
- Сначала импортируем необходимый модуль. В этом посте мы используем TensorFlow v2 с Keras. Мы также импортируем модель VGG-19 из tf.keras API.
Код:
# import numpy, tensorflow and matplotlib import tensorflow as tf import numpy as np import matplotlib.pyplot as plt # import VGG 19 model and keras Model API from tensorflow.python.keras.applications.vgg19 import VGG19, preprocess_input from tensorflow.python.keras.preprocessing.image import load_img, img_to_array from tensorflow.python.keras.models import Model |
- Теперь мы импортируем контент и изображения стиля и сохраняем их в нашем рабочем каталоге.
Код:
# Image Credits: Tensorflow Doc content_path = tf.keras.utils.get_file( 'content.jpg' , ' https://storage.googleapis.com/download.tensorflow.org/example_images/YellowLabradorLooking_new.jpg ' ) style_path = tf.keras.utils.get_file( 'style.jpg' , |
- Теперь мы инициализируем модель VGG с весами ImageNet, мы также удалим верхние слои и сделаем ее не обучаемой.
Код:
# code # this function download the VGG model and initiliase it model = VGG19( include_top = False , weights = 'imagenet' ) # set training to False model.trainable = False # Print details of different layers model.summary() |
Выход:
Загрузка данных с https://storage.googleapis.com/tensorflow/keras-applications/vgg19/vgg19_weights_tf_dim_ordering_tf_kernels_notop.h5 80142336/80134624 [==============================] - 1 с 0 мкс / шаг Модель: "vgg19" _________________________________________________________________ Слой (тип) Параметр формы вывода # ================================================== =============== input_1 (InputLayer) [(Нет, Нет, Нет, 3)] 0 _________________________________________________________________ block1_conv1 (Conv2D) (Нет, Нет, Нет, 64) 1792 _________________________________________________________________ block1_conv2 (Conv2D) (Нет, Нет, Нет, 64) 36928 _________________________________________________________________ block1_pool (MaxPooling2D) (Нет, Нет, Нет, 64) 0 _________________________________________________________________ block2_conv1 (Conv2D) (Нет, Нет, Нет, 128) 73856 _________________________________________________________________ block2_conv2 (Conv2D) (Нет, Нет, Нет, 128) 147584 _________________________________________________________________ block2_pool (MaxPooling2D) (Нет, Нет, Нет, 128) 0 _________________________________________________________________ block3_conv1 (Conv2D) (Нет, Нет, Нет, 256) 295168 _________________________________________________________________ block3_conv2 (Conv2D) (Нет, Нет, Нет, 256) 590080 _________________________________________________________________ block3_conv3 (Conv2D) (Нет, Нет, Нет, 256) 590080 _________________________________________________________________ block3_conv4 (Conv2D) (Нет, Нет, Нет, 256) 590080 _________________________________________________________________ block3_pool (MaxPooling2D) (Нет, Нет, Нет, 256) 0 _________________________________________________________________ block4_conv1 (Conv2D) (Нет, Нет, Нет, 512) 1180160 _________________________________________________________________ block4_conv2 (Conv2D) (Нет, Нет, Нет, 512) 2359808 _________________________________________________________________ block4_conv3 (Conv2D) (Нет, Нет, Нет, 512) 2359808 _________________________________________________________________ block4_conv4 (Conv2D) (Нет, Нет, Нет, 512) 2359808 _________________________________________________________________ block4_pool (MaxPooling2D) (Нет, Нет, Нет, 512) 0 _________________________________________________________________ block5_conv1 (Conv2D) (Нет, Нет, Нет, 512) 2359808 _________________________________________________________________ block5_conv2 (Conv2D) (Нет, Нет, Нет, 512) 2359808 _________________________________________________________________ block5_conv3 (Conv2D) (Нет, Нет, Нет, 512) 2359808 _________________________________________________________________ block5_conv4 (Conv2D) (Нет, Нет, Нет, 512) 2359808 _________________________________________________________________ block5_pool (MaxPooling2D) (Нет, Нет, Нет, 512) 0 ================================================== =============== Всего параметров: 20 024 384 Обучаемые параметры: 0 Необучаемые параметры: 20 024 384 ________________________________________________________________
- Теперь мы загружаем и обрабатываем изображение с помощью входных данных препроцессора Keras в VGG 19. Функция expand_dims добавляет измерение для представления ряда изображений во входных данных. Эта функция preprocess_input (используемая в VGG 19) преобразует входные изображения RGB в изображения BGR и центрирует эти значения вокруг 0 в соответствии с данными ImageNet (без масштабирования).
Код:
# code to load and process image def load_and_process_image(image_path): img = load_img(image_path) # convert image to array img = img_to_array(img) img = preprocess_input(img) img = np.expand_dims(img, axis = 0 ) return img |
- Теперь мы определяем функцию депроцесса, которая принимает входное изображение и выполняет обратную функцию preprocess_input, которую мы импортировали выше. Для отображения необработанного изображения мы также определяем функцию отображения.
Код:
# code def deprocess(img): # perform the inverse of the pre processing step img[:, :, 0 ] + = 103.939 img[:, :, 1 ] + = 116.779 img[:, :, 2 ] + = 123.68 # convert RGB to BGR img = img[:, :, :: - 1 ] img = np.clip(img, 0 , 255 ).astype( 'uint8' ) return img def display_image(image): # remove one dimension if image has 4 dimension if len (image.shape) = = 4 : img = np.squeeze(image, axis = 0 ) img = deprocess(img) plt.grid( False ) plt.xticks([]) plt.yticks([]) plt.imshow(img) return |
- Теперь мы используем указанную выше функцию для отображения изображений стиля и содержимого.
Код:
# load content image content_img = load_and_process_image(content_path) display_image(content_img) # load style image style_img = load_and_process_image(style_path) display_image(style_img) |
Выход:
- Теперь мы определяем модель содержимого и стиля с помощью API Keras.Model. Модель содержимого принимает изображение в качестве входных данных и выводит карту функций из «block5_conv1» из вышеупомянутой модели VGG.
Код:
# define content model content_layer = 'block5_conv2' content_model = Model( inputs = model. input , outputs = model.get_layer(content_layer).output ) content_model.summary() |
Выход:
Модель: "функционал_9" _________________________________________________________________ Слой (тип) Параметр формы вывода # ================================================== =============== input_1 (InputLayer) [(Нет, Нет, Нет, 3)] 0 _________________________________________________________________ block1_conv1 (Conv2D) (Нет, Нет, Нет, 64) 1792 _________________________________________________________________ block1_conv2 (Conv2D) (Нет, Нет, Нет, 64) 36928 _________________________________________________________________ block1_pool (MaxPooling2D) (Нет, Нет, Нет, 64) 0 _________________________________________________________________ block2_conv1 (Conv2D) (Нет, Нет, Нет, 128) 73856 _________________________________________________________________ block2_conv2 (Conv2D) (Нет, Нет, Нет, 128) 147584 _________________________________________________________________ block2_pool (MaxPooling2D) (Нет, Нет, Нет, 128) 0 _________________________________________________________________ block3_conv1 (Conv2D) (Нет, Нет, Нет, 256) 295168 _________________________________________________________________ block3_conv2 (Conv2D) (Нет, Нет, Нет, 256) 590080 _________________________________________________________________ block3_conv3 (Conv2D) (Нет, Нет, Нет, 256) 590080 _________________________________________________________________ block3_conv4 (Conv2D) (Нет, Нет, Нет, 256) 590080 _________________________________________________________________ block3_pool (MaxPooling2D) (Нет, Нет, Нет, 256) 0 _________________________________________________________________ block4_conv1 (Conv2D) (Нет, Нет, Нет, 512) 1180160 _________________________________________________________________ block4_conv2 (Conv2D) (Нет, Нет, Нет, 512) 2359808 _________________________________________________________________ block4_conv3 (Conv2D) (Нет, Нет, Нет, 512) 2359808 _________________________________________________________________ block4_conv4 (Conv2D) (Нет, Нет, Нет, 512) 2359808 _________________________________________________________________ block4_pool (MaxPooling2D) (Нет, Нет, Нет, 512) 0 _________________________________________________________________ block5_conv1 (Conv2D) (Нет, Нет, Нет, 512) 2359808 _________________________________________________________________ block5_conv2 (Conv2D) (Нет, Нет, Нет, 512) 2359808 ================================================== =============== Всего параметров: 15 304 768 Обучаемые параметры: 0 Необучаемые параметры: 15 304 768 _________________________________________________________________
- Теперь мы определяем модель содержимого и стиля с помощью API Keras.Model. Модель стиля принимает изображение в качестве входных данных и выводит карту функций из «block1_conv1, block3_conv1 и block5_conv2» из вышеупомянутой модели VGG.
Код:
# define style model style_layers = [ 'block1_conv1' , 'block3_conv1' , 'block5_conv1' ] style_models = [Model(inputs = model. input , outputs = model.get_layer(layer).output) for layer in style_layers] |
- Теперь мы определяем функцию потери контента, она берет карту характеристик сгенерированного и реального изображений и вычисляет среднеквадратичную разницу между ними.
Код:
# Content loss def content_loss(content, generated): a_C = content_model(content) loss = tf.reduce_mean(tf.square(a_C - a_G)) return loss |
- Теперь мы определяем матрицу грамма и функцию потери стиля. Эта функция также принимает реальные и сгенерированные изображения в качестве входных данных модели и вычисляет их граммовые матрицы перед вычислением потерь стиля, взвешенных по разным слоям.
Код:
# gram matrix def gram_matrix(A): channels = int (A.shape[ - 1 ]) a = tf.reshape(A, [ - 1 , channels]) n = tf.shape(a)[ 0 ] gram = tf.matmul(a, a, transpose_a = True ) return gram / tf.cast(n, tf.float32) weight_of_layer = 1. / len (style_models) # style loss def style_cost(style, generated): J_style = 0 for style_model in style_models: a_S = style_model(style) a_G = style_model(generated) GS = gram_matrix(a_S) GG = gram_matrix(a_G) current_cost = tf.reduce_mean(tf.square(GS - GG)) J_style + = current_cost * weight_of_layer return J_style |
- Теперь, когда мы определяем нашу обучающую функцию, мы обучим нашу модель до 50 итераций. Эта модель принимает входные изображения, количество итераций в качестве аргумента.
# training function generated_images = [] def training_loop(content_path, style_path, iterations = 50 , a = 10 , b = 1000 ): # load content and style images from their repsective path content = load_and_process_image(content_path) style = load_and_process_image(style_path) generated = tf.Variable(content, dtype = tf.float32) opt = tf.keras.optimizers.Adam(learning_rate = 7 ) best_cost = Inf best_image = None for i in range (iterations): % % time with tf.GradientTape() as tape: J_content = content_cost(content, generated) J_style = style_cost(style, generated) J_total = a * J_content + b * J_style grads = tape.gradient(J_total, generated) opt.apply_gradients([(grads, generated)]) if J_total < best_cost: best_cost = J_total best_image = generated.numpy() print ( "Iteration :{}" . format (i)) print ( 'Total Loss {:e}.' . format (J_total)) generated_images.append(generated.numpy()) return best_image |
- Теперь мы обучаем нашу модель, используя функцию обучения, которую мы определили выше.
Код:
# Train the model and get best image final_img = training(content_path, style_path) |
Выход:
Время ЦП: пользовательское 2 мкс, системное: 1e + 03 нс, всего: 3 мкс Время на стене: 6,2 мкс Итерация: 0 Общий убыток 5.133922e + 11. Время ЦП: пользовательское 2 мкс, sys: 1e + 03 нс, всего: 3 мкс Время на стене: 5,72 мкс Итерация: 1 Общий убыток 3.510511e + 11. Время ЦП: пользовательское 2 мкс, системное: 1e + 03 нс, всего: 3 мкс Время на стене: 6,68 мкс Итерация: 2 Общий убыток 2.069992e + 11. Время ЦП: пользовательское 3 мкс, sys: 1e + 03 нс, всего: 4 мкс Время на стене: 6,2 мкс Итерация: 3 Общий убыток 1.669609e + 11. Время ЦП: пользовательское 2 мкс, sys: 1e + 03 нс, всего: 3 мкс Время на стене: 6,44 мкс Итерация: 4 Общий убыток 1.575840e + 11. Время ЦП: пользовательское 2 мкс, sys: 1e + 03 нс, всего: 3 мкс Время на стене: 5,96 мкс Итерация: 5 Общий убыток 1.200623e + 11. Время ЦП: пользовательское 2 мкс, системное: 1e + 03 нс, всего: 3 мкс Время на стене: 5,96 мкс Итерация: 6 Общий убыток 8.824594e + 10. Время ЦП: пользовательское 2 мкс, sys: 1e + 03 нс, всего: 3 мкс Время на стене: 5,72 мкс Итерация: 7 Общий убыток 7.168546e + 10. Время ЦП: пользовательское 2 мкс, sys: 1e + 03 нс, всего: 3 мкс Время на стене: 5,48 мкс Итерация: 8 Общий убыток 6.207320e + 10. Время ЦП: пользовательское 3 мкс, sys: 1e + 03 нс, всего: 4 мкс Время на стене: 8,34 мкс Итерация: 9 Общий убыток 5,390836e + 10. Время ЦП: пользовательское 2 мкс, sys: 1e + 03 нс, всего: 3 мкс Время на стене: 6,2 мкс Итерация: 10 Общий убыток 4,735992e + 10. Время ЦП: пользовательское 2 мкс, sys: 1e + 03 нс, всего: 3 мкс Время на стене: 5,96 мкс Итерация: 11 Общий убыток 4.301782e + 10. Время ЦП: пользовательское 2 мкс, системное: 1e + 03 нс, всего: 3 мкс Время на стене: 6,2 мкс Итерация: 12 Общий убыток 3.912694e + 10. Время ЦП: пользовательское 2 мкс, системное: 1e + 03 нс, всего: 3 мкс Время на стене: 6,68 мкс Итерация: 13 Общий убыток 3.445185e + 10. Время ЦП: пользовательский 0 нс, системный: 3 мкс, всего: 3 мкс Время на стене: 6,2 мкс Итерация: 14 Общий убыток 2.975165e + 10. Время ЦП: пользовательское 2 мкс, системное: 0 нс, всего: 2 мкс Время на стене: 5,96 мкс Итерация: 15 Общий убыток 2.590984e + 10. Время ЦП: пользовательское 2 мкс, sys: 1e + 03 нс, всего: 3 мкс Время на стене: 20 мкс Итерация: 16 Общий убыток 2.302116e + 10. Время ЦП: пользовательское 2 мкс, системное: 1e + 03 нс, всего: 3 мкс Время на стене: 5,72 мкс Итерация: 17 Общий убыток 2.082643e + 10. Время ЦП: пользовательское 4 мкс, системное: 1e + 03 нс, всего: 5 мкс Время на стене: 8,34 мкс Итерация: 18 Общий убыток 1.906701e + 10. Время ЦП: пользовательское 2 мкс, sys: 1e + 03 нс, всего: 3 мкс Время на стене: 5,25 мкс Итерация: 19 Общий убыток 1.759801e + 10. Время ЦП: пользовательское 3 мкс, sys: 1e + 03 нс, всего: 4 мкс Время на стене: 6,2 мкс Итерация: 20 Общий убыток 1.635128e + 10. Время ЦП: пользовательское 2 мкс, sys: 1e + 03 нс, всего: 3 мкс Время на стене: 6,2 мкс Итерация: 21 Общий убыток 1.525327e + 10. Время ЦП: пользовательское 3 мкс, sys: 1e + 03 нс, всего: 4 мкс Время на стене: 5,96 мкс Итерация: 22 Общий убыток 1.418364e + 10. Время ЦП: пользовательское 4 мкс, системное: 1 мкс, всего: 5 мкс Время на стене: 9,06 мкс Итерация: 23 Общий убыток 1.306596e + 10. Время ЦП: пользовательское 2 мкс, системное: 1e + 03 нс, всего: 3 мкс Время на стене: 5,25 мкс Итерация: 24 Общий убыток 1.196509e + 10. Время ЦП: пользовательское 2 мкс, системное: 1e + 03 нс, всего: 3 мкс Время на стене: 5,96 мкс Итерация: 25 Общий убыток 1.102290e + 10. Время ЦП: пользовательское 2 мкс, sys: 1e + 03 нс, всего: 3 мкс Время на стене: 5,96 мкс Итерация: 26 Общий убыток 1.025539e + 10. Время ЦП: пользовательское 7 мкс, системное: 3 мкс, всего: 10 мкс Время на стене: 12,6 мкс Итерация: 27 Общий убыток 9.570500e + 09. Время ЦП: пользовательское 2 мкс, системное: 1e + 03 нс, всего: 3 мкс Время на стене: 5,72 мкс Итерация: 28 Общий убыток 8.917115e + 09. Время ЦП: пользовательское 2 мкс, sys: 1e + 03 нс, всего: 3 мкс Время на стене: 5,96 мкс Итерация: 29 Общий убыток 8,328761e + 09. Время ЦП: пользовательское 3 мкс, sys: 1e + 03 нс, всего: 4 мкс Время на стене: 9,54 мкс Итерация: 30 Общий убыток 7.840127e + 09. Время ЦП: пользовательское 2 мкс, системное: 1e + 03 нс, всего: 3 мкс Время на стене: 6,44 мкс Итерация: 31 Общий убыток 7,406647e + 09. Время ЦП: пользовательское 2 мкс, sys: 1e + 03 нс, всего: 3 мкс Время на стене: 8,34 мкс Итерация: 32 Общий убыток 6.967848e + 09. Время ЦП: пользовательское 2 мкс, системное: 1e + 03 нс, всего: 3 мкс Время на стене: 5,72 мкс Итерация: 33 Общий убыток 6.531650e + 09. Время ЦП: пользовательское 2 мкс, sys: 1e + 03 нс, всего: 3 мкс Время на стене: 5,72 мкс Итерация: 34 Общий убыток 6.136975e + 09. Время ЦП: пользовательское 2 мкс, системное: 1 мкс, всего: 3 мкс Время на стене: 5,96 мкс Итерация: 35 Общий убыток 5.788804e + 09. Время ЦП: пользовательское 2 мкс, системное: 1e + 03 нс, всего: 3 мкс Время на стене: 5,72 мкс Итерация: 36 Общий убыток 5.476942e + 09. Время ЦП: пользовательское 2 мкс, системное: 1e + 03 нс, всего: 3 мкс Время на стене: 6,2 мкс Итерация: 37 Общий убыток 5.204070e + 09. Время ЦП: пользовательское 3 мкс, системное: 1 мкс, всего: 4 мкс Время на стене: 6,2 мкс Итерация: 38 Общий убыток 4.954049e + 09. Время ЦП: пользовательское 2 мкс, системное: 1e + 03 нс, всего: 3 мкс Время на стене: 5,96 мкс Итерация: 39 Общий убыток 4.708641e + 09. Время ЦП: пользовательское 3 мкс, системное: 2 мкс, всего: 5 мкс Время на стене: 6,2 мкс Итерация: 40 Общий убыток 4.487677e + 09. Время ЦП: пользовательское 2 мкс, системное: 1e + 03 нс, всего: 3 мкс Время на стене: 5,96 мкс Итерация: 41 Общий убыток 4.296946e + 09. Время ЦП: пользовательское 2 мкс, системное: 1e + 03 нс, всего: 3 мкс Время на стене: 5,96 мкс Итерация: 42 Общий убыток 4.107909e + 09. Время ЦП: пользовательское 3 мкс, sys: 1e + 03 нс, всего: 4 мкс Время на стене: 6,44 мкс Итерация: 43 Общий убыток 3.918156e + 09. Время ЦП: пользовательское 3 мкс, sys: 1e + 03 нс, всего: 4 мкс Время на стене: 6,2 мкс Итерация: 44 Общий убыток 3.747263e + 09. Время ЦП: пользовательское 3 мкс, sys: 1e + 03 нс, всего: 4 мкс Время на стене: 8,34 мкс Итерация: 45 Общий убыток 3.595638e + 09. Время ЦП: пользовательское 2 мкс, sys: 1e + 03 нс, всего: 3 мкс Время на стене: 5,72 мкс Итерация: 46 Общий убыток 3.458928e + 09. Время ЦП: пользовательское 2 мкс, системное: 1e + 03 нс, всего: 3 мкс Время на стене: 6,2 мкс Итерация: 47 Общий убыток 3.331772e + 09. Время ЦП: пользовательское 4 мкс, системное: 1e + 03 нс, всего: 5 мкс Время на стене: 9,3 мкс Итерация: 48 Общий убыток 3.205911e + 09. Время ЦП: пользовательское 3 мкс, sys: 1e + 03 нс, всего: 4 мкс Время на стене: 5,96 мкс Итерация: 49 Общий убыток 3.089630e + 09.
- На заключительном этапе мы наносим окончательный и промежуточный результаты.
Код:
# code to display best generted image and last 10 intermediate results plt.figure(figsize = ( 12 , 12 )) for i in range ( 10 ): plt.subplot( 4 , 3 , i + 1 ) display_image(generated_images[i + 39 ]) plt.show() # plot best result display_image(final_img) |
Выход:
Использованная литература:
- Учебник Tensorflow по передаче нейронного стиля
- Бумага для переноса стилей
Внимание компьютерщик! Укрепите свои основы с помощью базового курса программирования Python и изучите основы.
Для начала подготовьтесь к собеседованию. Расширьте свои концепции структур данных с помощью курса Python DS. А чтобы начать свое путешествие по машинному обучению, присоединяйтесь к курсу Машинное обучение - базовый уровень.