機械学習モデルの汎化を妨げる過学習とは? - 旗揚げ画像の二値分類を例に

目的

この記事では機械学習に関連する以下のトピックについて解説し、機械学習を学んだことのない方が、汎化と過学習に関する観念的な理解を深めることを目指します。

  • 予測モデルの汎化能力
  • 汎化能力の向上を妨げる過学習という現象

観念的なわかりやすさを重視しているため、細部に誤りや不正確な点があるかもしれません。
不適切な点がございましたら、ご教示頂ければ幸いです。

以下、この記事では画像の二値分類を例に取ります。
ですが、汎化と過学習に関する知識の適用範囲は画像認識の分野にとどまりません。

  • 予測モデルにおいては、データ誤差を抑えることが本質的な課題であり、過学習はその目標達成を妨げる
  • 実際に機械学習を行う際、適切なデータサイズやエポック数(学習回数)を決めるために過学習に関する知識が必要である

以上のことから、モデルの汎化能力と過学習に関する理解は機械学習について学ぶ上で十分に汎用的であり、初学者が理解を深める段階から実用レベルに至るまで、幅広く役に立ちます。

旗揚げ画像の二値分類

以降の議論では、画像の二値分類をテーマとし、

  • 汎化精度を上げるとはどういうことか?
  • 過学習とはどんな現象のことか?

を説明していきます。

ここでは二値分類の対象として、以下のような小さなサイズの「旗揚げ画像」を取り上げます。

図1: 入力例

受け付ける入力は8*8ピクセルで、情報がRGB形式で与えられる任意のものとします。 (よってこの記事に示す入力例よりも、遥かに多様な入力が考えられることに注意してください。)

図1のような画像に「赤を揚げている」「白を揚げている」というラベルを付与したものを訓練データとし、教師つき学習を行うこととします。

図2: いろいろな入力例

学習の目的は、未知の入力画像に対し、赤と白のどちらを揚げている画像かを適切に判定出来るようになることです。

補足:二値分類とは

二値分類とは、与えられたデータを、事前に与えられた二つのラベルに振り分けることを指します。

画像認識においては、以下のような例が有名です。

  • 男性か女性かを見分ける
  • 犬か猫かを見分ける

画像認識以外の二値分類として、以下のような例が有名です。

  • メールをスパムと非スパムに区別する
  • ある患者のデータから、特定の疾病を持つかどうかを区別する

二値分類における主流のモデル - 事後確率の計算

ここで取り上げている二値分類や、より多くのクラスに分ける他クラス分類においては、入力に対し「事後確率」を計算するのが主流の方法です。

図3: 各入力に対する事後確率の例

ここで「事後確率」とは「データが与えられた後の確率」という意味であり、「ある特定のデータが、赤であるか白であるかについて、モデルによって推定された結果」であると解釈して頂いて問題ありません。学習途中においても、事後確率を算出することが可能です。

ここで機械学習モデルの中身についてはブラックボックスとしていますが、畳み込みニューラルネットワーク(CNN)が古典的なモデルとして有名です。

Convolutional Neural Networkとは何なのか

学習の枠組み - 最尤推定

各入力データに対して事後確率を求めるモデルでは、学習の各段階において尤度(likelihood)を求めることが出来ます。

尤度は、それぞれの入力データが持つ「正解のラベル」に対し、その時点のモデルが算出した事後確率の積を取ったものと考えられます。

図4: 尤度の計算

ここでは二つの入力データに対する尤度を計算しており、0.72 = 72%という値が算出されています。

事後確率を求める機械学習モデルは、学習を繰り返すことで尤度を増加させるよう設計されています。

ここでは訓練データに対して尤度を高める過程を適合(fitting)といい、下に示す図5は学習の結果、訓練データに極めて適合した状態を模式的に表しています。

図5: 訓練データへの適合度が高い状態

各入力に対し、正解のラベルに1に近い事後確率を割り振っていますね。

図5はいわば、学習によって、訓練データに対して百発百中の精度を持つようになった状態です。

これが望ましいことかどうかを、以下の段落で見ていきます。

未知の入力に対する推定 - 汎化能力

図4、5だけを見ると、訓練データへの適合度が高い図5の状態の方が望ましいように見えます。

しかし、このモデルの目的は「未知の入力画像に対し、赤と白のどちらを揚げている画像かを適切に判定出来るようになること」なのでした。

