コンテンツにスキップするには Enter キーを押してください

chainer メモ(その2)fizzbuzzメモ

はじめに

今日は自分でデータを用意して、それを学習に使う。

学習用に画像データを用意したりするのはめんどくさいので、適当に数字を用意してその答えが出てくるような、CPUでもすぐ終わる簡単なものをやります。




fizzbuzzという3で割り切れるときには「fizz」を返し、5で割り切れるときには「buzz」を返し、15で割り切れるときには「fizzbuzz」を返すというゲームのようなものです。

コードは ここ にあげてます。

chainer1.11.0以降の話については 以前の記事 に簡単にまとめたので、ここでは具体的にどんな感じで使うのかや、入力層のデータをどのように作るのかを書きます。

 

fizzbuzzの入力層と出力層

入力層は一つだけというのはありえないので、ここでは2進数にします。
また、学習するときに、例えば1~100までを学習させるとすると、必要な入力層の数はMAXの値を2進数にした時の必要な数になります。
100の場合は1100100になるので7桁必要です。
この場合の3は0000011が入力層になります。

出力層は「15で割り切れるか」「5で割り切れるか」「3で割り切れるか」「それ以外」の4つなので、それぞれを0,1,2,3と割り当てます。

つまり入力層に90を入れるとすると、15で割り切れるので
([1,0,1,1,0,1,0], 0)
(入力層, 出力層)のような形になります。

nums = list(range(1, args.maxnum)) # 1~指定したmaxnumまでの数字をnumsに入れる
random.shuffle(nums)	# 学習がちゃんとするようにshuffle(意味があるかよくわからない)
nptrain_data = np.ones( (args.maxnum, binary_size) )	#先に大きさを指定して定義(nptrain_data = [] とかにしてあとでappendでもよい)
nptrain_label_data = np.ones(args.maxnum) # こたえの配列
i = 0
for n in nums: # for文を回して学習用のデータを作る。
    nptrain_data[i] = to_input_binary(n, binary_size) # to_input_binary()は2進数にするやつ
    nptrain_label_data[i] = to_output_binary(n) # to_output_binary()は0~3のどれかにするやつ
    i = i + 1

nptrain_data = nptrain_data.astype(np.float32) # 入力層はnp.float32型
nptrain_label_data = nptrain_label_data.astype(np.int32) # 出力層はnp.int32型

train = tuple_dataset.TupleDataset(nptrain_data, nptrain_label_data) # chainerの学習用に([入力層], 出力層)となるような変数に変換
# print(train[0])の結果
# (array([ 0.,  1.,  0.,  0.,  0.,  1.,  0.], dtype=float32), 3) のようになる

trainと同じようにtestも作って、それぞれiteratorを作って、trainerに渡して実行する。

> $ python chainer1.15_fizzbuzz_test.py -e 50
GPU: -1
# unit: 1000
# Minibatch-size: 100
# epoch: 50

epoch       main/loss   validation/main/loss  main/accuracy  validation/main/accuracy
1           1.40144     1.67249               0.25           0.53
2           1.67249     1.31182               0.53           0.53
3           1.31182     1.36088               0.53           0.29
4           1.36088     1.17424               0.29           0.44
5           1.17424     1.12141               0.44           0.58
6           1.12141     1.1066                0.58           0.58
7           1.1066      1.06614               0.58           0.54
8           1.06614     1.02257               0.54           0.54
9           1.02257     0.998361              0.54           0.55
10          0.998361    0.997881              0.55           0.58
11          0.997881    0.991423              0.58           0.64
12          0.991423    0.958057              0.64           0.65
13          0.958057    0.911688              0.65           0.63
14          0.911689    0.877736              0.63           0.64
15          0.877736    0.862639              0.64           0.6
16          0.862639    0.852316              0.6            0.59
17          0.852316    0.83436               0.59           0.62
18          0.83436     0.807892              0.62           0.66
19          0.807892    0.777973              0.66           0.68
20          0.777973    0.751283              0.68           0.72
21          0.751283    0.728998              0.72           0.74
22          0.728998    0.708458              0.74           0.73
23          0.708458    0.687423              0.73           0.74
24          0.687423    0.664107              0.74           0.72
25          0.664107    0.638694              0.72           0.78
26          0.638694    0.61366               0.78           0.83
27          0.61366     0.591711              0.83           0.83
28          0.591711    0.570574              0.83           0.87
29          0.570574    0.54719               0.87           0.89
30          0.54719     0.523066              0.89           0.89
31          0.523066    0.50052               0.89           0.89
32          0.50052     0.478227              0.89           0.93
33          0.478228    0.455212              0.93           0.94
34          0.455212    0.433174              0.94           0.97
35          0.433173    0.412282              0.97           0.97
36          0.412282    0.391695              0.97           0.97
37          0.391695    0.37082               0.97           0.98
38          0.37082     0.35046               0.98           0.98
39          0.350461    0.331028              0.98           0.98
40          0.331028    0.312263              0.98           0.98
41          0.312263    0.293945              0.98           0.99
42          0.293944    0.276363              0.99           0.99
43          0.276363    0.259548              0.99           0.99
44          0.259548    0.243475              0.99           0.99
45          0.243475    0.228037              0.99           1
46          0.228037    0.213486              1              1
47          0.213486    0.199517              1              1
48          0.199517    0.186174              1              1
49          0.186174    0.173812              1              1
50          0.173812    0.162189              1              1

epochが50回程度で正解率は1とかになる。

最後に完成したモデルを保存する。

chainer.serializers.save_npz('my.model', model)
chainer.serializers.save_npz('my.state', optimizer)

これを使って予測するのはchainerメモ(その3)で
 




コメントを残す

メールアドレスが公開されることはありません。 * が付いている欄は必須項目です