The computational graph for batch norm can be simplified to:
Input vector is used once to compute mean, once to compute standard deviation and once to compute normalized x-values. is used once to compute standard deviation and times to compute normalized x-values. is used times to compute normalized x-values. We need to add up all these contributions backwards.
So we need to compute:
Also, keep in mind that we're looking for ultimately so we need to use the upstream gradient as per the Chain Rule.
We'll compute the gradients as they appear left-to-right in .
Thankfully, evaluates to (saving us from applying the chain rule many times!), so we only need 2 more derivatives
Now, you can plug back into , multiply by the upstream gradient and factorize/simplify.
Fool's way (T.T)
Say we expand the normalized vector like so:
Then we can compute the gradient as
We're going to use the quotient rule to compute the derivative. The derivative for the standard deviation is particularly bashy, so I'll include it here. The rest should be easy.
Completing the quotient rule, we end up with 2 cases
Now we group the left-hand side and right-hand side derivatives separately and sum as per .
You can factor then generalize this to all elements in .