A Memory Hack That Makes AI Smarter With Sequences

cover
1 Apr 2025

Authors:

(1) Hung Le, Applied AI Institute, Deakin University, Geelong, Australia;

(2) Dung Nguyen, Applied AI Institute, Deakin University, Geelong, Australia;

(3) Kien Do, Applied AI Institute, Deakin University, Geelong, Australia;

(4) Svetha Venkatesh, Applied AI Institute, Deakin University, Geelong, Australia;

(5) Truyen Tran, Applied AI Institute, Deakin University, Geelong, Australia.

Abstract & Introduction

Methods

Methods Part 2

Experimental Results

Experimental Results Part 2

Related Works, Discussion, & References

Appendix A, B, & C

Appendix D

D. Experimental Details

All the datasets and public codebases use Apache or MIT License. We trained all the models using a single GPU Tesla V100-SXM2. The running time of PANM depends on the Encoder and tasks. Overall, with 2 Mode-1 pointers and 1 Mode-2 pointer, PANM’s speed will be 70-80% compared to

Figure 4: PANM as a plug-and-play architecture. The encoder and decoder can be any model (LSTM, Transformer or BERT). PANM Controller can be used as the last layer of the Decoder to access the memory during decoding. To reduce the number of parameters of the augmented architecture, the decoder’s number of layers can be decreased.

the backbone model. For example, in Copy task, PANM’s speed is 15 iterations/s while LSTM’s is 20 iterations/s. If PANM uses Transformer Encoder, its speed is 77 iterations/s while Transformer’s is 90 iterations/s.

Baseline Choice Although our baselines are classic, they are still very strong baselines in our studied tasks. For example, in our algorithmic reasoning, LSTM with attention or Pointer Networks are still dominant baselines, outperforming the more recent Transformer. In Dyck recognition, stack-based models are still SOTA because their inductive bias is suitable for the task. Experiments in Sec. 3 adopt (Universal) Transformer+RPE, which is a recent and strong Transformer variant focusing on generalization. There are also other sophisticated methods focusing generalization [Webb et al., 2020].

In our experiments, PANM is ensured to have similar model size as the baselines and often built on top of similar backbones for fair comparison. We believe it is still important to improve fundamental baselines such as Transformers or LSTM because they are the building blocks of many practical applications including recent Large Language Models (LLMs). In this paper, we prove the improvement of these fundamental blocks, and in future works, we will extend our ideas to more advanced backbones such as LLMs.

D.1 Algorithmic Reasoning

We first give the details of the content-based tasks below.

In Dynamic Recall, an arbitrary input token is chosen as the query and is added to the end of the input sequence. Depending on the length of the input, a.k.a, odd or even, the first target token will be on the left or right of the query, following its succeeding tokens in the input. This task requires both content matching (find query token in the input sequence) and position-based access (shift left or right).

In Priority Sort, each input token is associated with a priority score sampled from the standard normal distribution. The target output will be tokens from the input sequence sorted ascending by their the score. This task can be solved in many ways and likely needs complicated symbol processing such as looping through items in the sequence and comparing the score of tokens.

Finally, in ID Sort, each input token is augmented with an id feature vector sampled from standard multivariate normal distribution such that every 2 tokens share one id. For example, with input x1, x2, x3, x4, x1 and x4 may share one id while x2 and x3 shares another id. The pairing is chosen randomly. The output token at position i-th will be the input token that share id with the i-th input token. The correct output for the earlier example is x4, x3, x2, x1. This task is specifically designed to test the ability to learn Mode 2 pointer-based memory access.

In this task, we implement the baselines such as LSTM, attention models and Transformer using Pytorch library. The hidden state dimension for these models are set to 512, which results in around 1-3 millions parameters. We tuned the number of layers of the encoder/decoder for these baselines in Copy task, and realized that 1-layer gave the best performance. For NTM and DNC, we use public repositories[1] with default controller’s hidden state of 256 dimensions and 128-slot external memory, which results in around 1.2 millions parameters. We use the ESBN’s author codebase [2] with default parameter setting, resulting in ≈1.2 million parameters. For PtrNet, since we do not use token index as the training label, we produce the predicted token by performing weighted sum the input tokens using the PtrNet’s attention weights. PtrNet’s hyperparameters are the same as attention models. We could not find the authors’ code for Neural Stack and NRAM so we implemented them and tuned hyperparameters for the Copy task at length L such that the model sizes are about 1.1 million parameters. In this task PANM uses LSTM with hidden state of 256 as the Encoder and does not stack the Controller on any decoder models, resulting in ≈1.1 million parameters.

