【AICで使う】KL divergence(カルバック-ライブラー情報量)をわかりやすく解説|python

こんにちは、今回はKL divergenceを解説します。

KL divergenceは、2つの確率分布間の相違を測定するために使用され、NLPにおける文書や単語の分布を比較する際に役立ちます。

レベル感としては、統計検定1級でAICの導出に使われる~という文脈で登場します。

【良いモデルとは】AIC(赤池情報量基準)について|R

KL Divergence(カルバック-ライブラー情報量)

Kullback-Leibler (KL) divergence(カルバック・ライブラー発散)は、確率論および情報理論において、二つの確率分布間の差異を定量化するために用いられる尺度です。

この概念は、二つの確率分布 PQ が与えられたとき、分布 Pがどの程度分布 Qから逸脱しているかを測定するために使用されます。

KL divergenceは、情報理論における相対エントロピーとしても知られています。

連続確率分布の場合は、以下のようになります。

DKL(P||Q)=p(x)log(p(x)q(x))dx

交差エントロピーから情報エントロピーを引くことで求められます。

いずれも期待値であり、交差エントロピーからシャノンエントロピーを引いたものは余分な情報量の平均です。

0であることが望ましいですね。

DKL(P||Q)=H(P,Q)H(P)=Ex P[logQ(x)]Ex P[logP(x)]

PとQは確率分布で、p(x)q(x)はそれぞれの確率密度関数ですね。

距離の尺度ですので、非負性があります。

もう少し厳密にいうと、対数をとっている分母分子をひっくり返してみるとわかります。

DKL(P||Q)=p(x)log(p(x)q(x))dx=p(x)(logq(x)p(x))

DKL(P||Q)=p(x)(logq(x)p(x))log(p(x)q(x)p(x))=0

logの凸性ですね。上の不等式はイェンセンの不等式と呼びます。

ただ、距離尺度とは言いつつ、以下のように

「分布Qを使用して分布Pを表現しようとした場合に生じる情報の損失量」と「分布Pを使用して分布Qを表現しようとした場合に生じる情報の損失量」は異なります。

DKL(P||Q)DKL(Q||P)

数学的な分布間の距離(ユークリッド距離とかマンハッタン距離)とは異なる概念であり、情報理論における一方の分布を使用して別の分布の事象をどれだけ「効率的に」説明できるか、という「情報損失」という方が正確です。

情報量|解釈と使い方

エントロピーや情報量がよくわからない方は、こちらから見てください。

よく「情報量が多い」などという言葉がありますが、情報理論において「ある確率変数Xが実現値xをとった時にどれほど、利得があるか」という意味です。

より稀な事象が起きた方が、情報量は大きいので、確率の逆数を使うというのは直感的にわかるはずです。

対数の底に2を選ぶのは慣習的なものですが、実はなんでも良いです。

I(x)=log21p(x)

そして、情報量の期待値をとると、シャノンエントロピーになります。

H(P)=EP(x)=P(x)log2P(x)

さて、シャノンエントロピーの一部を変えてみます。

H(P,Q)=P(x)log2Q(x)

上は交差エントロピーと呼ばれますが、シャノンエントロピーの差分を計算することで、ある分布Qを使用して別の分布Pを表現しようとした場合に生じる情報の損失量を見ることができるのです。

そしてこの差分を表す量こそがKL divergenceなのです。

最尤法との関わり

では、KL情報量が統計においてどう使われていくのかを深掘りしていきます。

込み入っているので、飛ばしていただいても構いません。(下は関連コンテンツです)

【尤度って?】尤度関数と最尤推定量の解説と例題

前提として、AICなどの情報量基準は、一般にモデル選択のための指標として使われます。

なので、先ほどのDKL(P||Q)=p(y)log(p(y)q(y))dyで、PやQといった分布は以下のように解釈することができます。

p(y):知りたい真の分布

q(y):作った数理モデル

DKL(P||Q)=p(y)log(p(y)q(y))dy=Eylogp(y)Eylogq(y)

ただ、真のモデルp(y)は未知なので、情報量を計算することはできません。

