Abstract
Self-supervised learning (SSL) methods such as Word2vec, BERT, and GPT have shown great effectiveness in language understanding. Contrastive learning, as a recent SSL approach, has attracted increasing attention in NLP. Contrastive learning learns data representations by predicting whether two augmented data instances are generated from the same original data example. Previous contrastive learning methods perform data augmentation and contrastive learning separately. As a result, the augmented data may not be optimal for contrastive learning. To address this problem, we propose a four-level optimization framework that performs data augmentation and contrastive learning end-to-end, to enable the augmented data to be tailored to the contrastive learning task. This framework consists of four learning stages, including training machine translation models for sentence augmentation, pretraining a text encoder using contrastive learning, finetuning a text classification model, and updating weights of translation data by minimizing the validation loss of the classification model, which are performed in a unified way. Experiments on datasets in the GLUE benchmark (Wang et al., 2018a) and on datasets used in Gururangan et al. (2020) demonstrate the effectiveness of our method.
1 Introduction
Self-supervised learning (Bengio et al., 2000; Mikolov et al., 2013; Devlin et al., 2019; Radford et al., 2018; Lewis et al., 2020), which learns data representations by solving prediction tasks defined on input data without leveraging human- provided labels, has achieved broad success in NLP. Many NLP-specific self-supervised learning (SSL) methods have been proposed, such as neural language models (Bengio et al., 2000), Word2vec (Mikolov et al., 2013), BERT (Devlin et al., 2019), GPT (Radford et al., 2018), BART (Lewis et al., 2020), and so forth, with various SSL tasks defined. For example, in Word2vec and BERT, the SSL task is predicting the identities of masked tokens based on their contexts. In neural language models including GPT, the SSL task is language modeling: Given a history of tokens, predict the next token.
Recently, contrastive self-supervised learning (He et al., 2020; Chen et al., 2020) has been borrowed from vision domains into NLP and has shown promising success in predicting semantic textual similarity (Gao et al., 2021), machine translation (Pan et al., 2021), relation extraction (Su et al., 2021), and so on. The key idea of contrastive self-supervised learning (CSSL) is: Create augments of original examples, then learn representations by predicting whether two augments are from the same original data example. In existing CSSL approaches, data augmentation and contrastive learning are performed separately. As a result, augmented data may not be optimal for contrastive learning. For example, considering a back-translation (Sennrich et al., 2016)–based augmentation method, if the translation model is trained using news corpora, it is not suitable for augmenting data for movie review data.
In this paper, we aim to address this issue. We propose a four-level optimization framework that performs data augmentation and contrastive learning end-to-end in a unified way, to allow the data augmentation models to be guided by the contrastive learning task and make augmented data suitable for performing contrastive learning. We assume the end task is text classification. Our framework consists of four learning stages. At the first stage, we train four translation models to perform sentence augmentation based on back translation. To account for the fact that translation data used at this stage and text classification data used in later stages have a domain discrepancy, we perform reweighting of translation data; these weights are tentatively fixed at this stage and will be updated at a later stage. At the second stage, we pretrain a text encoder using contrastive learning on augmented sentences created by the translation models. At the third stage, we finetune a text classification model, using the text encoder pretrained at the second stage as regularization. At the fourth stage, we measure the performance of the text classifier on a validation set and update weights of translation data by maximizing the validation performance. Each level of optimization problem in our framework corresponds to a learning stage. These stages are performed end-to-end. Experiments on datasets in the GLUE benchmark (Wang et al., 2018a) and on datasets used in Gururangan et al. (2020) demonstrate the effectiveness of our method.
The major contributions of this paper include:
We propose a four-level optimization framework to perform contrastive learning (CL) and data augmentation end-to-end. Our framework enables the training of augmentation models to be guided by the CL task and makes augmented data suitable for CL.
We demonstrate the effectiveness of our method on datasets in the GLUE benchmark (Wang et al., 2018a) and on datasets used in Gururangan et al. (2020).
2 Related Works
2.1 Contrastive Learning in NLP
Recently, contrastive learning has received increasing attention in NLP. Gao et al. (2021) proposed a simple contrastive learning–based sentence embedding method. In this method, the same input sentence is fed into a pretrained RoBERTa (Liu et al., 2019) model twice by applying different dropout masks and the resulting two embeddings are labeled as being similar. Embeddings of different sentences are labeled as dissimilar. Pan et al. (2021) proposed a contrastive learning method for many-to-many multilingual neural machine translation, where contrastive learning is leveraged to close the gap among representations of different languages. Su et al. (2021) developed a contrastive learning method for biomedical relation extraction, where linguistic knowledge is leveraged for data augmentation. Wang et al. (2021) proposed to construct semantically negative examples to perform contrastive learning, for the sake of improving the robustness against semantical adversarial attacks. Pan et al. (2022) proposed to perform contrastive learning on adversarial examples generated by perturbing word embeddings, in order to learn noise-invariant representations.
2.2 Contrastive Self-Supervised Learning in Non-NLP Domains
Contrastive self-supervised learning has been broadly studied recently in other domains besides NLP. Henaff (2020) proposed a contrastive predictive coding method for data-efficient classification. In this method, autoregressive models are leveraged to predict the future in a latent space. Khosla et al. (2020) proposed a supervised contrastive learning method. Data examples having the same class label are made close to each other in the latent space while examples with different class labels are separated farther apart. Laskin et al. (2020) proposed a method to learn contrastive unsupervised representations for reinforcement learning. In Klein and Nabi (2020), a contrastive self-supervised learning approach is proposed for commonsense reasoning.
2.3 Bi-level Optimization
Our framework is a multi-level optimization framework, which is an extension of bi-level optimization (BLO). BLO (Dempe, 2002) has been applied for many applications in NLP, such as neural architecture search (Liu et al., 2018), hyperparameter tuning (Feurer et al., 2015), data reweighting (Shu et al., 2019; Ren et al., 2020; Wang et al., 2020), label denoising (Zheng et al., 2021), learning rate adjustment (Baydin et al., 2018), meta learning (Finn et al., 2017), data generation (Such et al., 2020), and so forth. In these BLO-based methods, meta parameters (neural architectures, hyperparameters, importance weights of training data examples, etc.) are learned by minimizing a validation loss and weight parameters are optimized by minimizing a training loss.
3 Method
In this section, we present our proposed end-to-end contrastive learning framework.
3.1 Overview
We use back-translation (Sennrich et al., 2016) to perform data augmentation of sentences. Then on augmented sentences, contrastive learning is performed. Two augmented sentences are labeled as similar if they originate from the same original sentence. Two augmented sentences are labeled as dissimilar if they originate from different original sentences. Contrastive losses (Hadsell et al., 2006) are defined on these similar and dissimilar pairs. A text encoder is pretrained by minimizing the contrastive losses.
We assume the end task is text classification. Our framework consists of the following learning stages, which are performed end-to end. At the first stage, we train four translation models to perform data augmentation using back translation. Each translation pair in the training set is associated with a weight. At the second stage, on augmented sentences created by the translation models, we perform contrastive learning to train a text encoder. At the third stage, using the encoder trained at the second stage as regularization, we finetune a text classification model. At the fourth stage, the classification model trained at the third stage is evaluated on a validation classification dataset and the weights of translation pairs at the first stage are updated by minimizing the validation loss. The four stages are performed in a four- level optimization framework. Figure 1 illustrates our framework. Next, we describe the four stages in detail.
3.2 Stage I: Training Machine Translation Model for Sentence Augmentation
Figure 2 shows the workflow of data augmentation. For an input sentence x, we augment it using back-translation (Sennrich et al., 2016). In our experiments, the language of the classification data is English. We use an English-to-German machine translation (MT) model to translate x to y. Then we use a German-to-English MT model to translate y to x′. Then x′ is regarded as an augmented sentence of x. Similarly, we use an English-to-Chinese MT model and a Chinese-to- English MT model to obtain another augmented sentence x″. We use German and Chinese as the two auxiliary languages because 1) both of them are resource-rich languages that have abundant translation data for model training; 2) they are sufficiently different from each other to achieve higher diversity in augmented examples.
Given an original English sentence x, we feed it into to get a translated German sentence, which is then fed into to get a translated English sentence x′. Meanwhile, we feed x into to get a translated Chinese sentence, which is then fed into to get another translated English sentence x″. x′ and x″ are two augmented sentences of x. Since translation models are trained using in-domain translation examples that have large domain similarity with classification data, augmented examples generated by translation models are likely to be in the same domain as classification data as well; contrastive learning performed on these in- domain augmented examples is likely to produce latent representations that are suitable for representing classification data.
3.3 Stage II: Contrastive Learning
3.4 Stage III: Finetuning Text Classifier
3.5 Stage IV: Update Weights of Translation Data
3.6 Four-Level Optimization Framework
3.7 Reducing Memory and Computation Cost
3.8 Optimization Algorithm
4 Experiments
4.1 Tasks and Datasets
For text classification, we use two collections of datasets. The first collection is from the General Language Understanding Evaluation (GLUE) benchmark, which has 11 tasks, including 2 single-sentence tasks including CoLA (Warstadt et al., 2019) and SST-2 (Socher et al., 2013), 3 similarity and paraphrase tasks including MRPC (Dolan and Brockett, 2005), QQP,1 and STS-B (Cer et al., 2017), and 5 inference tasks including MNLI (Williams et al., 2018), QNLI (Rajpurkar et al., 2016), RTE (Dagan et al., 2005), and WNLI (Levesque et al., 2012). Table 1 shows the split statistics of GLUE datasets. The second collection is from Gururangan et al. (2020), including ChemProt (Kringelum et al., 2016), RCT (Dernoncourt and Lee, 2017), ACL-ARC (Jurgens et al., 2018), SciERC (Luan et al., 2018), HyperPartisan (Kiesel et al., 2019), AGNews (Zhang et al., 2015), Helpfulness (McAuley et al., 2015), and IMDB (Maas et al., 2011). In our method, we split the original training set into a new training set and a validation set, with a ratio of 1:1. The new training set is used as and the validation set is used as .
. | CoLA . | RTE . | QNLI . | STS-B . | MRPC . | WNLI . | SST-2 . | MNLI (m/mm) . | QQP . | AX . |
---|---|---|---|---|---|---|---|---|---|---|
Train | 8551 | 2490 | 104743 | 5749 | 3668 | 635 | 67349 | 392702 | 363871 | – |
Dev | 1043 | 277 | 5463 | 1500 | 408 | 71 | 872 | 9815/9832 | 40432 | – |
Test | 1063 | 3000 | 5463 | 1379 | 1725 | 146 | 1821 | 9796/9847 | 390965 | 1104 |
. | CoLA . | RTE . | QNLI . | STS-B . | MRPC . | WNLI . | SST-2 . | MNLI (m/mm) . | QQP . | AX . |
---|---|---|---|---|---|---|---|---|---|---|
Train | 8551 | 2490 | 104743 | 5749 | 3668 | 635 | 67349 | 392702 | 363871 | – |
Dev | 1043 | 277 | 5463 | 1500 | 408 | 71 | 872 | 9815/9832 | 40432 | – |
Test | 1063 | 3000 | 5463 | 1379 | 1725 | 146 | 1821 | 9796/9847 | 390965 | 1104 |
For machine translation, we use 3K English- Chinese and 3K English-German language pairs randomly sampled from WMT17.2 For contrastive learning, it is performed on all input texts (excluding labels) of training datasets in the 11 GLUE tasks.
4.2 Experimental Settings
For translation models, we use those experimented in Britz et al. (2017), which are encoder-decoder models with attention. The encoder and decoder are both 4-layer bi-directional LSTM networks with a hidden size of 512. The attentions are additive, with a dimension of 512. For classifiers, BERT is used for GLUE and RoBERTa is used for datasets in Gururangan et al. (2020). The gumbel- softmax trick (Jang et al., 2017; Maddison et al., 2017) is leveraged to deal with the non- differentiability of words.
We use MoCo (He et al., 2020) to implement the contrastive learning method. Text encoders are initialized using pretrained BERT (Devlin et al., 2019) or pretrained RoBERTa (Liu et al., 2019). In MoCo, the size of the queue (which is the hyperparameter K in Section 3.3) was set to 96606. The coefficient of MoCo momentum of updating the key encoder was set to 0.999. The temperature parameter (which is the hyperparameter τ in Section 3.3) in the contrastive loss was set to 0.07. A multi-layer perceptron head was used. For MoCo training, a stochastic gradient descent solver with momentum was used. Minibatch size was set to 16. Initial learning rate was set to 4 · 10−5. Learning rate was adjusted using cosine scheduling. Weight decay was used with a coefficient of 1 · 10−5.
For classification on GLUE tasks, the classification head is set to a linear layer. The maximum sequence length was set to 128. The tradeoff parameter λ in Eq. (4) is set to 0.1. Minibatch size was set to 16. The learning rate was set to 3 · 10−5 for CoLA, MNLI, STS-B; 2 · 10−5 for RTE, QNLI, MRPC, SST-2, WNLI; and 1 · 10−5 for QQP. The number of training epochs was set to 100.
Hyperparameter Tuning Details
For most hyperparameters in MoCo and LSTM, we use the default values given in He et al. (2020) and Britz et al. (2017). The tradeoff parameter λ is tuned in {0.01,0.05,0.1,0.5,1} on the development set. For each configuration of λ, we run our method on (one half of the training set) and (the other half of the training set). Then we measure the performance of the trained model on the development set. The λ value yielding the best performance is selected. We tuned the hyperparameters of baselines extensively, where the tuning time for each baseline is roughly the same as that for our method.
4.3 Baselines
We compare our methods with the following baselines. Let Ours-SPS denote the proposed framework in Eq. (6) which performs soft parameter-sharing (SPS) between V and U via regularization, and let Ours-HPS denote the framework in Section 3.7 which performs hard parameter-sharing (HPS) where V and U are the same.
Vanilla RoBERTa (Liu et al., 2019). The Transformer-based encoder is initialized with pretrained RoBERTa. A text classification model is formed by stacking the pretrained encoder and a classification head, with an architecture that is the same as that in Liu et al. (2019). The classification head is a feedforward layer, where the nonlinear activation function is tanh. Learned encoding of the special token [CLS] is fed into the classification head to predict the class label. Then we finetune the classification model on a classification dataset.
Vanilla BERT (Devlin et al., 2019). This approach is similar to vanilla RoBERTa. The only difference is that the Transformer-based encoder is initialized by pretrained BERT (Devlin et al., 2019) instead of RoBERTa.
TAPT: Task Adaptive Pretraining (Gururangan et al., 2020). In this approach, given a target dataset Dt, the pretrained BERT or RoBERTa on external data is further pretrained on the input sentences in Dt by predicting masked tokens.
SimCSE (Gao et al., 2021). In this approach, the same input sentence is fed into a pretrained RoBERTa encoder twice by applying different dropout masks, to get two different embeddings. These two embeddings are labeled as being “similar”. Embeddings of different sentences are labeled as being “dissimilar”. Contrastive learning is performed on these “similar” and “dissimilar” pairs.
CSSL-Separate. In this approach, data augmentation, contrastive learning, and text classification are performed separately. We first train machine translation models and use them to perform sentence augmentation. Then on augmented sentences, we perform contrastive learning. Finally, using the text encoder pretrained by contrastive learning as initialization, we finetune the classification model. When performing contrastive learning, the text encoder is initialized using pretrained RoBERTa or BERT.
CSSL-MTL. This approach is similar to CSSL-Separate, except that the CSSL task and classification task are performed jointly in a multi-task learning (MTL) framework, by minimizing the weighted sum of their losses. The weight is 0.01 for CSSL loss and is 1 for classification loss.
4.4 Results
4.4.1 Results in BERT-Based Experiments
In BERT-based experiments, the text encoder is initialized using BERT, before contrastive learning is performed. Tables 2 and 3 show the results on GLUE test sets, in BERT-based experiments. Our methods including Ours-SPS and Ours-HPS outperform all baselines on average scores. Out of the 11 tasks (MNLI-m and MNLI-mm are treated as two separate tasks), Ours-SPS outperforms all baselines on 8 tasks; Ours-HPS outperforms all baselines on 7 tasks. These results demonstrate the effectiveness of our end-to-end frameworks. Via parameter sharing, Ours-HPS has much smaller memory and computation costs than Ours-SPS, with a small sacrifice of classification performance. The inference costs of our methods are similar to those of baselines. During inference, only V and H are needed, which are the same as baseline models. Figure 4 shows the accuracy curve of Ours-HPS on the RTE validation set () under different runs. As can be seen, our algorithm converges well.
. | Train Time . | Memory . | Train Data . | Parameters . | CoLA . | SST-2 . | RTE . | QNLI . | MRPC . |
---|---|---|---|---|---|---|---|---|---|
(hours) . | (GB) . | (millions) . | (millions) . | (Matthew) . | (Acc.) . | (Acc.) . | (Acc.) . | (Acc./F1) . | |
BERT | 6.3 | 11.7 | 1.019 | 345 | 60.5 | 94.9 | 70.1 | 92.7 | 85.4/89.3 |
TAPT | 13.5 | 11.8 | 1.019 | 345 | 61.3 | 94.4 | 70.3 | 92.4 | 85.9/89.5 |
SimCSE | 16.4 | 12.1 | 1.019 | 690 | 59.5 | 94.3 | 71.2 | 92.9 | 85.9/89.8 |
CSSL-Separate | 16.1 | 12.0 | 1.025 | 713 | 59.4 | 94.5 | 71.4 | 92.8 | 85.8/89.6 |
CSSL-MTL | 16.9 | 12.2 | 1.025 | 713 | 59.7 | 94.7 | 71.2 | 92.5 | 86.0/89.6 |
Ours-SPS | 28.2 | 20.5 | 1.025 | 713 | 63.0 | 95.8 | 72.5 | 93.2 | 86.1/89.9 |
Ours-HPS | 17.4 | 12.7 | 1.025 | 356 | 62.4 | 95.3 | 72.4 | 92.5 | 86.3/89.9 |
. | Train Time . | Memory . | Train Data . | Parameters . | CoLA . | SST-2 . | RTE . | QNLI . | MRPC . |
---|---|---|---|---|---|---|---|---|---|
(hours) . | (GB) . | (millions) . | (millions) . | (Matthew) . | (Acc.) . | (Acc.) . | (Acc.) . | (Acc./F1) . | |
BERT | 6.3 | 11.7 | 1.019 | 345 | 60.5 | 94.9 | 70.1 | 92.7 | 85.4/89.3 |
TAPT | 13.5 | 11.8 | 1.019 | 345 | 61.3 | 94.4 | 70.3 | 92.4 | 85.9/89.5 |
SimCSE | 16.4 | 12.1 | 1.019 | 690 | 59.5 | 94.3 | 71.2 | 92.9 | 85.9/89.8 |
CSSL-Separate | 16.1 | 12.0 | 1.025 | 713 | 59.4 | 94.5 | 71.4 | 92.8 | 85.8/89.6 |
CSSL-MTL | 16.9 | 12.2 | 1.025 | 713 | 59.7 | 94.7 | 71.2 | 92.5 | 86.0/89.6 |
Ours-SPS | 28.2 | 20.5 | 1.025 | 713 | 63.0 | 95.8 | 72.5 | 93.2 | 86.1/89.9 |
Ours-HPS | 17.4 | 12.7 | 1.025 | 356 | 62.4 | 95.3 | 72.4 | 92.5 | 86.3/89.9 |
. | MNLI-m/mm . | QQP . | STS-B (Pearson/ . | WNLI . | AX . | Average . |
---|---|---|---|---|---|---|
(Accuracy) . | (Accuracy/F1) . | Spearman) . | (Accuracy) . | (Matthew) . | ||
BERT | 86.7/85.9 | 89.3/72.1 | 87.6/86.5 | 65.1 | 39.6 | 80.5 |
TAPT | 85.7/84.4 | 89.6/71.9 | 88.1/87.0 | 65.8 | 39.3 | 80.6 |
SimCSE | 87.1/86.4 | 90.5/72.5 | 87.8/86.9 | 65.8 | 39.6 | 80.8 |
CSSL-Separate | 87.3/86.6 | 90.6/72.7 | 87.4/86.6 | 65.5 | 39.6 | 80.8 |
CSSL-MTL | 87.4/86.8 | 90.9/72.9 | 87.3/86.6 | 65.4 | 39.6 | 80.8 |
Ours-SPS | 86.7/86.2 | 90.0/72.9 | 88.2/87.3 | 66.9 | 40.3 | 81.7 |
Ours-HPS | 86.8/86.2 | 89.8/72.8 | 88.3/87.3 | 66.1 | 40.2 | 81.4 |
. | MNLI-m/mm . | QQP . | STS-B (Pearson/ . | WNLI . | AX . | Average . |
---|---|---|---|---|---|---|
(Accuracy) . | (Accuracy/F1) . | Spearman) . | (Accuracy) . | (Matthew) . | ||
BERT | 86.7/85.9 | 89.3/72.1 | 87.6/86.5 | 65.1 | 39.6 | 80.5 |
TAPT | 85.7/84.4 | 89.6/71.9 | 88.1/87.0 | 65.8 | 39.3 | 80.6 |
SimCSE | 87.1/86.4 | 90.5/72.5 | 87.8/86.9 | 65.8 | 39.6 | 80.8 |
CSSL-Separate | 87.3/86.6 | 90.6/72.7 | 87.4/86.6 | 65.5 | 39.6 | 80.8 |
CSSL-MTL | 87.4/86.8 | 90.9/72.9 | 87.3/86.6 | 65.4 | 39.6 | 80.8 |
Ours-SPS | 86.7/86.2 | 90.0/72.9 | 88.2/87.3 | 66.9 | 40.3 | 81.7 |
Ours-HPS | 86.8/86.2 | 89.8/72.8 | 88.3/87.3 | 66.1 | 40.2 | 81.4 |
We present the following analysis. First, the reason that our methods outperform CSSL- Separate and CSSL-MTL is that in our methods, data augmentation and contrastive learning are performed end-to-end where the training of translation models (used for data augmentation) is guided by the contrastive learning performance and the augmented sentences are encouraged to be suitable for performing the contrastive learning task. In contrast, in CSSL-Separate and CSSL- MTL, data augmentation and contrastive learning are performed separately. Consequently, the augmented data may not be optimal for performing contrastive learning. Second, the reason that Ours-SPS outperforms Ours-HPS is that in Ours- SPS, while the classification model is regularized by the CSSL-pretrained text encoder, they are not exactly the same. This gives the classification model some flexibility in capturing the unique properties of classification data. In contrast, in Ours-HPS, the classification model and the CSSL-pretrained text encoder are required to be exactly the same, which might be too restrictive. On the other hand, it is worth noting that the classification performance gap between Ours- HPS and CSSL-pretrained is not very large while the memory and computation cost of Ours-HPS are much smaller.
Third, overall, CSSL-MTL works better than CSSL-Separate. Out of the 11 tasks, CSSL-MTL outperforms CSSL-Separate on 6 tasks and is on par with CSSL-Separate on 1 task. The reason is that in CSSL-MTL, contrastive learning and classification are performed jointly, which enables these two tasks to mutually benefit from each other. In contrast, in CSSL-Separate, contrastive learning and classification are performed separately. While contrastive learning influences classification, classification does not provide any feedback to contrastive learning. Fourth, CSSL- Separate and SimCSE are in general on par with each other. The only difference between these two methods is that CSSL-Separate uses back- translation for data augmentation while SimCSE uses dropout masks. This shows that these two augmentation methods work equally well for contrastive learning on texts. Fifth, CSSL-Separate and SimCSE outperform TAPT. In TAPT, the self-supervised task is masked token prediction. This shows that contrastive learning is a more effective self-supervised learning method than masked token prediction. Sixth, CSSL-Separate and SimCSE perform better than BERT, which further demonstrates the effectiveness of contrastive learning. Seventh, the improvement achieved by our methods is more prominent on smaller datasets such as WNLI and RTE. This is because, for smaller datasets, it is more necessary to leverage unlabeled data via contrastive learning to learn overfitting-resilient representations. Eighth, for tasks that have similar data size, the improvement of our method is more prominent on similarity tasks than on other types of tasks. For example, QQP and MNLI have similar data size; our method achieves better improvement on QQP, which is a similarity task, than on MNLI, which is an inference task.
Tables 4 and 5 show the results of BERT-based experiments on the development sets of GLUE. Our methods outperform all baselines on different tasks. The analysis of reasons is similar to that for Tables 2 and 3. Note that our method can be applied to other text encoders besides BERT. For example, our method can be applied to Turing-NLR-v5 (Bajaj et al., 2022). This encoder achieves state-of-the-art performance on the GLUE leaderboard, with an average score of 91.2.
. | CoLA . | SST-2 . | RTE . | QNLI . | MRPC . |
---|---|---|---|---|---|
(Matthew Corr.) . | (Accuracy) . | (Accuracy) . | (Accuracy) . | (Accuracy/F1) . | |
BERT (from Lan et al.) | 60.6 | 93.2 | 70.4 | 92.3 | 88.0/– |
BERT (our run) | 62.1 | 93.1 | 74.0 | 92.1 | 86.8/90.8 |
TAPT | 61.2 | 93.1 | 74.0 | 92.0 | 85.3/89.8 |
SimCSE | 62.0 | 93.5 | 73.5 | 92.2 | 86.8/90.9 |
CSSL-Separate | 62.3 | 93.7 | 72.0 | 92.5 | 87.0/90.9 |
CSSL-MTL | 62.7 | 93.7 | 72.6 | 92.3 | 87.1/90.9 |
No-CL | 62.3 | 93.2 | 74.0 | 92.2 | 86.9/90.8 |
No-FT | 62.4 | 93.6 | 72.3 | 92.4 | 87.0/90.9 |
No-BT | 62.8 | 93.7 | 74.0 | 92.5 | 87.1/90.9 |
Domain-Reweight | 62.7 | 93.8 | 72.5 | 92.6 | 87.1/90.9 |
Fix-Weight-Separate | 62.7 | 93.8 | 72.8 | 92.6 | 87.1/90.9 |
Fix-Weight-MTL | 62.9 | 93.8 | 73.2 | 92.4 | 87.2/90.9 |
DANN | 62.4 | 93.7 | 72.2 | 92.4 | 87.1/90.9 |
CAN | 62.3 | 93.6 | 72.3 | 92.4 | 87.2/90.9 |
Ours-Transformer-MT | 63.2 | 93.9 | 74.2 | 92.6 | 87.2/90.9 |
Ours-SPS | 63.6 | 94.0 | 74.8 | 92.9 | 87.8/91.1 |
Ours-HPS | 63.3 | 93.9 | 74.4 | 92.7 | 87.3/90.9 |
. | CoLA . | SST-2 . | RTE . | QNLI . | MRPC . |
---|---|---|---|---|---|
(Matthew Corr.) . | (Accuracy) . | (Accuracy) . | (Accuracy) . | (Accuracy/F1) . | |
BERT (from Lan et al.) | 60.6 | 93.2 | 70.4 | 92.3 | 88.0/– |
BERT (our run) | 62.1 | 93.1 | 74.0 | 92.1 | 86.8/90.8 |
TAPT | 61.2 | 93.1 | 74.0 | 92.0 | 85.3/89.8 |
SimCSE | 62.0 | 93.5 | 73.5 | 92.2 | 86.8/90.9 |
CSSL-Separate | 62.3 | 93.7 | 72.0 | 92.5 | 87.0/90.9 |
CSSL-MTL | 62.7 | 93.7 | 72.6 | 92.3 | 87.1/90.9 |
No-CL | 62.3 | 93.2 | 74.0 | 92.2 | 86.9/90.8 |
No-FT | 62.4 | 93.6 | 72.3 | 92.4 | 87.0/90.9 |
No-BT | 62.8 | 93.7 | 74.0 | 92.5 | 87.1/90.9 |
Domain-Reweight | 62.7 | 93.8 | 72.5 | 92.6 | 87.1/90.9 |
Fix-Weight-Separate | 62.7 | 93.8 | 72.8 | 92.6 | 87.1/90.9 |
Fix-Weight-MTL | 62.9 | 93.8 | 73.2 | 92.4 | 87.2/90.9 |
DANN | 62.4 | 93.7 | 72.2 | 92.4 | 87.1/90.9 |
CAN | 62.3 | 93.6 | 72.3 | 92.4 | 87.2/90.9 |
Ours-Transformer-MT | 63.2 | 93.9 | 74.2 | 92.6 | 87.2/90.9 |
Ours-SPS | 63.6 | 94.0 | 74.8 | 92.9 | 87.8/91.1 |
Ours-HPS | 63.3 | 93.9 | 74.4 | 92.7 | 87.3/90.9 |
. | MNLI-m/mm . | QQP . | STS-B (Pearson Corr./ . | WNLI . |
---|---|---|---|---|
(Accuracy) . | (Accuracy/F1) . | Spearman Corr.) . | (Accuracy) . | |
BERT (from Lan et al.) | 86.6/– | 91.3/– | 90.0/– | – |
BERT (our run) | 86.2/86.0 | 91.3/88.3 | 90.4/90.0 | 56.3 |
TAPT | 85.6/85.5 | 91.5/88.7 | 90.6/90.2 | 53.5 |
SimCSE | 86.4/86.1 | 91.2/88.1 | 90.5/90.1 | 56.3 |
CSSL-Separate | 86.6/86.4 | 91.4/88.5 | 90.2/89.9 | 56.3 |
CSSL-MTL | 86.6/86.5 | 91.5/88.6 | 90.4/90.1 | 56.3 |
No-CL | 86.3/86.1 | 91.3/88.4 | 90.5/90.2 | 56.3 |
No-FT | 86.6/86.5 | 91.4/88.4 | 90.4/90.0 | 56.3 |
No-BT | 86.7/86.6 | 91.5/88.7 | 90.7/90.3 | 56.3 |
Domain-Reweight | 86.6/86.5 | 91.5/88.6 | 90.4/90.0 | 56.3 |
Fix-Weight-Separate | 86.7/86.6 | 91.5/88.7 | 90.4/90.0 | 56.3 |
Fix-Weight-MTL | 86.7/86.6 | 91.6/88.8 | 90.5/90.2 | 56.3 |
DANN | 86.6/86.4 | 91.5/88.5 | 90.3/90.1 | 56.3 |
CAN | 86.6/86.5 | 91.6/88.5 | 90.2/90.1 | 56.3 |
Ours-Transformer-MT | 86.7/86.8 | 91.6/88.9 | 90.7/90.3 | 56.3 |
Ours-SPS | 86.9/87.0 | 91.9/89.2 | 91.0/90.8 | 56.3 |
Ours-HPS | 86.9/86.9 | 91.7/89.1 | 90.8/90.5 | 56.3 |
. | MNLI-m/mm . | QQP . | STS-B (Pearson Corr./ . | WNLI . |
---|---|---|---|---|
(Accuracy) . | (Accuracy/F1) . | Spearman Corr.) . | (Accuracy) . | |
BERT (from Lan et al.) | 86.6/– | 91.3/– | 90.0/– | – |
BERT (our run) | 86.2/86.0 | 91.3/88.3 | 90.4/90.0 | 56.3 |
TAPT | 85.6/85.5 | 91.5/88.7 | 90.6/90.2 | 53.5 |
SimCSE | 86.4/86.1 | 91.2/88.1 | 90.5/90.1 | 56.3 |
CSSL-Separate | 86.6/86.4 | 91.4/88.5 | 90.2/89.9 | 56.3 |
CSSL-MTL | 86.6/86.5 | 91.5/88.6 | 90.4/90.1 | 56.3 |
No-CL | 86.3/86.1 | 91.3/88.4 | 90.5/90.2 | 56.3 |
No-FT | 86.6/86.5 | 91.4/88.4 | 90.4/90.0 | 56.3 |
No-BT | 86.7/86.6 | 91.5/88.7 | 90.7/90.3 | 56.3 |
Domain-Reweight | 86.6/86.5 | 91.5/88.6 | 90.4/90.0 | 56.3 |
Fix-Weight-Separate | 86.7/86.6 | 91.5/88.7 | 90.4/90.0 | 56.3 |
Fix-Weight-MTL | 86.7/86.6 | 91.6/88.8 | 90.5/90.2 | 56.3 |
DANN | 86.6/86.4 | 91.5/88.5 | 90.3/90.1 | 56.3 |
CAN | 86.6/86.5 | 91.6/88.5 | 90.2/90.1 | 56.3 |
Ours-Transformer-MT | 86.7/86.8 | 91.6/88.9 | 90.7/90.3 | 56.3 |
Ours-SPS | 86.9/87.0 | 91.9/89.2 | 91.0/90.8 | 56.3 |
Ours-HPS | 86.9/86.9 | 91.7/89.1 | 90.8/90.5 | 56.3 |
4.4.2 Results in RoBERTa-Based Experiments
In RoBERTa-based experiments, text encoders are initialized using RoBERTa, before contrastive learning is performed. These experiments are conducted on datasets used in Gururangan et al. (2020). Table 6 shows the results. Our methods outperform all baselines. The analysis of reasons is similar to that for results in Tables 2 and 3.
Dataset . | RoBERTa . | TAPT . | SimCSE . | CSSL-Separate . | CSSL-MTL . | Ours-SPS . | Ours-HPS . |
---|---|---|---|---|---|---|---|
ChemProt | 81.91.0 | 82.60.4 | 83.20.2 | 82.90.5 | 83.40.3 | 84.50.4 | 84.20.3 |
87.20.1 | 87.70.1 | 87.60.1 | 87.60.1 | 87.70.1 | 88.00.1 | 87.90.1 | |
ACL-ARC | 63.05.8 | 67.41.8 | 69.52.6 | 68.73.1 | 70.83.7 | 75.62.9 | |
SciERC | 77.31.9 | 79.31.5 | 80.50.7 | 80.91.3 | 81.20.9 | 82.11.1 | 81.90.6 |
HyperPartisan | 86.60.9 | 90.45.2 | 90.92.7 | 91.43.3 | 90.71.9 | 92.32.1 | 91.81.6 |
AGNews | 93.90.2 | 94.50.1 | 94.70.1 | 94.30.1 | 94.50.1 | 95.10.1 | 94.90.1 |
Helpfulness | 65.13.4 | 68.51.9 | 68.71.7 | 69.20.5 | 69.51.1 | 70.90.2 | 70.40.4 |
IMDB | 95.00.2 | 95.50.1 | 95.70.1 | 95.60.1 | 95.30.1 | 96.00.1 | 95.80.1 |
Dataset . | RoBERTa . | TAPT . | SimCSE . | CSSL-Separate . | CSSL-MTL . | Ours-SPS . | Ours-HPS . |
---|---|---|---|---|---|---|---|
ChemProt | 81.91.0 | 82.60.4 | 83.20.2 | 82.90.5 | 83.40.3 | 84.50.4 | 84.20.3 |
87.20.1 | 87.70.1 | 87.60.1 | 87.60.1 | 87.70.1 | 88.00.1 | 87.90.1 | |
ACL-ARC | 63.05.8 | 67.41.8 | 69.52.6 | 68.73.1 | 70.83.7 | 75.62.9 | |
SciERC | 77.31.9 | 79.31.5 | 80.50.7 | 80.91.3 | 81.20.9 | 82.11.1 | 81.90.6 |
HyperPartisan | 86.60.9 | 90.45.2 | 90.92.7 | 91.43.3 | 90.71.9 | 92.32.1 | 91.81.6 |
AGNews | 93.90.2 | 94.50.1 | 94.70.1 | 94.30.1 | 94.50.1 | 95.10.1 | 94.90.1 |
Helpfulness | 65.13.4 | 68.51.9 | 68.71.7 | 69.20.5 | 69.51.1 | 70.90.2 | 70.40.4 |
IMDB | 95.00.2 | 95.50.1 | 95.70.1 | 95.60.1 | 95.30.1 | 96.00.1 | 95.80.1 |
4.4.3 Ablation Studies
To verify whether the individual components in our method are necessary, we perform the following ablation studies.
No contrastive learning (No-CL). Contrastive learning is removed from the framework. For each text-label pair , the input text x is fed into the four translation models to generate two augmentations x′ and x″. Then (x′,y) and (x″,y) are utilized as augmented data to train the classification model directly.
No finetuning (No-FT). At the third stage, the encoder V in the classification model is set to without being finetuned.
Replacing back-translation with SimCSE (No-BT). Instead of learning machine translation models for text augmentation, we use the augmentation method proposed in SimCSE (Gao et al., 2021) for augmentation.
Domain-Reweight. In CSSL-Separate, we reweight self-supervised training examples based on their domain relatedness to classification data, and perform contrastive learning on reweighted examples. Domain relatedness is calculated using an ℋ-divergence based metric (Elsahar and Gallé, 2019).
Fix-Weight-Separate and Fix-Weight-MTL: Learn translation sample weights using our method, fix them, then run CSSL-Separate and CSSL-MTL on reweighted translation examples.
Compare with other domain adaptation methods, including DANN (Ganin et al., 2016) and CAN (Kang et al., 2019). We use these methods to align the domains of input sentences in the translation and classification datasets.
Ours-Transformer-MT: In Ours-HPS, using Transformer (specifically, pretrained BERT) for machine translation, instead of using attentional LSTM.
Tables 4 and 5 show results on GLUE development sets, using BERT for model initialization. We make the following observations. First, our full methods work better than No-CL, No-FT, and No-BT. In these three ablation baselines, one component (which is contrastive learning, finetuning classifier, back translation, respectively) is removed from our full methods, yielding simpler methods. These results show that each component is useful and should not be removed, and simpler methods do not perform comparably well. Second, our methods work better than Domain-Reweight. The reason is that our methods perform reweighting of translation data together with performing other tasks (including training translation models, contrastive learning, finetuning, and validation), in an end-to-end manner. In this way, weights of translation data are influenced by other tasks and learned towards maximizing the classification performance. In contrast, Domain-Reweight performs reweighting of self-supervised training examples separately from other tasks. Weights calculated in this way are not guaranteed to be optimal for maximizing classification performance. On the other hand, Domain-Reweight outperforms CSSL-Separate, which shows that it is beneficial to reweight self-supervised training examples based on their domain relatedness to classification data. Third, our methods work better than Fix-Weight-Separate and Fix-Weight-MTL. The reason is that our methods generate augmented sentences and perform CL end-to-end while Fix-Weight-Separate and Fix-Weight-MTL perform these two tasks separately. In our end-to-end framework, guided by the performance of CL, the training of translation models is dynamically changing to generate augmented sentences that are better for improving CL. On the contrary, in Fix-Weight-Separate and Fix-Weight-MTL, the generation of augmented examples is not influenced by the CL task. Consequently, the generated augmentations may not be optimal for CL. On the other hand, Fix-Weight-Separate and Fix-Weight-MTL outperform CSSL-Separate and CSSL-MTL, which further demonstrates the benefits of reweighting translation examples based on their domain similarity to classification data and the weights learned by our methods can accurately reflect domain similarity. Fourth, our methods work better than the two domain adaptation methods DANN and CAN. The reason is because many translation examples have large domain discrepancies with classification texts; it is difficult to adapt these translation examples into the domain of classification data. Our methods learn to remove such examples instead of forcefully adapting them. Fifth, comparing Ours- Transformer-MT (using BERT for translation) and Ours-HPS (using LSTM), we can see that BERT works slightly worse than LSTM. While BERT is more expressive, it has more weight parameters to learn than LSTM, which incurs higher risk of overfitting.
We also check whether our framework can improve machine translation (MT). Since the end goal of our work is improving text classification, we evaluate MT performance on selected translation examples that have large domain similarity to classification data. Domain similarity is calculated using ℋ-divergence (Elsahar and Gallé, 2019). Translation examples whose normalized ℋ-divergence is smaller than 0.5 are selected. MT models are trained on selected training examples and evaluated on selected test examples. Table 7 compares the BLUE (Papineni et al., 2002) scores (on test sets) of the four MT models trained in our framework and those trained via MTL (minimizing the sum of training losses of the four models) without using our framework. As can be seen, the models trained in our framework perform better. The reason is that our framework trains the translation models to generate linguistically meaningful augmented texts; by doing this, the translation models are encouraged to generate translations with higher linguistic quality.
4.4.4 Parameter Sensitivity
Figure 5 shows how the classification accuracy on the development set of RTE changes with the tradeoff parameter λ of Ours-SPS. As can be seen, when λ increases from 0 to 0.1, the accuracy increases. This is because a larger λ encourages more knowledge transfer from the CSSL-pretrained encoder to the classification model. The representations learned in the contrastive SSL task help the classification model to learn. However, as we continue to increase λ, the accuracy decreases. This is because the classification model is too much biased to the CSSL- pretrained encoder and is less tailored to the classification data.
4.4.5 Qualitative Results
Table 8 shows some randomly sampled translation examples where the learned importance weights (ai in Eq. (6)) are close to 0 or 1, when the classification task is SST-2 (the percentage of data gets near-zero weights is 35.2%). Due to space limitations, we only show the English sentences in translation pairs. As can be seen, translation sentences with near-zero weights have a large domain discrepancy with the SST-2 data. SST-2 mainly contains movie reviews while these zero-weight sentences are mainly about politics. Due to this domain discrepancy, these translation data is not suitable to train data augmentation models for SST-2. Our framework can effectively identify such out-of-domain translation data and exclude them from the training process. This is another reason that our end-to-end framework achieves better performance than baselines which lack the mechanism of removing out-of-domain translation data. On the other hand, in Table 8, sentences whose weights are close to one are more relevant to movie reviews.
Sentence . | Weight . |
---|---|
The government recently called for more balanced development, even proposing a “green index” to measure growth. | 0 |
President-elect Donald Trump’s campaign narrative was based on the assumption that the US has fallen from its former greatness. | 0 |
Russia considers the agreements from the 1990s unjust, based as they were on its weakness at the time, and it wants to revise them. | 0 |
Publicity for the new film claims that it is “the first live-action film in the history of movies to star, and be told from the point of view of, a sentient animal.” | 1 |
Gore told the world in his Academy Award-winning movie (recently labeled “one-sided” and containing “scientific errors” by a British judge) to expect 20-foot sea-level rises over this century. | 1 |
Jia’s movie is episodic; four loosely linked stories about lone acts of extreme violence, mostly culled from contemporary newspaper stories. | 1 |
Sentence . | Weight . |
---|---|
The government recently called for more balanced development, even proposing a “green index” to measure growth. | 0 |
President-elect Donald Trump’s campaign narrative was based on the assumption that the US has fallen from its former greatness. | 0 |
Russia considers the agreements from the 1990s unjust, based as they were on its weakness at the time, and it wants to revise them. | 0 |
Publicity for the new film claims that it is “the first live-action film in the history of movies to star, and be told from the point of view of, a sentient animal.” | 1 |
Gore told the world in his Academy Award-winning movie (recently labeled “one-sided” and containing “scientific errors” by a British judge) to expect 20-foot sea-level rises over this century. | 1 |
Jia’s movie is episodic; four loosely linked stories about lone acts of extreme violence, mostly culled from contemporary newspaper stories. | 1 |
5 Conclusions, Discussion, and Future Work
In this paper, we propose an end-to-end framework for learning language representations based on contrastive learning. Different from existing contrastive learning methods that perform data augmentation and contrastive learning separately and thus cannot guarantee that the augmented data is optimal for contrastive learning, our method performs data augmentation and contrastive learning end-to-end in a unified framework so that data augmentation models are specifically trained for being suitable for contrastive learning. Our framework consists of four learning stages: 1) training machine translation models for text augmentation; 2) contrastive learning; 3) training a classification model; 4) updating weights of translation data by minimizing the validation loss of the classification model. We evaluate our framework on 11 English understanding tasks in the GLUE benchmark and 8 datasets in Gururangan et al. (2020). On both test set and development set, the experimental results demonstrate the effectiveness of our method.
One major limitation of our method is that it has larger computational and memory costs, due to the extra overhead of solving a four-level optimization based problem and storing MT models. To reduce these costs, in addition to tying parameters, we will explore other techniques in future, such as reducing the update frequencies of MT models and MT data weights, applying diversity-promoting regularizers (Xie et al., 2017) to speed up convergence, performing core-set based mini-batch selection (Sinha et al., 2020) to speed up convergence, and so forth.
For future work, we plan to study more challenging loss functions for self-supervised learning. We are interested in investigating a ranking-based loss, where each sentence is augmented with a ranked list of sentences that have decreasing discrepancy with the original sentence. The auxiliary task is to predict the order given the augmented sentences. Predicting an order is presumably more challenging than binary classification (as adopted in existing contrastive SSL methods) and may facilitate the learning of better representations.
Notes
References
Author notes
Action Editor: Dipanjan Das