人工知能(AI)のディープラーニングを使って画像分類を行いました。画像分類には、畳み込みニューラルネットワーク(CNN)を使いました。構築したモデルは、畳み込み層、プーリング層、全結合層、出力層から成ります。そのモデルを一から訓練してみました。
画像分類には5種類の花を使用しました。daisy、dandelion、rose、sunflower、tulipの5種類です。訓練データには、それぞれの花の画像を500枚ずつ、検証データには200枚ずつを使用しました。学習した結果の正解率は5割程度です。試しにsunflowerのテスト画像で予測してみたところ正解しています。
このモデルの正解率は5割程度と低いのですが、同じ訓練データを使って学習したVGG16の転移学習の正解率は高い(8割以上)です。

 

CNNを訓練して画像を分類してみる

In [1]:
import os

import keras
from keras.preprocessing.image import ImageDataGenerator
from keras.models import Sequential, Model
from keras.layers import Input, Dense, Dropout, Activation, Flatten,MaxPooling2D,Conv2D
from keras import optimizers
Using TensorFlow backend.
In [2]:
keras.__version__
Out[2]:
'2.2.4'

訓練画像、検証画像、テスト画像のディレクトリ

In [3]:
# 分類クラス
classes = ['daisy', 'dandelion','rose','sunflower','tulip']
nb_classes = len(classes)
batch_size_for_data_generator = 20

base_dir = "."

train_dir = os.path.join(base_dir, 'train_images')
validation_dir = os.path.join(base_dir, 'validation_images')
test_dir = os.path.join(base_dir, 'test_images')

train_daisy_dir = os.path.join(train_dir, 'daisy')
train_dandelion_dir = os.path.join(train_dir, 'dandelion')
train_rose_dir = os.path.join(train_dir, 'rose')
train_sunflower_dir = os.path.join(train_dir, 'sunflower')
train_tulip_dir = os.path.join(train_dir, 'tulip')

validation_daisy_dir = os.path.join(validation_dir, 'daisy')
validation_dandelion_dir = os.path.join(validation_dir, 'dandelion')
validation_rose_dir = os.path.join(validation_dir, 'rose')
validation_sunflower_dir = os.path.join(validation_dir, 'sunflower')
validation_tulip_dir = os.path.join(validation_dir, 'tulip')

test_daisy_dir = os.path.join(test_dir, 'daisy')
test_dandelion_dir = os.path.join(test_dir, 'dandelion')
test_rose_dir = os.path.join(test_dir, 'rose')
test_sunflower_dir = os.path.join(test_dir, 'sunflower')
test_tulip_dir = os.path.join(test_dir, 'tulip')

# 画像サイズ
img_rows, img_cols = 200, 200

画像データの数を確認する

In [4]:
print('total training daisy images:', len(os.listdir(train_daisy_dir)),train_daisy_dir)
print('total training dandelion images:', len(os.listdir(train_dandelion_dir)),train_dandelion_dir)
print('total training rose images:', len(os.listdir(train_rose_dir)),train_rose_dir)
print('total training sunflower images:', len(os.listdir(train_sunflower_dir)),train_sunflower_dir)
print('total training tulip images:', len(os.listdir(train_tulip_dir)),train_tulip_dir)

print('total validation daisy images:', len(os.listdir(validation_daisy_dir)),validation_daisy_dir)
print('total validation dandelion images:', len(os.listdir(validation_dandelion_dir)),validation_dandelion_dir)
print('total validation rose images:', len(os.listdir(validation_rose_dir)),validation_rose_dir)
print('total validation sunflower images:', len(os.listdir(validation_sunflower_dir)),validation_sunflower_dir)
print('total validation tulip images:', len(os.listdir(validation_tulip_dir)),validation_tulip_dir)