ですので、訓練データに含まれていない画像、すなわち「テストデータ」に対して推定を行ってみるべきでしょう。

図6: 未知の入力への適用

図6は、図4のモデルに対して、 $ x' $ という未知のテストデータを入力した結果を示しています(もちろん、実際のテストでは複数のテストデータを入力します)。

正解のラベル(白)に対して、赤よりも高い事後確率を割り振っていることがわかります。

このように、訓練データ以外のデータに対しても、正しく推定を行うことの出来る能力を汎化能力と呼びます。

この言葉を使えば、「機械学習モデルの精度を上げること=汎化能力を高めること」、と言うことが出来ます。

一方、訓練データに対しては百発百中だった図5のモデルはどうでしょうか。

図7: 過学習が疑われるケース

図7は、図5のモデルに対してテストデータ $ x' $ を入力したら、正解のラベル(白)よりも不正解のラベル(赤)に高い事後確率が割り振られてしまった状態を示しています。

これは機械学習が汎化能力をもつことの妨げとなる、過学習(overfitting)の問題を端的に表したものとなっています。

過学習(overfitting)とは

過学習が起こる原因は様々ありますが、その一つはデータサイズが不十分であり、データの持つ非本質的な「癖」まで学習してしまうことです。

例えばこれまでの図において、訓練データ $ x_2 $ として登場していた旗揚げモデルの男の子を「Aくん」としてみましょう。

図8: Aくんに着目

ここでAくんは複数の訓練データにおいて登場しているとします。

図9: Aくんを映した様々なデータ

Aくんは赤色が好きなので、ついつい赤を揚げてしまうという「癖」があります。当然、これらの訓練データには赤のラベルが付与されています。

一方で、Aくんを映した訓練データには他にも特徴があります。

図10: 本質的でないデータの特徴

Aくんは背が小さいので、画像の上のほうに空白が空いてしまうのです。(ここで空白とは、ピクセル間のRGB値の変化が穏やかな領域のことと考え、またAくんの他に背の低い旗揚げモデルを映したデータは存在しないものとします。)

以上の議論をまとめると、今回使用した訓練データには、

「上のほうに空白が開いている画像は、赤を上げている」

という非本質的な「癖」があるといえます。

実は、図7のような出力がなされた理由は、機械学習モデルが入力データを入念に学習しすぎることよって、この特徴を覚え込んでしまったことだったのです。

図7: 過学習が疑われるケース

その結果として、Aくんと同様に「背が低い」未知の入力 $ x' $ に対して、赤を揚げているという誤った推論をしてしまったのですね。

以上は、データの偏りによる過学習について述べました。他には、実際の振る舞いに比べてモデルの自由度が高すぎること等が、過学習の原因として考えられます。

過学習を防ぐための代表的な手法

まず、十分なデータ数を用意することが重要です。

データを多様化しておくことで、先のように非本質的な癖を持つことを防ぐことが出来ます。

また、訓練データに対する誤差(二値分類においては対数尤度の総和)を訓練誤差と呼びますが、訓練誤差ばかりが小さくなり、テスト誤差が小さくならない(または大きくなっている)場合は、学習の効果が出ていないことになります。

図11: 学習曲線

上図は学習曲線(learning curve)と呼ばれるグラフで、エポック数(訓練データに対する学習回数)と誤差の関係を表しています。

(二値分類においては、対数尤度の総和にマイナスを付けた値を誤差として用いれば、尤度が1に近づくほど誤差が0に近づき、尤度が低くなると無限大に発散するような誤差関数を実現出来ます。)

上図で重要なのは、エポック数が増えるほど(訓練をすればするほど)訓練データに対する誤差は小さくなっているが、テストデータに対する誤差が次第に剥離し、やがて増大し始めているという点です。

このような時は学習を早期打ち切りすることによって、過学習を防ぐことが出来ます。

他に、モデルの自由度が高すぎることによる過学習を防ぐために、特にニューラルネットワークにおいて有用とされている手法として正則化(regularization)やドロップアウトがあります。これについては今後の記事で紹介していきます。

最近の記事タグ

関連記事

\(^▽^*) 私たちと一緒に働いてみませんか? (*^▽^)/

少しでも興味をお持ちいただけたら、お気軽に、お問い合わせください。

採用応募受付へ

(採用応募じゃなく、ただ、会ってみたいという方も、大歓迎です。)