Вариационные автоэнкодеры

Опубликовано: 4 Января, 2022

Вариационный автокодер был предложен в 2013 году Книгмой и Веллингом в Google и Qualcomm. Вариационный автокодер (VAE) обеспечивает вероятностный способ описания наблюдения в скрытом пространстве. Таким образом, вместо того, чтобы создавать кодировщик, который выводит одно значение для описания каждого атрибута скрытого состояния, мы сформулируем наш кодировщик для описания распределения вероятностей для каждого скрытого атрибута.

Он имеет множество приложений, таких как сжатие данных, создание синтетических данных и т. Д.

Архитектура:

Автоэнкодеры - это тип нейронной сети, которая изучает кодировки данных из набора данных неконтролируемым образом. Он в основном состоит из двух частей: первая - это кодировщик, который похож на сверточную нейронную сеть, за исключением последнего слоя. Цель кодировщика - изучить эффективное кодирование данных из набора данных и передать его в архитектуру узкого места. Другая часть автоэнкодера - это декодер, который использует скрытое пространство в слое узких мест для регенерации изображений, аналогичных набору данных. Эти результаты передаются от нейронной сети в виде функции потерь.

Вариационный автоэнкодер отличается от автоэнкодера тем, что он обеспечивает статистический способ описания выборок набора данных в скрытом пространстве. Следовательно, в вариационном автоэнкодере кодер выводит распределение вероятностей в слое узких мест вместо единственного выходного значения.

Математика вариационного автоэнкодера:

Вариационный автоэнкодер использует KL-дивергенцию в качестве функции потерь, цель этого - минимизировать разницу между предполагаемым распределением и исходным распределением набора данных.

Предположим, у нас есть распределение z, и мы хотим сгенерировать из него наблюдение x. Другими словами, мы хотим вычислить

Сделать это можно следующим образом:

Но вычисление p (x) может быть довольно трудным.

Обычно это делает его трудноразрешимым распределением. Следовательно, нам нужно приблизить p (z | x) к q (z | x), чтобы сделать это распределение управляемым. Чтобы лучше приблизить p (z | x) к q (z | x), мы минимизируем потерю KL-дивергенции, которая вычисляет, насколько похожи два распределения:

При упрощении указанная выше задача минимизации эквивалентна следующей задаче максимизации:

Первый член представляет собой вероятность реконструкции, а второй член гарантирует, что наше изученное распределение q похоже на истинное априорное распределение p.

Таким образом, наши общие потери состоят из двух членов, одно из которых представляет собой ошибку реконструкции, а другое - потерю KL-дивергенции:

Реализация:

В этой реализации мы будем использовать набор данных Fashion-MNIST, этот набор данных уже доступен в API keras.datasets, поэтому нам не нужно добавлять или загружать вручную.

  • Во-первых, нам нужно импортировать необходимые пакеты в нашу среду Python. мы будем использовать пакет Keras с тензорным потоком в качестве бэкэнда.

Код:

# code
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import Input , Model
from tensorflow.keras.layers import Layer, Conv2D, Flatten, Dense, Reshape, Conv2DTranspose
import matplotlib.pyplot as plt
  • Для вариационных автоэнкодеров нам нужно определить архитектуру двух частей: кодера и декодера, но сначала мы определим уровень узких мест в архитектуре, уровень выборки.

Код:

# this sampling layer is the bottleneck layer of variational autoencoder,
# it uses the output from two dense layers z_mean and z_log_var as input,
# convert them into normal distribution and pass them to the decoder layer
class Sampling(Layer):
call( def self , inputs):
z_mean, z_log_var = inputs
batch = tf.shape(z_mean)[ 0 ]
dim = tf.shape(z_mean)[ 1 ]
epsilon = tf.keras.backend.random_normal(shape = (batch, dim))
return z_mean + tf.exp( 0.5 * z_log_var) * epsilon
  • Теперь мы определяем архитектуру кодировочной части нашего автокодировщика, эта часть принимает изображения в качестве входных данных и кодирует их представление в слое выборки.

Код:

