새소식

반응형
Vision AI/미분류

tf.keras로 CIFAR-10 데이터 학습하기

  • -
반응형

CIFAR-10 데이터 학습하는 방법을 조금 자세히 보겠습니다.

 

 

비교적 간단한 방법을 포스팅 한 것이며, 추후 다른 방법으로도 포스팅 하겠습니다.

 

 

먼저 제너레이터를 정의 하겠습니다.

 

 

제너레이터는 쉽게 설명하면 데이터를 한번에 메모리에 저장해놓고 한번에 학습하는 것이 아니라, 표현식을 정의만 해놓고 가용 메모리 범위 내에서 조금씩 반복하여 가져오는 아이입니다.

 

 

먹을 것으로 예를 들면, 사과 1000개를 먹어야한다고 가정할 때, 한번에 먹을 수 없겠죠?? 우리 위의 용량에는 너무 버겁습니다. 그래서 1개씩 1000일 나눠먹는 것과 같은 이치입니다.

 

 

또 써야하는 이유가 여러가지 있는데... 자세한 설명은 다음시간에 하도록 하겠씁니다.

 

 

from tensorflow.keras.preprocessing import image_dataset_from_directory


# 제너레이터 정의
def generator(train_dir, batch_size, image_size):
    train_set = image_dataset_from_directory(directory=train_dir,
                                             label_mode="categorical",
                                             color_mode="rgb",
                                             batch_size=batch_size,
                                             image_size=image_size,
                                             shuffle=True,
                                             seed=42,
                                             validation_split=0.2,
                                             interpolation="bilinear",
                                             subset="training"
                                             )

    val_set = image_dataset_from_directory(directory=train_dir,
                                           label_mode="categorical",
                                           color_mode="rgb",
                                           batch_size=batch_size,
                                           image_size=image_size,
                                           shuffle=False,
                                           seed=42,
                                           validation_split=0.2,
                                           interpolation="bilinear",
                                           subset="validation"
                                           )

    return train_set, val_set

 

image_dataset_from_directory는 제너레이터를 폴더구조로부터 생성하기 위한 함수입니다.

 

directory : 1개의 데이터셋으로 train, validation을 나누고 싶다면 1개의 데이터셋을 train_set, val_set에 지정하면 되고, 이미 train, validation 폴더로 나눠놓았으면 각각 지정하면 됩니다.

 

label_mode : 이진분류, 다중분류를 지정합니다.

 

validation_split : validation 비율을 지정합니다.

 

subset : train인지, val인지를 정의합니다.

 

 

다음으로, 모델을 정의합니다.

 

from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras import layers
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam


# 모델 정의
def build_model(input_shape, lr, n_classes):
    base_model = MobileNetV2(input_shape=(input_shape[0], input_shape[1], input_shape[2]),
                             weights="imagenet",
                             include_top=False)
    gap = layers.GlobalAveragePooling2D()(base_model.output)
    output = layers.Dense(256, activation='relu')(gap)
    output = layers.Dense(n_classes, activation='softmax')(output)

    model = Model(inputs=[base_model.input], outputs=[output])
    model.summary()
    model.compile(optimizer=Adam(lr=lr), loss='categorical_crossentropy', metrics=['categorical_accuracy'])
    return model

 

imagenet pretrained 가중치가 저장된 MobileNetV2를 불러서 GAP를 달아주고, 뒤에 dense를 달아주었습니다.

 

 

마지막단은 클래스별 소프트맥스 값입니다.

 

 

다음으로 정의한 함수들을 부른 후 학습시킵니다.

 

from os.path import join as jo
from datagen import generator
from models import build_model

data_dir = './CIFAR-10'

train_dir = jo(data_dir, 'train')

input_shape = (32, 32, 3)
batch_size = 256
classes = 10
lr = 1e-5
epoch = 10

train_set, val_set = generator(train_dir, batch_size=batch_size, image_size=(32, 32))

model = build_model(input_shape, lr, classes)
model.fit_generator(generator=train_set,
                    validation_data=val_set,
                    steps_per_epoch=len(train_set),
                    validation_steps=len(val_set),
                    epochs=epoch,
                    verbose=1)

 

사전에 파라미터들을 다 정의하고, fit_generator 함수를 사용하여 학습을 시켜주었습니다.

 

 

Keras는 high level API이기 때문에, 비교적 간단한 코드로 실험을 진행할 수 있습니다.

 

 

반응형

'Vision AI > 미분류' 카테고리의 다른 글

VGGnet의 구조 + keras code  (0) 2021.08.23
Convolution layer에 대한 고찰  (0) 2021.08.18
CIFAR-10 데이터 분석하기  (0) 2021.02.04
K-fold cross validation  (0) 2019.11.09
ROC curve 간단 정리  (0) 2019.05.17
Contents

포스팅 주소를 복사했습니다

이 글이 도움이 되었다면 공감과 광고 클릭 부탁드립니다~