Abstract
Catastrophic forgetting remains an outstanding challenge in continual learning. Recently, methods inspired by the brain, such as continual representation learning and memory replay, have been used to combat catastrophic forgetting. Associative learning (retaining associations between inputs and outputs, even after good representations are learned) plays an important function in the brain; however, its role in continual learning has not been carefully studied. Here, we identified a two-layer neural circuit in the fruit fly olfactory system that performs continual associative learning between odors and their associated valences. In the first layer, inputs (odors) are encoded using sparse, high-dimensional representations, which reduces memory interference by activating nonoverlapping populations of neurons for different odors. In the second layer, only the synapses between odor-activated neurons and the odor’s associated output neuron are modified during learning; the rest of the weights are frozen to prevent unrelated memories from being overwritten. We prove theoretically that these two perceptron-like layers help reduce catastrophic forgetting compared to the original perceptron algorithm, under continual learning. We then show empirically on benchmark data sets that this simple and lightweight architecture outperforms other popular neural-inspired algorithms when also using a two-layer feedforward architecture. Overall, fruit flies evolved an efficient continual associative learning algorithm, and circuit mechanisms from neuroscience can be translated to improve machine computation.
1 Introduction
Catastrophic forgetting, when neural networks inadvertently overwrite old memories with new memories, remains a long-standing problem in machine learning (Parisi et al., 2019). Here, we studied how fruit flies learn continuously to associate odors with behaviors and discovered a circuit motif capable of alleviating catastrophic forgetting.
While modern machine learning algorithms excel at learning complex and discriminating representations (LeCun et al., 2015), an equally challenging problem in continual learning is finding good ways to preserve associations between these representations and output classes. Indeed, the performance of deep artificial neural networks is considerably degraded when classes are learned sequentially (one at a time), as opposed to being randomly interleaved in the training data (Goodfellow et al., 2013). The effect of this simple change is profound and has warranted the search for new mechanisms that can preserve input-output associations over long periods of time. In addition, catastrophic forgetting has been shown to effect deeper layers of neural networks more than feature extraction layers (Ramasesh et al., 2021). This finding highlights the importance of preserving good associations for reducing catastrophic forgetting.
Since learning in the natural world often occurs sequentially, the past few years have witnessed an explosion of brain-inspired continual learning models. These models can be divided into three categories: (1) regularization models, where important weights (synaptic strengths) are identified and protected (Hinton & Plaut, 1987; Fusi et al., 2005; Benna & Fusi, 2016; Kirkpatrick et al., 2017; Zenke et al., 2017; Douillard et al., 2020; Peng et al., 2021); (2) experience replay models, which use external memory to store and reactivate old data (Lopez-Paz & Ranzato, 2017) or use generative models to generate new data from prior experience (van de Ven et al., 2020; Tadros et al., 2020, 2022; Shin et al., 2017) and (3) complementary learning systems (McClelland et al., 1995; Roxin & Fusi, 2013), which partition memory storage into multiple subnetworks, each subject to different learning rules and rates. Importantly, these models often take inspiration from mammalian memory systems, such as the hippocampus (Wilson & McNaughton, 1994; Rasch & Born, 2007) or the neocortex (Qin et al., 1997; Ji & Wilson, 2007), where detailed circuit anatomy and physiology are still lacking. Fortunately, continual learning is also faced by simpler organisms, such as insects, where supporting circuit mechanisms are understood at synaptic resolution (Takemura et al., 2017; Zheng et al., 2018; Li et al., 2020).
Most brain-inspired algorithms use backpropagation-based supervised learning, whose existence in the brain is still controversial. On the other hand, associative learning (a local learning scheme that strengthens connections between representation neurons and output neurons) plays an important role for learning in the brain, though its role toward reducing catastrophic forgetting in artificial neural networks has not been carefully analyzed.
Here, we developed an associative continual learning algorithm inspired by the fruit fly olfactory system. This algorithm tackles an important yet underappreciated subproblem within continual learning: after good representations are learned, how do you best preserve associations between representation neurons and output classes in a class-incremental learning framework? The algorithm we propose stitches together two well-known computational ideas—sparse coding (Maurer et al., 2013; Ruvolo & Eaton, 2013; Ororbia et al., 2019; Ahmad & Scheinkman, 2019; Rapp & Nawrot, 2020; Hitron et al., 2020) and perceptron-like associative learning (Hinton & Plaut, 1987; Fusi et al., 2005; Benna & Fusi, 2016; Kirkpatrick et al., 2017; Zenke et al., 2017; Minsky & Papert, 1988)—in a unique and effective way, which we show effectively reduces catastrophic forgetting under a simple feedforward network architecture. The fruit fly circuit uses a perceptron-like architecture, which we prove theoretically helps reduce catastrophic forgetting compared to the original perceptron algorithm. We also show empirically that the fruit fly circuit outperforms alternative perceptron-like circuits in design space (e.g., replacing sparse coding with dense coding, associative learning (freezing synapses) with supervised learning (modifiable synapses), which provides biological insight into the function of these evolved circuit motifs and how they operate together in the brain to sustain memories.
2 Methods
2.1 Circuit Mechanisms for Continual Learning in Fruit Flies
How do fruit flies associate odors (inputs) with behaviors (classes) such that behaviors for odors learned long ago are not erased by newly learned odors? We first review the basic anatomy and physiology of two layers of the olfactory system that are relevant to the exposition here (Modi et al., 2020).
The two-layer neural circuit we study takes as input an odor after a series of preprocessing steps have been applied, including gain control (Root et al., 2008; Gorur-Shandilya et al., 2017), noise reduction (Wilson, 2013), and normalization (Olsen et al., 2010; Stevens, 2015). After these steps, odors are represented by the firing rates of d = 50 types of projection neurons (PNs), which constitute the input to the two-layer network motif described next.
2.1.1 Sparse Coding
The goal of the first layer is to convert the dense input representation of the PNs into a sparse, high-dimensional representation (Cayco-Gajic & Silver, 2019) (see Figure 1A). This is accomplished by a set of about 2000 Kenyon cells (KCs), which receive input from the PNs. The matrix connecting PNs to KCs is sparse and approximately random (Caron et al., 2013); each KC randomly samples from about 6 of the 50 projection neurons and sums up their firing rates. Next, each KC provides feedforward excitation to a single inhibitory neuron, called APL. In return, APL sends feedback inhibition to each KC. The result of this loop is that approximately 95% of the lowest-firing KCs are shut off, and the top 5% remain firing in what is often referred to as a winner-take-all (WTA) computation (Turner et al., 2008; Lin et al., 2014; Stevens, 2015). Thus, an odor initially represented as a point in is transformed, via a 40-fold dimensionality expansion followed by WTA thresholding, to a point in , where only approximately 100 of the 2000 KCs are active (i.e., nonzero) for any given odor.
This transformation was previously studied in the context of similarity search (Dasgupta et al., 2017, 2018; Papadimitriou & Vempala, 2018; Ryali et al., 2020), compressed sensing (Stevens, 2015; Zhang & Sharpee, 2016), and pattern separation for subsequent learning (Babadi & Sompolinsky, 2014; Litwin-Kumar et al., 2017; Dasgupta & Tosh, 2020).
2.1.2 Associative Learning
The goal of the second layer is to associate odors (sparse points in high-dimensional space) with behaviors.
The main locus of associative learning lies at the synapses between KCs and a set of mushroom body output neurons called MBONs (Aso et al., 2014), which encode behaviorally relevant odor information important for decision making (see Figure 1).
During training, say the fly is presented with a naive odor (odor A) that is paired with a punishment (e.g., an electric shock). How does the fly learn to avoid odor A in the future? Initially, the synapses from KCs activated by odor A to both the “approach” MBON and the “avoid” MBON have equal weights. When odor A is paired with punishment, the KCs representing odor A are activated around the same time that a punishment-signaling dopamine neuron fires in response to the shock. The released dopamine causes the synaptic strength between odor A KCs and the approach MBON to decrease, resulting in a net increase in the avoidance MBON response.1 Eventually the synaptic weights between odor A KCs and the approach MBON are sufficiently reduced to reliably learn the avoidance association (Felsenberg et al., 2018).
Importantly, the only synapses that are modified in each associative learning trial are those from odor A KCs to the approach MBON. All synapses from odor A KCs to the avoid MBON are frozen (i.e., left unchanged), as are all weights from silent KCs to both MBONs. Thus, the vast majority of synapses are frozen during any single odor-association trial.
To summarize, associative learning in the fly is driven by dopamine signals that only affect the synapses of sparse odor-activated KCs and a target MBON that drives behavior.
2.2 The FlyModel
We now introduce an associative continual learning algorithm based on the two-layer olfactory circuit above.
As input, we are given a d-dimensional vector, . As in the fly circuit, we assume that x is preprocessed to remove noise and encode discriminative features. Biologically, this is often accomplished by peripheral sensory circuitry that is separate from learning-related circuitry; computationally, the representations extracted from models trained on other data sets, as often done in transfer learning, could serve a similar role. To emphasize, our goal here is not to study the complexities of learning good representations but rather to disentangle representation learning from associative learning and focus exclusively on the latter.
For computational convenience, a min-max normalization is applied to ϕ(x) so that each unit has a value between 0 and 1. The matrix Θ is fixed and not modified during learning; that is, there are no trainable parameters in the first layer. The winner-takes-all competition could be implemented in alternative ways (Holca-Lamarre et al., 2017) besides the direct inhibition we show here. But direct inhibition has the benefit of easily specifying the number of neurons remaining active per input.
The second layer is an associative learning layer, which contains k output class units, y = {y1, y2, . . . , yk}. The first and second layers are connected with all-to-all synapses. If an input x is to be associated with target yj, the only weights that are modified are those between the active units in ϕ(x) and yj. No other weights, including those from the active units in ϕ(x) to the other k − 1 units in y—are modified. We refer to this as “partial freezing” of weights during learning.
Finally, biological synapses have physical bounds on their strength, and here we mimic these bounds by capping weights to [0, 1].
3 Theoretical Results
3.1 Even the Perceptron with Linearly Separable Data Suffers from Catastrophic Forgetting
It is well known that if the data points (x, y) presented to the perceptron algorithm have linear margin γ > 0 and have lengths bounded by ‖x‖ ≤ R, then the algorithm will make at most 2kR2/γ2 wrong predictions over its entire lifetime, where k is the number of classes. This is true regardless of the order in which the data are presented.
Nonetheless, catastrophic forgetting is possible. One way this can happen is if the data are not perfectly separable. In that case, the linear margin condition does not hold and the perceptron will not converge. But forgetting can occur even when the data are separable, as we now show.
Suppose we introduce one class at a time, or a couple of classes at a time, as is commonly done in continual learning. The guarantees of the perceptron need to be interpreted carefully in this setting. After new classes are introduced, the algorithm will, in general, need to see more examples of earlier classes and tweak their weight vectors further. It is thus never “done” with a particular class.
To make this more concrete, suppose that we introduce one class at a time, and the margin so far, after seeing the first j classes is γj. That is, γj is the linear margin of classes {1, 2, . . . , j}, which can only get smaller as j grows: γ1 ≥ γ2 ≥ γ3 ≥ ⋅⋅⋅ Upon introducing the kth class, the mistake bound goes up by ; we can think of this quantity as the additional number of mistakes we might make in accommodating the kth class. Crucially, in order to get convergence, some of these mistakes may need to be made on classes that have been seen earlier.
We now construct a concrete example of this phenomenon.
Pick any positive integer d that is a power of two. Then there exists a set of d vectors x1, . . . , xd − 1 ∈ with the following two properties: (1) each vector xi has exactly d/2 ones and (2) any pair of vectors has a dot product exactly d/4.
Start with the d × d Hadamard matrix and remove the row that is all ones. In the resulting (d − 1) × d matrix, replace every −1 with a 0, and take x1, . . . , xd − 1 to be the rows of the matrix.
Suppose the classes are introduced one at a time, in order:
Class 1 is introduced: w1 is set to x1, which is perfect.
- Class 2 is introduced: point x2 is misclassified as coming from class 1. Thus, w2 is set to x2, which is perfect, but w1 is changed to x1 − x2, which no longer correctly classifies class 1:Moreover, for any j > 2, we have w1 · xj = (x1 − x2) · xj = 0. Thus, all such points xj will be classified as class 2.
Class 3 is introduced: point x3 is misclassified as coming from class 2. Thus w3 is set to x3, but now w2 becomes x2 − x3, suffering a similar fate to w1.
Thus, even in the case where the perceptron algorithm is known to fare best (i.e., when the data are linearly separable), catastrophic forgetting occurs under continual learning.
3.2 The Partial-Freezing perceptron Does Not Suffer from Catastrophic Forgetting
The partial-freezing algorithm in equation 2.3 is an associative version of the perceptron. We now show that even in a continual learning framework, this algorithm will provably learn to correctly distinguish classes if the classes satisfy a separation condition that says, roughly, that dot products between points within the same class are, on average, greater than between classes. We will then show that adding sparse coding enhances the separation of classes (Babadi & Sompolinsky, 2014), making associative learning easier.
3.3 Sparse Coding Provably Creates Favorable Separation for Continual Learning
The separation condition of definition 1 is quite strong and might not hold in the original data space. But we will show that subsequent sparse coding can nonetheless produce this condition, so that the partial freezing algorithm, when run on the sparse encodings, performs well.
We can show that for suitable ξ, the sparse representations of the prototypes—that is, ϕ(p1), . . . , ϕ(pN) ∈ {0, 1}m—are then guaranteed to be separable, so that the partial freezing algorithm will converge to a perfect classifier.
Let No = |Cj|. Under the assumptions above, the sparse representation of the data set, , . . . , , is -separated in the sense of definition 1.
This is a consequence of theorem 9 in the supplement, a more general result that applies to a broader model in which observed data are noisy versions of the prototypes.
4 Experimental Evaluation
4.1 Testing Framework and Problem Setup
We tested each algorithm (see below) on two benchmark data sets using a class-incremental learning setup (Farquhar & Gal, 2019; van de Ven et al., 2020), in which the training data were ordered and split into sequential tasks. For the MNIST-20 data set (a combination of regular MNIST and Fashion MNIST; see the supplement), we used 10 nonoverlapping tasks, where each task is a classification problem between two classes. For example, the first task is to classify between digits 0 and 1, the second task is to classify digits 2 and 3, and so on. Similarly, the CIFAR-100 data set is divided into 25 nonoverlapping tasks, where each task is a classification problem among four classes.
Testing is performed after the completion of training of each task and is quantified using two measures. The first measure, the accuracy for classes trained so far, assesses how well classes from previous tasks remain correctly classified after a new task is learned. Specifically, after training task i, we report the accuracy of the model tested on classes from all tasks ≤ i. The second measure, memory loss, quantifies forgetting for each task separately. We define the memory loss of task i as the accuracy of the model when tested (on classes from task i only) immediately after training on task i minus the accuracy when tested (again, on classes from task i only) after training on all tasks, that is, at the end of the experiment. For example, say the immediate accuracy of task i is 0.80, and the accuracy of task i at the end of the experiment is 0.70. Then the memory loss of task i is 0.10. A memory loss of zero means memory of the task was perfectly preserved despite learning new tasks.
4.1.1 Comparison to Other Algorithms
There are, of course, many heavy-duty continual learning algorithms in the literature, and our intention here is not to perform an exhaustive comparison to them. Instead, we compared the FlyModel with three neurally plausible methods that are popular in the literature and represent the broad strategies of brain-inspired continual learning (e.g., synapse protection, memory replay) outlined in section 1. All of these methods use backpropagation-based supervised learning as opposed to associative learning. Moreover, none of these methods have provable convergence guarantees, as ours does:
Elastic weight consolidation (EWC; Kirkpatrick et al., 2017) uses the Fisher information criterion to identify weights that are important for previously learned tasks and then introduces a penalty if these weights are modified when learning a new task.
Gradient episodic memory (GEM; Lopez-Paz & Ranzato, 2017) uses a memory system that stores a subset of data from previously learned tasks. These data are used to assess how much the loss function on previous tasks increases when model parameters are updated for a new task.
Brain-inspired replay (BI-R; van de Ven et al., 2020) protects old memories by using a generative model to replay activity patterns related to previously learned tasks.
Vanilla is a standard fully connected neural network that does not have any explicit continual learning mechanism. This is used as a lower bound on performance.
Offline is a standard fully connected neural network, but instead of learning classes sequentially, for each task, it is retrained from scratch on all classes (current and previously seen) together, presented in a random order. This is used as an upper bound on performance.
See the supplement for full details on data sets, preprocessing, network architectures, and parameters.
4.2 The FlyModel Outperforms Existing Methods in Class-Incremental Learning
The FlyModel reduced catastrophic forgetting compared to all four continual learning methods tested. For example, on the MNIST-20 data set (see Figure 2A), after training on 5 tasks (10 classes), the accuracy of the FlyModel was 0.86 ± 0.0006 compared to 0.77 ± 0.02 for BI-R, 0.69 ± 0.02 for GEM, 0.58 ± 0.10 for EWC, and 0.19 ± 0.0003 for Vanilla. At the end of training (10 tasks, 20 classes trained), the test accuracy of the FlyModel was at least 0.19 higher than any other method and only 0.11 lower than the optimal offline model.
Next, we used the memory loss measure to quantify how well the “memory” of an old task is preserved after training new tasks (see Figures 2B and S3). As expected, the standard neural network (Vanilla) preserves almost no memory of previous tasks; it has a memory loss of nearly one for all tasks except the most recent task. While GEM, EWC, and BI-R perform better—memory losses of 0.24, 0.27, and 0.42, respectively, averaged across all tasks—the FlyModel has an average memory loss of only 0.07. This means that the accuracy of task i was only degraded on average by 7% at the end of training when using the FlyModel.
4.3 Sparse Coding and Partial Freezing Are Both Required for Continual Learning
An important challenge in theoretical neuroscience is to understand why circuits may be designed the way they are. Quantifying how evolved circuits fare against putative, alternative circuits in design space could provide insight into the biological function of observed network motifs. We explored this question in the context of the two core components in the FlyModel: sparse coding of representations in the first layer and partial freezing of synaptic weights in the associative learning layer. Are both of these components required, or can good performance be attained with only one or the other?
We piecemeal explored the effects of replacing sparse coding with dense coding and replacing partial freezing with a traditional single-layer neural network (i.e., logistic regression), where every weight can change for each input. This gave us four combinations to test. The dense code was calculated in the same way as the sparse code, minus the winner-take-all step. In other words, for each input x, we used ψ(x) (see equation 2.1, with min-max normalization) as its representation, instead of ϕ(x) (see equation 2.2). For logistic regression, the associative layer was trained using backpropagation.
Both sparse coding variants (with partial freezing or with logistic regression) performed better than the two dense coding variants on both data sets (see Figures 3A and ;3B). For example, on MNIST-20, at the end of training, the sparse coding models had an average accuracy of 0.64 versus 0.07 for the two dense coding models. Furthermore, sparse coding with partial freezing (i.e., the FlyModel) performed better than sparse coding with logistic regression: 0.75 versus 0.54 on MNIST-20 and 0.41 versus 0.21 on CIFAR-100.
Hence, on at least the two data sets used here, both sparse coding and partial freezing are needed to optimize continual learning performance.
4.4 Comparison of the FlyModel with the Perceptron
In our theoretical analysis, we highlighted two important differences between the perceptron-supervised learning algorithm and the FlyModel associative learning algorithm. Next, we studied how the four combinations of these two differences affect continual learning.
The first combination (Perceptron v1) is the classic perceptron learning algorithm, where weights are modified only if an incorrect prediction is made, by increasing weights to the correct class and decreasing weights to the incorrectly predicted class. The second combination (Perceptron v2) also learns only when a mistake is made, but it increases weights only to the correct class (i.e., it does not decrease weights to the incorrect class). The third combination (Perceptron v3) increases weights to the correct class regardless of whether a mistake is made, and it decreases weights to the incorrect class when a mistake is made. Finally, the fourth combination (Perceptron v4) is equivalent to the FlyModel; it simply increases weights to the correct class regardless of whether a mistake is made. All models start with the same sparse, high-dimensional input representations in the first layer.
Overall, we find a striking difference in continual learning with these two tweaks, with the FlyModel performing significantly better than the other three models on both data sets (see Figures 4A and ;4B). Specifically, learning regardless of whether a mistake is made (v3 and v4) works better than mistake-only learning (v1 and v2), and decreasing the weights to incorrectly predicted class hurts performance (v4 compared to v3; no major difference between v2 and v1).
Perceptron v1 (Original Perceptron) . | Perceptron v2 . |
---|---|
1: forx in data do | 1: forx in data do |
2: if predict ≠ target then | 2: if predict ≠ target then |
3: weight[target] += βx | 3: weight[target] += βx |
4: weight[predict] −= βx | 4: |
5: end if | 5: end if |
6: end for | 6: end for |
Perceptron v1 (Original Perceptron) . | Perceptron v2 . |
---|---|
1: forx in data do | 1: forx in data do |
2: if predict ≠ target then | 2: if predict ≠ target then |
3: weight[target] += βx | 3: weight[target] += βx |
4: weight[predict] −= βx | 4: |
5: end if | 5: end if |
6: end for | 6: end for |
Perceptron v3 . | Perceptron v4 (FlyModel) . |
---|---|
1: forx in data do | 1: forx in data do |
2: if predict ≠ target then | 2: if predict ≠ target then |
3: weight[target] += βx | 3: weight[target] += βx |
4: weight[predict] −= βx | 4: |
5: else | 5: else |
6: weight[target] += βx | 6: weight[target] += βx |
7: end if | 7: end if |
8: end for | 8: end for |
Perceptron v3 . | Perceptron v4 (FlyModel) . |
---|---|
1: forx in data do | 1: forx in data do |
2: if predict ≠ target then | 2: if predict ≠ target then |
3: weight[target] += βx | 3: weight[target] += βx |
4: weight[predict] −= βx | 4: |
5: else | 5: else |
6: weight[target] += βx | 6: weight[target] += βx |
7: end if | 7: end if |
8: end for | 8: end for |
As we showed analytically, decreasing weights to the incorrect class (v1 and v3) suffers from catastrophic forgetting when inputs from different classes are overlapping. While this feature of the perceptron algorithm is believed to help create a larger boundary (margin) between the predicted incorrect class and the correct class, it also causes shared weights to be hijacked by recent classes observed. This leads to more catastrophic forgetting, albeit faster initial learning. The FlyModel, on the other hand, avoids this issue because the shared neurons are split between both classes and thus cancel each other out. As a result, the weight vectors in the associative layer converge to the mean of its class inputs, scaled by a constant (see the supplement, lemmas 3 and 4 and theorems 5 and 8). See supplement figures S4 and S5 for an empirical demonstration of this result.
5 Discussion
While learning mechanisms in the brain have been the source of inspiration for many continual learning algorithms, one commonly used neural learning mechanism (associative learning) has been largely overlooked. Here, we developed a simple and lightweight associative continual learning algorithm that reduces catastrophic forgetting, inspired by how fruit flies learn odor-behavior associations. The FlyModel outperformed three popular class-incremental continual learning algorithms on two benchmark data sets (MNIST-20 and CIFAR-100), despite not using external memory, generative replay, or backpropagation. The fly’s associative learning algorithm is strikingly similar to the classic perceptron algorithm but for two modifications that we show are critical for retaining old memories. Indeed, alternative circuits in design space suffered more catastrophic forgetting than the FlyModel, potentially shedding new light on the biological function and conservation of this circuit motif. Finally, we grounded these ideas theoretically by proving that associative layer weight vectors in the FlyModel converge to the mean representation of its class and that sparse coding further reduces memory interference by better separating classes compared to the conventional perceptron algorithm, which we proved suffers under the continual learning scenario even when classes are linearly separable.
Given the same architecture, the FlyModel requires less memory and has comparable training efficiency compared to alternative methods (see the supplement, Figures S6 and S7). If the input layer is N-dimensional and the hidden layer undergoes a 40 times dimensionality expansion, given m inputs and t tasks, the number of parameters (weights) the FlyModel needs to store is of O(N2 + tN) across the two layers, and the total computational complexity for training is O(mN), since for each input, we only update a few weights in the second layer. In addition, since FlyModel makes no distinction between tasks, the computational complexity is independent of t. On the other hand, EWC and GEM require additional storage of weights or data from previous tasks.
The two main features of the FlyModel—sparse coding (Kanerva, 1988, 2009; Babadi & Sompolinsky, 2014) and associative learning (i.e., partial synaptic freezing; Kirkpatrick et al., 2017; Zenke et al., 2017)—are well appreciated in both neuroscience and machine learning. For example, sparse, high-dimensional representations have long been recognized as central to neural encoding (Kanerva, 1988), hyperdimensional computing (Kanerva, 2009), and classification and recognition tasks (Babadi & Sompolinsky, 2014). Similarly, the notion of freezing certain weights during learning has been used in both classic perceptrons and modern deep networks (Kirkpatrick et al., 2017; Zenke et al., 2017), but these methods are still subject to interference caused by dense representations. However, the benefits of such features toward continual learning have not been well quantified. Indeed, the fly circuit evolved a unique combination of common computational ingredients that work effectively in practice.
The FlyModel performs associative rather than supervised learning. In associative learning, the same learning rule is applied regardless of whether the model makes a mistake. In supervised learning, changes are made to weights only when the model makes a mistake, and the changes are applied to weights for both the correct and the incorrect class labels. In other words, the FlyModel learns each class independently compared to supervised methods, and hence is flexible about the total number of classes to be learned; the network is easily expandable to more classes if necessary. Supervised methods focus on discrimination between multiple classes at a time, which we showed is particularly susceptible to interference, especially when class representations are overlapping. Thus, our results suggest that some traditional benefits of supervised classification may not carry over to continual learning (Hand, 2006) and that association-like models may better preserve memories when classes are learned sequentially.
However, associative learning alone without good representations is not sufficient to achieve good continual learning performance. While conventional backpropagation-based continual learning algorithms try to solve both representation learning and association learning at the same time, it could be argued that the brain takes a different approach by separating these two into different network layers. The associative learning component of continual learning has not been well studied in the literature, even though, as we showed, this seemingly simple problem can have important consequences on reducing catastrophic forgetting.
Previous studies share conceptual similarities with some features of FlyModel. For example, PackNet (Mallya & Lazebnik, 2018) uses weight pruning to free up redundant weights while keeping important weights fixed. This approach is similar to partial freezing, but instead of pruning less important weights, partial freezing only modifies relevant weights during learning and requires no computation to determine the importance of weights in retrospect. Partial freezing also resembles another well-known continual learning method, iCaRL (Rebuffi et al., 2017). iCaRL selects a few prototypes per class, stores these, and then, at prediction time, averages the prototypes for each class (using the current representation) and chooses the one nearest the query vector. FlyModel maintains one linear function per class and, at prediction time, takes the one with the highest value.
There are four additional features of the fruit fly mushroom body (MB) that remain underexplored computationally. First, instead of using one output neuron (MBON) per behavior, the mushroom body contains multiple output neurons per behavior, with each output neuron learning at a different rate (Hige et al., 2015; Aso & Rubin, 2016). This simultaneously provides fast learning with poor retention (large learning rates) and slow learning with longer retention (small learning rates), which is reminiscent of complementary learning systems (Parisi et al., 2019). Second, the MB contains mechanisms for memory extinction (Felsenberg et al., 2018) and reversal learning (Felsenberg et al., 2017; Felsenberg, 2021), which are used to update inaccurate memories. Third, there is evidence of memory replay in the MB, which is required for memory consolidation (Yu et al., 2005; Haynes et al., 2015; Cognigni et al., 2018). Fourth, there exists feedback from the MB to the input layer that could tune representations during learning (Hu et al., 2010). We hope our model can be used as a stepping stone as circuit mechanisms controlling these computations are discovered. Moreover, although we only evaluated continual learning performance using simple architectures, follow-up work has already successfully implemented some variants of FlyModel (Robinson et al., 2023; Bricken et al., 2023), suggesting that better performance can indeed be achieved using more sophisticated architectures.
Finally, a motif similar to that of the fruit fly olfactory system also appears in the mouse olfactory system, where sparse representations in the piriform cortex project to other learning-related areas of the brain (Komiyama & Luo, 2006; Wang et al., 2020). In addition, the visual system uses many successive layers to extract discriminative features (Riesenhuber & Poggio, 1999; Tacchetti et al., 2018), which are then projected to the hippocampus, where a similar sparse, high-dimensional representation is used for memory storage (Olshausen & Field, 2004; Wixted et al., 2014; Lodge & Bischofberger, 2019). Thus, the principles of learning studied here may help illuminate how continual learning is implemented in other brain regions and species.
In all, our work exemplifies how understanding detailed neural anatomy and physiology in a tractable model system can be translated into efficient architectures for use in artificial neural networks.
Note
Curiously, approach behaviors are learned by decreasing the avoid MBON response, as opposed to increasing the approach MBON response, as may be more intuitive.