Backpropagation ≠ Chain Rule

The chain rule is a fundamental result in calculus. Roughly speaking, it states that if a variable c is a differentiable function of intermediate variables b_1,\ldots,b_n, and each intermediate variable b_i is itself a differentiable function of a, then we can compute the derivative \frac{{\mathrm d} c}{{\mathrm d} a} as follows:

\begin{aligned}\frac{{\mathrm d} c}{{\mathrm d} a} = \frac{{\mathrm d}}{{\mathrm d} a}c(b_1(a),\ldots,b_n(a)) = \sum_{i=1}^n \frac{\partial c}{\partial b_i}\frac{{\mathrm d} b_i}{{\mathrm d} a}. && (1)\end{aligned}

Besides being a handy tool for computing derivatives in calculus homework, the chain rule is closely related to the backpropagation algorithm that is widely-used for computing derivatives (gradients) in neural network training. This blog post by Boaz Barak is a beautiful tutorial on the chain rule and the backpropagation algorithm.

As in Barak’s post, the backpropagation algorithm is usually taught as an application of the chain rule in machine learning classes. This leads to a common belief that “backpropagation is just applying the chain rule repeatedly”. While this is in a sense true, we wish to point out in this blog post that this belief is over-simplifying and can lead to incorrect implementations of the backpropagation algorithm. It has been discussed in many blog posts that backpropagation is not just the chain rule, but we want to focus on a simple and basic difference here.

Consider the “neural network” above: it consists of an input variable a, an output variable c, and two intermediate variables b_1 and b_2. The neural network describes c as a function of a in the following way:

\begin{aligned}b_1 = 2a,\ b_2 = 3a + b_1,\ c = 4b_1 + 5b_2.\end{aligned}

Using the chain rule (1) twice, we can compute the derivative \frac{{\mathrm d} c}{{\mathrm d} a} as follows:

\begin{aligned} \frac{{\mathrm d} c}{{\mathrm d} a} & = \frac{\partial c}{\partial b_1}\frac{{\mathrm d} b_1}{{\mathrm d} a} + \frac{\partial c}{\partial b_2}\frac{{\mathrm d} b_2}{{\mathrm d} a} \\ & = \frac{\partial c}{\partial b_1}\frac{{\mathrm d} b_1}{{\mathrm d} a} + \frac{\partial c}{\partial b_2}\left(\frac{\partial b_2}{\partial a} \frac{{\mathrm d} a}{{\mathrm d} a} + \frac{\partial b_2}{\partial b_1} \frac{{\mathrm d} b_1}{{\mathrm d} a}\right) \\ & = 4 \times 2 + 5 \times (3\times 1 + 1\times 2) \\ & = 33. & (2) \end{aligned}

Backpropagation computes the derivative \frac{{\mathrm d} c}{{\mathrm d} a} via a different route:

\begin{aligned}\frac{{\mathrm d} c}{{\mathrm d} a} & = \frac{{\mathrm d} c}{{\mathrm d} b_1}\frac{\partial b_1}{\partial a} + \frac{{\mathrm d} c}{{\mathrm d} b_2}\frac{\partial b_2}{\partial a} & (3) \\ & = \left(\frac{{\mathrm d} c}{{\mathrm d} b_2}\frac{\partial b_2}{\partial b_1} + \frac{{\mathrm d} c}{{\mathrm d} c}\frac{\partial c}{\partial b_1}\right)\frac{\partial b_1}{\partial a} + \frac{{\mathrm d} c}{{\mathrm d} b_2}\frac{\partial b_2}{\partial a} \\ & = (5\times 1+ 1\times 4)\times 2 + 5\times 3 \\ & = 33. \end{aligned}

In the calculations above, note the difference between partial and full derivatives. For example, the partial derivative \frac{\partial b_2}{\partial a} = 3, whereas the full derivative \frac{{\mathrm d} b_2}{{\mathrm d} a} = 3\times 1 + 1\times 2 = 5. Similarly, the partial derivative \frac{\partial c}{\partial b_1} = 4, whereas the full derivative \frac{{\mathrm d} c}{{\mathrm d} b_1} = 5\times 1 + 1\times 4 = 9. Intuitively, for every edge in the graph, we can compute the corresponding partial derivative locally (after the forward pass, to be precise), but a full derivative may require longer calculations.

Now it should be clear that equation (3) cannot be directly explained as the standard chain rule (1). The key difference is that in the chain rule, we need partial derivatives of a single variable w.r.t. multiple other variables, whereas in (3), we need partial derivatives of multiple variables w.r.t. the same variable.

Of course, one can prove the correctness of backpropagation using the chain rule in various ways, but the simple proof “backpropagation uses the standard chain rule at every step” is incomplete. Also, it is certainly possible to compute derivatives (gradients) on a neural network directly using the chain rule similarly to (2), but in neural network training one typically wants to calculate the derivatives of a single output variable w.r.t. a large number of input variables, in which case backpropagation allows a more efficient implementation than using the standard chain rule directly.

