Transformers are Multi-State RNNs

2401.06104

YC

41

Reddit

0

Published 6/19/2024 by Matanel Oren, Michael Hassid, Nir Yarden, Yossi Adi, Roy Schwartz
Transformers are Multi-State RNNs

Abstract

Transformers are considered conceptually different from the previous generation of state-of-the-art NLP models - recurrent neural networks (RNNs). In this work, we demonstrate that decoder-only transformers can in fact be conceptualized as unbounded multi-state RNNs - an RNN variant with unlimited hidden state size. We further show that transformers can be converted into $textit{bounded}$ multi-state RNNs by fixing the size of their hidden state, effectively compressing their key-value cache. We introduce a novel, training-free compression policy - $textbf{T}$oken $textbf{O}$mission $textbf{V}$ia $textbf{A}$ttention (TOVA). Our experiments with four long range tasks and several LLMs show that TOVA outperforms several baseline compression policies. Particularly, our results are nearly on par with the full model, using in some cases only $frac{1}{8}$ of the original cache size, which translates to 4.8X higher throughput. Our results shed light on the connection between transformers and RNNs, and help mitigate one of LLMs' most painful computational bottlenecks - the size of their key-value cache. We publicly release our code at https://github.com/schwartz-lab-NLP/TOVA

Create account to get full access

or

If you already have an account, we'll log you in

Overview

  • Examines the relationship between transformers and recurrent neural networks (RNNs)
  • Proposes that transformers can be viewed as a type of multi-state RNN
  • Explores the implications of this perspective for understanding transformer models

Plain English Explanation

Transformers are a type of deep learning model that have become very popular in recent years, particularly for tasks like language modeling and machine translation. At a high level, transformers work by [attention] - they can "focus" on the most relevant parts of their input when generating an output.

This paper argues that transformers can actually be thought of as a special type of [recurrent neural network (RNN)]. RNNs are a class of models that process sequential data one element at a time, maintaining an internal "state" that gets updated as the sequence is processed. The authors suggest that transformers can be viewed as a multi-state RNN, where the attention mechanism allows the model to dynamically update multiple distinct states as it processes the input.

This new perspective on transformers has some interesting implications. It may help us better [understand the inner workings of transformer models] and how they differ from traditional RNNs. It could also lead to new ways of [designing and training transformer-based models], drawing on the rich history and techniques developed for RNNs.

Technical Explanation

The paper first provides background on [RNNs] and [transformers]. RNNs are a class of neural network models that process sequences one element at a time, maintaining an internal state that gets updated as the sequence is processed. Transformers, on the other hand, use an [attention mechanism] to dynamically focus on relevant parts of the input when generating an output.

The key insight of the paper is that transformers can be viewed as a type of multi-state RNN. The attention mechanism in transformers allows the model to dynamically update multiple distinct internal states as it processes the input sequence. This is in contrast to traditional RNNs, which maintain a single, monolithic state.

To support this claim, the authors [analyze the mathematical structure of transformers] and show how it can be expressed as a multi-state RNN. They also [demonstrate empirically] that transformers exhibit behaviors characteristic of multi-state RNNs, such as the ability to remember and utilize past information in a targeted way.

Critical Analysis

The authors make a compelling case that transformers can be fruitfully viewed as a type of multi-state RNN. This perspective may [help bridge the gap between transformer and RNN research], allowing insights and techniques from the well-established RNN literature to be applied to transformers.

However, the paper does not fully [address the limitations of this analogy]. Transformers have many unique architectural features, like the use of self-attention, that may not have direct analogues in traditional RNNs. The extent to which this analogy can be pushed, and what insights it can actually yield, remains an open question.

Additionally, the [experimental evidence] provided, while suggestive, is somewhat limited. More thorough investigations, perhaps comparing the performance and behaviors of transformers and multi-state RNNs on a wider range of tasks, would help strengthen the claims made in the paper.

Conclusion

This paper presents a novel perspective on transformer models, suggesting they can be viewed as a type of multi-state RNN. This insight could [lead to new ways of understanding and designing transformer-based models], drawing on the rich history and techniques developed for RNNs.

While the paper makes a compelling case, there are still open questions and limitations to this analogy that warrant further exploration. Nonetheless, this work represents an important step in [bridging the gap between transformers and other sequential modeling approaches], and could have significant implications for the future development of deep learning architectures.



This summary was produced with help from an AI and may contain inaccuracies - check out the links to read the original source documents!

Related Papers

Small-E: Small Language Model with Linear Attention for Efficient Speech Synthesis

Small-E: Small Language Model with Linear Attention for Efficient Speech Synthesis