Overall, PANM observes significant improvement ranging from 10-20% on each task. We note that when compared with individual baselines, the improvement is much higher. Consider Copy as an example (Fig. 2a), PANM outperforms the worst baseline Transformer by around 60% at 2(L + 1) and 30% at 4(L + 1), respectively. As stated earlier that our tasks are challenging, thus, originally strong baselines such as NTM, DNC, and Neural Stack do not generalize well at extreme lengths, especially in ID Sort. ID Sort is trickier than content-free tasks, making some baselines fail at length L even though it is in the training data. The best other model in this case is Content Attention, which clearly underperforms our PANM from few % to 50% (Fig. 2b). Without curriculum learning and under the 10-class prediction setting, methods that use implicit pointers, such as PtrNet, NRAM, and ESBN, demonstrate mediocre performance on average when compared to PANM. Furthermore, PANM also outperforms in length-dependent tasks (Mix, D. Recall), indicating that it can track the sequence length in extrapolation. We hypothesize that PANM’s content–free pointer generation mechanism to simulate list iteration makes it possible.

Figure 5: Dyck (Left): mean ± std. accuracy over 5 runs with different testing lengths. bAbI QA (Right): mean ± std. testing accuracy and cross-entropy loss across 100 training epochs over 5 runs.

In Copy, only Mode-1 access is needed. As decoding step t increases, Pointer Unit generates p a t following the increment of the addresses as expected. That said, for several steps, the address attention is not sharp, showing other addresses pointed by the pointer, which is not surprising since we use soft attention and it is hard for a neural network to learn the exact rule: p a t+1 = p a t + 1. This problem gets worse as test length increases as the error accumulates, especially when the same token can appear many times, which confuses the model. This explains why PANM’s performance drops clearly in the hardest case 8(L + 1). Yet, it is still significantly better than others whose results are near random prediction.

D.2 Dyck Language Recognition

D.3 Compositional Learning

SCAN The training size is 16990 and the test size is 3920. SCAN is a well-known and standard benchmark for testing compositional learning and generalization in sequential models. One property of this dataset is that a new length often contains new rules that must be captured by the model to ensure generalization, and thus, if the model fails to learn a hidden rule, its performance may drop significantly from one length split to another. Fig. 6 illustrates PANM’s testing accuracy curves when L = 22, 24, 25, 26. Other learning curves for L > 26 looks similar to L = 26 where PANM easily solves the task perfectly.

Figure 6: SCAN: PANM’s exemplar learning curves.

Mathematical Problems Table 13 reports the accuracy with mean and standard deviation. Here, we augment TRM and TRM+RPE with PANM. Both shows improvement, especially for TRM+RPE, indicating that PANM is compatible with other methods designed to improve generalization in Transformer.

D.4 Other NLP Tasks

The bAbI dataset consists of 20 synthetic tasks that evaluate various reasoning skills. To prepare the data for each task, we combine train/valid/test into a single set and sort it by length and split it into training and testing sets, as described in the main text. We train the models jointly on all 20 tasks and measure the accuracy of their answers, which are considered correct only if they match the ground truth answer perfectly. The training/evaluation follows exactly the standard protocol presented in [Le et al., 2020b]. The Transformer used here has 8 heads, 3 layers of encoding, 3 layers of decoding, and hidden dimensions of 512. PANM uses the same Transformer backbone except that the decoder has 2 layers to make the model size equivalent. We run each model 5 times to report the mean and standard deviation as in Fig. 5 (right). Table 3 reports the detailed numbers.

