Phân loại hoa sử dụng pretrain model

- Phạm Duy Tùng
Ở bài viết này, chúng ta sẽ sử dụng pretrain model của kares để phân loại hoa hồng (rose), hoa mặt trời (sunflower), hoa bồ công anh (dandelion), hoa cúc (daisy) và hoa tulip.

Lời mở đầu

Ở trong bài viết này, chúng ta sẽ sử dụng tập dữ liệu là tập dữ liệu ở ở link https://www.kaggle.com/alxmamaev/flowers-recognition. Tập dữ liệu này bao gồm 4242 hình cảnh của 5 loại hoa hồng (rose), hoa mặt trời (sunflower), hoa bồ công anh (dandelion), hoa cúc (daisy) và hoa tulip. Nhóm tác giả đã thu thập dữ liệu dựa trên các trang web flicr, google images, yandex. Tập hình ảnh được chia thành 5 lớp, mỗi lớp có khoảng 800 hình, có kích thước xấp xỉ 320x320 pixel. Các hình ảnh có kích thước không đồng nhất với nhau.

Thực hiện

Dữ liệu sau khi giản nén có dạng

data_dir/classname1/*.*
data_dir/classname2/*.*
...

Cấu trúc lưu trũ như này đúng với mô hình của mình nên chúng ta cần nên chúng ta không thay đổi gì về câu trúc nữa, tiến hành viết code

Đầu tiên, chúng ta sẽ load dataset lên và tranform nó để đưa vào huấn luyện.

import sys
import os
from collections import defaultdict
import numpy as np
import scipy.misc


def preprocess_input(x0):
    x = x0 / 255.
    x -= 0.5
    x *= 2.
    return x


def reverse_preprocess_input(x0):
    x = x0 / 2.0
    x += 0.5
    x *= 255.
    return x


def dataset(base_dir, n):
    print("base dir: "+base_dir)
    print("n: "+str(n))
    n = int(n)
    d = defaultdict(list)
    for root, subdirs, files in os.walk(base_dir):
        for filename in files:
            file_path = os.path.join(root, filename)
            assert file_path.startswith(base_dir)
            
            suffix = file_path[len(base_dir):]
            
            suffix = suffix.lstrip("/")
            suffix = suffix.lstrip("\\")
            if(suffix.find('/')>-1): #linux
                label = suffix.split("/")[0]
            else: #window
                label = suffix.split("\\")[0]
            d[label].append(file_path)
    print("walk directory complete")
    tags = sorted(d.keys())

    processed_image_count = 0
    useful_image_count = 0

    X = []
    y = []

    for class_index, class_name in enumerate(tags):
        filenames = d[class_name]
        for filename in filenames:
            processed_image_count += 1
            if processed_image_count%100 ==0:
                print(class_name+"\tprocess: "+str(processed_image_count)+"\t"+str(len(d[class_name])))
            img = scipy.misc.imread(filename)
            height, width, chan = img.shape
            assert chan == 3
            aspect_ratio = float(max((height, width))) / min((height, width))
            if aspect_ratio > 2:
                continue
            # We pick the largest center square.
            centery = height // 2
            centerx = width // 2
            radius = min((centerx, centery))
            img = img[centery-radius:centery+radius, centerx-radius:centerx+radius]
            img = scipy.misc.imresize(img, size=(n, n), interp='bilinear')
            X.append(img)
            y.append(class_index)
            useful_image_count += 1
    print("processed %d, used %d" % (processed_image_count, useful_image_count))

    X = np.array(X).astype(np.float32)
    #X = X.transpose((0, 3, 1, 2))
    X = preprocess_input(X)
    y = np.array(y)

    perm = np.random.permutation(len(y))
    X = X[perm]
    y = y[perm]

    print("classes:",end=" ")
    for class_index, class_name in enumerate(tags):
        print(class_name, sum(y==class_index),end=" ")
    print("X shape: ",X.shape)

    return X, y, tags

Đoạn code trên khá đơn giản và dễ hiểu. Lưu ý ở đây là với những bức ảnh có tỷ lệ width và height > 2 thì mình sẽ loại chúng ra khỏi tập dữ liệu.

Tiếp theo, chúng ta sẽ xây dựng mô hình dựa trên mô hình Resnet50 có sẵn của kares, do sử dụng pretrain model, nên n-1 lớp trước đó sẽ không được huấn luyện và chúng ta sẽ sử dụng dụng các weight có sẵn đã được huấn luyện trên tập ImageNet rút đặc trưng cho mô hình. Chúng ta chỉ cần thêm một lớp full connected và softmax để phân lớp các loại hoa, công việc của chúng ta hiện tại là tìm ra trọng số của lớp full connected cuối cùng (thay vì huấn luyện lại hết toàn bộ mô hình).


# create the base pre-trained model
def build_model(nb_classes):
    base_model = ResNet50(weights='imagenet', include_top=False)

    # add a global spatial average pooling layer
    x = base_model.output
    x = GlobalAveragePooling2D()(x)
    # let's add a fully-connected layer
    x = Dense(1024, activation='relu')(x)
    # and a logistic layer
    predictions = Dense(nb_classes, activation='softmax')(x)

    # this is the model we will train
    model = Model(inputs=base_model.input, outputs=predictions)

    # first: train only the top layers (which were randomly initialized)
    # i.e. freeze all convolutional ResNet50 layers
    for layer in base_model.layers:
        layer.trainable = False

    return model
    

Visualize một chút xíu về kiến trúc inceptionV3 mình đang dùng.

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to
==================================================================================================
input_1 (InputLayer)            (None, None, None, 3 0
__________________________________________________________________________________________________
conv1_pad (ZeroPadding2D)       (None, None, None, 3 0           input_1[0][0]
__________________________________________________________________________________________________
conv1 (Conv2D)                  (None, None, None, 6 9472        conv1_pad[0][0]
__________________________________________________________________________________________________
bn_conv1 (BatchNormalization)   (None, None, None, 6 256         conv1[0][0]
__________________________________________________________________________________________________
activation_1 (Activation)       (None, None, None, 6 0           bn_conv1[0][0]
__________________________________________________________________________________________________
pool1_pad (ZeroPadding2D)       (None, None, None, 6 0           activation_1[0][0]
__________________________________________________________________________________________________
max_pooling2d_1 (MaxPooling2D)  (None, None, None, 6 0           pool1_pad[0][0]
__________________________________________________________________________________________________
res2a_branch2a (Conv2D)         (None, None, None, 6 4160        max_pooling2d_1[0][0]
__________________________________________________________________________________________________
bn2a_branch2a (BatchNormalizati (None, None, None, 6 256         res2a_branch2a[0][0]
__________________________________________________________________________________________________
activation_2 (Activation)       (None, None, None, 6 0           bn2a_branch2a[0][0]
__________________________________________________________________________________________________
res2a_branch2b (Conv2D)         (None, None, None, 6 36928       activation_2[0][0]
__________________________________________________________________________________________________
bn2a_branch2b (BatchNormalizati (None, None, None, 6 256         res2a_branch2b[0][0]
__________________________________________________________________________________________________
activation_3 (Activation)       (None, None, None, 6 0           bn2a_branch2b[0][0]
__________________________________________________________________________________________________
res2a_branch2c (Conv2D)         (None, None, None, 2 16640       activation_3[0][0]
__________________________________________________________________________________________________
res2a_branch1 (Conv2D)          (None, None, None, 2 16640       max_pooling2d_1[0][0]
__________________________________________________________________________________________________
bn2a_branch2c (BatchNormalizati (None, None, None, 2 1024        res2a_branch2c[0][0]
__________________________________________________________________________________________________
bn2a_branch1 (BatchNormalizatio (None, None, None, 2 1024        res2a_branch1[0][0]
__________________________________________________________________________________________________
add_1 (Add)                     (None, None, None, 2 0           bn2a_branch2c[0][0]
                                                                 bn2a_branch1[0][0]
__________________________________________________________________________________________________
activation_4 (Activation)       (None, None, None, 2 0           add_1[0][0]
__________________________________________________________________________________________________
res2b_branch2a (Conv2D)         (None, None, None, 6 16448       activation_4[0][0]
__________________________________________________________________________________________________
bn2b_branch2a (BatchNormalizati (None, None, None, 6 256         res2b_branch2a[0][0]
__________________________________________________________________________________________________
activation_5 (Activation)       (None, None, None, 6 0           bn2b_branch2a[0][0]
__________________________________________________________________________________________________
res2b_branch2b (Conv2D)         (None, None, None, 6 36928       activation_5[0][0]
__________________________________________________________________________________________________
bn2b_branch2b (BatchNormalizati (None, None, None, 6 256         res2b_branch2b[0][0]
__________________________________________________________________________________________________
activation_6 (Activation)       (None, None, None, 6 0           bn2b_branch2b[0][0]
__________________________________________________________________________________________________
res2b_branch2c (Conv2D)         (None, None, None, 2 16640       activation_6[0][0]
__________________________________________________________________________________________________
bn2b_branch2c (BatchNormalizati (None, None, None, 2 1024        res2b_branch2c[0][0]
__________________________________________________________________________________________________
add_2 (Add)                     (None, None, None, 2 0           bn2b_branch2c[0][0]
                                                                 activation_4[0][0]
__________________________________________________________________________________________________
activation_7 (Activation)       (None, None, None, 2 0           add_2[0][0]
__________________________________________________________________________________________________
res2c_branch2a (Conv2D)         (None, None, None, 6 16448       activation_7[0][0]
__________________________________________________________________________________________________
bn2c_branch2a (BatchNormalizati (None, None, None, 6 256         res2c_branch2a[0][0]
__________________________________________________________________________________________________
activation_8 (Activation)       (None, None, None, 6 0           bn2c_branch2a[0][0]
__________________________________________________________________________________________________
res2c_branch2b (Conv2D)         (None, None, None, 6 36928       activation_8[0][0]
__________________________________________________________________________________________________
bn2c_branch2b (BatchNormalizati (None, None, None, 6 256         res2c_branch2b[0][0]
__________________________________________________________________________________________________
activation_9 (Activation)       (None, None, None, 6 0           bn2c_branch2b[0][0]
__________________________________________________________________________________________________
res2c_branch2c (Conv2D)         (None, None, None, 2 16640       activation_9[0][0]
__________________________________________________________________________________________________
bn2c_branch2c (BatchNormalizati (None, None, None, 2 1024        res2c_branch2c[0][0]
__________________________________________________________________________________________________
add_3 (Add)                     (None, None, None, 2 0           bn2c_branch2c[0][0]
                                                                 activation_7[0][0]
__________________________________________________________________________________________________
activation_10 (Activation)      (None, None, None, 2 0           add_3[0][0]
__________________________________________________________________________________________________
res3a_branch2a (Conv2D)         (None, None, None, 1 32896       activation_10[0][0]
__________________________________________________________________________________________________
bn3a_branch2a (BatchNormalizati (None, None, None, 1 512         res3a_branch2a[0][0]
__________________________________________________________________________________________________
activation_11 (Activation)      (None, None, None, 1 0           bn3a_branch2a[0][0]
__________________________________________________________________________________________________
res3a_branch2b (Conv2D)         (None, None, None, 1 147584      activation_11[0][0]
__________________________________________________________________________________________________
bn3a_branch2b (BatchNormalizati (None, None, None, 1 512         res3a_branch2b[0][0]
__________________________________________________________________________________________________
activation_12 (Activation)      (None, None, None, 1 0           bn3a_branch2b[0][0]
__________________________________________________________________________________________________
res3a_branch2c (Conv2D)         (None, None, None, 5 66048       activation_12[0][0]
__________________________________________________________________________________________________
res3a_branch1 (Conv2D)          (None, None, None, 5 131584      activation_10[0][0]
__________________________________________________________________________________________________
bn3a_branch2c (BatchNormalizati (None, None, None, 5 2048        res3a_branch2c[0][0]
__________________________________________________________________________________________________
bn3a_branch1 (BatchNormalizatio (None, None, None, 5 2048        res3a_branch1[0][0]
__________________________________________________________________________________________________
add_4 (Add)                     (None, None, None, 5 0           bn3a_branch2c[0][0]
                                                                 bn3a_branch1[0][0]
__________________________________________________________________________________________________
activation_13 (Activation)      (None, None, None, 5 0           add_4[0][0]
__________________________________________________________________________________________________
res3b_branch2a (Conv2D)         (None, None, None, 1 65664       activation_13[0][0]
__________________________________________________________________________________________________
bn3b_branch2a (BatchNormalizati (None, None, None, 1 512         res3b_branch2a[0][0]
__________________________________________________________________________________________________
activation_14 (Activation)      (None, None, None, 1 0           bn3b_branch2a[0][0]
__________________________________________________________________________________________________
res3b_branch2b (Conv2D)         (None, None, None, 1 147584      activation_14[0][0]
__________________________________________________________________________________________________
bn3b_branch2b (BatchNormalizati (None, None, None, 1 512         res3b_branch2b[0][0]
__________________________________________________________________________________________________
activation_15 (Activation)      (None, None, None, 1 0           bn3b_branch2b[0][0]
__________________________________________________________________________________________________
res3b_branch2c (Conv2D)         (None, None, None, 5 66048       activation_15[0][0]
__________________________________________________________________________________________________
bn3b_branch2c (BatchNormalizati (None, None, None, 5 2048        res3b_branch2c[0][0]
__________________________________________________________________________________________________
add_5 (Add)                     (None, None, None, 5 0           bn3b_branch2c[0][0]
                                                                 activation_13[0][0]
__________________________________________________________________________________________________
activation_16 (Activation)      (None, None, None, 5 0           add_5[0][0]
__________________________________________________________________________________________________
res3c_branch2a (Conv2D)         (None, None, None, 1 65664       activation_16[0][0]
__________________________________________________________________________________________________
bn3c_branch2a (BatchNormalizati (None, None, None, 1 512         res3c_branch2a[0][0]
__________________________________________________________________________________________________
activation_17 (Activation)      (None, None, None, 1 0           bn3c_branch2a[0][0]
__________________________________________________________________________________________________
res3c_branch2b (Conv2D)         (None, None, None, 1 147584      activation_17[0][0]
__________________________________________________________________________________________________
bn3c_branch2b (BatchNormalizati (None, None, None, 1 512         res3c_branch2b[0][0]
__________________________________________________________________________________________________
activation_18 (Activation)      (None, None, None, 1 0           bn3c_branch2b[0][0]
__________________________________________________________________________________________________
res3c_branch2c (Conv2D)         (None, None, None, 5 66048       activation_18[0][0]
__________________________________________________________________________________________________
bn3c_branch2c (BatchNormalizati (None, None, None, 5 2048        res3c_branch2c[0][0]
__________________________________________________________________________________________________
add_6 (Add)                     (None, None, None, 5 0           bn3c_branch2c[0][0]
                                                                 activation_16[0][0]
__________________________________________________________________________________________________
activation_19 (Activation)      (None, None, None, 5 0           add_6[0][0]
__________________________________________________________________________________________________
res3d_branch2a (Conv2D)         (None, None, None, 1 65664       activation_19[0][0]
__________________________________________________________________________________________________
bn3d_branch2a (BatchNormalizati (None, None, None, 1 512         res3d_branch2a[0][0]
__________________________________________________________________________________________________
activation_20 (Activation)      (None, None, None, 1 0           bn3d_branch2a[0][0]
__________________________________________________________________________________________________
res3d_branch2b (Conv2D)         (None, None, None, 1 147584      activation_20[0][0]
__________________________________________________________________________________________________
bn3d_branch2b (BatchNormalizati (None, None, None, 1 512         res3d_branch2b[0][0]
__________________________________________________________________________________________________
activation_21 (Activation)      (None, None, None, 1 0           bn3d_branch2b[0][0]
__________________________________________________________________________________________________
res3d_branch2c (Conv2D)         (None, None, None, 5 66048       activation_21[0][0]
__________________________________________________________________________________________________
bn3d_branch2c (BatchNormalizati (None, None, None, 5 2048        res3d_branch2c[0][0]
__________________________________________________________________________________________________
add_7 (Add)                     (None, None, None, 5 0           bn3d_branch2c[0][0]
                                                                 activation_19[0][0]
__________________________________________________________________________________________________
activation_22 (Activation)      (None, None, None, 5 0           add_7[0][0]
__________________________________________________________________________________________________
res4a_branch2a (Conv2D)         (None, None, None, 2 131328      activation_22[0][0]
__________________________________________________________________________________________________
bn4a_branch2a (BatchNormalizati (None, None, None, 2 1024        res4a_branch2a[0][0]
__________________________________________________________________________________________________
activation_23 (Activation)      (None, None, None, 2 0           bn4a_branch2a[0][0]
__________________________________________________________________________________________________
res4a_branch2b (Conv2D)         (None, None, None, 2 590080      activation_23[0][0]
__________________________________________________________________________________________________
bn4a_branch2b (BatchNormalizati (None, None, None, 2 1024        res4a_branch2b[0][0]
__________________________________________________________________________________________________
activation_24 (Activation)      (None, None, None, 2 0           bn4a_branch2b[0][0]
__________________________________________________________________________________________________
res4a_branch2c (Conv2D)         (None, None, None, 1 263168      activation_24[0][0]
__________________________________________________________________________________________________
res4a_branch1 (Conv2D)          (None, None, None, 1 525312      activation_22[0][0]
__________________________________________________________________________________________________
bn4a_branch2c (BatchNormalizati (None, None, None, 1 4096        res4a_branch2c[0][0]
__________________________________________________________________________________________________
bn4a_branch1 (BatchNormalizatio (None, None, None, 1 4096        res4a_branch1[0][0]
__________________________________________________________________________________________________
add_8 (Add)                     (None, None, None, 1 0           bn4a_branch2c[0][0]
                                                                 bn4a_branch1[0][0]
__________________________________________________________________________________________________
activation_25 (Activation)      (None, None, None, 1 0           add_8[0][0]
__________________________________________________________________________________________________
res4b_branch2a (Conv2D)         (None, None, None, 2 262400      activation_25[0][0]
__________________________________________________________________________________________________
bn4b_branch2a (BatchNormalizati (None, None, None, 2 1024        res4b_branch2a[0][0]
__________________________________________________________________________________________________
activation_26 (Activation)      (None, None, None, 2 0           bn4b_branch2a[0][0]
__________________________________________________________________________________________________
res4b_branch2b (Conv2D)         (None, None, None, 2 590080      activation_26[0][0]
__________________________________________________________________________________________________
bn4b_branch2b (BatchNormalizati (None, None, None, 2 1024        res4b_branch2b[0][0]
__________________________________________________________________________________________________
activation_27 (Activation)      (None, None, None, 2 0           bn4b_branch2b[0][0]
__________________________________________________________________________________________________
res4b_branch2c (Conv2D)         (None, None, None, 1 263168      activation_27[0][0]
__________________________________________________________________________________________________
bn4b_branch2c (BatchNormalizati (None, None, None, 1 4096        res4b_branch2c[0][0]
__________________________________________________________________________________________________
add_9 (Add)                     (None, None, None, 1 0           bn4b_branch2c[0][0]
                                                                 activation_25[0][0]
__________________________________________________________________________________________________
activation_28 (Activation)      (None, None, None, 1 0           add_9[0][0]
__________________________________________________________________________________________________
res4c_branch2a (Conv2D)         (None, None, None, 2 262400      activation_28[0][0]
__________________________________________________________________________________________________
bn4c_branch2a (BatchNormalizati (None, None, None, 2 1024        res4c_branch2a[0][0]
__________________________________________________________________________________________________
activation_29 (Activation)      (None, None, None, 2 0           bn4c_branch2a[0][0]
__________________________________________________________________________________________________
res4c_branch2b (Conv2D)         (None, None, None, 2 590080      activation_29[0][0]
__________________________________________________________________________________________________
bn4c_branch2b (BatchNormalizati (None, None, None, 2 1024        res4c_branch2b[0][0]
__________________________________________________________________________________________________
activation_30 (Activation)      (None, None, None, 2 0           bn4c_branch2b[0][0]
__________________________________________________________________________________________________
res4c_branch2c (Conv2D)         (None, None, None, 1 263168      activation_30[0][0]
__________________________________________________________________________________________________
bn4c_branch2c (BatchNormalizati (None, None, None, 1 4096        res4c_branch2c[0][0]
__________________________________________________________________________________________________
add_10 (Add)                    (None, None, None, 1 0           bn4c_branch2c[0][0]
                                                                 activation_28[0][0]
__________________________________________________________________________________________________
activation_31 (Activation)      (None, None, None, 1 0           add_10[0][0]
__________________________________________________________________________________________________
res4d_branch2a (Conv2D)         (None, None, None, 2 262400      activation_31[0][0]
__________________________________________________________________________________________________
bn4d_branch2a (BatchNormalizati (None, None, None, 2 1024        res4d_branch2a[0][0]
__________________________________________________________________________________________________
activation_32 (Activation)      (None, None, None, 2 0           bn4d_branch2a[0][0]
__________________________________________________________________________________________________
res4d_branch2b (Conv2D)         (None, None, None, 2 590080      activation_32[0][0]
__________________________________________________________________________________________________
bn4d_branch2b (BatchNormalizati (None, None, None, 2 1024        res4d_branch2b[0][0]
__________________________________________________________________________________________________
activation_33 (Activation)      (None, None, None, 2 0           bn4d_branch2b[0][0]
__________________________________________________________________________________________________
res4d_branch2c (Conv2D)         (None, None, None, 1 263168      activation_33[0][0]
__________________________________________________________________________________________________
bn4d_branch2c (BatchNormalizati (None, None, None, 1 4096        res4d_branch2c[0][0]
__________________________________________________________________________________________________
add_11 (Add)                    (None, None, None, 1 0           bn4d_branch2c[0][0]
                                                                 activation_31[0][0]
__________________________________________________________________________________________________
activation_34 (Activation)      (None, None, None, 1 0           add_11[0][0]
__________________________________________________________________________________________________
res4e_branch2a (Conv2D)         (None, None, None, 2 262400      activation_34[0][0]
__________________________________________________________________________________________________
bn4e_branch2a (BatchNormalizati (None, None, None, 2 1024        res4e_branch2a[0][0]
__________________________________________________________________________________________________
activation_35 (Activation)      (None, None, None, 2 0           bn4e_branch2a[0][0]
__________________________________________________________________________________________________
res4e_branch2b (Conv2D)         (None, None, None, 2 590080      activation_35[0][0]
__________________________________________________________________________________________________
bn4e_branch2b (BatchNormalizati (None, None, None, 2 1024        res4e_branch2b[0][0]
__________________________________________________________________________________________________
activation_36 (Activation)      (None, None, None, 2 0           bn4e_branch2b[0][0]
__________________________________________________________________________________________________
res4e_branch2c (Conv2D)         (None, None, None, 1 263168      activation_36[0][0]
__________________________________________________________________________________________________
bn4e_branch2c (BatchNormalizati (None, None, None, 1 4096        res4e_branch2c[0][0]
__________________________________________________________________________________________________
add_12 (Add)                    (None, None, None, 1 0           bn4e_branch2c[0][0]
                                                                 activation_34[0][0]
__________________________________________________________________________________________________
activation_37 (Activation)      (None, None, None, 1 0           add_12[0][0]
__________________________________________________________________________________________________
res4f_branch2a (Conv2D)         (None, None, None, 2 262400      activation_37[0][0]
__________________________________________________________________________________________________
bn4f_branch2a (BatchNormalizati (None, None, None, 2 1024        res4f_branch2a[0][0]
__________________________________________________________________________________________________
activation_38 (Activation)      (None, None, None, 2 0           bn4f_branch2a[0][0]
__________________________________________________________________________________________________
res4f_branch2b (Conv2D)         (None, None, None, 2 590080      activation_38[0][0]
__________________________________________________________________________________________________
bn4f_branch2b (BatchNormalizati (None, None, None, 2 1024        res4f_branch2b[0][0]
__________________________________________________________________________________________________
activation_39 (Activation)      (None, None, None, 2 0           bn4f_branch2b[0][0]
__________________________________________________________________________________________________
res4f_branch2c (Conv2D)         (None, None, None, 1 263168      activation_39[0][0]
__________________________________________________________________________________________________
bn4f_branch2c (BatchNormalizati (None, None, None, 1 4096        res4f_branch2c[0][0]
__________________________________________________________________________________________________
add_13 (Add)                    (None, None, None, 1 0           bn4f_branch2c[0][0]
                                                                 activation_37[0][0]
__________________________________________________________________________________________________
activation_40 (Activation)      (None, None, None, 1 0           add_13[0][0]
__________________________________________________________________________________________________
res5a_branch2a (Conv2D)         (None, None, None, 5 524800      activation_40[0][0]
__________________________________________________________________________________________________
bn5a_branch2a (BatchNormalizati (None, None, None, 5 2048        res5a_branch2a[0][0]
__________________________________________________________________________________________________
activation_41 (Activation)      (None, None, None, 5 0           bn5a_branch2a[0][0]
__________________________________________________________________________________________________
res5a_branch2b (Conv2D)         (None, None, None, 5 2359808     activation_41[0][0]
__________________________________________________________________________________________________
bn5a_branch2b (BatchNormalizati (None, None, None, 5 2048        res5a_branch2b[0][0]
__________________________________________________________________________________________________
activation_42 (Activation)      (None, None, None, 5 0           bn5a_branch2b[0][0]
__________________________________________________________________________________________________
res5a_branch2c (Conv2D)         (None, None, None, 2 1050624     activation_42[0][0]
__________________________________________________________________________________________________
res5a_branch1 (Conv2D)          (None, None, None, 2 2099200     activation_40[0][0]
__________________________________________________________________________________________________
bn5a_branch2c (BatchNormalizati (None, None, None, 2 8192        res5a_branch2c[0][0]
__________________________________________________________________________________________________
bn5a_branch1 (BatchNormalizatio (None, None, None, 2 8192        res5a_branch1[0][0]
__________________________________________________________________________________________________
add_14 (Add)                    (None, None, None, 2 0           bn5a_branch2c[0][0]
                                                                 bn5a_branch1[0][0]
__________________________________________________________________________________________________
activation_43 (Activation)      (None, None, None, 2 0           add_14[0][0]
__________________________________________________________________________________________________
res5b_branch2a (Conv2D)         (None, None, None, 5 1049088     activation_43[0][0]
__________________________________________________________________________________________________
bn5b_branch2a (BatchNormalizati (None, None, None, 5 2048        res5b_branch2a[0][0]
__________________________________________________________________________________________________
activation_44 (Activation)      (None, None, None, 5 0           bn5b_branch2a[0][0]
__________________________________________________________________________________________________
res5b_branch2b (Conv2D)         (None, None, None, 5 2359808     activation_44[0][0]
__________________________________________________________________________________________________
bn5b_branch2b (BatchNormalizati (None, None, None, 5 2048        res5b_branch2b[0][0]
__________________________________________________________________________________________________
activation_45 (Activation)      (None, None, None, 5 0           bn5b_branch2b[0][0]
__________________________________________________________________________________________________
res5b_branch2c (Conv2D)         (None, None, None, 2 1050624     activation_45[0][0]
__________________________________________________________________________________________________
bn5b_branch2c (BatchNormalizati (None, None, None, 2 8192        res5b_branch2c[0][0]
__________________________________________________________________________________________________
add_15 (Add)                    (None, None, None, 2 0           bn5b_branch2c[0][0]
                                                                 activation_43[0][0]
__________________________________________________________________________________________________
activation_46 (Activation)      (None, None, None, 2 0           add_15[0][0]
__________________________________________________________________________________________________
res5c_branch2a (Conv2D)         (None, None, None, 5 1049088     activation_46[0][0]
__________________________________________________________________________________________________
bn5c_branch2a (BatchNormalizati (None, None, None, 5 2048        res5c_branch2a[0][0]
__________________________________________________________________________________________________
activation_47 (Activation)      (None, None, None, 5 0           bn5c_branch2a[0][0]
__________________________________________________________________________________________________
res5c_branch2b (Conv2D)         (None, None, None, 5 2359808     activation_47[0][0]
__________________________________________________________________________________________________
bn5c_branch2b (BatchNormalizati (None, None, None, 5 2048        res5c_branch2b[0][0]
__________________________________________________________________________________________________
activation_48 (Activation)      (None, None, None, 5 0           bn5c_branch2b[0][0]
__________________________________________________________________________________________________
res5c_branch2c (Conv2D)         (None, None, None, 2 1050624     activation_48[0][0]
__________________________________________________________________________________________________
bn5c_branch2c (BatchNormalizati (None, None, None, 2 8192        res5c_branch2c[0][0]
__________________________________________________________________________________________________
add_16 (Add)                    (None, None, None, 2 0           bn5c_branch2c[0][0]
                                                                 activation_46[0][0]
__________________________________________________________________________________________________
activation_49 (Activation)      (None, None, None, 2 0           add_16[0][0]
__________________________________________________________________________________________________
global_average_pooling2d_1 (Glo (None, 2048)         0           activation_49[0][0]
__________________________________________________________________________________________________
dense_1 (Dense)                 (None, 1024)         2098176     global_average_pooling2d_1[0][0]
__________________________________________________________________________________________________
dense_2 (Dense)                 (None, 5)            5125        dense_1[0][0]
==================================================================================================
Total params: 25,691,013
Trainable params: 2,103,301
Non-trainable params: 23,587,712
__________________________________________________________________________________________________

Phần train lại sẽ có khoảng hơn 2 triệu tham số, phần layer ở trước đó không train là khoảng 23 triệu tham số.

Chia tập dữ liệu ra thành 5 phần, 4 phần làm tập train, 1 phần làm tập validation.

X, y, tags = dataset.dataset(data_directory, n)
nb_classes = len(tags)


sample_count = len(y)
train_size = sample_count * 4 // 5
X_train = X[:train_size]
y_train = y[:train_size]
Y_train = np_utils.to_categorical(y_train, nb_classes)
X_test  = X[train_size:]
y_test  = y[train_size:]
Y_test = np_utils.to_categorical(y_test, nb_classes)

chúng ta tiến hành thực hiện ImageDataGenerator để có được nhiều dữ liệu mẫu hơn và chống overfit, trong keras đã có sẵn hàm

datagen = ImageDataGenerator(
        featurewise_center=False,
        samplewise_center=False,
        featurewise_std_normalization=False,
        samplewise_std_normalization=False,
        zca_whitening=False,
        rotation_range=45,
        width_shift_range=0.25,
        height_shift_range=0.25,
        horizontal_flip=True,
        vertical_flip=False,
        channel_shift_range=0.5,
        zoom_range=[0.5, 1.5],
        brightness_range=[0.5, 1.5],
        fill_mode='reflect')
        
datagen.fit(X_train)

Cuối cùng, chúng ta sẽ xây dựng mô hình và tiến hành huấn luyện, lưu mô hình. Quá trình này tốn hơi nhiều thời gian.


model = net.build_model(nb_classes)
model.compile(optimizer='rmsprop', loss='categorical_crossentropy', metrics=["accuracy"])

# train the model on the new data for a few epochs

print("training the newly added dense layers")

samples_per_epoch = X_train.shape[0]//batch_size*batch_size
steps_per_epoch = samples_per_epoch//batch_size
validation_steps = X_test.shape[0]//batch_size*batch_size

model.fit_generator(datagen.flow(X_train, Y_train, batch_size=batch_size, shuffle=True),
            samples_per_epoch=samples_per_epoch,
            epochs=nb_epoch,
            steps_per_epoch = steps_per_epoch,
            validation_data=datagen.flow(X_test, Y_test, batch_size=batch_size),
            validation_steps=validation_steps,
            )


net.save(model, tags, model_file_prefix)

Thử download một vài hình ảnh trên mạng về rồi test thử xem sao

Hình ảnh

Kết quả khá tốt phải không các bạn.

Cảm ơn các bạn đã theo dõi. Hẹn gặp bạn ở các bài viết tiếp theo.


Bài viết khác
comments powered by Disqus