# Define Encoder Model
latent_dim = 2
encoder_inputs = Input (shape = ( 28 , 28 , 1 ))
x = Conv2D( 32 , 3 , activation = "relu" , strides = 2 , padding = "same" )(encoder_inputs)
x = Conv2D( 64 , 3 , activation = "relu" , strides = 2 , padding = "same" )(x)
x = Flatten()(x)
x = Dense( 16 , activation = "relu" )(x)
z_mean = Dense(latent_dim, name = "z_mean" )(x)
z_log_var = Dense(latent_dim, name = "z_log_var" )(x)
z = Sampling()([z_mean, z_log_var])
encoder = Model(encoder_inputs, [z_mean, z_log_var, z], name = "encoder" )
encoder.summary()
 Модель: «энкодер»
__________________________________________________________________________________________________
Слой (тип) Параметр формы вывода # Подключен к                     
================================================== ================================================
input_3 (InputLayer) [(Нет, 28, 28, 1)] 0                                            
__________________________________________________________________________________________________
conv2d_2 (Conv2D) (Нет, 14, 14, 32) 320 input_3 [0] [0]                    
__________________________________________________________________________________________________
conv2d_3 (Conv2D) (Нет, 7, 7, 64) 18496 conv2d_2 [0] [0]                   
__________________________________________________________________________________________________
flatten_1 (Flatten) (Нет, 3136) 0 conv2d_3 [0] [0]                   
__________________________________________________________________________________________________
плотный_2 (плотный) (нет, 16) 50192 flatten_1 [0] [0]                  
__________________________________________________________________________________________________
z_mean (Плотный) (Нет, 2) 34 Плотный_2 [0] [0]                    
__________________________________________________________________________________________________
z_log_var (Плотный) (Нет, 2) 34 Плотный_2 [0] [0]                    
__________________________________________________________________________________________________
sampling_1 (Выборка) (Нет, 2) 0 z_mean [0] [0]                     
                                                                 z_log_var [0] [0]                  
================================================== ================================================
Всего параметров: 69, 076
Обучаемые параметры: 69, 076
Необучаемые параметры: 0
__________________________________________________________________________________________________


  • Теперь мы определяем архитектуру декодирующей части нашего автоэнкодера, эта часть принимает выходные данные слоя выборки в качестве входных и выводит изображение размера (28, 28, 1).

Код:

# Define Decoder Architecture
latent_inputs = keras. Input (shape = (latent_dim, ))
x = Dense( 7 * 7 * 64 , activation = "relu" )(latent_inputs)
x = Reshape(( 7 , 7 , 64 ))(x)
x = Conv2DTranspose( 64 , 3 , activation = "relu" , strides = 2 , padding = "same" )(x)
x = Conv2DTranspose( 32 , 3 , activation = "relu" , strides = 2 , padding = "same" )(x)
decoder_outputs = Conv2DTranspose( 1 , 3 , activation = "sigmoid" , padding = "same" )(x)
decoder = Model(latent_inputs, decoder_outputs, name = "decoder" )
decoder.summary()
 Модель: «декодер»
