Abstract
There is growing interest in predictive coding as a model of how the brain learns through predictions and prediction errors. Predictive coding models have traditionally focused on sensory coding and perception. Here we introduce active predictive coding (APC) as a unifying model for perception, action, and cognition. The APC model addresses important open problems in cognitive science and AI, including (1) how we learn compositional representations (e.g., part-whole hierarchies for equivariant vision) and (2) how we solve large-scale planning problems, which are hard for traditional reinforcement learning, by composing complex state dynamics and abstract actions from simpler dynamics and primitive actions. By using hypernetworks, self-supervised learning, and reinforcement learning, APC learns hierarchical world models by combining task-invariant state transition networks and task-dependent policy networks at multiple abstraction levels. We illustrate the applicability of the APC model to active visual perception and hierarchical planning. Our results represent, to our knowledge, the first proof-of-concept demonstration of a unified approach to addressing the part-whole learning problem in vision, the nested reference frames learning problem in cognition, and the integrated state-action hierarchy learning problem in reinforcement learning.
1 Introduction
Predictive coding (Rao & Ballard, 1997, 1999; Rao, 1999; Keller & Mrsic-Flogel, 2018; Jiang & Rao, 2022b) has received increasing attention in recent years as a model of how the brain learns models of the world through prediction and self-supervised learning. In predictive coding, feedback connections from a higher to a lower level of a cortical neural network (e.g., the visual cortex) convey predictions of lower-level responses, and the prediction errors are conveyed via feedforward connections to correct the higher-level estimates, completing a prediction-error-correction cycle (see also Mumford, 1992). Such a model has provided explanations for a wide variety of neural and cognitive phenomena (Keller & Mrsic-Flogel, 2018; Jiang & Rao, 2022b). The layered architecture of the cortex is remarkably similar across cortical areas (Mountcastle, 1978), hinting at a common computational principle, with superficial layers receiving and processing sensory information and deeper layers conveying outputs to motor centers (Sherman & Guillery, 2013). The traditional predictive coding model focused on learning visual hierarchical representations and did not acknowledge the important role of actions in learning world models.
In this article, we introduce active predictive coding (APC), a new model of predictive coding that combines state and action networks at different abstraction levels to learn hierarchical internal models. The model provides a unified framework for addressing several important but seemingly unrelated problems in perception, action, and cognition as described below.
2 Related Work and Contributions
2.1 Part-Whole Learning Problem
Hinton and colleagues have posed the problem of how neural networks can learn to parse visual scenes into part-whole hierarchies by dynamically allocating nodes in a parse tree. They have explored networks that use a group of neurons to represent not only the presence of an object but also parameters such as position and orientation (Sabour et al., 2017; Kosiorek et al., 2019; Hinton et al., 2018; Hinton, 2021). Such equivariant models seek to overcome the inability of deep convolutional neural networks (CNNs) (Krizhevsky et al., 2012) to explain the images they classify in the way humans do, in terms of objects, parts and their locations.
Other approaches such as the ones presented in Burgess et al. (2019) or Greff et al. (2019) represent a scene by segmenting it into distinct objects and modeling them separately. Seitzer et al. (2023) demonstrate such an approach on more realistic data sets. Oquab et al. (2023) apply principal-component analysis (Pearson, 1901) to image-patch features learned from a self-supervised, transformer-based model. They show that the first few principal components capture information that can be used to segment an object from its background or identify different parts of the object regardless of pose or style.
2.2 Reference Frames Problem
In a parallel line of research, Hawkins and colleagues (Hawkins, 2021; Lewis et al., 2019; see also George & Hawkins, 2009; George et al., 2017; Guntupalli et al., 2023) have taken inspiration from the cortex and “grid cells” to propose that the brain uses object-centered reference frames (or “schemas”) to represent objects, spatial environments, and even abstract concepts. There is some evidence from hippocampal and cortical studies in rodents and humans for spatial (and more abstract) reference frames being used for solving problems such as navigation and abstract reasoning (O’Keefe & Dostrovsky, 1971; Moser et al., 2017; Constantinescu et al., 2016). The question of how such reference frames can be learned and used in a nested manner for hierarchical recognition and reasoning has remained open.
2.3 Integrated State-Action Hierarchy Learning Problem
A considerable literature exists on hierarchical reinforcement learning (see Hutsebaut-Buysse et al., 2022, for a recent survey), where the goal is to make traditional reinforcement learning (RL) algorithms more efficient through state and/or action abstraction. While one class of approaches relies on identifying particular states as “subgoals” and achieving these subgoals to solve a main task (Hafner et al., 2022), another class of approaches uses options (Sutton et al., 1999; Bacon et al., 2016), which are abstract actions that can be selected in particular states (in the option’s “initiation set”) and whose execution results in a sequence of primitive actions as prescribed by the option’s lower-level policy. The problem of simultaneously learning state and action abstraction hierarchies has remained relatively less explored.
2.4 Hypernetworks
A hypernetwork (Ha et al., 2017) is an artificial neural network that generates parameters for another neural network called the primary network. The primary network is essentially a placeholder into which the generated weights are plugged in, to be used for solving a downstream task. In the simple case of a fully connected primary network of layers, these parameters consist of the weight matrices and bias vectors . Hypernetworks are typically trained end-to-end.
Hypernetworks in their most general form (as defined above) are biologically implausible: neural networks in the brain cannot generate other neural networks ex nihilo for each input. Although we do not pursue biological plausibility in this article, we note that there are at least two ways to approximate hypernetworks in a biologically plausible manner. First, one can use a specialized type of hypernetwork: instead of generating a new network, a hypernetwork could modulate an existing “primary” neural network , for example, by generating gain factors that multiply ’s outputs and, by doing so, change the function being computed by the neural network . Such “gain modulation” (Zipser & Andersen, 1988; Salinas & Abbott, 1996; Salinas & Sejnowski, 2001; Ferguson & Cardin, 2020; Shine et al., 2021; Stroud et al., 2018) appears to be common in the cortex, observed, for example, in the multiplicative modulation of tuning curves of visual cortical neurons during attention (McAdams & Maunsell, 1999) and in the changes in the input-output function of neurons in deep layers of the cortex due to top-down modulatory inputs to their apical dendrites (Larkum et al., 2004). It is important to note that when gain modulation is mediated by neuromodulators such as dopamine, it is typically diffuse and not synapse- or neuron-specific, as needed to implement a full-fledged hypernetwork. However, even diffuse coarse-grained modulation of a recurrent primary network can in some cases achieve a performance similar to that of neuron-specific modulation, as shown by Stroud et al. (2018). Second, instead of generating parameters for a network, the hypernetwork can generate a top-down contextual input for the primary network . ’s input is augmented with the contextual input, thereby allowing the hypernetwork to modulate the function being computed by (Yang et al., 2019; Eliasmith et al., 2012). Such an approach, known as the embedding approach in AI, can approximate the computational function of hypernetworks (Galanti & Wolf, 2020) in a biologically plausible manner (see Rao, 2022, for more details).
Although hypernetworks have been used extensively in AI, the problem of using hypernetworks for hierarchical abstraction of state and action functions has remained open.
2.5 Contributions of the Article
The APC model addresses the problems above in a unified manner using state/action embeddings and hypernetworks to dynamically generate and generalize over state and action networks at multiple hierarchical levels. The APC model contributes to a number of lines of research not connected before:
Perception, predictive coding, and reference frame learning: APC extends predictive coding and related neuroscience models of brain function (Rao & Ballard, 1999; Friston & Kiebel, 2009; Jiang et al., 2021; Jiang & Rao, 2022a) to hierarchical sensory-motor inference and learning, and connects these to learning nested reference frames (Hawkins, 2021) for perception and cognition.
Attention models: APC extends previous hard attention models such as the recurrent attention model (RAM) (Mnih et al., 2014) and attend-infer-repeat (AIR) (Eslami et al., 2016). As an active visual perception technique, it learns structured hierarchical strategies for sampling key parts of the visual scene.
Hierarchical planning and reinforcement learning: APC contributes to hierarchical planning and reinforcement learning (Hutsebaut-Buysse et al., 2022; Botvinick et al., 2009) by proposing a new way of simultaneously learning abstract macro-actions or options (Sutton et al., 1999) and abstract states.
2.6 General Applicability of the Model
When applied to vision, the APC model learns to hierarchically represent and parse images into parts and locations. When applied to RL problems, the model can exploit hypernetworks to (1) define a state hierarchy not merely through state aggregation, but by abstracting transition dynamics at multiple levels, and (2) potentially generalize learned hierarchical states and actions (options) to novel scenarios via interpolation and extrapolation in the input embedding space of the hypernetworks. Our approach brings us closer to solving an important challenge in both AI and cognitive science (Lake et al., 2017): How can neural networks learn hierarchical compositional representations that allow new concepts to be created, recognized, and learned?
3 Active Predictive Coding
The APC model implements a hierarchical version of the traditional partially observable Markov decision process (POMDP) (Kaelbling et al., 1998; Rao, 2010). Figure 1a shows the canonical APC generative module. The module is self-similar and is separated into two distinct systems: the state system and the action system. Following the POMDP formulation, the state system captures the transition dynamics of the environment. The action system determines which action (actual or abstract) the agent will take toward solving the downstream task. The state system maintains historical context via the recurrent state vector and the action system via the action vector . We denote the recurrent neural networks (RNNs) that generate these vectors by and , respectively. We can stack multiple such modules by allowing the current state and action vector at any given level to generate an entire new state network and action network respectively at the level below using hypernetworks (Ha et al., 2017).
Specifically, let and denote the state and action vectors at the higher level . The state embedding vector together with the function (implemented as a hypernetwork) generates a lower-level state transition function (Figure 1a, left). Similarly, the action embedding vector together with the hypernetwork generates a lower-level option/policy function (see Figure 1a, right). Both and are implemented as RNNs. The state and action systems exchange information horizontally within each level as shown in Figure 1b for a three-level APC model. During inference (see below), the APC architecture employs a feedback mechanism via which the lower level can inform the higher level of its “findings” via prediction errors (see Figure 1b, red arrows).
The novel idea behind the APC approach is to imbue structure by abstracting lower-level state and action transitions via RNN (subprogram) generation and restricting the scale or extent of a subprogram in terms of temporal steps or the afforded action space. In our current implementation, the lower-level RNNs (subprograms) execute for a fixed number of time steps before returning control back to the higher level.1 Although not the focus of this paper, we note that the components of the APC model described above could potentially be implemented with biologically plausible mechanisms within the laminar structure of the neocortex (Rao, 2022). Specifically, as mentioned in section 2.4 and discussed in more detail in Jiang & Rao (2022a) and Rao (2022), hypernetworks could be implemented via the neural mechanism of top-down gain modulation in cortical neurons (Larkum et al., 2004; Ferguson & Cardin, 2020).
3.1 Inference in the Active Predictive Coding Model
Inference involves estimating the state and action vectors at multiple levels based on the sequence of inputs produced by interacting with the environment in the context of a particular task or goal. For the rest of this article, we assume a two-level APC model (see Figure 1b without and ), which is sufficient to illustrate the basic capability of the proposed framework (we leave the exploration of deeper models to future work; see the discussion in section 6).
Given a two-level APC model, we assume the top level runs for steps (referred to as macrosteps). For each macrostep, the bottom level runs for “microsteps.” To improve readability, instead of and , we use the notations and to denote the top-level and bottom-level state and action functions, respectively (all functions are implemented by RNNs). Similarly, instead of and , we use the notation and to denote the top-level and bottom-level state and action embedding vectors, respectively (we omit boldface notation for vectors for the rest of the article as well). These state and action vectors are estimated by the recurrent activity vectors of the respective state and action networks. We use the notation to denote a network parameterized by , the weight matrices and biases for all the layers. Thus, the bottom-level state and action RNNs are denoted by and , while their activity vectors are denoted by and respectively ( ranges over macrosteps, over microsteps).
At each macrostep , the top-level state RNN produces a new state embedding vector based on the previous state and action embedding vectors. This higher-level state and action from the previous macrostep are fed to the action/policy RNN (which is determined by the current task or goal) to produce the next action-embedding vector (a macro-action/option/subgoal). The embedding vector is used as input to a nonlinear function, implemented by the hypernetwork , to dynamically generate the parameters ) for the lower-level action RNN, which implements a policy to generate primitive actions suitable for achieving the subgoal associated with .
The higher-level state also defines a new reference frame for the lower level to operate over as follows: (and any other state-relevant information) is fed as input to the state hypernetwork to generate the lower-level parameters ) specifying a dynamically generated bottom-level state RNN characterizing the state transition dynamics locally, for example, local parts and their transformations in vision (see Application I in section 4) and navigation dynamics in a local region of a building (see application II in section 5).
Each microstep proceeds in a manner similar to a macrostep. The bottom-level action RNN produces the next action based on the current lower-level state and previous action (see Figure 1b, lower right). This action (e.g., sensor/body movement or a lower-level abstract action) results in a new input being generated by the environment for the bottom (and possibly higher) state network.
To predict an input at time step , the lower-level state vector is fed to a generic decoder network to generate the input prediction . This predicted input is compared to the actual input to generate a prediction error . Following the predictive coding model (Rao & Ballard, 1999), the prediction error is used to update the state vector via the state network: . Additionally, for inference of the top-level state, the top-level RNN activity vector is updated using information from the lower-level state vectors , and the process continues.
3.2 Training the Active Predictive Coding Model
Since the state networks are task-agnostic and geared toward capturing the dynamics of the world, they are trained using self-supervised learning by minimizing prediction errors. For the results in this article, we used backpropagation, but other biologically plausible mechanisms for minimizing prediction errors may also be used (Rao & Ballard, 1999; Whittington & Bogacz, 2017; Lillicrap et al., 2020); such mechanisms have also been shown to emerge as a consequence of minimizing energy consumption (Ali et al., 2022). The action networks are trained to integrate the information provided by the state vectors toward a downstream task by minimizing the total expected task loss: this can be done using either reinforcement learning or planning with the help of the state networks. In application I, we illustrate the use of reinforcement learning,2 while in application II, we illustrate the use of planning, but the APC framework is flexible and allows either approach for estimating actions. Algorithm 1 summarizes the APC training process, with further details for each application provided in the sections that follow.
4 Application I: Visual Perception
A long-standing problem in vision and cognitive science is (Hinton, 2021): How can neural networks learn intrinsic references frames for objects and concepts and parse inputs (e.g., images) into part-whole hierarchies? Human vision provides an important clue. Unlike convolutional nets, which need to process an entire scene, human vision is an active sensory-motor process, sampling the scene via eye movements to move the high-resolution fovea to task-relevant locations, accumulating evidence for or against competing visual hypotheses (Wedel et al., 2022). The APC model is well suited to emulating the sensory-motor nature of human vision, given its integrated state and action networks.
For visual perception and part-whole learning, the actions in the APC model emulate eye movements (or attention) by moving a “glimpse sensor” (Mnih et al., 2014), which extracts high-resolution information about a small part of a larger input image. Ideally we would like our model to exhibit spatial convergence as we go up the representational hierarchy, capturing the inductive bias that an entity has a larger spatial extent than its constituent parts. The APC vision model implements this concept by using recursive object-centered reference frames. The top level of an APC architecture spans the entire image. At each step, the network chooses a subregion of the image to focus on (see Figure 2a). It then generates a lower-level image parser (comprising state-action subnetworks) and assigns this image subregion as the input to the lower level. The bottom-most level has direct access to the image via small-sized glimpses. The APC model performs a type of depth-first exploration of the representational graph, where each layer descends deeper into the graph with a new object-centered reference frame. These stacks of reference frames can be composed to derive the absolute location of any sampled glimpse within the image. Figure 2b shows an example of recursive reference frame traversal down a two-level hierarchy.
Interactions with an image (of size pixels) are carried out through a glimpse sensor . This sensor takes in a location and a fixed-scale fraction and extracts a square glimpse or patch centered around and of size . Since is continuous, the sensor is implemented using a subdifferentiable bilinear interpolation module as introduced in Jaderberg et al. (2015). The image dimensions are normalized so that ; is hard-coded for each layer. Other transformations such as rotation and shear are ignored in the current version of this model, but incorporating them represents an obvious direction for future research.
Note that the APC model’s use of a glimpse mechanism offers an additional degree of biological plausibility when compared to approaches that assume access to the whole image at once. Additionally, using a foveation emulator to intelligently select image locations for further processing potentially uses fewer neurons in the network and lowers energy expenditure compared to models that process all image locations such as CNNs (Krizhevsky et al., 2012) or vision transformers (ViTs; Dosovitskiy et al., 2021); this is because the latter scale quadratically with input size, while models using foveation such as APC have constant size regardless of input size.
4.1 Higher-Level Operation
The APC model we implemented for vision uses continuous feedback from the lower-level states. We use to denote the top-level vectors at macrostep and microstep . Although the top-level states receive continuous feedback, the lower-level parameters and are generated at the beginning of each macrostep using only and stay fixed throughout. The initial lower-level state vector is initialized through a small neural network, using . This state vector is then used together with to initialize .
At each macrostep , the top-level action RNN updates its activity vector and generates two values via two networks: a location and a macro-action (or option) (see Figure 3). The location is used to restrict the bottom level to a subregion corresponding to a new frame of reference of scale , centered around (see Figure 2a). The option is used as an embedding vector input to a nonlinear function, implemented by a hypernetwork , to dynamically generate the parameters of the lower-level action RNN. For exploration during reinforcement learning, we treat the location network output of the top-level action RNN as a mean value and add gaussian noise with fixed variance to sample an actual location: , where .
The state vector and location are fed as inputs to the state hypernetwork to generate the parameters specifying a dynamically generated bottom-level state RNN for the current frame of reference. Figure 3 illustrates this top-down generation process.
4.2 Lower-Level Operation
At the beginning of each microstep, the higher-level state is used to initialize the bottom-level state vector via a small feedforward network. Each microstep proceeds in a manner similar to a macrostep. The bottom-level action RNN produces the action vector based on the current state and past action, and a location is chosen as a function of (see Figure 3, lower right). This results in a glimpse image of scale centered around within the image subregion specified by the higher level. Figures 2b and 3 show the frames of reference and the corresponding image subregions across the two levels and how they relate to the operation of the two levels and state-action systems.
To predict the next glimpse image at the location specified by the action network, the lower-level state vector , along with locations and , are fed to a generic decoder network to generate the predicted glimpse . This predicted glimpse is compared to the actual glimpse image to generate a prediction error (see Figure 4). Following the predictive coding model (Rao & Ballard, 1999), the prediction error is used to update the state vector via the state network: . For exploration during reinforcement learning, we follow the same gaussian noise-based exploration strategy as the top level.
4.3 Training the Active Predictive Coding Network
The state and action systems are trained separately via different loss functions. The state system is trained to minimize prediction errors via backpropagation, while the action system is trained to minimize the total expected task loss via the reinforcement learning algorithm REINFORCE (Williams, 1992) together with backpropagation. In the implementation of the training procedure, whenever the state vectors at any given level are passed as input to that level’s action network, the gradients for backpropagation are cut off. The goal of the state prediction system is to predict the next state and is taskagnostic. The goal of the action system is to choose effective actions given past states and actions, so that the task loss is minimized.
4.3.1 Training the State System
4.3.2 Training the Action System
The top layer is trained using the cumulative reward from all future macrosteps , while the bottom layer is trained using the cumulative reward from all future microsteps .
We use an adjusted version of the baseline-based variance-reduction technique introduced in Sutton et al. (2000) and used in Mnih et al. (2014). We learn two separate baselines, and , and use the baseline-removed cumulative rewards and for training.
As mentioned earlier, to allow exploration during training with REINFORCE, the locations at each macro- or microstep were the location network’s output plus gaussian noise. Therefore, the logarithmic probability terms above reduce to the squared Euclidean distances between the mean and the sampled locations.
4.4 Results
We first tested the APC model on the task of sequential part/location prediction and image reconstruction of objects in the following data sets3:
MNIST: Original MNIST data set of 10 classes of handwritten digits.
Fashion-MNIST (FMNIST): Instead of digits, the data set consists of 10 classes of clothing items.
Omniglot: 1623 hand-written characters from 50 alphabets, with 20 samples per character.
affNIST: MNIST digits embedded in a pixel frame and transformed via random affine transformations.
4.4.1 Parsing Images and Perceptual Stability
Figure 6 shows an example of a parsing strategy learned by a two-level APC model for an MNIST digit. The top level learned to cover the input image sufficiently, while the bottom level learned to parse subparts inside the reference frame computed by the higher level. Figure 6 also suggests an explanation for why human perception can appear stable despite dramatic changes in our retinal images as our eyes move to sample a scene: the last row of the figure shows how the model maintains a visual hypothesis that is gradually refined and does not exhibit the kind of rapid changes seen in the sampled images (“Actual Glimpses” in Figure 6).
4.4.2 Prediction of Parts and Pattern Completion
To investigate the predictive and generative ability of the model, we had the model “hallucinate” different parts of an object by setting the prediction error input to the lower-level network to zero. This disconnects the model from the input, forcing it to predict the next sequence of parts and “complete” the object. Figure 9a shows that the model has learned to generate plausible predictions of parts given an initial glimpse.
4.4.3 Transfer Learning
We tested transfer learning for reconstruction of unseen character classes for the Omniglot data set. We trained a two-level APC model to reconstruct examples from of classes from each Omniglot alphabet. The rest of the classes were used to test transfer: the trained model had to generate new “programs” (via the state and action hypernets) to predict parts for new character classes for each alphabet. The model successfully performed this task (see Table 1 and Figure 9b).
. | MNIST . | FMNIST . | Om-Tst . | Om-Trn . | affNIST . |
---|---|---|---|---|---|
RB | 0.0097 | 0.0138 | 0.0222 | 0.0220 | 0.0116 |
APC-1 | 0.0072 | 0.0107 | 0.0205 | 0.0207 | 0.0055 |
APC-2 | 0.0070 | 0.0124 | 0.0191 | 0.0193 | 0.0063 |
. | MNIST . | FMNIST . | Om-Tst . | Om-Trn . | affNIST . |
---|---|---|---|---|---|
RB | 0.0097 | 0.0138 | 0.0222 | 0.0220 | 0.0116 |
APC-1 | 0.0072 | 0.0107 | 0.0205 | 0.0207 | 0.0055 |
APC-2 | 0.0070 | 0.0124 | 0.0191 | 0.0193 | 0.0063 |
Notes: See text for details. FMNIST, Om-Tst, and Om-Trn denote Fashion-MNIST, the Omniglot test and transfer data sets, respectively.
4.4.4 Ablation Studies
To test the utility of having two levels of abstraction, we compared the reconstruction performance of the two-level APC model (APC-2) to a one-level model (APC-1) and a randomized baseline model (RB), which samples glimpses (same size as APC-1 and APC-2) from i.i.d. locations ( is the same value as for APC-1 and APC-2, that is, 9 for MNIST/FMNIST, 12 for Omniglot), extracts an average feature vector and feeds this to a feedforward network to reconstruct the image. As shown in Table 1, both APC models clearly outperform RB, indicating that they learn to sample glimpses in an intelligent manner. APC-2 and APC-1 have comparable performance on MNIST. APC-2 clearly outperforms APC-1 on the standard and transfer Omniglot tasks, demonstrating the advantage of dynamically generating “programs” to parse novel characters. APC-1 outperforms APC-2 and RB by a wide margin on the FMNIST task. For this data set, sampling the borders of the image is a very effective strategy. APC-1 also performs better on affNIST, a data set for which an effective strategy requires initial exploration to locate the digit. Note that APC-2 is more restricted in terms of sampled glimpse locations, since for each macrostep, these are confined within the respective top-level frame of reference. However, employing nested references frames endows APC-2 with the power of hierarchical abstraction and compositionality, which pays off in transfer learning tasks as seen in our Omniglot examples.
Table 2 shows the number of learnable parameters, which is comparable across models and data sets. Note that a comparison with models such as Attend, Infer, Repeat (AIR; Eslami et al., 2016) or variational autoencoders (VAEs; Kingma & Welling, 2014), also tasked with image reconstruction, would not be a fair comparison since these models have access to the entire image at once while the APC model has to learn how to intelligently sample only a subset of the input.
. | MNIST . | FMNIST . | Om-Tst . | Om-Trn . | affNIST . |
---|---|---|---|---|---|
RB | 2.09M | 2.09M | 2.15M | 2.15M | 2.30M |
APC-1 | 2.06M | 2.06M | 2.13M | 2.13M | 2.27M |
APC-2 | 2.22M | 2.22M | 2.73M | 2.73M | 2.43M |
. | MNIST . | FMNIST . | Om-Tst . | Om-Trn . | affNIST . |
---|---|---|---|---|---|
RB | 2.09M | 2.09M | 2.15M | 2.15M | 2.30M |
APC-1 | 2.06M | 2.06M | 2.13M | 2.13M | 2.27M |
APC-2 | 2.22M | 2.22M | 2.73M | 2.73M | 2.43M |
Notes: See text for details. FMNIST, Om-Tst, and Om-Trn denote Fashion-MNIST, the Omniglot test and transfer data sets, respectively.
5 Application II: Hierarchical Planning
We now show that the same APC model we used above for learning part-whole hierarchies can also be used for a very different problem: learning hierarchical world models for efficient planning. We introduce a new compositional, scalable “multirooms” navigation task to illustrate this.
Consider the problem of navigating from any starting location to any goal location on a given floor of a large building such as the one in Figure 10A (gray: walls; blue circle: agent; green square: current goal). In the traditional (nonhierarchical) reinforcement learning (RL) approach for solving such a problem, the states are the discrete locations in the grid, and the actions are going north (N), east (E), south (S), or west (W). A large reward (10) is received at the goal location, with a small negative reward (0.1) for each action to encourage shortest paths.
5.1 Problems with Traditional Reinforcement Learning
The traditional RL approach suffers from the following problems: The first is sample inefficiency. As the environment gets larger, the number of interactions with the environment required to learn the value function becomes impractically large. The second is risk of catastrophic consequences: taking actual actions in the real world to estimate the value function might have catastrophic consequences (injury or death). The third is inflexibility: for every new goal or task, a new value function needs to be learned. Hierarchical reinforcement learning and safe reinforcement learning have been proposed as ways of addressing some of these problems (Sutton et al., 1999; Botvinick et al., 2009; Garcıa & Fernández, 2015; Kulkarni et al., 2016; Nachum et al., 2019; Hafner et al., 2022). The APC model differs from previous approaches in asserting the need for abstracting not only actions (policies) but also state transition functions in a hierarchical manner, as depicted in Figure 1.
5.2 How the APC Model Solves These Problems
Just as an object (e.g., an MNIST digit) consists of the same parts (e.g., strokes) occurring at different locations, the multi-rooms environment in Figure 10A is also made up of the same two components (“Room types” R1 and R2), shown in Figure 10B, occurring at different locations (some example locations highlighted by yellow and red boxes in Figure 10C).
These components form part of the higher-level states in the APC and are defined by state embedding vectors (say, and ), which can be trained to generate, via the hypernet (see Figure 1), the lower-level transition functions for rooms R1 and R2, respectively. Next, similar to how the APC model was able to reconstruct an image using top-level action embedding vectors to generate policies and actions (locations) to compose parts using strokes, the APC model can compute top-level action embedding vectors (option vectors) for the multi-rooms world that generate, via hypernet (see Figure 1), bottom-level policies that produce primitive actions (N, E, S, W) to reach a goal encoded by (note that we use the subscript for here and in the next two sections to denote a particular goal rather than time).
5.3 Local Reference Frames Allow Policy Reuse and Transfer
Fig-ure 10D illustrates the bottom-level policies for three such action-embedding vectors , , and , which generate policies for reaching goal locations 1, 2, and 3, respectively. Note that the are defined with respect to higher-level state or corresponding to room type R1 or R2. Defining these policies to operate within the local reference frame of the higher-level state or (regardless of global location in the building) confers the APC model with enormous flexibility because the same policy can be reused at multiple locations to solve local tasks (here, reach subgoals within R1 or R2).
For example, to solve the navigation problem in Figure 10C, the APC model only needs to plan and execute 3 higher-level actions or options: followed by followed by , compared to planning a sequence of 12 lower-level actions to reach the same goal. Finally, since the embedding space of options is continuous, the APC model offers an unprecedented opportunity to exploit properties of this embedding space (such as smoothness) to interpolate or extrapolate to create and explore new options for transfer learning; this possibility will be explored in future work.
5.4 Results
For simplicity, we assume the higher-level states capture local reference frames and are defined by an embedding vector generating the transition function for “room type” R1 or R2, along with the location for this local reference frame in the global frame of the building. The lower-level action network is trained to map a higher-level action embedding vector to a lower-level policy that navigates the agent to a particular goal location within R1 or R2. For the current example, eight embedding vectors were trained, using REINFORCE-based RL (Williams, 1992) to generate via the hypernet eight lower-level policies to navigate to each of the four corners of room types R1 and R2. The higher-level state network was trained to predict the next higher-level state (decoded as an image of room type R1 or R2, plus its location) given the current higher-level state and higher-level action.
The trained higher-level state network was used for planning at each step a sequence of four higher-level actions using “random-sampling shooting” model-predictive control (MPC) (Richards, 2004): random state-action trajectories of length 4 were generated using by starting from the current state and picking one of the four random actions for each next state; the action sequence with the highest total reward was selected, and its first action was executed. Figures 11A and 11B show an example of this high-level planning and MPC process using the trained APC model. Such an approach to planning is closely related to the ideas of planning by inference (Attias, 2003; Verma & Rao, 2005, 2006; Botvinick & Toussaint, 2012; Levine, 2018) and active inference (Friston et al., 2011, 2017; Fountas et al., 2020).
We compared the two-level APC model with both a heuristic lower-level-only planning algorithm and a REINFORCE-based RL algorithm using primitive states and actions. The task involved navigating to a randomly selected goal location in a building environment (as in Figure 10A), with the goal location changing after some number of episodes.
Figure 11C shows how the APC model, after an initial period spent on learning the hypernet to generate the lower-level options, is able to cope with goal changes and successfully navigate to each new goal by sequencing high-level actions (10 reward for goal; 0.1 per primitive action). The RL algorithm experiences a drop in performance after a goal change and does not recover even after 500 episodes.
Figure 11D demonstrates the efficacy of APC’s higher-level planning compared to lower-level planning (MPC using random sequences of four primitive future actions; Euclidean distance heuristic): the average number of planning steps to reach the goal increases dramatically for larger distances from the goal for lower-level compared to higher-level planning.
Additional results and details regarding the application of the APC model to hierarchical planning can be found in the Supplementary Information section online at https://doi.org/10.1162/neco_a_01627.
6 Discussion and Conclusion
Our results represent a first proof-of-concept demonstration of how the APC model can offer a unified approach to modeling a diverse set of problems that have previously required very different approaches, for example, active vision and part-whole learning (Hinton, 2021), learning reference frames (Hawkins, 2021), and planning using state-action hierarchies (Hutsebaut-Buysse et al., 2022).
The APC model for active vision employs eye movements for visual sampling, and performs end-to-end learning and parsing of part-whole hierarchies from images. The incremental parsing of a scene by the APC model using state-action recurrent networks also suggests an explanation for why our visual perception remains stable despite the staccato nature of our eye movements. We showed how the same framework can also be used to solve a complex navigation task by modeling a large environment as being composed of simpler components and using hierarchical states and actions for efficient planning.
Our results relied on assumptions such as a fixed number of time steps at each level, a two-level hierarchy, hard-coded glimpse operators for active vision, and preidentified higher-level states and actions for planning. Future work will involve relaxing these assumptions, comparing the APC model to other part-whole learning approaches using different metrics, and employing more sophisticated planning and RL methods to scale the model to more complex tasks.
We expect adding more levels to the model to boost its capability by allowing it to learn even more abstract state-action representations; this capability is not limited by sensor resolution or action space dimensionality, though the benefits of adding additional levels will depend on the complexity and compositional nature of the environment and problem. Adding more levels would potentially allow the model to handle more complex environmental dynamics and solve more complex tasks via a hierarchical divide-and-conquer strategy, breaking down the dynamics into compositions of simpler transition functions and breaking down long and complex sequences of low-level actions into simpler compositions of macro-actions. For example, adding a third level to our network for the navigation problem would allow learning of third-level abstract actions composed of sequences of macro-actions; these could be used, for example, for hierarchical planning between buildings in a larger campus environment. More broadly, within the domain of reinforcement learning problems, adding a third level to the APC model could enable learning sequences of macro-actions to solve complex problems composed of multiple subproblems, each solved by a macro-action. In the image parsing problem, adding a third level could, for example, allow frames of reference of larger spatial extent, transformations of the agent’s point of view in three-dimensional space relative to the object (such that each such transformation would result in a different two-dimensional view of the object), and other higher-level abstractions of object and scene properties.
Models based on transformers (Vaswani et al., 2017) have recently demonstrated tremendous promise in various vision-related tasks (Dosovitskiy et al., 2021). Soft attention models such as transformers and hard attention models such as ours represent two approaches with significant conceptual differences; the first assumes access to the whole input at once, while the latter uses a policy to sequentially and intelligently sample only parts of the input. The plain version of transformers scales quadratically with the input, while sequential models such as APC or RAM (Mnih et al., 2014) can be interpreted as having constant size regardless of the input size. While numerous recent approaches (Dao et al., 2022) are aimed at reducing the computational demands of transformers, an alternative approach (and a promising future research direction) would combine them into a hybrid model. Such a model would dynamically combine information (soft attention) from a restricted subset of the input (hard attention).
Another potentially fruitful direction for future research is to formulate a fully probabilistic version of the APC model and derive methods for approximate inference of posterior probability distributions over states and actions (see Levine, 2018 and Friston et al., 2017). Such a probabilistic version of the APC model would allow uncertainty-based reasoning and action selection within a hierarchical POMDP framework.
We began this article with the goal of identifying a unifying computational framework inspired by the theoretical problems posed by various researchers in section 2. We regard APC as one such framework, capable of tackling a diverse set of problems that previously were studied in different subareas of AI and cognitive science (vision, reinforcement learning, navigation, and compositional learning). Thus, rather than comparing performance on a single benchmark or problem, we provided proof-of-concept results demonstrating the diverse computational abilities of the same framework across multiple domains.
The APC model is inspired by the growing interest in predictive coding as a model for understanding cortical computation (Rao & Ballard, 1999; Keller & Mrsic-Flogel, 2018; Jiang & Rao, 2022b). The APC model augments the traditional predictive coding framework with state-action hierarchies and shows how these can be learned using hypernetworks. More broadly, our results showing diverse applicability of the same APC framework across multiple domains lend strong computational support to the long-cherished and much-discussed hypothesis that there may be a common computational principle operating across the cortex (Creutzfeldt, 1977; Mountcastle, 1978; Hawkins, 2021; Rao, 2022).
Acknowledgments
We thank the reviewers for their constructive suggestions. We also thank Ares Fisher, Preston Jiang, and Prashant Rangarajan for helpful discussions during the course of this work. This research was supported by the National Science Foundation EFRI grant 2223495, the Defense Advanced Research Projects Agency under contract HR001120C0021, a Weill Neurohub Investigator grant, a UW Amazon Science Hub grant, and a Frameworks grant from the Templeton World Charity Foundation. The opinions expressed in this article are our own and do not necessarily reflect the views of the funders.
Notes
For simplicity, we use the classic REINFORCE algorithm (Williams, 1992) with backpropagation, but more efficient algorithms such as proximal policy optimization (PPO) may also be used.
Code for the APC model is available at https://github.com/gklezd/apc-vision.