Smile Engineering Blog

ジェイエスピーからTipsや技術特集、プロジェクト物語を発信します

PyTorchで学ぶ『平均』と『分散』と『BatchNormalization』

ところでBatchNormalizationってなんだっけ?

深層学習でモデル構築をしていると、よくお見かけするのがBatchNormalizationと呼ばれるもの。

・・・はい。よく見かけますよね?

・・・・・・・・・。

ところで、結局それってなんでしたっけ?

「平均を0、分散を1」という説明はよく聞きますけど、それは結局どういうことなのか?
というのが今日のお題です。

ちなみにですけど、この記事を書いてる人間は、本ブログで度々統計のことについて語っていらっしゃる統計学スペシャリストとは別の人間です。どちらかというと統計学は素人に近い(?)ので、『え、そこから???』などと言わず、温かい目で読んでいただけたらと思います。
(使用しているのがNumpyではなくPyTorchという辺りがお察しですね・・・)

とりあえずテスト結果を例に。。

このクラスには、6人の生徒がいて、国、数、英、社、理の5科目の試験を受けました。
その試験結果は、こんな具合です。

A 99 10 85 90 1
B 99 25 75 90 1
C 99 30 21 90 1
D 99 15 80 90 1
E 99 12 84 91 1
F 99 20 85 91 99

・・・はい、確実に文系クラスですね!(という話は置いときまして)

  • 国語は全員が99点
  • 社会は4人が90点、2人が91点
  • 理科は5人が1点、1人が99点

まずはここから、平均と分散を求めたいと思います。

平均

まずは平均から。平均については、説明の必要がないと思います。
PyTorchで平均を求める場合は、torch.mean() というAPIを使用します。

# 点数定義             国  数   英  社  理
x_in = torch.tensor([[99, 10, 85, 90, 1],
                     [99, 25, 75, 90, 1],
                     [99, 30, 21, 90, 1],
                     [99, 15, 80, 90, 1],
                     [99, 12, 84, 91, 1],
                     [99, 20, 85, 91, 99]],
                    dtype=torch.float)

print("平均: ", torch.mean(x_in, 0))

結果:

平均:  tensor([99.0000, 18.6667, 71.6667, 90.3333, 17.3333])

ちゃんと平均が算出されてることがわかるかと思います。

分散

続いて、分散を求めたいと思います。
分散というのは、平均値からのばらつき具合です。
分散の値が小さければ、データの値が平均値付近に集中していることになります。

PyTorchでは torch.var() というAPIを使用します。

print("分散: ", torch.var(x_in, 0))

結果:

分散:  tensor([0.0000e+00, 6.0667e+01, 6.3107e+02, 2.6667e-01, 1.6007e+03])

表記が少しわかりにくいかもしれませんが(値が大きすぎですね…)、国語に関して言うと
全員が99点だったため、分散=0
となっています。平均が99点で、全員がその平均点を取ったため、分散が0という具合ですね。

その他の科目を見ると、点数のばらつきの小さい数学と社会は、分散の値も小さくなっています。逆に、理科はばらつきが大きいため、分散の値も大きな値となっていますね。

え、分散の値がおかしい!??

実はPyTorchで求める分散は、正確には不偏分散と呼ばれるものです。
分散といった場合、多くのソフトウェアでは不偏分散のことを差すことの方が多いようですが、Numpyはデフォルトが母分散と呼ばれるものになっており、結果としてPyTorchで求めた分散の結果と、値が異なってしまいます。

Numpyを使用して不偏分散を求める場合は、下記の手法で求めることができます。

# PyTorchのnumpy()を利用してNumpy型へ変換
x = x_in.numpy()
# "ddof=1" とすることで、不偏分散を求める ※デフォルトは0 = 母分散
print(x.var(0, ddof=1))

Batch Normalization

それでは最後に、Batch Normalizationを見ていきます。

Batch Normalization は上述の通り、「平均を0、分散を1」としたものです。
これは2015年に、Sergey Ioffe と Christian Szegedy によって提案された深層学習の手法です。

arxiv.org

Batch Normalization を用いることによって、勾配消失や発散をさせることなく、学習を安定させることができるとされています。特にGANなどではその威力が絶大で、必ず用いるべき!などと言われていたりしますね。

それでは早速PyTorchで実装してみましょう。PyTorchでは torch.nn.BatchNorm1d() などを使用します。(PyTorchだと1dや2dで分かれていたりするんですね…)

# BatchNormalization 式の定義 ※5科目
m = nn.BatchNorm1d(5)

print("-----")
print("BatchNormalization: ")
print(m(x_in))

結果:

-----
BatchNormalization: 
tensor([[ 0.0000, -1.2189,  0.5814, -0.7071, -0.4472],
        [ 0.0000,  0.8907,  0.1454, -0.7071, -0.4472],
        [ 0.0000,  1.5939, -2.2094, -0.7071, -0.4472],
        [ 0.0000, -0.5157,  0.3634, -0.7071, -0.4472],
        [ 0.0000, -0.9376,  0.5378,  1.4142, -0.4472],
        [ 0.0000,  0.1875,  0.5814,  1.4142,  2.2361]],
       grad_fn=<NativeBatchNormBackward>)