_________________________________________________________________
Слой (тип) Параметр формы вывода #   
================================================== ===============
input_4 (InputLayer) [(Нет, 2)] 0         
_________________________________________________________________
плотный_3 (Плотный) (Нет, 3136) 9408      
_________________________________________________________________
reshape_1 (Изменить форму) (Нет, 7, 7, 64) 0         
_________________________________________________________________
conv2d_transpose_3 (Conv2DTr (Нет, 14, 14, 64) 36928     
_________________________________________________________________
conv2d_transpose_4 (Conv2DTr (Нет, 28, 28, 32) 18464     
_________________________________________________________________
conv2d_transpose_5 (Conv2DTr (Нет, 28, 28, 1) 289       
================================================== ===============
Всего параметров: 65, 089
Обучаемые параметры: 65, 089
Необучаемые параметры: 0
_________________________________________________________________


  • На этом этапе мы объединяем модель и определяем процедуру обучения с функциями потерь.

Код:

# this class takes encoder and decoder models and
# define the complete variational autoencoder architecture
class VAE(keras.Model):
def __init__( self , encoder, decoder, * * kwargs):
super (VAE, self ).__init__( * * kwargs)
self .encoder = encoder
self .decoder = decoder
def train_step( self , data):
if isinstance (data, tuple ):
data = data[ 0 ]
with tf.GradientTape() as tape:
z_mean, z_log_var, z = encoder(data)
reconstruction = decoder(z)
reconstruction_loss = tf.reduce_mean(
keras.losses.binary_crossentropy(data, reconstruction)
)
reconstruction_loss * = 28 * 28
kl_loss = 1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var)
kl_loss = tf.reduce_mean(kl_loss)
kl_loss * = - 0.5
total_loss = reconstruction_loss + kl_loss
grads = tape.gradient(total_loss, self .trainable_weights)
self .optimizer.apply_gradients( zip (grads, self .trainable_weights))
return {
"loss" : total_loss,
"reconstruction_loss" : reconstruction_loss,
"kl_loss" : kl_loss,
}
  • Теперь самое время обучить нашу вариационную модель автоэнкодера, мы будем обучать ее 100 эпох. Но сначала нам нужно импортировать набор данных Fashion MNIST.

Код:

# load fashion mnist dataset from keras.dataset API
(x_train, _), (x_test, _) = keras.datasets.fashion_mnist.load_data()
fmnist_images = np.concatenate([x_train, x_test], axis = 0 )
# expand dimension to add a color map dimension
fmnist_images = np.expand_dims(fmnist_images, - 1 ).astype( "float32" ) / 255
# compile and train the model
vae = VAE(encoder, decoder)
vae. compile (optimizer = 'rmsprop' )
vae.fit(fmnist_images, epochs = 100 , batch_size = 64 )
 Эпоха 1/100
1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 301.9441 - реконструкция_потеря: 298.3138 - kl_loss : 3.6303
Эпоха 2/100
1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 273,5940 - потеря восстановления: 270,0484 - kl_loss : 3.5456
Эпоха 3/100
1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 269.3337 - восстановление_потер: 265.9077 - kl_loss : 3.4260
Эпоха 4/100
1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 266,8168 - восстановление_потер: 263,4100 - kl_loss : 3,4068
Эпоха 5/100
1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 264,9917 - потеря_конструкции: 261,5603 - kl_loss : 3.4314
Эпоха 6/100
1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 263,5237 - восстановление_потер: 260,0712 - kl_loss : 3,4525
Эпоха 7/100
1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 262,3414 - восстановление_потер: 258,8548 - kl_loss : 3.4865
Эпоха 8/100
1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 261.4241 - реконструкция_потеря: 257.9104 - kl_loss : 3.5137
Эпоха 9/100
1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 260.6090 - реконструкция_потеря: 257.0662 - kl_loss : 3.5428
Эпоха 10/100
1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 259.9735 - Восстановление_потеря: 256.4075 - kl_loss : 3.5660
Эпоха 11/100
1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 259.4184 - реконструкция_потеря: 255.8348 - kl_loss : 3.5836
Эпоха 12/100
1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 258.9688 - реконструкция_потеря: 255.3724 - kl_loss : 3.5964
Эпоха 13/100
1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 258,5413 - реконструкция_потеря: 254.9356 - kl_loss : 3.6057
Эпоха 14/100
1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 258.2400 - реконструкция_потеря: 254.6236 - kl_loss : 3.6163
Эпоха 15/100
1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 257.9335 - реконструкция_потеря: 254.3038 - kl_loss : 3.6298
Эпоха 16/100
1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 257.6331 - восстановление_потер: 253.9993 - kl_loss : 3.6339
Эпоха 17/100
1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 257.4199 - реконструкция_потеря: 253.7707 - kl_loss : 3.6492
Эпоха 18/100
1094/1094 [==============================] - 6 с 6 мс / шаг - потеря: 257.1951 - реконструкция_потеря: 253.5309 - kl_loss : 3.6643
Эпоха 19/100
1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 256.9326 - Восстановление_потеря: 253.2723 - kl_loss : 3.6604
Эпоха 20/100
1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 256.7551 - реконструкция_потеря: 253.0836 - kl_loss : 3.6715
Эпоха 21/100
1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 256,5663 - реконструкция_потеря: 252,8877 - kl_loss : 3.6786
Эпоха 22/100
1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 256.4068 - реконструкция_потеря: 252.7112 - kl_loss : 3.6956
Эпоха 23/100
1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 256,2588 - потеря_конструкции: 252,5588 - kl_loss : 3,7000
Эпоха 24/100
1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 256.0853 - реконструкция_потеря: 252.3794 - kl_loss : 3.7059
Эпоха 25/100
1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 255.9321 - Recovery_loss: 252.2201 - kl_loss : 3.7120
Эпоха 26/100
1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 255.7962 - реконструкция_потеря: 252.0814 - kl_loss : 3.7148
Эпоха 27/100
1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 255.6953 - реконструкция_потеря: 251.9673 - kl_loss : 3.7280
Эпоха 28/100
1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 255.5534 - реконструкция_потеря: 251.8248 - kl_loss : 3,7287
Эпоха 29/100
1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 255,4437 - восстановление_потер: 251,7134 - kl_loss : 3.7303
Эпоха 30/100
1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 255.3439 - реконструкция_потеря: 251.6064 - kl_loss : 3.7375
Эпоха 31/100
1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 255.2326 - реконструкция_потеря: 251.5018 - kl_loss : 3.7308
Эпоха 32/100
1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 255.1356 - реконструкция_потеря: 251.3933 - kl_loss : 3.7423
Эпоха 33/100
1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 255.0660 - реконструкция_потеря: 251.3224 - kl_loss : 3.7436
Эпоха 34/100
1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 254,9977 - Reconstruction_loss: 251.2449 - kl_loss : 3.7528
Эпоха 35/100
1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 254,8857 - восстановление_потер: 251,1363 - kl_loss : 3.7494
Эпоха 36/100
1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 254.7980 - реконструкция_потеря: 251.0481 - kl_loss : 3.7499
Эпоха 37/100
1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 254,7485 - восстановление_потер: 250,9851 - kl_loss : 3.7634
Эпоха 38/100
1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 254,6701 - потеря_ реконструкции: 250,9049 - kl_loss : 3.7652
Эпоха 39/100
1094/1094 [==============================] - 6 с 6 мс / шаг - потеря: 254,6105 - потеря_ реконструкции: 250,8389 - kl_loss : 3.7716
Эпоха 40/100
1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 254.4979 - реконструкция_потеря: 250.7333 - kl_loss : 3,7646
Эпоха 41/100
1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 254,4734 - потеря_конструкции: 250,7037 - kl_loss : 3,7697
Эпоха 42/100
1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 254.4408 - реконструкция_потеря: 250.6576 - kl_loss : 3.7831
Эпоха 43/100
1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 254,3272 - восстановление_потер: 250,5562 - kl_loss : 3.7711
Эпоха 44/100
1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 254,3110 - потеря_конструкции: 250,5354 - kl_loss : 3.7755
Эпоха 45/100
1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 254.1982 - реконструкция_потеря: 250.4256 - kl_loss : 3.7726
Эпоха 46/100
1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 254,1655 - восстановление_потер: 250,3795 - kl_loss : 3.7860
Эпоха 47/100
1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 254,0979 - восстановление_потер: 250,3105 - kl_loss : 3.7875
Эпоха 48/100
1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 254.0801 - реконструкция_потеря: 250.2973 - kl_loss : 3.7828
Эпоха 49/100
1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 254.0101 - реконструкция_потеря: 250.2270 - kl_loss : 3.7831
Эпоха 50/100
1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 253.9512 - реконструкция_потеря: 250.1681 - kl_loss : 3.7831
Эпоха 51/100
1094/1094 [==============================] - 7 с 7 мс / шаг - потеря: 253.9307 - реконструкция_потеря: 250.1408 - kl_loss : 3.7899
Эпоха 52/100
1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 253,8858 - восстановление_потер: 250,1059 - kl_loss : 3.7800
Эпоха 53/100
1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 253,8118 - потеря_ реконструкции: 250,0236 - kl_loss : 3,7882
Эпоха 54/100
1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 253,8171 - потеря_ реконструкции: 250,0325 - kl_loss : 3.7845
Эпоха 55/100
1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 253,7622 - реконструкция_потеря: 249.9735 - kl_loss : 3.7887
Эпоха 56/100
1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 253,7338 - Reconstruction_loss: 249.9380 - kl_loss : 3,7959
Эпоха 57/100
1094/1094 [==============================] - 6 с 6 мс / шаг - потеря: 253,6761 - реконструкция_потеря: 249,8792 - kl_loss : 3.7969
Эпоха 58/100
1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 253,6236 - восстановление_потер: 249,8283 - kl_loss : 3,7954
Эпоха 59/100
1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 253,6181 - восстановление_потер: 249,8236 - kl_loss : 3.7945
Эпоха 60/100
1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 253,5509 - Reconstruction_loss: 249.7587 - kl_loss : 3.7921
Эпоха 61/100
1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 253.5124 - реконструкция_потеря: 249.7126 - kl_loss : 3,7998
Эпоха 62/100
1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 253,4739 - Reconstruction_loss: 249.6683 - kl_loss : 3.8056
Эпоха 63/100
1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 253.4609 - реконструкция_потеря: 249.6567 - kl_loss : 3.8042
Эпоха 64/100
1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 253,4066 - реконструкция_потеря: 249.6020 - kl_loss : 3.8045
Эпоха 65/100
1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 253,3578 - восстановление_потер: 249,5580 - kl_loss : 3,7998
Эпоха 66/100
1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 253,3728 - реконструкция_потеря: 249,5609 - kl_loss : 3.8118
Эпоха 67/100
1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 253,3523 - реконструкция_потеря: 249,5351 - kl_loss : 3.8171
Эпоха 68/100
1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 253.2646 - реконструкция_потеря: 249.4452 - kl_loss : 3.8194
Эпоха 69/100
1094/1094 [==============================] - 6 с 6 мс / шаг - потеря: 253.2642 - реконструкция_потеря: 249.4603 - kl_loss : 3.8040
Эпоха 70/100
1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 253.2227 - реконструкция_потеря: 249.4159 - kl_loss : 3.8068
Эпоха 71/100
1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 253,1848 - восстановление_потер: 249,3755 - kl_loss : 3.8094
Эпоха 72/100
1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 253.1812 - реконструкция_потеря: 249.3737 - kl_loss : 3.8074
Эпоха 73/100
1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 253.1803 - реконструкция_потеря: 249.3743 - kl_loss : 3.8059
Эпоха 74/100
1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 253,1295 - восстановление_потер: 249,3114 - kl_loss : 3.8181
Эпоха 75/100
1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 253.0516 - реконструкция_потеря: 249.2391 - kl_loss : 3.8125
Эпоха 76/100
1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 253,0736 - реконструкция_потеря: 249,2582 - kl_loss : 3.8154
Эпоха 77/100
1094/1094 [==============================] - 6 с 6 мс / шаг - потеря: 253.0331 - реконструкция_потеря: 249.2200 - kl_loss : 3.8131
Эпоха 78/100
1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 253.0479 - Recovery_loss: 249.2272 - kl_loss : 3.8207
Эпоха 79/100
1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 252.9317 - реконструкция_потеря: 249.1137 - kl_loss : 3.8179
Эпоха 80/100
1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 252.9578 - Восстановление_потер: 249.1483 - kl_loss : 3.8095
Эпоха 81/100
1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 252.9072 - реконструкция_потеря: 249.0963 - kl_loss : 3.8109
Эпоха 82/100
1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 252,8793 - Reconstruction_loss: 249.0646 - kl_loss : 3.8147
Эпоха 83/100
1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 252,8914 - восстановление_потер: 249,0676 - kl_loss : 3.8238
Эпоха 84/100
1094/1094 [==============================] - 6 с 6 мс / шаг - потеря: 252.8365 - восстановление_потер: 249.0121 - kl_loss : 3.8244
Эпоха 85/100
1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 252.8063 - восстановление_потер: 248.9844 - kl_loss : 3.8218
Эпоха 86/100
1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 252,7960 - реконструкция_потеря: 248.9777 - kl_loss : 3.8183
Эпоха 87/100
1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 252,7733 - Reconstruction_loss: 248.9529 - kl_loss : 3.8204
Эпоха 88/100
1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 252,7303 - восстановление_потер: 248,9055 - kl_loss : 3.8248
Эпоха 89/100
1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 252,7225 - восстановление_потер: 248,8902 - kl_loss : 3.8323
Эпоха 90/100
1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 252,6822 - восстановление_потери: 248,8549 - kl_loss : 3.8273
Эпоха 91/100
1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 252,6540 - потеря_ реконструкции: 248,8314 - kl_loss : 3.8227
Эпоха 92/100
1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 252,6540 - восстановление_потер: 248,8239 - kl_loss : 3.8300
Эпоха 93/100
1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 252,6213 - восстановление_потер: 248,7778 - kl_loss : 3.8435
Эпоха 94/100
1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 252.5990 - восстановление_потер: 248.7594 - kl_loss : 3.8397
Эпоха 95/100
1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 252,5786 - реконструкция_потеря: 248,7413 - kl_loss : 3.8373
Эпоха 96/100
1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 252,5839 - потеря_ реконструкции: 248,7411 - kl_loss : 3.8427
Эпоха 97/100
1094/1094 [==============================] - 7 с 7 мс / шаг - потеря: 252,5364 - потеря восстановления: 248,6960 - kl_loss : 3.8404
Эпоха 98/100
1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 252,5347 - реконструкция_потеря: 248,6915 - kl_loss : 3.8431
Эпоха 99/100
1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 252.4996 - реконструкция_потеря: 248.6569 - kl_loss : 3.8428
Эпоха 100/100
1094/1094 [==============================] - 7 с 6 мс / шаг - потеря: 252,4938 - Reconstruction_loss: 248.6405 - kl_loss : 3.8533
<tensorflow.python.keras.callbacks.History в 0x7f5467c56be0>

  • На этом этапе мы отображаем результаты обучения, мы будем отображать эти результаты в соответствии с их значениями в векторах скрытого пространства.

Код:

def plot_latent(encoder, decoder):
# display an * n 2D manifold of imagess
n = 10
img_dim = 28
scale = 2.0
figsize = 15
figure = np.zeros((img_dim * n, img_dim * n))
# linearly spaced coordinates corresponding to the 2D plot
# of images classes in the latent space
grid_x = np.linspace( - scale, scale, n)
grid_y = np.linspace( - scale, scale, n)[:: - 1 ]
for i, yi in enumerate (grid_y):
for j, xi in enumerate (grid_x):
z_sample = np.array([[xi, yi]])
x_decoded = decoder.predict(z_sample)
images = x_decoded[ 0 ].reshape(img_dim, img_dim)
figure[
i * img_dim : (i + 1 ) * img_dim,
j * img_dim : (j + 1 ) * img_dim,
] = images
plt.figure(figsize = (figsize, figsize))
start_range = img_dim / / 2
end_range = n * img_dim + start_range + 1
pixel_range = np.arange(start_range, end_range, img_dim)
sample_range_x = np. round (grid_x, 1 )
sample_range_y = np. round (grid_y, 1 )
plt.xticks(pixel_range, sample_range_x)
plt.yticks(pixel_range, sample_range_y)
plt.xlabel( "z[0]" )
plt.ylabel( "z[1]" )
plt.imshow(figure, cmap = "Greys_r" )
plt.show()
plot_latent(encoder, decoder)

  • Чтобы получить более четкое представление о наших репрезентативных значениях скрытых векторов, мы будем строить диаграмму разброса обучающих данных на основе их значений соответствующих скрытых размеров, созданных кодировщиком.

Код:

def plot_label_clusters(encoder, decoder, data, test_lab):
z_mean, _, _ = encoder.predict(data)
plt.figure(figsize = ( 12 , 10 ))
sc = plt.scatter(z_mean[:, 0 ], z_mean[:, 1 ], c = test_lab)
cbar = plt.colorbar(sc, ticks = range ( 10 ))
cbar.ax.set_yticklabels([labels.get(i) for i in range ( 10 )])
plt.xlabel( "z[0]" )
plt.ylabel( "z[1]" )
plt.show()
labels = { 0 : "T-shirt / top" ,
1 : "Trouser" ,
2 : "Pullover" ,
3 : "Dress" ,
4 : "Coat" ,
5 : "Sandal" ,
6 : "Shirt" ,
7 : "Sneaker" ,
8 : "Bag" ,
9 : "Ankle boot" }
(x_train, y_train), _ = keras.datasets.fashion_mnist.load_data()
x_train = np.expand_dims(x_train, - 1 ).astype( "float32" ) / 255
plot_label_clusters(encoder, decoder, x_train, y_train)

Использованная литература:

  • Бумага для вариационного автоэнкодера
  • Вариационный автоэнкодер Keras