なので、左辺のEylogq(y)大小関係をモデルごとに比較してあげれば、情報量の相対比較はできる、というわけです。

Eylogq(y)は平均対数尤度ですが、これは直接計算できるのでしょうか?

答えはNOです。

なぜなら、真の分布p(y)で積分しているからです。

Eylogq(y)=logq(y)dp(y)

ここで使うのが、経験分布関数です。

ここのxiは各データであり、データごとに累積する関数です。

データが増えるほど、真の分布p(x)に寄っていくことが知られており、解析ではp(y)p^n(y)で置き換えてあげることが多いです。

p^n(y)=1nI(xi)

先ほどの式にp^n(y)を代入してあげると、対数尤度になります。

(データがある点で、1nずつq(xi)をかけて足し合わせたものになるので)

Eylogq(y)=nlogq(y)dp^n(y)=i=1nlogq(Xi)

そして、大数の法則により対数尤度は平均対数尤度に確率収束します。

1ni=1nlogq(Xi)Eylogq(y)

最適なパラメータを探索する際には、対数尤度を最大化します。

→これは平均対数尤度を最大化することに繋がり(最尤法)、それは結果的にKL情報量を最小にすることに繋がります。

つまり、最尤法は近似的にKL情報量を最小化していることになります。

NLP(自然言語処理)での使い方|CODE

生成AI、特に自然言語処理(NLP)や画像生成における機械学習モデルでもKL情報量は使われています。

最も簡単な例をコードで説明します。

テキストデータの簡単なアプローチは、単語の出現頻度を使用することです。

例えば、単語の確率分布は以下のように定義できます:

P(w)=w

from collections import Counter
import numpy as np

def calculate_word_frequencies(document):
    word_count = Counter(document)
    total_words = len(document)
    return {word: count / total_words for word, count in word_count.items()}

def kl_divergence(dist_p, dist_q):
    divergence = 0
    for word in dist_p:
        if word in dist_q:
            divergence += dist_p[word] * np.log(dist_p[word] / dist_q[word])
    return divergence

# 簡単な例
doc1 = ["apple", "banana", "apple", "orange", "banana", "apple"]
doc2 = ["apple", "orange", "orange", "cherry", "cherry", "orange"]

dist_p = calculate_word_frequencies(doc1)
dist_q = calculate_word_frequencies(doc2)

# KL divergence
kl_div = kl_divergence(dist_p, dist_q)
kl_div

生成AIモデルの場合(特にGANやVAEなど)、潜在空間の分布も重要です。

これらのモデルはデータを低次元の潜在ベクトルにマッピングしてから、その潜在ベクトルの分布を分析します。

ただ、使い方は同じでモデルが生成したデータの分布が実際のデータの分布にどれだけ近いかを測定する(情報損失として)ために使用されます

では変分オートエンコーダについて少し解説します。

VAEでKL情報量はどう使われる?

VAEでは、潜在変数の事後分布を近似するために、KL情報量を最小化することを目指します。

標準的なVAEの文脈でのKL divergenceは、次のように表現されます

DKL(qφ(z|x)||p(z))

qφ(z|x)はエンコーダによって提供される潜在変数zの条件付き分布です。

p(z)は潜在変数の事前分布ですね。通常は標準正規分布が使用されます。

エンコーダが提供する事後分布が事前分布からどれだけ離れているのかを測定します。

強めの仮定ですが、事後分布が多変量正規分布になる場合を想定してみます。

DKL(qφ(z|x)||p(z))=12(tr()+μTμklogdet())

tr()は共分散行列の対角成分の和

μTμは平均ベクトルの二乗ノルム

kは潜在空間の次元数

logdet()は共分散行列の行列式の自然対数ですね。

KL情報量の項は、例えばVAEの変分下界の一部として使用されます。

近似推論分布と真の潜在変数の事前分布との間のカルバック・ライブラー発散(KL発散)を表し、近似推論の精度を示しています。

より詳しく知りたい方は以下のコンテンツをご覧ください。

このプロセスによって、データの低次元の潜在表現を効果的に抽出し、その分布を分析するのに役立つということですね。

FOLLOW ME !