代表的な損失関数について

Python, ML22 January 2021

損失関数とは

NN(ニューラルネットワーク)においては、ある予測がどれくらい正しいかを判断するときには損失関数を用いる。
これは、大雑把にいうと正解に対して予測がどれだけ外れているか?を示す指標になっている。

なぜ、単純にどれだけ正解したか?といういわゆる精度を指標としないかという点は、
一般的に精度を指標として学習をしようとすると、多くの場合に勾配消失が起こり学習がストップしてしまうから。

MSE(平均二乗誤差)

Mean Square Errorの略。
以下の式で定義されます。

$$ E = \frac{1}{2} \sum_{k=1}^{n} (y_k - t_k)^2 $$

$E$: 損失
$y_k$: 予測
$t_k$: 正解

特に、回帰問題において最も一般的に使用される損失関数

$\frac{1}{2}$はこれがなくても、相対的な評価としては変わらないが、全体の計算が楽になるのでついている。

クロスエントロピー

以下の式で定義されます。

$$ E = -y \log(t) -(1-y) \log(1-t)$$

特に、分類問題において最も一般的に使用される損失関数

クロスエントロピーは一般に、2つの確率分布がどれだけ離れているか(異なっているか)を示すが、それを応用して使っている。
つまり、分類のソフトマックスの出力を確率分布とし、その正解との差を見ている

また、クラス分類が2つの2値分類の場合はバイナリクロスエントロピーと呼ばれる。

エントロピーの理論的な話はこちら

tags: Python, ML