새소식

반응형
Vision AI/미분류

CIFAR-10 데이터 분석하기

  • -
반응형

CIFAR-10 데이터 분석을 해보겠습니다.

 

CIFAR-10은 대표적인 이미지 분류의 벤치마크 데이터셋입니다.

 

저는 CIFAR-10을 깃헙에 올려주신 github.com/YoongiKim/CIFAR-10-images 이 분의 레포지토리에서 데이터를 다운받아 작성하였습니다.

 

from glob import glob
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt

# base_dir은 현재 CIFAR-10이 들어있는 폴더로 지정해줍니다.

base_dir = './'
train_data = glob(os.path.join(base_dir, 'train\\*\\*.jpg'))
test_data = glob(os.path.join(base_dir, 'test\\*\\*.jpg'))

 

# 데이터 수 출력
print(f'train_data 수 : {len(train_data)}')
print(f'test_data 수 : {len(test_data)}\n')

train_data 수 : 50000
test_data 수 : 10000

train, test 셋의 수를 파악합니다. train셋에서 얼만큼 train, validation셋으로 분할할지 생각할 수 있습니다. 물론 이 데이터는 잘 정제된 데이터라 이상적으로 배치되어 있겠지만, 실무에서는 그렇지 않기 때문입니다.

 

# 클래스 출력
classes = os.listdir(os.path.join(base_dir, 'train'))
print(f'클래스 이름 : {classes}\n')

클래스 이름 : ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

이진분류인지 다중분류인지를 보기 위해 데이터의 클래스를 파악합니다.

 

# 클래스별 train, test 데이터 수
for class_name in classes:
    class_dir = glob(os.path.join(base_dir, 'train', class_name, '*.jpg'))
    print(f'train_{class_name} 수 : {len(class_dir)}')

for class_name in classes:
    class_dir = glob(os.path.join(base_dir, 'test', class_name, '*.jpg'))
    print(f'test_{class_name} 수 : {len(class_dir)}')

데이터의 불균형이 있는지 확인하기 위해 클래스별 데이터 수를 봅니다.

 

 

# 클래스별 예시
n_examples = 10

def get_image_sample(directory, n_sample=7):
    image_paths = glob(os.path.join(directory, '*.jpg'))
    random_samples = np.random.choice(image_paths, n_sample)
    result = []
    result.append(image_paths[0].split('\\')[-2])
    result.append([cv2.imread(x) for x in random_samples])
    return result

fig, axarr = plt.subplots(len(classes), n_examples, figsize=(15,15))
for idx, class_name in enumerate(classes): 
    image_samples = get_image_sample(os.path.join(base_dir, 'train', class_name), n_examples)
    for i in range(n_examples):
        axarr[idx, i].imshow(image_samples[1][i])
        axarr[idx, i].axis('off')

plt.suptitle(f'Training {n_examples} examples of each class in CIFAR-10')
plt.show()
plt.close()

클래스별 데이터의 생김새를 파악합니다. 저는 샘플로 클래스당 10개정도 플로팅해보았습니다.

 

CIFAR-10에서 클래스별 10장씩의 학습이미지를 플로팅
CIFAR-10에서 클래스별 10장씩의 테스트이미지를 플로팅

 

두 번째 행의 자동차 이미지를 관찰해보겠습니다.

 

자동차의 색이 다르거나, 찍힌 방향이 다르고 크기도 제각각입니다. 그러므로 테스트 이미지를 잘 분류하기 위해서는 어떤 데이터 증강법을 사용해야 하는지를 정의할 수 있습니다.

 

또한, 클래스별 데이터 수를 플로팅하여 클래스별 데이터 수가 불균형적인지 균형적인지도 파악하여야 합니다. 불균형적이라면 class weight, loss weight 등 추가적인 조치가 필요할 수 있습니다.

 

반응형

'Vision AI > 미분류' 카테고리의 다른 글

Convolution layer에 대한 고찰  (0) 2021.08.18
tf.keras로 CIFAR-10 데이터 학습하기  (0) 2021.02.11
K-fold cross validation  (0) 2019.11.09
ROC curve 간단 정리  (0) 2019.05.17
Precision, Recall, Accuracy 간단 정리  (2) 2019.05.16
Contents

포스팅 주소를 복사했습니다

이 글이 도움이 되었다면 공감과 광고 클릭 부탁드립니다~