Вариационные автоэнкодеры
Вариационный автокодер был предложен в 2013 году Книгмой и Веллингом в Google и Qualcomm. Вариационный автокодер (VAE) обеспечивает вероятностный способ описания наблюдения в скрытом пространстве. Таким образом, вместо того, чтобы создавать кодировщик, который выводит одно значение для описания каждого атрибута скрытого состояния, мы сформулируем наш кодировщик для описания распределения вероятностей для каждого скрытого атрибута.
Он имеет множество приложений, таких как сжатие данных, создание синтетических данных и т. Д.
Архитектура:
Автоэнкодеры - это тип нейронной сети, которая изучает кодировки данных из набора данных неконтролируемым образом. Он в основном состоит из двух частей: первая - это кодировщик, который похож на сверточную нейронную сеть, за исключением последнего слоя. Цель кодировщика - изучить эффективное кодирование данных из набора данных и передать его в архитектуру узкого места. Другая часть автоэнкодера - это декодер, который использует скрытое пространство в слое узких мест для регенерации изображений, аналогичных набору данных. Эти результаты передаются от нейронной сети в виде функции потерь.
Вариационный автоэнкодер отличается от автоэнкодера тем, что он обеспечивает статистический способ описания выборок набора данных в скрытом пространстве. Следовательно, в вариационном автоэнкодере кодер выводит распределение вероятностей в слое узких мест вместо единственного выходного значения.
Математика вариационного автоэнкодера:
Вариационный автоэнкодер использует KL-дивергенцию в качестве функции потерь, цель этого - минимизировать разницу между предполагаемым распределением и исходным распределением набора данных.
Предположим, у нас есть распределение z, и мы хотим сгенерировать из него наблюдение x. Другими словами, мы хотим вычислить
Сделать это можно следующим образом:
Но вычисление p (x) может быть довольно трудным.
Обычно это делает его трудноразрешимым распределением. Следовательно, нам нужно приблизить p (z | x) к q (z | x), чтобы сделать это распределение управляемым. Чтобы лучше приблизить p (z | x) к q (z | x), мы минимизируем потерю KL-дивергенции, которая вычисляет, насколько похожи два распределения:
При упрощении указанная выше задача минимизации эквивалентна следующей задаче максимизации:
Первый член представляет собой вероятность реконструкции, а второй член гарантирует, что наше изученное распределение q похоже на истинное априорное распределение p.
Таким образом, наши общие потери состоят из двух членов, одно из которых представляет собой ошибку реконструкции, а другое - потерю KL-дивергенции:
Реализация:
В этой реализации мы будем использовать набор данных Fashion-MNIST, этот набор данных уже доступен в API keras.datasets, поэтому нам не нужно добавлять или загружать вручную.
- Во-первых, нам нужно импортировать необходимые пакеты в нашу среду Python. мы будем использовать пакет Keras с тензорным потоком в качестве бэкэнда.
Код:
# code import numpy as np import tensorflow as tf from tensorflow import keras from tensorflow.keras import Input , Model from tensorflow.keras.layers import Layer, Conv2D, Flatten, Dense, Reshape, Conv2DTranspose import matplotlib.pyplot as plt |
- Для вариационных автоэнкодеров нам нужно определить архитектуру двух частей: кодера и декодера, но сначала мы определим уровень узких мест в архитектуре, уровень выборки.
Код:
# this sampling layer is the bottleneck layer of variational autoencoder, # it uses the output from two dense layers z_mean and z_log_var as input, # convert them into normal distribution and pass them to the decoder layer class Sampling(Layer): call( def self , inputs): z_mean, z_log_var = inputs batch = tf.shape(z_mean)[ 0 ] dim = tf.shape(z_mean)[ 1 ] epsilon = tf.keras.backend.random_normal(shape = (batch, dim)) return z_mean + tf.exp( 0.5 * z_log_var) * epsilon |
- Теперь мы определяем архитектуру кодировочной части нашего автокодировщика, эта часть принимает изображения в качестве входных данных и кодирует их представление в слое выборки.
Код:
# Define Encoder Model latent_dim = 2 encoder_inputs = Input (shape = ( 28 , 28 , 1 )) x = Conv2D( 32 , 3 , activation = "relu" , strides = 2 , padding = "same" )(encoder_inputs) x = Conv2D( 64 , 3 , activation = "relu" , strides = 2 , padding = "same" )(x) x = Flatten()(x) x = Dense( 16 , activation = "relu" )(x) z_mean = Dense(latent_dim, name = "z_mean" )(x) z_log_var = Dense(latent_dim, name = "z_log_var" )(x) z = Sampling()([z_mean, z_log_var]) encoder = Model(encoder_inputs, [z_mean, z_log_var, z], name = "encoder" ) encoder.summary() |
Модель: «энкодер» __________________________________________________________________________________________________ Слой (тип) Параметр формы вывода # Подключен к ================================================== ================================================ input_3 (InputLayer) [(Нет, 28, 28, 1)] 0 __________________________________________________________________________________________________ conv2d_2 (Conv2D) (Нет, 14, 14, 32) 320 input_3 [0] [0] __________________________________________________________________________________________________ conv2d_3 (Conv2D) (Нет, 7, 7, 64) 18496 conv2d_2 [0] [0] __________________________________________________________________________________________________ flatten_1 (Flatten) (Нет, 3136) 0 conv2d_3 [0] [0] __________________________________________________________________________________________________ плотный_2 (плотный) (нет, 16) 50192 flatten_1 [0] [0] __________________________________________________________________________________________________ z_mean (Плотный) (Нет, 2) 34 Плотный_2 [0] [0] __________________________________________________________________________________________________ z_log_var (Плотный) (Нет, 2) 34 Плотный_2 [0] [0] __________________________________________________________________________________________________ sampling_1 (Выборка) (Нет, 2) 0 z_mean [0] [0] z_log_var [0] [0] ================================================== ================================================ Всего параметров: 69, 076 Обучаемые параметры: 69, 076 Необучаемые параметры: 0 __________________________________________________________________________________________________
- Теперь мы определяем архитектуру декодирующей части нашего автоэнкодера, эта часть принимает выходные данные слоя выборки в качестве входных и выводит изображение размера (28, 28, 1).
Код:
# Define Decoder Architecture latent_inputs = keras. Input (shape = (latent_dim, )) x = Dense( 7 * 7 * 64 , activation = "relu" )(latent_inputs) x = Reshape(( 7 , 7 , 64 ))(x) x = Conv2DTranspose( 64 , 3 , activation = "relu" , strides = 2 , padding = "same" )(x) x = Conv2DTranspose( 32 , 3 , activation = "relu" , strides = 2 , padding = "same" )(x) decoder_outputs = Conv2DTranspose( 1 , 3 , activation = "sigmoid" , padding = "same" )(x) decoder = Model(latent_inputs, decoder_outputs, name = "decoder" ) decoder.summary() |
Модель: «декодер» _________________________________________________________________ Слой (тип) Параметр формы вывода # ================================================== =============== input_4 (InputLayer) [(Нет, 2)] 0 _________________________________________________________________ плотный_3 (Плотный) (Нет, 3136) 9408 _________________________________________________________________ reshape_1 (Изменить форму) (Нет, 7, 7, 64) 0 _________________________________________________________________ conv2d_transpose_3 (Conv2DTr (Нет, 14, 14, 64) 36928 _________________________________________________________________ conv2d_transpose_4 (Conv2DTr (Нет, 28, 28, 32) 18464 _________________________________________________________________ conv2d_transpose_5 (Conv2DTr (Нет, 28, 28, 1) 289 ================================================== =============== Всего параметров: 65, 089 Обучаемые параметры: 65, 089 Необучаемые параметры: 0 _________________________________________________________________
- На этом этапе мы объединяем модель и определяем процедуру обучения с функциями потерь.
Код:
# this class takes encoder and decoder models and # define the complete variational autoencoder architecture class VAE(keras.Model): def __init__( self , encoder, decoder, * * kwargs): super (VAE, self ).__init__( * * kwargs) self .encoder = encoder self .decoder = decoder def train_step( self , data): if isinstance (data, tuple ): data = data[ 0 ] with tf.GradientTape() as tape: z_mean, z_log_var, z = encoder(data) reconstruction = decoder(z) reconstruction_loss = tf.reduce_mean( keras.losses.binary_crossentropy(data, reconstruction) ) reconstruction_loss * = 28 * 28 kl_loss = 1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var) kl_loss = tf.reduce_mean(kl_loss) kl_loss * = - 0.5 total_loss = reconstruction_loss + kl_loss grads = tape.gradient(total_loss, self .trainable_weights) self .optimizer.apply_gradients( zip (grads, self .trainable_weights)) return { "loss" : total_loss, "reconstruction_loss" : reconstruction_loss, "kl_loss" : kl_loss, } |
- Теперь самое время обучить нашу вариационную модель автоэнкодера, мы будем обучать ее 100 эпох. Но сначала нам нужно импортировать набор данных Fashion MNIST.
Код:
# load fashion mnist dataset from keras.dataset API (x_train, _), (x_test, _) = keras.datasets.fashion_mnist.load_data() fmnist_images = np.concatenate([x_train, x_test], axis = 0 ) # expand dimension to add a color map dimension fmnist_images = np.expand_dims(fmnist_images, - 1 ).astype( "float32" ) / 255 # compile and train the model vae = VAE(encoder, decoder) vae. compile (optimizer = 'rmsprop' ) vae.fit(fmnist_images, epochs = 100 , batch_size = 64 ) |
Эпоха 1/100 1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 301.9441 - реконструкция_потеря: 298.3138 - kl_loss : 3.6303 Эпоха 2/100 1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 273,5940 - потеря восстановления: 270,0484 - kl_loss : 3.5456 Эпоха 3/100 1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 269.3337 - восстановление_потер: 265.9077 - kl_loss : 3.4260 Эпоха 4/100 1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 266,8168 - восстановление_потер: 263,4100 - kl_loss : 3,4068 Эпоха 5/100 1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 264,9917 - потеря_конструкции: 261,5603 - kl_loss : 3.4314 Эпоха 6/100 1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 263,5237 - восстановление_потер: 260,0712 - kl_loss : 3,4525 Эпоха 7/100 1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 262,3414 - восстановление_потер: 258,8548 - kl_loss : 3.4865 Эпоха 8/100 1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 261.4241 - реконструкция_потеря: 257.9104 - kl_loss : 3.5137 Эпоха 9/100 1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 260.6090 - реконструкция_потеря: 257.0662 - kl_loss : 3.5428 Эпоха 10/100 1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 259.9735 - Восстановление_потеря: 256.4075 - kl_loss : 3.5660 Эпоха 11/100 1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 259.4184 - реконструкция_потеря: 255.8348 - kl_loss : 3.5836 Эпоха 12/100 1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 258.9688 - реконструкция_потеря: 255.3724 - kl_loss : 3.5964 Эпоха 13/100 1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 258,5413 - реконструкция_потеря: 254.9356 - kl_loss : 3.6057 Эпоха 14/100 1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 258.2400 - реконструкция_потеря: 254.6236 - kl_loss : 3.6163 Эпоха 15/100 1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 257.9335 - реконструкция_потеря: 254.3038 - kl_loss : 3.6298 Эпоха 16/100 1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 257.6331 - восстановление_потер: 253.9993 - kl_loss : 3.6339 Эпоха 17/100 1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 257.4199 - реконструкция_потеря: 253.7707 - kl_loss : 3.6492 Эпоха 18/100 1094/1094 [==============================] - 6 с 6 мс / шаг - потеря: 257.1951 - реконструкция_потеря: 253.5309 - kl_loss : 3.6643 Эпоха 19/100 1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 256.9326 - Восстановление_потеря: 253.2723 - kl_loss : 3.6604 Эпоха 20/100 1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 256.7551 - реконструкция_потеря: 253.0836 - kl_loss : 3.6715 Эпоха 21/100 1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 256,5663 - реконструкция_потеря: 252,8877 - kl_loss : 3.6786 Эпоха 22/100 1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 256.4068 - реконструкция_потеря: 252.7112 - kl_loss : 3.6956 Эпоха 23/100 1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 256,2588 - потеря_конструкции: 252,5588 - kl_loss : 3,7000 Эпоха 24/100 1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 256.0853 - реконструкция_потеря: 252.3794 - kl_loss : 3.7059 Эпоха 25/100 1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 255.9321 - Recovery_loss: 252.2201 - kl_loss : 3.7120 Эпоха 26/100 1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 255.7962 - реконструкция_потеря: 252.0814 - kl_loss : 3.7148 Эпоха 27/100 1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 255.6953 - реконструкция_потеря: 251.9673 - kl_loss : 3.7280 Эпоха 28/100 1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 255.5534 - реконструкция_потеря: 251.8248 - kl_loss : 3,7287 Эпоха 29/100 1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 255,4437 - восстановление_потер: 251,7134 - kl_loss : 3.7303 Эпоха 30/100 1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 255.3439 - реконструкция_потеря: 251.6064 - kl_loss : 3.7375 Эпоха 31/100 1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 255.2326 - реконструкция_потеря: 251.5018 - kl_loss : 3.7308 Эпоха 32/100 1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 255.1356 - реконструкция_потеря: 251.3933 - kl_loss : 3.7423 Эпоха 33/100 1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 255.0660 - реконструкция_потеря: 251.3224 - kl_loss : 3.7436 Эпоха 34/100 1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 254,9977 - Reconstruction_loss: 251.2449 - kl_loss : 3.7528 Эпоха 35/100 1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 254,8857 - восстановление_потер: 251,1363 - kl_loss : 3.7494 Эпоха 36/100 1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 254.7980 - реконструкция_потеря: 251.0481 - kl_loss : 3.7499 Эпоха 37/100 1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 254,7485 - восстановление_потер: 250,9851 - kl_loss : 3.7634 Эпоха 38/100 1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 254,6701 - потеря_ реконструкции: 250,9049 - kl_loss : 3.7652 Эпоха 39/100 1094/1094 [==============================] - 6 с 6 мс / шаг - потеря: 254,6105 - потеря_ реконструкции: 250,8389 - kl_loss : 3.7716 Эпоха 40/100 1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 254.4979 - реконструкция_потеря: 250.7333 - kl_loss : 3,7646 Эпоха 41/100 1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 254,4734 - потеря_конструкции: 250,7037 - kl_loss : 3,7697 Эпоха 42/100 1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 254.4408 - реконструкция_потеря: 250.6576 - kl_loss : 3.7831 Эпоха 43/100 1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 254,3272 - восстановление_потер: 250,5562 - kl_loss : 3.7711 Эпоха 44/100 1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 254,3110 - потеря_конструкции: 250,5354 - kl_loss : 3.7755 Эпоха 45/100 1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 254.1982 - реконструкция_потеря: 250.4256 - kl_loss : 3.7726 Эпоха 46/100 1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 254,1655 - восстановление_потер: 250,3795 - kl_loss : 3.7860 Эпоха 47/100 1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 254,0979 - восстановление_потер: 250,3105 - kl_loss : 3.7875 Эпоха 48/100 1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 254.0801 - реконструкция_потеря: 250.2973 - kl_loss : 3.7828 Эпоха 49/100 1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 254.0101 - реконструкция_потеря: 250.2270 - kl_loss : 3.7831 Эпоха 50/100 1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 253.9512 - реконструкция_потеря: 250.1681 - kl_loss : 3.7831 Эпоха 51/100 1094/1094 [==============================] - 7 с 7 мс / шаг - потеря: 253.9307 - реконструкция_потеря: 250.1408 - kl_loss : 3.7899 Эпоха 52/100 1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 253,8858 - восстановление_потер: 250,1059 - kl_loss : 3.7800 Эпоха 53/100 1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 253,8118 - потеря_ реконструкции: 250,0236 - kl_loss : 3,7882 Эпоха 54/100 1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 253,8171 - потеря_ реконструкции: 250,0325 - kl_loss : 3.7845 Эпоха 55/100 1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 253,7622 - реконструкция_потеря: 249.9735 - kl_loss : 3.7887 Эпоха 56/100 1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 253,7338 - Reconstruction_loss: 249.9380 - kl_loss : 3,7959 Эпоха 57/100 1094/1094 [==============================] - 6 с 6 мс / шаг - потеря: 253,6761 - реконструкция_потеря: 249,8792 - kl_loss : 3.7969 Эпоха 58/100 1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 253,6236 - восстановление_потер: 249,8283 - kl_loss : 3,7954 Эпоха 59/100 1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 253,6181 - восстановление_потер: 249,8236 - kl_loss : 3.7945 Эпоха 60/100 1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 253,5509 - Reconstruction_loss: 249.7587 - kl_loss : 3.7921 Эпоха 61/100 1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 253.5124 - реконструкция_потеря: 249.7126 - kl_loss : 3,7998 Эпоха 62/100 1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 253,4739 - Reconstruction_loss: 249.6683 - kl_loss : 3.8056 Эпоха 63/100 1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 253.4609 - реконструкция_потеря: 249.6567 - kl_loss : 3.8042 Эпоха 64/100 1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 253,4066 - реконструкция_потеря: 249.6020 - kl_loss : 3.8045 Эпоха 65/100 1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 253,3578 - восстановление_потер: 249,5580 - kl_loss : 3,7998 Эпоха 66/100 1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 253,3728 - реконструкция_потеря: 249,5609 - kl_loss : 3.8118 Эпоха 67/100 1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 253,3523 - реконструкция_потеря: 249,5351 - kl_loss : 3.8171 Эпоха 68/100 1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 253.2646 - реконструкция_потеря: 249.4452 - kl_loss : 3.8194 Эпоха 69/100 1094/1094 [==============================] - 6 с 6 мс / шаг - потеря: 253.2642 - реконструкция_потеря: 249.4603 - kl_loss : 3.8040 Эпоха 70/100 1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 253.2227 - реконструкция_потеря: 249.4159 - kl_loss : 3.8068 Эпоха 71/100 1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 253,1848 - восстановление_потер: 249,3755 - kl_loss : 3.8094 Эпоха 72/100 1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 253.1812 - реконструкция_потеря: 249.3737 - kl_loss : 3.8074 Эпоха 73/100 1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 253.1803 - реконструкция_потеря: 249.3743 - kl_loss : 3.8059 Эпоха 74/100 1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 253,1295 - восстановление_потер: 249,3114 - kl_loss : 3.8181 Эпоха 75/100 1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 253.0516 - реконструкция_потеря: 249.2391 - kl_loss : 3.8125 Эпоха 76/100 1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 253,0736 - реконструкция_потеря: 249,2582 - kl_loss : 3.8154 Эпоха 77/100 1094/1094 [==============================] - 6 с 6 мс / шаг - потеря: 253.0331 - реконструкция_потеря: 249.2200 - kl_loss : 3.8131 Эпоха 78/100 1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 253.0479 - Recovery_loss: 249.2272 - kl_loss : 3.8207 Эпоха 79/100 1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 252.9317 - реконструкция_потеря: 249.1137 - kl_loss : 3.8179 Эпоха 80/100 1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 252.9578 - Восстановление_потер: 249.1483 - kl_loss : 3.8095 Эпоха 81/100 1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 252.9072 - реконструкция_потеря: 249.0963 - kl_loss : 3.8109 Эпоха 82/100 1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 252,8793 - Reconstruction_loss: 249.0646 - kl_loss : 3.8147 Эпоха 83/100 1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 252,8914 - восстановление_потер: 249,0676 - kl_loss : 3.8238 Эпоха 84/100 1094/1094 [==============================] - 6 с 6 мс / шаг - потеря: 252.8365 - восстановление_потер: 249.0121 - kl_loss : 3.8244 Эпоха 85/100 1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 252.8063 - восстановление_потер: 248.9844 - kl_loss : 3.8218 Эпоха 86/100 1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 252,7960 - реконструкция_потеря: 248.9777 - kl_loss : 3.8183 Эпоха 87/100 1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 252,7733 - Reconstruction_loss: 248.9529 - kl_loss : 3.8204 Эпоха 88/100 1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 252,7303 - восстановление_потер: 248,9055 - kl_loss : 3.8248 Эпоха 89/100 1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 252,7225 - восстановление_потер: 248,8902 - kl_loss : 3.8323 Эпоха 90/100 1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 252,6822 - восстановление_потери: 248,8549 - kl_loss : 3.8273 Эпоха 91/100 1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 252,6540 - потеря_ реконструкции: 248,8314 - kl_loss : 3.8227 Эпоха 92/100 1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 252,6540 - восстановление_потер: 248,8239 - kl_loss : 3.8300 Эпоха 93/100 1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 252,6213 - восстановление_потер: 248,7778 - kl_loss : 3.8435 Эпоха 94/100 1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 252.5990 - восстановление_потер: 248.7594 - kl_loss : 3.8397 Эпоха 95/100 1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 252,5786 - реконструкция_потеря: 248,7413 - kl_loss : 3.8373 Эпоха 96/100 1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 252,5839 - потеря_ реконструкции: 248,7411 - kl_loss : 3.8427 Эпоха 97/100 1094/1094 [==============================] - 7 с 7 мс / шаг - потеря: 252,5364 - потеря восстановления: 248,6960 - kl_loss : 3.8404 Эпоха 98/100 1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 252,5347 - реконструкция_потеря: 248,6915 - kl_loss : 3.8431 Эпоха 99/100 1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 252.4996 - реконструкция_потеря: 248.6569 - kl_loss : 3.8428 Эпоха 100/100 1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 252,4938 - Reconstruction_loss: 248.6405 - kl_loss : 3.8533 <tensorflow.python.keras.callbacks.History в 0x7f5467c56be0>
- На этом этапе мы отображаем результаты обучения, мы будем отображать эти результаты в соответствии с их значениями в векторах скрытого пространства.
Код:
def plot_latent(encoder, decoder): # display an * n 2D manifold of imagess n = 10 img_dim = 28 scale = 2.0 figsize = 15 figure = np.zeros((img_dim * n, img_dim * n)) # linearly spaced coordinates corresponding to the 2D plot # of images classes in the latent space grid_x = np.linspace( - scale, scale, n) grid_y = np.linspace( - scale, scale, n)[:: - 1 ] for i, yi in enumerate (grid_y): for j, xi in enumerate (grid_x): z_sample = np.array([[xi, yi]]) x_decoded = decoder.predict(z_sample) images = x_decoded[ 0 ].reshape(img_dim, img_dim) figure[ i * img_dim : (i + 1 ) * img_dim, j * img_dim : (j + 1 ) * img_dim, ] = images plt.figure(figsize = (figsize, figsize)) start_range = img_dim / / 2 end_range = n * img_dim + start_range + 1 pixel_range = np.arange(start_range, end_range, img_dim) sample_range_x = np. round (grid_x, 1 ) sample_range_y = np. round (grid_y, 1 ) plt.xticks(pixel_range, sample_range_x) plt.yticks(pixel_range, sample_range_y) plt.xlabel( "z[0]" ) plt.ylabel( "z[1]" ) plt.imshow(figure, cmap = "Greys_r" ) plt.show() plot_latent(encoder, decoder) |
- Чтобы получить более четкое представление о наших репрезентативных значениях скрытых векторов, мы будем строить диаграмму разброса обучающих данных на основе их значений соответствующих скрытых размеров, созданных кодировщиком.
Код:
def plot_label_clusters(encoder, decoder, data, test_lab): z_mean, _, _ = encoder.predict(data) plt.figure(figsize = ( 12 , 10 )) sc = plt.scatter(z_mean[:, 0 ], z_mean[:, 1 ], c = test_lab) cbar = plt.colorbar(sc, ticks = range ( 10 )) cbar.ax.set_yticklabels([labels.get(i) for i in range ( 10 )]) plt.xlabel( "z[0]" ) plt.ylabel( "z[1]" ) plt.show() labels = { 0 : "T-shirt / top" , 1 : "Trouser" , 2 : "Pullover" , 3 : "Dress" , 4 : "Coat" , 5 : "Sandal" , 6 : "Shirt" , 7 : "Sneaker" , 8 : "Bag" , 9 : "Ankle boot" } (x_train, y_train), _ = keras.datasets.fashion_mnist.load_data() x_train = np.expand_dims(x_train, - 1 ).astype( "float32" ) / 255 plot_label_clusters(encoder, decoder, x_train, y_train) |
Использованная литература:
- Бумага для вариационного автоэнкодера
- Вариационный автоэнкодер Keras