Understanding Distil BERT In Depth
Distil Bert was introduced in paper DistilBERT, a distilled version of BERT: smaller,faster, cheaper by and lighter by victor,lysandre julien and thomas from hugging face. In the last few years transfer learning is widely using the field of natural language processing and large pretrained models becomes a basic tool in many NLP tasks. Even though these models perform best they have few millions of parameters which makes it expensive for training and inference.
These larger models come with some challenges:
- These models are computationally expensive to train and so high environmental cost
- Real time inference of these models is also expensive.
So how to place these kinds of models in Production? How to use these models under low latency constrains? Do we need to scale GPUs for inference of these models? Also, how to run these kinds of models in smartphones or other light weight devices?
These all questions come to our mind while considering these models. So, the major solution for all these can be
“How to reduce the size of these large models without affecting performance?”
There are mainly 3 ways to do this.
1. Model Quantization: Like using low precision arithmetic for inference like converting float to unassigned int.
2. Weight pruning: Removing weights those are close to zero
3. Model distillation: where a large complex model(teacher) distils its knowledge and passes it to train a smaller network(student) to match the output.
In DistilBERT paper, they showed that it is possible to reach similar performances on many downstream-tasks using much smaller language models pre-trained with knowledge distillation, resulting in models that are lighter and faster at inference time, while also requiring a smaller computational training budget. DistilBERT is small, fast, cheaper, smaller and light transformer trained by distilling Bert base. It has 40% less parameters than Bert-base-uncased, runs 60% faster while preserving over 95% of BERT’s performances.
Knowledge Distillation — A model agnostic approach
In machine learning knowledge distillation refers to the process of compressing and transferring knowledge from a computationally expensive large model(teacher model) to a smaller model(student model) by maintaining validity.
Major parts of technique
Teacher model: A very large model or ensemble of separately trained models trained with a strong regularizer such as dropout can be considered as the teacher model
Student Model: A small model that relies on teacher models distilled knowledge. It uses a different type of training called “distillation” to transfer knowledge from teacher model to student model. Student model will be more suitable for deployment because it is faster and less expensive by maintaining very close accuracy with teacher model.
How is distillation done?
- Defining a teacher network and student network:
Teacher network usually have billions/millions of parameters and student network will have few thousand parameters.
2. Train the teacher network (larger network) fully till the model converges. This is done of full training data
3. Next we will do knowledge distillation training with teacher model and student model.
For distillation we uses another set of training data(which is not in training data user for training teacher model) called transfer set. Actually here the forward pass is done through pretrained teacher model and student model and the loss is computes as below.
total loss = (alpha * Student model loss) + ((1 — alpha)* distillation loss)
Here alpha is the factor which weights student loss and distillation loss. Let me explain the SoftMax function before explaining anything.
The equation for SoftMax is given by
where t is the temperature variable and Zi(x) is the logits. For normal SoftMax function we have t=1.
- For calculating distillation loss we use KL-divergence loss. For that we take the predictions of teacher model with t>1 and predictions of student model with t>1. The output of these SoftMax is the soft labels.
- For student loss Student loss is the cross entropy loss between student prediction with t=1 and ground truth (of transfer set). The output of these are the hard labels.
What is the need of setting t > 1 for SoftMax function?
If we set t=1, the model almost always produces the correct answer with very high confidence, which has very little influence on the cross-entropy cost function during the transfer of knowledge from Teacher to Student because the probabilities are so close to zero. So we will set t>1 for SoftMax for teacher model and student model for calculating distillation loss.
Thus the total loss is computed. Finally the gradients are updated only for student model. During the test step, we evaluate the student model using the student prediction and the ground truth.
Knowledge Distillation in BERT
- Loss: For distillation in case of distill BERT they used linear combination of distill loss (Lce), Masked language modelling loss (Lmlm) and Cosine embedding loss (Lcos). They found that adding cosine embedding loss align the direction of teacher and student hidden vectors.
- Architecture: DistilBert follows same general architecture of Bert. The token type embedding and pooler are removed while the number of layers gets reduced by a factor of 2.
Why they didn’t reduced the number of hidden layers?
Actually reducing the hidden layers from 768 to 512 would reduce the total number of parameters by factor of 2. However in modern frameworks, most of the operations are optimized and reducing number of hidden dimensions have a small impact. So in their architecture number of layers were the main determinant factor.
- Layer initialization: They initialized layers of student from the teacher by taking one layer out of two.
- Training and computation: For distillation training they used larger batches. Also they used dynamic masking (Different masking pattern in different epochs) and without next sentence prediction objective. They trained Distil BERT on the same corpus as the original BERT model: a concatenation of English Wikipedia and Toronto Book Corpus. It was trained on 8 16GB V100 GPUs for approximately 90 hours.
Model Performance
They assessed the language understanding and generalization capabilities of Distil BERT on the General Language Understanding Evaluation (GLUE) benchmark, a collection of 9 datasets for evaluating natural language understanding systems. They compared the performance with Bert and ELMo encoder followed by two Bi-LSTMs. Distil Bert retains 97% of Bert Performance on GLUE benchmark. The results are as follows:
Also Distil Bert is significantly smaller and is constantly faster. Number of parameters and inference time is shown below:
Conclusion
Hugging face introduced Distil BERT which is a general-purpose pre-trained version of BERT, 40% smaller, 60% faster, that retains 97% of the language understanding capabilities. They showed that such a network can be generated through distillation and is suitable for edge devices.