Передача нейронного стиля с помощью TensorFlow

Опубликовано: 4 Января, 2022

Перенос нейронного стиля - это метод оптимизации, используемый для получения двух изображений, изображения содержимого и эталонного изображения стиля (например, работы известного художника), и смешивания их вместе, чтобы выходное изображение выглядело как изображение содержимого, но было «нарисовано» в стиле эталонного изображения стиля. Этот метод используется многими популярными приложениями для Android iOS, такими как Prisma , DreamScope , PicsArt .

Пример передачи стиля A - изображение содержимого, B выводится с изображением стиля в нижнем левом углу.

Архитектура :

Бумага для переноса в нейронном стиле использует карты характеристик, сгенерированные промежуточными слоями сети VGG-19, для генерации выходного изображения. Эта архитектура принимает изображения стиля и содержимого в качестве входных данных и сохраняет функции, извлеченные сверточными слоями сети VGG.

Архитектура ВГГ-19

Потеря контента:

Для расчета стоимости контента мы применяем среднеквадратичную разницу между матрицами, сгенерированными слоем контента, когда мы передаем сгенерированное изображение и исходное изображение. Пусть p и x - исходное изображение и изображение, которое сгенерировано, а P и F - их соответствующие представления признаков в слое l . Затем мы определяем потерю квадратичной ошибки между двумя представлениями функций.

Потеря стиля :

Для расчета стоимости стиля сначала рассчитаем матрицу граммов. Вычисление матриц грамма включает вычисление внутреннего продукта между векторизованными картами признаков определенного слоя . Здесь G ij (l) представляет собой внутренний продукт между векторизованными объектами i, j слоя l.

Теперь, чтобы вычислить потери от конкретного, мы найдем среднеквадратическую разницу матриц грамма, вычисленную из векторов признаков изображения стиля и сгенерированного изображения. Затем это взвешено с коэффициентом взвешивания слоя.

Пусть a и x - исходное изображение и сгенерированное изображение, а Al и Gl - их соответствующее представление стиля (матрицы граммов) в слое l. Вклад слоя l в общие потери тогда равен:

Таким образом, общая потеря стиля будет:

Общая потеря

Полная потеря - это линейная комбинация стиля и потери контента, которую мы определили выше:

Где α и β - весовые коэффициенты для реконструкции контента и стиля соответственно.

Реализация в Tensorflow:

  • Сначала импортируем необходимый модуль. В этом посте мы используем TensorFlow v2 с Keras. Мы также импортируем модель VGG-19 из tf.keras API.

Код:

# import numpy, tensorflow and matplotlib
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
# import VGG 19 model and keras Model API
from tensorflow.python.keras.applications.vgg19 import VGG19, preprocess_input
from tensorflow.python.keras.preprocessing.image import load_img, img_to_array
from tensorflow.python.keras.models import Model
  • Теперь мы импортируем контент и изображения стиля и сохраняем их в нашем рабочем каталоге.

Код:

