Как можно использовать Tensorflow с набором данных цветов для компиляции и подбора модели?

Опубликовано: 19 Февраля, 2023

В этой статье мы узнаем, как мы можем скомпилировать модель и подогнать к ней набор данных цветов. ЧТОБЫ подогнать набор данных к модели, нам нужно сначала создать конвейер данных, создать архитектуру модели с помощью высокоуровневого API TensorFlow, а затем, прежде чем подгонять модель к данным с помощью конвейеров данных, нам нужно скомпилировать модель с соответствующей функцией потерь. и оптимизатор и метрика, чтобы понять, развивается ли модель эпоха за эпохой или нет.

Импорт библиотек

Библиотеки Python позволяют нам очень легко обрабатывать данные и выполнять типичные и сложные задачи с помощью одной строки кода.

  • Pandas — эта библиотека помогает загружать фрейм данных в формате 2D-массива и имеет несколько функций для выполнения задач анализа за один раз.
  • Numpy — массивы Numpy работают очень быстро и могут выполнять большие вычисления за очень короткое время.
  • Matplotlib — эта библиотека используется для рисования визуализаций.
  • Sklearn — этот модуль содержит несколько библиотек с предварительно реализованными функциями для выполнения задач от предварительной обработки данных до разработки и оценки моделей.
  • Tensorflow — это библиотека с открытым исходным кодом, которая используется для машинного обучения и искусственного интеллекта и предоставляет ряд функций для реализации сложных функций с помощью одной строки кода.

Python3




import numpy as np
import pandas as pd
import seaborn as sb
import matplotlib.pyplot as plt
  
from glob import glob
from PIL import Image
from sklearn.model_selection import train_test_split
from skimage.feature import local_binary_pattern
  
import tensorflow as tf
from tensorflow import keras
from keras import layers
  
AUTO = tf.data.experimental.AUTOTUNE
import warnings
warnings.filterwarnings("ignore")

Теперь давайте проверим общее количество изображений, которые у нас есть для всех классов цветов. Ссылка на набор данных здесь https://www.kaggle.com/datasets/alxmamaev/flowers-recognition.

Python3




images = glob("flowers/*/*.jpg")
len(images)

Выход:

4317

Python3




df = pd.DataFrame({"filepath": images})
df["label"] = df["filepath"].str.split("/", expand=True)[1]
df.head()

Выход:

Python3




from sklearn.preprocessing import LabelEncoder
le = LabelEncoder()
df["encoded"] = le.fit_transform(df["label"])
df.head()

Выход:

Давайте проверим различные классы, присутствующие в обучающих данных, и какой класс был назначен какому целому числу.

Python3




classes = le.classes_
classes

Выход:

array(["daisy", "dandelion", "rose", "sunflower", "tulip"], dtype=object)

Визуализация данных

В этом разделе мы попытаемся понять и визуализировать некоторые изображения, которые были предоставлены нам для создания классификатора для каждого класса. Кроме того, мы проверим проблему дисбаланса.

Python3




x = df["label"].value_counts()
plt.pie(x.values,
        labels=x.index,
        autopct="%1.1f%%")
plt.show()

Выход:

Из приведенного выше графика мы можем сказать, что в данном наборе данных есть небольшая проблема с дисбалансом данных. Но обработка баланса данных не является целью этой статьи.

Python3




for cat in df["label"].unique():
    temp = df[df["label"] == cat]
  
    index_list = temp.index
    fig, ax = plt.subplots(1, 4, figsize=(15, 5))
    fig.suptitle(f"Images for {cat} category . . . .", fontsize=20)
    for i in range(4):
        index = np.random.randint(0, len(index_list))
        index = index_list[index]
        data = df.iloc[index]
  
        image_path = data[0]
  
        img = Image.open(image_path).resize((256, 256))
        img = np.array(img)
        ax[i].imshow(img)
        ax[i].axis("off")
plt.tight_layout()
plt.show()

Выход:

Python3




features = df["filepath"]
target = df["encoded"]
  
X_train, X_val,
 Y_train, Y_val = train_test_split(features, target,
                                   test_size=0.15,
                                   random_state=10)
   
X_train.shape, X_val.shape

Выход:

((3669,), (648,))

Теперь, используя вышеуказанную функцию, мы будем реализовывать наш конвейер ввода обучающих данных и конвейер данных проверки.

Python3