print('total test daisy images:', len(os.listdir(test_daisy_dir)),test_daisy_dir)
print('total test dandelion images:', len(os.listdir(test_dandelion_dir)),test_dandelion_dir)
print('total test rose images:', len(os.listdir(test_rose_dir)),test_rose_dir)
print('total test sunflower images:', len(os.listdir(test_sunflower_dir)),test_sunflower_dir)
print('total test tulip images:', len(os.listdir(test_tulip_dir)),test_tulip_dir)
total training daisy images: 500 .\train_images\daisy
total training dandelion images: 500 .\train_images\dandelion
total training rose images: 500 .\train_images\rose
total training sunflower images: 500 .\train_images\sunflower
total training tulip images: 500 .\train_images\tulip
total validation daisy images: 200 .\validation_images\daisy
total validation dandelion images: 200 .\validation_images\dandelion
total validation rose images: 200 .\validation_images\rose
total validation sunflower images: 200 .\validation_images\sunflower
total validation tulip images: 200 .\validation_images\tulip
total test daisy images: 69 .\test_images\daisy
total test dandelion images: 352 .\test_images\dandelion
total test rose images: 84 .\test_images\rose
total test sunflower images: 34 .\test_images\sunflower
total test tulip images: 284 .\test_images\tulip

ImageDataGeneratorを使って画像データを拡張する

In [5]:
train_datagen = ImageDataGenerator(rescale=1.0 / 255,shear_range=0.2,zoom_range=0.2,horizontal_flip=True)

train_generator = train_datagen.flow_from_directory(directory=train_dir,target_size=(img_rows, img_cols),color_mode='rgb',classes=classes,class_mode='categorical',batch_size=batch_size_for_data_generator,shuffle=True)
Found 2500 images belonging to 5 classes.
In [6]:
test_datagen = ImageDataGenerator(rescale=1.0 / 255)
    
validation_generator = test_datagen.flow_from_directory(directory=validation_dir,target_size=(img_rows, img_cols),color_mode='rgb',classes=classes,class_mode='categorical',batch_size=batch_size_for_data_generator,shuffle=True)
Found 1000 images belonging to 5 classes.

CNNモデル

In [7]:
model=Sequential()
model.add(Conv2D(32,(3,3),activation='relu',input_shape=(img_rows, img_cols, 3)))
model.add(MaxPooling2D((2,2)))
model.add(Conv2D(64,(3,3),activation='relu'))
model.add(MaxPooling2D((2,2)))
model.add(Conv2D(128,(3,3),activation='relu'))
model.add(MaxPooling2D((2,2)))
model.add(Conv2D(128,(3,3),activation='relu'))
model.add(MaxPooling2D((2,2)))
model.add(Conv2D(128,(3,3),activation='relu'))
model.add(MaxPooling2D((2,2)))
model.add(Flatten())
model.add(Dense(512,activation='relu'))
model.add(Dense(nb_classes,activation='softmax'))
          
