‘’If any one faculty of our nature may be called more wonderful than the rest, I do think it is memory’’ — Jane Austen

In this blog, we’ll consider the sequential prediction problem of predicting the next observation $x_t$ given a sequence of past observations $x_1, x_2,\dots, x_{t-1}$ , and we’ll study this from the point of view of storing and referencing information about past observations in memory in order to predict future observations. We’ll start with something basic, and then make it more general. This is based on joint work with Sham Kakade, Percy Liang and Greg Valiant.

Consider a simple scenario, where the sequence of observations is just a sequence of n bits that keeps repeating (see Fig. 1), and this $n$ bit string is even known in advance. Suppose you get a sequence of observations from this model, and the task is to predict the output at the next time step, given only the previous $\ell$ outputs. Clearly if  $\ell \ge n$, then the task is trivial, because the outputs are periodic with period $n$.  Is a shorter window sufficient?  If the $n$ bit string is chosen uniformly at random, then with high probability all $O(\log n)$ length substrings of the string are unique, and therefore $O(\log n)$ length sequences of observations are sufficient to uniquely identify the current position in the string, hence $\ell=O(\log n)$ is sufficient to predict the outputs accurately.

Fig. 1: A simple sequential model of $n$ bits which keep repeating.

Is there any hope of making good predictions when $\ell=O(\log n)$ for worst-case $n$ bit strings? Note that observations from this model could certainly have dependencies across length $O(n)$ time scales for some strings—for example consider the string which has a single 1 and 0s in all other positions, here we need to know the previous $(n-1)$ outputs to determine if the next output will be a 1.  Despite these long-range dependencies, as it turns out, windows of $O(\log n)$ observations are sufficient to make accurate predictions on average across time, both in this model of a repeating length $n$ string, as well as much more generally.  Before describing these results and the intuition behind them, we provide some context for this basic question of how to consolidate and reference memories about the past in order to effectively predict the future?

Such questions about the importance of memory and how humans form memories have interested thinkers since back in the time of Plato. Our featured picture is from Christopher Nolan’s Memento, the protagonist here suffers from short-term memory loss approximately every five minutes, and the plot masterfully explores how our memories in some sense shape our reality.  There has been considerable interest in the neurosciences community to understand how humans and animals create and retrieve memories to make accurate predictions about their environment (for example, see [1] and [2]). Closer to home, one can ask the question of how algorithms can consolidate and reference memories about the past in order to effectively predict the future.

This question of how to use memory (both how to figure out what to remember, and how to usefully query those memories) has been one of the most significant challenges that practical ML and NLP researchers have been grappling with recently. These efforts have led to a variety of neural network architectures that have an explicit notion of memory, including recurrent neural networks (RNNs), neural Turing machines, memory networks and Long Short-Term Memory (LSTM) networks (for a very nice introduction to these, see this). These advances have obtained some degree of practical success, but seem largely unable to consistently learn long-range dependencies, which are crucial in many settings including language. One amusing example of this is the recent sci-fi short film Sunspring whose script was automatically generated by a LSTM network. Locally, each sentence of the dialogue (mostly) makes sense, though there is no cohesion over longer time frames, and no overarching plot trajectory (despite the brilliant acting). It is an interesting watch — check it out below!

Many fundamental questions in this setting seem ripe for theoretical investigation. i) How much memory is necessary to accurately predict future observations, and what properties of the underlying sequence determine this requirement? ii) Must one remember significant information about the distant past or is a short-term memory sufficient? iii) What is the computational complexity of accurate prediction? iv) How do answers to the above questions depend on the metric that is used to evaluate prediction accuracy? We believe that answers to these questions could both guide the development of practical prediction systems, and help understand how prediction/learning takes place in nature.

