コンテンツにスキップするには Enter キーを押してください

chainerのAlexnetを用いてFine Tuningをする

はじめに

以前、alexnetを参考にして顔認識をやってみたのですが


実際に何かを判定する場合には、すでに学習済みのモデルを使い最終い最終レイヤーだけ自分で学習させるというファインチューニング(Fine Tuning)をすることが一般的なようです。
例えば、今回の様に「顔認識をしたい」となったら最終層のみを認識したい顔を学習させることで可能なようです。
ファインチューニングをする目的は、学習データが少なくても良かったりするらしい。
普通はGoogLeNetをファインチューニングするらしいのですが、今回は簡単にAlexnetでやってみようと思います。
GoogLeNetはまた今度挑戦…




 

大体の流れ


ちなみに、今回は以前書いた記事で使用しているソースを元に書いています。
< github >

 

学習済みのcaffeモデルをchainerモデルにして使えるようにする

学習済みのcaffeモデルはこのページから保存できます。
chainerには学習済みcaffeモデルを読み込むための機能があります。
読み込んだcaffeモデルをchainerで使える様に保存するのですが、serializersではなくpickleを使います。

#読み込むcaffeモデルとpklファイルを保存するパス
loadpath = "bvlc_alexnet.caffemodel"
savepath = "./chainermodels/alexnet.pkl"

from chainer.links.caffe import CaffeFunction
alexnet = CaffeFunction(loadpath)

import _pickle as pickle
pickle.dump(alexnet, open(savepath, 'wb'))

 

モデルパラメータのコピー

はじめにも記述しましたが、学習済みのパラメータを使って自分で学習させたい層だけコピーせずに学習させます。特定の層のみ学習させると言っても、実際にはすべての層を学習させてると思います。
モデルが大きくなり大量の画像を学習させると深い層のパラメータは簡単には更新されないそうです。
それはつまり、どの画像においても重用な低次元の特徴が低層のパラメータに反映されているので、自分で判別したい画像群を学習させたい時は高層のみが自然と学習されます。

パラメータのコピーのコードは参考にさせていただいたページで使われているものを使用させていただきます。このコードは同じレイヤー名であった場合はパラメータをコピーします。
つまり、自分で学習させる層はコピー元(学習済alexnet)に無い名前をつけましょう。

import chainer

def copy_model(src, dst):
    assert isinstance(src, chainer.Chain)
    assert isinstance(dst, chainer.Chain)
    for child in src.children():
        if child.name not in dst.__dict__: continue
        dst_child = dst[child.name]
        if type(child) != type(dst_child): continue
        if isinstance(child, chainer.Chain):
            copy_model(child, dst_child)
            if isinstance(child, chainer.Link):
                match = True
                for a, b in zip(child.namedparams(), dst_child.namedparams()):
                    if a[0] != b[0]:
                        match = False
                        break
                    if a[1].data.shape != b[1].data.shape:
                        match = False
                        break
                    if not match:
                        print('Ignore %s because of parameter mismatch' % child.name)
                        continue
                    for a, b in zip(child.namedparams(), dst_child.namedparams()):
                        b[1].data = a[1].data
                        print('Copy %s' % child.name)

次に、モデルの定義です。
AlexLikeとFromCaffeAlexNetという二つのクラスがありますが、前者が以前自分で作ったもので、後者が今回作ったものです。
今回はファインチューニングをするので、モデルはほぼAlexnetです。
具体的にどの様にファインチューニングするかというと、全結合層をすべて自分で学習させます。
学習させる画像のサイズが異なるので、パラメータのコピーをするとノード数などのズレが出てきてしまうからです。
畳込み層のパラメータはフィルタのパラメータとバイアスのみなので画像サイズが異なっても問題ありません。
コピーしない層にはmy_をつけています。

import numpy as np

import chainer
import chainer.functions as F
from chainer import initializers
import chainer.links as L