The real chain rule in actual backpropagation

When implementing the backpropagation algorithm, it is more convenient and efficient to add in the two terms (or any number of terms in general) on the right-hand-side of (3) at different steps of the algorithm. This is described in Barak’s tutorial, which also has an actual Python implementation! In contrast to the failure of naively explaining equation (3) as the chain rule, there is a way to explain this implementation using the chain rule directly.

To describe the implementation, suppose b_1,\ldots,b_n are all the variables, including input variables b_1,\ldots,b_m, intermediate variables b_{m+1},\ldots,b_{n-1}, and the only output variable c = b_n. Assume the variables are arranged in topological order: for every i = m+1,\ldots,n, variable b_i is locally a function of variables b_j with j < i. We can write this as b_i = b_i(b_1,\ldots,b_{i-1}). Note that some variables b_j with j < i may not have an edge to b_i, in which case \frac{\partial b_i}{\partial b_j} = 0.

For each variable b_i, backpropagation stores in {\mathsf {grad}}_i a “temporary derivative/gradient” w.r.t. b_i. Initially, {\mathsf {grad}}_n = 1 and {\mathsf {grad}}_i = 0 for i < n. The backpropagation algorithm iterates over i = n,n-1,\ldots,m+1 and performs the following updates in each iteration:

\begin{aligned}{\mathsf {grad}}_j \gets {\mathsf {grad}}_j + {\mathsf {grad}}_i \cdot \frac{\partial b_i}{\partial b_j} ,\ \textnormal{for all}\ j = 1,\ldots,i-1. && (4) \end{aligned}

Of course, it suffices to update {\mathsf {grad}}_j only when there is an edge from b_j to b_i, because otherwise \frac{\partial b_i}{\partial b_j} = 0 and the update does not change {\mathsf {grad}}_j.

The update rule (4) can be explained as a direct application of the chain rule as follows. If we know the values of b_1,\ldots,b_i, we can evaluate the remaining variables b_{i+1},\ldots,b_n one by one, so in this sense, the output variable c = b_n is a function c_i of b_1,\ldots,b_i. Using the function b_i = b_i(b_1,\ldots,b_{i-1}), we can relate functions c_i and c_{i-1} as follows:

\begin{aligned}c_{i-1}(b_1,\ldots,b_{i-1}) = c_i(b_1,\ldots,b_{i-1}, b_i(b_1,\ldots,b_{i-1})).\end{aligned}

By the chain rule,

\begin{aligned}\frac{\partial c_{i-1}}{\partial b_j} = \frac{\partial c_i}{\partial b_j} + \frac{\partial c_i}{\partial b_i} \cdot \frac{\partial b_i}{\partial b_j},\ \text{for all}\ j = 1,\ldots, i-1. && (5) \end{aligned}

The correspondence between (4) and (5) completes the explanation: if {\mathsf {grad}}_j = \frac{\partial c_i}{\partial b_j} for every j \le i before the update, we have {\mathsf {grad}}_j = \frac{\partial c_{i-1}}{\partial b_j} for every j < i after the update. By induction, after the final iteration with i = m+1, {\mathsf {grad}}_j contains the desired value \frac{\partial c_m}{\partial b_j} for every j = 1,\ldots,m, where c = c_{m}(b_1,\ldots,b_m) is exactly the function whose derivatives we want to compute.

3 Comments on Backpropagation ≠ Chain Rule

  1. Thanks! I think in theory terms, the way to describe it is as follows: if you follow the chain rule in the standard “forward” direction as we learned in basic calculus, you will pay a price that scales in the *formula size* for the output.

    Backpropagation allows us to pay a price that only scales with the *circuit size*. This is what enables automatic differentiation since a computation graph is simply a circuit.

    I still maintain that it’s the (multivariate) chain rule, but it applied in a clever way.

    Liked by 1 person

    • Thanks for the reply! Just to make a small clarification: if a circuit has only one input variable, there is a way to implement the “forward chain rule” algorithm efficiently (with “time” complexity proportional to the circuit size) via dynamic programming, no matter how many output variables we have. When there are m input variables, doing this for each input variable achieves complexity roughly m * circuit size, which can be much smaller than the formula size. By operating in the reversed direction, backpropagation replaces m by the number of *output* variables. This is a significant improvement because there is usually only one output variable but a huge number m of input variables. To achieve this, backpropagation applies a reversed version of the chain rule like in (3). The difference between the reversed version and the standard version is usually neglected, partly because they are the “same thing” if, for example, there is no edge from b_1 to b_2 in the example. Adding the edge highlights the difference.

      One purpose of this blog post is to show that backpropogation is an arguably non-trivial algorithm, despite its simplicity (which is a lesson that I failed to learn when I was an undergrad…) Thank you for your contribution in helping more people appreciate the ideas behind the backpropagation algorithm!

      Like

1 Trackback / Pingback

  1. Yet another backpropagation tutorial – Windows On Theory

Leave a comment