国語はそもそも全員が平均点だったため、Batch Normalizationの結果も全員0となります。

  • 平均点に近い = Batch Normalizationの結果も0付近
  • 平均点から遠い = 遠い分だけプラス or マイナス方向へシフト

そう考えるとおよそこの結果も良いように思え・・・ないですね、はい。(ぇ

上の結果、何かがおかしい!?

深層学習を常日頃から触っている人であれば、その問題点はすぐに気づくかと思いますが、実際上記の結果はかなり問題が生じています。最も顕著なのは、社会の Batch Normalization の結果ですね。

なぜ90点と91点という1点差でしかないのに、ここまで値が変わるのか!?と。

理由としてはこんなのが考えられます。。
そもそも入力となる値が大きすぎます!!

注:そもそも6人分の社会の平均点、分散しか見ていないからで、値としては問題なしですね。そもそも課題設定さえ不明!というのもありますけど、ここではあくまで説明用としておきます。

深層学習の入力で、ふた桁のint型の値を入力として使用することはほぼありません。(いや実際はあるのかもしれませんが私はその手法を存じません・・・
なので、Batch Normalizationを使用する前に、データの正規化を行う必要がありそうです。

入力を1/10000にしてみる

print("BatchNormalization(1/10000): ")
x = x_in / 10000
print(m(x))

結果:

BatchNormalization(1/10000): 
tensor([[ 0.0000, -0.2674,  0.3413, -0.0105, -0.3381],
        [ 0.0000,  0.1954,  0.0853, -0.0105, -0.3381],
        [ 0.0000,  0.3497, -1.2971, -0.0105, -0.3381],
        [ 0.0000, -0.1131,  0.2133, -0.0105, -0.3381],
        [ 0.0000, -0.2057,  0.3157,  0.0211, -0.3381],
        [ 0.0000,  0.0411,  0.3413,  0.0211,  1.6905]],
       grad_fn=<NativeBatchNormBackward>)

例として、入力データを全て 1/10000 にしてみました。(例:99点→0.0099点)
※ここもあくまで説明用。実際は log をとるなど、そのような手法を使用するケースが多いと思います。

いかがでしょうか?
社会の点数のばらつき具合を考慮すると、多少は改善されたように感じます。
その他も、概ね・・・・・本当に大丈夫なのかこれ(汗)

実際のところはもっとまともな正規化手法がいくらでもあります。そもそも Batch Normalization も "Normalization" というその名の通り、正規化手法のひとつです。
ただ、今回は極端な実験をしてみたかったため、さすがにデータのばらつきが酷すぎたかもしれません。

(おまけ)入力を1/100にしてみる

ちなみに、これを試す前に入力データを 1/100 とか 1/1000 などにして試していたのですが、ほぼほぼ効果が出ず。。。(元の値とほとんど変わらなかった)

print("BatchNormalization(1/100): ")
x = x_in / 100
print(m(x))

結果(元の結果と何も変わってない!?):

BatchNormalization(1/100): 
tensor([[ 0.0000, -1.2177,  0.5814, -0.5872, -0.4472],
        [ 0.0000,  0.8899,  0.1453, -0.5872, -0.4472],
        [ 0.0000,  1.5924, -2.2092, -0.5872, -0.4472],
        [ 0.0000, -0.5152,  0.3634, -0.5872, -0.4472],
        [ 0.0000, -0.9367,  0.5378,  1.1744, -0.4472],
        [ 0.0000,  0.1873,  0.5814,  1.1744,  2.2360]],
       grad_fn=<NativeBatchNormBackward>)

この辺り、正解が見えにくいところが、深層学習の難しいところですよね、はい。

まとめ

深層学習で用いられる正規化手法の一つ、Batch Normalization について見ていきました。
いかがでしたでしょうか。
ざっくりとした説明となってしまいましたが、イメージだけでも掴んでいただけると幸いです。

ソースコード

最後に、今回紹介したソースコード全文を載せておきます。

import torch
import torch.nn as nn

# 点数定義             国  数   英  社  理
x_in = torch.tensor([[99, 10, 85, 90, 1],
                     [99, 25, 75, 90, 1],
                     [99, 30, 21, 90, 1],
                     [99, 15, 80, 90, 1],
                     [99, 12, 84, 91, 1],
                     [99, 20, 85, 91, 99]],
                    dtype=torch.float)

print("平均: ", torch.mean(x_in, 0))
print("分散: ", torch.var(x_in, 0))

# BatchNormalization 式の定義 ※5科目
m = nn.BatchNorm1d(5)

print("-----")
print("BatchNormalization: ")
print(m(x_in))

print("-----")
print("BatchNormalization(1/10000): ")
x = x_in / 10000
print(m(x))