class AlexLike(chainer.Chain):
    insize = 128
    def __init__(self, n_out):
        super(AlexLike, self).__init__(
            conv1=L.Convolution2D(None,  64, 8, stride=4),
            conv2=L.Convolution2D(None, 128,  5, pad=2),
            conv3=L.Convolution2D(None, 128,  3, pad=1),
            conv4=L.Convolution2D(None, 128,  3, pad=1),
            conv5=L.Convolution2D(None, 64,  3, pad=1),
            fc6=L.Linear(None, 1024),
            fc7=L.Linear(None, 1024),
            fc8=L.Linear(None, n_out),
        )
        self.train = True

    def __call__(self, x):
        h = F.max_pooling_2d(F.local_response_normalization(
            F.relu(self.conv1(x))), 3, stride=2)
        h = F.max_pooling_2d(F.local_response_normalization(
            F.relu(self.conv2(h))), 3, stride=2)
        h = F.relu(self.conv3(h))
        h = F.relu(self.conv4(h))
        h = F.max_pooling_2d(F.relu(self.conv5(h)), 3, stride=2)
        h = F.dropout(F.relu(self.fc6(h)), train=self.train)
        h = F.dropout(F.relu(self.fc7(h)), train=self.train)
        h = self.fc8(h)
        return h

class FromCaffeAlexnet(chainer.Chain):
    insize = 128
    def __init__(self, n_out):
        super(FromCaffeAlexnet, self).__init__(
            conv1=L.Convolution2D(None, 96, 11, stride=2),
            conv2=L.Convolution2D(None, 256, 5, pad=2),
            conv3=L.Convolution2D(None, 384, 3, pad=1),
            conv4=L.Convolution2D(None, 384, 3, pad=1),
            conv5=L.Convolution2D(None, 256, 3, pad=1),
            my_fc6=L.Linear(None, 4096),
            my_fc7=L.Linear(None, 1024),
            my_fc8=L.Linear(None, n_out),
        )
        self.train = True

    def __call__(self, x):
        h = F.max_pooling_2d(F.local_response_normalization(
            F.relu(self.conv1(x))), 3, stride=2)
        h = F.max_pooling_2d(F.local_response_normalization(
            F.relu(self.conv2(h))), 3, stride=2)
        h = F.relu(self.conv3(h))
        h = F.relu(self.conv4(h))
        h = F.max_pooling_2d(F.relu(self.conv5(h)), 3, stride=2)
        h = F.dropout(F.relu(self.my_fc6(h)), train=self.train)
        h = F.dropout(F.relu(self.my_fc7(h)), train=self.train)
        h = self.my_fc8(h)
        return h

これであとはコピーします。

model = L.Classifier(alexLike.FromCaffeAlexnet(len(pathsAndLabels)) )
original_model = pickle.load(open("./chainermodels/alexnet.pkl", "rb"))
copy_model(original_model, model)

 

以前使っていたコードを改変する

改変と言ってもほぼほぼ終わっています。
以前は、

model = L.Classifier(alexLike.AlexLike(len(pathsAndLabels)))

の箇所を上のコピーするところに書いた3行を書くだけです。
あとは適宜copy_modelなど自分で書いたものをimportするだけです。

 

学習!

以前のものはフィルタ数やパラメータが異なるので比較になりません。
なので以前のものも同じ構造にして、コピーしたものとしていないもので比較をしてみます。

非ファインチューニング

-> % python facePredictionTraining.py -g0 -p ./images/ -e 100
GPU: 0
# unit: 1000
# Minibatch-size: 100
# epoch: 100