In our recent work, we attempt to make progress on the first three questions. Perhaps surprisingly, we show that for a broad class of sequences, the “naive” algorithm that bases its predictions only on the most recent few observation together with a set of simple summary statistics of the past observations, predicts nearly as well as possible on average. The average error here is the error at every time step averaged across time, for a sufficiently large time window (average error is the most natural error metric, and is the metric ubiquitously used in practice). One concrete special case of our more general result concerns sequences generated according to a Hidden Markov Model (HMM) with at most $n$ hidden states (note that model in Fig. 1 corresponds to a very simple HMM with $n$ hidden states). We show that the naive prediction algorithm based on the empirical frequencies of length $\ell=O(\log n/\epsilon)$ windows of observations achieves average $\ell_1$ error at most $\sqrt{\epsilon}$ greater than the average error of the optimal predictor which knows the entire history of the sequence and the parameters of the underlying HMM; for this naive empirical model to achieve this error, the length of the sequence must be quite long, $d^{\Omega(\log n/\epsilon)}$, where $d$ is the size of the observation alphabet.

This naive prediction algorithm is simply the $\ell$-th order Markov model, which predicts the distribution of the next observation based on its conditional distribution given the previous $(\ell-1)$ observations, where this conditional distribution is estimated from the empirical frequencies observed so far. Note that this our is independent of the mixing time of the Markov model, and holds even when the Markov Chain does not mix.

Interestingly, this result shows that accurate prediction is possible even if the algorithm does not explicitly capture any long-range dependencies.  Note that a $n$ state HMM can certainly represent dependencies of length $n$ (as in the model in Fig. 1); nevertheless, the predictor that only uses the most recent $O(\log n)$ observations can achieve nearly optimal prediction error.  One interpretation of these results is that, from the perspective of prediction, all that matters is the amount of ”dependencies”, not whether the dependencies are long-range or short-range.  Supporting this interpretation, we also show the following general result:  for any distribution over sequences of observations (not necessarily generated by a HMM), for which the mutual information between the entire past observations $\dots, x_{t-2}, x_{t-1}$ and future observations $x_t, x_{t+1}, \dots$ is bounded by $\mathcal{I}$, the best $\ell$-th order Markov model obtains average KL error $\mathcal{I}/\ell$, or $\ell_1$ error $\sqrt{\mathcal{I}/\ell}$ with respect to the optimal predictions (note that for a HMM the mutual information $\mathcal{I}$ is bounded by $\log n$ as $\log n$ bits are sufficient to specify the hidden state, and the hidden state encapsulates all the information about the past). We also show that it is information theoretically impossible to achieve a smaller error than this using only the previous $\ell$ observations.

The idea behind these results is most intuitive in the setting of a sequence generated according to an HMM with at most $n$ hidden states.  In this case, at each time step $t$, we either predict accurately (and are unsurprised when $x_t$ is revealed to us), or if we predict poorly and are surprised by the value of $x_t$ , then (in a sense that can be made rigorous) $x_t$ must contain a significant amount of information about the true hidden state.  Because the hidden state can be specified via $\log n$ bits, this provides a bound on the number of errors that one expects to make.  Check out the following video to explore this intuition more for the HMM in Fig. 1 —

These observations shed light on the striking power of the simple Markov model—it can obtain good predictions on average on any data-generating distribution given that its order scales with the mutual information of the sequence. These Markov models, with proper smoothing (i.e. “Kneyser-Ney Smoothing”), were essentially state of the art for natural language generation until till a few years back, and the result perhaps explains some of this success. It also strongly suggests that the widely used metric of average error may not be the right metric to train our algorithms if we want them to learn long-term dependencies, as the trivial Markov model can do pretty well on this metric—even though it is hampered by short-time memory loss (not unlike Memento’s protagonist!). More speculatively, some recent studies have claimed that many animals have very poor short term-memory, could it be that nature also opts for the Markov model when starved for resources?

As mentioned above, the data required to estimate a $\ell$-th order Markov model for an observations alphabet of size $d$ is at least $d^\ell$, as most sequences of $\ell$ observations might need to be observed. This prompts the question of whether it is possible to learn a successful predictor based on significantly less data. Without any additional assumptions on the structure of the sequence in question, the answer seems to be “no”.  As we show, even for sequences generated from an $n$ state HMM, any computationally efficient algorithm that learns an $\epsilon$-accurate predictor requires $d^{\Omega(\log n/\epsilon)}$ observations, assuming hardness of strongly refuting a certain class of CSPs. Read our full paper to find out more about these results!