새소식

반응형
Vision AI/미분류

CIFAR-10 데이터 분석하기

  • -
반응형

CIFAR-10 데이터 분석을 해보겠습니다.

 

CIFAR-10은 대표적인 이미지 분류의 벤치마크 데이터셋입니다.

 

저는 CIFAR-10을 깃헙에 올려주신 github.com/YoongiKim/CIFAR-10-images 이 분의 레포지토리에서 데이터를 다운받아 작성하였습니다.

 

from glob import glob import os import cv2 import numpy as np import matplotlib.pyplot as plt # base_dir은 현재 CIFAR-10이 들어있는 폴더로 지정해줍니다. base_dir = './' train_data = glob(os.path.join(base_dir, 'train\\*\\*.jpg')) test_data = glob(os.path.join(base_dir, 'test\\*\\*.jpg'))

 

# 데이터 수 출력 print(f'train_data 수 : {len(train_data)}') print(f'test_data 수 : {len(test_data)}\n') train_data 수 : 50000 test_data 수 : 10000

train, test 셋의 수를 파악합니다. train셋에서 얼만큼 train, validation셋으로 분할할지 생각할 수 있습니다. 물론 이 데이터는 잘 정제된 데이터라 이상적으로 배치되어 있겠지만, 실무에서는 그렇지 않기 때문입니다.

 

# 클래스 출력 classes = os.listdir(os.path.join(base_dir, 'train')) print(f'클래스 이름 : {classes}\n') 클래스 이름 : ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

이진분류인지 다중분류인지를 보기 위해 데이터의 클래스를 파악합니다.

 

# 클래스별 train, test 데이터 수 for class_name in classes: class_dir = glob(os.path.join(base_dir, 'train', class_name, '*.jpg')) print(f'train_{class_name} 수 : {len(class_dir)}') for class_name in classes: class_dir = glob(os.path.join(base_dir, 'test', class_name, '*.jpg')) print(f'test_{class_name} 수 : {len(class_dir)}')

데이터의 불균형이 있는지 확인하기 위해 클래스별 데이터 수를 봅니다.

 

 

# 클래스별 예시 n_examples = 10 def get_image_sample(directory, n_sample=7): image_paths = glob(os.path.join(directory, '*.jpg')) random_samples = np.random.choice(image_paths, n_sample) result = [] result.append(image_paths[0].split('\\')[-2]) result.append([cv2.imread(x) for x in random_samples]) return result fig, axarr = plt.subplots(len(classes), n_examples, figsize=(15,15)) for idx, class_name in enumerate(classes): image_samples = get_image_sample(os.path.join(base_dir, 'train', class_name), n_examples) for i in range(n_examples): axarr[idx, i].imshow(image_samples[1][i]) axarr[idx, i].axis('off') plt.suptitle(f'Training {n_examples} examples of each class in CIFAR-10') plt.show() plt.close()

클래스별 데이터의 생김새를 파악합니다. 저는 샘플로 클래스당 10개정도 플로팅해보았습니다.

 

CIFAR-10에서 클래스별 10장씩의 학습이미지를 플로팅
CIFAR-10에서 클래스별 10장씩의 테스트이미지를 플로팅

 

두 번째 행의 자동차 이미지를 관찰해보겠습니다.

 

자동차의 색이 다르거나, 찍힌 방향이 다르고 크기도 제각각입니다. 그러므로 테스트 이미지를 잘 분류하기 위해서는 어떤 데이터 증강법을 사용해야 하는지를 정의할 수 있습니다.

 

또한, 클래스별 데이터 수를 플로팅하여 클래스별 데이터 수가 불균형적인지 균형적인지도 파악하여야 합니다. 불균형적이라면 class weight, loss weight 등 추가적인 조치가 필요할 수 있습니다.

 

반응형
Contents

포스팅 주소를 복사했습니다

이 글이 도움이 되었다면 공감과 광고 클릭 부탁드립니다~