Backward pass for Batch Normalization

The computational graph for batch norm can be simplified to:

bitmap

Input vector x is used once to compute mean, once to compute standard deviation and once to compute normalized x-values. μ is used once to compute standard deviation and n times to compute normalized x-values. σ is used n times to compute normalized x-values. We need to add up all these contributions backwards.

So we need to compute:

(1)xμσx=xμσσσxstandard deviation+xμσxinputs+(xμσμ+σμ)μxmean

Also, keep in mind that we're looking for L/x ultimately so we need to use the upstream gradient as per the Chain Rule.

Lx=Lxμσxμσx

We'll compute the gradients as they appear left-to-right in 1.

(2)xμσσ=inxiμσ2(3)σx=12nσin(xiμ)2x=12nσinxi2+inμ22inxiμx=12nσ2(xμ)=xμnσ(4)xμσx=[1σ1σ]

Thankfully, σ/μ evaluates to 0 (saving us from applying the chain rule many times!), so we only need 2 more derivatives

(5)xμσμ=in1σ(6)μx=[1n1n]

Now, you can plug (2)(6) back into 1, multiply by the upstream gradient and factorize/simplify.

Fool's way (T.T)

Say we expand the normalized vector like so:

x^=[x0mnxmnkn(xkmnxmn)2nxnmnxmnkn(xkmnxmn)2n]

Then we can compute the gradient as

(7)Lxi=jnx^jxiLx^j

We're going to use the quotient rule to compute the derivative. The derivative for the standard deviation is particularly bashy, so I'll include it here. The rest should be easy.

σxi=12σnkn(xkmnxmn)2xi=12σn(knxk2xi+kn(mnxmn)2xi2knxkmnxmnxi)=12σn(2xi+1n(mnxm)2xi2nknxkmnxmxi)=12σn(2xi+1n(mnxm)2xi2n(xi(x0++xi++xm)xi+mnxmkinxkxi))=12σn(2xi+1n(mnxm)2xi2n(xi2+kinxixkxi+kinxk))=12σn(2xi+1n(mnxm)2xi2n(2xi+kinxk+kinxk))=12σn(2xi+1n(mnxm)2xi4knxkn)=12σn(2xi+2knxkn4knxkn)=xiknxknσn=xiμσn

Completing the quotient rule, we end up with 2 cases

x^jxi={σ(1n)σ2(xjμ)(xiμσn)σ2if ij,σ(11n)σ2(xiμ)(xiμσn)σ2if i=j

Now we group the left-hand side and right-hand side derivatives separately and sum as per 7.

LHS=(jin1nσLx^j)+(n1nσLx^i)=(1nσjnLx^j)+(1nσLx^i+n1nσLx^i)=(1nσjnLx^j)+(1σLx^i)RHS=jn(xjμ)(xiμ)nσ3Lx^j=x^inσjnx^jLx^j

You can factor 1/nσ then generalize this to all elements in x.

Another way

alt

Compute the gradient w.r.t p first, then x.