model.summary()
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
conv2d_1 (Conv2D)            (None, 198, 198, 32)      896       
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 99, 99, 32)        0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 97, 97, 64)        18496     
_________________________________________________________________
max_pooling2d_2 (MaxPooling2 (None, 48, 48, 64)        0         
_________________________________________________________________
conv2d_3 (Conv2D)            (None, 46, 46, 128)       73856     
_________________________________________________________________
max_pooling2d_3 (MaxPooling2 (None, 23, 23, 128)       0         
_________________________________________________________________
conv2d_4 (Conv2D)            (None, 21, 21, 128)       147584    
_________________________________________________________________
max_pooling2d_4 (MaxPooling2 (None, 10, 10, 128)       0         
_________________________________________________________________
conv2d_5 (Conv2D)            (None, 8, 8, 128)         147584    
_________________________________________________________________
max_pooling2d_5 (MaxPooling2 (None, 4, 4, 128)         0         
_________________________________________________________________
flatten_1 (Flatten)          (None, 2048)              0         
_________________________________________________________________
dense_1 (Dense)              (None, 512)               1049088   
_________________________________________________________________
dense_2 (Dense)              (None, 5)                 2565      
=================================================================
Total params: 1,440,069
Trainable params: 1,440,069
Non-trainable params: 0
_________________________________________________________________
In [8]:
model.compile(loss='categorical_crossentropy',optimizer=optimizers.RMSprop(lr=1e-4), metrics=['acc'])

学習

In [9]:
history = model.fit_generator(train_generator,steps_per_epoch=25,epochs=50,validation_data=validation_generator,validation_steps=10,verbose=1) 
Epoch 1/50
25/25 [==============================] - 5s 201ms/step - loss: 1.5583 - acc: 0.2760 - val_loss: 1.4675 - val_acc: 0.4200
Epoch 2/50
25/25 [==============================] - 4s 160ms/step - loss: 1.4632 - acc: 0.3700 - val_loss: 1.3868 - val_acc: 0.3800
Epoch 3/50
25/25 [==============================] - 4s 161ms/step - loss: 1.4359 - acc: 0.3720 - val_loss: 1.3752 - val_acc: 0.4600
Epoch 4/50
25/25 [==============================] - 4s 162ms/step - loss: 1.3319 - acc: 0.4480 - val_loss: 1.2934 - val_acc: 0.4350
Epoch 5/50
25/25 [==============================] - 4s 160ms/step - loss: 1.2935 - acc: 0.4500 - val_loss: 1.1782 - val_acc: 0.4850
Epoch 6/50
25/25 [==============================] - 4s 165ms/step - loss: 1.2595 - acc: 0.4720 - val_loss: 1.1379 - val_acc: 0.5200
Epoch 7/50
25/25 [==============================] - 4s 161ms/step - loss: 1.1867 - acc: 0.4980 - val_loss: 1.1652 - val_acc: 0.4900
Epoch 8/50
25/25 [==============================] - 4s 164ms/step - loss: 1.1645 - acc: 0.5280 - val_loss: 1.1589 - val_acc: 0.5650
Epoch 9/50
25/25 [==============================] - 4s 163ms/step - loss: 1.1807 - acc: 0.4800 - val_loss: 1.1266 - val_acc: 0.5450
Epoch 10/50
25/25 [==============================] - 4s 165ms/step - loss: 1.1370 - acc: 0.5240 - val_loss: 1.1999 - val_acc: 0.4250
Epoch 11/50
25/25 [==============================] - 4s 166ms/step - loss: 1.1475 - acc: 0.5400 - val_loss: 1.0494 - val_acc: 0.5950
Epoch 12/50
25/25 [==============================] - 4s 165ms/step - loss: 1.1475 - acc: 0.5180 - val_loss: 1.0844 - val_acc: 0.5950
Epoch 13/50
25/25 [==============================] - 4s 160ms/step - loss: 1.1083 - acc: 0.5140 - val_loss: 1.1558 - val_acc: 0.5350
Epoch 14/50
25/25 [==============================] - 4s 158ms/step - loss: 1.1278 - acc: 0.5280 - val_loss: 1.0092 - val_acc: 0.5900
Epoch 15/50
25/25 [==============================] - 4s 163ms/step - loss: 1.0067 - acc: 0.5760 - val_loss: 1.1988 - val_acc: 0.5150
Epoch 16/50
25/25 [==============================] - 4s 163ms/step - loss: 1.0782 - acc: 0.5680 - val_loss: 1.1315 - val_acc: 0.5050
Epoch 17/50
25/25 [==============================] - 4s 162ms/step - loss: 1.0361 - acc: 0.5800 - val_loss: 1.1039 - val_acc: 0.5550
Epoch 18/50
25/25 [==============================] - 4s 162ms/step - loss: 1.0494 - acc: 0.5600 - val_loss: 1.0331 - val_acc: 0.5850
Epoch 19/50
25/25 [==============================] - 4s 160ms/step - loss: 1.0974 - acc: 0.5600 - val_loss: 0.9825 - val_acc: 0.6100
Epoch 20/50
25/25 [==============================] - 4s 163ms/step - loss: 0.9924 - acc: 0.6160 - val_loss: 1.0714 - val_acc: 0.5400
Epoch 21/50
25/25 [==============================] - 4s 164ms/step - loss: 1.0482 - acc: 0.5920 - val_loss: 1.1485 - val_acc: 0.5450
Epoch 22/50
25/25 [==============================] - 4s 158ms/step - loss: 0.9660 - acc: 0.6060 - val_loss: 1.1102 - val_acc: 0.5650
Epoch 23/50
25/25 [==============================] - 4s 159ms/step - loss: 0.9943 - acc: 0.6040 - val_loss: 1.0671 - val_acc: 0.5600
Epoch 24/50
25/25 [==============================] - 4s 163ms/step - loss: 1.0023 - acc: 0.6120 - val_loss: 0.9693 - val_acc: 0.6300
Epoch 25/50
25/25 [==============================] - 4s 162ms/step - loss: 1.0322 - acc: 0.5720 - val_loss: 0.9546 - val_acc: 0.6250
Epoch 26/50
25/25 [==============================] - 4s 163ms/step - loss: 0.9558 - acc: 0.6340 - val_loss: 0.9550 - val_acc: 0.6150
Epoch 27/50
25/25 [==============================] - 4s 159ms/step - loss: 1.0270 - acc: 0.5920 - val_loss: 1.0458 - val_acc: 0.6200
Epoch 28/50
25/25 [==============================] - 4s 162ms/step - loss: 1.0035 - acc: 0.5840 - val_loss: 1.0653 - val_acc: 0.5550
Epoch 29/50
25/25 [==============================] - 4s 163ms/step - loss: 0.9557 - acc: 0.6240 - val_loss: 0.9648 - val_acc: 0.6450
Epoch 30/50
25/25 [==============================] - 4s 162ms/step - loss: 0.9079 - acc: 0.6220 - val_loss: 0.9516 - val_acc: 0.6250
Epoch 31/50
25/25 [==============================] - 4s 167ms/step - loss: 0.8651 - acc: 0.6720 - val_loss: 0.9837 - val_acc: 0.6350
Epoch 32/50
25/25 [==============================] - 4s 165ms/step - loss: 1.0389 - acc: 0.6160 - val_loss: 0.9686 - val_acc: 0.5750
Epoch 33/50
25/25 [==============================] - 4s 163ms/step - loss: 0.9761 - acc: 0.6040 - val_loss: 0.9374 - val_acc: 0.6250
Epoch 34/50
25/25 [==============================] - 4s 161ms/step - loss: 0.8965 - acc: 0.6240 - val_loss: 0.9805 - val_acc: 0.6750
Epoch 35/50
25/25 [==============================] - 4s 163ms/step - loss: 0.9441 - acc: 0.6420 - val_loss: 1.0402 - val_acc: 0.6200
Epoch 36/50
25/25 [==============================] - 4s 167ms/step - loss: 0.9396 - acc: 0.6240 - val_loss: 0.8549 - val_acc: 0.6750
Epoch 37/50
25/25 [==============================] - 4s 160ms/step - loss: 0.9022 - acc: 0.6720 - val_loss: 0.9711 - val_acc: 0.6300
Epoch 38/50
25/25 [==============================] - 4s 161ms/step - loss: 0.9180 - acc: 0.6360 - val_loss: 1.0083 - val_acc: 0.5500
Epoch 39/50
25/25 [==============================] - 4s 159ms/step - loss: 0.8988 - acc: 0.6360 - val_loss: 1.0903 - val_acc: 0.5750
Epoch 40/50
25/25 [==============================] - 4s 161ms/step - loss: 0.9224 - acc: 0.6320 - val_loss: 0.8654 - val_acc: 0.6950
Epoch 41/50
25/25 [==============================] - 4s 167ms/step - loss: 0.9675 - acc: 0.6100 - val_loss: 0.8410 - val_acc: 0.6850
Epoch 42/50
25/25 [==============================] - 4s 166ms/step - loss: 0.8883 - acc: 0.6500 - val_loss: 1.0312 - val_acc: 0.6050
Epoch 43/50
25/25 [==============================] - 4s 162ms/step - loss: 0.8802 - acc: 0.6500 - val_loss: 0.9148 - val_acc: 0.6200
Epoch 44/50
25/25 [==============================] - 4s 168ms/step - loss: 0.8266 - acc: 0.6800 - val_loss: 0.8887 - val_acc: 0.6550
Epoch 45/50
25/25 [==============================] - 4s 165ms/step - loss: 0.8742 - acc: 0.6540 - val_loss: 0.9059 - val_acc: 0.6250
Epoch 46/50
25/25 [==============================] - 4s 166ms/step - loss: 0.8581 - acc: 0.6620 - val_loss: 0.8741 - val_acc: 0.6300
Epoch 47/50
25/25 [==============================] - 4s 167ms/step - loss: 0.8614 - acc: 0.6600 - val_loss: 1.0364 - val_acc: 0.6050
Epoch 48/50
25/25 [==============================] - 4s 165ms/step - loss: 0.8650 - acc: 0.6640 - val_loss: 0.9213 - val_acc: 0.6400
Epoch 49/50
25/25 [==============================] - 4s 163ms/step - loss: 0.8872 - acc: 0.6520 - val_loss: 0.9023 - val_acc: 0.6650
Epoch 50/50
25/25 [==============================] - 4s 161ms/step - loss: 0.8376 - acc: 0.6680 - val_loss: 0.8574 - val_acc: 0.6400

学習結果を保存する

In [10]:
hdf5_file = os.path.join(base_dir, 'flower-model_cnn.hdf5')
model.save_weights(hdf5_file)

学習推移をグラフに表示する

In [11]:
import matplotlib.pyplot as plt
In [12]:
%matplotlib inline
In [13]:
acc = history.history['acc']
val_acc = history.history['val_acc']
loss = history.history['loss']
val_loss = history.history['val_loss']

epochs = range(len(acc))

plt.plot(epochs, acc, 'bo', label='Training acc')
plt.plot(epochs, val_acc, 'b', label='Validation acc')
plt.title('Training and validation accuracy')
plt.legend()

plt.figure()

plt.plot(epochs, loss, 'bo', label='Training loss')
plt.plot(epochs, val_loss, 'b', label='Validation loss')
plt.title('Training and validation loss')
plt.legend()

plt.show()
訓練データと検証データの正解率のグラフ

訓練データと検証データでの正解率

訓練データと検証データでの損失値のグラフ

訓練データと検証データでの損失値

テストの画像データで正解率を調べる

In [14]:
test_generator = test_datagen.flow_from_directory(directory=test_dir,target_size=(img_rows, img_cols),color_mode='rgb',classes=classes,class_mode='categorical',batch_size=batch_size_for_data_generator)

test_loss, test_acc = model.evaluate_generator(test_generator, steps=50)
print('test acc:', test_acc)
Found 823 images belonging to 5 classes.
test acc: 0.5381485300170805

実際にテスト画像を分離してみる

In [15]:
import numpy as np
from keras.preprocessing.image import load_img, img_to_array
from keras.applications.vgg16 import preprocess_input
In [16]:
filename = os.path.join(test_dir, 'sunflower')
filename = os.path.join(filename, '3681233294_4f06cd8903.jpg')
filename
Out[16]:
'.\\test_images\\sunflower\\3681233294_4f06cd8903.jpg'
In [17]:
from PIL import Image
In [18]:
img = np.array( Image.open(filename))
plt.imshow( img )
Out[18]:
<matplotlib.image.AxesImage at 0x204b3d1ecf8>
テストデータの画像

テストデータ

In [19]:
img = load_img(filename, target_size=(img_rows, img_cols))
x = img_to_array(img)
x = np.expand_dims(x, axis=0)

predict = model.predict(preprocess_input(x))
for pre in predict:
    y = pre.argmax()
    print("test result=",classes[y], pre)
test result= sunflower [0.0000000e+00 0.0000000e+00 0.0000000e+00 1.0000000e+00 1.1292886e-11]