こんにちは、えびかずきです!
ディープラーニングの活用において学習済みモデルの保存・読み込みは必須作業です。
せっかく時間をかけて重みを調整したモデルが消えてしまっては元も子もありません。
ということで今回は、
Kerasでディープラーニングモデルを保存・読み込みする方法について説明します。
開発環境
OS:MacOS Catalina 10.15.2
言語:Python3.5.4
IDE:jupyter notebook
フレームワーク:
・tensorflow 1.5.0(要install)
※Kerasはtensorflowの内部で使用
方法
保存
kerasでのモデル保存は非常に簡単で、モデルの学習が終わった後に、以下のsaveメソッドを実行するだけでOKです。
これを実行するとプログラムがあるディレクトリに「.h5」というHDF5拡張子(階層データ形式)で保存されます。
読み込み
保存したモデルを読み込むには、tensorflow.python.keras.modelsからload_modelをインポートして下のように使います。
このように保存・読み込みは、どちらも面倒な準備は必要なく簡単に実装できます。
では実際の使用例を見ていきましょう。
モデルの保存例
今回は、下記の記事で紹介した性別判定のための画像認識CNNモデルを保存してみます。
このモデルはVGG16をベースに作成した転移学習モデルになっていて、大まかな構造としては、入力層-隠れ層(21層)-出力層という構造になっています。
学習には総計800枚の顔写真を使用していて、モデルの最適化に約5時間もかかっています。(結構大変!)
学習が完了した後に、モデル変数(下例のmodel)に対して.saveメソッドを適用します。
# ファイル名をtestとしてモデルを保存
model.save('test.h5')
保存ディレクトリを確認してみると、赤枠に示すように「test.h5」というファイルが保存されています。
モデルの読み込み例
では次に読み込みをしてみましょう。
モデル変数(下例ではmodel)へ保存データを格納するという形式で実装します。
#load_modelをインポートする
from tensorflow.python.keras.models import load_model
#modelへ保存データを読み込み
model = load_model('test.h5')
.summary()で本当に読み込みができているか確認してみます。
model.summary()
以下の通り、隠れ層(21層)からなるCNNモデルが読み込まれています。
きちんとmodelが読み込めていることが確認できました。
.summaryでの出力結果:
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_2 (InputLayer) (None, 224, 224, 3) 0
_________________________________________________________________
block1_conv1 (Conv2D) (None, 224, 224, 64) 1792
_________________________________________________________________
block1_conv2 (Conv2D) (None, 224, 224, 64) 36928
_________________________________________________________________
block1_pool (MaxPooling2D) (None, 112, 112, 64) 0
_________________________________________________________________
block2_conv1 (Conv2D) (None, 112, 112, 128) 73856
_________________________________________________________________
block2_conv2 (Conv2D) (None, 112, 112, 128) 147584
_________________________________________________________________
block2_pool (MaxPooling2D) (None, 56, 56, 128) 0
_________________________________________________________________
block3_conv1 (Conv2D) (None, 56, 56, 256) 295168
_________________________________________________________________
block3_conv2 (Conv2D) (None, 56, 56, 256) 590080
_________________________________________________________________
block3_conv3 (Conv2D) (None, 56, 56, 256) 590080
_________________________________________________________________
block3_pool (MaxPooling2D) (None, 28, 28, 256) 0
_________________________________________________________________
block4_conv1 (Conv2D) (None, 28, 28, 512) 1180160
_________________________________________________________________
block4_conv2 (Conv2D) (None, 28, 28, 512) 2359808
_________________________________________________________________
block4_conv3 (Conv2D) (None, 28, 28, 512) 2359808
_________________________________________________________________
block4_pool (MaxPooling2D) (None, 14, 14, 512) 0
_________________________________________________________________
block5_conv1 (Conv2D) (None, 14, 14, 512) 2359808
_________________________________________________________________
block5_conv2 (Conv2D) (None, 14, 14, 512) 2359808
_________________________________________________________________
block5_conv3 (Conv2D) (None, 14, 14, 512) 2359808
_________________________________________________________________
block5_pool (MaxPooling2D) (None, 7, 7, 512) 0
_________________________________________________________________
flatten_2 (Flatten) (None, 25088) 0
_________________________________________________________________
dense_3 (Dense) (None, 256) 6422784
_________________________________________________________________
dropout_2 (Dropout) (None, 256) 0
_________________________________________________________________
dense_4 (Dense) (None, 1) 257
=================================================================
Total params: 21,137,729
Trainable params: 13,502,465
Non-trainable params: 7,635,264
_________________________________________________________________
まとめ
今回はKerasによる学習済みモデルの保存・読み込みについて説明しました。
ディープラーニングでは数時間、ときには数日もの長い時間をかけて、学習を行わなければならないケースが多々あります。しかしながら、そんな長時間の学習を毎回やるわけにはいきません。
モデルをそのまま使わないとしても修正して使ったり転移学習モデルとして活用するという可能性もあるので、
学習を実行したら必ず.saveメソッドで保存しておきましょう!
補足(Tensorflowの場合)
TensorflowでもKeras同様Sequentialを用いてモデルを実装した場合は、上記方法でよいですが、
KerasのModelクラスを継承して実装した場合には、上記の方法では保存・読み込みいずれも実装できないようです。
この場合は、モデルの重みを保存する以下の方法を使用してください。
保存(重みのみ)
読み込み(重みのみ)
※この方法ではモデルの重みのみが保存されている状態なので、
面倒ですがモデルインスタンスは再度生成する必要があります。
参考書籍
今回の記事は以下の書籍を参考にさせていただきました。
コメントを書く