0X00 概念
模型微调指对于用特征提取的冻结的模型基,将模型顶部的几层“解冻”,并将解冻的几层和新增加的几层联合训练。
0X01 导入模型
from keras.applications import VGG16
conv_base = VGG16(weights='imagenet',
include_top=False,
input_shape=(150, 150, 3)
)
#查看模型
conv_base.summary()
0x02 解冻知道某一层的所有层
conv_base.trainable = True
set_trainable = False
for layer in conv_base.layers:
if layer.name == 'block5_conv1 (Conv2D) ':
set_trainable = True
if set_trainable:
layer.trainable = True
else:
layer.trainable = False
0x03 微调模型
#构建ImageDataGenerator生成训练和验证集
import os
base_dir = 'D:\\Jupyter\\dogs-vs-cats\\'
train_dir = os.path.join(base_dir, 'train')
validation_dir = os.path.join(base_dir, 'validation')
test_dir = os.path.join(base_dir, 'test')
from keras.preprocessing.image import ImageDataGenerator
from keras import optimizers
train_datagen = ImageDataGenerator(
rescale = 1./255,
rotation_range = 40,
width_shift_range = 0.2,
height_shift_range = 0.2,
shear_range = 0.2,
zoom_range = 0.2,
horizontal_flip = True,
fill_mode = 'nearest'
)
test_datagen = ImageDataGenerator(rescale = 1./255)
train_generator = train_datagen.flow_from_directory(
train_dir,
target_size = (150, 150),
batch_size = 20,
class_mode = 'binary',#因为使用binary_crossentropy,所以得使用binary二进制标签
)
validation_generator = test_datagen.flow_from_directory(
validation_dir,
target_size = (150,150),
batch_size = 20,
class_mode = 'binary',
)
#构建完整模型
from keras import models
from keras import layers
model = models.Sequential()
model.add(conv_base)
model.add(layers.Flatten())
model.add(layers.Dense(256, activation='relu'))
model.add(layers.Dense(1, activation='sigmoid'))
model.compile(optimizer=optimizers.RMSprop(lr=2e-5),
loss='binary_crossentropy',
metrics=['acc'])
#训练
history = model.fit_generator(
train_generator,
steps_per_epoch=100,
epochs=100,
validation_data = validation_generator,
validation_steps=50
)
0x04 评估
import matplotlib.pyplot as plt
acc = history.history['acc']
val_acc = history.history['val_acc']
loss = history.history['loss']
val_loss = history.history['val_loss']
epochs = range(1, len(acc) + 1)
plt.plot(epochs, acc, 'bo', label='Training acc')
plt.plot(epochs, val_acc, 'b', label='Validation acc')
plt.title('Training and Validation accuary')
plt.legend()
plt.figure()
plt.plot(epochs, loss, 'bo', label='Training loss')
plt.plot(epochs, val_loss, 'b', label='Validation loss')
plt.title('Training and Validation loss')
plt.legend()
plt.show()
#最后可以保存模型
#model.save('cats_and_dogs_VGG16_unfreeze.h5')