ロジスティック回帰における交差エントロピー誤差が凸関数であることの証明

観測値(xi,yi)...(xN,yN)(x_i,y_i)...(x_N,y_N)R×{1,1}\in\mathbb{R}\times\{-1,1\} から、尤度i=1N11+eyi(β0+xiβ)\prod_{i=1}^{N}\frac{1}{1+e^{-y_i(\beta_0+x_i\beta)}}を最大にする、もしくは、その対数のマイナスをとった値

L(β0,β)=i=1Nln(1+vi),vi=eyi(β0+xiβ)L(\beta_0,\beta)=\sum_{i=1}^{N}\ln(1+v_i), v_i=e^{-y_i(\beta_0+x_i\beta)}

が凸関数であることを示す。

簡単にするため、

パラメータをww(1+vi)(1+v_i)ϕ(wTxi)\phi(w^Tx_i)とおく。(ただしyi=1y_i=1)。また、yi=1y_i=-1の場合は、1ϕ(wTxi)1-\phi(w^Tx_i)とおくことができる。

L(y=tiw)=i=1Nlnϕ(wTxi)ti{1ϕ(wTxi)}1ti=i=1Ntilnϕ(wTxi)+(1ti)ln(1ϕ(wTxi))\begin{aligned} L(y=t_i|w)&=\sum_{i=1}^{N}\ln\phi(w^Tx_i)^{t_i}\{1-\phi(w^Tx_i)\}^{1-t_i} \\ &=\sum_{i=1}^{N}{t_i}\ln\phi(w^Tx_i)+(1-t_i)\ln(1-\phi(w^Tx_i)) \end{aligned}

ここでtを導入しているのは場合分けによって前半の項と後半の項のどちらかが残るようにしているからです。コスト関数を足し算で計算できるようになりました。

式を見やすくするために、ρi=ϕ(wTxi)\rho_i= \phi(w^Tx_i) とおく

L(w)w=L(w)yyρρw=i=1N[tiρiρi(1ρi)xi+1ti1ρi{ρi(1ρi)}xi]=i=1N{ti(1ρi)xi(1ti)ρixi}=i=1N(tixntnρixiρixi+tiρixi)=i=1N(tiρi)xi\begin{aligned} \frac{\partial L(w)}{\partial w}&=\frac{\partial L(w)}{\partial y}\frac{\partial y}{\partial \rho}\frac{\partial \rho}{\partial w} \\ &=\sum_{i=1}^{N} \left[\frac{t_i}{\rho_i}{\rho_i}(1-{\rho_i})x_i + \frac{1-t_i}{1-\rho_i}\{{-\rho_i(1-\rho_i)}\}x_i \right] \\ &=\sum_{i=1}^{N}\{t_i(1-\rho_i)x_i-(1-t_i){\rho_i}{x_i}\} \\ &=\sum_{i=1}^{N}(t_ix_n-t_n\rho_ix_i-\rho_ix_i+t_i\rho_ix_i) \\ &=\sum_{i=1}^{N}(t_i-\rho_i)x_i \end{aligned}

(3)にて合成微分を利用している。

(4)の計算にはロジスティック回帰の微分公式を利用している。

2階微分をすることでヘッセ行列を求める。

2L(w)wwT=L(w)yyρρw=yyρρwi=1N(ρiti)xij=i=1Nxiρi(1ρi)xij=i=1Nρi(1ρi)xijxikT\begin{aligned} \frac{\partial^2{L(w)}}{\partial w\partial{w^T}}&=\frac{\partial L(w)}{\partial{y}}\frac{\partial y}{\partial \rho}\frac{\partial \rho}{\partial w} \\ &=\frac{\partial}{\partial y}\frac{\partial y}{\partial \rho}\frac{\partial \rho}{\partial w}\sum_{i=1}^{N}(\rho_i-t_i)x_{ij} \\ &=\sum_{i=1}^{N}x_i\rho_i(1-\rho_i)x_{ij} \\ &=\sum_{i=1}^{N}\rho_i(1-\rho_i)x_{ij}x_{ik}^T \end{aligned}

2階偏微分であるヘッセ行列の定義より

my image

つまり、以下のように展開される

2L(w)wwT=(ρi(1ρi)xi1xi1ρi(1ρi)xi1xiDρi(1ρi)xiDxi1ρi(1ρi)xiDxiD)\frac{\partial^2 L(w)}{\partial w\partial w^T}= \begin{pmatrix} \sum \rho_i(1-\rho_i)x_{i1}x_{i1} & \cdots & \sum \rho_i(1-\rho_i)x_{i1}x_{iD}\\ \vdots & \ddots & \vdots \\ \sum \rho_i(1-\rho_i)x_{iD}x_{i1} & \cdots & \sum \rho_i(1-\rho_i)x_{iD}x_{iD} \end{pmatrix}

ここで、

W=(ρ1(1ρ1)00ρN(1ρN)),X=(x11x1DxN1xND)W= \begin{pmatrix} \rho_1(1-\rho_1) & \cdots & 0\\ \vdots & \ddots & \vdots \\ 0 & \cdots & \rho_N(1-\rho_N) \end{pmatrix} , X= \begin{pmatrix} x_{11} & \cdots & x_{1D}\\ \vdots & \ddots & \vdots \\ x_{N1} & \cdots & x_{ND} \end{pmatrix}

と定義すると、

H=2L(w)wwT=XTWXH=\frac{\partial^2 L(w)}{\partial w\partial w^T}=X^TWX

とヘッセ行列を表現することができる。

ヘッセ行列における凸性の判定条件は、「任意の点 x ∈ O でヘッセ行列 ∇2f(x) が正定値行列であること」

つまり、任意のベクトルuに対して、uTHu>0u^THu>0を示せば良い。

uTHu=uTXTWXu=(W12uX)T(W12uX)=(W12uX)2>0u^THu=u^TX^TWXu=(W^{\frac{1}{2}}uX)^T(W^{\frac{1}{2}}uX)= ||(W^{\frac{1}{2}}uX)^2||>0

以上より、ロジスティック回帰における交差エントロピー誤差が凸関数であることを証明した