今天延續之前的主題,我們將使用EfficientNetB0的架構,但不使用預訓練權重,參考了Keras文檔的文章,我們將input_shape設定為EfficientNetB0預設的值224X224,等等會利用Keras提供的函數,幫我們載入並縮放這些圖片,由於驗證集的部分不希望使用資料增強,所以分成了兩個資料夾,使用兩個ImageDataGenerator去處理。
input_shape = (224, 224)
batch_size = 64
from keras.preprocessing.image import ImageDataGenerator
traindatagen = ImageDataGenerator(
width_shift_range = 0.1,
height_shift_range = 0.1,
zoom_range = 0.2,
shear_range = 0.1,
rotation_range = 25,
horizontal_flip = True,
rescale = 1/255.
)
validdatagen = ImageDataGenerator(
rescale = 1/255.
)
train = traindatagen.flow_from_directory(
img_directory + 'train',
target_size=input_shape,
color_mode="rgb",
class_mode="binary",
batch_size=batch_size,
shuffle=True,
interpolation="lanczos",
)
valid = validdatagen.flow_from_directory(
img_directory + 'validation',
target_size=input_shape,
color_mode="rgb",
class_mode="binary",
batch_size=batch_size,
shuffle=True,
interpolation="lanczos",
)
你應該可以看到它印出了兩行結果,表示完成了對路徑的掃瞄,數量比例也大致相符。接著引入我們的主角。
from tensorflow.keras.applications import EfficientNetB0
from tensorflow.keras.layers import Dense, Dropout
from tensorflow.keras.models import Model
model = EfficientNetB0(include_top=True, weights=None, input_shape=(*input_shape,3), classes=1, activation='sigmoid', pooling='avg')
model.compile(
optimizer = 'adam',
loss = 'binary_crossentropy',
metrics = ['accuracy']
)
再來設定一些回調函數。
from tensorflow.keras.callbacks import ModelCheckpoint, CSVLogger, TerminateOnNaN, EarlyStopping
mcp = ModelCheckpoint(filepath='EfficientNetB0-{epoch:02d}.h5', monitor='val_loss', verbose=0, save_best_only=True, save_weights_only=False, mode='auto', save_freq='epoch')
log = CSVLogger(filename='EfficientNetB0.csv', separator=',', append=False)
ton = TerminateOnNaN()
esl = EarlyStopping(monitor='val_loss', patience=10, mode='auto', restore_best_weights=True)
esa = EarlyStopping(monitor='val_accuracy', patience=10, mode='auto', restore_best_weights=True)
只要經過漫長的等待就可以收穫勝利的果實了。
hist = model.fit(
x = train,
steps_per_epoch = train.samples // batch_size,
epochs = 50,
validation_data = valid,
validation_steps = valid.samples // batch_size,
callbacks = [mcp, log, ton, esl, esa]
)