Классификация изображений с использованием набора данных CIFAR-10 и CIFAR-100 в TensorFlow

Опубликовано: 21 Февраля, 2023

CIFAR10 и CIFAR100 — некоторые из известных эталонных наборов данных, которые используются для обучения CNN задаче компьютерного зрения.

В этой статье мы должны выполнить классификацию изображений для обоих этих наборов данных CIFAR10, а также CIFAR100, поэтому мы будем использовать трансферное обучение здесь.

Но как? Во-первых, мы создадим пользовательскую модель CNN и обучим ее на наших данных CIFAR100. А затем мы будем использовать обученные веса сверточных слоев для построения модели классификации для набора данных CIFAR10.

Импорт библиотек

Нам потребуются следующие библиотеки:

  • ТензорФлоу
  • Нампи
  • Матплотлиб

Python3




import tensorflow as tf
from tensorflow import keras
from keras import layers
  
import numpy as np
import matplotlib.pyplot as plt
  
import warnings
warnings.filterwarnings("ignore")

Импорт набора данных

Теперь давайте сначала загрузим набор данных CIFAR100 с помощью TensorFlow API. Полученные данные уже разделены на наборы данных для обучения и проверки.

Python3




# Load in the data
cifar100 = tf.keras.datasets.cifar100
  
# Distribute it to train and test set
(x_train, y_train), (x_val, y_val) = cifar100.load_data()
print(x_train.shape, y_train.shape, x_val.shape, y_val.shape)

Выход:

Downloading data from https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz
169001437/169001437 [==============================] - 13s 0us/step
(50000, 32, 32, 3) (50000, 1) (10000, 32, 32, 3) (10000, 1)

Это означает, что у нас есть 50 000 изображений формата RGB размеров (32 и 32) в обучающих данных. И около 10 000 изображений в наборе данных проверки.

Визуализация данных

Здесь, на этом этапе, мы визуализируем набор данных, это поможет нам улучшить модель.

Python3




def show_samples(data, labels):
    plt.subplots(figsize=(10, 10))
    for i in range(12):
        plt.subplot(3, 4, i+1)
        k = np.random.randint(0, data.shape[0])
        plt.title(labels[k])
        plt.imshow(data[k])
    plt.tight_layout()
    plt.show()
  
  
show_samples(x_train, y_train)

Выход:

Из-за очень маленького размера изображений их содержание не так ясно. Чтобы узнать, какое число представляет какой класс, можно обратиться к различным источникам, доступным в Интернете.

Разделение данных

Нам нужно разделить данные на обучение и проверку.

Python3




y_train = tf.one_hot(y_train,
                     depth=y_train.max() + 1,
                     dtype=tf.float64)
y_val = tf.one_hot(y_val,
                   depth=y_val.max() + 1,
                   dtype=tf.float64)
  
y_train = tf.squeeze(y_train)
y_val = tf.squeeze(y_val)

Архитектура модели

Теперь давайте определим архитектуру модели, которую мы будем использовать в качестве нашей CNN для классификации этих 100 классов.

Python3




model = tf.keras.models.Sequential([
    layers.Conv2D(16, (3, 3), activation="relu",
                  input_shape=(32, 32, 3), padding="same"),
    layers.Conv2D(32, (3, 3),
                  activation="relu",
                  padding="same"),
    layers.Conv2D(64, (3, 3),
                  activation="relu",
                  padding="same"),
    layers.MaxPooling2D(2, 2),
    layers.Conv2D(128, (3, 3),
                  activation="relu",
                  padding="same"),
  
  
    layers.Flatten(),
    layers.Dense(256, activation="relu"),
    layers.BatchNormalization(),
    layers.Dense(256, activation="relu"),
    layers.Dropout(0.3),
    layers.BatchNormalization(),
    layers.Dense(100, activation="softmax")
])
  
model.compile(
    loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),
    optimizer="adam",
    metrics=["AUC", "accuracy"]
)

Какие изменения вносятся в изображения при углублении в сеть, которые можно визуализировать с помощью сводки модели? Он также показывает количество параметров, которые будут обучаться в этой модели.

Python3




model.summary()

Выход:

Примерка модели

Подгонку модели можно выполнить с помощью приведенного ниже кода.

Python3




hist = model.fit(x_train, y_train,
                 epochs=5,
                 batch_size=64,
                 verbose=1,
                 validation_data=(x_val, y_val))

Выход:

Использование модели для обучения CIFAR10

Теперь мы будем использовать слои свертки этой модели для построения нашего классификатора CIFAR10.

Python3




temp = model.get_layer("conv2d_3")
last_output = temp.output
last_output.shape

Выход:

TensorShape([None, 16, 16, 128])

Теперь реализуем функциональную модель, которая будет использовать выходные данные предыдущей модели и учиться на ее основе.

Python3




x = layers.Flatten()(last_output)
  
x = layers.Dense(256, activation="relu")(x)
x = layers.BatchNormalization()(x)
  
x = layers.Dense(256, activation="relu")(x)
x = layers.Dropout(0.3)(x)
x = layers.BatchNormalization()(x)
  
output = layers.Dense(10, activation="softmax")(x)
  
model_new = keras.Model(model.input, output)

Давайте проверим сводку этой новой модели.

Python3




model_new.summary()

Выход:

Теперь давайте скомпилируем модель с теми же потерями, оптимизатором и метриками, что и в классификаторе CIFAR100.

Python3




model_new.compile(
    loss="categorical_crossentropy",
    optimizer="adam",
    metrics=["AUC", "accuracy"]
)

Итак, мы подготовили новую модель, которая теперь готова к обучению на наборе данных CIFAR10. Давайте также загрузим эти данные с помощью TensorFlow API.

Python3




# Load in the data
cifar10 = tf.keras.datasets.cifar10
  
# Distribute it to train and test set
(x_train, y_train), (x_val, y_val) = cifar10.load_data()
print(x_train.shape, y_train.shape, x_val.shape, y_val.shape)

Выход:

Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
170498071/170498071 [==============================] - 14s 0us/step
(50000, 32, 32, 3) (50000, 1) (10000, 32, 32, 3) (10000, 1)

Python3




y_train = tf.one_hot(y_train, depth=10,
                     dtype=tf.float64)
y_val = tf.one_hot(y_val, depth=10,
                   dtype=tf.float64)
  
y_train = tf.squeeze(y_train)
y_val = tf.squeeze(y_val)

Теперь мы будем обучать нашу модель.

Python3




history = model_new.fit(x_train, y_train,
                        batch_size=64,
                        epochs=5,
                        verbose=1,
                        validation_data=(x_val, y_val))

Выход:

Вывод

Обучая модель для большего количества эпох, можно получить наилучшие результаты.