The SQUAD dataset contains more than 100K realistic context/question-answer pairs. Again, we combine train/test into a single set and sort it by length and split into new train/test sets. Following Kenton and Toutanova [2019], we use BERT model (https://huggingface.co/bert-base-uncased) to predict the start and end location of the answer, and finetune the model with the same setting (e.g., 3 epochs with a learning rate of 5e-5) except that our batch size is 16 to fit with our GPU. PANM appends the Controller to BERT to predict the start and end. Both BERT and PANM have around 114 million parameters. Table 4 reports the detailed numbers.

D.5 Additional Experiments

Pointer Hyperparameters In this section, we confirm the logic presented in Appendix C by performing experiments that involve varying the number and type of pointers.

Figure 7: Testing accuracy (mean ± std.) at 2(L+1) length over training steps. Different configurations of Mode-1 pointers are trained and evaluated 5 times.

Figure 8: Testing accuracy (mean ± std.) at 2(L+1) length over training steps. Different configurations of Mode-2 pointers are trained and evaluated 5 times.

Table 6: Failure of Chat-GPT on algorithmic reasoning test cases of length 2L. Token-level accuracy is reported. We do not test Chat-GPT on Priority and ID sort because they have complicated token representations. PANM results cannot be directly compared, and shown for reference only.

D.6 Failures of Chat-GPT in Our Tasks

Large Language Models (LLMs), especially Chat-GPT, have shown remarkable results in reasoning and generalization. Directly comparing Chat-GPT with other models used in our experiments would be unfair because Chat-GPT was not directly trained with our datasets and it has much more parameters than our model. Therefore, in this section, we merely use Chat-GPT as a tool to verify that our chosen tasks, despite being simple, are non-trivial. The evaluated tasks are algorithmic reasoning and SCAN. We do not examine Dyck recognition because the output encoding is complicated to represent in text. Other datasets are more common and likely to be used for training Chat-GPT, thus, are not suitable for generalization test. For example, in Mathematics task, if we ask Chat-GPT the question from the data What is the hundreds digit of 31253?, it provide the correct answer (2). However, slightly modifying the question to ensure it does not appear in the training and testing set will successfully fool Chat-GPT:

• Example 1:

– Prompt: What is the hundreds digit of 312537?

– Chat-GPT answer: The hundreds digit of the number 312537 is 2.

• Example 2:

– Prompt: What is the hundreds digit of 319253?

– Chat-GPT answer: The hundreds digit of the number 319253 is 9.

We use Open AI’s Chat-GPT 3.5 version September and evaluate the model on our data using few-shot example prompts, following the format:

Algorithmic Reasoning To ensure that Chat-GPT does not memorize the output answer from its vast training data, we use non-digit symbols: ˜!@#$%ˆ&*( as 10 tokens of the datasets. For each task, we sample 20 training examples of length L = 5 to build the in-context examples, and test on 1 longer sequence of length 2L = 10. We conduct 20 trials and report the average test accuracy. Table 6 summaries the evaluation results. Overall, except for Copy task where Chat-GPT shows excellent generalization, other tasks are very hard for Chat-GPT, indicating that the length extrapolation problem still poses a big challenge to today AI techniques.

SCAN In this task, we sample 20 examples in the L-cutoff=40 split set (easiest) as in-context learning examples and evaluate on 10 unseen sequences. Chat-GPT totally failed in this task. When

Table 7: Copy: accuracy (mean ± std. over 5 runs)

Table 8: Reverse: accuracy (mean ± std. over 5 runs)

testing on the similar length or longer length as the examples, Chat-GPT cannot produce any exact match results (exact match accuracy=0). Below are some failure examples:

• IN: walk and turn opposite right OUT:

– Chat-GPT output: I TURN RIGHT I TURN RIGHT I WALK

– True output: I WALK I TURN RIGHT I TURN RIGHT

• IN: run around left twice and run around right OUT:

– Chat-GPT output: I RUN I TURN LEFT I RUN I TURN LEFT I RUN I TURN RIGHT I RUN

– True output: I TURN LEFT I RUN I TURN LEFT I RUN I TURN LEFT I RUN I TURN LEFT I RUN I TURN LEFT I RUN I TURN LEFT I RUN I TURN LEFT I RUN I TURN LEFT I RUN I TURN RIGHT I RUN I TURN RIGHT I RUN I TURN RIGHT I RUN I TURN RIGHT I RUN

Table 9: Mix: accuracy (mean ± std. over 5 runs)

Table 10: Drecall: accuracy (mean ± std. over 5 runs)

Table 11: PSort: accuracy (mean ± std. over 5 runs)

Table 12: ID Sort: accuracy (mean ± std. over 5 runs)

Table 13: Mathematics: mean ± std accuracy over 5 runs. ♣ are numbers from Csord´as et al. [2021]. ♦ is our rerun to confirm the results, which, in some cases, could not match the reported numbers. - means training crash reported in the original papers. We run PANM using the authors’ codebase.

This paper is available on arxiv under CC BY 4.0 DEED license.


[1] https://github.com/thaihungle/SAM

[2] https://github.com/taylorwwebb/emergent_symbol

[3] https://github.com/suzgunmirac/marnns

[4] https://github.com/RobertCsordas/transformer_generalization