The chain rule is a fundamental result in calculus. Roughly speaking, it states that if a variable is a differentiable function of intermediate variables , and each intermediate variable is itself a differentiable function of , then we can compute the derivative as follows:
Besides being a handy tool for computing derivatives in calculus homework, the chain rule is closely related to the backpropagation algorithm that is widely-used for computing derivatives (gradients) in neural network training. This blog post by Boaz Barak is a beautiful tutorial on the chain rule and the backpropagation algorithm.
As in Barak’s post, the backpropagation algorithm is usually taught as an application of the chain rule in machine learning classes. This leads to a common belief that “backpropagation is just applying the chain rule repeatedly”. While this is in a sense true, we wish to point out in this blog post that this belief is over-simplifying and can lead to incorrect implementations of the backpropagation algorithm. It has been discussed in many blog posts that backpropagation is not just the chain rule, but we want to focus on a simple and basic difference here.
Consider the “neural network” above: it consists of an input variable , an output variable , and two intermediate variables and . The neural network describes as a function of in the following way:
Using the chain rule (1) twice, we can compute the derivative as follows:
Backpropagation computes the derivative via a different route:
In the calculations above, note the difference between partial and full derivatives. For example, the partial derivative , whereas the full derivative . Similarly, the partial derivative , whereas the full derivative . Intuitively, for every edge in the graph, we can compute the corresponding partial derivative locally (after the forward pass, to be precise), but a full derivative may require longer calculations.
Now it should be clear that equation (3) cannot be directly explained as the standard chain rule (1). The key difference is that in the chain rule, we need partial derivatives of a single variable w.r.t. multiple other variables, whereas in (3), we need partial derivatives of multiple variables w.r.t. the same variable.
Of course, one can prove the correctness of backpropagation using the chain rule in various ways, but the simple proof “backpropagation uses the standard chain rule at every step” is incomplete. Also, it is certainly possible to compute derivatives (gradients) on a neural network directly using the chain rule similarly to (2), but in neural network training one typically wants to calculate the derivatives of a single output variable w.r.t. a large number of input variables, in which case backpropagation allows a more efficient implementation than using the standard chain rule directly.
The real chain rule in actual backpropagation
When implementing the backpropagation algorithm, it is more convenient and efficient to add in the two terms (or any number of terms in general) on the right-hand-side of (3) at different steps of the algorithm. This is described in Barak’s tutorial, which also has an actual Python implementation! In contrast to the failure of naively explaining equation (3) as the chain rule, there is a way to explain this implementation using the chain rule directly.
To describe the implementation, suppose are all the variables, including input variables , intermediate variables , and the only output variable . Assume the variables are arranged in topological order: for every , variable is locally a function of variables with . We can write this as . Note that some variables with may not have an edge to , in which case .
For each variable , backpropagation stores in a “temporary derivative/gradient” w.r.t. . Initially, and for . The backpropagation algorithm iterates over and performs the following updates in each iteration:
Of course, it suffices to update only when there is an edge from to , because otherwise and the update does not change .
The update rule (4) can be explained as a direct application of the chain rule as follows. If we know the values of , we can evaluate the remaining variables one by one, so in this sense, the output variable is a function of . Using the function , we can relate functions and as follows:
By the chain rule,
The correspondence between (4) and (5) completes the explanation: if for every before the update, we have for every after the update. By induction, after the final iteration with , contains the desired value for every , where is exactly the function whose derivatives we want to compute.