Phân Loại Hoa Sử Dụng Pretrain Model

Lời mở đầu

Ở trong bài viết này, chúng ta sẽ sử dụng tập dữ liệu là tập dữ liệu ở ở link https://www.kaggle.com/alxmamaev/flowers-recognition. Tập dữ liệu này bao gồm 4242 hình cảnh của 5 loại hoa hồng (rose), hoa mặt trời (sunflower), hoa bồ công anh (dandelion), hoa cúc (daisy) và hoa tulip. Nhóm tác giả đã thu thập dữ liệu dựa trên các trang web flicr, google images, yandex. Tập hình ảnh được chia thành 5 lớp, mỗi lớp có khoảng 800 hình, có kích thước xấp xỉ 320x320 pixel. Các hình ảnh có kích thước không đồng nhất với nhau.

Thực hiện

Dữ liệu sau khi giản nén có dạng

1data_dir/classname1/*.*
2data_dir/classname2/*.*
3...

Cấu trúc lưu trũ như này đúng với mô hình của mình nên chúng ta cần nên chúng ta không thay đổi gì về câu trúc nữa, tiến hành viết code

Đầu tiên, chúng ta sẽ load dataset lên và tranform nó để đưa vào huấn luyện.

 1import sys
 2import os
 3from collections import defaultdict
 4import numpy as np
 5import scipy.misc
 6
 7
 8def preprocess_input(x0):
 9    x = x0 / 255.
10    x -= 0.5
11    x *= 2.
12    return x
13
14
15def reverse_preprocess_input(x0):
16    x = x0 / 2.0
17    x += 0.5
18    x *= 255.
19    return x
20
21
22def dataset(base_dir, n):
23    print("base dir: "+base_dir)
24    print("n: "+str(n))
25    n = int(n)
26    d = defaultdict(list)
27    for root, subdirs, files in os.walk(base_dir):
28        for filename in files:
29            file_path = os.path.join(root, filename)
30            assert file_path.startswith(base_dir)
31
32            suffix = file_path[len(base_dir):]
33
34            suffix = suffix.lstrip("/")
35            suffix = suffix.lstrip("\\")
36            if(suffix.find('/')>-1): #linux
37                label = suffix.split("/")[0]
38            else: #window
39                label = suffix.split("\\")[0]
40            d[label].append(file_path)
41    print("walk directory complete")
42    tags = sorted(d.keys())
43
44    processed_image_count = 0
45    useful_image_count = 0
46
47    X = []
48    y = []
49
50    for class_index, class_name in enumerate(tags):
51        filenames = d[class_name]
52        for filename in filenames:
53            processed_image_count += 1
54            if processed_image_count%100 ==0:
55                print(class_name+"\tprocess: "+str(processed_image_count)+"\t"+str(len(d[class_name])))
56            img = scipy.misc.imread(filename)
57            height, width, chan = img.shape
58            assert chan == 3
59            aspect_ratio = float(max((height, width))) / min((height, width))
60            if aspect_ratio > 2:
61                continue
62            # We pick the largest center square.
63            centery = height // 2
64            centerx = width // 2
65            radius = min((centerx, centery))
66            img = img[centery-radius:centery+radius, centerx-radius:centerx+radius]
67            img = scipy.misc.imresize(img, size=(n, n), interp='bilinear')
68            X.append(img)
69            y.append(class_index)
70            useful_image_count += 1
71    print("processed %d, used %d" % (processed_image_count, useful_image_count))
72
73    X = np.array(X).astype(np.float32)
74    #X = X.transpose((0, 3, 1, 2))
75    X = preprocess_input(X)
76    y = np.array(y)
77
78    perm = np.random.permutation(len(y))
79    X = X[perm]
80    y = y[perm]
81
82    print("classes:",end=" ")
83    for class_index, class_name in enumerate(tags):
84        print(class_name, sum(y==class_index),end=" ")
85    print("X shape: ",X.shape)
86
87    return X, y, tags

Đoạn code trên khá đơn giản và dễ hiểu. Lưu ý ở đây là với những bức ảnh có tỷ lệ width và height > 2 thì mình sẽ loại chúng ra khỏi tập dữ liệu.

Tiếp theo, chúng ta sẽ xây dựng mô hình dựa trên mô hình Resnet50 có sẵn của kares, do sử dụng pretrain model, nên n-1 lớp trước đó sẽ không được huấn luyện và chúng ta sẽ sử dụng dụng các weight có sẵn đã được huấn luyện trên tập ImageNet rút đặc trưng cho mô hình. Chúng ta chỉ cần thêm một lớp full connected và softmax để phân lớp các loại hoa, công việc của chúng ta hiện tại là tìm ra trọng số của lớp full connected cuối cùng (thay vì huấn luyện lại hết toàn bộ mô hình).

 1
 2# create the base pre-trained model
 3def build_model(nb_classes):
 4    base_model = ResNet50(weights='imagenet', include_top=False)
 5
 6    # add a global spatial average pooling layer
 7    x = base_model.output
 8    x = GlobalAveragePooling2D()(x)
 9    # let's add a fully-connected layer
10    x = Dense(1024, activation='relu')(x)
11    # and a logistic layer
12    predictions = Dense(nb_classes, activation='softmax')(x)
13
14    # this is the model we will train
15    model = Model(inputs=base_model.input, outputs=predictions)
16
17    # first: train only the top layers (which were randomly initialized)
18    # i.e. freeze all convolutional ResNet50 layers
19    for layer in base_model.layers:
20        layer.trainable = False
21
22    return model

Visualize một chút xíu về kiến trúc inceptionV3 mình đang dùng.

  1__________________________________________________________________________________________________
  2Layer (type)                    Output Shape         Param #     Connected to
  3==================================================================================================
  4input_1 (InputLayer)            (None, None, None, 3 0
  5__________________________________________________________________________________________________
  6conv1_pad (ZeroPadding2D)       (None, None, None, 3 0           input_1[0][0]
  7__________________________________________________________________________________________________
  8conv1 (Conv2D)                  (None, None, None, 6 9472        conv1_pad[0][0]
  9__________________________________________________________________________________________________
 10bn_conv1 (BatchNormalization)   (None, None, None, 6 256         conv1[0][0]
 11__________________________________________________________________________________________________
 12activation_1 (Activation)       (None, None, None, 6 0           bn_conv1[0][0]
 13__________________________________________________________________________________________________
 14pool1_pad (ZeroPadding2D)       (None, None, None, 6 0           activation_1[0][0]
 15__________________________________________________________________________________________________
 16max_pooling2d_1 (MaxPooling2D)  (None, None, None, 6 0           pool1_pad[0][0]
 17__________________________________________________________________________________________________
 18res2a_branch2a (Conv2D)         (None, None, None, 6 4160        max_pooling2d_1[0][0]
 19__________________________________________________________________________________________________
 20bn2a_branch2a (BatchNormalizati (None, None, None, 6 256         res2a_branch2a[0][0]
 21__________________________________________________________________________________________________
 22activation_2 (Activation)       (None, None, None, 6 0           bn2a_branch2a[0][0]
 23__________________________________________________________________________________________________
 24res2a_branch2b (Conv2D)         (None, None, None, 6 36928       activation_2[0][0]
 25__________________________________________________________________________________________________
 26bn2a_branch2b (BatchNormalizati (None, None, None, 6 256         res2a_branch2b[0][0]
 27__________________________________________________________________________________________________
 28activation_3 (Activation)       (None, None, None, 6 0           bn2a_branch2b[0][0]
 29__________________________________________________________________________________________________
 30res2a_branch2c (Conv2D)         (None, None, None, 2 16640       activation_3[0][0]
 31__________________________________________________________________________________________________
 32res2a_branch1 (Conv2D)          (None, None, None, 2 16640       max_pooling2d_1[0][0]
 33__________________________________________________________________________________________________
 34bn2a_branch2c (BatchNormalizati (None, None, None, 2 1024        res2a_branch2c[0][0]
 35__________________________________________________________________________________________________
 36bn2a_branch1 (BatchNormalizatio (None, None, None, 2 1024        res2a_branch1[0][0]
 37__________________________________________________________________________________________________
 38add_1 (Add)                     (None, None, None, 2 0           bn2a_branch2c[0][0]
 39                                                                 bn2a_branch1[0][0]
 40__________________________________________________________________________________________________
 41activation_4 (Activation)       (None, None, None, 2 0           add_1[0][0]
 42__________________________________________________________________________________________________
 43res2b_branch2a (Conv2D)         (None, None, None, 6 16448       activation_4[0][0]
 44__________________________________________________________________________________________________
 45bn2b_branch2a (BatchNormalizati (None, None, None, 6 256         res2b_branch2a[0][0]
 46__________________________________________________________________________________________________
 47activation_5 (Activation)       (None, None, None, 6 0           bn2b_branch2a[0][0]
 48__________________________________________________________________________________________________
 49res2b_branch2b (Conv2D)         (None, None, None, 6 36928       activation_5[0][0]
 50__________________________________________________________________________________________________
 51bn2b_branch2b (BatchNormalizati (None, None, None, 6 256         res2b_branch2b[0][0]
 52__________________________________________________________________________________________________
 53activation_6 (Activation)       (None, None, None, 6 0           bn2b_branch2b[0][0]
 54__________________________________________________________________________________________________
 55res2b_branch2c (Conv2D)         (None, None, None, 2 16640       activation_6[0][0]
 56__________________________________________________________________________________________________
 57bn2b_branch2c (BatchNormalizati (None, None, None, 2 1024        res2b_branch2c[0][0]
 58__________________________________________________________________________________________________
 59add_2 (Add)                     (None, None, None, 2 0           bn2b_branch2c[0][0]
 60                                                                 activation_4[0][0]
 61__________________________________________________________________________________________________
 62activation_7 (Activation)       (None, None, None, 2 0           add_2[0][0]
 63__________________________________________________________________________________________________
 64res2c_branch2a (Conv2D)         (None, None, None, 6 16448       activation_7[0][0]
 65__________________________________________________________________________________________________
 66bn2c_branch2a (BatchNormalizati (None, None, None, 6 256         res2c_branch2a[0][0]
 67__________________________________________________________________________________________________
 68activation_8 (Activation)       (None, None, None, 6 0           bn2c_branch2a[0][0]
 69__________________________________________________________________________________________________
 70res2c_branch2b (Conv2D)         (None, None, None, 6 36928       activation_8[0][0]
 71__________________________________________________________________________________________________
 72bn2c_branch2b (BatchNormalizati (None, None, None, 6 256         res2c_branch2b[0][0]
 73__________________________________________________________________________________________________
 74activation_9 (Activation)       (None, None, None, 6 0           bn2c_branch2b[0][0]
 75__________________________________________________________________________________________________
 76res2c_branch2c (Conv2D)         (None, None, None, 2 16640       activation_9[0][0]
 77__________________________________________________________________________________________________
 78bn2c_branch2c (BatchNormalizati (None, None, None, 2 1024        res2c_branch2c[0][0]
 79__________________________________________________________________________________________________
 80add_3 (Add)                     (None, None, None, 2 0           bn2c_branch2c[0][0]
 81                                                                 activation_7[0][0]
 82__________________________________________________________________________________________________
 83activation_10 (Activation)      (None, None, None, 2 0           add_3[0][0]
 84__________________________________________________________________________________________________
 85res3a_branch2a (Conv2D)         (None, None, None, 1 32896       activation_10[0][0]
 86__________________________________________________________________________________________________
 87bn3a_branch2a (BatchNormalizati (None, None, None, 1 512         res3a_branch2a[0][0]
 88__________________________________________________________________________________________________
 89activation_11 (Activation)      (None, None, None, 1 0           bn3a_branch2a[0][0]
 90__________________________________________________________________________________________________
 91res3a_branch2b (Conv2D)         (None, None, None, 1 147584      activation_11[0][0]
 92__________________________________________________________________________________________________
 93bn3a_branch2b (BatchNormalizati (None, None, None, 1 512         res3a_branch2b[0][0]
 94__________________________________________________________________________________________________
 95activation_12 (Activation)      (None, None, None, 1 0           bn3a_branch2b[0][0]
 96__________________________________________________________________________________________________
 97res3a_branch2c (Conv2D)         (None, None, None, 5 66048       activation_12[0][0]
 98__________________________________________________________________________________________________
 99res3a_branch1 (Conv2D)          (None, None, None, 5 131584      activation_10[0][0]
100__________________________________________________________________________________________________
101bn3a_branch2c (BatchNormalizati (None, None, None, 5 2048        res3a_branch2c[0][0]
102__________________________________________________________________________________________________
103bn3a_branch1 (BatchNormalizatio (None, None, None, 5 2048        res3a_branch1[0][0]
104__________________________________________________________________________________________________
105add_4 (Add)                     (None, None, None, 5 0           bn3a_branch2c[0][0]
106                                                                 bn3a_branch1[0][0]
107__________________________________________________________________________________________________
108activation_13 (Activation)      (None, None, None, 5 0           add_4[0][0]
109__________________________________________________________________________________________________
110res3b_branch2a (Conv2D)         (None, None, None, 1 65664       activation_13[0][0]
111__________________________________________________________________________________________________
112bn3b_branch2a (BatchNormalizati (None, None, None, 1 512         res3b_branch2a[0][0]
113__________________________________________________________________________________________________
114activation_14 (Activation)      (None, None, None, 1 0           bn3b_branch2a[0][0]
115__________________________________________________________________________________________________
116res3b_branch2b (Conv2D)         (None, None, None, 1 147584      activation_14[0][0]
117__________________________________________________________________________________________________
118bn3b_branch2b (BatchNormalizati (None, None, None, 1 512         res3b_branch2b[0][0]
119__________________________________________________________________________________________________
120activation_15 (Activation)      (None, None, None, 1 0           bn3b_branch2b[0][0]
121__________________________________________________________________________________________________
122res3b_branch2c (Conv2D)         (None, None, None, 5 66048       activation_15[0][0]
123__________________________________________________________________________________________________
124bn3b_branch2c (BatchNormalizati (None, None, None, 5 2048        res3b_branch2c[0][0]
125__________________________________________________________________________________________________
126add_5 (Add)                     (None, None, None, 5 0           bn3b_branch2c[0][0]
127                                                                 activation_13[0][0]
128__________________________________________________________________________________________________
129activation_16 (Activation)      (None, None, None, 5 0           add_5[0][0]
130__________________________________________________________________________________________________
131res3c_branch2a (Conv2D)         (None, None, None, 1 65664       activation_16[0][0]
132__________________________________________________________________________________________________
133bn3c_branch2a (BatchNormalizati (None, None, None, 1 512         res3c_branch2a[0][0]
134__________________________________________________________________________________________________
135activation_17 (Activation)      (None, None, None, 1 0           bn3c_branch2a[0][0]
136__________________________________________________________________________________________________
137res3c_branch2b (Conv2D)         (None, None, None, 1 147584      activation_17[0][0]
138__________________________________________________________________________________________________
139bn3c_branch2b (BatchNormalizati (None, None, None, 1 512         res3c_branch2b[0][0]
140__________________________________________________________________________________________________
141activation_18 (Activation)      (None, None, None, 1 0           bn3c_branch2b[0][0]
142__________________________________________________________________________________________________
143res3c_branch2c (Conv2D)         (None, None, None, 5 66048       activation_18[0][0]
144__________________________________________________________________________________________________
145bn3c_branch2c (BatchNormalizati (None, None, None, 5 2048        res3c_branch2c[0][0]
146__________________________________________________________________________________________________
147add_6 (Add)                     (None, None, None, 5 0           bn3c_branch2c[0][0]
148                                                                 activation_16[0][0]
149__________________________________________________________________________________________________
150activation_19 (Activation)      (None, None, None, 5 0           add_6[0][0]
151__________________________________________________________________________________________________
152res3d_branch2a (Conv2D)         (None, None, None, 1 65664       activation_19[0][0]
153__________________________________________________________________________________________________
154bn3d_branch2a (BatchNormalizati (None, None, None, 1 512         res3d_branch2a[0][0]
155__________________________________________________________________________________________________
156activation_20 (Activation)      (None, None, None, 1 0           bn3d_branch2a[0][0]
157__________________________________________________________________________________________________
158res3d_branch2b (Conv2D)         (None, None, None, 1 147584      activation_20[0][0]
159__________________________________________________________________________________________________
160bn3d_branch2b (BatchNormalizati (None, None, None, 1 512         res3d_branch2b[0][0]
161__________________________________________________________________________________________________
162activation_21 (Activation)      (None, None, None, 1 0           bn3d_branch2b[0][0]
163__________________________________________________________________________________________________
164res3d_branch2c (Conv2D)         (None, None, None, 5 66048       activation_21[0][0]
165__________________________________________________________________________________________________
166bn3d_branch2c (BatchNormalizati (None, None, None, 5 2048        res3d_branch2c[0][0]
167__________________________________________________________________________________________________
168add_7 (Add)                     (None, None, None, 5 0           bn3d_branch2c[0][0]
169                                                                 activation_19[0][0]
170__________________________________________________________________________________________________
171activation_22 (Activation)      (None, None, None, 5 0           add_7[0][0]
172__________________________________________________________________________________________________
173res4a_branch2a (Conv2D)         (None, None, None, 2 131328      activation_22[0][0]
174__________________________________________________________________________________________________
175bn4a_branch2a (BatchNormalizati (None, None, None, 2 1024        res4a_branch2a[0][0]
176__________________________________________________________________________________________________
177activation_23 (Activation)      (None, None, None, 2 0           bn4a_branch2a[0][0]
178__________________________________________________________________________________________________
179res4a_branch2b (Conv2D)         (None, None, None, 2 590080      activation_23[0][0]
180__________________________________________________________________________________________________
181bn4a_branch2b (BatchNormalizati (None, None, None, 2 1024        res4a_branch2b[0][0]
182__________________________________________________________________________________________________
183activation_24 (Activation)      (None, None, None, 2 0           bn4a_branch2b[0][0]
184__________________________________________________________________________________________________
185res4a_branch2c (Conv2D)         (None, None, None, 1 263168      activation_24[0][0]
186__________________________________________________________________________________________________
187res4a_branch1 (Conv2D)          (None, None, None, 1 525312      activation_22[0][0]
188__________________________________________________________________________________________________
189bn4a_branch2c (BatchNormalizati (None, None, None, 1 4096        res4a_branch2c[0][0]
190__________________________________________________________________________________________________
191bn4a_branch1 (BatchNormalizatio (None, None, None, 1 4096        res4a_branch1[0][0]
192__________________________________________________________________________________________________
193add_8 (Add)                     (None, None, None, 1 0           bn4a_branch2c[0][0]
194                                                                 bn4a_branch1[0][0]
195__________________________________________________________________________________________________
196activation_25 (Activation)      (None, None, None, 1 0           add_8[0][0]
197__________________________________________________________________________________________________
198res4b_branch2a (Conv2D)         (None, None, None, 2 262400      activation_25[0][0]
199__________________________________________________________________________________________________
200bn4b_branch2a (BatchNormalizati (None, None, None, 2 1024        res4b_branch2a[0][0]
201__________________________________________________________________________________________________
202activation_26 (Activation)      (None, None, None, 2 0           bn4b_branch2a[0][0]
203__________________________________________________________________________________________________
204res4b_branch2b (Conv2D)         (None, None, None, 2 590080      activation_26[0][0]
205__________________________________________________________________________________________________
206bn4b_branch2b (BatchNormalizati (None, None, None, 2 1024        res4b_branch2b[0][0]
207__________________________________________________________________________________________________
208activation_27 (Activation)      (None, None, None, 2 0           bn4b_branch2b[0][0]
209__________________________________________________________________________________________________
210res4b_branch2c (Conv2D)         (None, None, None, 1 263168      activation_27[0][0]
211__________________________________________________________________________________________________
212bn4b_branch2c (BatchNormalizati (None, None, None, 1 4096        res4b_branch2c[0][0]
213__________________________________________________________________________________________________
214add_9 (Add)                     (None, None, None, 1 0           bn4b_branch2c[0][0]
215                                                                 activation_25[0][0]
216__________________________________________________________________________________________________
217activation_28 (Activation)      (None, None, None, 1 0           add_9[0][0]
218__________________________________________________________________________________________________
219res4c_branch2a (Conv2D)         (None, None, None, 2 262400      activation_28[0][0]
220__________________________________________________________________________________________________
221bn4c_branch2a (BatchNormalizati (None, None, None, 2 1024        res4c_branch2a[0][0]
222__________________________________________________________________________________________________
223activation_29 (Activation)      (None, None, None, 2 0           bn4c_branch2a[0][0]
224__________________________________________________________________________________________________
225res4c_branch2b (Conv2D)         (None, None, None, 2 590080      activation_29[0][0]
226__________________________________________________________________________________________________
227bn4c_branch2b (BatchNormalizati (None, None, None, 2 1024        res4c_branch2b[0][0]
228__________________________________________________________________________________________________
229activation_30 (Activation)      (None, None, None, 2 0           bn4c_branch2b[0][0]
230__________________________________________________________________________________________________
231res4c_branch2c (Conv2D)         (None, None, None, 1 263168      activation_30[0][0]
232__________________________________________________________________________________________________
233bn4c_branch2c (BatchNormalizati (None, None, None, 1 4096        res4c_branch2c[0][0]
234__________________________________________________________________________________________________
235add_10 (Add)                    (None, None, None, 1 0           bn4c_branch2c[0][0]
236                                                                 activation_28[0][0]
237__________________________________________________________________________________________________
238activation_31 (Activation)      (None, None, None, 1 0           add_10[0][0]
239__________________________________________________________________________________________________
240res4d_branch2a (Conv2D)         (None, None, None, 2 262400      activation_31[0][0]
241__________________________________________________________________________________________________
242bn4d_branch2a (BatchNormalizati (None, None, None, 2 1024        res4d_branch2a[0][0]
243__________________________________________________________________________________________________
244activation_32 (Activation)      (None, None, None, 2 0           bn4d_branch2a[0][0]
245__________________________________________________________________________________________________
246res4d_branch2b (Conv2D)         (None, None, None, 2 590080      activation_32[0][0]
247__________________________________________________________________________________________________
248bn4d_branch2b (BatchNormalizati (None, None, None, 2 1024        res4d_branch2b[0][0]
249__________________________________________________________________________________________________
250activation_33 (Activation)      (None, None, None, 2 0           bn4d_branch2b[0][0]
251__________________________________________________________________________________________________
252res4d_branch2c (Conv2D)         (None, None, None, 1 263168      activation_33[0][0]
253__________________________________________________________________________________________________
254bn4d_branch2c (BatchNormalizati (None, None, None, 1 4096        res4d_branch2c[0][0]
255__________________________________________________________________________________________________
256add_11 (Add)                    (None, None, None, 1 0           bn4d_branch2c[0][0]
257                                                                 activation_31[0][0]
258__________________________________________________________________________________________________
259activation_34 (Activation)      (None, None, None, 1 0           add_11[0][0]
260__________________________________________________________________________________________________
261res4e_branch2a (Conv2D)         (None, None, None, 2 262400      activation_34[0][0]
262__________________________________________________________________________________________________
263bn4e_branch2a (BatchNormalizati (None, None, None, 2 1024        res4e_branch2a[0][0]
264__________________________________________________________________________________________________
265activation_35 (Activation)      (None, None, None, 2 0           bn4e_branch2a[0][0]
266__________________________________________________________________________________________________
267res4e_branch2b (Conv2D)         (None, None, None, 2 590080      activation_35[0][0]
268__________________________________________________________________________________________________
269bn4e_branch2b (BatchNormalizati (None, None, None, 2 1024        res4e_branch2b[0][0]
270__________________________________________________________________________________________________
271activation_36 (Activation)      (None, None, None, 2 0           bn4e_branch2b[0][0]
272__________________________________________________________________________________________________
273res4e_branch2c (Conv2D)         (None, None, None, 1 263168      activation_36[0][0]
274__________________________________________________________________________________________________
275bn4e_branch2c (BatchNormalizati (None, None, None, 1 4096        res4e_branch2c[0][0]
276__________________________________________________________________________________________________
277add_12 (Add)                    (None, None, None, 1 0           bn4e_branch2c[0][0]
278                                                                 activation_34[0][0]
279__________________________________________________________________________________________________
280activation_37 (Activation)      (None, None, None, 1 0           add_12[0][0]
281__________________________________________________________________________________________________
282res4f_branch2a (Conv2D)         (None, None, None, 2 262400      activation_37[0][0]
283__________________________________________________________________________________________________
284bn4f_branch2a (BatchNormalizati (None, None, None, 2 1024        res4f_branch2a[0][0]
285__________________________________________________________________________________________________
286activation_38 (Activation)      (None, None, None, 2 0           bn4f_branch2a[0][0]
287__________________________________________________________________________________________________
288res4f_branch2b (Conv2D)         (None, None, None, 2 590080      activation_38[0][0]
289__________________________________________________________________________________________________
290bn4f_branch2b (BatchNormalizati (None, None, None, 2 1024        res4f_branch2b[0][0]
291__________________________________________________________________________________________________
292activation_39 (Activation)      (None, None, None, 2 0           bn4f_branch2b[0][0]
293__________________________________________________________________________________________________
294res4f_branch2c (Conv2D)         (None, None, None, 1 263168      activation_39[0][0]
295__________________________________________________________________________________________________
296bn4f_branch2c (BatchNormalizati (None, None, None, 1 4096        res4f_branch2c[0][0]
297__________________________________________________________________________________________________
298add_13 (Add)                    (None, None, None, 1 0           bn4f_branch2c[0][0]
299                                                                 activation_37[0][0]
300__________________________________________________________________________________________________
301activation_40 (Activation)      (None, None, None, 1 0           add_13[0][0]
302__________________________________________________________________________________________________
303res5a_branch2a (Conv2D)         (None, None, None, 5 524800      activation_40[0][0]
304__________________________________________________________________________________________________
305bn5a_branch2a (BatchNormalizati (None, None, None, 5 2048        res5a_branch2a[0][0]
306__________________________________________________________________________________________________
307activation_41 (Activation)      (None, None, None, 5 0           bn5a_branch2a[0][0]
308__________________________________________________________________________________________________
309res5a_branch2b (Conv2D)         (None, None, None, 5 2359808     activation_41[0][0]
310__________________________________________________________________________________________________
311bn5a_branch2b (BatchNormalizati (None, None, None, 5 2048        res5a_branch2b[0][0]
312__________________________________________________________________________________________________
313activation_42 (Activation)      (None, None, None, 5 0           bn5a_branch2b[0][0]
314__________________________________________________________________________________________________
315res5a_branch2c (Conv2D)         (None, None, None, 2 1050624     activation_42[0][0]
316__________________________________________________________________________________________________
317res5a_branch1 (Conv2D)          (None, None, None, 2 2099200     activation_40[0][0]
318__________________________________________________________________________________________________
319bn5a_branch2c (BatchNormalizati (None, None, None, 2 8192        res5a_branch2c[0][0]
320__________________________________________________________________________________________________
321bn5a_branch1 (BatchNormalizatio (None, None, None, 2 8192        res5a_branch1[0][0]
322__________________________________________________________________________________________________
323add_14 (Add)                    (None, None, None, 2 0           bn5a_branch2c[0][0]
324                                                                 bn5a_branch1[0][0]
325__________________________________________________________________________________________________
326activation_43 (Activation)      (None, None, None, 2 0           add_14[0][0]
327__________________________________________________________________________________________________
328res5b_branch2a (Conv2D)         (None, None, None, 5 1049088     activation_43[0][0]
329__________________________________________________________________________________________________
330bn5b_branch2a (BatchNormalizati (None, None, None, 5 2048        res5b_branch2a[0][0]
331__________________________________________________________________________________________________
332activation_44 (Activation)      (None, None, None, 5 0           bn5b_branch2a[0][0]
333__________________________________________________________________________________________________
334res5b_branch2b (Conv2D)         (None, None, None, 5 2359808     activation_44[0][0]
335__________________________________________________________________________________________________
336bn5b_branch2b (BatchNormalizati (None, None, None, 5 2048        res5b_branch2b[0][0]
337__________________________________________________________________________________________________
338activation_45 (Activation)      (None, None, None, 5 0           bn5b_branch2b[0][0]
339__________________________________________________________________________________________________
340res5b_branch2c (Conv2D)         (None, None, None, 2 1050624     activation_45[0][0]
341__________________________________________________________________________________________________
342bn5b_branch2c (BatchNormalizati (None, None, None, 2 8192        res5b_branch2c[0][0]
343__________________________________________________________________________________________________
344add_15 (Add)                    (None, None, None, 2 0           bn5b_branch2c[0][0]
345                                                                 activation_43[0][0]
346__________________________________________________________________________________________________
347activation_46 (Activation)      (None, None, None, 2 0           add_15[0][0]
348__________________________________________________________________________________________________
349res5c_branch2a (Conv2D)         (None, None, None, 5 1049088     activation_46[0][0]
350__________________________________________________________________________________________________
351bn5c_branch2a (BatchNormalizati (None, None, None, 5 2048        res5c_branch2a[0][0]
352__________________________________________________________________________________________________
353activation_47 (Activation)      (None, None, None, 5 0           bn5c_branch2a[0][0]
354__________________________________________________________________________________________________
355res5c_branch2b (Conv2D)         (None, None, None, 5 2359808     activation_47[0][0]
356__________________________________________________________________________________________________
357bn5c_branch2b (BatchNormalizati (None, None, None, 5 2048        res5c_branch2b[0][0]
358__________________________________________________________________________________________________
359activation_48 (Activation)      (None, None, None, 5 0           bn5c_branch2b[0][0]
360__________________________________________________________________________________________________
361res5c_branch2c (Conv2D)         (None, None, None, 2 1050624     activation_48[0][0]
362__________________________________________________________________________________________________
363bn5c_branch2c (BatchNormalizati (None, None, None, 2 8192        res5c_branch2c[0][0]
364__________________________________________________________________________________________________
365add_16 (Add)                    (None, None, None, 2 0           bn5c_branch2c[0][0]
366                                                                 activation_46[0][0]
367__________________________________________________________________________________________________
368activation_49 (Activation)      (None, None, None, 2 0           add_16[0][0]
369__________________________________________________________________________________________________
370global_average_pooling2d_1 (Glo (None, 2048)         0           activation_49[0][0]
371__________________________________________________________________________________________________
372dense_1 (Dense)                 (None, 1024)         2098176     global_average_pooling2d_1[0][0]
373__________________________________________________________________________________________________
374dense_2 (Dense)                 (None, 5)            5125        dense_1[0][0]
375==================================================================================================
376Total params: 25,691,013
377Trainable params: 2,103,301
378Non-trainable params: 23,587,712
379__________________________________________________________________________________________________

Phần train lại sẽ có khoảng hơn 2 triệu tham số, phần layer ở trước đó không train là khoảng 23 triệu tham số.

Chia tập dữ liệu ra thành 5 phần, 4 phần làm tập train, 1 phần làm tập validation.

 1X, y, tags = dataset.dataset(data_directory, n)
 2nb_classes = len(tags)
 3
 4
 5sample_count = len(y)
 6train_size = sample_count * 4 // 5
 7X_train = X[:train_size]
 8y_train = y[:train_size]
 9Y_train = np_utils.to_categorical(y_train, nb_classes)
10X_test  = X[train_size:]
11y_test  = y[train_size:]
12Y_test = np_utils.to_categorical(y_test, nb_classes)

chúng ta tiến hành thực hiện ImageDataGenerator để có được nhiều dữ liệu mẫu hơn và chống overfit, trong keras đã có sẵn hàm

 1datagen = ImageDataGenerator(
 2        featurewise_center=False,
 3        samplewise_center=False,
 4        featurewise_std_normalization=False,
 5        samplewise_std_normalization=False,
 6        zca_whitening=False,
 7        rotation_range=45,
 8        width_shift_range=0.25,
 9        height_shift_range=0.25,
10        horizontal_flip=True,
11        vertical_flip=False,
12        channel_shift_range=0.5,
13        zoom_range=[0.5, 1.5],
14        brightness_range=[0.5, 1.5],
15        fill_mode='reflect')
16
17datagen.fit(X_train)

Cuối cùng, chúng ta sẽ xây dựng mô hình và tiến hành huấn luyện, lưu mô hình. Quá trình này tốn hơi nhiều thời gian.

 1
 2model = net.build_model(nb_classes)
 3model.compile(optimizer='rmsprop', loss='categorical_crossentropy', metrics=["accuracy"])
 4
 5# train the model on the new data for a few epochs
 6
 7print("training the newly added dense layers")
 8
 9samples_per_epoch = X_train.shape[0]//batch_size*batch_size
10steps_per_epoch = samples_per_epoch//batch_size
11validation_steps = X_test.shape[0]//batch_size*batch_size
12
13model.fit_generator(datagen.flow(X_train, Y_train, batch_size=batch_size, shuffle=True),
14            samples_per_epoch=samples_per_epoch,
15            epochs=nb_epoch,
16            steps_per_epoch = steps_per_epoch,
17            validation_data=datagen.flow(X_test, Y_test, batch_size=batch_size),
18            validation_steps=validation_steps,
19            )
20
21
22net.save(model, tags, model_file_prefix)

Thử download một vài hình ảnh trên mạng về rồi test thử xem sao

![Hình ảnh] (flower-classifition_demo.jpg)

Kết quả khá tốt phải không các bạn.

Cảm ơn các bạn đã theo dõi. Hẹn gặp bạn ở các bài viết tiếp theo.

Comments