LSTM are not dead yet
The world of Natural Language Processing is dominated by ever new types of BERT. Recently Sebastian Rudder tweeted an XKCD-inspired graphics which he called "Types of ML / NLP papers" – one frequently seen sub-species is this:
Sebastian is a respected member of the ML community and co-author of a 2018 paper called Universal Language Model Fine-tuning for Text Classification. In this paper he and Jeremy Howard propose several enhancements to a language model trained using a Long Short-Term Memory (LSTM) network. But why should anyone care about such "obsolete" techniques today? It turns out that three years after the transformer revolution recurrent models are still relevant. Even though they cannot beat BERT in terms of accuracy scores, chasing the highest numbers on an academic benchmark is often not the point of practical applications. Unless you have a cloud of TPUs at your disposal, as a practitioner you will often have to seek a compromise between F1 scores and training efficiency or inference times. This is the primary reason why we still think doing research on RNNs is worthwhile.
Apart from the joy of rediscovering the efficiency of a classical architecture, two more reasons to study recurrent networks when everyone else is bent on transformers are:
- There are very exciting RNN variants that incorporate convolutions. This effectively means that instead of training on word vectors, you are training on n-gram representations. In particular, the quasi-RNNs (QRNNs) blend both convolutions and reset / update gates known from LSTMs. In 2020 the ULMFiT's authors together with four other researchers (J.M Eisenschlos, Piotr Czapla, Marcin Kardas and Sylvain Gugger) refreshed their original model by replacing AWD-regularized LSTM cells with QRNN to produce a highly efficient model called MultiFiT.
- Recurrent networks are being rediscovered as student models in knowledge distillation. Earlier this year Google released a paper on distilling multilingual BERT into a bidirectional QRNN. With quantization they managed to shrink the weights size from 440 MB to just a little more than 1 MB!
What is ULMFiT
In a nutshell it’s a 3-layer LSTM neural network trained (and optionally fine-tuned) on the next token prediction task with some optimization tweaks.
If the distribution of the dataset used for the target task (e.g. document classification in (c)) differs substantially from the distribution of the pretrained corpus (a), it might be helpful to fine-tune the language model (b) earlier. For instance, sentiment analysis on labelled Twitter data (target task) works better if the original language model pre-trained on Wikipedia is fine-tuned on some unlabelled social media posts first.
To keep this post short, we have opted not to assault you with the maths and mechanics of LSTMs (but see this well-known excellent tutorial if you need a refresher) and instead we’d like to focus on three of the optimization tweaks mentioned earlier:
Several dropout techniques
Dropout provides regularization, which prevents the model from overfitting. In addition to plain dropout on embeddings, ULMFiT uses AWD, which is dropout on the recurrent weights matrix.
This is different from e.g. the
recurrent_dropout implementation in Tensorflow’s LSTMCell, which is applied to activations (not weights). There is an excellent write-up of all dropout techniques used in ULMFiT together with examples available in this Medium blog post.
Variable learning rate schedules
The learning rate schedules allow us to start with a slow learning rate, increase it (fairly quickly) to a peak value and decay it towards the end of training. There are two schedules commonly used with ULMFiT and described in depth in Sylvain Gugger’s post: the One-Cycle Policy and the Slanted Triangular Learning Rate used in the original ULMFiT paper. They both combine the benefits of simulated annealing (exploring a large solution space at peak rates) with stability (warmup and cool-off phases) and nudge the Adam optimizer to converge near the global optimum relatively quickly. A variable LR schedule also significantly reduces training time in comparison with a recurrent network trained using plain SGD and a fixed learning rate.
Rather than train all layers at once, ULMFiT builds on the intuition that the top layers of a sequential network learn “faster” than those below. As a consequence, when training a three layer network, we fix the first two layers for a couple of epochs and train only the top layer. We then unfreeze the middle layer, train for a couple of more epochs, and in the final epochs train all weights.
Finally with Tensorflow
Until now if you wanted to experiment with ULMFiT, your only choice was to use a (rather niche) FastAI framework developed primarily by ULMFiT's authors. At https://bitbucket.org/edroneteam/tf2_ulmfit/ we are releasing code that will allow you to work with models ported to Tensorflow.
These are by no means the first pre-trained ULMFiT models available. The FastAI framework comes with an English encoder. To the best of our knowledge the first Polish model was released for the PolEval 2018 (task 3) competition by Piotr Czapla et al. (repo).
The README.md file contains detailed instructions on how to build a document classifier, a regressor and a sequence tagger using Keras layers. There are also links to pretrained models for English and Polish. In addition, we will soon be making a submission of these models to TF Hub. We hope our effort will make this architecture more accessible to the NLP community.
The code in our repo provides – among other things – a Keras callback for applying AWD, which is the model's most important regularization technique after each batch. We also replicated as faithfully as we could all other dropouts described in the paper to the embeddings and subsequent recurrent layers. Additionally, our examples show how you can use Tensorflow's implementation of Slanted Triangular Learning Rates, the One-Cycle Policy scheduler and the learning rate finder.
edrone Blog Newsletter
Join the newsletter to receive the latest updates in your inbox.