# Image Credits: Tensorflow Doc
content_path = tf.keras.utils.get_file( 'content.jpg' ,
style_path = tf.keras.utils.get_file( 'style.jpg' ,
  • Теперь мы инициализируем модель VGG с весами ImageNet, мы также удалим верхние слои и сделаем ее не обучаемой.

Код:

# code
# this function download the VGG model and initiliase it
model = VGG19(
include_top = False ,
weights = 'imagenet'
)
# set training to False
model.trainable = False
# Print details of different layers
model.summary()


Выход:

 Загрузка данных с https://storage.googleapis.com/tensorflow/keras-applications/vgg19/vgg19_weights_tf_dim_ordering_tf_kernels_notop.h5
80142336/80134624 [==============================] - 1 с 0 мкс / шаг
Модель: "vgg19"
_________________________________________________________________
Слой (тип) Параметр формы вывода #   
================================================== ===============
input_1 (InputLayer) [(Нет, Нет, Нет, 3)] 0         
_________________________________________________________________
block1_conv1 (Conv2D) (Нет, Нет, Нет, 64) 1792      
_________________________________________________________________
block1_conv2 (Conv2D) (Нет, Нет, Нет, 64) 36928     
_________________________________________________________________
block1_pool (MaxPooling2D) (Нет, Нет, Нет, 64) 0         
_________________________________________________________________
block2_conv1 (Conv2D) (Нет, Нет, Нет, 128) 73856     
_________________________________________________________________
block2_conv2 (Conv2D) (Нет, Нет, Нет, 128) 147584    
_________________________________________________________________
block2_pool (MaxPooling2D) (Нет, Нет, Нет, 128) 0         
_________________________________________________________________
block3_conv1 (Conv2D) (Нет, Нет, Нет, 256) 295168    
_________________________________________________________________
block3_conv2 (Conv2D) (Нет, Нет, Нет, 256) 590080    
_________________________________________________________________
block3_conv3 (Conv2D) (Нет, Нет, Нет, 256) 590080    
_________________________________________________________________
block3_conv4 (Conv2D) (Нет, Нет, Нет, 256) 590080    
_________________________________________________________________
block3_pool (MaxPooling2D) (Нет, Нет, Нет, 256) 0         
_________________________________________________________________
block4_conv1 (Conv2D) (Нет, Нет, Нет, 512) 1180160   
_________________________________________________________________
block4_conv2 (Conv2D) (Нет, Нет, Нет, 512) 2359808   
_________________________________________________________________
block4_conv3 (Conv2D) (Нет, Нет, Нет, 512) 2359808   
_________________________________________________________________
block4_conv4 (Conv2D) (Нет, Нет, Нет, 512) 2359808   
_________________________________________________________________
block4_pool (MaxPooling2D) (Нет, Нет, Нет, 512) 0         
_________________________________________________________________
block5_conv1 (Conv2D) (Нет, Нет, Нет, 512) 2359808   
_________________________________________________________________
block5_conv2 (Conv2D) (Нет, Нет, Нет, 512) 2359808   
_________________________________________________________________
block5_conv3 (Conv2D) (Нет, Нет, Нет, 512) 2359808   
_________________________________________________________________
block5_conv4 (Conv2D) (Нет, Нет, Нет, 512) 2359808   
_________________________________________________________________
block5_pool (MaxPooling2D) (Нет, Нет, Нет, 512) 0         
================================================== ===============
Всего параметров: 20 024 384
Обучаемые параметры: 0
Необучаемые параметры: 20 024 384
________________________________________________________________

  • Теперь мы загружаем и обрабатываем изображение с помощью входных данных препроцессора Keras в VGG 19. Функция expand_dims добавляет измерение для представления ряда изображений во входных данных. Эта функция preprocess_input (используемая в VGG 19) преобразует входные изображения RGB в изображения BGR и центрирует эти значения вокруг 0 в соответствии с данными ImageNet (без масштабирования).

Код:

# code to load and process image
def load_and_process_image(image_path):
img = load_img(image_path)
# convert image to array
img = img_to_array(img)
img = preprocess_input(img)
img = np.expand_dims(img, axis = 0 )
return img


  • Теперь мы определяем функцию депроцесса, которая принимает входное изображение и выполняет обратную функцию preprocess_input, которую мы импортировали выше. Для отображения необработанного изображения мы также определяем функцию отображения.

Код:

# code
def deprocess(img):
# perform the inverse of the pre processing step
img[:, :, 0 ] + = 103.939
img[:, :, 1 ] + = 116.779
img[:, :, 2 ] + = 123.68
# convert RGB to BGR
img = img[:, :, :: - 1 ]
img = np.clip(img, 0 , 255 ).astype( 'uint8' )
return img
def display_image(image):
# remove one dimension if image has 4 dimension
if len (image.shape) = = 4 :
img = np.squeeze(image, axis = 0 )
img = deprocess(img)
plt.grid( False )
plt.xticks([])
plt.yticks([])
plt.imshow(img)
return


  • Теперь мы используем указанную выше функцию для отображения изображений стиля и содержимого.

Код:

# load content image
content_img = load_and_process_image(content_path)
display_image(content_img)
# load style image
style_img = load_and_process_image(style_path)
display_image(style_img)


Выход:

Изображение содержимого

Стиль изображения

  • Теперь мы определяем модель содержимого и стиля с помощью API Keras.Model. Модель содержимого принимает изображение в качестве входных данных и выводит карту функций из «block5_conv1» из вышеупомянутой модели VGG.

Код:

# define content model
content_layer = 'block5_conv2'
content_model = Model(
inputs = model. input ,
outputs = model.get_layer(content_layer).output
)
content_model.summary()


Выход:

 Модель: "функционал_9"
_________________________________________________________________
Слой (тип) Параметр формы вывода #   
================================================== ===============
input_1 (InputLayer) [(Нет, Нет, Нет, 3)] 0         
_________________________________________________________________
block1_conv1 (Conv2D) (Нет, Нет, Нет, 64) 1792      
_________________________________________________________________
block1_conv2 (Conv2D) (Нет, Нет, Нет, 64) 36928     
_________________________________________________________________
block1_pool (MaxPooling2D) (Нет, Нет, Нет, 64) 0         
_________________________________________________________________
block2_conv1 (Conv2D) (Нет, Нет, Нет, 128) 73856     
_________________________________________________________________
block2_conv2 (Conv2D) (Нет, Нет, Нет, 128) 147584    
_________________________________________________________________
block2_pool (MaxPooling2D) (Нет, Нет, Нет, 128) 0         
_________________________________________________________________
block3_conv1 (Conv2D) (Нет, Нет, Нет, 256) 295168    
_________________________________________________________________
block3_conv2 (Conv2D) (Нет, Нет, Нет, 256) 590080    
_________________________________________________________________
block3_conv3 (Conv2D) (Нет, Нет, Нет, 256) 590080    
_________________________________________________________________
block3_conv4 (Conv2D) (Нет, Нет, Нет, 256) 590080    
_________________________________________________________________
block3_pool (MaxPooling2D) (Нет, Нет, Нет, 256) 0         
_________________________________________________________________
block4_conv1 (Conv2D) (Нет, Нет, Нет, 512) 1180160   
_________________________________________________________________
block4_conv2 (Conv2D) (Нет, Нет, Нет, 512) 2359808   
_________________________________________________________________
block4_conv3 (Conv2D) (Нет, Нет, Нет, 512) 2359808   
_________________________________________________________________
block4_conv4 (Conv2D) (Нет, Нет, Нет, 512) 2359808   
_________________________________________________________________
block4_pool (MaxPooling2D) (Нет, Нет, Нет, 512) 0         
_________________________________________________________________
block5_conv1 (Conv2D) (Нет, Нет, Нет, 512) 2359808   
_________________________________________________________________
block5_conv2 (Conv2D) (Нет, Нет, Нет, 512) 2359808   
================================================== ===============
Всего параметров: 15 304 768
Обучаемые параметры: 0
Необучаемые параметры: 15 304 768
_________________________________________________________________
  • Теперь мы определяем модель содержимого и стиля с помощью API Keras.Model. Модель стиля принимает изображение в качестве входных данных и выводит карту функций из «block1_conv1, block3_conv1 и block5_conv2» из вышеупомянутой модели VGG.

Код:

# define style model
style_layers = [
'block1_conv1' ,
'block3_conv1' ,
'block5_conv1'
]
style_models = [Model(inputs = model. input ,
outputs = model.get_layer(layer).output) for layer in style_layers]
  • Теперь мы определяем функцию потери контента, она берет карту характеристик сгенерированного и реального изображений и вычисляет среднеквадратичную разницу между ними.

Код:

# Content loss
def content_loss(content, generated):
a_C = content_model(content)
loss = tf.reduce_mean(tf.square(a_C - a_G))
return loss
  • Теперь мы определяем матрицу грамма и функцию потери стиля. Эта функция также принимает реальные и сгенерированные изображения в качестве входных данных модели и вычисляет их граммовые матрицы перед вычислением потерь стиля, взвешенных по разным слоям.

Код:

# gram matrix
def gram_matrix(A):
channels = int (A.shape[ - 1 ])
a = tf.reshape(A, [ - 1 , channels])
n = tf.shape(a)[ 0 ]
gram = tf.matmul(a, a, transpose_a = True )
return gram / tf.cast(n, tf.float32)
weight_of_layer = 1. / len (style_models)
# style loss
def style_cost(style, generated):
J_style = 0
for style_model in style_models:
a_S = style_model(style)
a_G = style_model(generated)
GS = gram_matrix(a_S)
GG = gram_matrix(a_G)
current_cost = tf.reduce_mean(tf.square(GS - GG))
J_style + = current_cost * weight_of_layer
return J_style


  • Теперь, когда мы определяем нашу обучающую функцию, мы обучим нашу модель до 50 итераций. Эта модель принимает входные изображения, количество итераций в качестве аргумента.
# training function
generated_images = []
def training_loop(content_path, style_path, iterations = 50 , a = 10 , b = 1000 ):
# load content and style images from their repsective path
content = load_and_process_image(content_path)
style = load_and_process_image(style_path)
generated = tf.Variable(content, dtype = tf.float32)
opt = tf.keras.optimizers.Adam(learning_rate = 7 )
best_cost = Inf
best_image = None
for i in range (iterations):
% % time
with tf.GradientTape() as tape:
J_content = content_cost(content, generated)
J_style = style_cost(style, generated)
J_total = a * J_content + b * J_style
grads = tape.gradient(J_total, generated)
opt.apply_gradients([(grads, generated)])
if J_total < best_cost:
best_cost = J_total
best_image = generated.numpy()
print ( "Iteration :{}" . format (i))
print ( 'Total Loss {:e}.' . format (J_total))
generated_images.append(generated.numpy())
return best_image
  • Теперь мы обучаем нашу модель, используя функцию обучения, которую мы определили выше.

Код:

# Train the model and get best image
final_img = training(content_path, style_path)

Выход:

Время ЦП: пользовательское 2 мкс, системное: 1e + 03 нс, всего: 3 мкс
Время на стене: 6,2 мкс
Итерация: 0
Общий убыток 5.133922e + 11.
Время ЦП: пользовательское 2 мкс, sys: 1e + 03 нс, всего: 3 мкс
Время на стене: 5,72 мкс
Итерация: 1
Общий убыток 3.510511e + 11.
Время ЦП: пользовательское 2 мкс, системное: 1e + 03 нс, всего: 3 мкс
Время на стене: 6,68 мкс
Итерация: 2
Общий убыток 2.069992e + 11.
Время ЦП: пользовательское 3 мкс, sys: 1e + 03 нс, всего: 4 мкс
Время на стене: 6,2 мкс
Итерация: 3
Общий убыток 1.669609e + 11.
Время ЦП: пользовательское 2 мкс, sys: 1e + 03 нс, всего: 3 мкс
Время на стене: 6,44 мкс
Итерация: 4
Общий убыток 1.575840e + 11.
Время ЦП: пользовательское 2 мкс, sys: 1e + 03 нс, всего: 3 мкс
Время на стене: 5,96 мкс
Итерация: 5
Общий убыток 1.200623e + 11.
Время ЦП: пользовательское 2 мкс, системное: 1e + 03 нс, всего: 3 мкс
Время на стене: 5,96 мкс
Итерация: 6
Общий убыток 8.824594e + 10.
Время ЦП: пользовательское 2 мкс, sys: 1e + 03 нс, всего: 3 мкс
Время на стене: 5,72 мкс
Итерация: 7
Общий убыток 7.168546e + 10.
Время ЦП: пользовательское 2 мкс, sys: 1e + 03 нс, всего: 3 мкс
Время на стене: 5,48 мкс
Итерация: 8
Общий убыток 6.207320e + 10.
Время ЦП: пользовательское 3 мкс, sys: 1e + 03 нс, всего: 4 мкс
Время на стене: 8,34 мкс
Итерация: 9
Общий убыток 5,390836e + 10.
Время ЦП: пользовательское 2 мкс, sys: 1e + 03 нс, всего: 3 мкс
Время на стене: 6,2 мкс
Итерация: 10
Общий убыток 4,735992e + 10.
Время ЦП: пользовательское 2 мкс, sys: 1e + 03 нс, всего: 3 мкс
Время на стене: 5,96 мкс
Итерация: 11
Общий убыток 4.301782e + 10.
Время ЦП: пользовательское 2 мкс, системное: 1e + 03 нс, всего: 3 мкс
Время на стене: 6,2 мкс
Итерация: 12
Общий убыток 3.912694e + 10.
Время ЦП: пользовательское 2 мкс, системное: 1e + 03 нс, всего: 3 мкс
Время на стене: 6,68 мкс
Итерация: 13
Общий убыток 3.445185e + 10.
Время ЦП: пользовательский 0 нс, системный: 3 мкс, всего: 3 мкс
Время на стене: 6,2 мкс
Итерация: 14
Общий убыток 2.975165e + 10.
Время ЦП: пользовательское 2 мкс, системное: 0 нс, всего: 2 мкс
Время на стене: 5,96 мкс
Итерация: 15
Общий убыток 2.590984e + 10.
Время ЦП: пользовательское 2 мкс, sys: 1e + 03 нс, всего: 3 мкс
Время на стене: 20 мкс
Итерация: 16
Общий убыток 2.302116e + 10.
Время ЦП: пользовательское 2 мкс, системное: 1e + 03 нс, всего: 3 мкс
Время на стене: 5,72 мкс
Итерация: 17
Общий убыток 2.082643e + 10.
Время ЦП: пользовательское 4 мкс, системное: 1e + 03 нс, всего: 5 мкс
Время на стене: 8,34 мкс
Итерация: 18
Общий убыток 1.906701e + 10.
Время ЦП: пользовательское 2 мкс, sys: 1e + 03 нс, всего: 3 мкс
Время на стене: 5,25 мкс
Итерация: 19
Общий убыток 1.759801e + 10.
Время ЦП: пользовательское 3 мкс, sys: 1e + 03 нс, всего: 4 мкс
Время на стене: 6,2 мкс
Итерация: 20
Общий убыток 1.635128e + 10.
Время ЦП: пользовательское 2 мкс, sys: 1e + 03 нс, всего: 3 мкс
Время на стене: 6,2 мкс
Итерация: 21
Общий убыток 1.525327e + 10.
Время ЦП: пользовательское 3 мкс, sys: 1e + 03 нс, всего: 4 мкс
Время на стене: 5,96 мкс
Итерация: 22
Общий убыток 1.418364e + 10.
Время ЦП: пользовательское 4 мкс, системное: 1 мкс, всего: 5 мкс
Время на стене: 9,06 мкс
Итерация: 23
Общий убыток 1.306596e + 10.
Время ЦП: пользовательское 2 мкс, системное: 1e + 03 нс, всего: 3 мкс
Время на стене: 5,25 мкс
Итерация: 24
Общий убыток 1.196509e + 10.
Время ЦП: пользовательское 2 мкс, системное: 1e + 03 нс, всего: 3 мкс
Время на стене: 5,96 мкс
Итерация: 25
Общий убыток 1.102290e + 10.
Время ЦП: пользовательское 2 мкс, sys: 1e + 03 нс, всего: 3 мкс
Время на стене: 5,96 мкс
Итерация: 26
Общий убыток 1.025539e + 10.
Время ЦП: пользовательское 7 мкс, системное: 3 мкс, всего: 10 мкс
Время на стене: 12,6 мкс
Итерация: 27
Общий убыток 9.570500e + 09.
Время ЦП: пользовательское 2 мкс, системное: 1e + 03 нс, всего: 3 мкс
Время на стене: 5,72 мкс
Итерация: 28
Общий убыток 8.917115e + 09.
Время ЦП: пользовательское 2 мкс, sys: 1e + 03 нс, всего: 3 мкс
Время на стене: 5,96 мкс
Итерация: 29
Общий убыток 8,328761e + 09.
Время ЦП: пользовательское 3 мкс, sys: 1e + 03 нс, всего: 4 мкс
Время на стене: 9,54 мкс
Итерация: 30
Общий убыток 7.840127e + 09.
Время ЦП: пользовательское 2 мкс, системное: 1e + 03 нс, всего: 3 мкс
Время на стене: 6,44 мкс
Итерация: 31
Общий убыток 7,406647e + 09.
Время ЦП: пользовательское 2 мкс, sys: 1e + 03 нс, всего: 3 мкс
Время на стене: 8,34 мкс
Итерация: 32
Общий убыток 6.967848e + 09.
Время ЦП: пользовательское 2 мкс, системное: 1e + 03 нс, всего: 3 мкс
Время на стене: 5,72 мкс
Итерация: 33
Общий убыток 6.531650e + 09.
Время ЦП: пользовательское 2 мкс, sys: 1e + 03 нс, всего: 3 мкс
Время на стене: 5,72 мкс
Итерация: 34
Общий убыток 6.136975e + 09.
Время ЦП: пользовательское 2 мкс, системное: 1 мкс, всего: 3 мкс
Время на стене: 5,96 мкс
Итерация: 35
Общий убыток 5.788804e + 09.
Время ЦП: пользовательское 2 мкс, системное: 1e + 03 нс, всего: 3 мкс
Время на стене: 5,72 мкс
Итерация: 36
Общий убыток 5.476942e + 09.
Время ЦП: пользовательское 2 мкс, системное: 1e + 03 нс, всего: 3 мкс
Время на стене: 6,2 мкс
Итерация: 37
Общий убыток 5.204070e + 09.
Время ЦП: пользовательское 3 мкс, системное: 1 мкс, всего: 4 мкс
Время на стене: 6,2 мкс
Итерация: 38
Общий убыток 4.954049e + 09.
Время ЦП: пользовательское 2 мкс, системное: 1e + 03 нс, всего: 3 мкс
Время на стене: 5,96 мкс
Итерация: 39
Общий убыток 4.708641e + 09.
Время ЦП: пользовательское 3 мкс, системное: 2 мкс, всего: 5 мкс
Время на стене: 6,2 мкс
Итерация: 40
Общий убыток 4.487677e + 09.
Время ЦП: пользовательское 2 мкс, системное: 1e + 03 нс, всего: 3 мкс
Время на стене: 5,96 мкс
Итерация: 41
Общий убыток 4.296946e + 09.
Время ЦП: пользовательское 2 мкс, системное: 1e + 03 нс, всего: 3 мкс
Время на стене: 5,96 мкс
Итерация: 42
Общий убыток 4.107909e + 09.
Время ЦП: пользовательское 3 мкс, sys: 1e + 03 нс, всего: 4 мкс
Время на стене: 6,44 мкс
Итерация: 43
Общий убыток 3.918156e + 09.
Время ЦП: пользовательское 3 мкс, sys: 1e + 03 нс, всего: 4 мкс
Время на стене: 6,2 мкс
Итерация: 44
Общий убыток 3.747263e + 09.
Время ЦП: пользовательское 3 мкс, sys: 1e + 03 нс, всего: 4 мкс
Время на стене: 8,34 мкс
Итерация: 45
Общий убыток 3.595638e + 09.
Время ЦП: пользовательское 2 мкс, sys: 1e + 03 нс, всего: 3 мкс
Время на стене: 5,72 мкс
Итерация: 46
Общий убыток 3.458928e + 09.
Время ЦП: пользовательское 2 мкс, системное: 1e + 03 нс, всего: 3 мкс
Время на стене: 6,2 мкс
Итерация: 47
Общий убыток 3.331772e + 09.
Время ЦП: пользовательское 4 мкс, системное: 1e + 03 нс, всего: 5 мкс
Время на стене: 9,3 мкс
Итерация: 48
Общий убыток 3.205911e + 09.
Время ЦП: пользовательское 3 мкс, sys: 1e + 03 нс, всего: 4 мкс
Время на стене: 5,96 мкс
Итерация: 49
Общий убыток 3.089630e + 09.

  • На заключительном этапе мы наносим окончательный и промежуточный результаты.

Код:

# code to display best generted image and last 10 intermediate results
plt.figure(figsize = ( 12 , 12 ))
for i in range ( 10 ):
plt.subplot( 4 , 3 , i + 1 )
display_image(generated_images[i + 39 ])
plt.show()
# plot best result
display_image(final_img)

Выход:

Последние 10 созданных изображений

Лучшее созданное изображение

Использованная литература:

  • Учебник Tensorflow по передаче нейронного стиля
  • Бумага для переноса стилей

Внимание компьютерщик! Укрепите свои основы с помощью базового курса программирования Python и изучите основы.

Для начала подготовьтесь к собеседованию. Расширьте свои концепции структур данных с помощью курса Python DS. А чтобы начать свое путешествие по машинному обучению, присоединяйтесь к курсу Машинное обучение - базовый уровень.