Th'eodor Lemerle, Nicolas Obin, Axel Roebel

YC

0

Reddit

0

Recent advancements in text-to-speech (TTS) powered by language models have showcased remarkable capabilities in achieving naturalness and zero-shot voice cloning. Notably, the decoder-only transformer is the prominent architecture in this domain. However, transformers face challenges stemming from their quadratic complexity in sequence length, impeding training on lengthy sequences and resource-constrained hardware. Moreover they lack specific inductive bias with regards to the monotonic nature of TTS alignments. In response, we propose to replace transformers with emerging recurrent architectures and introduce specialized cross-attention mechanisms for reducing repeating and skipping issues. Consequently our architecture can be efficiently trained on long samples and achieve state-of-the-art zero-shot voice cloning against baselines of comparable size. Our implementation and demos are available at https://github.com/theodorblackbird/lina-speech.

Read more

6/12/2024

Does Transformer Interpretability Transfer to RNNs?

Does Transformer Interpretability Transfer to RNNs?

Gonc{c}alo Paulo, Thomas Marshall, Nora Belrose

YC

0

Reddit

0

Recent advances in recurrent neural network architectures, such as Mamba and RWKV, have enabled RNNs to match or exceed the performance of equal-size transformers in terms of language modeling perplexity and downstream evaluations, suggesting that future systems may be built on completely new architectures. In this paper, we examine if selected interpretability methods originally designed for transformer language models will transfer to these up-and-coming recurrent architectures. Specifically, we focus on steering model outputs via contrastive activation addition, on eliciting latent predictions via the tuned lens, and eliciting latent knowledge from models fine-tuned to produce false outputs under certain conditions. Our results show that most of these techniques are effective when applied to RNNs, and we show that it is possible to improve some of them by taking advantage of RNNs' compressed state.

Read more

4/10/2024

On Limitation of Transformer for Learning HMMs

On Limitation of Transformer for Learning HMMs

Jiachen Hu, Qinghua Liu, Chi Jin

YC

0

Reddit

0

Despite the remarkable success of Transformer-based architectures in various sequential modeling tasks, such as natural language processing, computer vision, and robotics, their ability to learn basic sequential models, like Hidden Markov Models (HMMs), is still unclear. This paper investigates the performance of Transformers in learning HMMs and their variants through extensive experimentation and compares them to Recurrent Neural Networks (RNNs). We show that Transformers consistently underperform RNNs in both training speed and testing accuracy across all tested HMM models. There are even challenging HMM instances where Transformers struggle to learn, while RNNs can successfully do so. Our experiments further reveal the relation between the depth of Transformers and the longest sequence length it can effectively learn, based on the types and the complexity of HMMs. To address the limitation of transformers in modeling HMMs, we demonstrate that a variant of the Chain-of-Thought (CoT), called $textit{block CoT}$ in the training phase, can help transformers to reduce the evaluation error and to learn longer sequences at a cost of increasing the training time. Finally, we complement our empirical findings by theoretical results proving the expressiveness of transformers in approximating HMMs with logarithmic depth.

Read more

6/7/2024

Attention as an RNN

Leo Feng, Frederick Tung, Hossein Hajimirsadeghi, Mohamed Osama Ahmed, Yoshua Bengio, Greg Mori

YC

0

Reddit

0

The advent of Transformers marked a significant breakthrough in sequence modelling, providing a highly performant architecture capable of leveraging GPU parallelism. However, Transformers are computationally expensive at inference time, limiting their applications, particularly in low-resource settings (e.g., mobile and embedded devices). Addressing this, we (1) begin by showing that attention can be viewed as a special Recurrent Neural Network (RNN) with the ability to compute its textit{many-to-one} RNN output efficiently. We then (2) show that popular attention-based models such as Transformers can be viewed as RNN variants. However, unlike traditional RNNs (e.g., LSTMs), these models cannot be updated efficiently with new tokens, an important property in sequence modelling. Tackling this, we (3) introduce a new efficient method of computing attention's textit{many-to-many} RNN output based on the parallel prefix scan algorithm. Building on the new attention formulation, we (4) introduce textbf{Aaren}, an attention-based module that can not only (i) be trained in parallel (like Transformers) but also (ii) be updated efficiently with new tokens, requiring only constant memory for inferences (like traditional RNNs). Empirically, we show Aarens achieve comparable performance to Transformers on $38$ datasets spread across four popular sequential problem settings: reinforcement learning, event forecasting, time series classification, and time series forecasting tasks while being more time and memory-efficient.

Read more

5/29/2024