newcohospitality.com

Exploring xLSTM: The New Contender Against Transformers

Written on

xLSTM vs. Transformers: A New Era in AI

Transformers have dominated discussions in the AI field for a significant time. However, prior to their ascent, Long Short-Term Memory (LSTM) networks were the standard bearers. Initially created to address the vanishing gradient issue associated with Recurrent Neural Networks (RNNs), LSTMs paved the way for various advancements. Recently, the spotlight has shifted to a novel architecture called xLSTM, which not only matches the performance of Transformers but, in certain scenarios, surpasses them.

For further reading on Mamba, a state space model that builds on LSTM concepts, click here.

Let’s delve into this fascinating new research without delay.

Table of Contents

  • Understanding RNN
  • An Intro to LSTM
  • What Does xLSTM Offer?
  • Decoding sLSTM and mLSTM Block
  • xLSTM Architecture
  • Conclusion

Understanding RNN

Recurrent Neural Networks are a unique type of neural network with remarkable capabilities despite their limitations. Traditional neural networks, including Convolutional Networks, have rigid structures: they take a fixed-size vector as input (such as an image) and produce a fixed-size vector as output (like class probabilities).

These models execute this mapping using a predetermined number of computational steps, defined by the model's layers. RNNs, however, offer a more exciting approach because they handle sequences of vectors—whether in the input, output, or both. Here are some concrete examples:

The sequence-based operation of RNNs is significantly more powerful than that of static networks, which are inherently limited by their fixed computational steps, making RNNs appealing for applications like language modeling. They merge input vectors with their internal state vectors using a learned function to generate a new state vector. This perspective allows us to view RNNs as programs capable of simulating arbitrary computations, confirming their Turing-completeness.

An Intro to LSTM

Reflecting on my experience with LSTMs, I recall grappling with their complexity. Understanding each component is one thing, but grasping the overall architecture is another challenge. Let's briefly outline the LSTM structure.

In the LSTM architecture, we identify several operations. The variable C_t-1 represents the long-term memory of the network up to the previous input token t-1. The hidden state h_t-1 serves as the short-term memory. Lastly, Z_t denotes our input sequence.

The LSTM comprises three gates: the Input Gate, Forget Gate, and Output Gate.

The cell state C_t is formed by the Forget Gate, which determines what to discard, combined with the Input Gate i_t. This integration is crucial for addressing the vanishing gradient issue noted in the xLSTM literature.

The vanishing gradient issue can be illustrated as follows:

For instance, if asked who went to the store to buy a drink, the LSTM correctly identifies "Tom" in a simple sentence. However, as the sentence length increases, the model struggles to connect "He" back to "Tom" in a longer context.

For a deeper understanding of LSTMs, consider Stanford's renowned lecture on RNNs and LSTMs:

What Does xLSTM Offer?

LSTMs face three primary challenges:

  1. Inflexibility in storage decisions: LSTMs cannot effectively revise stored values when new, similar vectors are introduced. In contrast, xLSTM addresses this through exponential gating, allowing retrieval of previously forgotten information.
  2. Storage limitations: Information must be compressed into scalar cell states, hindering LSTM performance, especially with rare tokens. xLSTM resolves this by utilizing matrix memory.
  3. Limited parallelization: The memory mixing in LSTMs restricts their ability to leverage GPU parallelization fully.

To tackle these shortcomings, Extended Long Short-Term Memory (xLSTM) introduces two key innovations: exponential gating and advanced memory structures. This leads to two new variants: sLSTM, which employs scalar memory, and mLSTM, which utilizes matrix memory with a covariance update rule, allowing for full parallelization.

Both sLSTM and mLSTM enhance LSTMs through exponential gating. The mLSTM model eliminates memory mixing, enabling better parallelization. Both models can utilize multiple memory cells, with sLSTM allowing for memory mixing across cells and heads. This innovative approach to memory mixing distinguishes the two architectures.

Decoding sLSTM and mLSTM Block

The exponential function in gating mechanisms allows for more dynamic and responsive updates. Unlike the sigmoid function, which restricts values between 0 and 1, the exponential function accommodates a broader range of values, enhancing the model's ability to learn intricate patterns.

The above image illustrates that using the exponential function significantly increases output expressivity compared to the sigmoid function.

Traditionally, sigmoid functions were favored to maintain values within a specific range; however, normalization techniques are now implemented in xLSTM to manage large values arising from the exponential function, ensuring stable operations.

By incorporating both exponential and logarithmic functions, xLSTM achieves a controlled value scaling. The exponential function facilitates handling large input variations, while the logarithmic function prevents numerical instability.

mLSTMs

In mLSTMs, we introduce the concept of parallel computation. While traditional LSTM cell states are scalar, mLSTM states transform into matrices. Moreover, the Input Gate is adapted to a key-value pair structure.

The key-value mechanism in mLSTM, inspired by Bidirectional Associative Memories (BAMs), allows for efficient information storage and retrieval, enabling the model to adaptively modify its memory based on new, relevant data. This dynamic updating enhances performance in tasks requiring real-time adaptations, such as language modeling and sequence prediction.

xLSTM Architecture

Now, let’s examine the detailed architecture of sLSTM. This design utilizes a post-up projection.

Embedded in a pre-LayerNorm residual structure, the input undergoes optional causal convolution with a window size of 4, incorporating a Swish activation for the input and forget gates. The input, forget, and output gates, along with the cell update, flow through a block-diagonal linear layer with four diagonal blocks or "heads". The resulting hidden state is processed through a GroupNorm layer, followed by up- and down-projection using a gated MLP with a GeLU activation function.

Next, we analyze the architecture of mLSTM, which employs a pre-up projection.

Similar to sLSTM, mLSTM is also structured within a pre-LayerNorm residual framework. Input is first up-projected with a factor of 2 and undergoes dimension-wise causal convolution before entering a learnable skip connection. The model employs block-diagonal projection matrices for input and key-value pairs, while values are directly fed in, bypassing the convolution stage. After mixing outputs, normalization is applied through GroupNorm, followed by gating with the external output gate.

Due to the extensive results presented in the original paper, we will forgo discussing them here. For a comprehensive understanding, please refer to the original study at https://arxiv.org/pdf/2405.04517.

Conclusion

This paper addresses a fundamental question: How far can we push language modeling by scaling LSTM architectures to billions of parameters? The evidence suggests that xLSTM can reach at least the performance levels of current technologies such as Transformers or State Space Models. By implementing exponential gating and innovative memory structures, xLSTM demonstrates strong performance in language modeling, competing favorably against state-of-the-art methods. These scaling laws imply that larger xLSTM models may become formidable contenders to existing Large Language Models built on Transformer technology. Additionally, xLSTM holds promise for significant contributions across various deep learning domains, including Reinforcement Learning, Time Series Prediction, and modeling physical systems.

Creating such articles demands substantial effort and time. Your support through claps and shares is greatly appreciated. Your engagement inspires me to continue writing about cutting-edge AI topics with clarity and simplicity. Don’t miss future insights—be sure to follow me on [Your Platform]. Happy learning!

Don’t forget to subscribe to the AIGuys Digest Newsletter.

In Plain English

Thank you for being part of the In Plain English community! Before you leave:

  • Be sure to clap and follow the writer!
  • Follow us on: X | LinkedIn | YouTube | Discord | Newsletter
  • Explore our other platforms: Stackademic | CoFeed | Venture | Cubed
  • Tired of algorithm-driven blogging platforms? Try Differ.
  • More content available at PlainEnglish.io.