Abstract
Text classification is a widely studied problem and has broad applications. In many real-world problems, the number of texts for training classification models is limited, which renders these models prone to overfitting. To address this problem, we propose SSL-Reg, a data-dependent regularization approach based on self-supervised learning (SSL). SSL (Devlin et al., 2019a) is an unsupervised learning approach that defines auxiliary tasks on input data without using any human-provided labels and learns data representations by solving these auxiliary tasks. In SSL-Reg, a supervised classification task and an unsupervised SSL task are performed simultaneously. The SSL task is unsupervised, which is defined purely on input texts without using any human- provided labels. Training a model using an SSL task can prevent the model from being overfitted to a limited number of class labels in the classification task. Experiments on 17 text classification datasets demonstrate the effectiveness of our proposed method. Code is available at https://github.com/UCSD-AI4H/SSReg.
1 Introduction
Text classification (Korde and Mahender, 2012; Lai et al., 2015; Wang et al., 2017; Howard and Ruder, 2018) is a widely studied problem in natural language processing and finds broad applications. For example, given clinical notes of a patient, judge whether this patient has heart diseases. Given a scientific paper, judge whether it is about NLP. In many real-world text classification problems, texts available for training are oftentimes limited. For instance, it is difficult to obtain many clinical notes from hospitals dueto concern of patient privacy. It is well known that when training data is limited, models tend to overfit to training data and perform less well on test data.
To address overfitting problems in text classification, we propose a data-dependent regularizer called SSL-Reg based on self-supervised learning (SSL) (Devlin et al., 2019a; He et al., 2019; Chen et al., 2020) and use it to regularize the training of text classification models, where a supervised classification task and an unsupervised SSL task are performed simultaneously. SSL (Devlin et al., 2019a; He et al., 2019; Chen et al., 2020) is an unsupervised learning approach that defines auxiliary tasks on input data without using any human-provided labels and learns data representations by solving these auxiliary tasks. For example, BERT (Devlin et al., 2019a) is a typical SSL approach where an auxiliary task is defined to predict masked tokens and a text encoder is learned by solving this task. In existing SSL approaches for NLP, an SSL task and a target task are performed sequentially. A text encoder is first trained by solving an SSL task defined on a large collection of unlabeled texts. Then this encoder is used to initialize an encoder in a target task. The encoder is finetuned by solving the target task. A potential drawback of performing SSL task and target task sequentially is that text encoder learned in SSL task may be overridden after being finetuned in target task. If training data in the target task is small, the finetuned encoder has a high risk of being overfitted to training data.
To address this problem, in SSL-Reg we perform SSL task and target tasks (which is classification) simultaneously. In SSL-Reg, an SSL loss serves as a regularization term and is optimized jointly with a classification loss. SSL-Reg enforces a text encoder to jointly solve two tasks: an unsupervised SSL task and a supervised text classification task. Due to the presence of the SSL task, models are less likely to be biased to the classification task defined on small-sized training data. We perform experiments on 17 datasets, where experimental results demonstrate the effectiveness of SSL-Reg in alleviating overfitting and improving generalization performance.
The major contributions of this paper are:
- •
We propose SSL-Reg, which is a data-dependent regularizer based on SSL, to reduce the risk that a text encoder is biased to a data-deficient classification task on small- sized training data.
- •
Experiments on 17 datasets demonstrate the effectiveness of our approaches.
2 Related Works
2.1 Self-supervised Learning for NLP
SSL aims to learn meaningful representations of input data without using human annotations. It creates auxiliary tasks solely using input data and forces deep networks to learn highly effective latent features by solving these auxiliary tasks. In NLP, various auxiliary tasks have been proposed for SSL, such as next token prediction in GPT (Radford et al.), masked token prediction in BERT (Devlin et al., 2019a), text denoising in BART (Lewis et al., 2019), contrastive learning (Fang et al., 2020), and so on. These models have achieved substantial success in learning language representations. The GPT model (Radford et al.) is a language model based on Transformer (Vaswani et al., 2017). Different from Transformer, which defines a conditional probability on an output sequence given an input sequence, GPT defines a marginal probability on a single sequence. In GPT, conditional probability of the next token given a historical sequence is defined using a Transformer decoder. Weight parameters are learned by maximizing likelihood on token sequences. BERT (Devlin et al., 2019a) aims to learn a Transformer encoder for representing texts. BERT’s model architecture is a multi-layer bidirectional Transformer encoder. In BERT, Transformer uses bidirectional self-attention. To train the encoder, BERT masks some percentage of input tokens at random, and then predicts those masked tokens by feeding hidden vectors (produced by the encoder) corresponding to masked tokens into an output softmax over word vocabulary. BERT-GPT (Wu et al., 2019et al., 2019) is a model used for sequence-to- sequence modeling where a pretrained BERT is used to encode input text and GPT is used to generate output texts. In BERT-GPT, pretraining of BERT encoder and GPT decoder is conducted separately, which may lead to inferior performance. Auto-Regressive Transformers (BART) (Lewis et al., 2019) has a similar architecture as BERT-GPT, but trains BERT encoder and GPT decoder jointly. To pretrain BART weights, input texts are corrupted randomly, such as token masking, token deletion, text infilling, and so forth, then a network is learned to reconstruct original texts. ALBERT (Lan et al., 2019) uses parameter-reduction methods to reduce memory consumption and increase training speed of BERT. It also introduces a self-supervised loss that models inter-sentence coherence. RoBERTa (Liu et al., 2019a) is a replication study of BERT pretraining. It shows that BERT’s performance can be greatly improved by carefully tuning training processes, such as (1) training models longer, with larger batches, over more data; (2) removing the next sentence prediction objective; (3) training on longer sequences, and so on. XLNet (Yang et al., 2019) learns bidirectional contexts by maximizing expected likelihood over all permutations of factorization order and uses a generalized autoregressive pretraining mechanism to overcome the pretrain-finetune discrepancy of BERT. T5 (Raffel et al., 2019) compared pretraining objectives, architectures, unlabeled datasets, transfer approaches on a wide range of language understanding tasks and proposed a unified framework that casts these tasks as a text-to-text task. ERNIE 2.0 (Sun et al., 2019) proposed a continual pretraining framework which builds and learns incrementally pretraining tasks through constant multi-task learning, to capture lexical, syntactic and semantic information from training corpora. Gururangan et al. (2020) proposed task adaptive pretraining (TAPT) and domain adaptive pretraining (DAPT). Given a RoBERTa model pretrained on large-scale corpora, TAPT continues to pretrain RoBERTa on training dataset of target task. DAPT continues to pretrain RoBERTa on datasets that have small domain differences with data in target tasks. The difference between our proposed SSL-Reg method with TAPT and DAPT is that SSL-Reg uses a self-supervised task (e.g., mask token prediction) to regularize the finetuning of RoBERTa where text classification task and self-supervised task are performed jointly. In contrast, TAPT and DAPT use self-supervised task for pretraining, where text classification task and self-supervised task are performed sequentially. The connection between our method and TAPT is that they both leverage texts in target tasks to perform self-supervised learning, in addition to SSL on large-scale external corpora. Different from SSL-Reg and TAPT, DAPT uses domain-similar texts rather than target texts for additional SSL.
2.2 Self-supervised Learning in General
Self-supervised learning has been widely applied to other application domains, such as image classification (He et al., 2019; Chen et al., 2020), graph classification (Zeng and Xie, 2021), visual question answering (He et al., 2020a), and so forth, where various strategies have been proposed to construct auxiliary tasks, based on temporal correspondence (Li et al., 2019; Wang et al., 2019a), cross-modal consistency (Wang et al., 2019b), rotation prediction (Gidaris et al., 2018; Sun et al., 2020), image inpainting (Pathak et al., 2016), automatic colorization (Zhang et al., 2016), context prediction (Nathan Mundhenk et al., 2018), and so on. Some recent works studied self-supervised representation learning based on instance discrimination (Wu et al., 2018) with contrastive learning. Oord et al. (2018) proposed contrastive predictive coding, which predicts the future in latent space by using powerful autoregressive models, to extract useful representations from high-dimensional data. Bachman et al. (2019) proposed a self-supervised representation learning approach based on maximizing mutual information between features extracted from multiple views of a shared context. MoCo (He et al., 2019) and SimCLR (Chen et al., 2020) learned image encoders by predicting whether two augmented images were created from the same original image. Srinivas et al. (2020) proposed to learn contrastive unsupervised representations for reinforcement learning. Khosla et al. (2020) investigated supervised contrastive learning, where clusters of points belonging to the same class were pulled together in embedding space, while clusters of samples from different classes were pushed apart. Klein and Nabi (2020) proposed a contrastive self-supervised learning approach for commonsense reasoning. He et al. (2020b); Yang et al. (2020) proposed an Self-Trans approach which applied contrastive self-supervised learning on top of networks pretrained by transfer learning.
Compared with supervised learning that requires each data example to be labeled by humans or semi-supervised learning which requires part of data examples to be labeled, self-supervised learning is similar to unsupervised learning because it does not need human-provided labels. The key difference between SSL and unsupervised learning is that SSL focuses on learning data representations by solving auxiliary tasks defined on un-labeled data while unsupervised learning is more general and aims to discover latent structures from data, such as clustering, dimension reduction, manifold embedding (Roweis and Saul, 2000), and so on.
2.3 Text Classification
Text classification (Minaee et al., 2020) is one of the key tasks in natural language processing and has a wide range of applications, such as sentiment analysis, spam detection, tag suggestion, and so forth. A number of approaches have been proposed for text classification. Many of them are based on RNNs. Liu et al. (2016) use multi-task learning to train RNNs, utilizing the correlation between tasks to improve text classification performance. Tai et al. (2015) generalize sequential LSTM to tree-structured LSTM to capture the syntax of sentences for achieving better classification performance. Compared with RNN-based models, CNN-based models are good at capturing local and position-invariant patterns. Kalchbrenner et al. (2014) proposed dynamic CNN, which uses dynamic k-max-pooling to explicitly capture short-term and long-range relations of words and phrases. Zhang et al. (2015) proposed a character-level CNN model for text classification, which can deal with out-of-vocabulary words. Hybrid methods combine RNN and CNN to explore the advantages of both. Zhou et al. (2015) proposed a convolutional LSTM network, which uses a CNN to extract phrase-level representations, then feeds them to an LSTM network to represent the whole sentence.
3 Methods
To alleviate overfitting in text classification, we propose SSL-Reg, which is a regularization approach based on self-supervised learning (SSL), where an unsupervised SSL task and a supervised text classification task are performed jointly.
3.1 SSL-based Regularization
At the core of SSL-Reg is using SSL to learn a text encoder that is robust to overfitting. Our methods can be used to learn any text encoder. In this work, we perform the study using a Transformer encoder, while noting that other text encoders are also applicable.
3.2 Self-supervised Learning Tasks
In this work, we use two self-supervised learning tasks—masked token prediction (MTP) and sentence augmentation type prediction (SATP)—to perform our studies while noting that other SSL tasks are also applicable.
- •
Masked Token Prediction (MTP) This task is used in BERT. Some percentage of input tokens are masked at random. Texts with masked tokens are fed into a text encoder that learns a latent representation for each token including the masked ones. The task is to predict these masked tokens by feeding hidden vectors (produced by the encoder) corresponding to masked tokens into an output softmax over word vocabulary.
- •
Sentence Augmentation Type Prediction (SATP) Given an original text o, we apply different types of augmentation methods to create augmented texts from o. We train a model to predict which type of augmentation was applied to an augmented text. We consider four types of augmentation operations used in Wei and Zou (2019), including synonym replacement, random insertion, random swap, and random deletion. Synonym replacement randomly chooses 10% of non-stop tokens from original texts and replaces each of them with a randomly selected synonym. In random insertion, for a randomly chosen non-stop token in a text, among the synonyms of this token, one randomly selected synonym is inserted into a random position in the text. This operation is performed for 10% of tokens. Synonyms for synonym replacement and random insertion are obtained from Synsets in NLTK (Bird and Loper, 2004) which are constructed based on WordNet (Miller, 1995). Synsets serve as a synonym dictionary containing groupings of synonymous words. Some words have only one Synset and some have several. In synonym replacement, if a selected word in a sentence has multiple synonyms, we randomly choose one of them, and replace all occurrences of this word in the sentence with this synonym. Random swap randomly chooses two tokens in a text and swaps their positions. This operation is performed for 10% of token pairs. Random deletion randomly removes a token with a probability of 0.1. In this SSL task, an augmented sentence is fed into a text encoder and the encoding is fed into a 4-way classification head to predict which operation was applied to generate this augmented sentence.
3.3 Text Encoder
We use a Transformer encoder to perform the study while noting that other text encoders are also applicable. Different from sequence-to-sequence models (Sutskever et al., 2014) that are based on recurrent neural networks (e.g., LSTM [Hochreiter and Schmidhuber, 1997], GRU [Chung et al., 2014]), which model a sequence of tokens via a recurrent manner and hence are computationally inefficient, Transformer eschews recurrent computation and instead uses self-attention which not only can capture dependency between tokens but also is amenable for parallel computation with high efficiency. Self-attention calculates the correlation among every pair oftokens and uses these correlation scores to create “attentive” representations by taking weighted summation of tokens’ embeddings. Transformer is composed of building blocks, each consisting of a self-attention layer and a position-wise feed-forward layer. Residual connection (He et al., 2016) is applied around each of these two sub-layers, followed by layer normalization (Ba et al., 2016). Given an input sequence, an encoder— which is a stack of such building blocks—is applied to obtain a representation for each token.
4 Experiments
4.1 Datasets
We evaluated our method on the datasets used in Gururangan et al. (2020), which are from various domains. For each dataset, we follow the train/development/test split specified in Gururangan et al. (2020). Dataset statistics are summarized in Table 1.
Domain . | Dataset . | Label Type . | Train . | Dev . | Test . | Classes . |
---|---|---|---|---|---|---|
BioMed | ChemProt | relation classification | 4169 | 2427 | 3469 | 13 |
RCT | abstract sent. roles | 180040 | 30212 | 30135 | 5 | |
CS | ACL-ARC | citation intent | 1688 | 114 | 139 | 6 |
SciERC | relation classification | 3219 | 455 | 974 | 7 | |
News | HyperPartisan | partisanship | 515 | 65 | 65 | 2 |
AGNews | topic | 115000 | 5000 | 7600 | 4 | |
Reviews | Helpfulness | review helpfulness | 115251 | 5000 | 25000 | 2 |
IMDB | review sentiment | 20000 | 5000 | 25000 | 2 |
Domain . | Dataset . | Label Type . | Train . | Dev . | Test . | Classes . |
---|---|---|---|---|---|---|
BioMed | ChemProt | relation classification | 4169 | 2427 | 3469 | 13 |
RCT | abstract sent. roles | 180040 | 30212 | 30135 | 5 | |
CS | ACL-ARC | citation intent | 1688 | 114 | 139 | 6 |
SciERC | relation classification | 3219 | 455 | 974 | 7 | |
News | HyperPartisan | partisanship | 515 | 65 | 65 | 2 |
AGNews | topic | 115000 | 5000 | 7600 | 4 | |
Reviews | Helpfulness | review helpfulness | 115251 | 5000 | 25000 | 2 |
IMDB | review sentiment | 20000 | 5000 | 25000 | 2 |
In addition, we performed experiments on the datasets in the GLUE benchmark (Wang et al., 2018). The General Language Understanding Evaluation (GLUE) benchmark has 10 tasks, including 2 single-sentence tasks, 3 similarity and paraphrase tasks, and 5 inference tasks. For each GLUE task, labels in development sets are publicly available and those in test sets are not released. We obtain performance on test sets by submitting inference results to GLUE evaluation server.1Table 2 shows the statistics of data split in each task.
. | CoLA . | RTE . | QNLI . | STS-B . | MRPC . | WNLI . | SST-2 . | MNLI (m/mm) . | QQP . | AX . |
---|---|---|---|---|---|---|---|---|---|---|
Train | 8551 | 2490 | 104743 | 5749 | 3668 | 635 | 67349 | 392702 | 363871 | – |
Dev | 1043 | 277 | 5463 | 1500 | 408 | 71 | 872 | 9815/9832 | 40432 | – |
Test | 1063 | 3000 | 5463 | 1379 | 1725 | 146 | 1821 | 9796/9847 | 390965 | 1104 |
. | CoLA . | RTE . | QNLI . | STS-B . | MRPC . | WNLI . | SST-2 . | MNLI (m/mm) . | QQP . | AX . |
---|---|---|---|---|---|---|---|---|---|---|
Train | 8551 | 2490 | 104743 | 5749 | 3668 | 635 | 67349 | 392702 | 363871 | – |
Dev | 1043 | 277 | 5463 | 1500 | 408 | 71 | 872 | 9815/9832 | 40432 | – |
Test | 1063 | 3000 | 5463 | 1379 | 1725 | 146 | 1821 | 9796/9847 | 390965 | 1104 |
4.2 Experimental Setup
4.2.1 Baselines
For experiments on datasets used in Gururangan et al. (2020), text encoders in all methods are initialized using pretrained RoBERTa (Liu et al., 2019a). For experiments on GLUE datasets, text encoders are initialized using pretrained BERT (Liu et al., 2019a) or pretrained RoBERTa. We compare our proposed SSL-Reg with the following baselines.
- •
Unregularized RoBERTa (Liu et al., 2019b). In this approach, the Transformer encoder is initialized with pretrained RoBERTa. Then the pretrained encoder and a classification head form a text classification model, which is then finetuned on a target classification task. Architecture of the classification model is the same as that in Liu et al. (2019b). Specifically, representation of the [CLS] special token is passed to a feedforward layer for class prediction. Nonlinear activation function in the feedforward layer is tanh. During finetuning, no SSL-based regularization is used. This approach is evaluated on all datasets used in Gururangan et al. (2020) and all datasets in GLUE.
- •
Unregularized BERT. This approach is the same as unregularized RoBERTa, except that the Transformer encoder is initialized by pretrained BERT (Devlin et al., 2019a) instead of RoBERTa. This approach is evaluated on all GLUE datasets.
- •
Task adaptive pretraining (TAPT) (Gururangan et al., 2020). In this approach, given the Transformer encoder pretrained using RoBERTa or BERT on large-scale external corpora, it is further pretrained by RoBERTa or BERT on input texts in a target classification dataset (without using class labels). Then this further pretrained encoder is used to initialize the encoder in the text classification model and is finetuned to perform classification tasks which use both input texts and their class labels. Similar to SSL-Reg, TAPT also performs SSL on texts in target classification dataset. The difference is: TAPT performs SSL task and classification task sequentially while SSL-Reg performs these two tasks jointly. TAPT is studied for all datasets in this paper.
- •
Domain adaptive pretraining (DAPT) (Gururangan et al., 2020). In this approach, given a pretrained encoder on large-scale external corpora, the encoder is further pretrained on a small-scale corpora whose domain is similar to that of texts in a target classification dataset. Then this further pretrained encoder is finetuned in a classification task. DAPT is similar to TAPT, except that TAPT performs the second stage pretraining on texts T in the classification dataset while DAPT performs the second stage pretraining on external texts whose domain is similar to that of T rather than directly on T. The external dataset is usually much larger than T.
- •
TAPT+SSL-Reg. When finetuning the classification model, SSL-Reg is applied. The rest is the same as TAPT.
- •
DAPT+SSL-Reg. When finetuning the classification model, SSL-Reg is applied. The rest is the same as DAPT.
4.2.2 Hyperparameter Settings
Hyperparameters were tuned on development datasets.
Hyperparameter settings for RoBERTa on datasets used in
Gururangan et al. (2020). For a fair comparison, most of our hyperparameters are the same as those in Gururangan et al. (2020). The maximum text length was set to 512. Text encoders in all methods are initialized using pretrained RoBERTa (Liu et al., 2019a) on a large-scale external dataset. For TAPT, DAPT, TAPT+SSL-Reg, and DAPT+SSL-Reg, the second-stage pretraining on texts T in a target classification dataset or on external texts whose domain is similar to that of T is based on the pretraining approach in RoBERTa. In SSL-Reg, the SSL task is masked token prediction. SSL loss function only considers the prediction of masked tokens and ignores the prediction of non-masked tokens. Probability for masking tokens is 0.15. If a token t is chosen to be masked, 80% of the time, we replace t with a special token [MASK]; 10% of the time, we replace t with a random word; and for the rest 10% of the time, we keep t unchanged. For the regularization parameter in SSL-Reg, we set it to 0.2 for ACL-ARC, 0.1 for SciERC, ChemProt, AGNews, RCT, Helpfulness, IMDB, and 0.01 for HyperPartisan. For ACL-ARC, ChemProt, RCT, SciERC, and HyperPartisan, we trained SSL-Reg for 10 epochs; for Helpfulness, 5 epochs; for AGNews, RCT and IMDB, 3 epochs. For all datasets, we used a batch size of 16 with gradient accumulation. We used the AdamW optimizer (Loshchilov and Hutter, 2017) with a warm-up proportion of 0.06, a weight decay of 0.1, and an epsilon of 1e-6. In AdamW, β1 and β2 are set to 0.9 and 0.98, respectively. The maximum learning rate was 2e-5.
Hyperparameter settings for BERT on GLUE datasets.
The maximum text length was set to 128. Since external texts whose domains are similar to those of the GLUE texts are not available, we did not compare with DAPT and DAPT+SSL-Reg. For each method applied, text encoder is initialized using pretrained BERT (Devlin et al., 2019a) (with 24 layers) on a large-scale external dataset. In TAPT, the second-stage pretraining is performed using BERT. As we will show later on, TAPT does not perform well on GLUE datasets; therefore, we did not apply TAPT+SSL-Reg on these datasets further. In SSL-Reg, we studied two SSL tasks: masked token prediction (MTP) and sentence augmentation type prediction (SATP). In MTP, we randomly mask 15% of tokens in each text. Batch size was set to 32 with gradient accumulation. We use the AdamW optimizer (Loshchilov and Hutter, 2017) with a warm-up proportion of 0.1, a weight decay of 0.01, and an epsilon of 1e-8. In AdamW, β1 and β2 are set to 0.9 and 0.999, respectively. Other hyperparameter settings are presented in Table 3 and Table 4.
Task . | Epoch . | Learning Rate . | Regularization Parameter . |
---|---|---|---|
CoLA | 10 | 3e-5 | 0.2 |
SST-2 | 3 | 3e-5 | 0.05 |
MRPC | 5 | 4e-5 | 0.05 |
STS-B | 10 | 4e-5 | 0.1 |
QQP | 5 | 3e-5 | 0.2 |
MNLI | 3 | 3e-5 | 0.1 |
QNLI | 4 | 4e-5 | 0.5 |
RTE | 10 | 3e-5 | 0.1 |
WNLI | 5 | 5e-5 | 2 |
Task . | Epoch . | Learning Rate . | Regularization Parameter . |
---|---|---|---|
CoLA | 10 | 3e-5 | 0.2 |
SST-2 | 3 | 3e-5 | 0.05 |
MRPC | 5 | 4e-5 | 0.05 |
STS-B | 10 | 4e-5 | 0.1 |
QQP | 5 | 3e-5 | 0.2 |
MNLI | 3 | 3e-5 | 0.1 |
QNLI | 4 | 4e-5 | 0.5 |
RTE | 10 | 3e-5 | 0.1 |
WNLI | 5 | 5e-5 | 2 |
Task . | Epoch . | Learning Rate . | Regularization Parameter . |
---|---|---|---|
CoLA | 6 | 3e-5 | 0.4 |
SST-2 | 3 | 3e-5 | 0.8 |
MRPC | 5 | 4e-5 | 0.05 |
STS-B | 10 | 4e-5 | 0.05 |
QQP | 5 | 3e-5 | 0.4 |
MNLI | 4 | 3e-5 | 0.5 |
QNLI | 4 | 4e-5 | 0.05 |
RTE | 8 | 3e-5 | 0.6 |
WNLI | 5 | 5e-5 | 0.1 |
Task . | Epoch . | Learning Rate . | Regularization Parameter . |
---|---|---|---|
CoLA | 6 | 3e-5 | 0.4 |
SST-2 | 3 | 3e-5 | 0.8 |
MRPC | 5 | 4e-5 | 0.05 |
STS-B | 10 | 4e-5 | 0.05 |
QQP | 5 | 3e-5 | 0.4 |
MNLI | 4 | 3e-5 | 0.5 |
QNLI | 4 | 4e-5 | 0.05 |
RTE | 8 | 3e-5 | 0.6 |
WNLI | 5 | 5e-5 | 0.1 |
Hyperparameter settings for RoBERTa on GLUE datasets.
Most hyperparameter settings follow those in RoBERTa experiments performed on datasets used in Gururangan et al. (2020).
We set different learning rates and different epoch numbers for different datasets as guided by Liu et al. (2019b). In addition, we set different regularization parameters for different datasets. These hyperparameters are listed in Table 5.
Task . | Epoch . | Learning Rate . | Regularization Parameter . |
---|---|---|---|
CoLA | 10 | 1e-5 | 0.8 |
SST-2 | 3 | 1e-5 | 1.0 |
MRPC | 10 | 1e-5 | 0.01 |
STS-B | 10 | 2e-5 | 0.01 |
QQP | 10 | 1e-5 | 0.1 |
MNLI | 3 | 1e-5 | 0.1 |
QNLI | 3 | 1e-5 | 0.1 |
RTE | 10 | 2e-5 | 0.1 |
WNLI | 10 | 2e-5 | 0.02 |
Task . | Epoch . | Learning Rate . | Regularization Parameter . |
---|---|---|---|
CoLA | 10 | 1e-5 | 0.8 |
SST-2 | 3 | 1e-5 | 1.0 |
MRPC | 10 | 1e-5 | 0.01 |
STS-B | 10 | 2e-5 | 0.01 |
QQP | 10 | 1e-5 | 0.1 |
MNLI | 3 | 1e-5 | 0.1 |
QNLI | 3 | 1e-5 | 0.1 |
RTE | 10 | 2e-5 | 0.1 |
WNLI | 10 | 2e-5 | 0.02 |
4.3 Results
4.3.1 Results on the Datasets used in Gururangan et al. (2020)
Performance of text classification on datasets used in Gururangan et al. (2020) is reported in Table 6. Following Gururangan et al. (2020), for ChemProt and RCT, we report micro-F1; for other datasets, we report macro-F1. From this table, we make the following observations. First, SSL-Reg outperforms unregularized RoBERTa significantly on all datasets. We used a double-sided t-test to perform significance tests. The p-values are less than 0.01, which indicate strong statistical significance. This demonstrates the effectiveness of our proposed SSL-Reg approach in alleviating overfitting and improving generalization performance. To further confirm this, we measure the difference between F1 scores on the training set and test set in Table 7. A larger difference implies more overfitting: performing well on the training set and less well on the test set. As can be seen, the train-test difference under SSL-Reg is smaller than that under RoBERTa. SSL-Reg encourages text encoders to solve an additional task based on SSL, which reduces the risk of overfitting to the data-deficient classification task on small-sized training data. In Figure 2, we compare the training dynamics of unregularized RoBERTa and SSL-Reg (denoted by “Regularized”). As can be seen, under a large regularization parameter λ = 1, our method achieves smaller differences between training accuracy and validation accuracy than unregularized RoBERTa; our method also achieves smaller differences between training accuracy and test accuracy than unregularized RoBERTa. These results show that our proposed SSL-Reg indeed acts as a regularizer which reduces the gap between performances on training set and validation/test set. Besides, when increasing λ from 0.1 to 1, the training accuracy of SSL-Reg decreases considerably. This also indicates that SSL-Reg acts as a regularizer which penalizes training performance.
Dataset . | RoBERTa . | DAPT . | TAPT . | SSL-Reg . | TAPT+SSL-Reg . | DAPT+SSL-Reg . | |
---|---|---|---|---|---|---|---|
ChemProt | 81.91.0 | 84.20.2 | 82.60.4 | 83.10.5 | 83.50.1 | 84.40.3 | |
RCT | 87.20.1 | 87.60.1 | 87.70.1 | 87.40.1 | 87.70.1 | 87.70.1 | |
ACL-ARC | 63.05.8 | 75.42.5 | 67.41.8 | 69.34.9 | 68.12.0 | 75.71.4 | |
SciERC | 77.31.9 | 80.81.5 | 79.31.5 | 81.40.8 | 80.40.6 | 82.30.8 | |
HyperPartisan | 86.60.9 | 88.25.9 | 90.45.2 | 92.31.4 | 93.21.8 | 90.73.2 | |
AGNews | 93.90.2 | 93.90.2 | 94.50.1 | 94.20.1 | 94.40.1 | 94.00.1 | |
Helpfulness | 65.13.4 | 66.51.4 | 68.51.9 | 69.40.2 | 71.01.0 | 68.31.4 | |
IMDB | 95.00.2 | 95.40.1 | 95.50.1 | 95.70.1 | 96.10.1 | 95.40.1 |
Dataset . | RoBERTa . | DAPT . | TAPT . | SSL-Reg . | TAPT+SSL-Reg . | DAPT+SSL-Reg . | |
---|---|---|---|---|---|---|---|
ChemProt | 81.91.0 | 84.20.2 | 82.60.4 | 83.10.5 | 83.50.1 | 84.40.3 | |
RCT | 87.20.1 | 87.60.1 | 87.70.1 | 87.40.1 | 87.70.1 | 87.70.1 | |
ACL-ARC | 63.05.8 | 75.42.5 | 67.41.8 | 69.34.9 | 68.12.0 | 75.71.4 | |
SciERC | 77.31.9 | 80.81.5 | 79.31.5 | 81.40.8 | 80.40.6 | 82.30.8 | |
HyperPartisan | 86.60.9 | 88.25.9 | 90.45.2 | 92.31.4 | 93.21.8 | 90.73.2 | |
AGNews | 93.90.2 | 93.90.2 | 94.50.1 | 94.20.1 | 94.40.1 | 94.00.1 | |
Helpfulness | 65.13.4 | 66.51.4 | 68.51.9 | 69.40.2 | 71.01.0 | 68.31.4 | |
IMDB | 95.00.2 | 95.40.1 | 95.50.1 | 95.70.1 | 96.10.1 | 95.40.1 |
Dataset . | RoBERTa . | SSL-Reg . |
---|---|---|
ChemProt | 13.05 | 13.57 |
ACL-ARC | 28.67 | 25.24 |
SciERC | 19.51 | 18.23 |
HyperPartisan | 7.44 | 5.64 |
Dataset . | RoBERTa . | SSL-Reg . |
---|---|---|
ChemProt | 13.05 | 13.57 |
ACL-ARC | 28.67 | 25.24 |
SciERC | 19.51 | 18.23 |
HyperPartisan | 7.44 | 5.64 |
Second, on 6 out of the 8 datasets, SSL-Reg performs better than TAPT. On the other two datasets, SSL-Reg is on par with TAPT. This shows that SSL-Reg is more effective than TAPT. SSL-Reg and TAPT both leverage input texts in classification datasets for self-supervised learning. The difference is: TAPT uses these texts to pretrain the encoder while SSL-Reg uses these texts to regularize the encoder during finetuning. In SSL-Reg, the encoder is learned to perform classification tasks and SSL tasks simultaneously. Thus the encoder is not completely biased to classification tasks. In TAPT, the encoder is first learned by performing SSL tasks, then finetuned by performing classification tasks. There is a risk that after finetuning, the encoder is largely biased to classification tasks on small-sized training data, which leads to overfitting.
Third, on 5 out of the 8 datasets, SSL-Reg performs better than DAPT, although DAPT leverages additional external data. The reasons are two-fold: 1) similar to TAPT, DAPT performs SSL task first and then classification task separately; as a result, the encoder may be eventually biased to classification task on small-sized training data; 2) external data used in DAPT still has a domain shift with target dataset; this domain shift may render the text encoder pretrained on external data not suitable for target task. To verify this, we measure the domain similarity between external texts and target texts by calculating cosine similarity between the BERT embeddings of these texts. The similarity score is 0.14. As a reference, the similarity score between texts in the target dataset is 0.27. This shows that there is indeed a domain difference between external texts and target texts.
Fourth, on 6 out of 8 datasets, TAPT+SSL-Reg performs better than TAPT. On the other two datasets, TAPT+SSL-Reg is on par with TAPT. This further demonstrates the effectiveness of SSL-Reg.
Fifth, on all eight datasets, DAPT+SSL-Reg performs better than DAPT. This again shows that SSL-Reg is effective.
Sixth, on 6 out of 8 datasets, TAPT+SSL-Reg performs better than SSL-Reg, indicating that it is beneficial to use both TAPT and SSL-Reg: first use the target texts to pretrain the encoder based on SSL, then apply SSL-based regularizer on these target texts during finetuning.
Seventh, DAPT+SSL-Reg performs better than SSL-Reg on 4 datasets, but worse on the other 4 datasets, indicating that with SSL-Reg used, DAPT is not necessarily useful.
Eighth, on smaller datasets, improvement achieved by SSL-Reg over baselines is larger. For example, on HyperPartisan which has only about 500 training examples, improvement of SSL-Reg over RoBERTa is 5.7% (absolute percentage). Relative improvement is 6.6%. As another example, on ACL-ARC which has about 1700 training examples, improvement of SSL-Reg over RoBERTa is 6.3% (absolute percentage). Relative improvement is 10%. In contrast, on large datasets such as RCT which contains about 180000 training examples, improvement of SSL-Reg over RoBERTa is 0.2% (absolute percentage). Relative improvement is 0.2%. On another large dataset AGNews which contains 115000 training examples, improvement of SSL-Reg over RoBERTa is 0.3% (absolute percentage). Relative improvement is 0.3%. The reason that SSL-Reg achieves better improvement on smaller datasets is that smaller datasets are more likely to lead to overfitting and SSL-Reg is more needed to alleviate this overfitting.
Figure 3 shows how classification F1 score varies as we increase regularization parameter λ from 0.01 to 1.0 in SSL-Reg. As can be seen, starting from 0.01, when the regularizer parameter is increasing, F1 score increases. This is because a larger λ imposes a stronger regularization effect, which helps to reduce overfitting. However, if λ becomes too large, F1 score drops. This is because the regularization effect is too strong, which dominates classification loss. Among these 4 datasets, F1 score drops dramatically on HyperPartisan as λ increases. This is probably because this dataset contains very long sequences. This makes MTP on this dataset more difficult and therefore yields an excessively strong regularization outcome that hurts classification performance. Compared with HyperPartisan, F1 score is less sensitive on other datasets because their sequence lengths are relatively smaller.
4.3.2 Results on the GLUE Benchmark
Table 8 and Table 9 show results of BERT-based experiments on development sets of GLUE. As mentioned in (Devlin et al., 2019b), for the 24-layer version of BERT, finetuning is sometimes unstable on small datasets, so we run each method several times and report the median and best performance. Table 10 shows the best performance on test sets. Following Wang et al. (2018), we report Matthew correlation on CoLA, Pearson correlation and Spearman correlation on STS-B, accuracy and F1 on MRPC and QQP. For the rest of the datasets, we report accuracy. From these tables, we make the following observations. First, SSL-Reg methods including SSL-Reg-SATP and SSL-Reg-MTP outperform unregularized BERT (our run) on most datasets: 1) on test sets, SSL-Reg-SATP performs better than BERT on 7 out of 10 datasets and SSL-Reg-MTP performs better than BERT on 9 out of 10 datasets; 2) in terms of median results on development sets, SSL-Reg-SATP performs better than BERT (our run) on 7 out of 9 datasets and SSL-Reg-MTP performs better than BERT (our run) on 8 out of 9 datasets; 3) in terms of best results on development sets, SSL-Reg-SATP performs better than BERT (our run) on 8 out of 9 datasets and SSL-Reg-MTP performs better than BERT (our run) on 8 out of 9 datasets. This further demonstrates the effectiveness of SSL-Reg in improving generalization performance.
. | CoLA . | SST-2 . | RTE . | QNLI . | MRPC . |
---|---|---|---|---|---|
. | (Matthew Corr.) . | (Accuracy) . | (Accuracy) . | (Accuracy) . | (Accuracy/F1) . |
The median result | |||||
BERT, Lan et al., 2019 | 60.6 | 93.2 | 70.4 | 92.3 | 88.0/– |
BERT, our run | 62.1 | 93.1 | 74.0 | 92.1 | 86.8/90.8 |
TAPT | 61.2 | 93.1 | 74.0 | 92.0 | 85.3/89.8 |
SSL-Reg (SATP) | 63.7 | 93.9 | 74.7 | 92.3 | 86.5/90.3 |
SSL-Reg (MTP) | 63.8 | 93.8 | 74.7 | 92.6 | 87.3/90.9 |
The best result | |||||
BERT, our run | 63.9 | 93.3 | 75.8 | 92.5 | 89.5/92.6 |
TAPT | 62.0 | 93.9 | 76.2 | 92.4 | 86.5/90.7 |
SSL-Reg (SATP) | 65.3 | 94.6 | 78.0 | 92.8 | 88.5/91.9 |
SSL-Reg (MTP) | 66.3 | 94.7 | 78.0 | 93.1 | 89.5/92.4 |
. | CoLA . | SST-2 . | RTE . | QNLI . | MRPC . |
---|---|---|---|---|---|
. | (Matthew Corr.) . | (Accuracy) . | (Accuracy) . | (Accuracy) . | (Accuracy/F1) . |
The median result | |||||
BERT, Lan et al., 2019 | 60.6 | 93.2 | 70.4 | 92.3 | 88.0/– |
BERT, our run | 62.1 | 93.1 | 74.0 | 92.1 | 86.8/90.8 |
TAPT | 61.2 | 93.1 | 74.0 | 92.0 | 85.3/89.8 |
SSL-Reg (SATP) | 63.7 | 93.9 | 74.7 | 92.3 | 86.5/90.3 |
SSL-Reg (MTP) | 63.8 | 93.8 | 74.7 | 92.6 | 87.3/90.9 |
The best result | |||||
BERT, our run | 63.9 | 93.3 | 75.8 | 92.5 | 89.5/92.6 |
TAPT | 62.0 | 93.9 | 76.2 | 92.4 | 86.5/90.7 |
SSL-Reg (SATP) | 65.3 | 94.6 | 78.0 | 92.8 | 88.5/91.9 |
SSL-Reg (MTP) | 66.3 | 94.7 | 78.0 | 93.1 | 89.5/92.4 |
. | MNLI-m/mm . | QQP . | STS-B . | WNLI . |
---|---|---|---|---|
. | (Accuracy) . | (Accuracy/F1) . | (Pearson Corr./Spearman Corr.) . | (Accuracy) . |
The median result | ||||
BERT, Lan et al., 2019 | 86.6/– | 91.3/– | 90.0/– | – |
BERT, our run | 86.2/86.0 | 91.3/88.3 | 90.4/90.0 | 56.3 |
TAPT | 85.6/85.5 | 91.5/88.7 | 90.6/90.2 | 53.5 |
SSL-Reg (SATP) | 86.2/86.2 | 91.6/88.8 | 90.7/90.4 | 56.3 |
SSL-Reg (MTP) | 86.6/86.6 | 91.8/89.0 | 90.7/90.3 | 56.3 |
The best result | ||||
BERT, our run | 86.4/86.3 | 91.4/88.4 | 90.9/90.5 | 56.3 |
TAPT | 85.7/85.7 | 91.7/89.0 | 90.8/90.4 | 56.3 |
SSL-Reg (SATP) | 86.4/86.5 | 91.8/88.9 | 91.1/90.8 | 59.2 |
SSL-Reg (MTP) | 86.9/86.9 | 91.9/89.1 | 91.1/90.8 | 57.7 |
. | MNLI-m/mm . | QQP . | STS-B . | WNLI . |
---|---|---|---|---|
. | (Accuracy) . | (Accuracy/F1) . | (Pearson Corr./Spearman Corr.) . | (Accuracy) . |
The median result | ||||
BERT, Lan et al., 2019 | 86.6/– | 91.3/– | 90.0/– | – |
BERT, our run | 86.2/86.0 | 91.3/88.3 | 90.4/90.0 | 56.3 |
TAPT | 85.6/85.5 | 91.5/88.7 | 90.6/90.2 | 53.5 |
SSL-Reg (SATP) | 86.2/86.2 | 91.6/88.8 | 90.7/90.4 | 56.3 |
SSL-Reg (MTP) | 86.6/86.6 | 91.8/89.0 | 90.7/90.3 | 56.3 |
The best result | ||||
BERT, our run | 86.4/86.3 | 91.4/88.4 | 90.9/90.5 | 56.3 |
TAPT | 85.7/85.7 | 91.7/89.0 | 90.8/90.4 | 56.3 |
SSL-Reg (SATP) | 86.4/86.5 | 91.8/88.9 | 91.1/90.8 | 59.2 |
SSL-Reg (MTP) | 86.9/86.9 | 91.9/89.1 | 91.1/90.8 | 57.7 |
. | BERT . | TAPT . | SSL-Reg (SATP) . | SSL-Reg (MTP) . |
---|---|---|---|---|
CoLA (Matthew Corr.) | 60.5 | 61.3 | 63.0 | 61.2 |
SST-2 (Accuracy) | 94.9 | 94.4 | 95.1 | 95.2 |
RTE (Accuracy) | 70.1 | 70.3 | 71.2 | 72.7 |
QNLI (Accuracy) | 92.7 | 92.4 | 92.5 | 93.2 |
MRPC (Accuracy/F1) | 85.4/89.3 | 85.9/89.5 | 85.3/89.3 | 86.1/89.8 |
MNLI-m/mm (Accuracy) | 86.7/85.9 | 85.7/84.4 | 86.2/85.4 | 86.6/86.1 |
QQP (Accuracy/F1) | 89.3/72.1 | 89.3/71.6 | 89.6/72.2 | 89.7/72.5 |
STS-B (Pearson Corr./Spearman Corr.) | 87.6/86.5 | 88.4/87.3 | 88.3/87.5 | 88.1/87.2 |
WNLI (Accuracy) | 65.1 | 65.8 | 65.8 | 66.4 |
AX(Matthew Corr.) | 39.6 | 39.3 | 40.2 | 40.3 |
Average | 80.5 | 80.6 | 81.0 | 81.3 |
. | BERT . | TAPT . | SSL-Reg (SATP) . | SSL-Reg (MTP) . |
---|---|---|---|---|
CoLA (Matthew Corr.) | 60.5 | 61.3 | 63.0 | 61.2 |
SST-2 (Accuracy) | 94.9 | 94.4 | 95.1 | 95.2 |
RTE (Accuracy) | 70.1 | 70.3 | 71.2 | 72.7 |
QNLI (Accuracy) | 92.7 | 92.4 | 92.5 | 93.2 |
MRPC (Accuracy/F1) | 85.4/89.3 | 85.9/89.5 | 85.3/89.3 | 86.1/89.8 |
MNLI-m/mm (Accuracy) | 86.7/85.9 | 85.7/84.4 | 86.2/85.4 | 86.6/86.1 |
QQP (Accuracy/F1) | 89.3/72.1 | 89.3/71.6 | 89.6/72.2 | 89.7/72.5 |
STS-B (Pearson Corr./Spearman Corr.) | 87.6/86.5 | 88.4/87.3 | 88.3/87.5 | 88.1/87.2 |
WNLI (Accuracy) | 65.1 | 65.8 | 65.8 | 66.4 |
AX(Matthew Corr.) | 39.6 | 39.3 | 40.2 | 40.3 |
Average | 80.5 | 80.6 | 81.0 | 81.3 |
Second, on 7 out of 10 test sets, SSL-Reg-SATP outperforms TAPT; on 8 out of 10 test sets, SSL-Reg-MTP outperforms TAPT. On most development datasets, SSL-Reg-SATP and SSL-Reg-MTP outperform TAPT. The only exception is: on QQP development set, the best F1 of TAPT is slightly better than that of SSL-Reg-SATP. This further demonstrates that performing SSL-based regularization on target texts is more effective than using them for pretraining.
Third, overall, SSL-Reg-MTP performs better than SSL-Reg-SATP. For example, on 8 out of 10 test datasets, SSL-Reg-MTP performs better than SSL-Reg-SATP. MTP works better than SATP probably because it is a more challenging self-supervised learning task that encourages encoders to learn more powerful representations.
Fourth, improvement of SSL-Reg methods over BERT is more prominent on smaller training datasets, such as CoLA and RTE. This may be because smaller training datasets are more likely to lead to overfitting where the advantage of SSL-Reg in alleviating overfitting can be better played.
Tables 11 and 12 show results of RoBERTa-based experiments on development sets of GLUE. From these two tables, we make observations that are similar to those in Table 8 and Table 9. In terms of median results, SSL-Reg (MTP) performs better than unregularized RoBERTa (our run) on 7 out of 9 datasets and achieves the same performance as RoBERTa (our run) on the rest 2 datasets. In terms of best results, SSL-Reg (MTP) performs better than RoBERTa (our run) on 5 out of 9 datasets and achieves the same performance as RoBERTa (our run) on the rest 4 datasets. This further demonstrates the effectiveness of our proposed SSL-Reg approach which uses an MTP-based self-supervised task to regularize the finetuning of RoBERTa.
. | CoLA . | SST-2 . | RTE . | QNLI . | MRPC . |
---|---|---|---|---|---|
. | (Matthew Corr.) . | (Accuracy) . | (Accuracy) . | (Accuracy) . | (Accuracy/F1) . |
The median result | |||||
RoBERTa, Liu et al., 2019b | 68.0 | 96.4 | 86.6 | 94.7 | 90.9/– |
RoBERTa, our run | 68.7 | 96.1 | 84.8 | 94.6 | 89.5/92.3 |
SSL-Reg (MTP) | 69.2 | 96.3 | 85.2 | 94.9 | 90.0/92.7 |
The best result | |||||
RoBERTa, our run | 69.2 | 96.7 | 86.6 | 94.7 | 90.4/93.1 |
SSL-Reg (MTP) | 70.2 | 96.7 | 86.6 | 95.2 | 91.4/93.8 |
. | CoLA . | SST-2 . | RTE . | QNLI . | MRPC . |
---|---|---|---|---|---|
. | (Matthew Corr.) . | (Accuracy) . | (Accuracy) . | (Accuracy) . | (Accuracy/F1) . |
The median result | |||||
RoBERTa, Liu et al., 2019b | 68.0 | 96.4 | 86.6 | 94.7 | 90.9/– |
RoBERTa, our run | 68.7 | 96.1 | 84.8 | 94.6 | 89.5/92.3 |
SSL-Reg (MTP) | 69.2 | 96.3 | 85.2 | 94.9 | 90.0/92.7 |
The best result | |||||
RoBERTa, our run | 69.2 | 96.7 | 86.6 | 94.7 | 90.4/93.1 |
SSL-Reg (MTP) | 70.2 | 96.7 | 86.6 | 95.2 | 91.4/93.8 |
. | MNLI-m/mm . | QQP . | STS-B . | WNLI . |
---|---|---|---|---|
. | (Accuracy) . | (Accuracy) . | (Pearson Corr./Spearman Corr.) . | (Accuracy) . |
The median result | ||||
RoBERTa, Liu et al., 2019b | 90.2/90.2 | 92.2 | 92.4/– | – |
RoBERTa, our run | 90.5/90.5 | 91.6 | 92.0/92.0 | 56.3 |
SSL-Reg (MTP) | 90.7/90.7 | 91.6 | 92.0/92.0 | 62.0 |
The best result | ||||
RoBERTa, our run | 90.7/90.5 | 91.7 | 92.3/92.2 | 60.6 |
SSL-Reg (MTP) | 90.7/90.5 | 91.8 | 92.3/92.2 | 66.2 |
. | MNLI-m/mm . | QQP . | STS-B . | WNLI . |
---|---|---|---|---|
. | (Accuracy) . | (Accuracy) . | (Pearson Corr./Spearman Corr.) . | (Accuracy) . |
The median result | ||||
RoBERTa, Liu et al., 2019b | 90.2/90.2 | 92.2 | 92.4/– | – |
RoBERTa, our run | 90.5/90.5 | 91.6 | 92.0/92.0 | 56.3 |
SSL-Reg (MTP) | 90.7/90.7 | 91.6 | 92.0/92.0 | 62.0 |
The best result | ||||
RoBERTa, our run | 90.7/90.5 | 91.7 | 92.3/92.2 | 60.6 |
SSL-Reg (MTP) | 90.7/90.5 | 91.8 | 92.3/92.2 | 66.2 |
In SSL-Reg (SATP), we perform an ablation study on different types of sentence augmentation. Results are shown in Table 13, where SR, RD, RI, and RS denote synonym replacement, random deletion, random insertion, and random swap, respectively. SR+RD+RI+RS means that we apply these four types of operations to augment sentences; given an augmented sentence a, we predict which of the four types of operations was applied to an original sentence to create a. SR+RD+RI+RS and SR+RD hold similar meanings. From this table, we make the following observations.
. | CoLA . | SST-2 . | RTE . | QNLI . | MRPC . | STS-B . |
---|---|---|---|---|---|---|
SR+RD+RI+RS | 63.6 | 94.0 | 74.8 | 92.2 | 86.8/90.6 | 90.6/90.3 |
SR+RD+RI | 63.4 | 93.8 | 72.8 | 92.1 | 86.9/90.8 | 90.6/90.2 |
SR+RD | 61.6 | 93.6 | 72.5 | 92.2 | 87.2/91.0 | 90.6/90.3 |
. | CoLA . | SST-2 . | RTE . | QNLI . | MRPC . | STS-B . |
---|---|---|---|---|---|---|
SR+RD+RI+RS | 63.6 | 94.0 | 74.8 | 92.2 | 86.8/90.6 | 90.6/90.3 |
SR+RD+RI | 63.4 | 93.8 | 72.8 | 92.1 | 86.9/90.8 | 90.6/90.2 |
SR+RD | 61.6 | 93.6 | 72.5 | 92.2 | 87.2/91.0 | 90.6/90.3 |
First, as the number of augmentation types increases from 2 (SR+RD) to 3 (SR+RD+RI) then to 4 (SR+RD+RI+RS), the performance increases in general. This shows that it is beneficial to have more augmentation types in SATP. The reason is that more types make the SATP task more challenging and solving a more challenging self-supervised learning task can enforce sentence encoders to learn more powerful representations.
Second, SR+RD+RI+RS outperforms SR+RD +RI on 5 out of 6 datasets. This demonstrates that leveraging random swap (RS) for SATP can learn more effective representations of sentences. The reason is: SR, RD, and RI change the collection of tokens in a sentence via synonym replacement, random deletion, and random insertion; RS does not change the collection of tokens, but changes the order of tokens; therefore, RS is complementary to the other three operations; adding RS can bring in additional benefits that are complementary to those of SR, RD, and RI.
Third, SR+RD+RI performs much better than SR+RD on CoLA and is on par with SR+RD on the rest five datasets. This shows that adding RI to SR+RD is beneficial. Unlike synonym replacement (SR) and random deletion (RD) which do not increase the number of tokens in a sentence, RI increases token number. Therefore, RI is complementary to SR and RD and can bring in additional benefits.
5 Conclusions and Future Work
In this paper, we propose to use self-supervised learning to alleviate overfitting in text classification problems. We propose SSL-Reg, which is a regularizer based on SSL and a text encoder is trained to simultaneously minimize classification loss and regularization loss. We demonstrate the effectiveness of our methods on 17 text classification datasets.
For future work, we will use other self-supervised learning tasks to perform regularization, such as contrastive learning, which predicts whether two augmented sentences stem from the same original sentence.
Notes
References
Author notes
Equal contribution.