Computing the gradient

Computation of the gradient is often a non-trivial task in modern optimisation problems.

Background

Recall the gradient descent algorithm for minimising an objective function \(f:\mathbb{R}^d\to\mathbb{R}\): \[ x_{n+1} = x_n - \alpha \nabla f(x_n).\] To implement the algorithm, we need to compute the gradient \(\nabla f(x_n)\) every time we take a step \(\alpha \nabla f(x_n)\) from \(x_n\) to \(x_{n+1}\).

While mathematically it is simply a partial derivative \(\frac{\partial f}{\partial x_i}(x_n)\) for \(i\in\{1,\dots,d\}\). As long as we have this partial derivative function implemented in R for all \(i\), we might just call each of \(d\) functions at every step \(n\in\{1,2,\dots,N\}\), and we reach a local minimum after \(N\) steps. In practice, this naïve idea can easily become infeasible.

First, the required number of steps \(N\) depends on models, the step size \(\alpha\), and quality of initial point \(x_0\). Even in simple 2-D problems \((d=2)\), \(N\) can be thousands.

Second, in modern optimisation problems, it is often the case that \(d\) is quite large, and \(\nabla f\) comprises \(d\) of complicated functions. Even a single evaluation of such \(\nabla f\) may take non-trivial time.

To get a sense of large-scale problems, repeat a simple operation sum <- sum + (-1)^i a billion times (i in 1:1e9).

In statistics and machine learning, optimisation problems often arise from estimation of model parameters. For example, in deep learning, it is common to have billions of parameters \((d>10^9)\) that define a deep neural network. Estimation of those parameters takes days or even weeks.

In a regression problem, we often have the following minimisation problem: \[\min_{w\in\mathbb{R}^d} \mathcal{C}(w) := \min_{w\in\mathbb{R}^d} \sum_{j=1}^m(y_j-f(x_j;w))^2,\] where \(f\) is a statistical model parameterised by \(w\), and \(\{(x_j,y_j)\}_{j=1}^m\) is data to which we fit \(f\) by solving the above minimisation problem.

To use the gradient descent \[ w_{n+1} = w_n - \alpha \nabla \mathcal{C}(w_n),\] we need to compute \(\nabla \mathcal{C}(w_n)\) at every step \(n\in\{1,2,\dots,N\}\).

Neural network

Let’s use a simple neural network to understand common challenges of computing the gradient and a workaround. In standard regression problems, a neural network is a real-valued function \(f:\mathbb{R}^k\to\mathbb{R}\) composed in a special way. The composition is typically illustrated using a network diagram as below.