['./images/0_the_others', './images/nishino', './images/ikuta', './images/hashimoto', './images/akimoto', './images/ikoma', './images/shiraishi']
epoch       main/loss   validation/main/loss  main/accuracy  validation/main/accuracy
1           2.81115     1.74383               0.4384         0.425
2           1.58114     1.52161               0.462917       0.430682
3           1.4713      1.50598               0.4675         0.430682
4           1.4234      1.47006               0.480833       0.443182
5           1.40309     1.40085               0.4775         0.448182
6           1.32762     1.28482               0.509167       0.490227
7           1.24311     1.19394               0.538333       0.555455
8           1.12041     1.14435               0.594583       0.591136
9           1.11958     1.06941               0.58           0.636136
10          1.03119     1.06588               0.617917       0.624318
11          0.950916    1.0767                0.664167       0.638636
12          0.844859    1.03695               0.712917       0.630227
13          0.788513    0.913434              0.72           0.707955
14          0.693958    0.889949              0.769167       0.713636
15          0.591916    0.822851              0.799167       0.725
16          0.551598    0.909012              0.8125         0.749318
17          0.505234    0.874651              0.832917       0.751364
18          0.422546    0.923046              0.856667       0.748182
19          0.405247    0.991293              0.869583       0.755
20          0.350608    1.04474               0.888333       0.722955
21          0.312314    0.964968              0.90125        0.724318
22          0.287352    1.00368               0.90625        0.726591
23          0.2282      1.12567               0.925417       0.744773
24          0.184975    1.02929               0.93625        0.760682
25          0.205978    1.16097               0.935833       0.707273
26          0.190608    0.936228              0.9425         0.764318
27          0.164341    1.17843               0.946667       0.775
28          0.158653    1.24886               0.94375        0.742955
29          0.133191    1.17196               0.95125        0.741818
30          0.125333    1.30723               0.954167       0.721591
31          0.112244    1.20098               0.9675         0.7675
32          0.125869    1.12286               0.956667       0.761818
33          0.0869682   1.38304               0.97375        0.746136
34          0.0572585   1.28984               0.9828         0.775682
35          0.104041    1.40781               0.96375        0.7775
36          0.0904317   1.22378               0.970833       0.753636
37          0.0751891   1.25887               0.976666       0.788182
38          0.0653719   1.67261               0.975833       0.726136
39          0.118484    1.20526               0.962917       0.776364
40          0.0747643   1.30332               0.976667       0.775682
41          0.0402793   1.1529                0.9875         0.796364
42          0.0756816   1.34263               0.975833       0.783182
43          0.0557631   1.37093               0.985          0.775
44          0.038351    1.39616               0.987083       0.776136
45          0.0239319   1.8253                0.99125        0.757273
46          0.0644174   1.85465               0.98375        0.761818
47          0.0725174   1.32585               0.977917       0.775
48          0.0611989   1.33659               0.97875        0.771136
49          0.0526827   1.44711               0.980833       0.788182
50          0.0181338   1.8165                0.99375        0.790682
51          0.0268105   1.9287                0.992083       0.781818
52          0.0314947   2.28898               0.989583       0.746591
53          0.117547    1.45766               0.963333       0.746136
54          0.0524381   1.76803               0.98125        0.759091
55          0.0747688   1.54606               0.98           0.753864
56          0.101947    1.43049               0.969167       0.757273
57          0.123285    1.12438               0.965416       0.732273
58          0.0294317   1.28103               0.991667       0.769318
59          0.010548    1.88965               0.9975         0.771818
60          0.0242695   1.83785               0.993333       0.746591
61          0.018614    1.54702               0.994167       0.7825
62          0.0203313   1.89289               0.995417       0.746591
63          0.0654236   1.3633                0.980833       0.765682
64          0.0684283   1.41756               0.975          0.795682
65          0.0473127   1.41569               0.985          0.737273
66          0.0190186   1.70416               0.992917       0.781364
67          0.0258169   1.85392               0.9896         0.776818
68          0.0585173   1.46012               0.980833       0.790682
69          0.0575793   1.62392               0.98375        0.768864
70          0.0314874   1.54686               0.990417       0.745455
71          0.032756    1.74902               0.992083       0.753636
72          0.0304242   1.56112               0.990417       0.7825
73          0.00713755  1.71896               0.997917       0.795
74          0.0146182   1.54899               0.995          0.78
75          0.0494247   1.43268               0.982917       0.791818
76          0.0309536   1.54138               0.99125        0.784318
77          0.0244532   1.76728               0.993333       0.747955
78          0.0402335   1.80536               0.988333       0.746136
79          0.0700219   1.38528               0.979167       0.763636
80          0.0487884   1.28138               0.988333       0.807045
81          0.0186637   1.4186                0.994583       0.785
82          0.0221136   1.25561               0.993333       0.811364
83          0.0166552   1.74485               0.99375        0.78
84          0.0611952   1.175                 0.982083       0.780682
85          0.0420974   1.33954               0.990417       0.776136
86          0.0263487   1.21781               0.994583       0.7775
87          0.0152975   1.48428               0.995417       0.770455
88          0.0222285   1.58793               0.993333       0.803182
89          0.011478    1.48795               0.995417       0.798864
90          0.039703    2.1737                0.99125        0.730909
91          0.0482692   1.4288                0.985416       0.784318
92          0.0226002   1.71401               0.99375        0.780682
93          0.0260725   1.90249               0.9925         0.7725
94          0.0665291   1.80076               0.982083       0.792045
95          0.0662953   2.05899               0.984583       0.743409
96          0.0676098   1.45836               0.98           0.773636
97          0.0381842   1.27724               0.986667       0.817045
98          0.022721    1.65259               0.992917       0.7975
99          0.0333957   1.62624               0.990833       0.793182
100         0.0164402   1.57237               0.995          0.766818

