Abstract
Recent models for natural language understanding are inclined to exploit simple patterns in datasets, commonly known as shortcuts. These shortcuts hinge on spurious correlations between labels and latent features existing in the training data. At inference time, shortcut-dependent models are likely to generate erroneous predictions under distribution shifts, particularly when some latent features are no longer correlated with the labels. To avoid this, previous studies have trained models to eliminate the reliance on shortcuts. In this study, we explore a different direction: pessimistically aggregating the predictions of a mixture-of-experts, assuming each expert captures relatively different latent features. The experimental results demonstrate that our post-hoc control over the experts significantly enhances the model’s robustness to the distribution shift in shortcuts. Additionally, we show that our approach has some practical advantages. We also analyze our model and provide results to support the assumption.1
1 Introduction
The datasets for natural language understanding (NLU) often contain simple patterns correlated with target labels, which are unintentionally introduced by annotators’ simple heuristics, preferences, etc. (Gururangan et al., 2018; Geva et al., 2019). More fundamentally, the compositional nature of natural language inherently introduces tokens that correlate with target labels individually (Gardner et al., 2021). For example, word overlap (McCoy et al., 2019; Zhang et al., 2019) and specific vocabulary, such as negations (Gururangan et al., 2018; Schuster et al., 2019), are known to have such correlations. However, these correlations are not guaranteed to hold in general and are therefore called spurious correlations (Feder et al., 2022). The simple patterns are easy to exploit, so recent NLU models are inclined to take advantage of them. This exploitation or the exploited patterns themselves are called shortcuts (Makar et al., 2022; Feder et al., 2022; Du et al., 2021; Meissner et al., 2022).2
At inference time, shortcuts often result in inaccurate predictions under relevant distribution shifts. The shifts can occur, for example, when test data are collected from annotators with different heuristics or preferences (Geva et al., 2019; McCoy et al., 2019; Zhang et al., 2019; Schuster et al., 2019). Data from the same distribution as the training data is referred to as in-distribution (ID) data, while data from a distribution shifted relative to the training data is referred to as out-of-distribution (OOD) data. Figure 1 shows the examples of ID and OOD data.
A simple solution to this problem is to eliminate reliance on shortcuts, which is the mainstream approach, including recent studies in NLU (Clark et al., 2019; He et al., 2019; Mahabadi et al., 2020). Typically, those methods up-weight training instances where some known shortcuts cannot predict labels correctly and down-weight the others. A practical deficiency of this approach arises in a performance trade-off between ID and OOD data. It deviates models from ID data by eliminating shortcuts, which are valid features in ID data. Due to this trade-off, another practical problem arises where the hyperparameter search has to be made using OOD test or validation data, as noticed as the limitation of previous work (Clark et al., 2019; Mahabadi et al., 2020; Clark et al., 2020a; Ghaddar et al., 2021; Liu et al., 2021; Creager et al., 2021; Yu et al., 2022; Yang et al., 2023). Even when using OOD validation data, its distribution is the same as that of test data. Thus, in other words, the approach requires knowing the test-time distribution to tune hyperparameters, which is impractical in testing OOD robustness.
In this paper, we opt not to pursue training to eliminate shortcuts. Instead, we propose to aggregate predictions of a mixture model during inference. The problem with shifts in shortcuts is that some latent features in the training data are no longer associated with the labels. We hypothesize that this OOD situation can be addressed by effectively aggregating predictions, assuming that the predictions are based on relatively different latent features. We propose a mixture model and its training strategy to encourage such modeling of latent features. At inference time, we perform theoretically grounded risk minimization strategies through post-hoc control for the predictions in the event of potential shifts in shortcuts.
The experimental results demonstrate that our method significantly enhances the model’s robustness when faced with shifts in shortcuts. Moreover, our method shows two other practical benefits that address the problems of previous methods. First, the mixture weights of our model can be used to detect shifts in latent features during inference. This opens up the possibility of adaptive post-hoc control to address the performance trade-off between ID and OOD data. Second, hyperparameters can be tuned with ID data only, removing the need to tune hyperparameters with OOD data. We also analyze our mixture model and provide results supporting the assumption of modeling latent features.
2 Background
This section first overviews shortcuts. Then, we describe how previous approaches have addressed shortcuts and outline how we approach them. Below, and denote the input instance space and the entire class of target labels, respectively.
2.1 Shortcuts in Detail
Shortcuts or spurious correlations arise when (1) some feature a related to input is predictive of label in training data, (2) but this association between a and y changes under relevant distribution shifts (Makar et al., 2022; Feder et al., 2022). Often, those features are latent, that is, difficult to identify a priori. Among those latent features, shortcuts refer to those that are easy to represent; sometimes, they refer to the exploitation of such latent features (Makar et al., 2022; Feder et al., 2022; Du et al., 2021; Meissner et al., 2022). Following Makar et al. (2022), we emphasize that the ease of modeling is an important characteristic of shortcuts. It enables models to capture and depend on the latent features, thereby posing a serious threat when the relevant distribution shifts.
Shortcuts are pervasive in NLU datasets due to the simple heuristics, preferences, etc., possessed by annotators (Gururangan et al., 2018; Geva et al., 2019). Shortcut-dependent models severely degrade performance on datasets collected with different heuristics and preferences (Geva et al., 2019; McCoy et al., 2019; Zhang et al., 2019; Schuster et al., 2019). Moreover, there is a more fundamental discussion that the compositional nature of natural language produces many simple features (e.g., words and phrases) that can robustly predict labels when the entire context is considered but are only spuriously correlated when considered individually (Gardner et al., 2021; Eisenstein, 2022). Figure 1 shows an illustrative example where simple word-overlap features and length features are associated with labels in training and ID data, but the association drastically changes in OOD data.
While pervasive, note that shortcuts are only part of the distribution-shift problem. For example, shortcuts can be viewed as a special case of domain shift (Feder et al., 2022) and can arise independently from the shift in label distribution p(y) (Yang et al., 2023). Also, the problem of shortcuts is one of the consequences of underspecification, where distinct solutions can solve the problem equivalently well (D’Amour et al., 2022). Following Makar et al. (2022), we address distribution shifts exclusively in terms of shortcuts. Consequently, the OOD data we address involve shifts in the association between a and y.
2.2 Overview of Previous Approaches
To improve the OOD performance, previous studies have tried to remove the reliance on shortcuts. In the study of NLU, a widely used approach is reweighting (Clark et al., 2019; He et al., 2019; Mahabadi et al., 2020). This approach reweights instances to reduce learning on shortcut-inducing instances and increase learning on the others. The weights are computed based on how accurately shortcuts predict labels. During training, instances where shortcuts are predictive are down-weighted, and the others are up-weighted.
In the machine learning (ML) literature, training data is first partitioned into groups (also called environments) based on the spuriously correlated features. The training data is assumed to be a mixture of the groups divided by the features. Previous approaches avoid relying on shortcuts, that is, the group-specific spurious correlations. There are two principal approaches in the ML literature. Invariant risk minimization (IRM; Arjovsky et al., 2019) trains a classifier that is simultaneously optimal for all groups. Group distributionally robust optimization (GroupDRO; Sagawa et al., 2020) learns to minimize the worst-group risk by up-weighting the loss of the worst-case group.
2.3 Problems of Previous Approaches
These approaches share one common idea: training models while minimizing reliance on shortcuts to achieve robust predictions. In practice, however, daring to eliminate predictive features in the training data and its ID data causes deviations from the ID data, resulting in a performance trade-off between the ID and OOD data. See, for example, the aforementioned work on reweighting, IRM, and GroupDRO for empirical results. This trade-off raises the following practical problems.
Overfitting to OOD Data.
The degraded performance on ID data is a direct consequence of this trade-off (Utama et al., 2020a). Evaluating worst-case performance or performance on adversarial OOD data is essential for assessing generalization, and this study also aims to improve on these evaluations. However, such extreme distribution shifts do not always occur after model deployment, so it is desirable for practical purposes to be able to deal with ID data as well.
Hyperparameter Tuning with OOD Data.
An indirect but more fundamental problem is the need for OOD test or validation data to tune hyperparameters. This problem arises because the trade-off makes it difficult to predict performance on OOD data simply by looking at performance on ID data. Obtaining OOD test data (in the context of ML, worst-group) requires pre-identification of shortcuts and their test-time distribution. Even when using OOD validation data, its distribution needs to be the same as test data’s. In testing OOD robustness, this requirement is clearly impractical.
Initial studies used training data where shortcuts are pre-identified, in addition to OOD test or validation data (Clark et al., 2019; Arjovsky et al., 2019; Sagawa et al., 2020, inter alia). Pre-identification of shortcuts is costly as it requires careful analysis of given data. Seeking more practical solutions, subsequent approaches followed that did not require pre-identification of shortcuts in training data (Clark et al., 2020a; Liu et al., 2021; Creager et al., 2021, inter alia). However, they still need OOD test or validation data related to pre-identified shortcuts to tune hyperparameters. This requirement has been discussed as a serious limitation for practical use (Clark et al., 2019; Mahabadi et al., 2020; Clark et al., 2020a; Ghaddar et al., 2021; Liu et al., 2021; Creager et al., 2021; Yu et al., 2022). Moreover, the performance of those approaches on OOD data is considerably low without the hyperparameter tuning on OOD data (Yang et al., 2023).
2.4 Overview of Our Approach
In this study, we explore a different direction from previous approaches: We aggregate predictions of a mixture model. Our hypothesis is that effective aggregation of predictions enables addressing potential shifts in shortcuts, suppose the predictions are based on relatively different latent features.
In addition, this approach allows us to address the problems described in Section 2.3 as follows. We have the flexibility to address both ID and OOD data scenarios through the adaptable application or omission of post-hoc control techniques.4 This adaptability sets our approach apart from the existing methods, which typically rely on fitting a single model exclusively for either the ID or OOD case. (b′) Since our method focuses on fitting ID data during training, it does not require OOD data for training or tuning hyperparameters.
3 Methods
The proposed method consists of two parts. The first part is a training method using a mixture model. The second part is a test-time operation, which aggregates the mixture model’s predictions to make robust predictions when facing distribution shifts. Figure 2 in Appendix A shows the overview of our method.
3.1 Training Phase: Mixture Model to Capture Latent Features
To model latent features, we employ a mixture model, as seen in Eq (3). A typical implementation of the mixture model is mixture-of-experts (MoE) (Jacobs et al., 1991). Shazeer et al. (2017) showed that MoE improves performance and efficiency in large-scale deep-learning models. Following these studies’ success, we employ a variant of MoE in this study.
Our Implementation of MoE.
Penalty Term for π: Different Experts for Different Latent Features.
Comparing Eq. (2) and (4), we see that different experts are expected to capture different latent features that predict labels. However, this expectation does not hold when the mixture weights are consistently uniform or dominated by the same few experts across all the training instances. In those cases, all the experts or the few experts capture the latent features indistinguishably. To facilitate capturing the mixture of latent features at the mixture architecture, we propose a penalty term that constrains the router π. Intuitively, it encourages the router to assign different inputs to different experts, assuming that different inputs have differences in their latent features to some extent.
We use this penalty term with the following modifications. First, the penalty cannot be minimized to zero when the mini-batch size M exceeds the number of experts K. During joint minimization with the classification loss C, forcing the minimization of the never-zero penalty could lead models too far away from the optimal solution for C. To avoid this, we consider that in the m-th row of Π⊤Π − I, the corresponding expert for xm captures the same latent features among the top-ℓ elements (instances). We then exclude those elements to allow such multi-instance assignments. Let be a function to drop out the elements with the top-ℓ values in each row to 0 (Algorithm 1). We minimize .
Second, the penalty varies highly depending on the batch size M, as the Frobenius norm takes the sum, not the mean, of the squares of the matrix elements. To search for weighting hyperparameters robust to changes in M, we normalize the penalty into [0, 1]. We divide the penalty by , where J ∈ℝM×M is a matrix of ones.
Previous studies on MoE observed that assignment was concentrated on the same few experts and proposed penalty terms to balance the assignment among experts (Shazeer et al., 2017; Lepikhin et al., 2021; Fedus et al., 2022b, a). However, these penalty terms were proposed for text generation models and cannot be applied directly to the classification model in consideration. Notably, the penalty terms encourage balanced, uniform assignments but do not encourage diverse assignments that vary among groups of instances.
3.2 Inference Phase: Post-Hoc Control for Risk Minimization under Uncertainty
The problem of shortcuts emerges upon the distribution shifts where some latent features are no longer associated with labels. In this subsection, we consider controlling the mixture weights to minimize the risk under the OOD circumstances where we do not know which latent features to rely on during inference. For this control, we suppose that different experts capture those latent features with small overlaps, as encouraged in the training (see Section 3.1: penalty term). However, note that this is not a strict requirement, and moderate differences may be sufficient. The point is that not all experts depend on the same latent features. We introduce two post-hoc operations on π to ensure predictions remain robust to such shifts. The operations replace the estimated π with π* according to the theory of risk minimization under uncertainty.
Uniform Weighting.
Argmin Weighting.
Derivation of the Operations.
The remainder of this subsection further explains the principles behind the prediction rules introduced earlier, with a focus on risk minimization. This perspective is rooted in the classical statistical decision-making framework (see Wald, 1950; Berger, 1985).
4 Experiments
Our goal is to achieve predictions robust to distribution shifts related to shortcuts. In this section, we test whether the proposed post-hoc control improves performance on those OOD tests and analyze the mechanism based on our assumption.
4.1 Setup
This subsection describes the experimental setup. Please refer to Appendix B for further details.
Datasets.
In accordance with previous research on shortcut mitigation, we experimented with three NLU datasets. These are popular datasets but are all reported to induce shortcuts, and OOD test data were later created that cannot be correctly classified by the shortcuts. Each dataset consists of training data, validation data drawn from the same distribution as the training data (ID dev), and test data where the correlation between some latent features and labels changed adversarially (OOD test). Following previous studies in comparison, we evaluate the accuracy.
MNLI (Williams et al., 2018) is a dataset for natural language inference (NLI) across multiple genres. Given a pair of premise and hypothesis sentences, the task is to classify the relationship between the two sentences into one of three labels: entailment, contradiction, or neutral. In MNLI, a shortcut arises from a spurious correlation between the word overlap of input sentences and target labels. We used its matched development set as ID dev and HANS (McCoy et al., 2019) as OOD test. QQP is a dataset for paraphrase identification. The task is to classify whether two sentences are paraphrases or not. A shortcut also arises from a spurious correlation in the word overlap of input sentences. We used its development set as ID dev and PAWS (Zhang et al., 2019) as OOD test. FEVER (Thorne et al., 2018) is a dataset for fact verification. Given two sentences of claim and evidence, the task is to classify the relation of the evidence toward the claim into either Supports, Refutes, or Not-enough-info. Some negative phrases in the claim sentences spuriously correlate with target labels, causing a shortcut that allows classification using only the claim sentences. We used its development set as ID dev and FEVER Symmetric v1 and v2 (Schuster et al., 2019) as OOD tests.
Baseline and Principal Methods.
We used BERT (bert-base-uncased) (Devlin et al., 2019) as the baseline and backbone for a fair comparison with previous studies. We used the last layer in the position of [CLS] for h in Eq. (7).
To compare in a practical setting where only ID data are available for training and tuning (Yang et al., 2023), we reran principal methods in that setting using their publicly available code.
Conf-reg ♠self-debias (Utama et al., 2020b) and JTT (Liu et al., 2021) use heuristics that weak models are likely to exploit shortcuts. Conf-reg ♠self-debias reweights the loss according to predictions of a weak model while balancing the weights using predictions of a teacher model. JTT up-weights the loss of training instances that a weak model misclassified. RISK (Wu and Gui, 2022) considers shortcuts to be redundant features and applies feature reduction. EIIL (Creager et al., 2021) first estimates the groups of training instances where some shortcuts are in common and then applies IRM (see Section 2.2) using the estimated groups. BAI (Yu et al., 2022) extends EIIL to estimate multiple levels of groups and apply IRM multiple times accordingly. GroupDROlabel-group (Sagawa et al., 2020) and ReWeightCRT (Kang et al., 2020) are reported to perform well on OOD data when the label distribution p(y) is imbalanced in ID data but is uniform in OOD data (Yang et al., 2023), while they do not aimed at addressing the shift in shortcuts. GroupDROlabel-group minimizes the loss on the worst-case class label given groups divided only by class labels, and ReWeightCRT reweights the loss with the relative frequency of class labels.
Hyperparameters.
As Eq. (11) shows, the optimal model for the proposed method is one that can accurately classify x and output diverse π(x). Therefore, we define the optimal hyperparameters for the proposed method as those that minimize the sum of the two losses (C +R) on ID dev.
In the proposed method, the number of experts K in Eq. (4), the number of row-wise dropouts ℓ in Eq. (10), and the loss-weighting value λ in Eq. (11) are model-specific hyperparameters. We explored the values of K ∈{5,10,15} and λ ∈{0.0,0.5,1.0}. For an efficient search, we conducted a two-stage search. At the first stage, we fixed λ = 0 and determined K* that naturally fit the data. Then, we searched for the optimal balance of losses λ under K*. Table 1 shows the results of the hyperparameter search. Across settings, the value of ℓ was set to be the smallest value in 2n that satisfies . This ensures that R in each mini-batch of size M can be zero in all settings when π is maximally diverse: when a different expert is allocated to every ℓ instances with probability one. We used parallel processing of two mini-batches of M = 32 each and , so we set ℓ = 8 to satisfy the condition. Regarding epochs, we set the training epoch to 10 and the learning rate to 2e-5 for all datasets and select the best epoch on ID dev scores without applying post-hoc control.
When rerunning the comparison methods, we set all hyperparameters to the values specified in the papers or the official implementation, except for an annealing hyperparameter α of Conf-reg ♠self-debias, as it was tuned on OOD tests. We took the best epoch on ID dev for all the methods and the best α on ID dev for Conf-reg ♠self-debias.
4.2 Results
As the main results, we demonstrate that our post-hoc control over the experts achieves robust predictions on OOD test data. Table 2 shows the results in the setting where no shortcut is pre-identified. BERT is the baseline, + MoS is our mixture model, and Uniform / Argmin performs the post-hoc control on the mixture model. Since scores on the OOD tests have been reported to have high variance, all the results are shown in the mean and standard deviation of five runs with different seeds in accordance with previous studies. We observe that in all datasets, our post-hoc control significantly improves performance on the OOD tests from the baseline and MoS.
The comparison methods do not improve performance on the OOD tests much when tuned solely with ID data,5 which is consistent with the observation in Yang et al. (2023). As an exception, GroupDROlabel-group and ReWeightCRT perform well on FEVER, where the difference from our method is marginal considering the standard deviation. This is because the label distribution of FEVER shifts as these methods suppose,6 which is also consistent with the observation in Yang et al. (2023). However, they do not improve the OOD performance on MNLI and QQP, which have no such label distribution shift. In contrast, our method does not exploit assumptions on label distribution shifts but consistently improves the OOD performance across all the datasets.
4.3 Analyses
Now, we turn to the mechanism behind our method’s robust performance and analyze the mixture model based on our assumption.
Analysis 1: Penalty Term R in ID and OOD Data.
We first analyze the penalty term R, the essential statistic of our mixture model. Recall that R encourages the router to assign different inputs to different experts, assuming different inputs have some difference in their latent features.7 In other words, R measures the sensitivity to the difference in inputs. Drawing an inference from this, we expect that the value of R differs in the shifts related to latent features, that is, shifts between ID and OOD data we address.
Table 3 shows the value of R on the ID and OOD datasets.8 In all the datasets, the values of R differ significantly between ID and OOD data, indicating that R is sensitive to the distribution shifts in these data.
From a practical perspective, this sensitivity may provide an advantage. We can compute R during inference since its computation does not require annotated labels either in training or inference. Therefore, during inference, we can determine which data to perform the post-hoc control on by looking at how different R is from that on ID data. While the post-hoc control decreases the ID dev scores, MoS performs the same as the baseline on the ID dev, regardless of the training with the penalty term (Table 2). Thus, adaptively applying the post-hoc control enables handling both ID and OOD data. This adaptive use is an advantage over previous methods, which only obtain a single model fitted to either OOD or ID data. However, note that it is limited to when involving a major shift in the distribution of latent features. Since we do not precisely know the threshold for how much difference should be regarded as a threatening shift, it may be difficult to determine in data such as FEVER, where the difference is significant but relatively small.
Interestingly, FEVER differs from the others in how R changes between ID and OOD data. While the others have lower R on ID data and higher R on OOD data, the opposite is true on FEVER. This suggests that the mixture model does not model latent features well on FEVER, and in fact, the performance improvement by performing the post-hoc control is relatively small on FEVER (). Shortcuts in FEVER depend on very local patterns: particular phrases contained only in claim sentences. Our method uses the highly abstracted final-layer features of BERT and may not be good at successfully isolating the effects of the local patterns. The features h in Eq. (7) can be modified arbitrarily, so we leave more effective encoding methods for future work.
Analysis 2: Captured Latent Features and their Interpretability.
Our post-hoc control supposes that different experts capture different latent features to some extent. The robust performance of our post-hoc control supports this assumption but not directly. Toward direct validation, we analyze which experts a particular feature is assigned to.
To this end, we use some data splits in which a specific feature, known as a shortcut, is dominant. For MNLI and QQP, high word overlap is a dominant feature in HANS and PAWS, while their creation process is different. The sentence pairs in HANS were created by replacing or partially deleting some parts of premise sentences, while those in PAWS by word swapping, back translation, and human post-processing. However, FEVER has no such split where a single feature is dominant. To obtain such splits in FEVER, we extracted instances that contain frequent bigrams that are reported to strongly correlate with labels. The bigrams “at least”, “person who”, and “united states” strongly correlate with the Supports label and “did not”, “does not”, and “to be” with the Refutes label (Schuster et al., 2019).9 We created six splits from FEVER ID dev. All the features above are known to be shortcuts.
Figure 2 shows the mixture weights averaged on each split of the datasets. Note that no post-hoc control is performed on the mixture weights here. Overall, a single or a few experts dominate the mixture weights in the splits where specific features are dominant: HANS, PAWS, and the newly created FEVER splits. We also observe that the dominant experts differ among the FEVER splits. Although the features under analysis are limited, this behavior of the mixture weights aligns with our assumption that different experts capture different latent features. As an exception, there are no dominant experts in the data split of “united states”, and the same expert is dominant in the data splits of “did not”, “does not”, and “to be.” However, note that the assumption does not need to hold completely (see Section 3.2), and such local bigram features may be exceptionally difficult to capture in the current encoding (see Section 4.3). Taking these points into account, we consider that the results as a whole support the assumption.
Figure 2 also suggests that the mixture weights provide some degree of interpretability: Instances assigned to the same expert are likely to have the same features in common. Although our mixture model does not specify captured features by itself, the suggested interpretability may allow us to discover new prominent features in data by analyzing the commonalities of the instances assigned to the same expert. The discovery of commonalities in each expert will serve to support the assumption further. We leave this direction as our future work.
Analysis 3: Ablation Study on Mixture Model.
We analyzed the contribution to the performance with respect to the hyperparameters of our mixture model: the number of experts K, the number of row-wise dropouts ℓ, and the loss-weighting value λ. Table 4 shows the ablation study on MNLI. This table shows how performance changes by varying one of the hyperparameters from the values determined to be optimal on the ID dev. There is little to no difference in performance on the ID dev for any given value, but the OOD performance with post-hoc control is best for nearly all the values determined to be best on the ID dev. It is also worth noting that using our R and top-ℓ dropout consistently improves OOD performance better than without using them (when λ or ℓ is zero). These results indicate the effectiveness of the proposed training strategy and hyperparameter search.
We also tested DeBERTav3-large (He et al., 2023) for the encoder gϕ. It has around three times larger parameters than the BERT we used and performs better than BERTlarge, RoBERTalarge (Liu et al., 2019), XLNetlarge (Yang et al., 2019), ELECTRAlarge (Clark et al., 2020b), etc., on MNLI (He et al., 2023). We conducted the same hyperparameter search for DeBERTav3-large and found the best hyperparameters were exactly the same as BERT’s. The results show that the larger model significantly improves not only ID performance but also OOD performance. However, there is still a gap between ID and OOD performance, and applying the proposed method further improves OOD performance. These results indicate that even for large models, our method is effective in improving OOD performance.
Analysis 4: Identifiability of Finite Mixture.
The empirical results clearly demonstrate the effectiveness of our approach. Nevertheless, another limitation of this work is that we do not provide a theoretical guarantee for our mixture model to capture latent features within the data. This issue has previously been studied in the statistical literature and is referred to as the identification problem of finite mixtures. See Huang and Yao (2012), Compiani and Kitamura (2016), and Xiang et al. (2019) for the recent development of finite mixture models. As explained by Compiani and Kitamura (2016) among others, the identification of a finite mixture model is accomplished when predictors have a distinct influence on both the outcome prediction and mixture weights. Consistent with this, our penalty term R is designed to ensure the experts and router play distinct roles in determining the conditional outcome probabilities and the mixture weights. This approach allows our model to effectively capture and reflect the significant variations found within the data. From our empirical Analyses 1 and 3, the penalty term R is indeed understood as an important source of identifying mixture weights. Since our main focus is the excellent performance of our approach in NLU applications, we plan to leave the theoretical analysis of identification for future work.
5 Related Work
As seen in Section 2.1, datasets for NLU tasks are known to have multiple shortcuts due to the simple heuristics, preferences, etc., possessed by annotators (Gururangan et al., 2018; Geva et al., 2019), or more fundamentally, the compositional nature of natural language (Gardner et al., 2021). A number of studies have addressed the problem of shortcuts in NLU, but their primary difference lies in prior knowledge of shortcuts.
Known Shortcut Setting.
This setting allows models to know the existence and details of shortcuts in advance. Previous studies used this prior knowledge to mitigate the identified shortcuts.
Reweighting is the basic strategy of previous methods. They used shortcut-dependent models that only take shortcut features as input, e.g., word overlap (Clark et al., 2019; He et al., 2019; Mahabadi et al., 2020). These shortcut-dependent models let main models know which training instances cannot be predicted correctly via shortcuts and thus should be up-weighted. Xiong et al. (2021) showed that the performance of these methods was further enhanced by calibrating the uncertainty of the shortcut-dependent models. Utama et al. (2020a) additionally employed a teacher model to adjust the weights so that a main model would not deviate too much from the distribution of training data. Belinkov et al. (2019) and Stacey et al. (2020) trained a main model adversarially to a shortcut-dependent classifier.
Izmailov et al. (2022) and Kirichenko et al. (2023) first trained a model on ID data and then re-trained its last classification layer on a small amount of OOD data, showing that this small parameter update for the ID-fitted model is enough to improve OOD performance.
Several approaches used the counterfactual framework of causal inference. To make counterfactual predictions unaffected by shortcuts, Tian et al. (2022) and Niu et al. (2021) combined predictions of a main model and a shortcut-dependent model. Wang and Culotta (2020) classified features into genuine or spurious and selected genuine features for predictions. Others utilized identified spurious features to train model predictions to be invariant to interventions on the spurious features (Veitch et al., 2021; Makar et al., 2022; Puli et al., 2022b).
The above methods effectively mitigate shortcuts but require the significant cost of careful analysis to achieve the prior knowledge of shortcuts.
Unknown Shortcut Setting.
The existence and details of shortcuts are generally unknown. Another line of studies has sought a way to mitigate shortcuts without the cost of manual identification.
The basic strategy is reweighting, just as in the known shortcut setting. To estimate the weights, previous methods have utilized the heuristics that weak models are likely to exploit shortcuts. The weak models include models with limited capacity (Clark et al., 2020a; Sanh et al., 2021), models trained with a limited number of data (Utama et al., 2020b), a single epoch (Du et al., 2021), or shallow layers (Ghaddar et al., 2021; Wang et al., 2022). While these methods used continuous weights, Yaghoobzadeh et al. (2021) used binary weights that take the value of 1 only for training instances that a weak model misclassified. Other approaches applied reweighting when learning to prune fully trained models (Meissner et al., 2022; Du et al., 2023; Liu et al., 2022).
Some studies addressed shortcuts in the feature space by removing redundancy (Wu and Gui, 2022) or correlations (Dou et al., 2022; Gao et al., 2022) in the space. These studies reported the best scores at different epochs for each of the ID validation data and OOD test data, so their results are not directly comparable to the other studies.
The above methods still require OOD test data related to pre-identified shortcuts to tune hyperparameters, as described in Section 2.3. Our method is different from them in that the training and tuning can be conducted solely on ID data. In Section 4.2, we demonstrated that in the setting of fully unknown shortcuts where only ID data are available, our method improves the performance on OOD data significantly better than the previous methods.
Additional Data.
Other studies make use of additional data to mitigate shortcuts. Counterfactual data augmentation is one such study. Counterfactual data were generated using manual annotation (Kaushik et al., 2021), known shortcuts (Wu et al., 2022), or large language models (Wen et al., 2022; Chen et al., 2023). Other studies used human explanation (Stacey et al., 2022a, b) or human gaze signals (Ren and Xiong, 2023) as additional supervision to guide models during training. Although effective, collecting these external data is cost-intensive and requires additional training.
Literature Outside of NLU.
Outside of NLU, shortcuts have been addressed in the ML literature as one of the broader OOD problems (Krueger et al., 2021; Yang et al., 2023). Still, many methods used in ML and NLU tasks have the same concepts in common, such as reweighting (Nam et al., 2020; Liu et al., 2021; Clark et al., 2020a; Utama et al., 2020b), IRM (Creager et al., 2021; Yu et al., 2022), counterfactual invariance (Veitch et al., 2021; Makar et al., 2022; Puli et al., 2022b), and data augmentation (Yao et al., 2022; Puli et al., 2022a; Wu et al., 2022). As described in Section 2.2, IRM (Arjovsky et al., 2019) and GroupDRO (Sagawa et al., 2020) are the two principal approaches. These approaches considered the known shortcut setting, and similar to NLU literature, their follow-up approaches have sought to address shortcuts in the unknown shortcut setting (Nam et al., 2020; Liu et al., 2021; Creager et al., 2021; Yao et al., 2022; Puli et al., 2022a; Izmailov et al., 2022; Kirichenko et al., 2023). However, also similar to NLU literature, those follow-up approaches still require shortcuts to be pre-identified in validation sets (Yang et al., 2023).
6 Conclusion
This study proposed a conceptually novel approach to address the shortcuts problem by pessimistically aggregating the mixture model’s predictions at inference time. We introduced the MoE-based model, a penalty term to encourage different experts to capture different latent features, and post-hoc control for the mixture weights that is theoretically grounded in risk minimization. The experimental results show that our method not only significantly enhances the model’s robustness to shifts in shortcuts but also provides additional benefits to address the previous methods’ problems: the performance trade-off between ID and OOD data and the need for OOD test or validation data to tune hyperparameters.
Our analyses provided results supporting the assumption: Different experts capture different latent features to some extent. However, we also noted the limitations in the encoding method (Analysis 1), the tested features and interpretability (Analysis 2), and the theoretical guarantee of identifiability (Analysis 4). Future work includes improving the encoding method to capture latent features more accurately, analyzing the instances assigned to the same expert to interpret what it captures and further support the assumption, and theoretically accounting for how the penalty term enhances identifiability. While the focus of this study is on shortcuts, another future direction is extending our method to address a broader range of OOD problems (see Section 2.1). We believe these are interesting future research departing from this study.
Acknowledgments
We thank Jacob Eisenstein, who served as our TACL action editor, and the anonymous reviewers for their insightful comments.
Notes
The code is available at https://github.com/CyberAgentAILab/posthoc-control-moe.
Shortcuts are also called dataset bias. However, we avoid using this term because it is confusing with the social bias or bias of an estimator. Similarly, we do not use the term debiasing in this paper.
We assume that it is reasonable to consider the finite latent features. Generally speaking, model predictions are likely to depend on simple latent features strongly associated with labels, not evenly dependent on infinite possible latent features. In addition, it is empirically established that finite mixture models can approximate a wide variety of distributions, as long as a sufficient number of mixture components are included (Titterington et al., 1985; Walker and Ben-Akiva, 2011; Nguyen et al., 2020).
While this adaptive use of post-hoc control necessitates determining whether the test data falls under the ID or OOD category, the results presented in Table 3 suggest that changes in the mixture weights can effectively discern this distinction at inference time. We revisit this point in Section 4.3.
Conf-reg ♠self-debias reported taking the last epoch of arbitrarily determined epochs rather than ID dev best epoch. We also reported the performance of the last-epoch models (Conf-reg ♠self-debiaslast), but we found that this practice did not work well when its annealing hyperparameter α (see Section 4.1) was tuned solely with ID data.
The label distribution of FEVER is approximately Supports:Refutes:Not-enough-info = 2:1:2 in the training data but Supports:Refutes:Not-enough-info = 1:1:0 in both ID dev and OOD test. Thus, supposing the flat label distribution for Supports and Refutes improves the OOD performance even without addressing shortcuts.
While not an inevitable consequence of this objective nor a requirement for our method, we analyzed how different experts output different predictions. Figure 5 in Appendix A shows a significant variance between experts’ predictions.
HANS is sorted by the type of shortcuts, so we shuffled the order before computing R. We did not observe this kind of sorted pattern in the other datasets.
We omitted Not-Enough-Info labels since the FEVER ID dev and OOD tests have no instances with this gold label.
References
A Additional Figures
Figure 3 shows an overview of our method (Section 3), and Figure 4 illustrates an example of our Argmin weighting (Section 3.2).
Figure 5 shows the average prediction of each expert, calculated across MNLI ID dev (Section 4.3). We observe a significant variance between experts’ predictions, which indicates that different experts tend to make different predictions.
B Further Setup Details
Table 5 specifies the URLs of the datasets, pre-trained models, and code of the previous methods we introduced in Section 4.1.
Datasets . | |
---|---|
MNLI | https://cims.nyu.edu/˜sbowman/multinli/ |
HANS | https://github.com/tommccoy1/hans |
QQP and PAWS | https://github.com/google-research-datasets/paws |
FEVER and Symm. v1/v2 | https://github.com/TalSchuster/FeverSymmetric |
Pre-Trained Models | |
BERT | https://huggingface.co/bert-base-uncased |
DeBERTav3-large | https://huggingface.co/microsoft/deberta-v3-large |
Code | |
Conf-reg ♠self-debias* | https://github.com/UKPLab/emnlp2020-debiasing-unknown |
JTT* | https://github.com/YyzHarry/SubpopBench |
RISK | https://github.com/CuteyThyme/RISK |
EIIL | https://github.com/PluviophileYU/BAI |
BAI | https://github.com/PluviophileYU/BAI |
GroupDROlabel-group* | https://github.com/YyzHarry/SubpopBench |
ReWeightCRT* | https://github.com/YyzHarry/SubpopBench |
Datasets . | |
---|---|
MNLI | https://cims.nyu.edu/˜sbowman/multinli/ |
HANS | https://github.com/tommccoy1/hans |
QQP and PAWS | https://github.com/google-research-datasets/paws |
FEVER and Symm. v1/v2 | https://github.com/TalSchuster/FeverSymmetric |
Pre-Trained Models | |
BERT | https://huggingface.co/bert-base-uncased |
DeBERTav3-large | https://huggingface.co/microsoft/deberta-v3-large |
Code | |
Conf-reg ♠self-debias* | https://github.com/UKPLab/emnlp2020-debiasing-unknown |
JTT* | https://github.com/YyzHarry/SubpopBench |
RISK | https://github.com/CuteyThyme/RISK |
EIIL | https://github.com/PluviophileYU/BAI |
BAI | https://github.com/PluviophileYU/BAI |
GroupDROlabel-group* | https://github.com/YyzHarry/SubpopBench |
ReWeightCRT* | https://github.com/YyzHarry/SubpopBench |
Following the fine-tuning hyperparameters of DeBERTav3-large (He et al., 2023), we set the learning rate to 5e-6 and used gradient clipping with the maximum gradient norm of 1.0 in the ablation study with DeBERTav3-large (Section 4.3).
Author notes
Action Editor: Jacob Eisenstein