Motivation

This idea was inspired by a talk that Albert Gu gave at Simons Institute Workshop Transformers as a Computational Model (link to the talk). While listening to this talk, I began to think about how we might design a new memory model by combining features from both State Space Models and transformers.

State Space Models ( Citation: et al., , & (). Efficiently modeling long sequences with structured state spaces. CoRR, abs/2111.00396. Retrieved from https://arxiv.org/abs/2111.00396 ) compress all the history in a finite memory. By contrast, transformers ( Citation: et al., , , , , , , & (). Attention is all you need. CoRR, abs/1706.03762. Retrieved from http://arxiv.org/abs/1706.03762 ) keep all the context. Intuitively, transformers function like a database with $O(n)$ memory growth, while SSMs focus on compressing the most critical information from the context into $O(1)$ memory.

Can we use the best of both worlds and design a graph-like memory model that dynamically adds new nodes (like transformers, but with sublinear growth) while also learning to compress information hierarchically into rich latent representations (like SSMs)?

Human memory is associative; we remember concepts by their relationship to each other. Hopfield networks provide one of the earliest examples of associative memory in artificial neural networks. In Hopfield Network is All You Need ( Citation: et al., , , , , , , , , , , , , , & (). Hopfield networks is all you need. CoRR, abs/2008.02217. Retrieved from https://arxiv.org/abs/2008.02217 ) Hopfield layers are introduced and their convergence and capacity properties are analyzed.

In Memory Networks ( Citation: et al., , & (). Memory networks. Retrieved from https://arxiv.org/abs/1410.3916 ) a memory design with four components is introduced:

  • feature map $I$: this encodes the data in the latent space
  • generalization component $G$: updates the memory with the new data
  • output feature map $O$: produces output based on query
  • response component $R$: this is essentially the decoder

This is how the output feature map operates: $$O(x, m) = \argmax_{i=1,…,N} s_O(x, m_i)$$ where $m_i$ are the memory elements, and $x$ is the featurized query, and $s_O$ is a score function. The score function they proposed was: $$ s_O(x, y) = \Phi_x(x)^T U^T U \Phi_y(y) $$

Interestingly, this is conceptually similar to the attention mechanism in which we assign score $x^t \frac{Q^T K}{\sqrt{d_k}} y$ to a feature pair $x, y$. In this context, $U$ from above can be derived through Cholesky decomposition and the softmax function in attention replaces argmax in this formulation.

Dynamic Neural Memory

Can we combine the best of both worlds: the compression mechanism of SSMs and the dynamic, database-like nature of transformers, to design a memory mechanism that grows sublinearly, perhaps $O(\log n)$ with respect to input size? This would not only address transformers’ context length limitation but also provide a balance between understanding/compressing and memorizing. Essentially, this is a tradeoff between memory and runtime search.

It’s very natural to think of memory as a graph of concepts where each concept has its own embedding.

Adding new concepts involves traversing this graph, adding new nodes and edges, and enriching the latent representations of existing nodes. Retrieving information would involve a similar traversal, using the encoded query in the latent space and aggregating features from the visited nodes.

A simple first step could be to design a neural binary search tree. In this design, each node has three learnable, parameterized functions: update, stop, and router. Adding new data could look like this:

def add_concept(memory_node, concept):
    update(memory_node, concept)
    if stop(memory_node, concept):
        return
    next_memory_node = router(memory_node, concept)
    add_concept(memory_node, concept)

new_info = encode(information)
add_concept(memory_node=memory.root, concept=new_info)

In a graph-based structure, router could be implemented as an attention block. Rather than routing to a single node, it could calculate a probability distribution over adjacent nodes and pass the probability flow down the graph.

Potential Issues

  • Learnability: How do we define a learning mechanism that ensures convergence? In particular, designing an objective function for training may be complex due to the non-linear, high-dimensional updates involved in memory storage and retrieval.
  • Scalability: The process described above is sequential. How can we limit the number of sequential steps required for memory updates, or find a balance between traversal depth (i.e., “thinking time”) and accuracy?

Final Thoughts

The attention mechanism, at its core, functions as a read-memory operation. It became a breakthrough for deep learning due to its generality and scalability. However, attention-based memory operates more like a database and less like natural memory, which is hierarchically organized and capable of efficient retrieval based on relationships between concepts.

An improved memory design could offer hierarchical data storage, enhance generalization, improve data efficiency, and overcome the context length limitations in transformers. These advancements could significantly impact how we approach sequential problem-solving in NLP and beyond.

One motivating implication for me is in-context learning. Think of the memory model in world models ( Citation: & , & (). World models. https://doi.org/10.5281/ZENODO.1207631 ) . As an agent interacts with its environment, it receives a continuous stream of data, which it uses to build a representation of the environment. The effectiveness of the memory model determines how well the agent learns from past experiences and assigns credit to rewards. Integrating memory models into reinforcement learning has been explored in works such as Reinforcement Learning as One Big Sequence Modeling Problem ( Citation: et al., , & (). Reinforcement learning as one big sequence modeling problem. CoRR, abs/2106.02039. Retrieved from https://arxiv.org/abs/2106.02039 ) and Decision Transformer: Reinforcement Learning via Sequence Modeling ( Citation: et al., , , , , , , , & (). Decision transformer: Reinforcement learning via sequence modeling. Retrieved from https://arxiv.org/abs/2106.01345 ) .

Citation

Cited as:

Shayan Pardis. (Oct 2024). Dynamic Neural Memory for In-Context Learning: SSMs or Transformers?. https://shayanp.me/ideas/neural-memory-for-in-context-learning/.

Or
@article{ shayan-pardis2024dynamic-neural-memory-for-in-context-learning-ssms-or-transformers,
  title   = "Dynamic Neural Memory for In-Context Learning: SSMs or Transformers?",
  author  = "Shayan Pardis",
  year    = "2024",
  month   = "Oct",
  url     = "https://shayanp.me/ideas/neural-memory-for-in-context-learning/"
}

References

[1] et al. Attention is all you need CoRR (2017)
, , , , , , & (). Attention is all you need. CoRR, abs/1706.03762. Retrieved from http://arxiv.org/abs/1706.03762
[2] et al. Hopfield networks is all you need CoRR (2020)
, , , , , , , , , , , , , & (). Hopfield networks is all you need. CoRR, abs/2008.02217. Retrieved from https://arxiv.org/abs/2008.02217
[3] et al. Reinforcement learning as one big sequence modeling problem CoRR (2021)
, & (). Reinforcement learning as one big sequence modeling problem. CoRR, abs/2106.02039. Retrieved from https://arxiv.org/abs/2106.02039
[4] et al. Efficiently modeling long sequences with structured state spaces CoRR (2021)
, & (). Efficiently modeling long sequences with structured state spaces. CoRR, abs/2111.00396. Retrieved from https://arxiv.org/abs/2111.00396
[5] et al. Decision transformer: Reinforcement learning via sequence modeling (2021)
, , , , , , , & (). Decision transformer: Reinforcement learning via sequence modeling. Retrieved from https://arxiv.org/abs/2106.01345
[6] & World models (2018)
& (). World models. https://doi.org/10.5281/ZENODO.1207631
[7] et al. Memory networks (2015)
, & (). Memory networks. Retrieved from https://arxiv.org/abs/1410.3916