ファインチューニング

-> % python facePredictionTraining.py -g0 -p ./images/ -e 100 -cmp ./pkls/alexnet.pkl
GPU: 0
# unit: 1000
# Minibatch-size: 100
# epoch: 100

['./images/0_the_others', './images/nishino', './images/ikuta', './images/hashimoto', './images/akimoto', './images/ikoma', './images/shiraishi']
epoch       main/loss   validation/main/loss  main/accuracy  validation/main/accuracy
1           3.61595     1.6791                0.4092         0.456136
2           1.63221     1.48171               0.454167       0.476818
3           1.47838     1.46167               0.462083       0.484318
4           1.39424     1.36533               0.480417       0.474318
5           1.25943     1.22625               0.540833       0.551364
6           1.19304     1.1528                0.570833       0.571364
7           1.13556     1.10993               0.592917       0.611591
8           1.07141     1.04741               0.617917       0.637273
9           1.00369     1.04192               0.656667       0.614091
10          1.00062     1.05676               0.652917       0.640909
11          0.889444    0.84411               0.68375        0.697955
12          0.776894    0.815443              0.733333       0.721364
13          0.706739    0.740765              0.759583       0.74
14          0.609966    0.799358              0.787083       0.756364
15          0.591103    0.733467              0.7975         0.748864
16          0.515128    0.67837               0.820417       0.792045
17          0.440191    0.657561              0.845833       0.768864
18          0.415537    0.799981              0.86375        0.759545
19          0.386299    0.702737              0.869167       0.784545
20          0.300046    0.728946              0.8975         0.799545
21          0.277886    0.709152              0.902083       0.807727
22          0.210336    0.819891              0.927917       0.807727
23          0.215491    0.700536              0.925833       0.800227
24          0.187639    0.907534              0.937917       0.771364
25          0.157693    0.903744              0.945417       0.807273
26          0.1694      0.882888              0.940833       0.781364
27          0.169727    0.919885              0.942083       0.7725
28          0.144512    1.0064                0.953333       0.815227
29          0.102455    1.02722               0.964583       0.755682
30          0.138434    0.837951              0.951667       0.809545
31          0.0982915   0.934578              0.96375        0.817045
32          0.107138    0.999158              0.965          0.808182
33          0.144835    1.05502               0.9525         0.746364
34          0.124283    0.867925              0.9584         0.820227
35          0.118712    1.02738               0.96375        0.787727
36          0.0910045   0.996191              0.966667       0.840227
37          0.101193    1.01551               0.96875        0.812727
38          0.0820016   1.05829               0.974583       0.789545
39          0.0830705   0.946766              0.97375        0.825227
40          0.0572238   1.0376                0.982083       0.815909
41          0.0635522   1.17951               0.975833       0.807045
42          0.0491275   1.07587               0.984167       0.835227
43          0.0411055   1.04035               0.99           0.819545
44          0.024805    1.12053               0.99125        0.818182
45          0.0590315   1.13356               0.98           0.817727
46          0.0744615   1.1405                0.977083       0.789545
47          0.0607121   0.984897              0.982083       0.815227
48          0.0531898   1.08086               0.981667       0.810682
49          0.0455061   1.02838               0.9825         0.795909
50          0.0583834   0.973976              0.98125        0.818409
51          0.0355984   1.33466               0.987917       0.822727
52          0.0372748   1.19893               0.989167       0.817045
53          0.0238393   1.2076                0.990417       0.827727
54          0.0235404   1.28087               0.9925         0.835227
55          0.0576453   1.38844               0.981667       0.799545
56          0.0955091   1.15469               0.967916       0.815909
57          0.0563831   1.2052                0.98           0.832727
58          0.045446    1.13118               0.986667       0.822045
59          0.0834624   1.16511               0.975833       0.825227
60          0.0331868   1.17437               0.987083       0.842727
61          0.0322712   1.1474                0.990833       0.831591
62          0.0573194   1.27675               0.982083       0.806364
63          0.0342424   1.07925               0.9875         0.832045
64          0.0430821   1.38478               0.984167       0.831591
65          0.0555329   1.32085               0.9825         0.820227
66          0.0468738   1.33457               0.986667       0.817727
67          0.058449    1.29985               0.984          0.790682
68          0.0387702   1.37681               0.985          0.817045
69          0.0216443   1.20956               0.99125        0.822045
70          0.00983097  1.14439               0.995833       0.837727
71          0.0550079   1.14223               0.982083       0.815682
72          0.0786399   1.1515                0.977917       0.804545
73          0.0934996   1.23051               0.97375        0.823864
74          0.0568555   0.981688              0.980417       0.809545
75          0.0375977   1.15129               0.986667       0.813409
76          0.0314885   1.32367               0.99           0.827727
77          0.0345499   1.30919               0.989583       0.825227
78          0.0429338   1.29852               0.986667       0.822727
79          0.046201    1.27599               0.984583       0.819545
80          0.0543402   1.40908               0.98125        0.810909
81          0.0484369   1.20406               0.9825         0.817045
82          0.0469517   1.11406               0.986667       0.822727
83          0.0399603   1.14925               0.987917       0.784545
84          0.0520156   1.15001               0.98625        0.834545
85          0.0271345   1.20866               0.992083       0.801364
86          0.0502531   1.00551               0.985833       0.818409
87          0.0328203   1.16065               0.989167       0.824545
88          0.0631189   0.94623               0.984583       0.833409
89          0.0239701   1.18773               0.993333       0.822727
90          0.0140465   1.20417               0.995          0.846591
91          0.026571    1.28324               0.992083       0.830227
92          0.0466379   1.53863               0.9875         0.802727
93          0.0995185   1.12926               0.969583       0.824091
94          0.0966937   1.09435               0.97375        0.812727
95          0.0572787   1.1985                0.980833       0.835909
96          0.0183072   1.57023               0.994167       0.795227
97          0.0473443   1.30383               0.985833       0.827273
98          0.0417519   1.28281               0.985833       0.841591
99          0.0414065   1.36836               0.987916       0.822045
100         0.0398212   1.33762               0.989167       0.813409

ファインチューニングしたほうがvalidationが高くでています。

次は、GoogLeNetでやってみようと思います。

では〜

 

参考

Chainerでファインチューニングするときの個人的ベストプラクティス





コメントする

コメントを残す

メールアドレスが公開されることはありません。 * が付いている欄は必須項目です