0X00 概念
模型微调指对于用特征提取的冻结的模型基,将模型顶部的几层“解冻”,并将解冻的几层和新增加的几层联合训练。
0X01 导入模型
from keras.applications import VGG16 conv_base = VGG16(weights='imagenet', include_top=False, input_shape=(150, 150, 3) ) #查看模型 conv_base.summary()
0x02 解冻知道某一层的所有层
conv_base.trainable = True set_trainable = False for layer in conv_base.layers: if layer.name == 'block5_conv1 (Conv2D) ': set_trainable = True if set_trainable: layer.trainable = True else: layer.trainable = False
0x03 微调模型
#构建ImageDataGenerator生成训练和验证集 import os base_dir = 'D:\\Jupyter\\dogs-vs-cats\\' train_dir = os.path.join(base_dir, 'train') validation_dir = os.path.join(base_dir, 'validation') test_dir = os.path.join(base_dir, 'test') from keras.preprocessing.image import ImageDataGenerator from keras import optimizers train_datagen = ImageDataGenerator( rescale = 1./255, rotation_range = 40, width_shift_range = 0.2, height_shift_range = 0.2, shear_range = 0.2, zoom_range = 0.2, horizontal_flip = True, fill_mode = 'nearest' ) test_datagen = ImageDataGenerator(rescale = 1./255) train_generator = train_datagen.flow_from_directory( train_dir, target_size = (150, 150), batch_size = 20, class_mode = 'binary',#因为使用binary_crossentropy,所以得使用binary二进制标签 ) validation_generator = test_datagen.flow_from_directory( validation_dir, target_size = (150,150), batch_size = 20, class_mode = 'binary', ) #构建完整模型 from keras import models from keras import layers model = models.Sequential() model.add(conv_base) model.add(layers.Flatten()) model.add(layers.Dense(256, activation='relu')) model.add(layers.Dense(1, activation='sigmoid')) model.compile(optimizer=optimizers.RMSprop(lr=2e-5), loss='binary_crossentropy', metrics=['acc']) #训练 history = model.fit_generator( train_generator, steps_per_epoch=100, epochs=100, validation_data = validation_generator, validation_steps=50 )
0x04 评估
import matplotlib.pyplot as plt acc = history.history['acc'] val_acc = history.history['val_acc'] loss = history.history['loss'] val_loss = history.history['val_loss'] epochs = range(1, len(acc) + 1) plt.plot(epochs, acc, 'bo', label='Training acc') plt.plot(epochs, val_acc, 'b', label='Validation acc') plt.title('Training and Validation accuary') plt.legend() plt.figure() plt.plot(epochs, loss, 'bo', label='Training loss') plt.plot(epochs, val_loss, 'b', label='Validation loss') plt.title('Training and Validation loss') plt.legend() plt.show() #最后可以保存模型 #model.save('cats_and_dogs_VGG16_unfreeze.h5')