Adaptive Aggregation Networks: Don’t Forget What You Learned
Written by Charles Yuan. A discussion on the paper title “Adaptive Aggregation Networks for Class-Incremental Learning”.
When you really think about it, deep learning algorithms are fundamentally restricted from achieving true artificial general intelligence. For those who have ever constructed and trained a neural network, you would likely be familiar with the term batch gradient descent, and by extension mini-batch gradient descent and stochastic gradient descent. Essentially, the conventional approach for training neural networks is to do it all at once, either on all of your data or select batches of it. Either way, once the set number of training epochs has been reached, the training terminates and cannot be re-initiated. This is known as static training, and is the general practice when one has a fixed dataset. However, it is clear that this is not conducive to creating an artificial intelligence that can truly learn like humans. After all, our learning never stops!
For those who are familiar with reinforcement learning, you may have heard of the concept of continual learning. When you graduate from one grade in school and start learning new concepts in the next, you don’t suddenly forget all the concepts you learned previously. Or at least, I hope not! This is the problem associated with static training. Catastrophic forgetting occurs when a previously-trained machine learning model is exposed to new data and, in turn, forgets what it learned before . When you retrain a neural network on an updated version of your previous dataset, it ends up forgetting everything it learned before. Even if you saved the weights or utilized a method like transfer learning, there is still no system in place to prevent the model from over-prioritizing the new data and forgetting the old. This is the fundamental flaw of conventional deep learning algorithms; they cannot learn continually.
The task of continual learning for deep learning algorithms can be classified into two categories:
- Training on new examples of existing classes, known as online learning.
- Training on new classes, known as class-incremental learning (CIL).
While both are important, the focus of this article will be on the acquisition of new classes and, ultimately, whether or not a neural network can learn new things while remembering the old. To keep things simple, let’s focus on a task that you’ve probably heard of before: Image classification.
“As the field of computer vision moves closer towards artificial intelligence it becomes apparent that more flexible strategies are required to handle the large-scale and dynamic properties of real-world object categorization situations.”  This line is from what is arguably the first paper written on the task of class-incremental learning, titled “iCaRL: Incremental Classifier and Representation Learning” by Rebuffi et al. . Utilizing a ResNet-32 backbone and a novel classification method known as nearest-mean-of-exemplars, iCaRL relies on a set of exemplar images that it selects out of the data stream. In order to reduce computational overhead, the total number of exemplar images never exceeds a fixed number. Instead, the number of examples per class in the stored set is reduced whenever a new one is introduced. The network is then trained to output the correct class indicator for the new class(es), using classification loss, and for the old classes to compare to how it performed before, using distillation loss . This method of training was one of the first shown to prevent the issue of catastrophic forgetting.
Adaptive Aggregation Networks (AANets)
So why did we talk about iCaRL first? Well, that’s because the main paper that we are reviewing today relies on much of the previous work done in the field of class-incremental learning! Namely, it addresses one of the core problems of CIL, which is maintaining a balance between retaining knowledge from old classes and learning new ones. The stability-plasticity dilemma limits how well a learning algorithm can perform the task of balancing stability, the ability to not forget previous knowledge, and plasticity, the ability to acquire new knowledge . While past CIL algorithms had their own methods of tackling the stability-plasticity dilemma, it still resulted in a significant decrease in overall accuracy across classes as training went on .
Adaptive Aggregation Networks (AANets) is based on a ResNet-type architecture, with each of its residual levels comprised of two types of blocks . The stable block aims to maintain stability, the ability to not forget previous knowledge, while the plastic block aims to increase plasticity, the ability to acquire new knowledge. In the diagram shown, the orange (plastic) blocks will have their parameters fully adapted to the new class data, while the blue (stable) blocks will have theirs partially fixed to maintain the knowledge learned from old classes . The primary difference between these two blocks is their ‘’level of learnability’’, in that the stable block possesses less learnable parameters than the plastic block . Once the inputs pass through a level, the aggregation weights are applied, which act to dynamically balance the influence of each block . The two types of blocks and the trainable aggregation weights comprise AANets’ solution to the stability-plasticity dilemma.
The Math and the Algorithm
Of course, learning the actual mathematical and algorithmic aspects behind the paper is just as important as understanding the concepts theoretically. That being said, for the sake of brevity, only the main parts will be included. In order to train both the network parameters and the aggregation weights, AANets treats the task as a bilevel optimization problem (BOP), one that alternates between solving two optimization problems in tandem .
First, the network parameters are initialized as [ϕ, η], with the aggregation weights being α . CIL usually assumes training for N+1 phases total, with one initial phase plus N incremental phases in which the number of classes gradually increases . On the i-th phase, in order to minimize the classification loss, the ideal BOP is formulated as:
This forms the upper-level problem (1) for updating the aggregation weights and the lower-level problem (2) for improving the network parameters . Subsequently, the update rules are as follows:
I know these equations look complicated, but they are basically a modified version of the loss function minimization and the classic weight update equations. Subsequently, the full algorithm is as follows:
Experiments and Results
The experiments themselves were tested on two CIL benchmark datasets: CIFAR-100 and ImageNet . CIFAR-100 contains 60,000 samples of 32x32 colour images for 100 classes, while ImageNet contains 1.3 million samples of 224x224 colour images for 1000 classes . As seen in the quantitative comparison table earlier in this review, all previous state-of-the-art methods improved with the implementation of AANets. However, rather than going over all the numbers and statistics, let’s instead discuss the qualitative results below, demonstrating the stability-plasticity dilemma during training. At first, the model makes the correct predictions using the stable block, as it successfully retains knowledge of old samples while the plastic block forgets. However, after a few phases of training, the plastic block performs better due to its ability to adapt to new data. The final takeaway from these experiments is that, regardless of which block performs better, AANets is able to extract the most informative features by utilizing both .
At the end of the day, we are still quite far away from creating an artificial general intelligence. With most deep learning applications limited in their scopes, we must address the fundamental problem of extending narrow AI to general AI. At the end of the day, regardless of how long you train a model for or how large your dataset is, if it can’t learn continuously, then it can never achieve true intelligence. The paper discussed today was but one step in the right direction, albeit for one of the most straightforward tasks in deep learning (image classification). Nevertheless, if these types of continual learning algorithms can be extended to models as large and elaborate as GPT-3, perhaps we can truly achieve some form of AGI one day.
1. iCaRL — the one that started it all; implemented on AANets’
2. LUCIR — also implemented on AANets’ Github.
3. PODNet — also referenced in the paper, though the implementation is work in progress.
4. Mnemonics Training — by the same author!