train_ds = (
    tf.data.Dataset
    .from_tensor_slices((X_train, Y_train))
    .map(decode_image, num_parallel_calls=AUTO)
    .batch(32)
    .prefetch(AUTO)
)
  
val_ds = (
    tf.data.Dataset
    .from_tensor_slices((X_val, Y_val))
    .map(decode_image, num_parallel_calls=AUTO)
    .batch(32)
    .prefetch(AUTO)
)

Разработка модели

Мы будем использовать предварительно обученный вес для начальной сети, которая обучена на наборе данных imagenet. Этот набор данных содержит миллионы изображений примерно для 1000 классов изображений. Параметры модели, которую мы импортируем, уже обучены на миллионах изображений и в течение нескольких недель нам не нужно обучать их снова.

Python3




from tensorflow.keras.applications.resnet50 import ResNet50
  
pre_trained_model = ResNet50(
    input_shape = (224,224,3),
    weights = "imagenet",
    include_top = False
)
  
for layer in pre_trained_model.layers:
  layer.trainable = False

Выход:

94765736/94765736 [==============================] - 5s 0us/step

Архитектура модели

Мы реализуем модель с использованием функционального API Keras, которая будет содержать следующие части:

  • В данном случае базовой моделью является начальная модель.
  • Слой Flatten сглаживает выходные данные базовой модели.
  • Тогда у нас будет два полносвязных слоя, за которыми следует вывод сглаженного слоя.
  • Мы включили несколько слоев BatchNormalization, чтобы обеспечить стабильное и быстрое обучение, и слой Dropout перед последним слоем, чтобы избежать любой возможности переобучения.
  • Последний слой — это выходной слой, который выводит мягкие вероятности для трех классов.

Python3




from tensorflow.keras import Model
  
inputs = layers.Input(shape=(224, 224, 3))
x = layers.Flatten()(inputs)
  
x = layers.Dense(256,activation="relu")(x)
x = layers.BatchNormalization()(x)
x = layers.Dense(256,activation="relu")(x)
x = layers.Dropout(0.3)(x)
x = layers.BatchNormalization()(x)
outputs = layers.Dense(5, activation="softmax")(x)
  
model = Model(inputs, outputs)

При составлении модели мы предоставляем эти три основных параметра:

  • оптимизатор — это метод, который помогает оптимизировать функцию стоимости с помощью градиентного спуска.
  • потеря — функция потерь, с помощью которой мы отслеживаем, улучшается ли модель с обучением или нет.
  • Метрики — это помогает оценить модель, предсказывая данные обучения и проверки.

Python3




model.compile(
    loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),
    optimizer="adam",
    metrics=["AUC"]
)

Теперь мы готовы обучить нашу модель.

Python3




history = model.fit(train_ds,
                    validation_data=val_ds,
                    epochs=5,
                    verbose=1)

Выход:

Epoch 1/5
115/115 [==============================] - 8s 60ms/step - loss: 1.5825 - auc: 0.7000 - val_loss: 1.6672 - val_auc: 0.7152
Epoch 2/5
115/115 [==============================] - 7s 59ms/step - loss: 1.3806 - auc: 0.7650 - val_loss: 1.4497 - val_auc: 0.7531
Epoch 3/5
115/115 [==============================] - 8s 68ms/step - loss: 1.2619 - auc: 0.7980 - val_loss: 1.3494 - val_auc: 0.7751
Epoch 4/5
115/115 [==============================] - 7s 58ms/step - loss: 1.1828 - auc: 0.8242 - val_loss: 1.3371 - val_auc: 0.7751
Epoch 5/5
115/115 [==============================] - 7s 60ms/step - loss: 1.0954 - auc: 0.8485 - val_loss: 1.8526 - val_auc: 0.7215

В приведенном ниже коде мы создадим фрейм данных из журнала, полученного в результате обучения модели.

Python3




hist_df=pd.DataFrame(history.history)
hist_df.head()

Выход:

Давайте визуализируем потери при обучении и потери данных при проверке.

Python3




hist_df["loss"].plot()
hist_df["val_loss"].plot()
plt.title("Loss v/s Validation Loss")
plt.legend()
plt.show()

Выход:

Давайте визуализируем AUC обучения и AUC проверки данных.

Python3




hist_df["auc"].plot()
hist_df["val_auc"].plot()
plt.title("AUC v/s Validation AUC")
plt.legend()
plt.show()

Выход: