終点枕崎

にゃーんo(^・x・^)o

バッチ正規化層でのgradの逆伝播

「バッチ正規化 逆伝播」で検索すると、計算グラフを用いた説明ばかり出てきます。数式での証明も欲しいものです。そこで、この記事は、数式を使って導出することを目標とします。一応Chainerの実装 batch_normalization.py で確認しましたが、間違いや誤植などあれば教えていただけると幸いです。

バッチ正規化は、レイヤー間を流れるデータの分布が、適度な広がりと偏りを持つように正規化することを目的とします。具体的には、各データ要素に対し、ミニバッチ全体の平均と分散を求め、データの平均が0、分散が1になるように調整します。

データ要素を表す添字は省略
\(N\) : バッチサイズ
\(x_i\) : 入力。iはミニバッチのi番目のデータであることを表す添字。
\(y_i\) : 出力。iはミニバッチのi番目のデータであることを表す添字。
\( \epsilon \) : 定数。小さい。

\begin{eqnarray}
\overline{x} & = & \frac{1}{N} \sum_{k = 1}^{N} x_k \\
s^2 & = & \frac{1}{N} \sum_{k = 1}^{N} (x_i - \overline{x})^2 \\
\hat{x_i} & = & \frac{x_i - \overline{x}}{ \sqrt{s^2 + \epsilon} } \\
y_i & = & \gamma \hat{x_i} + \beta
\end{eqnarray}

逆伝播では、\( \partial E / \partial {\bf y} \) を受け取って、最終的に\( \partial E / \partial {\bf x} \) を計算します。ただし、誤差関数を \(E\) としました。

バッチ正規化が学習するパラメータは、\( \gamma \) と \( \beta \) です。なので、まずは \( \gamma \) と \( \beta \) についての偏微分を求めます。\( \gamma \) と \( \beta \) はどちらも、変化すると任意の \( y_i \) が変化するため、連鎖律から

\begin{eqnarray}
\frac{\partial E}{\partial \gamma} & = & \sum_{k} \frac{\partial E}{\partial y_k} \frac{\partial y_k}{\partial \gamma} = \sum_{k} \frac{\partial E}{\partial y_k} \hat{x_k} \\
\frac{\partial E}{\partial \beta} & = & \sum_{k} \frac{\partial E}{\partial y_k} \frac{\partial y_k}{\partial \beta} = \sum_{k} \frac{\partial E}{\partial y_k}
\end{eqnarray}

と計算できます。
\( y_i \) と \( \hat{x_i} \) に関する偏微分の関係は、(4)から次のようになります。

\begin{equation}
\frac{\partial E}{\partial \hat{x_i}} = \gamma \frac{\partial E}{\partial y_i}
\end{equation}

あとは、 \( \hat{x_i} \) と \( x_i \) に関する偏微分の関係を求めればOKです。

\begin{equation}
\frac{\partial E}{\partial x_i} = \sum_{k} \frac{\partial E}{\partial \hat{x_k} } \frac{\partial \hat{x_k} }{\partial x_i} = \frac{\partial E}{\partial \hat{x_i} } \frac{\partial \hat{x_i} }{\partial x_i} + \sum_{k \neq i} \frac{\partial E}{\partial \hat{x_k} } \frac{\partial \hat{x_k} }{\partial x_i}
\end{equation}

ここで、\( \frac{\partial}{\partial x_i} (\sqrt{s^2 + \epsilon}) \) をあらかじめ計算しておきます。

\begin{equation}
\frac{\partial }{\partial x_i} (\sqrt{s^2 + \epsilon}) = \frac{s}{\sqrt{s^2 + \epsilon}} \frac{\partial s}{\partial x_i} \\
\end{equation}
(2)を微分して
\begin{eqnarray}
2s \frac{\partial s}{\partial x_i} & = & \frac{1}{N} \biggl[ 2(x_i - \overline{x}) \biggl(1 - \frac{1}{N} \biggr) + \sum_{k \neq i} 2(x_k - \overline{x}) \biggl(-\frac{1}{N} \biggr) \biggr] \nonumber \\
& = & \frac{1}{N} \bigg[ 2(x_i - \overline{x}) - 2\sum_k \frac{x_k - \overline{x}}{N} \bigg] \nonumber \\
& = & \frac{2}{N} (x_i - \overline{x}) \nonumber
\end{eqnarray}
より
$$
\frac{\partial s}{\partial x_i} = \frac{1}{N} \frac{x_i - \overline{x}}{s}
$$
(3)と(9)を使って
\begin{equation}
\frac{\partial}{\partial x_i} (\sqrt{s^2 + \epsilon}) = \frac{\hat{x_i}}{N}
\end{equation}
次に、(3)を微分し、(8)の右辺第1項の \( \partial \hat{x_i} / \partial x_i \)を求めます。途中で(10)を用います。
\begin{eqnarray}
\frac{\partial \hat{x_i}}{\partial x_i} & = & \frac{1}{s^2 + \epsilon} \bigg[ \bigg( 1 - \frac{1}{N} \bigg) \sqrt{s^2 + \epsilon} - (x_i - \overline{x}) \frac{\hat{x_i}}{N} \bigg] \nonumber \\
& = & \frac{1}{\sqrt{s^2 + \epsilon}} \bigg[ \bigg( 1 - \frac{1}{N} \bigg) - \frac{\hat{x_i}^2}{N} \bigg]
\end{eqnarray}
次に、\( k \neq i \) のもと、(3)を微分し、(8)の右辺第2項の \( \partial \hat{x_k} / \partial x_i \)を求めます。
\begin{eqnarray}
\frac{\partial \hat{x_k}}{\partial x_i} & = & \frac{1}{s^2 + \epsilon} \bigg[ - \frac{1}{N} \sqrt{s^2 + \epsilon} - (x_k - \overline{x}) \frac{\hat{x_i}}{N} \bigg] \nonumber \\
& = & \frac{1}{\sqrt{s^2 + \epsilon}} \bigg[ -\frac{1}{N} - \frac{\hat{x_i} \hat{x_k}}{N} \bigg]
\end{eqnarray}

(11)と(12)を(8)に代入し整理すると、
\begin{eqnarray}
\frac{\partial E}{\partial x_i} & = & \frac{1}{\sqrt{s^2 + \epsilon}} \Bigg[ \frac{\partial E}{\partial \hat{x_i}} \bigg[ \bigg( 1 - \frac{1}{N} \bigg) - \frac{\hat{x_i}^2}{N} \bigg]
+ \sum_{k \neq i} \frac{\partial E}{\partial \hat{x_k} } \bigg[ -\frac{1}{N} - \frac{\hat{x_i} \hat{x_k}}{N} \bigg] \Bigg] \nonumber \\
& = & \frac{1}{\sqrt{s^2 + \epsilon}} \Bigg[ \frac{\partial E}{\partial \hat{x_i}} - \frac{1}{N} \bigg[ \sum_k \frac{\partial E}{\partial \hat{x_k}} + \hat{x_i} \sum_k \frac{\partial E}{\partial \hat{x_k}} \hat{x_k} \bigg] \Bigg] \nonumber \\
& = & \frac{\gamma}{\sqrt{s^2 + \epsilon}} \Bigg[ \frac{\partial E}{\partial y_i} - \frac{1}{N} \bigg[ \sum_k \frac{\partial E}{\partial y_k} + \hat{x_i} \sum_k \frac{\partial E}{\partial y_k} \hat{x_k} \bigg] \Bigg] \nonumber
\end{eqnarray}
ただし途中で(7)を使いました。最後に、(5)と(6)を代入すると、
\begin{equation}
\frac{\partial E}{\partial x_i} = \frac{\gamma}{\sqrt{s^2 + \epsilon}} \Bigg[ \frac{\partial E}{\partial y_i} - \frac{1}{N} \bigg[ \frac{\partial E}{\partial \beta} + \hat{x_i} \frac{\partial E}{\partial \gamma} \bigg] \Bigg]
\end{equation}

と比較的簡単な式で表せました。実際、Chainerもこの式で導出しています。まとめると、バッチ正規化レイヤーでのbackwardの演算は、
1. (5), (6)式で \( \partial E / \partial \gamma \) と \( \partial E / \partial \beta \) を求める
2. (13)式で \( \partial E / \partial x_i \) を求める
という流れになります。実装は省略します。