Data Analyst KIM

[Deep Learning] 이미지 증강 : MRI 뇌 사진을 통한 치매 환자 예측 본문

데이터 분석/ML | DL | NLP

[Deep Learning] 이미지 증강 : MRI 뇌 사진을 통한 치매 환자 예측

김두연 2023. 10. 27. 09:48
반응형

현실에서 데이터 셋이 부족하여 학습의 정확도가 높지 않은 경우가 많다.

이럴 경우 이미지를 증강하여 데이터 셋을 늘리는 방법을 사용할 수 있다.

 

이미지 증강은 지도 학습인데 지도 학습의 성능을 향상시키기 위해서는 수 많은 정답지가 필요하다.

따라서 이미지 증강을 통해 현실에서 적은 데이터의 양을 늘려서 학습을 시켜보자.

 

이미지 증강 기법은 원본 이미지를 회전시키거나, 뒤집거나, 자르는 등의 방법을 통해 새로운 이미지를 생성한다.

이미지를 자르거나 섞는 방법으로 만들어진 이미지는 모델의 과적ㅇ합을 막아주는 중요한 역할을 한다.

단, 너무 많은 증강 기법을 사용하면 학습 시간이 늘어날 수 있기 때문에 불필요하게 많이 생성하면 안된다.

주어진 데이터의 특성을 잘 파악해서 사용하는 것이 효과적이며 학습 데이터 셋에만 사용하자.

 

 

이미지 증강 옵션

옵션 설명
rescale 주어진 이미지의 크기를 바꾸어 줌
horizontal_flip, vertical_flip 주어진 이미지를 수평 또는 수직으로 뒤집음
zoom_range 정해진 범위 안에서 축소 또는 확대함
width_shift, height_shift 정해진 범위 안에서 그림을 수평 또는 수직으로 랜덤하게 평행 이동시킴
rotation_rang 정해진 각도만큼 이미지를 회전시킴
shear_range 좌표 하나를 고정시키고다른 몇 개의 좌표를 이동시키는변환을 함
fill_mode 이미지를 축소 또는 회전하거나이동할 때 생기는빈 공간을 어떻게 채울지결정함
(nearest 옵션을선택하면 가장 비슷한 색으로 채워짐)

 

이미지 증강 예시


MRI 뇌 사진을 통한 치매 환자 예측

  • MRI 사진을 통해서 치매 환자의 뇌 or 일반인의 뇌인지를 CNN을 통해 예측해보자.
  • 이미지 데이터가 부족하기 때문에 이미지 증강 기법을 사용

from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Activation, Dropout, Flatten, Conv2D, MaxPooling2D
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras import optimizers

import numpy as np
import matplotlib.pyplot as plt

# 깃허브에 준비된 데이터를 가져옵니다.
!git clone https://github.com/taehojo/data-ch20.git

# 학습셋의 변형을 설정하는 부분입니다.
train_datagen = ImageDataGenerator(rescale=1./255,          # 주어진 이미지의 크기를 설정합니다.
                                  horizontal_flip=True,     # 수평 대칭 이미지를 50% 확률로 만들어 추가합니다.
                                  width_shift_range=0.1,    # 전체 크기의 15% 범위에서 좌우로 이동합니다.
                                  height_shift_range=0.1,   # 마찬가지로 위, 아래로 이동합니다.
                                  #rotation_range=5,        # 정해진 각도만큼 회전시킵니다.
                                  #shear_range=0.7,         # 좌표 하나를 고정시키고 나머지를 이동시킵니다.
                                  #zoom_range=1.2,          # 확대 또는 축소시킵니다.
                                  #vertical_flip=True,      # 수직 대칭 이미지를 만듭니다.
                                  #fill_mode='nearest'      # 빈 공간을 채우는 방법입니다. nearest 옵션은 가장 비슷한 색으로 채우게 됩니다.
                                  )

train_generator = train_datagen.flow_from_directory(
       './data-ch20/train',   # 학습셋이 있는 폴더의 위치입니다.
       target_size=(150, 150),
       batch_size=5,
       class_mode='binary')

# 테스트셋은 이미지 부풀리기 과정을 진행하지 않습니다.
test_datagen = ImageDataGenerator(rescale=1./255)

test_generator = test_datagen.flow_from_directory(
       './data-ch20/test',   # 테스트셋이 있는 폴더의 위치입니다.
       target_size=(150, 150),
       batch_size=5,
       class_mode='binary')


# 앞서 배운 CNN 모델을 만들어 적용해 보겠습니다.
model = Sequential()
model.add(Conv2D(32, (3, 3), input_shape=(150,150,3)))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))

model.add(Conv2D(32, (3, 3)))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))

model.add(Conv2D(64, (3, 3)))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))

model.add(Flatten())
model.add(Dense(64))
model.add(Activation('relu'))
model.add(Dropout(0.5))
model.add(Dense(1))
model.add(Activation('sigmoid'))
model.summary()
# 모델 실행의 옵션을 설정합니다.
model.compile(loss='binary_crossentropy', optimizer=optimizers.Adam(learning_rate=0.0002), metrics=['accuracy'])

# 학습의 조기 중단을 설정합니다.
early_stopping_callback = EarlyStopping(monitor='val_loss', patience=5)

# 모델을 실행합니다
history = model.fit(
       train_generator,
       epochs=100,
       validation_data=test_generator,
       validation_steps=10,
       callbacks=[early_stopping_callback])

 accuracy와 val_accuracy가 모두 96%이상

# 검증셋과 학습셋의 오차를 저장합니다.
y_vloss = history.history['val_loss']
y_loss = history.history['loss']

# 그래프로 표현해 봅니다.
x_len = np.arange(len(y_loss))
plt.plot(x_len, y_vloss, marker='.', c="red", label='Testset_loss')
plt.plot(x_len, y_loss, marker='.', c="blue", label='Trainset_loss')

# 그래프에 그리드를 주고 레이블을 표시하겠습니다.
plt.legend(loc='upper right')
plt.grid()
plt.xlabel('epoch')
plt.ylabel('loss')
plt.show()

 

반응형