TensorFlowで訓練したパラメータをChainerのモデルにrestoreする
何がどうなってか深層学習〜的なものに触れる機会が増えたので,何かそれっぽい話を
「他のフレームワークに比べてなんだか学習済み(pre-trained)モデルが公開されてないような?」
Chainerを使ってみた深層学習マンはきっとこのお気持ちになったことがあるんじゃないだろうか.ユーザ数の差だろうか?数は力だよ兄貴!Caffeのモデルファイル(.caffemodel)であれば,ものによってはchainer.links.caffe.CaffeFunctionでロードできるので一応使えはする.global pooling 等で非対応がありロードできないこともあるが,できない部分は飛ばしてロードしてスキマを自分で書いて〜とかできなくはないので,スンナリとロードはできなくともガンバリでロードはできる気がする.
一方,なんだかんだTensorFlowのpre-trainedモデルがckptファイルで公開されていることは多い.
こいつをいざChainerから利用したいとなってもあんまり記事とかが見当たらない.まぁ逆についても無いというか,そもフレームワークを越えてどうにかする話があんまりない.どうにかしたいという雰囲気や動きは散見されるけど,全体からみたらまだ二の次案件のように見える.よく訓練された深層学習マンにはこんなこと呼吸に等しいタスクだからなのか,それともあんまりこういったマネはしないからなのか,もしくは,そんなことせんでも計算資源がありあまってて新規に学習しちゃえばいいだろということなのか.
いずれにせよ,あのモデルをChainerで書いてみたちょっとちゃんと動くか試したい…けど学習済みのものは無くて〜程度のことで,最近あの東京大学でさえ節約していると噂の貴重な貴重な電力(と時間)を消費して新規に学習し始めるというのも心苦しい.Gentoo使いならおさらだ.なので,やっぱりクロスフレームワークでも再利用したいというのはあるんじゃないかなと.
というわけでckptファイル(群)のパラメータをChainerのモデルにrestoreする方法は,
- ckptファイル(pre-trainedなモデル)を入手
- そのckptファイルを学習したTensorFlowのモデルを入手
- TensorFlowのモデルを眺めて各パラメータに付けられた名称を調べる
- TensorFlowのモデルと同じモデルをChainerで書く
- TensorFlowのCheckpointReaderでckptファイルを開く
- ckptファイルから各名称のパラメータを引っ張り出す
- TensorFlowとChainerではweightのdimの順番が違うので必要に応じて転置
- chainer.Linkの対応するパラメータに代入していく
みたいな流れになる.
class CKPT: def __init__(self, path): # ckptファイルを開く self.ckpt = tf.train.NewCheckpointReader(path) # ndim見て決め打ってるけど,そのパラが何のものかはモデルからわかってるので, # 本当は個別にget_conv2d_weightとかget_fc_weightとかを用意したほうがいい def get(self, name): arr = self.ckpt.get_tensor(name) nd = np.ndim(arr) # 必要に応じて転置 if nd == 4: # おそらく 2D Convolution だろうと return arr.transpose(3,2,0,1).copy() # TensorFlow -> Chainer if nd == 2: # おそらく Fully Connected だろうと return arr.transpose(1,0).copy() # TensorFlow -> Chainer if nd == 1: # biasやBatchNormのパラメータだろうと return arr else: pass # unknown weight type # TODO: raise Exception (snip.) # Chainer版のモデル class YourModel(chainer.Chain): def __init__(self): super(YourModel, self).__init__( c0 = L.Convolution2D(3 , 16, 7, 3, 3, nobias=True), bn = L.BatchNormalization(16, 0.9997, 0.001) c1 = L.Convolution2D(None, 32, 3, 1, 1), (snip.) ) def __call__(self, x): (snip.) def restore_from_ckpt(self, path): ckpt = CKPT(path) # 各名称のパラメータを引っ張り出す self.c0.W.data = ckpt.get('PreTrainedModel/Conv2d_0/weights') self.bn.beta.data = ckpt.get('PreTrainedModel/BatchNorm/beta') self.bn.gamma.data = ckpt.get('PreTrainedModel/BatchNorm/gamma') self.bn.avg_mean.data = ckpt.get('PreTrainedModel/BatchNorm/moving_mean') self.bn.avg_var.data = ckpt.get('PreTrainedModel/BatchNorm/moving_variance') self.c1.W.data = ckpt.get('PreTrainedModel/Conv2d_1/weights') self.c1.b.data = ckpt.get('PreTrainedModel/Conv2d_1/bias') (snip.) if __name__ == '__main__': model = YourModel() model.restore_from_ckpt('pre-trained.ckpt') (snip.)
もちろん,これで万事うまくいくというわけではないだろうが.
名前調べるのめんどくさいなぁという場合,CheckpointReaderのget_variable_to_shape_mapでshapeと共に一覧できるので,それ見て判断とかでもできなくはないこともある.